【发布时间】:2019-03-16 04:58:37
【问题描述】:
我将每个数据点存储在一个 .npy 文件中,并带有shape=(1024,7,8)。我想通过类似于ImageDataGenerator 的方式将它们加载到 Keras 模型中,所以我编写并尝试了不同的自定义生成器,但它们都不起作用,这是我改编自 this 的一个
def find(dirpath, prefix=None, suffix=None, recursive=True):
"""Function to find recursively all files with specific prefix and suffix in a directory
Return a list of paths
"""
l = []
if not prefix:
prefix = ''
if not suffix:
suffix = ''
for (folders, subfolders, files) in os.walk(dirpath):
for filename in [f for f in files if f.startswith(prefix) and f.endswith(suffix)]:
l.append(os.path.join(folders, filename))
if not recursive:
break
l
return l
def generate_data(directory, batch_size):
i = 0
file_list = find(directory)
while True:
array_batch = []
for b in range(batch_size):
if i == len(file_list):
i = 0
random.shuffle(file_list)
sample = file_list[i]
i += 1
array = np.load(sample)
array_batch.append(array)
yield array_batch
我发现这缺少标签,所以它不适合使用 fit_generator 的模型。鉴于我可以将它们存储在一个 numpy 数组中,我如何将标签添加到此生成器中?
【问题讨论】:
-
while 循环何时停止在
while True:??