【发布时间】:2021-04-29 09:10:08
【问题描述】:
我正在使用 PyTorch 使用 REINFORCE 算法。我注意到我使用 Softmax 的简单网络的批量推理/预测总和不等于 1(甚至不接近 1)。我附上了最低限度的工作代码,以便您可以重现它。我在这里错过了什么?
import numpy as np
import torch
obs_size = 9
HIDDEN_SIZE = 9
n_actions = 2
np.random.seed(0)
model = torch.nn.Sequential(
torch.nn.Linear(obs_size, HIDDEN_SIZE),
torch.nn.ReLU(),
torch.nn.Linear(HIDDEN_SIZE, n_actions),
torch.nn.Softmax(dim=0)
)
state_transitions = np.random.rand(3, obs_size)
state_batch = torch.Tensor(state_transitions)
pred_batch = model(state_batch) # WRONG PREDICTIONS!
print('wrong predictions:\n', *pred_batch.detach().numpy())
# [0.34072137 0.34721774] [0.30972624 0.30191955] [0.3495524 0.3508627]
# DOES NOT SUM TO 1 !!!
pred_batch = [model(s).detach().numpy() for s in state_batch] # CORRECT PREDICTIONS
print('correct predictions:\n', *pred_batch)
# [0.5955179 0.40448207] [0.6574412 0.34255883] [0.624833 0.37516695]
# DOES SUM TO 1 AS EXPECTED
【问题讨论】:
标签: python pytorch inference softmax