【发布时间】:2019-09-24 14:56:51
【问题描述】:
import tensorflow as tf
from tf.contrib import rnn
lstm_f = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_b = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
blstm_out, state_f, state_b = rnn.static_bidirectional_rnn(lstm_f, lstm_b, x, dtype=tf.float32)
上面的代码适用于 tensorflow 1.x,但是我觉得很难找到使用 tensorflow 2.0 API 重写此代码的方法。
我知道我应该从 tf.keras.layers.LSTMCell() 开始,但我不知道什么 API 函数适合 2 个 LSTMCell 实例作为输入。
【问题讨论】:
标签: tensorflow deep-learning lstm recurrent-neural-network tensorflow2.0