【问题标题】:Why does my SRGAN (using PyTorch) result look similar to SRResNet results?为什么我的 SRGAN(使用 PyTorch)结果看起来与 SRResNet 结果相似?
【发布时间】:2018-10-26 03:49:05
【问题描述】:

SRGAN 是使用 PyTorch 实现的。

生成器预训练进行了 100 次,SRGAN 训练进行了 200 次。

代码是现有github代码的组合。

对于内容损失,使用 PyTorch 中的 MSELoss() 和 PyTorch 中的 BCELoss() 用于对抗性损失。

当我运行代码时,LossD 收敛到 0,而 LossG 在某个值附近振荡。所以我停止了训练,因为我认为它不再是训练了。

如果训练是论文中的1e5,结果会改变吗?还是损失函数的问题?

下面是 SRGAN 训练代码。

print('Adversarial training')
for epoch in range(NUM_EPOCHS):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
    # train_bar = tqdm(train_loader)
    for data, target in train_bar:
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        target_real = Variable(torch.ones(batch_size, 1))
        target_fake = Variable(torch.zeros(batch_size, 1))

        if torch.cuda.is_available():
            target_real = target_real.cuda()
            target_fake = target_fake.cuda()

        real_img = Variable(target)
        z = Variable(data)

        # Generate real and fake inputs
        if torch.cuda.is_available():
            inputsD_real = real_img.cuda()
            inputsD_fake = netG(z.cuda())
        else:
            inputsD_real = real_img
            inputsD_fake = netG(z)

        ######### Train discriminator #########
        netD.zero_grad()

        # With real data
        outputs = netD(inputsD_real)
        D_real = outputs.data.mean()

        lossD_real = adversarial_criterion(outputs, target_real)

        # With fake data
        outputs = netD(inputsD_fake.detach()) # Don't need to compute gradients wrt weights of netG (for efficiency)
        D_fake = outputs.data.mean()

        lossD_fake = adversarial_criterion(outputs, target_fake)

        lossD_total = lossD_real + lossD_fake

        lossD_total.backward()

        # Update discriminator weights
        optimizerD.step()

        ######### Train generator #########
        netG.zero_grad()

        real_features = Variable(feature_extractor(inputsD_real).data)
        fake_features = feature_extractor(inputsD_fake)

        lossG_vgg19 = content_criterion(fake_features, real_features)
        lossG_adversarial = adversarial_criterion(netD(inputsD_fake).detach(), target_real)
        lossG_mse = content_criterion(inputsD_fake, inputsD_real)

        lossG_total = lossG_mse + 2e-6 * lossG_vgg19 + 0.001 * lossG_adversarial
        lossG_total.backward()

        # Update generator weights
        optimizerG.step()

【问题讨论】:

  • GANs 的训练可能相当容易,如果你使用不同的学习率或任何与论文中完全不同的东西,那么所有关于其性能的赌注都将落空。即使您使用与论文中相同的代码,结果仍然会因运行而异,具体取决于随机种子。

标签: python tensorflow neural-network deep-learning pytorch


【解决方案1】:

如果使用经典的GAN模型进行交替训练,这种情况在所难免,尝试改变训练方式

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2011-11-15
    • 1970-01-01
    • 2016-02-02
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-08-13
    相关资源
    最近更新 更多