【问题标题】:Shape error when reading tfrecords with tf.data.TFRecordDataset?使用 tf.data.TFRecordDataset 读取 tfrecord 时出现形状错误?
【发布时间】:2018-11-29 13:36:20
【问题描述】:

我用自己的图像创建了一个 tfrecords 文件,当我尝试使用 tf.data.TFRecordDataset 读取它时,出现了一个形状错误: 我用代码创建了 tfrecord:

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def img_to_tfrecord(data_path):
writer = tf.python_io.TFRecordWriter('test_imgs/test.tfrecords')

file = open('test_imgs/test.txt')
for line in file.readlines():
    img_name = line.split(' ')[0]
    label = int(line.split(' ')[1])
    img_path = data_path + '/test_imgs/' + img_name

    img = Image.open(img_path)
    img = img.resize((224, 224))
    img_bytes = img.tobytes()

    feature={'train_img': _bytes_feature(img_bytes),
             'train_label': _int64_feature(label)}
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(example.SerializeToString())

writer.close()

并用代码阅读:

def parser(record):
parsed = tf.parse_single_example(record, {'train_img': tf.FixedLenFeature((), tf.string),
                                          'train_label': tf.FixedLenFeature((), tf.int64)})

image = tf.decode_raw(parsed['train_img'], tf.uint8)
image = tf.reshape(image, [224, 224, 3])
label = tf.cast(parsed['train_label'], tf.int32)

return image, label

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    tf.reset_default_graph()

    dataset = tf.data.TFRecordDataset('test_imgs/test.tfrecords')
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=6).batch(4).repeat(2)
    iterator = dataset.make_one_shot_iterator()
    img, label = iterator.get_next()

    with tf.Session() as sess:
       a,b=sess.run([img, label])
       print(a.shape)

既然150528=224*224*3,那么200704是怎么来的呢?我看了很多教程还是解决不了这个问题。我已经关注了代码的解析类型:image = tf.decode_raw(parsed['train_img'], tf.uint8)。谁能帮帮我,我快崩溃了。 用于创建 tfrecord 的图像如下所示:

【问题讨论】:

    标签: python tensorflow reshape tfrecord


    【解决方案1】:

    好的,我解决了。我的上帝。这是关于图片本身的。其中一张图片的位深度为 32,则其形状为 (224,224,4)。

    【讨论】:

      猜你喜欢
      • 2018-06-29
      • 2020-02-03
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-05-19
      • 1970-01-01
      • 2019-08-31
      • 2019-05-18
      相关资源
      最近更新 更多