【发布时间】:2020-06-15 02:44:16
【问题描述】:
我有一个PyTorch LSTM 模型,我的forward 函数看起来像:
def forward(self, x, hidden):
print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
lstm_out, hidden = self.lstm(x, hidden)
return lstm_out, hidden
所有print 语句都显示torch.float64,我认为这是一个双倍。那我为什么会遇到这个问题呢?
我已经在所有相关的地方投了double。
【问题讨论】: