【问题标题】:Nan loss in keras with triplet losskeras中的Nan损失与三重损失
【发布时间】:2020-01-22 10:06:16
【问题描述】:

我正在尝试学习结合 VGG 和 Adrian Ung triplet loss 的 Paris6k 图像嵌入。问题是经过少量迭代,在第一个epoch,loss变成了nan,然后accuracy和validation accuracy都增长到了1。

我已经尝试过降低学习率、增加批量大小(由于内存仅增加到 16)、更改优化器(Adam 和 RMSprop)、检查我的数据集上是否有 None 值、从 'float32 更改数据格式' 到 'float64',给它们添加一点偏差并简化模型。

这是我的代码:

base_model = VGG16(include_top = False, input_shape = (512, 384, 3))
input_images = base_model.input
input_labels = Input(shape=(1,), name='input_label')

embeddings = Flatten()(base_model.output)
labels_plus_embeddings = concatenate([input_labels, embeddings])

model = Model(inputs=[input_images, input_labels], outputs=labels_plus_embeddings)

batch_size = 16
epochs = 2
embedding_size = 64

opt = Adam(lr=0.0001)

model.compile(loss=tl.triplet_loss_adapted_from_tf, optimizer=opt, metrics=['accuracy'])

label_list = np.vstack(label_list)

x_train = image_list[:2500]
x_val = image_list[2500:]

y_train = label_list[:2500]
y_val = label_list[2500:]

dummy_gt_train = np.zeros((len(x_train), embedding_size + 1))
dummy_gt_val = np.zeros((len(x_val), embedding_size + 1))

H = model.fit(
    x=[x_train,y_train],
    y=dummy_gt_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=([x_val, y_val], dummy_gt_val),callbacks=callbacks_list)

图像为 3366,其值在 [0, 1] 范围内缩放。 该网络采用虚拟值,因为它试图从图像中学习嵌入,即同一类的图像应该具有较小的距离,而不同类的图像应该具有较大的距离,而真实类是训练的一部分。

我注意到我之前进行了不正确的分类(并保留了应该丢弃的图像),并且我没有遇到 nan loss 问题。

我应该怎么做?

提前感谢我的英语。

【问题讨论】:

    标签: keras loss


    【解决方案1】:

    在某些情况下,随机 NaN 丢失可能是由您的数据引起的,因为如果您的批次中没有正对,您将获得 NaN 丢失。

    正如您在 Adrian Ung 的笔记本中看到的(或在 tensorflow addonstriplet loss 中看到的;代码相同):

    semi_hard_triplet_loss_distance = math_ops.truediv(
            math_ops.reduce_sum(
                math_ops.maximum(
                    math_ops.multiply(loss_mat, mask_positives), 0.0)),
            num_positives,
            name='triplet_semihard_loss')
    

    除以正数对 (num_positives) 的数量,这可能导致 NaN。

    我建议您尝试检查您的数据管道,以确保每个批次中至少有一对阳性。 (例如,您可以修改triplet_loss_adapted_from_tf 中的一些代码以获取您的批次的num_positives,并检查它是否大于0。

    【讨论】:

      【解决方案2】:

      尝试增加批量大小。它也发生在我身上。如上一个答案所述,网络无法找到任何 num_positives。我有 250 节课,一开始就输了。我将它增加到 128/256,然后就没有问题了。

      我看到 Paris6k 有 15 个类或 12 个类。增加批量大小 32,如果出现 GPU 内存,您可以尝试使用较少参数的模型。您可以使用 Efficient B0 模型进行启动。与具有 138M 参数的 VGG16 相比,它具有 5.3M 参数。

      【讨论】:

      • 谢谢...这对我很有帮助
      【解决方案3】:

      我已经实现了一个用于生成三元组的包,这样可以保证每个批次都包含正对。它仅与 TF/Keras 兼容。

      https://github.com/ma7555/kerasgen(免责声明:我是所有者)

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2021-07-18
        • 2018-06-25
        • 2021-09-06
        • 1970-01-01
        • 2019-10-06
        • 2019-04-02
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多