【问题标题】:Read TFRecord image data with new TensorFlow Dataset API使用新的 TensorFlow Dataset API 读取 TFRecord 图像数据
【发布时间】:2018-05-11 07:09:54
【问题描述】:

我在使用“新”(TensorFlow v1.4) 数据集 API 读取 TFRecord 格式图像数据时遇到问题。我相信问题在于我在尝试读取时以某种方式消耗了整个数据集而不是单个批次。我在这里有一个使用批处理/文件队列 API 执行此操作的工作示例:https://github.com/gnperdue/TFExperiments/tree/master/conv(好吧,在示例中我正在运行分类器,但读取 TFRecord 图像的代码在 DataReaders.py 类中)。

我相信问题函数是:

def parse_mnist_tfrec(tfrecord, features_shape):
    tfrecord_features = tf.parse_single_example(
        tfrecord,
        features={
            'features': tf.FixedLenFeature([], tf.string),
            'targets': tf.FixedLenFeature([], tf.string)
        }
    )
    features = tf.decode_raw(tfrecord_features['features'], tf.uint8)
    features = tf.reshape(features, features_shape)
    features = tf.cast(features, tf.float32)
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
    targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0)
    targets = tf.cast(targets, tf.float32)
    return features, targets

class MNISTDataReaderDset:
    def __init__(self, data_reader_dict):
        # doesn't matter here

    def batch_generator(self, num_epochs=1):
        def parse_fn(tfrecord):
            return parse_mnist_tfrec(
                tfrecord, self.name, self.features_shape
            )
        dataset = tf.data.TFRecordDataset(
            self.filenames_list, compression_type=self.compression_type
        )
        dataset = dataset.map(parse_fn)
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(self.batch_size)
        iterator = dataset.make_one_shot_iterator()
        batch_features, batch_labels = iterator.get_next()
        return batch_features, batch_labels

那么,在使用中:

        batch_features, batch_labels = \
            data_reader.batch_generator(num_epochs=1)

        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            # look at 3 batches only
            for _ in range(3):
                labels, feats = sess.run([
                    batch_labels, batch_features
                ])

这会产生如下错误:

 [[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
 Input to reshape is a tensor with 50000 values, but the requested shape has 1
 [[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

有人有什么想法吗?

我有一个阅读器示例中的完整代码的要点以及 TFRecord 文件的链接(我们的老朋友 MNIST,采用 TFRecord 形式):

https://gist.github.com/gnperdue/56092626d611ae23370a21fdeeb2abe8

谢谢!

编辑 - 我也尝试了flat_map,例如:

def batch_generator(self, num_epochs=1):
    """
    TODO - we can use placeholders for the list of file names and
    init with a feed_dict when we call `sess.run` - give this a
    try with one list for training and one for validation
    """
    def parse_fn(tfrecord):
        return parse_mnist_tfrec(
            tfrecord, self.name, self.features_shape
        )
    dataset = tf.data.Dataset.from_tensor_slices(self.filenames_list)
    dataset = dataset.flat_map(
        lambda filename: (
            tf.data.TFRecordDataset(
                filename, compression_type=self.compression_type
            ).map(parse_fn).batch(self.batch_size)
        )
    )
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

我也尝试只使用一个文件而不是列表(在我上面的第一种方法中)。无论如何,似乎 TF 总是想将整个文件吃到 TFRecordDataset 中,并且不会对单个记录进行操作。

【问题讨论】:

    标签: python tensorflow dataset


    【解决方案1】:

    好的,我想通了 - 上面的代码很好。问题是我创建 TFRecords 的脚本。基本上,我有一个这样的块

    def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
        writer = tf.python_io.TFRecordWriter(tfrecord_file)
        tfeat, ttarg = get_binary_data(reader, start_idx, stop_idx)
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'features': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[tfeat])
                    ),
                    'targets': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[ttarg])
                    )
                }
            )
        )
        writer.write(example.SerializeToString())
        writer.close()
    

    而我需要这样的块:

    def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
        writer = tf.python_io.TFRecordWriter(tfrecord_file)
        for idx in range(start_idx, stop_idx):
            tfeat, ttarg = get_binary_data(reader, idx)
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        'features': tf.train.Feature(
                            bytes_list=tf.train.BytesList(value=[tfeat])
                        ),
                        'targets': tf.train.Feature(
                            bytes_list=tf.train.BytesList(value=[ttarg])
                        )
                    }
                )
            )
            writer.write(example.SerializeToString())
        writer.close()
    

    也就是说 - 当我需要在数据中为每个示例创建一个时,我基本上是将整个数据块编写为一个巨大的 TFRecord。

    事实证明,如果您在旧文件和批处理队列 API 中使用任何一种方式,一切正常 - 像 tf.train.batch 这样的函数自动神奇地“智能”到足以分割大块或连接大量单个-example 根据您给它的内容记录成一个批次。当我修复了创建 TFRecords 文件的代码时,我不需要更改旧文件和批处理队列代码中的任何内容,它仍然可以很好地使用 TFRecords 文件。但是,Dataset API 对这种差异很敏感。这就是为什么在我上面的代码中它似乎总是在消耗整个文件——因为整个文件确实是一个大的 TFRecord。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-11-06
      • 2018-02-11
      • 1970-01-01
      • 2018-04-09
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多