【发布时间】:2018-11-09 02:23:11
【问题描述】:
我有一个包含我的预测的张量和一个包含我的二进制分类问题的实际标签的张量。如何有效地计算混淆矩阵?
【问题讨论】:
标签: pytorch
我有一个包含我的预测的张量和一个包含我的二进制分类问题的实际标签的张量。如何有效地计算混淆矩阵?
【问题讨论】:
标签: pytorch
在我使用 for 循环的第一个版本被证明效率低下后,这是我迄今为止提出的最快的解决方案,用于两个等维张量 prediction 和 truth:
def confusion(prediction, truth):
confusion_vector = prediction / truth
true_positives = torch.sum(confusion_vector == 1).item()
false_positives = torch.sum(confusion_vector == float('inf')).item()
true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
false_negatives = torch.sum(confusion_vector == 0).item()
return true_positives, false_positives, true_negatives, false_negatives
https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d@https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d评论版本和测试用例
【讨论】: