【问题标题】:RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2RuntimeError: 标量类型 Double 的预期对象,但参数 #2 得到标量类型 Float
【发布时间】:2020-06-15 02:44:16
【问题描述】:

我有一个PyTorch LSTM 模型,我的forward 函数看起来像:

    def forward(self, x, hidden):
        print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
        lstm_out, hidden = self.lstm(x, hidden)
        return lstm_out, hidden

所有print 语句都显示torch.float64,我认为这是一个双倍。那我为什么会遇到这个问题呢?

我已经在所有相关的地方投了double

【问题讨论】:

    标签: python pytorch dtype


    【解决方案1】:

    确保你的数据和模型都在dtypedouble

    对于模型:

    net = net.double()
    

    对于数据:

    net(x.double())
    

    一直是discussed on PyTorch forum

    【讨论】:

      猜你喜欢
      • 2020-05-31
      • 2020-12-02
      • 2019-11-06
      • 2021-07-10
      • 1970-01-01
      • 2020-10-21
      • 2020-01-29
      • 2021-08-08
      • 2021-09-16
      相关资源
      最近更新 更多