【发布时间】:2020-10-22 01:29:49
【问题描述】:
真的有必要在 TFRecord 文件上存储图像尺寸信息吗?我目前正在使用由不同比例图像组成的数据集,并且没有存储我处理的图像的宽度、长度和通道数信息,现在我面临一个问题,将它们调整回原始形状 在加载 tfrecord 之后执行其他预处理管道,例如数据增强。
# Create dataset
records_path = DATA_DIR + 'TFRecords/train_0.tfrecords'
dataset = tf.data.TFRecordDataset(filenames=records_path)
#Parse dataset
parsed_dataset = dataset.map(parsing_fn)
# Get iterator
iterator = tf.compat.v1.data.make_one_shot_iterator(parsed_dataset)
image,label = iterator.get_next()
# Get the numpy array from tensor, convert to uint8 and plot image from array
img_array = image.numpy()
img_array = img_array.astype(np.uint8)
plt.imshow(img_array)
plt.show()
输出:TypeError:图像数据的尺寸无效
在转换为 uint8 之间,我应该将图像大小调整回原始形状?如果是这样,如果我没有存储维度信息,我该怎么办?
下面的管道演示了一个转换示例,我想将其应用于从 tfrecord 读取的图像,但我相信这些 keras 增强方法需要适当调整大小的数组具有定义的操作尺寸。(我不一定需要打印图像)。
def brightness(brightness_range, image):
img = tf.keras.preprocessing.image.load_img(image)
data = tf.keras.preprocessing.image.array_to_img(img)
samples = expand_dims(data,0)
print(samples.shape)
datagen = tf.keras.preprocessing.image.ImageDataGenerator(brightness_range=brightness_range)
iterator = datagen.flow(samples,batch_size=1)
for i in range(9):
pyplot.subplot(330+1+i)
batch = iterator.next()
image = batch[0].astype('uint8')
pyplot.imshow(image)
pyplot.show()
brightness([0.2,1.0],DATA_DIR+"183350/5c3e30f1706244e9f199d5a0c5a5ec00d1cbf473.jpg")
帮助函数写入和读取 TFRecord 格式
转换为 tfrecord:
def convert(image_paths, labels, out_path):
# Args:
# image_paths List of file-paths for the images.
# labels Class-labels for the images.
# out_path File-path for the TFRecords output file.
print("Converting: " + out_path)
# Number of images. Used when printing the progress.
num_images = len(image_paths)
# Open a TFRecordWriter for the output-file.
with tf.python_io.TFRecordWriter(out_path) as writer:
# Iterate over all the image-paths and class-labels.
for i in range(num_images):
# Print the percentage-progress.
print_progress(count=i, total=num_images-1)
# Load the image-file using matplotlib's imread function.
path = image_paths[i]
img = imread(path)
path = path.split('/')
# Convert the image to raw bytes.
img_bytes = img.tostring()
# Get the label index
label = int(path[4])
# Create a dict with the data we want to save in the
# TFRecords file. You can add more relevant data here.
data = \
{
'image': wrap_bytes(img_bytes),
'label': wrap_int64(label)
}
# Wrap the data as TensorFlow Features.
feature = tf.train.Features(feature=data)
# Wrap again as a TensorFlow Example.
example = tf.train.Example(features=feature)
# Serialize the data.
serialized = example.SerializeToString()
# Write the serialized data to the TFRecords file.
writer.write(serialized)
解析函数
def parsing_fn(serialized):
# Define a dict with the data-names and types we expect to
# find in the TFRecords file.
# It is a bit awkward that this needs to be specified again,
# because it could have been written in the header of the
# TFRecords file instead.
features = \
{
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.io.parse_single_example(serialized=serialized,
features=features)
# Get the image as raw bytes.
image_raw = parsed_example['image']
# Decode the raw bytes so it becomes a tensor with type.
image = tf.io.decode_raw(image_raw, tf.uint8)
# The type is now uint8 but we need it to be float.
image = tf.cast(image, tf.float32)
# Get the label associated with the image.
label = parsed_example['label']
# The image and label are now correct TensorFlow types.
return image, label
【问题讨论】:
-
你能发布一些sn-ps代码来展示你的尝试,然后会导致什么样的意外行为或异常?
标签: python tensorflow2.0 tfrecord