【发布时间】:2021-11-21 01:57:13
【问题描述】:
您好,我正在阅读有关迁移学习的 pytorch 教程。 (https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
model.training 是干什么用的?
enter def visualize_model(model,num_images=6):
was_training=model.training
model.eval()
images_so_far=0
fig=plt.figure()
with torch.no_grad():
for i, (inputs,labels) in enumerate(dataloaders['val']):
inputs=inputs.to(device)
labels=labels.to(device)
outputs=model(inputs)
_,pred=torch.max(outputs,1)
for j in range(inputs.size()[0]):
images_so_far+=1
ax=plt.subplot(num_images//2,2,images_so_far)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_names[preds[j]]))
imshow(inputs.cpu().data[j])
if images_so_far==num_images:
model.train(mode=was_training)
return
model.train(mode=was_training)code here
我无法理解“model.train(model=was_training)”。有什么帮助吗??非常感谢
【问题讨论】:
-
这能回答你的问题吗? What does model.train() do in PyTorch?
-
哦,谢谢!但现在我想知道他们为什么在测试会话中使用 model.train。为什么他们把代码放在“with torch.no_grad()”里面?? was_training=false 不是很明显吗??
标签: pytorch