【问题标题】:Unable to understand tf.nn.raw_rnn无法理解 tf.nn.raw_rnn
【发布时间】:2018-05-08 22:45:02
【问题描述】:
在tf.nn.raw_rnn 的official documentation 中,当loop_fn 第一次运行时,我们将发射结构作为loop_fn 的第三个输出。
稍后使用 emit_structure 将tf.zeros_like(emit_structure) 复制到由emit = tf.where(finished, tf.zeros_like(emit_structure), emit) 完成的小批量条目中。
我对谷歌缺乏理解或糟糕的文档是:发射结构是None 所以tf.where(finished, tf.zeros_like(emit_structure), emit) 会抛出一个ValueError,因为tf.zeros_like(None) 会这样做。有人可以在这里填写我所缺少的吗?
【问题讨论】:
标签:
python-3.x
tensorflow
recurrent-neural-network
rnn
tensorflow-slim
【解决方案1】:
是的,这个地方的文档相当混乱。如果你看tf.nn.raw_rnn的内部结构,关键词是“in pseudo-code”,所以文档中的例子并不准确。
确切的源代码如下所示(可能因您的 tensorflow 版本而异):
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size
flat_emit_size = nest.flatten(emit_structure)
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
所以它处理emit_structure is None 时的情况,并简单地取值cell.output_size。这就是为什么没有什么东西真的坏了。