【问题标题】:How does TensorFlow/Keras's class_weight parameter of the fit() function work?TensorFlow/Keras 的 fit() 函数的 class_weight 参数如何工作?
【发布时间】:2020-01-15 23:31:02
【问题描述】:

我使用 TensorFlow 1.12 和 Keras 进行语义分割。我使用class_weight 参数向tf.keras.Model.fit() 提供权重向量(大小等于类数)。我想知道这是如何在内部工作的。我使用自定义损失函数(骰子损失和焦点损失等),并且在将权重输入到损失函数之前,不能将权重与预测或单热基础事实相乘,因为这不会产生任何影响感觉。我的损失函数输出一个标量值,因此它也不能与函数输出相乘。那么究竟在哪里以及如何考虑类权重呢?

我的自定义损失函数是:

def cross_entropy_loss(onehots_true, logits): # Inputs are [BATCH_SIZE, height, width, num_classes]
    logits, onehots_true = mask_pixels(onehots_true, logits) # Removes pixels for which no ground truth exists, and returns shape [num_gt_pixels, num_classes]
    return tf.losses.softmax_cross_entropy(onehots_true, logits)

【问题讨论】:

  • 你检查过我的答案吗,如果你想要的话请告诉我
  • 抱歉我反应迟了。你的回答很有帮助!我还是不明白class_sample_weights是什么时候开始应用的,但是我还没有时间进一步探索源代码。

标签: tensorflow keras loss-function


【解决方案1】:

正如Keras Official Docs中提到的,

class_weight: 可选字典映射类索引(整数) 权重(浮点)值,用于加权损失函数 (仅在训练期间)。这对于告诉模型“支付 更多地关注”来自代表性不足的班级的样本。

基本上,我们在有 类不平衡e 的情况下提供类权重。这意味着,训练样本并非均匀分布在所有类别中。有些类的样本较少,而有些类的样本较多。

我们需要分类器更多地注意数量较少的类。一种方法可能是增加低样本类的损失值。更高的损失意味着更高的优化,从而导致有效的分类。

就 Keras 而言,我们将 dict 映射类索引传递给它们的权重(损失值将乘以的因素)。举个例子吧,

class_weights = { 0 : 1.2 , 1 : 0.9 }

在内部,第 0 类和第 1 类的损失值将乘以它们对应的权重值。

weighed_loss_class0 = loss0 * class_weights[0]
weighed_loss_class1 = loss1 * class_weights[1]

现在,the weighed_loss_class0weighed_loss_class1 将用于反向传播。

参见thisthis

【讨论】:

  • 感谢您的回复。我知道什么类权重有好处以及为什么使用它们,我只是想知道它们是如何实际应用的。您提到它们与每个单独类的损失相乘,但是在哪里会发生这种情况?我的损失函数输出一个标量值,那么如何对特定类别的损失进行加权平均?
  • 如果你能分享损失函数的代码就好了。
  • 酷。当您有多个输出时,您提供给 class_weight 的字典结构应该是什么?
【解决方案2】:

您可以从Github中的Keras源代码中引用以下代码:

    class_sample_weight = np.asarray(
        [class_weight[cls] for cls in y_classes if cls in class_weight])

    if len(class_sample_weight) != len(y_classes):
      # subtract the sets to pick all missing classes
      existing_classes = set(y_classes)
      existing_class_weight = set(class_weight.keys())
      raise ValueError(
          '`class_weight` must contain all classes in the data.'
          ' The classes %s exist in the data but not in '
          '`class_weight`.' % (existing_classes - existing_class_weight))

  if class_sample_weight is not None and sample_weight is not None:
    # Multiply weights if both are provided.
    return class_sample_weight * sample_weight

所以如你所见,首先class_weight被转换为numpy数组class_sample_weight,然后它乘以Sample_weight。

来源:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/training_utils.py

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2019-02-05
    • 1970-01-01
    • 2015-09-07
    • 2019-12-26
    • 1970-01-01
    • 2018-09-23
    • 2019-11-14
    • 1970-01-01
    相关资源
    最近更新 更多