【发布时间】:2021-10-20 01:29:34
【问题描述】:
我正在尝试将 PyTorch 模型(包含权重的 pth 文件)转换为 onnx 文件,然后再转换为 TensorFlow 模型,因为我在 TensorFlow 上工作。然后对其进行微调。 这是我迄今为止的尝试。然而,我不断收到错误。enter image description here 我认为问题在于权重是针对视觉转换器的。但是我还没有弄清楚要使用什么类型的模型来转换它。我假设是 CRNN,但如果有更简单的方法,我很想知道。 PS:我确实将 pth 文件加载到了我的驱动器中。路径是正确的
from torch.autograd import Variable
import torch.onnx
import torchvision
import torch
import onnx
import torch.nn as nn
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_base_patch16_224_aug.pth'
model = torchvision.models.vgg16()
model.load_state_dict(torch.load(file_path))
model.eval()
torch.onnx.export(model, dummy_input, "vitstr.onnx")
【问题讨论】:
-
你得到 Roua 的错误是什么?请添加信息以便人们可以帮助您,否则您的问题可能会被关闭。
-
将整个错误堆栈添加到您的描述中。
-
您是否尝试将
Vision Transformer模型的权重和偏差加载到VGG16模型中? -
我已经更新了描述。非常感谢,希望现在清楚了
-
@Kishore 可能是正确的:将视觉转换器的参数加载到新初始化的 VGG 模型中是没有意义的。您看到的错误消息是告诉您参数名称不同。
标签: tensorflow machine-learning deep-learning pytorch onnx