【发布时间】:2020-06-29 15:53:15
【问题描述】:
我正在尝试保存我的 GAN 模型。我在互联网上没有找到太多信息,对如何保存——生成器、判别器和 GAN 非常困惑。这是我找到的issue,我对其进行了相应的编码。但是即使以这种方式保存和加载模型后,我也会收到一条错误消息You must compile a model before training/testing. 请注意,他们要求使用tensorflow.keras 而不是keras,但我不太了解。
如何保存和加载 GAN,以便可以在 epoch 之间进行故意的键盘中断,并稍后再次运行?
保存 GAN 模型的功能:
def save_model_to_file(gan, generator, discriminator, epoch):
discriminator.trainable = False
gan.save('facegan-gannet-epoch:%02d.h5' % epoch)
discriminator.trainable = True
generator.save('facegan-generator-epoch:%02d.h5' % epoch)
discriminator.save('facegan-discriminator-epoch:%02d.h5' % epoch)
以这种方式加载模型:(所有参数都可训练)
discriminator = load_model(models_paths["discriminator"])
discriminator.trainable = False
generator = load_model(models_paths["generator"])
gan = load_model(models_paths["gan"])
gan.summary()
discriminator.summary()
generator.summary()
主要列车部分:
generator = get_generator()
discriminator = get_discriminator()
gan = get_gan_network(discriminator, LATENT_DIM, generator, optimizer)
for epoch in range(1, epochs + 1):
print('\n', '\t' * 3, '-' * 4, 'Epoch %d' % epoch, '-' * 4)
for batch_count, image_batch in tqdm(enumerate(datagen)):
if batch_count == len(datagen): # len(datagen)
break
# Get a random set of input noise and images
noise = np.random.normal(0, 1, size=[BATCH_SIZE, LATENT_DIM])
# Generate fake images
generated_images = generator.predict(noise)
X = np.concatenate([image_batch, generated_images])
# Labels for generated and real data
y_dis = np.zeros(2 * BATCH_SIZE)
# One-sided label smoothing
y_dis[:BATCH_SIZE] = 0.9
# Train discriminator
discriminator.trainable = True
discriminator_loss = discriminator.train_on_batch(X, y_dis)
# Train generator
noise = np.random.normal(0, 1, size=[BATCH_SIZE, LATENT_DIM])
y_gen = np.ones(BATCH_SIZE)
discriminator.trainable = False
gannet_loss = gan.train_on_batch(noise, y_gen)
save_model_to_file(gan, generator, discriminator, epoch)
完整的错误信息:
RuntimeError
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-29-740ccd719c22> in <module>
245
246 print("[INFO] Training started...")
--> 247 train(3, BATCH_SIZE, model_paths)
248 print("[INFO] Training completed.")
<ipython-input-29-740ccd719c22> in train(epochs, batch_size, models_paths)
225 # Train discriminator
226 discriminator.trainable = True
--> 227 discriminator_loss = discriminator.train_on_batch(X, y_dis)
228
229 # Train generator
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
1506 x, y,
1507 sample_weight=sample_weight,
-> 1508 class_weight=class_weight)
1509 if self._uses_dynamic_learning_phase():
1510 ins = x + y + sample_weights + [1]
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
506 if y is not None:
507 if not self.optimizer:
--> 508 raise RuntimeError('You must compile a model before '
509 'training/testing. '
510 'Use `model.compile(optimizer, loss)`.')
RuntimeError: You must compile a model before training/testing. Use `model.compile(optimizer, loss)`.
You must compile a model before training/testing. Use `model.compile(optimizer, loss)`.
【问题讨论】:
标签: python tensorflow keras generative-adversarial-network