【发布时间】:2017-12-28 00:42:16
【问题描述】:
我发现了 tensorflow 的 lstm 单元格(不限于 lstm,但我只用这个检查过)的一个特殊属性,据我所知尚未报告。 我不知道它是否真的有,所以我把这篇文章留在了 SO 中。下面是这个问题的玩具代码:
import tensorflow as tf
import numpy as np
import time
def network(input_list):
input,init_hidden_c,init_hidden_m = input_list
cell = tf.nn.rnn_cell.BasicLSTMCell(256, state_is_tuple=True)
init_hidden = tf.nn.rnn_cell.LSTMStateTuple(init_hidden_c, init_hidden_m)
states, hidden_cm = tf.nn.dynamic_rnn(cell, input, dtype=tf.float32, initial_state=init_hidden)
net = [v for v in tf.trainable_variables()]
return states, hidden_cm, net
def action(x, h_c, h_m):
t0 = time.time()
outputs, output_h = sess.run([rnn_states[:,-1:,:], rnn_hidden_cm], feed_dict={
rnn_input:x,
rnn_init_hidden_c: h_c,
rnn_init_hidden_m: h_m
})
dt = time.time() - t0
return outputs, output_h, dt
rnn_input = tf.placeholder("float", [None, None, 512])
rnn_init_hidden_c = tf.placeholder("float", [None,256])
rnn_init_hidden_m = tf.placeholder("float", [None,256])
rnn_input_list = [rnn_input, rnn_init_hidden_c, rnn_init_hidden_m]
rnn_states, rnn_hidden_cm, rnn_net = network(rnn_input_list)
feed_input = np.random.uniform(low=-1.,high=1.,size=(1,1,512))
feed_init_hidden_c = np.zeros(shape=(1,256))
feed_init_hidden_m = np.zeros(shape=(1,256))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(10000):
_, output_hidden_cm, deltat = action(feed_input, feed_init_hidden_c, feed_init_hidden_m)
if i % 10 == 0:
print 'Running time: ' + str(deltat)
(feed_init_hidden_c, feed_init_hidden_m) = output_hidden_cm
feed_input = np.random.uniform(low=-1.,high=1.,size=(1,1,512))
[不重要]此代码的作用是从包含 LSTM 的“network()”函数生成输出,其中输入的时间维度为 1,因此输出的时间维度也为 1,并为每个运行步骤拉入和拉出初始状态。
[重要] 查看 'sess.run()' 部分。由于某些原因,在我的真实代码中,我碰巧将 [:,-1:,:] 用于“rnn_states”。然后发生的事情是每个 'sess.run()' 花费的时间增加。对于我自己的一些检查,我发现这种减速源于 [:,-1:,:]。我只是想在最后一步获得输出。如果您执行 'outputs, output_h = sess.run([rnn_states, rnn_hidden_cm], feed_dict{~' w/o [:,-1:,:] 并采用 'last_output = outputs[:,-1:,:]'在 'sess.run()' 之后,不会出现减速。
我不知道为什么在 [:,-1:,:] 运行时会发生这种指数级的时间增量。这是不是 tensorflow 的性质没有被记录,但速度特别慢(可能会自己添加更多图表?)? 谢谢,希望这篇文章的其他用户不要发生这个错误。
【问题讨论】:
-
您只需将该切片移到 for 循环之外。
-
@Aaron:我猜“for”循环不是重点。 “action()”在“输出”的最后一个时间步骤输出“输出”,我发布了是否可以在“sess.run()”中从“输出”中切出最后一个“输出” ' - 结果证明是有问题的 - 或者没有。
标签: tensorflow slice lstm