【发布时间】:2019-01-19 12:34:53
【问题描述】:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer': optimizer.state_dict()
}, is_best)
我正在这样保存我的模型。如何加载模型以便在其他地方使用它,例如 cnn 可视化?
这就是我现在加载模型的方式:
torch.load('model_best.pth.tar')
但是当我这样做时,我得到了这个错误:
AttributeError: 'dict' 对象没有属性 'eval'
我在这里错过了什么???
编辑:我想使用我训练的模型来可视化过滤器和毕业生。我正在使用这个repo 进行可视化。我将第 179 行替换为 torch.load('model_best.pth.tar')
【问题讨论】:
-
当我像你一样保存时没有收到错误。你的 pytorch 版本是多少?
-
@SalihKaragoz pytorch 版本:0.4.1
-
它与我存储模型的方式有关吗?我的自定义字典?
-
我认为您应该提供更多信息。没有语法错误。您是否尝试加载 multigpus 或类似的东西?
-
不,不。只是尝试加载模型,以便我可以测试它,然后我想可视化毕业生和过滤器。
标签: python deep-learning conv-neural-network pytorch