【问题标题】:In GAN, is it necessary to compile generator在GAN中,是否需要编译生成器
【发布时间】:2020-09-29 21:17:37
【问题描述】:

我一直在研究 GAN,让我摸不着头脑的是为什么我们必须编译生成器模型,即使我们编译组合的 GAN 模型,为什么还要单独编译生成器。

def create_generator():
    generator = Sequential()

    generator.add(Dense(256, input_dim=noise_dim))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(img_rows*img_cols*channels, activation='tanh'))

    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator

def create_descriminator():
    discriminator = Sequential()

    discriminator.add(Dense(1024, input_dim=img_rows*img_cols*channels))
    discriminator.add(LeakyReLU(0.2))

    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))

    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))

    discriminator.add(Dense(1, activation='sigmoid'))

    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator

discriminator = create_descriminator()
generator = create_generator()

# Make the discriminator untrainable when we are training the generator.  This doesn't effect the discriminator by itself
discriminator.trainable = False

# Link the two models to create the GAN
gan_input = Input(shape=(noise_dim,))
fake_image = generator(gan_input)

gan_output = discriminator(fake_image)

gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

在这段代码中,我们可以看到生成器、判别器和 gan(组合模型)这三个都被编译了。根据我的理解,我们应该只编译鉴别器(训练鉴别器)和 gan(组合模型,训练生成器),因为鉴别器的权重在 GAN 训练期间被冻结,结果只有生成器得到训练。那为什么要编译生成器

【问题讨论】:

  • 确实不需要编译G。在D 训练期间,您只需使用generator.predict()D 提供假样本。确保在 gan 更新期间冻结了 D 权重。
  • 感谢@Slowpoke 的回答,您知道其他人编译生成器的任何可能原因吗,即使我们不需要它。
  • @IrfanDanish 如果你不编译生成器,你会收到来自 TensorFlow 的大量警告。这可能只是摆脱这些警告的一种方式。

标签: python-3.x tensorflow keras deep-learning generative-adversarial-network


【解决方案1】:

在训练期间,generatordiscriminator 有相反的目标: discriminator 试图区分假图像和真实图像,而 生成器尝试生成看起来足够真实的图像来欺骗 鉴别器。
因为 GAN 由两个具有不同目标的网络组成,所以不能像常规神经网络那样进行训练。 每次训练迭代分为两个阶段:

  • 在第一阶段,我们训练判别器。一批真货 图像是从训练集中采样的,并用 生成器生成的假图像数量相等。标签是 假图像设置为 0,真实图像设置为 1,判别器 在这个标记的批次上训练一步,使用二进制 交叉熵损失。重要的是,反向传播只会优化 此阶段鉴别器的权重。
  • 在第二阶段,我们训练生成器。我们首先用它来 产生另一批假图像,再一次 鉴别器用于判断图像是假的还是真实的。 这次我们不批量添加真实图片,所有的标签 设置为 1(实数):换句话说,我们希望生成器产生 鉴别器将(错误地)认为是真实的图像! 至关重要的是,discriminator 的权重在此期间为 frozen 步,所以反向传播只影响生成器的权重。

接下来,我们需要编译这些模型。这 generator 只会通过gan model 训练,所以我们不需要 完全编译它。重要的是,discriminator 不应该 在第二阶段进行训练,所以我们之前将其设为non-trainable compiling甘模型:

【讨论】:

    猜你喜欢
    • 2016-09-23
    • 2018-08-17
    • 2013-09-14
    • 1970-01-01
    • 2011-07-18
    • 2012-01-05
    • 1970-01-01
    • 2014-10-31
    • 1970-01-01
    相关资源
    最近更新 更多