【问题标题】:Tensorflow GANs discriminator doesn't learnTensorflow GANs 鉴别器不学习
【发布时间】:2023-04-09 11:36:01
【问题描述】:

我正在尝试制作一个基本的 GAN,它试图学习一个简单的 3 x 3 矩阵,其中包含一个加号。

但是,由于某种原因,判别器损失没有改变。

例如:

[[0.0, 0.98, 0,01]

[0.95, 0.97, 0.99]

[0.02, 0.99, 0.02]]

代码如下:

生成器和鉴别器:


def make_generator():
    model = keras.Sequential()
    model.add(keras.layers.Dense(10, activation='relu', input_shape=(5, )))
    model.add(keras.layers.Dense(20, activation='relu'))
    model.add(keras.layers.Dense(9, activation='relu'))
    model.add(keras.layers.Reshape((3, 3)))
    return model


def make_discriminator():
    model = keras.Sequential()
    model.add(keras.layers.Dense(10, activation='relu', input_shape=[3, 3]))
    model.add(keras.layers.Dropout(0.2))
    model.add(keras.layers.Dense(20, activation='relu'))
    model.add(keras.layers.Dropout(0.2))
    model.add(keras.layers.Dense(9, activation='relu'))
    model.add(keras.layers.Dropout(0.2))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(1, activation='softmax'))
    return model

generator = make_generator()
discriminator = make_discriminator()


我认为问题出在训练上,但我不确定。

培训计划:

generator_optimizer = tf.keras.optimizers.SGD(learning_rate = 0.1)
discriminator_optimizer = tf.keras.optimizers.SGD(learning_rate = 0.1)

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def generator_loss(generated_im):
    loss = cross_entropy(tf.ones_like(generated_im), generated_im)
    return loss


def discriminator_loss(real_im_pred, generated_im_pred):
    loss_on_real = cross_entropy(tf.ones_like(real_im_pred), real_im_pred)
    loss_on_generated = cross_entropy(tf.zeros_like(generated_im_pred), generated_im_pred)
    loss = loss_on_generated + loss_on_real
    return loss

@tf.function
def train_step(images, batch_size):
        
    noise = tf.random.normal([batch_size, 5])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        classification_on_real = discriminator(images, training=True)
        classification_on_fake = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(generated_images)
        disc_loss = discriminator_loss(classification_on_real, classification_on_fake)

        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss
        
def train(data, epochs, batch_size):
    for epoch in range(epochs):
        start = time.time()
        
        # Keep track of the total loss and accuracy
        total_gen_loss = 0
        total_disc_loss = 0
        
        for image_batch in data:
            gen_loss, disc_loss = train_step(image_batch, batch_size)
            
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
        print ('Time for epoch {} is {} sec, generator loss: {}, discriminator loss: {}'
               .format(epoch + 1, round(time.time()-start), round(float(total_gen_loss), 2), round(float(total_disc_loss), 2)))

运行此代码时得到的输出如下:

Time for epoch 1 is 3 sec, generator loss: 346.15, discriminator loss: 3252.97
Time for epoch 2 is 2 sec, generator loss: 308.61, discriminator loss: 3252.97
Time for epoch 3 is 2 sec, generator loss: 308.33, discriminator loss: 3252.97
Time for epoch 4 is 2 sec, generator loss: 308.24, discriminator loss: 3252.97
Time for epoch 5 is 2 sec, generator loss: 308.19, discriminator loss: 3252.97
Time for epoch 6 is 2 sec, generator loss: 308.16, discriminator loss: 3252.97
Time for epoch 7 is 2 sec, generator loss: 308.14, discriminator loss: 3252.97
Time for epoch 8 is 2 sec, generator loss: 308.13, discriminator loss: 3252.97
Time for epoch 9 is 2 sec, generator loss: 308.12, discriminator loss: 3252.97
Time for epoch 10 is 2 sec, generator loss: 308.11, discriminator loss: 3252.97
Time for epoch 11 is 2 sec, generator loss: 308.11, discriminator loss: 3252.97
Time for epoch 12 is 2 sec, generator loss: 308.11, discriminator loss: 3252.97
Time for epoch 13 is 2 sec, generator loss: 308.1, discriminator loss: 3252.97
Time for epoch 14 is 2 sec, generator loss: 308.1, discriminator loss: 3252.97
Time for epoch 15 is 2 sec, generator loss: 308.1, discriminator loss: 3252.97
Time for epoch 16 is 2 sec, generator loss: 308.1, discriminator loss: 3252.97
Time for epoch 17 is 2 sec, generator loss: 308.09, discriminator loss: 3252.97
Time for epoch 18 is 2 sec, generator loss: 308.09, discriminator loss: 3252.97
Time for epoch 19 is 2 sec, generator loss: 308.09, discriminator loss: 3252.97
Time for epoch 20 is 2 sec, generator loss: 308.09, discriminator loss: 3252.97

如果您对制作数据的代码感兴趣,这里是:

def plus():
    array = np.array([[np.random.normal(0.05, 0.01, 1)[0], np.random.normal(0.95, 0.01, 1)[0], np.random.normal(0.05, 0.01, 1)[0]],
                     [np.random.normal(0.95, 0.01, 1)[0], np.random.normal(0.95, 0.01, 1)[0], np.random.normal(0.95, 0.01, 1)[0]],
                     [np.random.normal(0.05, 0.01, 1)[0], np.random.normal(0.95, 0.01, 1)[0], np.random.normal(0.05, 0.01, 1)[0]]])
    return array


def dataset(size):
    X = []
    
    for _ in range(size):
        x = plus()
        X.append(x)
    return np.array(X)


def get_batches(x, batch_size):
    batches = []
    for i in range(0, x.shape[0], batch_size):
        batch = x[i:i + batch_size]
        batches.append(batch)
    
    random.shuffle(batches)
    return np.array(batches)

BATCH_SIZE = 10
data = dataset(20000)
data = get_batches(data, BATCH_SIZE)

希望你能帮上忙! 非常感谢。

【问题讨论】:

  • 鉴别器输出需要使用sigmoid激活,而不是softmax。单个输出单元的 Softmax 没有意义。

标签: python tensorflow keras deep-learning generative-adversarial-network


【解决方案1】:

来自cmets:

Discriminator最后一层要使用的激活函数应该是Sigmoid,而不是Softmax,因为DiscriminatorFinal Dense Layer只有1个Node/Neuron/Unit。 (转自 xdurch0)。

【讨论】:

  • 如果您打算从 cmets 批发内容,您应该将您的答案标记为社区 Wiki,并将您复制的答案嵌入到块引用中。后者我会做,但前者必须由你来做。
  • @TylerH,将答案标记为社区 Wiki。感谢您的反馈。
猜你喜欢
  • 2019-12-29
  • 1970-01-01
  • 1970-01-01
  • 2018-10-08
  • 1970-01-01
  • 2020-10-29
  • 2016-09-19
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多