【问题标题】:fit_generator in keras: where is the batch_size specified?keras中的fit_generator:batch_size在哪里指定?
【发布时间】:2017-05-04 10:21:15
【问题描述】:

您好,我不了解 keras fit_generator 文档。

我希望我的困惑是理性的。

batch_size,还有分批训练的概念。使用model_fit(),我指定batch_size 为128。

对我来说,这意味着我的数据集将一次输入 128 个样本,从而大大减轻了内存。只要我有时间等待,它就应该允许训练 1 亿个样本数据集。毕竟,keras 一次只能“处理”128 个样本。对吧?

但我高度怀疑单独指定 batch_size 并不能满足我的任何要求。大量内存仍在使用中。为了我的目标,我需要分批训练 128 个示例。

所以我猜这就是fit_generator 所做的。我真的很想问为什么batch_size 不像它的名字所暗示的那样实际工作?

更重要的是,如果需要fit_generator,我在哪里指定batch_size?文档说无限循环。 生成器对每一行循环一次。如何一次循环超过 128 个样本并记住我上次停止的位置并在下次 keras 要求下一批的起始行号时回忆它(在第一批完成后将是第 129 行)。

【问题讨论】:

    标签: tensorflow keras


    【解决方案1】:

    您需要在生成器内部以某种方式处理批量大小。这是一个生成随机批次的示例:

    import numpy as np
    data = np.arange(100)
    data_lab = data%2
    wholeData = np.array([data, data_lab])
    wholeData = wholeData.T
    
    def data_generator(all_data, batch_size = 20):
    
        while True:        
    
            idx = np.random.randint(len(all_data), size=batch_size)
    
            # Assuming the last column contains labels
            batch_x = all_data[idx, :-1]
            batch_y = all_data[idx, -1]
    
            # Return a tuple of (Xs,Ys) to feed the model
            yield(batch_x, batch_y)
    
    print([x for x in data_generator(wholeData)])
    

    【讨论】:

      【解决方案2】:

      首先,keras batch_size 确实工作得很好。如果您正在使用 GPU,您应该知道使用 keras 的模型可能非常繁重,尤其是在使用循环单元时。如果你在 CPU 上工作,整个程序都加载到内存中,批量大小不会对内存产生太大影响。如果您使用fit(),则整个数据集可能已加载到内存中,keras 在每一步都会生成批次。很难预测将要使用的内存量。

      对于fit_generator() 方法,您应该构建一个python 生成器函数(使用yield 而不是return),每一步产生一个批次。 yield 应该处于无限循环中(我们经常使用while true: ...)。

      你有一些代码来说明你的问题吗?

      【讨论】:

        猜你喜欢
        • 2019-05-08
        • 2020-03-23
        • 1970-01-01
        • 1970-01-01
        • 2012-07-10
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多