【问题标题】:batch_first in PyTorch LSTMPyTorch LSTM 中的 batch_first
【发布时间】:2021-11-16 11:11:04
【问题描述】:

我是这个领域的新手,所以我仍然不了解 PyTorch LSTM 中的 batch_first。我尝试了有人提到我的代码,当 batch_first = False 时,它​​适用于我的火车数据,它为 Official LSTM 和 Manual LSTM 产生相同的输出。但是,当我更改 batch_first = True 时,它​​不再产生相同的值,而我需要将 batch_first 更改为 True,因为我的数据集形状是 (Batch, Sequences, Input size) 的张量。当 batch_first = True 时,Manual LSTM 的哪一部分需要更改以产生与 Official LSTM 相同的输出?这是代码sn-p:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

train_x = torch.tensor([[[0.14285755], [0], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.04761982], [0.09523869], [0.09523869], [0.09523869], 
  [0.09523869], [0.09523869], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.09523869], [0.        ], [0.        ], [0.        ],
  [0.        ], [0.09523869], [0.09523869], [0.09523869], [0.09523869],
  [0.09523869], [0.09523869], [0.09523869],[0.14285755], [0.14285755]]], 
  requires_grad=True)

seed = 23
torch.manual_seed(seed)
np.random.seed(seed)

pytorch_lstm = torch.nn.LSTM(1, 1, bidirectional=False, num_layers=1, batch_first=True)
weights = torch.randn(pytorch_lstm.weight_ih_l0.shape,dtype = torch.float)
pytorch_lstm.weight_ih_l0 = torch.nn.Parameter(weights)
# Set bias to Zero
pytorch_lstm.bias_ih_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm.weight_hh_l0 = torch.nn.Parameter(torch.ones(pytorch_lstm.weight_hh_l0.shape))
# Set bias to Zero
pytorch_lstm.bias_hh_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm_out = pytorch_lstm(train_x)

batch_size=1

# Manual Calculation
W_ii, W_if, W_ig, W_io = pytorch_lstm.weight_ih_l0.split(1, dim=0)
b_ii, b_if, b_ig, b_io = pytorch_lstm.bias_ih_l0.split(1, dim=0)

W_hi, W_hf, W_hg, W_ho = pytorch_lstm.weight_hh_l0.split(1, dim=0)
b_hi, b_hf, b_hg, b_ho = pytorch_lstm.bias_hh_l0.split(1, dim=0)

prev_h = torch.zeros((batchsize,1))
prev_c = torch.zeros((batchsize,1))

i_t = torch.sigmoid(F.linear(train_x, W_ii, b_ii) + F.linear(prev_h, W_hi, b_hi))
f_t = torch.sigmoid(F.linear(train_x, W_if, b_if) + F.linear(prev_h, W_hf, b_hf))
g_t = torch.tanh(F.linear(train_x, W_ig, b_ig) + F.linear(prev_h, W_hg, b_hg))
o_t = torch.sigmoid(F.linear(train_x, W_io, b_io) + F.linear(prev_h, W_ho, b_ho))
c_t = f_t * prev_c + i_t * g_t
h_t = o_t * torch.tanh(c_t)

print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], h_t))
print('nn.LSTM hidden {}, manual hidden {}'.format(pytorch_lstm_out[1][0], h_t))
print('nn.LSTM state {}, manual state {}'.format(pytorch_lstm_out[1][1], c_t))

【问题讨论】:

    标签: python pytorch lstm


    【解决方案1】:

    您必须一次遍历每个序列元素,并将计算出的隐藏状态和单元状态作为下一个时间步的输入...

    h_t = torch.zeros((batch_size,1))
    c_t = torch.zeros((batch_size,1))
    
    hidden_seq = []
    
    for t in range(30):
      x_t = train_x[:, t, :]
      i_t = torch.sigmoid(F.linear(x_t, W_ii, b_ii) + F.linear(h_t, W_hi, b_hi))
      f_t = torch.sigmoid(F.linear(x_t, W_if, b_if) + F.linear(h_t, W_hf, b_hf))
      g_t = torch.tanh(F.linear(x_t, W_ig, b_ig) + F.linear(h_t, W_hg, b_hg))
      o_t = torch.sigmoid(F.linear(x_t, W_io, b_io) + F.linear(h_t, W_ho, b_ho))
      c_t = f_t * c_t + i_t * g_t
      h_t = o_t * torch.tanh(c_t)
      hidden_seq.append(h_t.unsqueeze(0))
    
    hidden_seq = torch.cat(hidden_seq, dim=0)
    hidden_seq = hidden_seq.transpose(0, 1).contiguous()
    
    print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], hidden_seq))
    

    【讨论】:

      猜你喜欢
      • 2018-10-05
      • 2020-12-28
      • 2018-07-27
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2023-04-05
      • 2020-09-28
      • 2021-06-02
      相关资源
      最近更新 更多