【问题标题】:How do you load an LMDB file into TensorFlow?如何将 LMDB 文件加载到 TensorFlow 中?
【发布时间】:2016-05-20 03:36:30
【问题描述】:

我有一大组 (1 TB) 数据,分为大约 3,000 个 CSV 文件。我的计划是将其转换为一个大型 LMDB 文件,以便可以快速读取它以训练神经网络。但是,我找不到任何关于如何将 LMDB 文件加载到 TensorFlow 中的文档。有谁知道如何做到这一点?我知道 TensorFlow 可以读取 CSV 文件,但我认为这太慢了。

【问题讨论】:

    标签: machine-learning tensorflow


    【解决方案1】:

    根据this,TensorFlow中有几种读取数据的方法。

    最简单的一种是通过 占位符 提供您的数据。使用 占位符 时 - 改组和批处理的责任在您身上。

    如果您想将混洗和批处理委托给框架,那么您需要创建一个输入管道。问题是——如何将 lmdb 数据注入符号输入管道。一种可能的解决方案是使用tf.py_func 操作。这是一个例子:

    def create_input_pipeline(lmdb_env, keys, num_epochs=10, batch_size=64):
       key_producer = tf.train.string_input_producer(keys, 
                                                     num_epochs=num_epochs,
                                                     shuffle=True)
       single_key = key_producer.dequeue()
    
       def get_bytes_from_lmdb(key):
          with lmdb_env.begin() as txn:
             lmdb_val = txn.get(key)
          example = get_example_from_val(lmdb_val) # A single example (numpy array)
          label = get_label_from_val(lmdb_val)     # The label, could be a scalar
          return example, label
    
       single_example, single_label = tf.py_func(get_bytes_from_lmdb,
                                                 [single_key], [tf.float32, tf.float32])
       # if you know the shapes of the tensors you can set them here:
       # single_example.set_shape([224,224,3])
    
       batch_examples, batch_labels = tf.train.batch([single_example, single_label],
                                                     batch_size)
       return batch_examples, batch_labels
    

    tf.py_func 操作在 TensorFlow 图中插入了对常规 python 代码的调用,我们需要指定输入以及输出的数量和类型。 tf.train.string_input_producer 使用给定的键创建一个混洗队列。 tf.train.batch 操作创建另一个包含批量数据的队列。训练时,batch_examplesbatch_labels 的每次评估都会从该队列中取出另一个批次。

    因为我们创建了队列,所以在开始训练之前我们需要小心并运行QueueRunner 对象。这样做是这样的(来自 TensorFlow 文档):

    # Create the graph, etc.
    init_op = tf.initialize_all_variables()
    
    # Create a session for running operations in the Graph.
    sess = tf.Session()
    
    # Initialize the variables (like the epoch counter).
    sess.run(init_op)
    
    # Start input enqueue threads.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
        while not coord.should_stop():
            # Run training steps or whatever
            sess.run(train_op)
    
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # When done, ask the threads to stop.
        coord.request_stop()
    
    # Wait for threads to finish.
    coord.join(threads)
    sess.close()
    

    【讨论】:

      猜你喜欢
      • 2018-04-13
      • 2019-10-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-04-21
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多