【问题标题】:Can't load trained PyTorch model无法加载经过训练的 PyTorch 模型
【发布时间】:2022-02-10 06:04:58
【问题描述】:

我已经在自定义数据集上训练了 ResNet152。 当我尝试以这种方式加载它时:

trained_model = torch.nn.Module.load_state_dict(torch.load('/content/drive/My Drive/X-Ray-pneumonia-with-CV/X-ray-pytorch-model.pth'))
trained_model.eval()

我收到一个错误: RuntimeError:尝试反序列化 CUDA 设备上的对象,但 torch.cuda.is_available() 为 False。如果您在仅 CPU 的机器上运行,请使用带有 map_location=torch.device('cpu') 的 torch.load 将您的存储映射到 CPU。

当我添加 map_location 时:

trained_model = torch.nn.Module.load_state_dict(torch.load('/content/drive/My Drive/X-Ray-pneumonia-with-CV/X-ray-pytorch-model.pth',
                                                           map_location = torch.device('cpu')))
trained_model.eval()

我遇到了另一个错误: TypeError: load_state_dict() 缺少 1 个必需的位置参数:'state_dict'

那么我做错了什么?请帮忙

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    而不是调用torch.nn.Module.load_state_dict,您应该首先实例化您要加载的模块类的对象。否则,load_state_dict 的参数 self 不会绑定到任何东西。这样,您通过torch.load 加载的状态字典将作为self 而不是state_dict 传递。看看this 的回答就明白其中的区别了。

    【讨论】:

    • 添加了trained_model = models.resnet152() 并且仍然不起作用
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-04-05
    • 2017-07-28
    • 2020-02-07
    • 2019-07-17
    • 2019-11-06
    • 2018-04-21
    • 2021-07-16
    相关资源
    最近更新 更多