【问题标题】:How to view the output of tf.train.batch()如何查看 tf.train.batch() 的输出
【发布时间】:2016-08-30 02:14:13
【问题描述】:

我一直在关注 wildml 上的 "RNNs in TensorFlow, a Practical Guide and Undocumented Features" 帖子,但无法查看 tf.train.batch() 函数的输出。存储、加载和处理输入的代码如下:

sequences = [[1, 2, 3], [4, 5, 1], [1, 2]]
label_sequences = [[0, 1, 0], [1, 0, 0], [1, 1]]

def make_example(sequence, labels):
    # The object we return
    ex = tf.train.SequenceExample()
    # A non-sequential feature of our example
    sequence_length = len(sequence)
    ex.context.feature["length"].int64_list.value.append(sequence_length)
    # Feature lists for the two sequential features of our example
    fl_tokens = ex.feature_lists.feature_list["tokens"]
    fl_labels = ex.feature_lists.feature_list["labels"]
    for token, label in zip(sequence, labels):
        fl_tokens.feature.add().int64_list.value.append(token)
        fl_labels.feature.add().int64_list.value.append(label)
    return ex
fname = "/home/someUser/PycharmProjects/someTensors"
writer = tf.python_io.TFRecordWriter(fname)
for sequence, label_sequence in zip(sequences, label_sequences):
    ex = make_example(sequence, label_sequence)
    print ex
    writer.write(ex.SerializeToString())
writer.close()
print("Wrote to {}".format(fname))
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer([fname])
_, serialized_example = reader.read(filename_queue)
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=serialized_example, context_features=context_features,
sequence_features=sequence_features)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
tf.train.start_queue_runners(sess=sess)

batched_data = tf.train.batch(tensors=
[context_parsed['length'], sequence_parsed['tokens'],     
sequence_parsed['labels']], batch_size=5, dynamic_pad= True)

batched_context_data = tf.train.batch(tensors= [context_parsed['length']],
batch_size=5, dynamic_pad= True)

batched_tokens_data = tf.train.batch(tensors=
[sequence_parsed['tokens']], batch_size=5, dynamic_pad= True)

batched_labels_data = tf.train.batch(tensors=
[sequence_parsed['labels']], batch_size=5, dynamic_pad= True)

根据帖子,应该可以查看批次的输出如下:

res = tf.contrib.learn.run_n({"y": batched_data}, n=1, feed_dict=None)
print("Batch shape: {}".format(res[0]["y"].shape))
print(res[0]["y"])

或者更具体的情况如下:

res = tf.contrib.learn.run_n({"y": batched_context_data}, n=1, feed_dict=None)
print("Batch shape: {}".format(res[0]["y"].shape))
print(res[0]["y"])

不幸的是,TensorFlow 需要很长时间才能计算这两种情况,所以我最终终止了该进程。谁能告诉我我做错了什么?

非常感谢!

【问题讨论】:

    标签: tensorflow recurrent-neural-network lstm


    【解决方案1】:

    我怀疑问题出在这一行,调用tf.train.start_queue_runners()

    tf.train.start_queue_runners(sess=sess)
    

    ...出现在这些行之前,其中包含对 tf.train.batch() 的调用:

    batched_data = tf.train.batch(...)
    
    batched_context_data = tf.train.batch(...)
    
    batched_tokens_data = tf.train.batch(...)
    
    batched_labels_data = tf.train.batch(...)
    

    如果您将调用移至tf.train.start_queue_runners()调用tf.train.batch(),那么您的程序不应再出现死锁。


    为什么会这样? tf.train.batch() 函数在内部创建队列来缓冲正在批处理的数据,在 TensorFlow 中填充这些队列的常用方法是创建一个 "queue runner",它(通常)是一个将元素移动到队列中的后台线程. tf.train.start_queue_runners() 函数在调用时为所有已注册的队列运行器启动后台线程,但如果在创建队列运行器之前调用它,则不会启动这些线程。

    【讨论】:

      猜你喜欢
      • 2012-11-09
      • 1970-01-01
      • 2022-01-17
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多