【发布时间】:2019-05-27 03:07:48
【问题描述】:
我正在使用 Keras 2.2.4,我正在尝试实现像素分类的损失函数,如here 中所述,但我遇到了这里提出的一些困难。我正在做 3D 分割,因此我的目标向量是(b_size, width_x, width_y, width_z, nb_classes)。我实现了以下损失函数,其中权重图与目标和预测向量的形状相同:
def dice_xent_loss(y_true, y_pred, weight_map):
"""Adaptation of https://arxiv.org/pdf/1809.10486.pdf for multilabel
classification with overlapping pixels between classes. Dec 2018.
"""
loss_dice = weighted_dice(y_true, y_pred, weight_map)
loss_xent = weighted_binary_crossentropy(y_true, y_pred, weight_map)
return loss_dice + loss_xent
def weighted_binary_crossentropy(y_true, y_pred, weight_map):
return tf.reduce_mean((K.binary_crossentropy(y_true,
y_pred)*weight_map)) / (tf.reduce_sum(weight_map) + K.epsilon())
def weighted_dice(y_true, y_pred, weight_map):
if weight_map is None:
raise ValueError("Weight map cannot be None")
if y_true.shape != weight_map.shape:
raise ValueError("Weight map must be the same size as target vector")
dice_numerator = 2.0 * K.sum(y_pred * y_true * weight_map, axis=[1,2,3])
dice_denominator = K.sum(weight_map * y_true, axis=[1,2,3]) + \
K.sum(y_pred * weight_map, axis=[1,2,3])
loss_dice = (dice_numerator) / (dice_denominator + K.epsilon())
h1=tf.square(tf.minimum(0.1,loss_dice)*10-1)
h2=tf.square(tf.minimum(0.01,loss_dice)*100-1)
return 1.0 - tf.reduce_mean(loss_dice) + \
tf.reduce_mean(h1)*10 + \
tf.reduce_mean(h2)*10
我正在按照建议使用sample_weights=temporal 编译模型,并将权重作为sample_weight=weights 传递给model.fit。我仍然收到以下错误:
File "overfit_one_case.py", line 153, in <module>
main()
File "overfit_one_case.py", line 81, in main
sample_weight_mode="temporal")
File "/home/igt/anaconda2/envs/niftynet/lib/python2.7/site-packages/keras/engine/training.py", line 342, in compile
sample_weight, mask)
File "/home/igt/anaconda2/envs/niftynet/lib/python2.7/site-packages/keras/engine/training_utils.py", line 404, in weighted
score_array = fn(y_true, y_pred)
TypeError: dice_xent_loss() takes exactly 3 arguments (2 given)
在training_utils.py Keras 调用我的自定义损失,没有任何权重。关于如何解决这个问题的任何想法?我的另一个限制是我正在尝试对这个特定模型进行迁移学习。因此,我无法按照建议的here 将weight_map 添加到Input 层。
【问题讨论】:
标签: tensorflow keras deep-learning