【问题标题】:Problem loading Pytourch 3.0 model unexpected key "module.features.0.weight" in state_dict在 state_dict 中加载 Pytourch 3.0 模型意外键“module.features.0.weight”时出现问题
【发布时间】:2019-03-27 10:11:46
【问题描述】:

我正在尝试加载我使用 Pytorch 训练过的模型, 但我不断收到以下错误:

文件“convert.py”,第 12 行,在 model.load_state_dict(torch.load('model/model_vgg2d_2.pth')) 文件 "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", 第 490 行,在 load_state_dict .format(name)) KeyError: 'state_dict 中的意外键“module.features.0.weight”'

下面是我的代码:

import torch.onnx
import torch.nn as nn

class TempModel(nn.Module):
    def __init__(self):
        super(TempModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 5, (3, 3))
    def forward(self, inp):
        return self.conv1(inp)

model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")

我正在使用用于训练模型的同一台机器(具有多个 GPU)。 任何想法我做错了什么?

【问题讨论】:

  • 您是否有可能尝试加载完全不同模型的 state_dict?您是否尝试在 TempModel 上强制使用 VGG 权重??
  • 哦,这就是问题所在,我需要生成完全相同的模型才能加载,谢谢! (我习惯了 TF,我只是加载了一个 pb 文件,而不管他的基本模型如何)

标签: python-3.x pytorch onnx


【解决方案1】:

加载state_dict 时,您需要它是相同 模型的state_dict:您不能将VGG 模型的state_dict 加载到完全不同的BasicModel 中。


旧答案
您保存了未将 nn.DataParallel 应用于模型的模型,现在您尝试在添加后加载。试试

model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model)  # parallel AFTER load

【讨论】:

  • 相同的结果 - 首先它给了我明显的“NameError: name 'model' is not defined”,当我这样做时:model = TempModel() model.load_state_dict(torch.load('model/model_vgg2d_2 .pth')) model = nn.DataParallel(TempModel()) 它从来没有到达第三行,因为它给了我同样的错误: KeyError: 'unexpected key "module.features.0.weight" in state_dict跨度>
猜你喜欢
  • 1970-01-01
  • 2019-07-29
  • 2022-10-19
  • 2019-05-23
  • 1970-01-01
  • 2021-02-21
  • 2017-10-29
  • 1970-01-01
  • 2020-09-08
相关资源
最近更新 更多