【问题标题】:How to enable Dict/OrderedDict/NamedTuple support in pytorch 1.1.0 JIT compiler?如何在 pytorch 1.1.0 JIT 编译器中启用 Dict/OrderedDict/NamedTuple 支持?
【发布时间】:2019-11-12 14:03:30
【问题描述】:

来自 pytorch 1.1.0 的发布亮点。似乎最新的 JIT 编译器现在支持 Dict 类型。 (来源:https://jaxenter.com/pytorch-1-1-158332.html

TorchScript 中的字典和列表支持:列表和字典类型的行为类似于 Python 列表和字典。

很遗憾,我找不到使这项改进正常工作的方法。以下代码是将特征金字塔网络 (FPN) 导出到 tensorboard 的简单示例,它使用 JIT 编译器:

from collections import OrderedDict

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

torchWriter = SummaryWriter(log_dir=".tensorboard/example1")

m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
# get some dummy data
x = OrderedDict()
x['feat0'] = torch.rand(1, 10, 64, 64)
x['feat2'] = torch.rand(1, 20, 16, 16)
x['feat3'] = torch.rand(1, 30, 8, 8)
# compute the FPN on top of x
output = m.forward(x)
print([(k, v.shape) for k, v in output.items()])

torchWriter.add_graph(m, input_to_model=x)

当我运行它时,我得到了以下错误:

Traceback (most recent call last):
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 276, in graph
    trace, _ = torch.jit.get_trace_graph(model, args)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 231, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 284, in forward
    in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got collections.OrderedDict

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/peng/git-drone/gate_detection/python/gate_detection/errorcase/tb.py", line 36, in <module>
    torchWriter.add_graph(m, input_to_model=x)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 534, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, **kwargs))
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 279, in graph
    _ = model(*args)  # don't catch, just print the error message
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

从错误消息看来,支持仍在等待中。我可以相信发布亮点吗?还是我没有正确使用 API?

【问题讨论】:

    标签: pytorch tensorboard torchscript


    【解决方案1】:

    发行说明是准确的,尽管有点模糊。该链接中描述的字典/列表/用户定义的类支持(和official release notes)仅适用于TorchScript compiler(发行说明中有一些代码示例),但SummaryWriter 默认情况下将运行TorchScript 跟踪器在您传递给它的任何模块上,跟踪器仅支持张量和张量的列表/元组。

    所以解决方法是使用 TorchScript 编译器而不是跟踪器,但这需要:

    1. 获取原始代码
    2. 支持 Tensorboard 中的编译输出 (ScriptModule)

    您应该 file an issue 处理 (2),并且有 ongoing work 来修复 (1),但这在短期内对于该模型 afaik 不起作用。

    【讨论】:

    • 非常感谢您的澄清!只是为了确定:你的意思是 TorchScript 跟踪器在调试时不支持它自己的编译输出(ScriptModel)?这应该是跟踪器的唯一功能吗?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-04-25
    • 2018-11-10
    • 2011-07-17
    • 2014-05-12
    • 1970-01-01
    相关资源
    最近更新 更多