【发布时间】:2020-02-25 04:40:15
【问题描述】:
交叉发帖my question from the PyTorch forum:
我开始在目标 Dirichlet 分布和我的模型的输出 Dirichlet 分布之间收到负 KL 散度。网上有人说,这可能说明狄利克雷分布的参数和不等于1。我觉得这很可笑,因为模型的输出是通过的
output = F.softmax(self.weights(x), dim=1)
但仔细研究后,我发现torch.all(torch.sum(output, dim=1) == 1.) 返回 False!查看有问题的行,我看到它是tensor([0.0085, 0.9052, 0.0863], grad_fn=<SelectBackward>)。但是torch.sum(output[5]) == 1. 产生tensor(False)。
我对 softmax 有什么误解,以至于输出概率之和不等于 1?
这是 PyTorch 版本 1.2.0+cpu。完整模型复制如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
def assert_no_nan_no_inf(x):
assert not torch.isnan(x).any()
assert not torch.isinf(x).any()
class Network(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Linear(
in_features=2,
out_features=3)
def forward(self, x):
output = F.softmax(self.weights(x), dim=1)
assert torch.all(torch.sum(output, dim=1) == 1.)
assert_no_nan_no_inf(x)
return output
【问题讨论】:
-
我在这里遗漏了什么吗? “有问题”的行总和为 1:
sum([0.0085, 0.9052, 0.0863]) -
也许你应该考虑一个能容忍小的精度误差的断言,例如:
s = torch.sum(output, dim=1); torch.allclose(s, torch.ones_like(s)) -
行。但是 Python + Pytorch 说该行没有。
-
不幸的是,似乎容忍浮点错误可能会产生负的 KL 散度(如果密度不再等于 1)。
-
我明白了。那么,我建议删除并发布一个具有不同视角的新内容。现在它具有误导性。