【问题标题】:Batch inference of softmax does not sum to 1softmax 的批量推断总和不等于 1
【发布时间】:2021-04-29 09:10:08
【问题描述】:

我正在使用 PyTorch 使用 REINFORCE 算法。我注意到我使用 Softmax 的简单网络的批量推理/预测总和不等于 1(甚至不接近 1)。我附上了最低限度的工作代码,以便您可以重现它。我在这里错过了什么?

import numpy as np
import torch

obs_size = 9
HIDDEN_SIZE = 9
n_actions = 2

np.random.seed(0)

model = torch.nn.Sequential(
        torch.nn.Linear(obs_size, HIDDEN_SIZE),
        torch.nn.ReLU(),
        torch.nn.Linear(HIDDEN_SIZE, n_actions),
        torch.nn.Softmax(dim=0)
    )

state_transitions = np.random.rand(3, obs_size)

state_batch = torch.Tensor(state_transitions)
pred_batch = model(state_batch)  # WRONG PREDICTIONS!
print('wrong predictions:\n', *pred_batch.detach().numpy())
# [0.34072137 0.34721774] [0.30972624 0.30191955] [0.3495524 0.3508627]
# DOES NOT SUM TO 1 !!!

pred_batch = [model(s).detach().numpy() for s in state_batch]  # CORRECT PREDICTIONS
print('correct predictions:\n', *pred_batch)
# [0.5955179  0.40448207] [0.6574412  0.34255883] [0.624833   0.37516695]
# DOES SUM TO 1 AS EXPECTED

【问题讨论】:

    标签: python pytorch inference softmax


    【解决方案1】:

    虽然 PyTorch 让我们侥幸逃脱,但我们实际上并没有提供具有正确维度的输入。我们有一个接受一个输入并产生一个输出的模型,但 PyTorch nn.Module 及其子类旨在同时处理多个样本。为了容纳多个样本,模块期望输入的第零维是 batch 中的样本数。

    您的模型适用于每个单独的样本是一种实现方式。您错误地指定了 softmax 的维度(跨批次而不是跨变量),因此当给定批次维度时,它计算的是跨样本而不是样本内的 softmax:

    nn.Softmax 要求我们指定使用 softmax 函数的维度:

    softmax = nn.Softmax(dim=1)
    

    在这种情况下,我们在两行中有两个输入向量(就像我们使用 批次),所以我们初始化 nn.Softmax沿维度 1 操作。

    torch.nn.Softmax(dim=0) 更改为torch.nn.Softmax(dim=1) 以获得适当的结果。

    【讨论】:

      猜你喜欢
      • 2020-02-25
      • 2020-03-19
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2014-11-17
      • 2021-06-03
      • 1970-01-01
      • 2016-03-22
      相关资源
      最近更新 更多