【发布时间】:2020-08-19 04:20:42
【问题描述】:
KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
或
def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
它们有什么关系?为什么代码中没有“tr”或“.transpose()”?
【问题讨论】:
-
@jodag 非常有帮助,谢谢
-
@jodag 关于torch.sum和torch.mean,你说“这可能意味着你需要不同的学习率”,但是KL损失并不是唯一的损失项,loss=kl_loss+recon_loss,这是否意味着损失实际上是具有不同权重的加权和?
-
是的,如果您使用均值而不是总和,则 kl_loss 分量的权重将隐式低于原始公式,这可能会影响损失函数的最佳点,并可能影响最终结果。
标签: pytorch autoencoder loss-function