【问题标题】:Pytorch MNIST code is returning IndexErrorPytorch MNIST 代码返回 IndexError
【发布时间】:2020-07-21 20:39:31
【问题描述】:

我遵循了 Pytorch 文档,并为 MNIST 数据集制作了一个非常简单的分类器。以下是我的代码:

import numpy as np

import torch
import torchvision
from torchvision import transforms, datasets

import torch.nn as nn
import torch.nn.functional as F

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
    ])

train = datasets.MNIST('', train=True, download=True, transform=transform)
test = datasets.MNIST('', train=False, download=True, transform=transform)
trainset = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=1, shuffle=False)

class Classifier(nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Classifier, self).__init__()
        self.linear_1 = torch.nn.Linear(D_in, H)
        self.linear_2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = self.linear_1(x).clamp(min=0)
        x = self.linear_2(x)
        return F.log_softmax(x, dim=1)


net = Classifier(28*28, 128, 10)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

for epoch in range(3):
    running_loss = 0.0
    for X, label in iter(trainset):
        X = X.view(28*28, -1)

        optimizer.zero_grad()

        output = net(torch.flatten(X))
        loss = nn.CrossEntropyLoss(output, label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')
            running_loss = 0.0
print("Finished training.")

torch.save(net.state_dict(), './classifier.pth')

由于某种原因,我得到了输出

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

在线:output = net(torch.flatten(X)

提前感谢您的帮助!

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    flatten() 你删除所有维度包括批量维度!

    试试:

    output = net(x.view(x.shape[0], -1))
    

    【讨论】:

    • 谢谢,这行得通。对于阅读本文的任何人,我还必须删除行 X = X.view(784, -1)
    猜你喜欢
    • 2019-08-01
    • 2019-07-23
    • 2016-04-21
    • 2020-12-16
    • 1970-01-01
    • 2021-06-14
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多