【问题标题】:How to create a tf.data pipeline with multiple .npy files如何使用多个 .npy 文件创建 tf.data 管道
【发布时间】:2021-05-12 04:58:42
【问题描述】:

我已经研究过这个问题的其他问题,但找不到确切的答案,所以从头开始尝试:

问题

我有多个 .npy 文件(X_train 文件),每个文件都有一个形状数组 (n, 99, 2) - 只有第一个维度不同,其余两个相同。根据 .npy 文件的名称,我还可以获得相应的标签(y_train 文件)。

每一个这样的文件都可以轻松加载到内存中(多个文件也是如此),但不能一次全部加载。

我构建了一个生成器,它遍历文件列表并为训练批次聚合给定数量的文件:

def tf_data_generator(filelist, directory = [], batch_size = 5):
    i = 0
    x_t = os.listdir(directory[0])
    y_t = os.listdir(directory[1])
    while True:
        file_chunk = filelist[i*batch_size:(i+1)*batch_size] 
        X_a = []
        Y_a = []
        for fname in file_chunk:
            x_info = np.load(path_x_tr+fname)
            y_info = np.load(path_y_tr+fname)
            X_a.append(x_info)
            Y_a.append(y_info)
        X_a = np.concatenate(X_a)
        Y_a = np.concatenate(Y_a)
        yield X_a, Y_a
        i = i + 1

实际上(在 CPU 上)它工作正常,但是如果我尝试在 CUDA 上使用 GPU,它会崩溃,并给出 Failed to call ThenRnnForward with model config: 错误(请参阅:link

所以我正在尝试寻找另一种方法并使用 tf.data API 生成数据。但是,我被卡住了:

def parse_file(name):
    x = np.load('./data/x_train_m/'+name)
    y = np.load('./data/y_train_m/'+name)
    train_dataset = tf.data.Dataset.from_tensor_slices((test1, test2))
    return train_dataset

train_dataset = parse_file('example1.npy')
train_dataset = train_dataset.shuffle(100).batch(64)

model = wtte_rnn()
model.summary()
K.set_value(model.optimizer.lr, 0.01)
model.fit(train_dataset,
          epochs=10)

这很好用,但是我找不到方法:

  1. 混合多个文件(最多一定数量,比如说五个)
  2. 遍历整个文件列表

我已经阅读了 flat_map 和 interleave,但是,我无法更进一步,并且任何尝试使用它们都没有成功。如何制作与代码上部类似的生成器,但使用 tf.data API?

【问题讨论】:

    标签: python tensorflow keras tensorflow-datasets tf.data.dataset


    【解决方案1】:

    您可以尝试将它们连接起来,如下所示:

    train_dataset = parse_file('example1.npy') # initialize train dataset
    
    for file in files[1:]: # concatenate with the remaining files
        train_dataset = train_dataset.concatenate(parse_file(file))
    

    【讨论】:

    • 如果我可以将所有数据加载到内存中,这将是有意义的——但是,我不能一次全部完成。我在这里尝试做的想法是使用 tf.data API 生成数据并将其提供给 .fit 方法。
    • tf.data.Dataset 不会将所有内容都加载到内存中
    • 它可能不是,而是像答案中的那样循环,psutil显示内存使用量确实在增加。
    • 更多实验:只需运行上面显示的代码,释放内存的唯一方法就是实际删除 train_dataset,这首先违背了目的.. 不知道为什么会发生这种情况
    • 文档没有提到从 .npy tensorflow.org/tutorials/load_data/numpy 加载数据的迭代方式
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2014-08-31
    • 2011-05-19
    • 2011-01-18
    • 2018-07-31
    • 1970-01-01
    • 1970-01-01
    • 2021-09-11
    相关资源
    最近更新 更多