【问题标题】:PyTorch - TypeError: forward() takes 1 positional argument but 2 were givenPyTorch - TypeError: forward() 接受 1 个位置参数,但给出了 2 个
【发布时间】:2020-12-01 18:47:49
【问题描述】:

我不明白这个错误来自哪里,模型的参数数量似乎是正确的,下面是我的模型:

class MancalaModel(nn.Module):

    def __init__(self, n_inputs=16, n_outputs=16):
        super().__init__()

        n_neurons = 256

        def create_block(n_in, n_out):
            block = nn.ModuleList()
            block.append(nn.Linear(n_in, n_out))
            block.append(nn.ReLU())
            return block

        self.blocks = nn.ModuleList()
        self.blocks.append(create_block(n_inputs, n_neurons))
        for _ in range(6):
            self.blocks.append(create_block(n_neurons, n_neurons))

        self.actor_block = nn.ModuleList()
        self.critic_block = nn.ModuleList()
        for _ in range(2):
            self.actor_block.append(create_block(n_neurons, n_neurons))
            self.critic_block.append(create_block(n_neurons, n_neurons))

        self.actor_block.append(create_block(n_neurons, n_outputs))
        self.critic_block.append(create_block(n_neurons, 1))

        self.apply(init_weights)

    def forward(self, x):
        x = self.blocks(x)
        actor = F.softmax(self.actor_block(x))
        critics = self.critic_block(x)
        return actor, critics

然后我创建一个实例并使用随机数进行前向传递

model = MancalaModel()
x = model(torch.rand(1, 16))

然后我得到 TypeError 说参数的数量不正确:

      2 model = MancalaModel()
----> 3 x = model(torch.rand(1, 16))
      4 # summary(model, (16,), device='cpu')
      5 

d:\environments\python\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

D:\UOM\Year3\AI & Games\KalahPlayer\agents\model_agent.py in forward(self, x)
     54 
     55     def forward(self, x):
---> 56         x = self.blocks(x)
     57         actor = F.softmax(self.actor_block(x))
     58         critics = self.critic_block(x)

d:\environments\python\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

TypeError: forward() takes 1 positional argument but 2 were given

感谢任何帮助,谢谢!

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    TL;DR
    您正在尝试 forwardnn.ModuleList - 这没有定义。 您需要将self.blocks 转换为nn.Sequential

            def create_block(n_in, n_out):
                # do not work with ModuleList here either.
                block = nn.Sequential(
                  nn.Linear(n_in, n_out),
                  nn.ReLU()
                )
                return block
    
            blocks = []  # simple list - not a member of self, for temporal use only.
            blocks.append(create_block(n_inputs, n_neurons))
            for _ in range(6):
                blocks.append(create_block(n_neurons, n_neurons))
            self.blocks = nn.Sequential(*blocks)  # convert the simple list to nn.Sequential
    

    我希望你得到NotImplementedError,而不是这个TypeError,因为你的self.blocksnn.ModuleList 类型,它的forward 方法抛出NotImplementedError。我刚刚创建了pull request 来解决这个令人困惑的问题。
    更新(2021 年 4 月 22 日): PR was merged。在未来的版本中,您应该在调用nn.ModuleListnn.ModuleDict 时看到NotImplementedError

    【讨论】:

    • 我可以确认这个问题存在于torch 1.7.1
    • @RakshitKothari 这个拉取请求的状态是“卡住” - 它与火炬脚本无关的问题相冲突。随时在此处发表评论并引起 pytorch 开发人员的注意。
    • @RakshitKothari 此 PR 于 4 月 22 日合并。对于更新的 pytorch 版本,您应该获得 NotImplementedError。我用 1.9.0 版确认了这一点
    猜你喜欢
    • 2021-09-01
    • 2021-08-20
    • 1970-01-01
    • 2020-06-16
    • 1970-01-01
    • 1970-01-01
    • 2021-07-06
    • 2017-01-12
    • 2017-04-22
    相关资源
    最近更新 更多