【发布时间】:2019-07-19 10:13:42
【问题描述】:
我得到了混淆矩阵,但由于我的实际数据集有很多分类类别,所以很难理解。
例子-
>>> from sklearn.metrics import confusion_matrix
>>> y_test
['a', 'a', 'b', 'c', 'd', 'd', 'e', 'a', 'c']
>>> y_pred
['b', 'a', 'b', 'c', 'a', 'd', 'e', 'a', 'c']
>>>
>>>
>>> confusion_matrix(y_test, y_pred)
array([[2, 1, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 2, 0, 0],
[1, 0, 0, 1, 0],
[0, 0, 0, 0, 1]], dtype=int64)
但是如何打印标签/列名以便更好地理解?
我什至试过这个 -
>>> pd.factorize(y_test)
(array([0, 0, 1, 2, 3, 3, 4, 0, 2], dtype=int64), array(['a', 'b', 'c', 'd', 'e'], dtype=object))
>>> pd.factorize(y_pred)
(array([0, 1, 0, 2, 1, 3, 4, 1, 2], dtype=int64), array(['b', 'a', 'c', 'd', 'e'], dtype=object))
有什么帮助吗?
【问题讨论】:
-
The docs for
confusion_matrix在labels参数下说:如果没有给出,则在y_true或y_pred中至少出现一次的那些将按排序顺序使用。 所以标签就是a b c d e,按顺序排列。
标签: classification python scikit-learn confusion-matrix