【问题标题】:unable to load pytorch model for evaluation无法加载 pytorch 模型进行评估
【发布时间】:2021-10-08 02:37:11
【问题描述】:

我保存了一个.pth 模型,我正在尝试使用以下代码加载以进行推理

model = GatherModel()
model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))

我收到如下所示的错误。为什么我会得到这个。

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-3bff0e426886> in <module>()
----> 1 model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1405         if len(error_msgs) > 0:
   1406             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1409 

RuntimeError: Error(s) in loading state_dict for GatherModel:
    Missing key(s) in state_dict: "lin0.weight", "lin0.bias", "set2set.lstm.weight_ih_l0", "set2set.lstm.weight_hh_l0", "set2set.lstm.bias_ih_l0", "set2set.lstm.bias_hh_l0", "message_layer.weight", "message_layer.bias", "conv.bias", "conv.edge_func.0.weight", "conv.edge_func.0.bias", "conv.edge_func.2.weight", "conv.edge_func.2.bias". 
    Unexpected key(s) in state_dict: "solute_pass.U_0.weight", "solute_pass.U_0.bias", "solute_pass.U_1.weight", "solute_pass.U_1.bias", "solute_pass.U_2.weight", "solute_pass.U_2.bias", "solute_pass.M_0.weight", "solute_pass.M_0.bias", "solute_pass.M_1.weight", "solute_pass.M_1.bias", "solute_pass.M_2.weight", "solute_pass.M_2.bias", "solvent_pass.U_0.weight", "solvent_pass.U_0.bias", "solvent_pass.U_1.weight", "solvent_pass.U_1.bias", "solvent_pass.U_2.weight", "solvent_pass.U_2.bias", "solvent_pass.M_0.weight", "solvent_pass.M_0.bias", "solvent_pass.M_1.weight", "solvent_pass.M_1.bias", "solvent_pass.M_2.weight", "solvent_pass.M_2.bias", "lstm_solute.weight_ih_l0", "lstm_solute.weight_hh_l0", "lstm_solute.bias_ih_l0", "lstm_solute.bias_hh_l0", "lstm_solvent.weight_ih_l0", "lstm_solvent.weight_hh_l0", "lstm_solvent.bias_ih_l0", "lstm_solvent.bias_hh_l0", "lstm_gather_solute.weight_ih_l0", "lstm_gather_solute.weight_hh_l0", "lstm_gather_solute.bias_ih_l0", "lstm_gather_solute.bias_hh_l0", "lstm_gather_solvent.weight_ih_l0", "lstm_gather_solvent.weight_hh_l0", "lstm_gather_solvent.bias_ih_l0", "lstm_gather_solvent.bias_hh_l0", "first_layer.weight", "first_layer.bias", "second_layer.weight", "second_layer.bias", "third_layer.weight", "third_layer.bias", "fourth_layer.weight", "fourth_layer.bias". 

我曾尝试在 state_dict 中使用strict=False,但出现此错误

_IncompatibleKeys(missing_keys=['lin0.weight', 'lin0.bias', 'set2set.lstm.weight_ih_l0', 'set2set.lstm.weight_hh_l0', 'set2set.lstm.bias_ih_l0', 'set2set.lstm.bias_hh_l0', 'message_layer.weight', 'message_layer.bias', 'conv.bias', 'conv.edge_func.0.weight', 'conv.edge_func.0.bias', 'conv.edge_func.2.weight', 'conv.edge_func.2.bias'], unexpected_keys=['solute_pass.U_0.weight', 'solute_pass.U_0.bias', 'solute_pass.U_1.weight', 'solute_pass.U_1.bias', 'solute_pass.U_2.weight', 'solute_pass.U_2.bias', 'solute_pass.M_0.weight', 'solute_pass.M_0.bias', 'solute_pass.M_1.weight', 'solute_pass.M_1.bias', 'solute_pass.M_2.weight', 'solute_pass.M_2.bias', 'solvent_pass.U_0.weight', 'solvent_pass.U_0.bias', 'solvent_pass.U_1.weight', 'solvent_pass.U_1.bias', 'solvent_pass.U_2.weight', 'solvent_pass.U_2.bias', 'solvent_pass.M_0.weight', 'solvent_pass.M_0.bias', 'solvent_pass.M_1.weight', 'solvent_pass.M_1.bias', 'solvent_pass.M_2.weight', 'solvent_pass.M_2.bias', 'lstm_solute.weight_ih_l0', 'lstm_solute.weight_hh_l0', 'lstm_solute.bias_ih_l0', 'lstm_solute.bias_hh_l0', 'lstm_solvent.weight_ih_l0', 'lstm_solvent.weight_hh_l0', 'lstm_solvent.bias_ih_l0', 'lstm_solvent.bias_hh_l0', 'lstm_gather_solute.weight_ih_l0', 'lstm_gather_solute.weight_hh_l0', 'lstm_gather_solute.bias_ih_l0', 'lstm_gather_solute.bias_hh_l0', 'lstm_gather_solvent.weight_ih_l0', 'lstm_gather_solvent.weight_hh_l0', 'lstm_gather_solvent.bias_ih_l0', 'lstm_gather_solvent.bias_hh_l0', 'first_layer.weight', 'first_layer.bias', 'second_layer.weight', 'second_layer.bias', 'third_layer.weight', 'third_layer.bias', 'fourth_layer.weight', 'fourth_layer.bias'])

【问题讨论】:

    标签: python pytorch inference


    【解决方案1】:

    该错误基本上表示您正在使用的架构定义的权重不在state_dict 中,还有一些架构未定义的权重,但存在于state_dict 中。您确定GatherModel() 定义的架构与最初创建state_dict 的架构相同吗?因为这个错误表明答案是否定的。

    【讨论】:

      猜你喜欢
      • 2019-09-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-07-05
      • 1970-01-01
      • 1970-01-01
      • 2020-08-18
      相关资源
      最近更新 更多