【问题标题】:Tensorflow, negative KL DivergenceTensorflow,负KL散度
【发布时间】:2018-08-10 14:15:40
【问题描述】:

我正在使用变分自动编码器类型模型,我的损失函数的一部分是均值 0 和方差 1 的正态分布与另一个均值和方差由我的模型预测的正态分布之间的 KL 散度。

我是这样定义损失的:

def kl_loss(mean, log_sigma):
    normal=tf.contrib.distributions.MultivariateNormalDiag(tf.zeros(mean.get_shape()),
                                                           tf.ones(log_sigma.get_shape()))
    enc_normal = tf.contrib.distributions.MultivariateNormalDiag(mean,
                                                                     tf.exp(log_sigma),
                                                                     validate_args=True,
                                                                     allow_nan_stats=False,
                                                                     name="encoder_normal")
    kl_div = tf.contrib.distributions.kl_divergence(normal,
                                                    enc_normal,
                                                    allow_nan_stats=False,
                                                    name="kl_divergence")
return kl_div

输入是长度为 N 的无约束向量

log_sigma.get_shape() == mean.get_shape()

现在在训练期间,我观察到在几千次迭代后出现负的 KL 散度,直到值 -10。下面你可以看到 Tensorboard 的训练曲线:

KL divergence curve

Zoom in of KL divergence curve

现在这对我来说似乎很奇怪,因为 KL 背离在某些条件下应该是正的。我知道我们需要“仅当 P 和 Q 都为 1 并且对于任何 i 使得 P(i) > 0 的 Q(i) > 0 时才定义 K-L 散度。” (请参阅https://mathoverflow.net/questions/43849/how-to-ensure-the-non-negativity-of-kullback-leibler-divergence-kld-metric-rela)但我看不出在我的情况下如何违反这一点。非常感谢任何帮助!

【问题讨论】:

  • 你最后一层的激活函数是什么?
  • 最后一层是具有线性(无)激活函数 (tensorflow.org/api_docs/python/tf/layers/conv3d) 和内核大小为 1 的 3D 卷积层。我将生成的张量展平,前半部分是我的平均值,第二部分一半 log_sigma。
  • 那么最后一层的输出可以大于1吗?
  • 嗯,是的,我知道。从代码 sn-p 和我的帖子中可以清楚地看出,我正在初始化两个具有对角协方差矩阵的多元正态分布。这些分布之一的均值和方差由我的网络输出决定。其中唯一的限制是 sigma 是正数。这是通过采用 exp(log-sigma) 来确保的。那你问的是什么?
  • 它与输入无关,但您显然没有阅读我正在写的内容。上面的函数应该为 ANY 输入返回一个正的 KL 散度。我怀疑这是一个数值问题,与此函数在 Tensorflow 中实现的方式有关。我只是想听听偶然发现同一问题的人对此的另一种看法。

标签: python tensorflow machine-learning statistics distribution


【解决方案1】:

面临同样的问题。 这是因为使用了浮点精度。 如果您注意到负值接近 0 并且限制为一个小的负值。为损失添加一个小的正值是一种解决方法。

【讨论】:

  • 嗨!请更详细地回答您的问题并提供解决方案,或者您可以将此答案移至评论部分。
猜你喜欢
  • 2017-11-02
  • 1970-01-01
  • 2017-06-11
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2019-04-01
  • 2017-10-03
  • 2018-09-27
相关资源
最近更新 更多