【问题标题】:Tensorflow tfrecord not being read correctlyTensorflow tfrecord 没有被正确读取
【发布时间】:2018-01-16 01:49:34
【问题描述】:

我正在尝试使用 Tensorflow 在我自己的分割数据集上训练一个 CNN。根据我的研究,tfRecords 似乎是最好的方法。我已经想出了如何写入和读取 tfRecord 数据库,但我绝对没有尝试在 Tensorflow 图中成功读取它。这是一个从我的数据库中成功重建图像和地面实况的 sn-p:

data_path = 'Training/train.tfrecords'  # address to save the hdf5 file

record_iterator = tf.python_io.tf_record_iterator(path=data_path)
reconstructed_images = []
reconstructed_groundtruths = []
count = 0
for string_record in record_iterator:
  example = tf.train.Example()
  example.ParseFromString(string_record)
  height = int(example.features.feature['height']
                             .int64_list
                             .value[0])

  width = int(example.features.feature['width']
                            .int64_list
                            .value[0])

  gt_string = (example.features.feature['train/groundTruth']
                              .bytes_list
                              .value[0])

  image_string = (example.features.feature['train/image']
                            .bytes_list
                            .value[0])

  img_1d = np.fromstring(image_string, dtype=np.uint8)
  reconstructed_img = img_1d.reshape((height, width))
  gt_1d = np.fromstring(gt_string, dtype=np.uint8)
  reconstructed_gt = gt_1d.reshape((height, width))

  reconstructed_images.append(reconstructed_img)
  reconstructed_groundtruths.append(reconstructed_gt)
  count += 1

此代码成功地为我提供了数据库中图像和基本事实标签的 numpy 数组列表。现在,为了尝试实际训练某些东西,我正在使用 MNIST 示例,您可以找到 here

我已将解码功能替换为以下内容:

def decode(serialized_example):

  features = tf.parse_single_example(
    serialized_example,
    # Defaults are not specified since both keys are required.
    features={
      'height': tf.FixedLenFeature([1],tf.int64),
      'width': tf.FixedLenFeature([1],tf.int64),
      'train/image': tf.FixedLenFeature([], tf.string),
      'train/groundTruth': tf.FixedLenFeature([], tf.string),
    })


  height = tf.cast(features['height'], tf.int64)
  width = tf.cast(features['width'], tf.int64)
  image = tf.decode_raw(features['train/image'], tf.uint8)
  image.set_shape((height,width))
  gt = tf.decode_raw(features['train/groundTruth'], tf.uint8)
  gt.set_shape((height,width))


  return image, gt

当我运行它时,有多个问题表明纯代码无法读取数据库。如上所述,我会在解析height 的行上得到一个错误,它指出

int() 参数必须是字符串、类似字节的对象或数字,而不是 “张量”

如果我暂时只是将 heightwidth 设置为文字,我会在图像解析行上收到错误提示

形状 (?,) 和 (512, 512) 不兼容

很明显,这意味着图像没有从数据库中正确读取,但我完全不知道为什么或如何修复它。有人可以告诉我我做错了什么吗?

【问题讨论】:

    标签: python tensorflow tfrecord


    【解决方案1】:

    我完全靠运气找到了解决方案。显然,

    image.set_shape((height,width)) 
    

    应该是

    image = tf.reshape(image,(height,width,1))
    

    和 gt 类似。我不知道为什么我正在关注的 Tensorflow 教程使用 set_shape ......我只能猜测它适用于 1d 但不适用于 2d 或更多?我现在可以看到它也不是张量函数,所以它不能使用像高度这样的图形相关变量,但这本身并不能解释为什么当我用全局替换 (height,width) 时它不起作用常数。如果有人知道,将不胜感激。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2011-04-07
      • 1970-01-01
      • 2012-01-06
      • 1970-01-01
      • 2019-06-04
      • 1970-01-01
      相关资源
      最近更新 更多