【问题标题】:Converting Pytorch model .pth into onnx model将 Pytorch 模型 .pth 转换为 onnx 模型
【发布时间】:2018-10-05 00:49:01
【问题描述】:

我有一个预训练模型,格式为 .pth 扩展名。我想把它转换成 Tensorflow protobuf。但我没有找到任何方法来做到这一点。我已经看到 onnx 可以将模型从 pytorch 转换为 onnx,然后从 onnx 转换为 Tensorflow。但是使用这种方法,我在转换的第一阶段遇到了以下错误。

from torch.autograd import Variable
import torch.onnx
import torchvision
import torch 

dummy_input = Variable(torch.randn(1, 3, 256, 256))
model = torch.load('./my_model.pth')
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")`

它给出了这样的错误。

File "t.py", line 9, in <module>
    torch.onnx.export(model, dummy_input, "moment-in-time.onnx")
  File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 75, in export
    _export(model, args, f, export_params, verbose, training)
  File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 108, in _export
    orig_state_dict_keys = model.state_dict().keys()
AttributeError: 'dict' object has no attribute 'state_dict'

什么是可能的解决方案?

【问题讨论】:

  • 您的.pth 文件是状态字典,而不是完整的模型。您首先需要创建一个模型,然后加载该状态字典,然后开始您的转换过程。检查this answer
  • 其中显示的方法需要编写模型。但我有预训练模型,我不知道它的确切架构。所以我不能像在那个答案中那样定义模型。我该怎么办?
  • 那么就很难确定架构了。您可以通过查看参数大小来猜测架构,但是即使在查看大小之后猜测正确的架构也非常困难,因为残差网络将具有与非残差网络相同大小的参数。最好的办法是从预训练的权重源中获取架构定义
  • 好的。让我们看看我能不能得到它。谢谢你的帮助。如果我有 .pth.tar 文件,那么这个过程也会相同还是改变?

标签: python tensorflow deep-learning pytorch


【解决方案1】:

尝试将您的代码更改为此

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

dummy_input = Variable(torch.randn(1, 3, 256, 256))
state_dict = torch.load('./my_model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")

【讨论】:

    【解决方案2】:

    这里的问题是您正在加载模型的权重,但您也需要模型的架构,例如,如果您使用的是 mobilenet:

    import torch
    import torchvision.models as models
    
    model=models.mobilenet_v3_large(weights)#Give your weights here
    torch.onnx.export(model, torch.rand(1,3,640,640), "MobilenetV3.onnx")
    

    更多信息请参考:https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html

    【讨论】:

      【解决方案3】:

      这意味着您的模型不是 torch.nn.Modules 类的子类。如果您将其设为子类,则应该可以。

      【讨论】:

      • 怎么做?请提供一些指导。
      • 定义类时,将标题编辑为:class ModelName(nn.Module)
      猜你喜欢
      • 2021-10-20
      • 2022-10-13
      • 2023-01-31
      • 2018-11-25
      • 2019-11-10
      • 1970-01-01
      • 2020-03-31
      • 2020-02-12
      • 2019-12-09
      相关资源
      最近更新 更多