【问题标题】:PyTorch Softmax Output Doesn't Sum to 1PyTorch Softmax 输出总和不等于 1
【发布时间】:2020-02-25 04:40:15
【问题描述】:

交叉发帖my question from the PyTorch forum:

我开始在目标 Dirichlet 分布和我的模型的输出 Dirichlet 分布之间收到负 KL 散度。网上有人说,这可能说明狄利克雷分布的参数和不等于1。我觉得这很可笑,因为模型的输出是通过的

output = F.softmax(self.weights(x), dim=1)

但仔细研究后,我发现torch.all(torch.sum(output, dim=1) == 1.) 返回 False!查看有问题的行,我看到它是tensor([0.0085, 0.9052, 0.0863], grad_fn=<SelectBackward>)。但是torch.sum(output[5]) == 1. 产生tensor(False)

我对 softmax 有什么误解,以至于输出概率之和不等于 1?

这是 PyTorch 版本 1.2.0+cpu。完整模型复制如下:

import torch
import torch.nn as nn
import torch.nn.functional as F



def assert_no_nan_no_inf(x):
    assert not torch.isnan(x).any()
    assert not torch.isinf(x).any()


class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Linear(
            in_features=2,
            out_features=3)

    def forward(self, x):
        output = F.softmax(self.weights(x), dim=1)
        assert torch.all(torch.sum(output, dim=1) == 1.)
        assert_no_nan_no_inf(x)
        return output

【问题讨论】:

  • 我在这里遗漏了什么吗? “有问题”的行总和为 1:sum([0.0085, 0.9052, 0.0863])
  • 也许你应该考虑一个能容忍小的精度误差的断言,例如:s = torch.sum(output, dim=1); torch.allclose(s, torch.ones_like(s))
  • 行。但是 Python + Pytorch 说该行没有。
  • 不幸的是,似乎容忍浮点错误可能会产生负的 KL 散度(如果密度不再等于 1)。
  • 我明白了。那么,我建议删除并发布一个具有不同视角的新内容。现在它具有误导性。

标签: pytorch softmax


【解决方案1】:

这很可能是由于有限精度导致的浮点数值错误。

您应该检查均方误差或在可接受的范围内,而不是检查严格的不等式。

例如:我得到 torch.norm(output.sum(dim=1)-1)/N 小于 1e-8。 N 是批量大小。

【讨论】:

  • 当然。我的问题是容忍浮点错误可能会产生负的 KL 散度(如果密度不再等于 1)。
  • 我不确定您对 KL 使用的确切公式是什么,但我只能说它可能在数值上更安全。
  • 那我应该告诉 PyTorch 的开发人员吗? :)
  • 我不确定这一点,因为我看不出总和为 1 是问题的原因。你会使用 \sum p_i log p_i/q_i 对吗?计算 KL 对吗?
猜你喜欢
  • 2020-03-19
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2014-11-17
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多