【问题标题】:Truncated Backpropagation Through Time (BPTT) in PytorchPytorch 中的随时间截断反向传播 (BPTT)
【发布时间】:2019-05-23 14:08:03
【问题描述】:

在 pytorch 中,我通过以下方式开始反向传播(通过时间)来训练 RNN/GRU/LSTM 网络:

loss.backward()

当序列很长时,我想通过时间进行截断反向传播,而不是使用整个序列的正常时间反向传播。

但我在 Pytorch API 中找不到任何参数或函数来设置截断的 BPTT。我错过了吗?我应该自己在 Pytorch 中编写代码吗?

【问题讨论】:

  • 只需在要剪切反向传播的位置使用h = h.detach()。请参阅语言建模示例中的repackage_hidden()。它有效地进行截断。
  • 谢谢。在这段代码中,哪个参数控制我想要 BPTT 的序列数?例如,代码中的序列长度 (args.bptt) 为 35,假设我希望 BPTT 仅在最后 5 个序列上完成。 5.使用什么参数。

标签: pytorch backpropagation truncated


【解决方案1】:

这是一个例子:

for t in range(T):
   y = lstm(y)
   if T-t == k:
      out.detach()
out.backward()

所以在本例中,k 是您用来控制要展开的时间步长的参数。

【讨论】:

  • 您是否也应该在 if 子句中添加out.backward()
  • 是的,如果您想在每个有效时间步进行反向传播。所示示例仅在最后一个有效时间步上进行更新。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2020-11-04
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-11-20
  • 2017-01-20
  • 2011-08-25
相关资源
最近更新 更多