【问题标题】:Target and output shape/type for binary classification using PyTorch使用 PyTorch 进行二进制分类的目标和输出形状/类型
【发布时间】:2021-05-30 16:01:26
【问题描述】:

所以我有一些带注释的图像,我想用它们来训练二值图像分类器,但我在创建数据集和实际获取要训练的测试模型时遇到了问题。每个图像要么属于某个类别,要么不属于某个类别,因此我想使用 PyTorch 建立一个二进制分类数据集/模型。我有一些问题:

  1. 标签应该是浮动的还是长的?
  2. 我的标签应该是什么形状?
  3. 我正在使用来自 torchvision 模型的 resnet18 类,我的最终 softmax 层应该有一个还是两个输出?
  4. 如果我的批大小为 200,那么在训练期间,我的目标应该是什么形状?
  5. 我的输出应该是什么形状?

提前致谢

报价 删除

【问题讨论】:

    标签: python deep-learning pytorch classification


    【解决方案1】:

    二元分类与多标签分类略有不同:而对于多标签,您的模型预测每个样本的“logits”向量,并使用 softmax 将 logits 转换为概率;在二进制情况下,模型预测每个样本的 标量“logit”,并使用 sigmoid 函数将其转换为类概率。

    中,softmax 和 sigmoind 被“折叠”到损失层中(出于数值稳定性考虑),因此这两种情况有不同的交叉熵损失层:nn.BCEWithLogitsLoss 用于二进制情况(使用 sigmoid ) 和 nn.CrossEntropyLoss 用于多标签情况(使用 softmax)。

    在您的情况下,您想使用二进制版本(带有 sigmoid):nn.BCEWithLogitsLoss
    因此,您的标签应该是 torch.float32 类型(与网络输出相同的 float 类型)而不是整数。 每个样本应该有一个单个标签。因此,如果您的批量大小为 200,则目标的形状应为 (200,1)


    我将把它留在这里作为练习,以展示训练具有两个输出和 CE+softmax 的模型等效于二进制输出+sigmoid ;)

    【讨论】:

    • 非常感谢,我按照上述步骤操作,网络现在可以运行了。然而!由于某种原因,我的损失根本没有改变。知道为什么会发生这种情况吗?
    • @MohamedMoustafa 这可能有很多原因。尝试改变学习率
    • 我设法解决了这个问题。我将 sigmoid 放在完全连接层的末尾(我从 torchvision 编辑了 resnet18)。从网络中删除 sigmoid 似乎可以解决问题。
    • @MohamedMoustafa 使用BCEWithLogits 使得sigmoid 的显式使用变得多余。
    猜你喜欢
    • 2021-05-09
    • 2021-01-31
    • 2020-12-20
    • 2022-10-01
    • 2020-07-11
    • 2021-05-18
    • 2021-05-02
    • 2018-01-29
    • 1970-01-01
    相关资源
    最近更新 更多