【问题标题】:PyTorch weights not updating after softmaxPyTorch 权重在 softmax 之后没有更新
【发布时间】:2020-12-17 23:36:05
【问题描述】:

我正在使用PyTorch 执行优化问题,即找到一组权重w,这样x (sum(w * x) / sum(w)) 的加权平均值可用于估计一些变量,例如y .

下面是我pytorch的“模型”,

dtype = torch.float
device = torch.device('cpu')

class WAvg(nn.Module):
    def __init__(self, p):
        super(WAvg, self).__init__()
        self.p = p
        self.q = nn.Parameter(torch.randn(self.p, 1, device=device, dtype=dtype))
        self.w = nn.functional.softmax(self.q, dim=0)
    def forward(self, x):
        w_avg = nn.functional.linear(x, self.w.T)
        return w_avg

训练代码,

x_tr = np.array([
    [1, 1, 1],
    [1, 4, 1],
    [2, 4, 6], 
    [1, 2, 3], 
    [4, 2, -3], 
    [2, 2, 2] 
])
y_tr = np.array([1, 2.1, 3.9, 2, 1.2, 1.8])

x_tr = torch.from_numpy(x_tr).float()
y_tr = torch.from_numpy(y_tr).float()


wa = WAvg(3)

criterion = nn.MSELoss()
optimizer = optim.Adam(wa.parameters(), lr=0.01)

for epoch in range(10):
    # Set running loss
    running_loss_tr = 0.0
    # zero the parameter gradients
    optimizer.zero_grad()
    # forward + backward + optimize
    y_pred_tr = wa(x_tr)
    loss_tr = criterion(y_pred_tr, y_tr)
    loss_tr.backward()
    optimizer.step()
    # print statistics
    print(epoch, loss_tr.item())

这会报错

RuntimeError: 试图第二次向后遍历图形,但缓冲区已被释放。第一次向后调用时指定retain_graph=True。

loss_tr.backward() 中添加了参数retain_graph=True(如此post 中所建议的那样),但参数qw 似乎没有更新。我认为问题应该是由softmax 引起的,它对权重进行了限制,使它们的总和为一,有什么解决办法吗?

输出:

0 1.305460810661316
1 1.305460810661316
2 1.305460810661316
3 1.305460810661316
4 1.305460810661316
5 1.305460810661316
6 1.305460810661316
7 1.305460810661316
8 1.305460810661316
9 1.305460810661316

【问题讨论】:

  • 您在初始化期间计算w。它不是叶节点,因此不会获得渐变,也不会更新。网络中唯一的叶节点是 q,但更改 q 不会对 w 产生任何影响,因为在转发期间您不会重新计算 w。应该起作用的一件事是将self.w = 行移动到forward 中的第一行。

标签: pytorch softmax weighted-average


【解决方案1】:
optimizer = optim.Adam(ma.parameters(), lr=0.01)

'ma'.parameters()....你不是说 wa 吗?

这就是问题所在。

它解释了为什么您没有更新正确的参数,并且由于您没有将正确的梯度归零,这表明您在同一个图表中反向传播了两次。

【讨论】:

  • 感谢您发现我的错字,但将ma 更改为wa 后参数仍未更新。
猜你喜欢
  • 1970-01-01
  • 2018-12-08
  • 2018-11-19
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2019-04-09
  • 2019-05-06
相关资源
最近更新 更多