【问题标题】:Numpy to TFrecords: Is there a more simple way to handle batch inputs from tfrecords?Numpy to TFrecords:有没有更简单的方法来处理来自 tfrecords 的批量输入?
【发布时间】:2018-01-07 17:15:57
【问题描述】:

我的问题是关于如何从多个(或分片)tfrecord 中获取批量输入。我已阅读示例https://github.com/tensorflow/models/blob/master/inception/inception/image_processing.py#L410。基本管道是,以训练集为例,(1)首先生成一系列tfrecords(例如,train-000-of-005train-001-of-005,...),(2)从这些文件名中,生成一个列表并输入它们进入tf.train.string_input_producer 以获取队列,(3)同时生成tf.RandomShuffleQueue 来做其他事情,(4)使用tf.train.batch_join 生成批量输入。

我认为这很复杂,我不确定这个过程的逻辑。就我而言,我有一个.npy 文件列表,我想生成分片 tfrecords(多个单独的 tfrecords,而不仅仅是一个大文件)。这些.npy 文件中的每一个都包含不同数量的正样本和负样本(2 类)。一种基本方法是生成一个大型 tfrecord 文件。但文件太大(~20Gb)。所以我求助于分片 tfrecords。有没有更简单的方法来做到这一点?谢谢。

【问题讨论】:

    标签: python tensorflow tensorflow-datasets tfrecord


    【解决方案1】:

    使用Dataset API 简化了整个过程。以下是这两个部分:(1): Convert numpy array to tfrecords(2,3,4): read the tfrecords to generate batches

    1。 从 numpy 数组创建 tfrecords:

        def npy_to_tfrecords(...):
           # write records to a tfrecords file
           writer = tf.python_io.TFRecordWriter(output_file)
    
           # Loop through all the features you want to write
           for ... :
              let say X is of np.array([[...][...]])
              let say y is of np.array[[0/1]]
    
             # Feature contains a map of string to feature proto objects
             feature = {}
             feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=X.flatten()))
             feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=y))
    
             # Construct the Example proto object
             example = tf.train.Example(features=tf.train.Features(feature=feature))
    
             # Serialize the example to a string
             serialized = example.SerializeToString()
    
             # write the serialized objec to the disk
             writer.write(serialized)
          writer.close()
    

    2。 使用 Dataset API (tensorflow >=1.2) 读取 tfrecord:

        # Creates a dataset that reads all of the examples from filenames.
        filenames = ["file1.tfrecord", "file2.tfrecord", ..."fileN.tfrecord"]
        dataset = tf.contrib.data.TFRecordDataset(filenames)
        # for version 1.5 and above use tf.data.TFRecordDataset
    
        # example proto decode
        def _parse_function(example_proto):
          keys_to_features = {'X':tf.FixedLenFeature((shape_of_npy_array), tf.float32),
                              'y': tf.FixedLenFeature((), tf.int64, default_value=0)}
          parsed_features = tf.parse_single_example(example_proto, keys_to_features)
         return parsed_features['X'], parsed_features['y']
    
        # Parse the record into tensors.
        dataset = dataset.map(_parse_function)  
    
        # Shuffle the dataset
        dataset = dataset.shuffle(buffer_size=10000)
    
        # Repeat the input indefinitly
        dataset = dataset.repeat()  
    
        # Generate batches
        dataset = dataset.batch(batch_size)
    
        # Create a one-shot iterator
        iterator = dataset.make_one_shot_iterator()
    
        # Get batch X and y
        X, y = iterator.get_next()
    

    【讨论】:

    • 嗨,先生,这个 api 是否支持 num_threadscapacity 就像 tf.train.shuffle_batch api 中的那样?在我的情况下,如果网络很小,那么 GPU 中的执行速度比数据加载快,这会导致 GPU 时间空闲。所以我想获取数据的队列总是满的。谢谢。
    • 非常感谢!
    • 感谢这个很好的例子 - 使用 reader = tf.TFRecordReader(); key, value = reader.read(filename_queue) 我得到一个键值对(值对应于代码中的 example_proto)。如何使用dataset = tf.contrib.data.TFRecordDataset(filenames) 获取密钥?
    • 是否可以将“shapeofnparray”存储在 TFRecord 中,然后使用类似于stackoverflow.com/a/42603692/2184122 的方式进行整形?我无法在旧方式和数据集方式之间进行映射。
    • example_proto 到底是什么?字符串还是字节数据?该变量分配在哪里?它分配给什么?
    猜你喜欢
    • 2014-04-17
    • 1970-01-01
    • 2021-10-21
    • 1970-01-01
    • 2014-01-18
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多