【发布时间】:2020-11-16 16:19:49
【问题描述】:
我正在尝试使用自动编码器开发图像着色器。有13000张训练图像。如果我使用 tf.data,每个 epoch 大约需要 45 分钟,如果我使用 tf.utils.keras.Sequence,大约需要 25 分钟。然而,使用 Sequence 存在死锁的风险。如何改进 tf.data?我尝试了几件事,但似乎没有任何改善。
tf.data 1
image_path_list = glob.glob('datasets/imagenette/*')
data = tf.data.Dataset.list_files(image_path_list)
def tf_rgb2lab(image):
im_shape = image.shape
[image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
image.set_shape(im_shape)
return image
def preprocess(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf_rgb2lab(image)
L = image[:,:,0]/100.
ab = image[:,:,1:]/128.
input = tf.stack([L,L,L], axis=2)
return input, ab
train_ds = data.repeat().map(preprocess, AUTOTUNE).batch(32).prefetch(AUTOTUNE)
tf.data 2
AUTOTUNE = tf.data.experimental.AUTOTUNE
def tf_rgb2lab(image):
im_shape = image.shape
[image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
image.set_shape(im_shape)
return image
def split_for_feed(image):
L = image[:,:,:,0]/100.
ab = image[:,:,:,1:]/128.
input = tf.stack([L,L,L], axis=-1)
return input, ab
def read_images(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf_rgb2lab(image)
return image
data2 = data.repeat().map(read_images, AUTOTUNE).batch(32)
train_ds = data2.map(split_for_feed, AUTOTUNE).prefetch(AUTOTUNE)
序列
class ImageGenerator(tf.keras.utils.Sequence):
def __init__(self, image_filenames, batch_size):
self.image_filenames = image_filenames
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.image_filenames) / self.batch_size)
def __getitem__(self, idx):
batch = self.image_filenames[idx * self.batch_size : (idx + 1) * self.batch_size]
X_batch = []
y_batch = []
for file_name in batch:
file_name = 'datasets/imagenette/' + file_name
try:
color_image = transform.resize(io.imread(file_name),(224,224))
lab_image = color.rgb2lab(color_image)
L = lab_image[:,:,0]/100.
ab = lab_image[:,:,1:]/128.
X_batch.append(np.stack((L,L,L), axis=2))
y_batch.append(ab)
except:
pass
return np.array(X_batch), np.array(y_batch)
【问题讨论】:
标签: tensorflow keras deep-learning tensorflow-datasets tf.keras