【问题标题】:Faster method of computing confusion matrix?计算混淆矩阵的更快方法?
【发布时间】:2020-03-23 15:18:27
【问题描述】:

我正在为图像语义分割计算如下所示的混淆矩阵,这是一种非常冗长的方法:

def confusion_matrix(preds, labels, conf_m, sample_size):
    preds = normalize(preds,0.9) # returns [0,1] tensor
    preds = preds.flatten()
    labels = labels.flatten()
    for i in range(len(preds)):
        if preds[i]==1 and labels[i]==1:
            conf_m[0,0] += 1/(len(preds)*sample_size) # TP
        elif preds[i]==1 and labels[i]==0:
            conf_m[0,1] += 1/(len(preds)*sample_size) # FP
        elif preds[i]==0 and labels[i]==0:
            conf_m[1,0] += 1/(len(preds)*sample_size) # TN
        elif preds[i]==0 and labels[i]==1:
            conf_m[1,1] += 1/(len(preds)*sample_size) # FN 
    return conf_m

在预测循环中:

conf_m = torch.zeros(2,2) # two classes (object or no-object)
for img,label in enumerate(data):
    ...
    out = Net(img)
    conf_m = confusion_matrix(out, label, len(data))
    ...

是否有更快的方法(在 PyTorch 中)来有效地计算图像语义分割输入样本的混淆矩阵?

【问题讨论】:

    标签: python-3.x pytorch metrics confusion-matrix


    【解决方案1】:

    我使用这两个函数来计算混淆矩阵(在sklearn 中定义):

    # rewrite sklearn method to torch
    def confusion_matrix_1(y_true, y_pred):
        N = max(max(y_true), max(y_pred)) + 1
        y_true = torch.tensor(y_true, dtype=torch.long)
        y_pred = torch.tensor(y_pred, dtype=torch.long)
        return torch.sparse.LongTensor(
            torch.stack([y_true, y_pred]), 
            torch.ones_like(y_true, dtype=torch.long),
            torch.Size([N, N])).to_dense()
    
    # weird trick with bincount
    def confusion_matrix_2(y_true, y_pred):
        N = max(max(y_true), max(y_pred)) + 1
        y_true = torch.tensor(y_true, dtype=torch.long)
        y_pred = torch.tensor(y_pred, dtype=torch.long)
        y = N * y_true + y_pred
        y = torch.bincount(y)
        if len(y) < N * N:
            y = torch.cat(y, torch.zeros(N * N - len(y), dtype=torch.long))
        y = y.reshape(N, N)
        return y
    
    y_true = [2, 0, 2, 2, 0, 1]
    y_pred = [0, 0, 2, 2, 0, 2]
    
    confusion_matrix_1(y_true, y_pred)
    # tensor([[2, 0, 0],
    #         [0, 0, 1],
    #         [1, 0, 2]])
    
    

    在类数量较少的情况下,第二个函数更快。

    %%timeit
    confusion_matrix_1(y_true, y_pred)
    # 102 µs ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    %%timeit
    confusion_matrix_2(y_true, y_pred)
    # 25 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    

    【讨论】:

    • 在混淆矩阵 2 中返回 y 而不是 T
    【解决方案2】:

    感谢Grigory Feldman 的回答!我不得不改变一些事情来配合我的实现。

    对于未来的观察者,这是我的最终函数,它总结了每个混淆矩阵在一批输入中的百分比(用于训练或测试循环)

    def confusion_matrix_2(y_true, y_pred, sample_sz, conf_m):
        y_pred = normalize(y_pred,0.9)
        obj = y_true[y_true==1]
        no_obj = y_true[y_true==0]
        N = torch.tensor(torch.max(torch.max(y_true), torch.max(y_pred)) + 1,dtype=torch.int)
        y_true = torch.tensor(y_true, dtype=torch.long)
        y_pred = torch.tensor(y_pred, dtype=torch.long)
        y = N * y_true + y_pred
        y = torch.bincount(y.flatten())
        if len(y) < N * N:
            y = torch.cat((y, torch.zeros(N * N - len(y), dtype=torch.long)))
        y = y.reshape(N.item(), N.item())
        y = y.float()
        conf_m[0,:] += y[0,:]/(len(no_obj)*sample_sz)
        conf_m[1,:] += y[1,:]/(len(obj)*sample_sz)
        return conf_m
    
    ...
    conf_m = torch.zeros((2, 2),dtype=torch.float) # two classes (object or no-object)
    for _, data in enumerate(dataloader):
        for img,label in enumerate(data):
            ...
            out = Net(img)
            conf_m = confusion_matrix(out, label, len(data))
            ...
        ...
    

    【讨论】:

    • normalize 未定义。你是说y_pred = (ypred&gt;0.9).float() 吗?
    • 加上这个实现提供了 NaN。
    【解决方案3】:

    感谢Grigory Feldman的答案!
    o先生和我用numpy制作。

    # weird trick with bincount
    def confusion_matrix_2_numpy(y_true, y_pred, N=None):
        y_true = y_true.reshape(-1)
        y_pred = y_pred.reshape(-1) 
        if (N is None):
            N = max(max(y_true), max(y_pred)) + 1
        y = N * y_true + y_pred
        y = np.bincount(y, minlength=N*N)
        y = y.reshape(N, N)
        return y
    

    请尝试。
    当您使用已知的class_num时,它可能会更快。
    我应该提到的一件事是,最大值不匹配原始类别的情况。
    例如,当批量大小很小并且对每个迭代集成了混淆矩阵。

    【讨论】:

      猜你喜欢
      • 2018-04-01
      • 2017-02-25
      • 2020-07-09
      • 2020-05-17
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多