【发布时间】:2019-03-27 13:36:32
【问题描述】:
我将图像日期保存到 tfrecord 中,但无法使用 tensorflow dataset api 解析它。
我的环境
- Ubuntu 18.04
- Python 3.6.8
- Jupyter 笔记本
- 张量流 1.12.0
我通过以下代码保存了图像数据,
writer = tf.python_io.TFRecordWriter('training.tfrecord')
# X_train: paths to the image, y_train: labels (0 or 1)
for image_path, label in zip(X_train, y_train):
image = cv2.imread(image_path)
image = cv2.resize(image, (150, 150)) / 255.0
ex = tf.train.Example(
features = tf.train.Features(
feature={
'image' : tf.train.Feature(float_list = tf.train.FloatList(value=image.ravel())),
'label' : tf.train.Feature(int64_list = tf.train.Int64List(value=[label]))
}
)
)
writer.write(ex.SerializeToString())
writer.close()
我尝试从 tfrecord 文件中获取图像。
for record in tf.python_io.tf_record_iterator('test.tfrecord'):
example = tf.train.Example()
example.ParseFromString(record)
img = example.features.feature['image'].float_list.value
label = example.features.feature['label'].int64_list.value[0]
此方法有效。
但当我使用 Dataset API 为我的 ML 模型获取图像时,它不会。
def _parse_function(example_proto):
features = {
'label' : tf.FixedLenFeature((), tf.int64),
'image' : tf.FixedLenFeature((), tf.float32)
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features['image'], parsed_features['label']
def read_image(images, labels):
label = tf.cast(labels, tf.int32)
images = tf.cast(images, tf.float32)
image = tf.reshape(images, [150, 150, 3])
# read the data
dataset = tf.data.TFRecordDataset('training.tfrecord')
dataset = dataset.map(_parse_function)
dataset = dataset.map(read_image) # <- ERROR!
错误消息是
ValueError: Cannot reshape a tensor with 1 elements to shape [150,150,3] (67500 elements) for 'Reshape' (op: 'Reshape') with input shapes: [], [3] and with input tensors computed as partial shapes: input[1] = [150,150,3].
我虽然这个错误的原因是数组的形状不对,所以我确认了“数据集”的元素
<MapDataset shapes: ((), ()), types: (tf.float32, tf.int64)>
“数据集”变量没有数据。我不知道为什么会这样。
后记
我尝试了 Sharky 的解决方案,结果,
def parse(example_proto):
features = {
'label' : tf.FixedLenFeature((), tf.string, ''),
'image' : tf.FixedLenFeature((), tf.string, '')
}
parsed_features = tf.parse_single_example(example_proto, features)
img_shape = tf.stack([150, 150, 3])
image = tf.decode_raw(parsed_features['image'], tf.float32)
image = tf.reshape(image, img_shape)
label = tf.decode_raw(parsed_features['label'], tf.int32)
label = tf.reshape(label, tf.stack([1]))
return image, label
我认为有效。但我无法从此 MapDataset 类型对象中获取数组。该怎么做?
【问题讨论】:
标签: python tensorflow machine-learning