【问题标题】:Filling queue from python iterator从python迭代器填充队列
【发布时间】:2017-04-05 12:58:44
【问题描述】:

我想创建一个由迭代器填充的队列。但是,在以下 MWE 中,总是将相同的值排入队列:

import tensorflow as tf
import numpy as np

# data
imgs = [np.random.randn(i,i) for i in [2,3,4,5]]

# iterate through data infinitly
def data_iterator():
    while True:
        for img in imgs:
            yield img

it = data_iterator()

# create queue for data
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64])

# feed next element from iterator
enqueue_op = q.enqueue(list(next(it)))

# setup queue runner
numberOfThreads = 1 
qr = tf.train.QueueRunner(q, [enqueue_op] * numberOfThreads)
tf.train.add_queue_runner(qr) 

# dequeue
dequeue_op  = q.dequeue() 
dequeue_op = tf.Print(dequeue_op, data=[dequeue_op], message="dequeue()")

# We start the session as usual ...
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(10):
        data = sess.run(dequeue_op)
        print(data)
.
    coord.request_stop()
    coord.join(threads)

我必须使用feed_dict 吗?如果是,我必须如何将它与 QueueRunner 结合使用?

【问题讨论】:

    标签: python-3.x tensorflow


    【解决方案1】:

    运行时

    enqueue_op = q.enqueue(list(next(it)))
    

    tensorflow 将只执行一次 list(next(it))。此后,每次运行enqueue_op 时,它都会保存第一个列表并将其添加到 q。为避免这种情况,您必须使用占位符。馈送占位符与tf.train.QueueRunner 不兼容。而是使用这个:

    import tensorflow as tf
    import numpy as np
    import threading
    
    # data
    imgs = [np.random.randn(i,i) for i in [2,3,4,5]]
    
    # iterate through data infinitly
    def data_iterator():
        while True:
            for img in imgs:
                yield img
    
    it = data_iterator()
    
    # create queue for data
    q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64])
    
    # feed next element from iterator
    
    img_p = tf.placeholder(tf.float64, [None, None])
    enqueue_op = q.enqueue(img_p)
    
    dequeue_op  = q.dequeue()
    
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
    
        def enqueue_thread():
            with coord.stop_on_exception():
                while not coord.should_stop():
                    sess.run(enqueue_op, feed_dict={img_p: list(next(it))})
    
        numberOfThreads = 1
        for i in range(numberOfThreads):
          threading.Thread(target=enqueue_thread).start()
    
    
    
        for i in range(3):
            data = sess.run(dequeue_op)
            print(data)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-01-04
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-08-24
      • 2013-04-23
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多