【发布时间】:2019-05-23 14:08:03
【问题描述】:
在 pytorch 中,我通过以下方式开始反向传播(通过时间)来训练 RNN/GRU/LSTM 网络:
loss.backward()
当序列很长时,我想通过时间进行截断反向传播,而不是使用整个序列的正常时间反向传播。
但我在 Pytorch API 中找不到任何参数或函数来设置截断的 BPTT。我错过了吗?我应该自己在 Pytorch 中编写代码吗?
【问题讨论】:
-
只需在要剪切反向传播的位置使用
h = h.detach()。请参阅语言建模示例中的repackage_hidden()。它有效地进行截断。 -
谢谢。在这段代码中,哪个参数控制我想要 BPTT 的序列数?例如,代码中的序列长度 (args.bptt) 为 35,假设我希望 BPTT 仅在最后 5 个序列上完成。 5.使用什么参数。
标签: pytorch backpropagation truncated