【问题标题】:issue while exporting torch model to onnx format将火炬模型导出为 onnx 格式时出现问题
【发布时间】:2021-08-19 12:12:01
【问题描述】:

我正在尝试将我的 PyTorch 模型导出为 ONNX 格式,但我不断收到此错误:

TypeError: forward() 缺少 1 个必需的位置参数:'text'

这是我的代码:

model = Model(opt)
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))
#model = torch.nn.DataParallel(model).to(device)
#print(model)
torch.onnx.export(model, dummy_input, "vitstr.onnx", verbose=True)

【问题讨论】:

    标签: python deep-learning pytorch onnx


    【解决方案1】:

    ViTSTR forward 需要两个位置参数,inputtext

    def forward(self, input, text, is_train=True, seqlen=25):
        # ...
    

    因此,你需要传递一个额外的参数:

    # ...
    dummy_text = # create a dummy_text as well, with the appropriate shape
    torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)
    

    【讨论】:

    • @RouaRouatbi 这不是 SO 的工作方式:) 这是一个不同的问题。如果对您有帮助,请考虑将其标记为答案或投票。您的新问题很容易解决,但它应该是一个不同的问题。请随时在此处发布指向您的新问题的链接。此外,请始终发布完整的回溯。
    猜你喜欢
    • 2022-08-04
    • 2018-11-05
    • 2020-11-23
    • 2022-01-09
    • 2019-08-06
    • 2018-09-02
    • 2017-07-20
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多