tf.data API(从 tensorflow 1.4 开始)非常适合这样的事情。管道将如下所示:
- 创建一个初始的
tf.data.Dataset 对象以迭代所有示例
- (如果是训练)
shuffle/repeat 数据集;
-
map它通过一些使所有图像大小相同的功能;
-
batch;
- (可选)
prefetch 告诉您的程序在网络处理当前批次时收集预处理后续批次的数据;和
- 并获取输入。
有多种方法可以创建您的初始数据集(请参阅here 以获得更深入的答案)
带有 TensorFlow 数据集的 TFRecords
支持 tensorflow 1.12 及以上版本,Tensorflow datasets 提供了一个相对简单的 API 用于创建 tfrecord 数据集,并自动处理数据下载、分片、统计生成等功能。
参见例如this image classification dataset implementation。里面有很多记账的东西(下载网址、引用等),但技术部分归结为指定 features 并编写 _generate_examples 函数
features = tfds.features.FeaturesDict({
"image": tfds.features.Image(shape=(_TILES_SIZE,) * 2 + (3,)),
"label": tfds.features.ClassLabel(
names=_CLASS_NAMES),
"filename": tfds.features.Text(),
})
...
def _generate_examples(self, root_dir):
root_dir = os.path.join(root_dir, _TILES_SUBDIR)
for i, class_name in enumerate(_CLASS_NAMES):
class_dir = os.path.join(root_dir, _class_subdir(i, class_name))
fns = tf.io.gfile.listdir(class_dir)
for fn in sorted(fns):
image = _load_tif(os.path.join(class_dir, fn))
yield {
"image": image,
"label": class_name,
"filename": fn,
}
您还可以使用较低级别的操作生成tfrecords。
通过tf.data.Dataset.map和tf.py_func(tion)加载图片
或者,您可以从tf.data.Dataset.map 中的文件名加载图像文件,如下所示。
image_paths, labels = load_base_data(...)
epoch_size = len(image_paths)
image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
labels = tf.convert_to_tensor(labels)
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
if mode == 'train':
dataset = dataset.repeat().shuffle(epoch_size)
def map_fn(path, label):
# path/label represent values for a single example
image = tf.image.decode_jpeg(tf.read_file(path))
# some mapping to constant size - be careful with distorting aspec ratios
image = tf.image.resize_images(out_shape)
# color normalization - just an example
image = tf.to_float(image) * (2. / 255) - 1
return image, label
# num_parallel_calls > 1 induces intra-batch shuffling
dataset = dataset.map(map_fn, num_parallel_calls=8)
dataset = dataset.batch(batch_size)
# try one of the following
dataset = dataset.prefetch(1)
# dataset = dataset.apply(
# tf.contrib.data.prefetch_to_device('/gpu:0'))
images, labels = dataset.make_one_shot_iterator().get_next()
我从未在分布式环境中工作过,但我从未注意到使用这种方法对tfrecords 的性能影响。如果您需要更多自定义加载功能,也请查看tf.py_func。
更多一般信息here,以及性能说明here