【问题标题】:Understanding multi-label classifier using confusion matrix使用混淆矩阵理解多标签分类器
【发布时间】:2018-07-30 01:49:32
【问题描述】:

我有一个包含 12 个类别的多标签分类问题。我正在使用Tensorflow 中的slim 来使用在ImageNet 上预训练的模型来训练模型。以下是训练和验证中每个班级出现的百分比

            Training     Validation
  class0      44.4          25
  class1      55.6          50
  class2      50            25
  class3      55.6          50
  class4      44.4          50
  class5      50            75
  class6      50            75
  class7      55.6          50
  class8      88.9          50
  class9     88.9           50
  class10     50            25
  class11     72.2          25

问题是模型没有收敛,验证集上ROC 曲线 (Az) 的下方很差,类似于:

               Az 
  class0      0.99
  class1      0.44
  class2      0.96  
  class3      0.9
  class4      0.99
  class5      0.01
  class6      0.52
  class7      0.65
  class8      0.97
  class9     0.82
  class10     0.09
  class11     0.5
  Average     0.65

我不知道为什么它对某些课程有效而对其他课程无效。我决定深入研究细节,看看神经网络在学习什么。我知道混淆矩阵仅适用于二元或多类分类。因此,为了能够绘制它,我必须将问题转换为多类分类对。尽管该模型是使用sigmoid 为每个类别提供预测的,但对于下面混淆矩阵中的每个单元格,我显示了概率的平均值(通过将sigmoid 函数应用于tensorflow)的图像,其中矩阵的行中的类存在并且列中的类不存在。这应用于验证集图像。通过这种方式,我认为我可以获得有关模型学习内容的更多详细信息。我只是圈出了对角线元素以供显示。

我的解释是:

  1. 类别 0 和 4 在存在时被检测到,而在不存在时则不存在。这意味着可以很好地检测到这些类。
  2. 2、6 和 7 类始终被检测为不存在。这不是我要找的。​​li>
  3. 3、8 和 9 类始终被检测为存在。这不是我要找的。这可以应用于 11 类。
  4. 5 类在不存在时被检测为存在,在它存在时被检测为不存在。它被反向检测。
  5. 第 3 类和第 10 类:我认为我们不能为这 2 个类提取太多信息。

我的问题是解释。我不确定问题出在哪里,也不确定数据集中是否存在产生此类结果的偏差。我还想知道是否有一些指标可以帮助解决多标签分类问题?你能和我分享你对这种混淆矩阵的解释吗?以及接下来看什么/去哪里?对其他指标的一些建议会很棒。

谢谢。

编辑:

我将问题转换为多类分类,因此对于每对类(例如 0,1)来计算概率(类 0,类 1),表示为 p(0,1): 我对存在工具 0 且不存在工具 1 的图像进行工具 1 的预测,并通过应用 sigmoid 函数将它们转换为概率,然后显示这些概率的平均值。对于p(1, 0),我对工具 0 执行相同操作,但现在使用工具 1 存在且工具 0 不存在的图像。对于p(0, 0),我使用存在工具 0 的所有图像。考虑上图中的p(0,4),N/A 表示没有工具 0 存在且工具 4 不存在的图像。

以下是 2 个子集的图像数量:

  1. 169320张图片用于训练
  2. 37440 张图片用于验证

这是在训练集上计算的混淆矩阵(计算方式与前面描述的验证集相同),但这次颜色代码是用于计算每个概率的图像数量:

已编辑: 对于数据增强,我对网络的每个输入图像进行随机平移、旋转和缩放。此外,以下是有关这些工具的一些信息:

class 0 shape is completely different than the other objects.
class 1 resembles strongly to class 4.
class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
class 3 shape is completely different than the other objects.
class 4 resembles strongly to class 1
class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
class 6 resembles strongly to class 7
class 7 resembles strongly to class 6
class 8 shape is completely different than the other objects.
class 9 resembles strongly to class 10
class 10 resembles strongly to class 9
class 11 shape is completely different than the other objects.

已编辑: 以下是下面为训练集提出的代码的输出:

Avg. num labels per image =  6.892700212615167
On average, images with label  0  also have  6.365296803652968  other labels.
On average, images with label  1  also have  6.601033718926901  other labels.
On average, images with label  2  also have  6.758548914659531  other labels.
On average, images with label  3  also have  6.131520940484937  other labels.
On average, images with label  4  also have  6.219187208527648  other labels.
On average, images with label  5  also have  6.536933407946279  other labels.
On average, images with label  6  also have  6.533908387864367  other labels.
On average, images with label  7  also have  6.485973817793214  other labels.
On average, images with label  8  also have  6.1241642788920725  other labels.
On average, images with label  9  also have  5.94092288040875  other labels.
On average, images with label  10  also have  6.983303518187239  other labels.
On average, images with label  11  also have  6.1974066621953945  other labels.

对于验证集:

Avg. num labels per image =  6.001282051282051
On average, images with label  0  also have  6.0  other labels.
On average, images with label  1  also have  3.987080103359173  other labels.
On average, images with label  2  also have  6.0  other labels.
On average, images with label  3  also have  5.507731958762887  other labels.
On average, images with label  4  also have  5.506459948320414  other labels.
On average, images with label  5  also have  5.00169779286927  other labels.
On average, images with label  6  also have  5.6729452054794525  other labels.
On average, images with label  7  also have  6.0  other labels.
On average, images with label  8  also have  6.0  other labels.
On average, images with label  9  also have  5.506459948320414  other labels.
On average, images with label  10  also have  3.0  other labels.
On average, images with label  11  also have  4.666095890410959  other labels.

评论: 我认为这不仅与分布之间的差异有关,因为如果模型能够很好地概括第 10 类(意味着对象在训练过程中被正确识别,如第 0 类),那么验证集的准确性会很好足够的。我的意思是,问题在于训练集本身以及它是如何构建的,而不是两种分布之间的差异。它可以是:类或对象的存在频率非常相似(如类 10 与类 9 非常相似)或数据集中的偏差或薄对象(可能代表输入中 1% 或 2% 的像素像类 2 的图像)。我并不是说问题是其中之一,但我只是想指出,我认为这不仅仅是两种分布之间的差异。

【问题讨论】:

  • 您能否更详细地解释一下矩阵中的值是如何计算的? N/As 是什么意思?除以 0?你的训练和测试集有多大?您是否还有关于哪些类在训练数据中同时出现的频率的任何信息(例如,如果您绘制热图,它最终看起来是否类似于您的混淆矩阵)?
  • @DennisSoemers,我编辑了我的问题以包含更多细节。
  • 我对目标类感到困惑。每个图像可以有几个目标类?我认为这是一个“多标签分类问题”。您在神经网络中使用什么损失函数?在这里查看一些不同的选项:en.wikipedia.org/wiki/…
  • 你从你的网络得到什么输出?它不是已经为每个标签提供了一个“概率”([0, 1] 中的数字)吗?如果是这样,我想我不明白你为什么要应用额外的 sigmoid 来获取混淆矩阵中的数字。你不能直接取平均值吗?
  • @KPLauritzen。是的,这是一个多标签分类问题。每个图像可以有零到 n 个类别。我使用 sigmoid 作为损失函数。

标签: python tensorflow machine-learning deep-learning confusion-matrix


【解决方案1】:

输出校准

我认为首先要意识到的重要一点是,神经网络的输出可能校准不佳。我的意思是,它提供给不同实例的输出可能会产生良好的排名(带有标签 L 的图像在该标签上的得分往往高于没有标签 L 的图像),但这些分数并不总是可以可靠地解释为概率(它可能会给没有标签的实例提供非常高的分数,例如0.9,而只会给带有标签的实例更高的分数,例如0.99)。我想这是否会发生取决于你选择的损失函数。

有关这方面的更多信息,请参阅例如:https://arxiv.org/abs/1706.04599


一个接一个地完成所有课程

0 级: AUC(曲线下面积)= 0.99。这是一个非常好的分数。混淆矩阵中的第 0 列看起来也不错,所以这里没有错。

1 类: AUC = 0.44。这非常糟糕,低于 0.5,如果我没记错的话,这几乎意味着您最好故意与您的网络对此标签的预测相反相反

查看混淆矩阵中的第 1 列,它在各处的得分几乎相同。对我来说,这表明网络并没有从这个类中学到很多东西,而且几乎只是根据训练集中包含这个标签的图像的百分比(55.6%)来“猜测”。由于这个百分比在验证集中下降到 50%,这个策略确实意味着它会比随机做的稍微差一些。尽管如此,第 1 行仍然是该列中所有行数最多的,所以它似乎至少学到了一点点,但并不多。

第 2 类: AUC = 0.96。这是非常好的。

您对这个类的解释是,根据整个列的光线阴影,它总是被预测为不存在。我不认为这种解释是正确的。看看它是如何在对角线上的得分 > 0 的,而在该列的其​​他任何地方都只有 0。它在该行中的分数可能相对较低,但很容易与同一列中的其他行分开。您可能只需要设置阈值来选择该标签是否相对较低。我怀疑这是由于上面提到的校准问题。

这也是 AUC 实际上非常好的原因;可以选择一个阈值,使得分数高于阈值的大多数实例正确地具有标签,而低于阈值的大多数实例正确地没有。但是,该阈值可能不是 0.5,如果您假设校准良好,这是您可能期望的阈值。绘制此特定标签的 ROC 曲线可以帮助您准确确定阈值应该在哪里。

3 级: AUC = 0.9,相当好。

您将其解释为始终被检测为存在,并且混淆矩阵在列中确实有很多高数字,但 AUC 很好,对角线上的单元格确实具有足够高的值,它可能是很容易与其他人分开。我怀疑这与第 2 类情况类似(只是颠倒过来,到处都是高预测,因此正确决策所需的阈值很高)。

如果您希望能够确定一个精心选择的阈值是否确实可以正确地将大多数“阳性”(第 3 类的实例)与大多数“阴性”(没有第 3 类的实例)分开,您需要根据标签 3 的预测分数对所有实例进行排序,然后遍历整个列表,并在每对连续条目之间计算如果您决定将阈值放在那里,您将获得的验证集的准确性,并选择最佳阈值。

4 类:与 0 类相同。

5 级: AUC = 0.01,显然很糟糕。也同意你对混淆矩阵的解释。很难确定为什么它在这里表现如此糟糕。也许这是一种难以识别的物体?可能还会发生一些过度拟合(从第二个矩阵中的列判断,训练数据中的 0 误报,尽管也有其他类发生这种情况)。

标签 5 图像的比例从训练数据到验证数据的增加可能也无济于事。这意味着网络在训练期间在此标签上的表现不如在验证期间重要。

第 6 类: AUC = 0.52,仅略好于随机。

从第一个矩阵的第 6 列来看,这实际上可能与第 2 类的情况相似。如果我们也考虑 AUC,它看起来也没有很好地学习对实例进行排名。类似于 5 级,只是没有那么糟糕。同样,训练和验证分布也完全不同。

第 7 类: AUC = 0.65,相当平均。例如,显然不如第 2 类好,但也没有您仅从矩阵中解释的那么差。

8 级: AUC = 0.97,非常好,类似于 3 级。

第 9 类: AUC = 0.82,没有那么好,但仍然很好。矩阵中的列有很多暗单元,数字非常接近,我认为 AUC 非常好。它几乎出现在训练数据的每张图像中,因此预测它经常出现也就不足为奇了。也许其中一些非常暗的单元仅基于少量绝对数量的图像?这会很有趣。

10 级: AUC = 0.09,很糟糕。对角线上的 0 非常令人担忧(您的数据标记是否正确?)。根据第一个矩阵的第 10 行,第 3 类和第 9 类似乎经常被混淆(棉花和 primary_incision_knives 看起来很像 secondary_incision_knives 吗?)。也许对训练数据也有一些过拟合。

第 11 类: AUC = 0.5,不比随机好。性能不佳(并且矩阵中的得分明显过高)可能是因为该标签存在于大多数训练图像中,但只有少数验证图像中存在。


还有什么要绘制/测量的?

为了更深入地了解您的数据,我首先绘制了每个类共同出现频率的热图(一个用于训练,一个用于验证数据)。单元格 (i, j) 将根据包含标签 i 和 j 的图像的比例进行着色。这将是一个对称图,对角线单元格根据您问题中的第一个数字列表着色。比较这两个热图,看看它们在哪里有很大不同,看看这是否有助于解释您的模型的性能。

此外,了解(对于两个数据集)每个图像平均有多少个不同的标签,以及对于每个单独的标签,它平均与多少个其他标签共享一个图像可能很有用。例如,我怀疑标签为 10 的图像在训练数据中的其他标签相对较少。这可能会阻止网络在识别其他事物时预测标签 10,并且如果标签 10 在验证数据中突然更频繁地与其他对象共享图像,则会导致性能下降。由于伪代码可能比文字更容易理解重点,因此打印如下内容可能会很有趣:

# Do all of the following once for training data, AND once for validation data    
tot_num_labels = 0
for image in images:
    tot_num_labels += len(image.get_all_labels())
avg_labels_per_image = tot_num_labels / float(num_images)
print("Avg. num labels per image = ", avg_labels_per_image)

for label in range(num_labels):
    tot_shared_labels = 0
    for image in images_with_label(label):
        tot_shared_labels += (len(image.get_all_labels()) - 1)
    avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
    print("On average, images with label ", label, " also have ", avg_shared_labels, " other labels.")

对于单个数据集,这并不能提供太多有用的信息,但如果你为训练和验证集这样做,你可以看出它们的分布是完全不同的,如果数字非常不同

最后,我有点担心您的第一个矩阵中的某些列如何完全在许多不同的行上出现相同的平均预测。我不太确定是什么原因造成的,但这可能有助于调查。


如何改进?

如果您还没有这样做,我建议您研究一下数据增强以获取您的训练数据。由于您使用的是图像,因此您可以尝试将现有图像的旋转版本添加到您的数据中。

对于您的多标签案例,目标是检测不同类型的对象,尝试简单地将一堆不同的图像(例如两个或四个图像)连接在一起也可能很有趣。然后,您可以将它们缩小到原始图像大小,并作为标签分配原始标签集的并集。合并图像的边缘会出现有趣的不连续性,我不知道这是否有害。也许它不适用于您的多目标检测案例,我认为值得一试。

【讨论】:

  • 感谢您的详细回答。我编辑了我的问题以添加更多细节,但我只有几个 cmets:我确实为训练/验证集生成了热图,但它们没有帮助。您能否在“还有什么要绘制的内容”部分中详细说明您的第二个建议?
  • 我还有一个问题是关于您在训练/验证集(例如第 5 类和第 6 类)中每个对象的出现频率之间所做的相关性,以给出您的解释。从我的角度来看,我只是检查了训练集中每个对象出现的频率,因为这是模型用来继续前进的。
  • @Maystro 为第一个问题编辑了更多信息。至于训练/验证集中标签的频率。假设,作为一个极端的例子,某个标签在训练集中出现 100%(或 0%)的时间。然后模型不会学习任何东西,它只会预测 100% 或 0% 而不管图像是什么样子,这在测试数据中可能是错误的。这对您来说不会那么极端,但是当您的训练集和验证集具有非常不同的分布时,您仍然可以观察到类似的效果
  • 我明白你的意思。编辑了我的问题以包含您提供的代码的输出。对我来说似乎没问题..不确定您是否对此类结果有任何意见?
  • @Maystro 对于某些类(但不是全部),它可以帮助诊断性能不佳。例如,在训练中,10 类的图像平均还包含 7 个其他对象。在验证中,这突然只有 3。也许你的网络没有学会识别第 10 类的对象,也许它只是学会了识别“有很多对象的图像”。一般来说,这确实表明您的训练集和验证集确实具有显着不同的分布,这通常意味着您无法合理地期望机器学习的惊人性能
猜你喜欢
  • 2021-11-11
  • 2020-01-22
  • 2014-10-19
  • 2019-05-22
  • 2020-10-24
  • 2018-11-06
  • 2019-03-01
  • 2021-05-14
相关资源
最近更新 更多