【问题标题】:keras model asks for compiling after loading the model using `load_model`keras 模型在使用 load_model 加载模型后要求编译
【发布时间】:2020-06-29 15:53:15
【问题描述】:

我正在尝试保存我的 GAN 模型。我在互联网上没有找到太多信息,对如何保存——生成器、判别器和 GAN 非常困惑。这是我找到的issue,我对其进行了相应的编码。但是即使以这种方式保存和加载模型后,我也会收到一条错误消息You must compile a model before training/testing. 请注意,他们要求使用tensorflow.keras 而不是keras,但我不太了解。

如何保存和加载 GAN,以便可以在 epoch 之间进行故意的键盘中断,并稍后再次运行?

保存 GAN 模型的功能:

def save_model_to_file(gan, generator, discriminator, epoch):
    discriminator.trainable = False
    gan.save('facegan-gannet-epoch:%02d.h5' % epoch)
    discriminator.trainable = True
    generator.save('facegan-generator-epoch:%02d.h5' % epoch)
    discriminator.save('facegan-discriminator-epoch:%02d.h5' % epoch)

以这种方式加载模型:(所有参数都可训练)

discriminator = load_model(models_paths["discriminator"])
discriminator.trainable = False
generator = load_model(models_paths["generator"])
gan = load_model(models_paths["gan"])
gan.summary()
discriminator.summary()
generator.summary()

主要列车部分:

    generator = get_generator()
    discriminator = get_discriminator()
    gan = get_gan_network(discriminator, LATENT_DIM, generator, optimizer)

    for epoch in range(1, epochs + 1):
        print('\n', '\t' * 3, '-' * 4, 'Epoch %d' % epoch, '-' * 4)

        for batch_count, image_batch in tqdm(enumerate(datagen)):

            if batch_count == len(datagen):  # len(datagen)
                break

            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[BATCH_SIZE, LATENT_DIM])

            # Generate fake images
            generated_images = generator.predict(noise)
            X = np.concatenate([image_batch, generated_images])

            # Labels for generated and real data
            y_dis = np.zeros(2 * BATCH_SIZE)
            # One-sided label smoothing
            y_dis[:BATCH_SIZE] = 0.9

            # Train discriminator
            discriminator.trainable = True
            discriminator_loss = discriminator.train_on_batch(X, y_dis)

            # Train generator
            noise = np.random.normal(0, 1, size=[BATCH_SIZE, LATENT_DIM])
            y_gen = np.ones(BATCH_SIZE)
            discriminator.trainable = False
            gannet_loss = gan.train_on_batch(noise, y_gen)

        save_model_to_file(gan, generator, discriminator, epoch)

完整的错误信息:

RuntimeError
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-29-740ccd719c22> in <module>
    245 
    246 print("[INFO] Training started...")
--> 247 train(3, BATCH_SIZE, model_paths)
    248 print("[INFO] Training completed.")

<ipython-input-29-740ccd719c22> in train(epochs, batch_size, models_paths)
    225             # Train discriminator
    226             discriminator.trainable = True
--> 227             discriminator_loss = discriminator.train_on_batch(X, y_dis)
    228 
    229             # Train generator

/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
   1506             x, y,
   1507             sample_weight=sample_weight,
-> 1508             class_weight=class_weight)
   1509         if self._uses_dynamic_learning_phase():
   1510             ins = x + y + sample_weights + [1]

/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    506         if y is not None:
    507             if not self.optimizer:
--> 508                 raise RuntimeError('You must compile a model before '
    509                                    'training/testing. '
    510                                    'Use `model.compile(optimizer, loss)`.')

RuntimeError: You must compile a model before training/testing. Use `model.compile(optimizer, loss)`.
You must compile a model before training/testing. Use `model.compile(optimizer, loss)`.

【问题讨论】:

    标签: python tensorflow keras generative-adversarial-network


    【解决方案1】:

    如果你想保存以后继续,我推荐使用 save_weights 和 load_weights 代替,这意味着你总是创建模型并编译它,然后检查是否有一个 weights 文件,如果有这样的文件,加载权重。

    类似这样的:

    import os
    g_file_name = 'your generator.h5'
    d_file_name = 'your descriminator.h5'
    (generator, descriminator) = create_your_models()
    compile_your_models(generator, descriminator)
    if os.path.isfile(g_file_name):
        generator.load_weights(g_file_name)
    if os.path.isfile(d_file_name):
        descriminator.load_weights(d_file_name)
    
    for epoch in range(num_epochs):
        gan_training()
        generator.save_weights(g_file_name)
        descriminator.save_weights(g_file_name)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2023-04-04
      • 2019-03-31
      • 1970-01-01
      • 2021-08-30
      • 1970-01-01
      • 2021-03-29
      • 2021-02-08
      • 2018-11-07
      相关资源
      最近更新 更多