【发布时间】:2023-12-07 00:43:02
【问题描述】:
我有以下代码:
import torch
import torch.nn as nn
model = nn.Sequential(
nn.LSTM(300, 300),
nn.Linear(300, 100),
nn.ReLU(),
nn.Linear(300, 7),
)
s = torch.ones(1, 50, 300)
a = model(s)
我得到:
My-MBP:Desktop myname$ python3 testmodel.py
Traceback (most recent call last):
File "testmodel.py", line 12, in <module>
a = model(s)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/functional.py", line 1688, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'
为什么?尺寸应该没问题。当*input 在model.forward 中定义时,我看到了针对此问题的相关修复,但我什至还没有实现任何东西。
/edit: 等等,有*input!?我怎样才能覆盖它?
【问题讨论】:
-
这个错误是因为 LSTM 的输出是一个包含输出、隐藏状态和单元状态的元组。您不能将其传递给线性层。你应该只传递 LSTM 的输出,而不是隐藏和单元状态。