【发布时间】:2018-07-30 01:49:32
【问题描述】:
我有一个包含 12 个类别的多标签分类问题。我正在使用Tensorflow 中的slim 来使用在ImageNet 上预训练的模型来训练模型。以下是训练和验证中每个班级出现的百分比
Training Validation
class0 44.4 25
class1 55.6 50
class2 50 25
class3 55.6 50
class4 44.4 50
class5 50 75
class6 50 75
class7 55.6 50
class8 88.9 50
class9 88.9 50
class10 50 25
class11 72.2 25
问题是模型没有收敛,验证集上ROC 曲线 (Az) 的下方很差,类似于:
Az
class0 0.99
class1 0.44
class2 0.96
class3 0.9
class4 0.99
class5 0.01
class6 0.52
class7 0.65
class8 0.97
class9 0.82
class10 0.09
class11 0.5
Average 0.65
我不知道为什么它对某些课程有效而对其他课程无效。我决定深入研究细节,看看神经网络在学习什么。我知道混淆矩阵仅适用于二元或多类分类。因此,为了能够绘制它,我必须将问题转换为多类分类对。尽管该模型是使用sigmoid 为每个类别提供预测的,但对于下面混淆矩阵中的每个单元格,我显示了概率的平均值(通过将sigmoid 函数应用于tensorflow)的图像,其中矩阵的行中的类存在并且列中的类不存在。这应用于验证集图像。通过这种方式,我认为我可以获得有关模型学习内容的更多详细信息。我只是圈出了对角线元素以供显示。
我的解释是:
- 类别 0 和 4 在存在时被检测到,而在不存在时则不存在。这意味着可以很好地检测到这些类。
- 2、6 和 7 类始终被检测为不存在。这不是我要找的。li>
- 3、8 和 9 类始终被检测为存在。这不是我要找的。这可以应用于 11 类。
- 5 类在不存在时被检测为存在,在它存在时被检测为不存在。它被反向检测。
- 第 3 类和第 10 类:我认为我们不能为这 2 个类提取太多信息。
我的问题是解释。我不确定问题出在哪里,也不确定数据集中是否存在产生此类结果的偏差。我还想知道是否有一些指标可以帮助解决多标签分类问题?你能和我分享你对这种混淆矩阵的解释吗?以及接下来看什么/去哪里?对其他指标的一些建议会很棒。
谢谢。
编辑:
我将问题转换为多类分类,因此对于每对类(例如 0,1)来计算概率(类 0,类 1),表示为 p(0,1):
我对存在工具 0 且不存在工具 1 的图像进行工具 1 的预测,并通过应用 sigmoid 函数将它们转换为概率,然后显示这些概率的平均值。对于p(1, 0),我对工具 0 执行相同操作,但现在使用工具 1 存在且工具 0 不存在的图像。对于p(0, 0),我使用存在工具 0 的所有图像。考虑上图中的p(0,4),N/A 表示没有工具 0 存在且工具 4 不存在的图像。
以下是 2 个子集的图像数量:
- 169320张图片用于训练
- 37440 张图片用于验证
这是在训练集上计算的混淆矩阵(计算方式与前面描述的验证集相同),但这次颜色代码是用于计算每个概率的图像数量:
已编辑: 对于数据增强,我对网络的每个输入图像进行随机平移、旋转和缩放。此外,以下是有关这些工具的一些信息:
class 0 shape is completely different than the other objects.
class 1 resembles strongly to class 4.
class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
class 3 shape is completely different than the other objects.
class 4 resembles strongly to class 1
class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
class 6 resembles strongly to class 7
class 7 resembles strongly to class 6
class 8 shape is completely different than the other objects.
class 9 resembles strongly to class 10
class 10 resembles strongly to class 9
class 11 shape is completely different than the other objects.
已编辑: 以下是下面为训练集提出的代码的输出:
Avg. num labels per image = 6.892700212615167
On average, images with label 0 also have 6.365296803652968 other labels.
On average, images with label 1 also have 6.601033718926901 other labels.
On average, images with label 2 also have 6.758548914659531 other labels.
On average, images with label 3 also have 6.131520940484937 other labels.
On average, images with label 4 also have 6.219187208527648 other labels.
On average, images with label 5 also have 6.536933407946279 other labels.
On average, images with label 6 also have 6.533908387864367 other labels.
On average, images with label 7 also have 6.485973817793214 other labels.
On average, images with label 8 also have 6.1241642788920725 other labels.
On average, images with label 9 also have 5.94092288040875 other labels.
On average, images with label 10 also have 6.983303518187239 other labels.
On average, images with label 11 also have 6.1974066621953945 other labels.
对于验证集:
Avg. num labels per image = 6.001282051282051
On average, images with label 0 also have 6.0 other labels.
On average, images with label 1 also have 3.987080103359173 other labels.
On average, images with label 2 also have 6.0 other labels.
On average, images with label 3 also have 5.507731958762887 other labels.
On average, images with label 4 also have 5.506459948320414 other labels.
On average, images with label 5 also have 5.00169779286927 other labels.
On average, images with label 6 also have 5.6729452054794525 other labels.
On average, images with label 7 also have 6.0 other labels.
On average, images with label 8 also have 6.0 other labels.
On average, images with label 9 also have 5.506459948320414 other labels.
On average, images with label 10 also have 3.0 other labels.
On average, images with label 11 also have 4.666095890410959 other labels.
评论: 我认为这不仅与分布之间的差异有关,因为如果模型能够很好地概括第 10 类(意味着对象在训练过程中被正确识别,如第 0 类),那么验证集的准确性会很好足够的。我的意思是,问题在于训练集本身以及它是如何构建的,而不是两种分布之间的差异。它可以是:类或对象的存在频率非常相似(如类 10 与类 9 非常相似)或数据集中的偏差或薄对象(可能代表输入中 1% 或 2% 的像素像类 2 的图像)。我并不是说问题是其中之一,但我只是想指出,我认为这不仅仅是两种分布之间的差异。
【问题讨论】:
-
您能否更详细地解释一下矩阵中的值是如何计算的? N/As 是什么意思?除以 0?你的训练和测试集有多大?您是否还有关于哪些类在训练数据中同时出现的频率的任何信息(例如,如果您绘制热图,它最终看起来是否类似于您的混淆矩阵)?
-
@DennisSoemers,我编辑了我的问题以包含更多细节。
-
我对目标类感到困惑。每个图像可以有几个目标类?我认为这是一个“多标签分类问题”。您在神经网络中使用什么损失函数?在这里查看一些不同的选项:en.wikipedia.org/wiki/…
-
你从你的网络得到什么输出?它不是已经为每个标签提供了一个“概率”([0, 1] 中的数字)吗?如果是这样,我想我不明白你为什么要应用额外的 sigmoid 来获取混淆矩阵中的数字。你不能直接取平均值吗?
-
@KPLauritzen。是的,这是一个多标签分类问题。每个图像可以有零到 n 个类别。我使用 sigmoid 作为损失函数。
标签: python tensorflow machine-learning deep-learning confusion-matrix