【问题标题】:tf.contrib.data.TFRecordDataset not able to read from *.tfrecordtf.contrib.data.TFRecordDataset 无法从 *.tfrecord 读取
【发布时间】:2018-03-18 02:05:34
【问题描述】:

在创建和加载 .tfrecord 文件的上下文中,我遇到了以下问题:

生成 dataset.tfrecord 文件

文件夹 /Batch_manager/assets 包含一些用于生成 dataset.tfrecord 文件的 *.tif 图像:

def _save_as_tfrecord(self, path, name):
    self.__filename = os.path.join(path, name + '.tfrecord')
    writer = tf.python_io.TFRecordWriter(self.__filename)
    print('Writing', self.__filename)
    for index, img in enumerate(self.load(get_iterator=True, n_images=1)):
        img = img[0]
        image_raw = img.tostring()
        rows = img.shape[0]
        cols = img.shape[1]
        try:
            depth = img.shape[2]
        except IndexError:
            depth = 1
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': self._int64_feature(rows), 
            'width': self._int64_feature(cols), 
            'depth': self._int64_feature(depth), 
            'label': self._int64_feature(int(self.target[index])), 
            'image_raw': self._bytes_feature(image_raw)
                }))
        writer.write(example.SerializeToString())
    writer.close()

从 dataset.tfrecord 文件中读取

接下来我尝试使用指向 dataset.tfrecord 文件的路径从该文件中读取:

def dataset_input_fn(self, path):
    dataset = tf.contrib.data.TFRecordDataset(path)

    def parser(record):
        keys_to_features = {
            "height": tf.FixedLenFeature((), tf.int64, default_value=""),
            "width": tf.FixedLenFeature((), tf.int64, default_value=""),
            "depth": tf.FixedLenFeature((), tf.int64, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64, default_value=""),
            "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        }
        print(record)
        features = tf.parse_single_example(record, features=keys_to_features)
        print(features)
        label = features['label']
        height = features['height']
        width = features['width']
        depth = features['depth']
        image = tf.decode_raw(features['image_raw'], tf.float32) 
        image = tf.reshape(image, [height, width, -1])
        label = tf.cast(features["label"], tf.int32)

        return {"image_raw": image, "height": height, "width": width, "depth":depth, "label":label}

    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(32)
    iterator = dataset.make_one_shot_iterator()

    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features = iterator.get_next()

    return Features

错误信息:

TypeError: 应为 int64,得到的是 'str' 类型的 ''。

这段代码有什么问题?我成功验证了 dataset.tfrecord 实际上包含正确的图像和元数据!

【问题讨论】:

  • self.load(...) 简单地返回一个迭代器,可用于基于每个图像加载。我很确定这个问题要么是因为我构建 example 变量并将其写入 dataset.tfrecord 的方式,要么是因为使用 tf.contrib.data.TFRecordDataset(path)parser 函数解析它的方式.map(func)

标签: tensorflow dataset read-data tfrecord


【解决方案1】:

发生错误是因为我复制并粘贴了这个示例,该示例将所有键值对的值设置为空字符串,由default_value="" 引起。从所有tf.FixedLenFeature 中删除它解决了这个问题。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-06-29
    • 1970-01-01
    • 1970-01-01
    • 2018-11-06
    相关资源
    最近更新 更多