【发布时间】:2020-03-23 15:18:27
【问题描述】:
我正在为图像语义分割计算如下所示的混淆矩阵,这是一种非常冗长的方法:
def confusion_matrix(preds, labels, conf_m, sample_size):
preds = normalize(preds,0.9) # returns [0,1] tensor
preds = preds.flatten()
labels = labels.flatten()
for i in range(len(preds)):
if preds[i]==1 and labels[i]==1:
conf_m[0,0] += 1/(len(preds)*sample_size) # TP
elif preds[i]==1 and labels[i]==0:
conf_m[0,1] += 1/(len(preds)*sample_size) # FP
elif preds[i]==0 and labels[i]==0:
conf_m[1,0] += 1/(len(preds)*sample_size) # TN
elif preds[i]==0 and labels[i]==1:
conf_m[1,1] += 1/(len(preds)*sample_size) # FN
return conf_m
在预测循环中:
conf_m = torch.zeros(2,2) # two classes (object or no-object)
for img,label in enumerate(data):
...
out = Net(img)
conf_m = confusion_matrix(out, label, len(data))
...
是否有更快的方法(在 PyTorch 中)来有效地计算图像语义分割输入样本的混淆矩阵?
【问题讨论】:
标签: python-3.x pytorch metrics confusion-matrix