【发布时间】:2017-05-21 14:49:26
【问题描述】:
我正在使用 TensorFlow 的 python API 来训练 LSTM 的一个变体。
为此,我使用tf.while_loop 函数迭代时间步长。
在 cpu 上运行我的脚本时,它不会产生任何错误消息,但在 gpu 上 python 崩溃由于:
...tensorflow/tensorflow/core/framework/tensor.cc:885] Check failed: nullptr != b.buf_ (nullptr vs. 00...)
导致此失败的代码部分(注释掉它时,它可以工作)位于 while 循环的主体中:
...
h_gathered = h_ta.gather(tf.range(time))
h_gathered = tf.transpose(h_gathered, [1, 0, 2])
syn_t = self.syntactic_weights_ta.read(time)[:, :time]
syn_t = tf.expand_dims(syn_t, 1)
syn_state_t = tf.squeeze(tf.tanh(tf.matmul(syn_t, h_gathered)), 1)
...
其中time 是从零开始并在每一步之后递增,h_ta 是一个 TensorArray
h_ta = tf.TensorArray(
dtype=dtype,
size=max_seq_len,
clear_after_read=False,
element_shape=[batch_size, num_hidden],
tensor_array_name="fw_output")
而self.syntactic_weights_ta 也是一个 TensorArray
self.syntactic_weights_ta = tf.TensorArray(
dtype=dtype,
size=max_seq_len,
tensor_array_name="fw_syntactic_weights")
self.syntactic_weights_ta = self.syntactic_weights_ta.unstack(syntactic_weights)
我试图在代码 sn-p 中实现的基本上是过去输出的加权和,存储在 h_ta 中。
最后我用tf.train.AdamOptimizer训练网络。
我再次测试了脚本,但这次将 while 循环中的 swap_memory 参数设置为 False 并且它也适用于 GPU,但我真的很想知道为什么它不适用于 @ 987654333@.
【问题讨论】:
标签: python tensorflow nullptr