【问题标题】:How to `dot` weights to batch data in PyTorch?如何在 PyTorch 中为批处理数据添加“点”权重?
【发布时间】:2017-11-14 07:55:57
【问题描述】:

我有批量数据,想dot() 到数据。 W 是可训练的参数。 批量数据和权重之间如何打点?

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)

更新

这个怎么样?

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    扩展W张量以匹配data张量的形状。以下应该可以工作。

    hid_dim = 32
    data = torch.randn(10, 2, 3, hid_dim)
    data = data.view(10, 2*3, hid_dim)
    W = torch.randn(hid_dim)
    W = W.unsqueeze(0).unsqueeze(0).expand(*data.size())
    result = torch.sum(data * W, 2)
    result = result.view(10, 2, 3)
    

    编辑:您更新的代码是正确的。由于您正在将W 转换为Bxhid_dimx1 并且您的数据的形状为Bxdxhid_dim,因此进行批量矩阵乘法将导致Bxdx1 这本质上是W 参数和所有行向量之间的点积data (dxhid_dim)。

    【讨论】:

    • 我更新了我的帖子。但是你的代码看起来比我好。我的代码也正确吗?
    猜你喜欢
    • 2020-03-14
    • 2020-10-09
    • 2021-05-28
    • 2022-01-23
    • 2021-05-05
    • 1970-01-01
    • 2020-11-05
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多