【问题标题】:Wrong ROC curve for multiclass classification多类分类的错误 ROC 曲线
【发布时间】:2021-07-30 07:35:01
【问题描述】:

我已经训练了一个 CNN 来将图像分类为 5 个类别。但是当我尝试为每个类与其他类绘制 ROC 曲线时,所有 5 个类几乎都有一条对角线曲线,AUC 约为 0.5。我不知道出了什么问题。

模型的准确率应该在 86% 左右。

代码如下:

import os, shutil
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import plot_confusion_matrix, accuracy_score
from sklearn.metrics import roc_curve, auc, roc_auc_score, RocCurveDisplay
from sklearn.preprocessing import label_binarize
import random

model = tf.keras.models.load_model('G:/Myxoid lesion/Myxoid_EN3_finetune4b')

model.summary()

data_dir='G:/Myxoid lesion/Test/'

batch_size = 64
img_height = 300
img_width = 300

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  seed = 123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

model.compile(optimizer = optimizers.Adam(lr=0.00002),
              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics = ['sparse_categorical_accuracy'])

correct =  np.array([], dtype='int32')

# Get the labels of test_ds
for x, y in test_ds:
    correct = np.concatenate([correct, y.numpy()])

# Get the prediction probabilities for each class for each test image
prediction_prob = tf.nn.softmax(model.predict(test_ds))

num_class = 5
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(num_class):
    fpr[i], tpr[i], _ = roc_curve(correct, prediction_prob[:,i], pos_label=i)
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure()
lw = 2
for i in range(num_class):
    plt.plot(fpr[i],tpr[i],
             color=(random.random(),random.random(),random.random()),
             label='{0} (AUC = {1:0.2f})'''.format(labels[i], roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.legend(loc="lower right")
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')

plt.title('ROC analysis')

plt.show()

“prediction_prob”变量包含:

array([[6.3877934e-09, 6.3617526e-06, 5.5736535e-07, 4.9789862e-05,
        9.9994326e-01],
       [6.5260068e-08, 8.8882577e-03, 3.9350948e-06, 9.9110776e-01,
        4.0252076e-11],
       [2.7514220e-04, 2.9315910e-05, 1.6688553e-04, 9.9952865e-01,
        3.5938730e-10],
       ...,
       [1.1131389e-09, 9.8325908e-01, 3.4283744e-06, 1.6737511e-02,
        7.3243338e-12],
       [1.4697845e-08, 4.7125661e-05, 1.4077022e-03, 6.4052530e-02,
        9.3449265e-01],
       [9.9999940e-01, 1.3071107e-07, 4.3149896e-07, 4.7902233e-08,
        9.2861301e-09]], dtype=float32)>

虽然“正确”变量包含每个测试图像的正确标签:

array([0, 1, 4, ..., 4, 2, 4])

我想我遵循scikit-learn 网站上提到的内容。

生成的 tpr[i] 和 fpr[i] 变量变为线性相关,因此 AUC 变为 0.5

我觉得生成tpr[i]和fpr[i]有问题?谁能解决这个问题?

谢谢!

【问题讨论】:

  • 代码应该没问题.. 只是查看你的输出中的概率和真实标签,它们不相符,你能用混淆矩阵检查吗?
  • 看来model.predict会给出与原始test_ds不同顺序的结果。
  • 如果我在 test_ds 中对 x, y 使用 prediction = np.array([]):prediction = np.concatenate([prediction, tf.nn.softmax(model.predict(x))] ) 它将返回:ValueError:所有输入数组必须具有相同的维数,但索引 0 处的数组有 1 个维度,而索引 1 处的数组有 2 个维度
  • 有没有办法获取test_ds的标签?我使用: for x, y in test_ds: correct = np.concatenate([correct, y.numpy()])
  • @StupidWolf 混乱是可以的。当我按如下方式生成预测和标签时:for x, y in test_ds: prediction = np.concatenate([prediction, np.argmax(tf.nn.softmax(model.predict(x)), axis=-1)] ) 正确 = np.concatenate([正确, y.numpy()])

标签: tensorflow scikit-learn roc multiclass-classification


【解决方案1】:

如果我以这种方式生成标签和预测,那么我可以得到正确的ROC曲线:

prediction_prob = np.array([]).reshape(0,5)
correct =  np.array([], dtype='int32')

for x, y in test_ds:
    correct = np.concatenate([correct, y.numpy()])
    prediction_prob = np.vstack([prediction_prob, tf.nn.softmax(model.predict(x))])

但是,如果我从 model.predict(test_ds) 获得预测,则预测的顺序与原始数据集不同,因此它与原始标签不匹配。我不确定这是否是 tensorflow 中的“错误”,或者对此有其他解释。

我也无法获得微平均(尽管这对我的目标并不重要)

fpr["micro"], tpr["micro"], _ = roc_curve(correct.ravel(), prediction_prob.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

它给出了以下错误:

raise ValueError("{0} format is not supported".format(y_type))
ValueError: multiclass format is not supported

【讨论】:

    猜你喜欢
    • 2016-08-06
    • 2021-03-03
    • 2021-03-08
    • 2012-07-10
    • 2018-12-19
    • 2018-12-24
    • 2012-09-04
    • 2022-01-13
    • 2014-09-27
    相关资源
    最近更新 更多