【发布时间】:2021-08-19 12:12:01
【问题描述】:
我正在尝试将我的 PyTorch 模型导出为 ONNX 格式,但我不断收到此错误:
TypeError: forward() 缺少 1 个必需的位置参数:'text'
这是我的代码:
model = Model(opt)
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))
#model = torch.nn.DataParallel(model).to(device)
#print(model)
torch.onnx.export(model, dummy_input, "vitstr.onnx", verbose=True)
【问题讨论】:
标签: python deep-learning pytorch onnx