【问题标题】:pytorch doesn't give expected outputpytorch 没有给出预期的输出
【发布时间】:2018-09-18 06:24:41
【问题描述】:

首先,一堆数据被CNN模型分类。然后,我尝试对第一步中正确分类的数据进行预测,预计准确度为 100%。但是,我发现结果不稳定,有时是 99+%,但不是 100%。有人知道我的代码有什么问题吗?提前多谢了,困扰我好几天了~~

torch.版本

'0.3.1.post2'

import numpy as np
import torch 
import torch.nn as nn
from torch.autograd import Variable

n = 2000
data = np.random.randn(n, 1, 10, 10)
label = np.random.randint(2, size=(n, ))

def test_pred(model, data_test, label_test):

    data_batch = data_test
    labels_batch = label_test

    images = torch.autograd.Variable(torch.FloatTensor(data_batch))
    labels = torch.autograd.Variable(torch.FloatTensor(labels_batch))

    outputs = model(images)

    _, predicted = torch.max(outputs.data, 1)

    correct = (np.array(predicted) == labels_batch).sum()

    label_pred = np.array(predicted)

    acc = correct/len(label_test)
    print(" acc:", acc)

    return acc, label_pred

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(128, 2)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

cnn = CNN()

[_, label_pred] = test_pred(cnn, data, label)

print("Acc:", np.mean(label_pred==label))
# Given the correctly classified data in previous step, expect to get 100% accuracy
# Why it sometimes doesn't give a 100% accuracy ?
print("Using selected data size {}:".format(data[label_pred==label].shape))
_, _ = test_pred(cnn, data[label_pred==label], label[label_pred==label])

输出:

acc:0.482

加速度:0.482

使用选定的数据大小(964、1、10、10):

acc:0.9979253112033195

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    似乎您没有将网络设置为评估模式,这可能会导致一些问题,尤其是 BatchNorm 层。做

    cnn = CNN()
    cnn.eval()
    

    它应该可以工作。

    【讨论】:

    • 就是这样,我试了一下,完美解决了问题。非常感谢。
    猜你喜欢
    • 2019-09-07
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-02-22
    • 2019-07-16
    • 1970-01-01
    相关资源
    最近更新 更多