【发布时间】:2019-12-23 17:59:32
【问题描述】:
如何正确编写代码来绘制此线性回归模型中的预测值?
我正在使用本教程学习线性回归:https://www.deeplearningwizard.com/deep_learning/practical_pytorch/pytorch_linear_regression/
我能够成功实现 GPU。我的问题是绘制预测值。我尝试寻找解决方案来学习如何将值保持为张量,但似乎我没有这样做的语法知识。
从这里开始
epochs = 100
for epoch in range(epochs):
epoch += 1
# Convert numpy array to torch Variable
if torch.cuda.is_available():
inputs = (torch.from_numpy(x_train).cuda())
labels = (torch.from_numpy(y_train).cuda())
else:
inputs = (torch.from_numpy(x_train))
labels = (torch.from_numpy(y_train))
# Clear gradients w.r.t. parameters
optimizer.zero_grad()
# Forward to get output
outputs = model(inputs)
# Calculate Loss
loss = criterion(outputs, labels)
# Getting gradients w.r.t. parameters
loss.backward()
# Updating parameters
optimizer.step()
# Logging
print('epoch {}, loss {}'.format(epoch, loss.item()))
这里做了预测,我选择使用cuda
predicted = model(Variable(torch.from_numpy(x_train).requires_grad_().cuda()))
print("Predicted")
print(predicted)
print("Output")
print(y_train)
plt.clf()
# Get predictions
#predicted = model(Variable(torch.from_numpy(x_train).requires_grad_().cuda()))
# Plot true data
plt.plot(x_train, y_train, 'go', label='True data', alpha=0.5)
无法绘图后在此处调用错误
# Plot predictions
plt.plot(x_train, predicted, '--', label='Predictions', alpha=0.5)
# Legend and plot
plt.legend(loc='best')
plt.show()
给定错误:
Traceback (most recent call last):
File "D:/Test with GPU/Linear regression.py", line 101, in <module>
plt.plot(x_train, predicted, '--', label='Predictions', alpha=0.5)
File "D:\Anaconda3\envs\gputest\lib\site-packages\matplotlib\pyplot.py", line 2795, in plot
is not None else {}), **kwargs)
File "D:\Anaconda3\envs\gputest\lib\site-packages\matplotlib\axes\_axes.py", line 1666, in plot
lines = [*self._get_lines(*args, data=data, **kwargs)]
File "D:\Anaconda3\envs\gputest\lib\site-packages\matplotlib\axes\_base.py", line 225, in __call__
yield from self._plot_args(this, kwargs)
File "D:\Anaconda3\envs\gputest\lib\site-packages\matplotlib\axes\_base.py", line 391, in _plot_args
x, y = self._xy_from_xy(x, y)
File "D:\Anaconda3\envs\gputest\lib\site-packages\matplotlib\axes\_base.py", line 271, in _xy_from_xy
if x.ndim > 2 or y.ndim > 2:
AttributeError: 'Tensor' object has no attribute 'ndim'
【问题讨论】:
标签: python numpy matplotlib pytorch