首先贴一份在cpu上运行的代码
1 import torch 2 from torchvision import transforms 3 from torchvision import datasets 4 from torch.utils.data import DataLoader 5 import torch.nn.functional as F 6 import torch.optim as optim 7 8 batch_size = 64 9 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 10 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform) 11 train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size) 12 test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform) 13 test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size) 14 15 16 class Net(torch.nn.Module): 17 def __init__(self): 18 super(Net, self).__init__() 19 self.l1 = torch.nn.Linear(784, 512) 20 self.l2 = torch.nn.Linear(512, 256) 21 self.l3 = torch.nn.Linear(256, 128) 22 self.l4 = torch.nn.Linear(128, 64) 23 self.l5 = torch.nn.Linear(64, 10) 24 25 def forward(self, x): 26 x = x.view(-1, 784) 27 x = F.relu(self.l1(x)) 28 x = F.relu(self.l2(x)) 29 x = F.relu(self.l3(x)) 30 x = F.relu(self.l4(x)) 31 return self.l5(x) 32 33 34 model = Net() 35 36 criterion = torch.nn.CrossEntropyLoss() 37 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 38 39 40 def train(epoch): 41 running_loss = 0.0 42 for batch_idx, data in enumerate(train_loader, 0): 43 inputs, target = data 44 optimizer.zero_grad() 45 # forward + backward + update 46 outputs = model(inputs) 47 loss = criterion(outputs, target) 48 loss.backward() 49 optimizer.step() 50 running_loss += loss.item() 51 if batch_idx % 300 == 299: 52 print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300)) 53 running_loss = 0.0 54 55 56 def test(): 57 correct = 0 58 total = 0 59 with torch.no_grad(): 60 for data in test_loader: 61 images, labels = data 62 outputs = model(images) 63 _, predicted = torch.max(outputs.data, dim=1) 64 total += labels.size(0) 65 correct += (predicted == labels).sum().item() 66 print('Accuracy on test set: %d %%' % (100 * correct / total)) 67 68 69 if __name__ == '__main__': 70 for epoch in range(10): 71 train(epoch) 72 test()