【问题标题】:roc_curve in sklearn: why doesn't it work correctly?sklearn 中的 roc_curve:为什么它不能正常工作?
【发布时间】:2019-11-21 08:29:35
【问题描述】:

我正在解决多类分类的任务,并希望使用 sklearn 中的 roc 曲线来估计结果。据我所知,如果我设置了一个正标签,它允许在这种情况下绘制一条曲线。 我尝试使用正标签绘制 roc 曲线并得到奇怪的结果:类的“正标签”越大,roc 曲线越靠近左上角。 然后我用数组的先前二进制标记绘制一条 roc 曲线。这2个地块是不同的!我认为第二个是正确构建的,但在二元类的情况下,情节只有 3 个点,这不是信息。

我想了解,为什么二进制类的 roc 曲线和带有“正标签”的 roc 曲线看起来不同,以及如何正确绘制带有正标签的 roc 曲线。

代码如下:

from sklearn.metrics import roc_curve, auc
y_pred = [1,2,2,2,3,3,1,1,1,1,1,2,1,2,3,2,2,1,1]
y_test = [1,3,2,2,1,3,2,1,2,2,1,2,2,2,1,1,1,1,1]
fp, tp, _ = roc_curve(y_test, y_pred, pos_label = 2)
from sklearn.preprocessing import label_binarize
y_pred = label_binarize(y_pred, classes=[1, 2, 3])
y_test = label_binarize(y_test, classes=[1, 2, 3])
fpb, tpb, _b = roc_curve(y_test[:,1], y_pred[:,1])
plt.plot(fp, tp, 'ro-', fpb, tpb, 'bo-', alpha = 0.5)
plt.show()
print('AUC with pos_label', auc(fp,tp))
print('AUC binary variant', auc(fpb,tpb))

这是example of the plot

红色曲线表示带有pos_label的roc_curve,蓝色曲线表示“二进制情况”下的roc_curve

【问题讨论】:

  • 似乎您处于多类设置(超过 2 个类),而不是多标签设置(单个实例可以属于多个类) - 已编辑问题和标签。
  • @desertnaut 你是对的,我有 3 个不同的课程。据我所知,如果我有很多类,设置 pos_label 将允许构建“one vs all”曲线。与 2 个二元类的 roc 曲线相同。
  • 一般来说,请记住 ROC 曲线需要y_pred 中的概率预测,而不是“硬”类。
  • @desertnaut 这种情况无疑更适合 ROC 曲线,但它写在 sklearn 文档中,我们可以使用非阈值度量:“目标分数,可以是正类的概率估计,置信度值或决策的非阈值度量(由某些分类器上的“decision_function”返回)”
  • 是的 - 重点是非阈值;但是您的 y_pred 确实是阈值,因此与您列出的措施相比,它提供了“硬”类。

标签: python scikit-learn roc multiclass-classification auc


【解决方案1】:

正如 cmets 中所解释的,ROC 曲线适合评估 阈值 预测(即硬类),如您的 y_pred;此外,在使用 AUC 时,记住一些对许多从业者来说并不明显的限制是很有用的 - 有关更多详细信息,请参阅 Getting a low ROC AUC score but a high accuracy 中自己的答案的最后部分。

您能否给我一些建议,我可以使用哪些指标来评估这种具有“硬”类的多类分类的质量?

最直接的方法是混淆矩阵和 scikit-learn 提供的分类报告:

from sklearn.metrics import confusion_matrix, classification_report

y_pred = [1,2,2,2,3,3,1,1,1,1,1,2,1,2,3,2,2,1,1]
y_test = [1,3,2,2,1,3,2,1,2,2,1,2,2,2,1,1,1,1,1]

print(classification_report(y_test, y_pred)) # caution - order of arguments matters!
# result:
             precision    recall  f1-score   support

          1       0.56      0.56      0.56         9
          2       0.57      0.50      0.53         8
          3       0.33      0.50      0.40         2

avg / total       0.54      0.53      0.53        19

cm = confusion_matrix(y_test, y_pred) # again, order of arguments matters
cm
# result:
array([[5, 2, 2],
       [4, 4, 0],
       [0, 1, 1]], dtype=int64)

从混淆矩阵中,您可以提取其他感兴趣的数量,例如每类的真假阳性等 - 有关详细信息,请参阅How to get precision, recall and f-measure from confusion matrix in Python中的答案

【讨论】:

  • 谢谢!至于地块——precision_recall_curve 能否具有代表性,否则我将面临与 ROC 曲线相同的问题?
  • @svetlana 不客气;精确召回曲线也需要非阈值预测
猜你喜欢
  • 2017-08-11
  • 1970-01-01
  • 2016-07-16
  • 2019-01-04
  • 2020-09-03
  • 2016-10-10
  • 2016-10-24
  • 2017-02-27
  • 2017-07-08
相关资源
最近更新 更多