【问题标题】:Tensor indexing in custom loss function自定义损失函数中的张量索引
【发布时间】:2018-02-22 07:33:36
【问题描述】:

基本上,我希望我的自定义损失函数在通常的 MSE 和从不同索引中减去值的自定义 MSE 之间交替。

为了澄清,假设我有一个 y_pred 张量是 [1, 2, 4, 5] 和一个 y_true 张量是 [2, 5, 1, 3]。在通常的 MSE 中,我们应该得到:

return K.mean(K.squared(y_pred - y_true))

这将执行以下操作:

[1, 2, 4, 5] - [2, 5, 1, 3] = [-1, -3, 3, 2]

[-1, -3, 3, 2]² = [1, 9, 9, 4]

平均([1, 9, 9, 4]) = 5.75

我需要我的自定义损失函数来选择这个平均值和其他从 y_pred 张量切换索引 1 和 3 的最小值,即:

[1, 5, 4, 2] - [2, 5, 1, 3] = [-1, 0, 3, 1]

[-1, 0, 3, 1]² = [1, 0, 9, 1]

平均([1, 0, 9, 1]) = 2.75

因此,我的自定义损失将返回 2.75,这是两种方法之间的最小值。为此,我尝试在 numpy 数组中转换 y_true 和 y_pred 张量,进行所有相关的数学运算,如下所示:

def new_mse(y_true, y_pred):
    sess = tf.Session()

    with sess.as_default():
        np_y_true = y_true.eval()
        np_y_pred = y_pred.eval()

        np_err_mse = np.empty(np_y_true.shape)
        np_err_mse = np.square(np_y_pred - np_y_true)

        np_err_new_mse = np.empty(np_y_true.shape)
        l0 = np.square(np_y_pred[:, 2] - np_y_true[:, 0])   
        l1 = np.square(np_y_pred[:, 3] - np_y_true[:, 1])
        l2 = np.square(np_y_pred[:, 0] - np_y_true[:, 2])
        l3 = np.square(np_y_pred[:, 1] - np_y_true[:, 3])   
        l4 = np.square(np_y_pred[:, 4] - np_y_true[:, 4])
        l5 = np.square(np_y_pred[:, 5] - np_y_true[:, 5])
        np_err_new_mse = np.transpose(np.vstack(l0, l1, l2, l3, l4, l5))

        np_err_mse = np.mean(np_err_mse)
        np_err_new_mse = np.mean(np_err_new_mse)

        return np.amin([np_err_mse, np_err_new_mse])

问题是我不能对 y_true 和 y_pred 张量使用 eval() 方法,不知道为什么。最后,我的问题是:

  1. 是否有更简单的方法来处理张量和损失函数内部的索引?总的来说,我是 Tensorflow 和 Keras 的新手,我坚信在 numpy 数组中转换所有内容根本不是最佳方法。
  2. 与问题不完全相关,但是当我尝试使用 K.shape(y_true) 打印 y_true 张量的形状时,我得到了“Tensor("Shape_1:0", shape=(2,), dtype=int32)"。这让我很困惑,因为我使用的 y.shape 等于 (7032, 6),即 7032 个图像,每个图像有 6 个标签。可能是与损失函数使用的我的 y 和 y_pred 相关的一些误解。

【问题讨论】:

    标签: tensorflow neural-network keras


    【解决方案1】:

    您通常只使用backend functions,而您从不尝试知道张量的实际值。

    from keras.losses import mean_square_error
    
    def new_mse(y_true,y_pred): 
    
        #swapping elements 1 and 3 - concatenate slices of the original tensor
        swapped = K.concatenate([y_pred[:1],y_pred[3:],y_pred[2:3],y_pred[1:2]])
        #actually, if the tensors are shaped like (batchSize,4), use this:
        #swapped = K.concatenate([y_pred[:,:1],y_pred[:,3:],y_pred[:,2:3],Y_pred[:,1:2])
    
        #losses
        regularLoss = mean_squared_error(y_true,y_pred)
        swappedLoss = mean_squared_error(y_true,swapped)
    
        #concat them for taking a min value
        concat = K.concatenate([regularLoss,swappedLoss])
    
        #take the minimum
        return K.min(concat)
    

    所以,对于您的物品:

    1. 你完全正确。在张量操作(损失函数、激活、自定义层等)中不惜一切代价避免使用 numpy。

    2. K.shape() 也是张量。它可能具有形状 (2,),因为它有两个值,一个值为 7032,另一个值为 6。但您只能在评估此张量时看到这些值。在损失函数中这样做通常是个坏主意。

    【讨论】:

    • 所以,我在 K.concatenate() 方法中得到“TypeError: Expected int32, got list contains Tensors of type '_Message' instead”。由于'axis'参数,我试图改变参数的顺序,但这似乎不是问题。目前使用 Keras 1.2.2 和 Tensorflow 1.1.0。关于如何解决这个问题的任何想法?
    • Hmmmm...我已经阅读过,keras 1 和 keras 2 之间存在一些显着差异。我使用的是 keras 2.0.x。 --- 请注意我在答案的swapped 行中添加的小评论。您可能需要使用注释行。
    • 您也可以尝试将regularLossswappedLoss 这两行替换为K.mean(K.square(y_pred - y_true))K.mean(K.square(swapped - y_true)),以防这些message 张量来自导入的损失函数。
    • 它可以是一个张量,但K.min() 将是一个只有一个元素的张量。
    • (不过我建议你迁移到 keras 2)
    【解决方案2】:

    如果使用 Keras 2,您应该只使用 K.gather 函数来执行索引。

    Daniel Möller 的回答变成:

    from keras.losses import mean_square_error
    
    def reindex(t, perm):
        K.concatenate([K.gather(t, i) for i in perm])
    
    def my_loss(y_true,y_pred):
    
        #losses
        regularLoss = mean_squared_error(y_true, y_pred)
        swappedLoss = mean_squared_error(y_true, reindex(y_pred, [0,3,2,1]))
    
        #concat them for taking a min value
        concat = K.concatenate([regularLoss,swappedLoss])
    
        #take the minimum
        return K.min(concat)
    

    【讨论】:

      猜你喜欢
      • 2020-05-30
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-09-20
      • 2017-09-02
      • 2020-01-08
      • 1970-01-01
      • 2021-01-03
      相关资源
      最近更新 更多