请在下面找到一个可能的解决方案。
为了演示,我使用 python 生成器而不是 TFRecords 作为输入(我假设您知道如何使用 TF Dataset 来读取和解析每个文件夹中的文件。其他线程在其他方面涵盖了这一点,例如here)。
import tensorflow as tf
import numpy as np
def get_class_generator(class_id, num_el, el_shape=(32, 32), el_dtype=np.int32):
""" Returns a dummy generator,
outputting "num_el" elements of a single class (input data & class label)
"""
def class_generator():
x = 0
for x in range(num_el):
element = np.ones(el_shape, dtype=el_dtype) * x
yield element, class_id
return class_generator
def concatenate_datasets(datasets):
""" Concatenate a list of datasets together.
Snippet by user2781994 (https://stackoverflow.com/a/49069420/624547)
"""
ds0 = tf.data.Dataset.from_tensors(datasets[0])
for ds1 in datasets[1:]:
ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
return ds0
num_classes = 11
class_batch_size = 3
num_classes_per_batch = 5
# note: using 3 instead of 5 for class_batch_size in this example
# just to distinguish between the 2 vars.
# Initializing per-class datasets:
# (note: replace tf.data.Dataset.from_generator(...) to suit your use-case
# e.g. tf.contrib.data.TFRecordDataset(glob.glob(perclass_tfrecords_path))
# .map(your_parsing_function)
class_datasets = [tf.data.Dataset
.from_generator(get_class_generator(
class_id, num_el=np.random.randint(1, 60)
# ^ simulating unequal number of samples per class
), (tf.int32, tf.int32), ([32, 32], []))
.repeat(-1)
.batch(class_batch_size)
for class_id in range(num_classes)]
# Initializing complete dataset:
dataset = (tf.data.Dataset
# Concatenating all the class datasets together:
.zip(tuple(class_datasets))
.flat_map(lambda *args: concatenate_datasets(args))
# Shuffling the class datasets:
.shuffle(buffer_size=num_classes)
# Flattening batches from shape (num_classes_per_batch, class_batch_size, ...)
# into (num_classes_per_batch * class_batch_size, ...):
.flat_map(lambda *args: tf.data.Dataset.from_tensor_slices(args))
# Returning correct number of el. (num_classes_per_batch * class_batch_size):
.batch(num_classes_per_batch * class_batch_size))
# Visualizing results:
next_batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(10):
batch = sess.run(next_batch)
print(">> batch {}".format(i))
print("- inputs shape: {} ; label shape: {}".format(batch[0].shape,batch[1].shape))
print("- class values: {}".format(batch[1]))
输出:
>> batch 0
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 1 1 1 0 0 0 10 10 10 2 2 2 9 9 9]
>> batch 1
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 2 2 2 3 3 3 5 5 5 6 6 6]
>> batch 2
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 9 9 9 8 8 8 4 4 4 3 3 3 10 10 10]
>> batch 3
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [7 7 7 8 8 8 6 6 6 6 6 6 2 2 2]
>> batch 4
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [1 1 1 0 0 0 1 1 1 8 8 8 5 5 5]
>> batch 5
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [2 2 2 4 4 4 9 9 9 5 5 5 5 5 5]
>> batch 6
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 7 7 7 3 3 3 9 9 9 7 7 7]
>> batch 7
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [10 10 10 10 10 10 1 1 1 6 6 6 7 7 7]
>> batch 8
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [4 4 4 3 3 3 5 5 5 6 6 6 3 3 3]
>> batch 9
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [8 8 8 9 9 9 2 2 2 8 8 8 0 0 0]