【发布时间】:2020-08-25 19:49:16
【问题描述】:
我有一个很长的数据集,我需要分块训练它。我读过可以多次调用model.fit,但最好使用model.train_on_batch。那是真实的?为什么?
【问题讨论】:
标签: python machine-learning keras neural-network conv-neural-network
我有一个很长的数据集,我需要分块训练它。我读过可以多次调用model.fit,但最好使用model.train_on_batch。那是真实的?为什么?
【问题讨论】:
标签: python machine-learning keras neural-network conv-neural-network
而不是使用 model.fit 多个。您可以在 tensorflow 中使用 make_csv_dataset 函数并将数据集传递给您的 fit 命令。 假设您的数据是 csv 格式。此功能的优点是它会在需要时加载数据,而不是将所有内容都加载到主内存中。
tf.data.experimental.make_csv_dataset(
file_pattern,
batch_size,
label_name=None,
select_columns=None,
shuffle=True,
)
这里的文件模式是单个字符串,即。如果要加载多个文件,则为文件名或字符串模式。见Documentation
如果您有图像数据集,则可以使用名为 flow_from_directory 的内容。这以类似的方式工作。它只加载处理所需的图像。
# this is preprocessing step where you define preprocessing on images.
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
# this is where you create iterator.
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
model.fit(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
【讨论】: