【问题标题】:Imbalanced Dataset for Multi Label Classification多标签分类的不平衡数据集
【发布时间】:2017-11-01 07:17:27
【问题描述】:

所以我在我创建的多标签数据集(大约 20000 个样本)上训练了一个深度神经网络。我将 softmax 切换为 sigmoid 并尝试最小化(使用 Adam 优化器):

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

我最终得到了这个预测之王(相当“恒定”):

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646 0.38991812  0.54468459  0.34593087  0.82790571]

Prediction for Im2 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177 0.35400775  0.36479294  0.30002621  0.84438241]

Prediction for Im3 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551 0.43719986  0.54638696  0.20344526  0.88144571]

起初,我以为我只需要为每个类找到一个阈值。

但我注意到,例如,在我的 20000 个样本中,第一类出现大约 10800 个,因此比率为 0.54,它是我每次预测的值。所以我认为我需要找到一种方法来解决 tuis“不平衡数据集”问题。

我考虑减少我的数据集(欠采样)以使每个类的出现次数大致相同,但只有 26 个样本对应于我的一个类...这会让我丢失很多样本...

我读到过关于过采样或惩罚更多罕见但并没有真正理解其工作原理的类。

有人可以分享一些关于这些方法的解释吗?

在实践中,在 Tensorflow 上,是否有函数可以帮助做到这一点?

还有其他建议吗?

谢谢你:)

PS:Neural Network for Imbalanced Multi-Class Multi-Label Classification这个帖子提出了同样的问题,但没有答案!

【问题讨论】:

  • 为什么不使用您拥有的所有样本并使用不平衡数据来使用异常检测算法?
  • 如果我理解得很好,你的建议是在我的 (9) 个类上训练我的网络(在我的数据集中“很好”表示),然后在我的“表现不佳”的类上训练另一个网络(比如在这个类上做二元分类)?
  • 没有。我建议使用算法来检测非常小的少数,你的绝大多数数据的差异。它们通常被称为异常检测算法,因为通常当您尝试检测异常时,您有很多“好”样本但很少有“异常”样本。然而,这些算法通常用于在两个类别之间进行分类。所以也许这对你不好,但它可能是更复杂的分类过程的一部分
  • 好的,谢谢你的想法(和你的额外解释)!

标签: tensorflow deep-learning multilabel-classification


【解决方案1】:

嗯,在一个类中有 10000 个样本,而在一个稀有类中只有 26 个确实是个问题。

但是,在我看来,您所经历的更像是“输出甚至看不到输入”,因此网络只会了解您的输出分布。

为了调试它,我会创建一个简化的集合(仅用于此调试目的),每个类有 26 个样本,然后尝试严重过度拟合。如果你得到正确的预测,我的想法是错误的。但是,如果网络甚至无法检测到那些采样不足的过拟合样本,那么这确实是一个架构/实现问题,而不是由于不规则的分布(然后你需要修复它。但它不会像你当前的结果那么糟糕)。

【讨论】:

  • 一开始我认为这可能是我的网络的问题,但它适用于单标签分类,例如 MNIST(当我有 Softmax 时)。但我仍然会尝试过拟合每类 26 个样本!谢谢你的回答!
  • 好吧,你是对的 .. 不幸的是我!但是,正如我之前所说,用于在 MNIST 数据集和我创建的数据集(多类单标签)上学习和执行非常好的完全相同的架构!唯一改变的是我用 Sigmoid 替换了 Softmax ..
【解决方案2】:

您的问题不在于类别不平衡,而只是缺少数据。对于几乎任何真正的机器学习任务,26 个样本被认为是一个非常小的数据集。可以通过确保每个 minibatch 至少有一个来自每个类的样本来轻松处理类不平衡(这会导致某些样本比另一个更频繁地使用的情况,但谁在乎)。

但是,在只有 26 个样本的情况下,这种方法(以及任何其他方法)将很快导致过度拟合。这个问题可以通过某种形式的数据增强来部分解决,但是样本仍然太少而无法构建合理的东西。

所以,我的建议是收集更多数据。

【讨论】:

  • 26 不是我的数据集的大小,只是我的一个班级在整个数据集(即 20000 个样本)中出现的次数。感谢您提出“确保每个 minibatch 至少有一个来自每个班级的样本”的想法。这与 Oersampling 的想法相同吗? :)
猜你喜欢
  • 1970-01-01
  • 2020-10-23
  • 1970-01-01
  • 2017-11-03
  • 2015-01-28
  • 2021-04-19
  • 1970-01-01
  • 2013-12-26
  • 2020-02-01
相关资源
最近更新 更多