【发布时间】:2016-08-24 16:09:41
【问题描述】:
我正在尝试根据 tensorflow 提供的 CIFAR-10 示例的修改版本绘制 ROC 曲线。现在是 2 个班级而不是 10 个班级。
网络的输出称为 logits 并采用以下形式:
[[-2.57313061 2.57966399] [ 0.04221377 -0.04033273] [-1.42880082 1.43337202] [-2.7692945 2.78173304] [-2.48195744 2.49331546] [2.0941515 -2.10268974] [-3.51670194 3.53267646] [-2.74760485 2.75617]
首先,这些logits实际上代表什么?网络中的最后一层是 WX+b 形式的“softmax linear”。
模型可以通过调用计算准确率
top_k_op = tf.nn.in_top_k(logits, labels, 1)
然后一旦图被初始化:
predictions = sess.run([top_k_op])
predictions_int = np.array(predictions).astype(int)
true_count += np.sum(predictions)
...
precision = true_count / total_sample_count
这很好用。
但是现在我如何从中绘制 ROC 曲线?
我一直在尝试“sklearn.metrics.roc_curve()”函数 (http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#sklearn.metrics.roc_curve),但我不知道用什么作为“y_score”参数。
任何帮助将不胜感激!
【问题讨论】:
-
查看此处link 获取计算和绘制 ROC 曲线的代码。
标签: python scikit-learn tensorflow roc