【问题标题】:upgrade code rnn.static_bidirectional_rnn to fit with tensorflow 2.0 API升级代码 rnn.static_bidirectional_rnn 以适应 tensorflow 2.0 API
【发布时间】:2019-09-24 14:56:51
【问题描述】:
import tensorflow as tf
from tf.contrib import rnn
lstm_f = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_b = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
blstm_out, state_f, state_b = rnn.static_bidirectional_rnn(lstm_f, lstm_b, x, dtype=tf.float32)

上面的代码适用于 tensorflow 1.x,但是我觉得很难找到使用 tensorflow 2.0 API 重写此代码的方法。

我知道我应该从 tf.keras.layers.LSTMCell() 开始,但我不知道什么 API 函数适合 2 个 LSTMCell 实例作为输入。

【问题讨论】:

    标签: tensorflow deep-learning lstm recurrent-neural-network tensorflow2.0


    【解决方案1】:

    相当于您的 sn-p 的 Keras 将是

    lstm = keras.layers.LSTM(n_hidden, unit_forget_bias=True, unroll=True)
    keras.layers.Bidirectional(lstm)
    

    请注意,虽然 Keras 有一个 LSTMCell 的实现,但您可能希望使用 LSTM 来代替,它不仅仅是一个单元格,而是一个完全展开的 RNN,一次对整个序列进行操作。默认情况下,RNN 通过 while 循环动态展开,我们通过传递 unroll=True 强制它是静态的(在 TF 1.X 术语中)。最后,keras.layers.Bidirectional 包装器使 RNN 成为双向的。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-05-11
      • 2020-10-28
      • 2019-05-24
      • 1970-01-01
      • 1970-01-01
      • 2020-10-28
      • 1970-01-01
      • 2018-11-11
      相关资源
      最近更新 更多