【发布时间】:2019-01-07 09:47:42
【问题描述】:
我根据教程构建了一个简单的网络,但出现此错误:
RuntimeError:预期的类型为 torch.cuda.FloatTensor 的对象,但已找到 为参数 #4 'mat1' 输入 torch.FloatTensor
有什么帮助吗?谢谢!
import torch
import torchvision
device = torch.device("cuda:0")
root = '.data/'
dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.out = torch.nn.Linear(28*28, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.out(x)
return x
net = Net()
net.to(device)
for i, (inputs, labels) in enumerate(dataloader):
inputs.to(device)
out = net(inputs)
【问题讨论】:
-
这个问题是关于 PyTorch,而不是 CUDA。这就是我删除标签的原因。两次。请不要再添加了
标签: python image-processing machine-learning deep-learning pytorch