【问题标题】:I want to read data from TFRecord我想从 TFRecord 读取数据
【发布时间】:2019-03-27 13:36:32
【问题描述】:

我将图像日期保存到 tfrecord 中,但无法使用 tensorflow dataset api 解析它。

我的环境

  • Ubuntu 18.04
  • Python 3.6.8
  • Jupyter 笔记本
  • 张量流 1.12.0

我通过以下代码保存了图像数据,

writer = tf.python_io.TFRecordWriter('training.tfrecord')

# X_train: paths to the image, y_train: labels (0 or 1)
for image_path, label in zip(X_train, y_train):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (150, 150)) / 255.0
    ex = tf.train.Example(
        features = tf.train.Features(
            feature={
                'image' : tf.train.Feature(float_list = tf.train.FloatList(value=image.ravel())),
                'label' : tf.train.Feature(int64_list = tf.train.Int64List(value=[label]))
            }
        )
    )
    writer.write(ex.SerializeToString())
writer.close()

我尝试从 tfrecord 文件中获取图像。

for record in tf.python_io.tf_record_iterator('test.tfrecord'):
    example = tf.train.Example()
    example.ParseFromString(record)

    img = example.features.feature['image'].float_list.value
    label = example.features.feature['label'].int64_list.value[0]

此方法有效。

但当我使用 Dataset API 为我的 ML 模型获取图像时,它不会。

def _parse_function(example_proto):
    features = {
        'label' : tf.FixedLenFeature((), tf.int64),
        'image' : tf.FixedLenFeature((), tf.float32)
    }
    parsed_features = tf.parse_single_example(example_proto, features)

    return parsed_features['image'], parsed_features['label']

def read_image(images, labels):
    label = tf.cast(labels, tf.int32)
    images = tf.cast(images, tf.float32)
    image = tf.reshape(images, [150, 150, 3])

# read the data
dataset = tf.data.TFRecordDataset('training.tfrecord')
dataset = dataset.map(_parse_function)
dataset = dataset.map(read_image) # <- ERROR!

错误消息是

ValueError: Cannot reshape a tensor with 1 elements to shape [150,150,3] (67500 elements) for 'Reshape' (op: 'Reshape') with input shapes: [], [3] and with input tensors computed as partial shapes: input[1] = [150,150,3].

我虽然这个错误的原因是数组的形状不对,所以我确认了“数据集”的元素

<MapDataset shapes: ((), ()), types: (tf.float32, tf.int64)>

“数据集”变量没有数据。我不知道为什么会这样。

后记

我尝试了 Sharky 的解决方案,结果,

def parse(example_proto):
    features = {
        'label' : tf.FixedLenFeature((), tf.string, ''),
        'image' : tf.FixedLenFeature((), tf.string, '')
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    img_shape = tf.stack([150, 150, 3])
    image = tf.decode_raw(parsed_features['image'], tf.float32)
    image = tf.reshape(image, img_shape)
    label = tf.decode_raw(parsed_features['label'], tf.int32)
    label = tf.reshape(label, tf.stack([1]))

    return image, label

我认为有效。但我无法从此 MapDataset 类型对象中获取数组。该怎么做?

【问题讨论】:

    标签: python tensorflow machine-learning


    【解决方案1】:

    尝试使用单个解析函数

    def parse(example_proto):
        features = {
            'label' : tf.FixedLenFeature((), tf.int64),
            'image' : tf.FixedLenFeature((), tf.string)
        }
        parsed_features = tf.parse_single_example(example_proto, features)
        img_shape = tf.stack([height, width, channel])
        image = tf.decode_raw(parsed_features['image'], tf.float32)
        image = tf.reshape(image, img_shape)
        label = tf.cast(parsed['label'], tf.int32)
        return image, label
    

    好的,parse_single_example 似乎需要字符串类型而不是浮点数。我建议像这样编码

    def int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    
    def bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    writer = tf.python_io.TFRecordWriter('training.tfrecord')
    
    for image_path, label in zip(X_train, y_train):
        image = cv2.imread(image_path)
        image = cv2.resize(image, (150, 150)) / 255.0
        img_raw = image.tostring()
        ex = tf.train.Example(features=tf.train.Features(feature={                                                                     
                            'image': bytes_feature(img_raw),
                            'label': int64_feature(label)
                             }))
        writer.write(ex.SerializeToString())
    writer.close()
    

    【讨论】:

    • 这很奇怪。我能想到的一件事是您使用tf.train.FloatList 而不是BytesList diring 转换为tfrecords。尝试用 float 替换 'image' : tf.FixedLenFeature((), tf.float32) 到 tf.string, ''
    • 感谢您的建议。它似乎运作良好。我在上面添加了代码和注释。
    • 正如我所说,尝试使用字节功能进行编码,这应该可以工作
    • 我尝试使用字节编码(img_raw=image.tostring() ?)并保存。它成功了,但是解析不起作用(在解析函数 [height, width, channel]=[150,150,3] 中?)。错误消息如下,
    • ValueError: 试图将“张量”转换为张量并失败。错误:参数必须是密集张量:FixedLenFeature(shape=(), dtype=tf.float32, default_value=None) - 得到形状 [3],但想要 [3, 0]。
    猜你喜欢
    • 2018-11-06
    • 2018-03-18
    • 1970-01-01
    • 1970-01-01
    • 2020-09-16
    • 2018-04-09
    • 2018-06-29
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多