【发布时间】:2018-02-07 06:45:38
【问题描述】:
我有一个庞大的数据集,我需要以生成器的形式提供给 Keras,因为它不适合内存。但是,使用fit_generator,我无法复制我在使用model.fit 进行常规训练时获得的结果。而且每个纪元的持续时间也相当长。
我实现了一个最小的示例。也许有人可以告诉我问题出在哪里。
import random
import numpy
from keras.layers import Dense
from keras.models import Sequential
random.seed(23465298)
numpy.random.seed(23465298)
no_features = 5
no_examples = 1000
def get_model():
network = Sequential()
network.add(Dense(8, input_dim=no_features, activation='relu'))
network.add(Dense(1, activation='sigmoid'))
network.compile(loss='binary_crossentropy', optimizer='adam')
return network
def get_data():
example_input = [[float(f_i == e_i % no_features) for f_i in range(no_features)] for e_i in range(no_examples)]
example_target = [[float(t_i % 2)] for t_i in range(no_examples)]
return example_input, example_target
def data_gen(all_inputs, all_targets, batch_size=10):
input_batch = numpy.zeros((batch_size, no_features))
target_batch = numpy.zeros((batch_size, 1))
while True:
for example_index, each_example in enumerate(zip(all_inputs, all_targets)):
each_input, each_target = each_example
wrapped = example_index % batch_size
input_batch[wrapped] = each_input
target_batch[wrapped] = each_target
if wrapped == batch_size - 1:
yield input_batch, target_batch
if __name__ == "__main__":
input_data, target_data = get_data()
g = data_gen(input_data, target_data, batch_size=10)
model = get_model()
model.fit(input_data, target_data, epochs=15, batch_size=10) # 15 * (1000 / 10) * 10
# model.fit_generator(g, no_examples // 10, epochs=15) # 15 * (1000 / 10) * 10
在我的电脑上,model.fit 总是以 0.6939 的损失完成第 10 个 epoch,并且在 ca 之后。 2-3 秒。
然而,model.fit_generator 方法运行的时间要长得多,并以 不同 损失 (0.6931) 结束最后一个 epoch。
我一般不明白为什么两种方法的结果不同。这可能看起来差别不大,但我需要确保具有相同网络的相同数据产生相同的结果,独立于常规训练或使用生成器。
更新:@Alex R. 为部分原始问题提供了答案(一些性能问题以及每次运行的结果变化)。然而,由于核心问题仍然存在,我只是相应地调整了问题和标题。
【问题讨论】:
-
我认为您在面向 Python 编程的网站上可能会更好。
-
您的训练数据集有多大?如果你在 fit 生成器中增加批量大小会发生什么?
-
@AlexR。我有大约 250 万个例子。如果我增加批量大小,损失仍然不稳定,并且与我使用
model.fit()得到的损失仍然不同。 -
@mdewey 如果你知道在没有 Python 的情况下使用 Keras 的方法,我期待听到它。
-
Also each epoch lasts considerably longer.其原因显然是与 I/O 操作相关的开销。它随领土而来。为了缩短它,您可能需要一个固态硬盘。