【问题标题】:PyTorch broadcasting: how this worked? [duplicate]PyTorch 广播:这是如何工作的? [复制]
【发布时间】:2019-11-19 03:44:57
【问题描述】:

我是深度学习的新手。我在 Udacity 学习。

我遇到了一个构建神经网络的代码,其中添加了 2 个张量,特别是带有张量乘积输出的“偏差”张量。

有点……

def activation(x):
return (1/(1+torch.exp(-x)))

inputs = images.view(images.shape[0], -1)
w1 = torch.randn(784, 256)
b1 = torch.randn(256)
h = activation(torch.mm(inputs,w1) + b1)

在展平 MNIST 后,结果为 [64,784](输入)。

我不知道如何将维度 [256] 的偏置张量 (b1) 添加到“输入”和“w1”的乘积中,结果是 [256, 64] 的维度。

【问题讨论】:

标签: python machine-learning pytorch


【解决方案1】:

简单来说,每当我们使用 Python 库(Numpy 或 PyTorch)中的“广播”时,我们所做的就是处理我们的数组(权重、偏差)在维度上兼容。

换句话说,如果您使用形状为 [256,64] 的 W 进行操作,并且您的偏差仅为 [256]。然后,广播将完成那个缺乏的维度。

如上图所示,左边的维度正在被填充,以便我们的操作可以成功完成。希望对你有帮助

【讨论】:

  • 但我尝试添加大小为 [3,3] 和 [2,3] 的张量,但没有成功。它仅适用于单维,即行张量或列张量?
【解决方案2】:

这就是 PyTorch 广播。

It 与 NumPy 广播非常相似,如果您使用该库的话。 这是将标量添加到 2D 张量 m 的示例。

m = torch.rand(3,3)
print(m)
s=1
print(m+s)

# tensor([[0.2616, 0.4726, 0.1077],
#         [0.0097, 0.1070, 0.7539],
#         [0.9406, 0.1967, 0.1249]])
# tensor([[1.2616, 1.4726, 1.1077],
#         [1.0097, 1.1070, 1.7539],
#         [1.9406, 1.1967, 1.1249]])

这是另一个添加一维张量和二维张量的示例。

v = torch.rand(3)
print(v)
print(m+v)

# tensor([0.2346, 0.9966, 0.0266])
# tensor([[0.4962, 1.4691, 0.1343],
#         [0.2442, 1.1035, 0.7805],
#         [1.1752, 1.1932, 0.1514]])

我重写了你的例子:

def activation(x):
    return (1/(1+torch.exp(-x)))

images = torch.randn(3,28,28)
inputs = images.view(images.shape[0], -1)
print("INPUTS:", inputs.shape)

W1 = torch.randn(784, 256)
print("W1:", w1.shape)
B1 = torch.randn(256)
print("B1:", b1.shape)
h = activation(torch.mm(inputs,W1) + B1)

出来

INPUTS: torch.Size([3, 784])
W1: torch.Size([784, 256])
B1: torch.Size([256])

解释一下:

INPUTS: of size [3, 784] @ W1: of size [784, 256] 将创建大小为 [3, 256] 的张量

然后加法:

After mm: [3, 256] + B1: [256] 已完成,因为B1 将根据广播的形式形成[3, 256]

【讨论】:

    【解决方案3】:

    64 是您的批次大小,这意味着偏差张量将添加到您批次中的 64 个示例中的每一个中。基本上就像你取了 64 个大小为 256 的张量并为每个张量添加了偏差。 Pytorch 会自然地将 256 张量广播到 64*256 大小,可以添加到先例层的 64*256 输出中。

    【讨论】:

      猜你喜欢
      • 2018-12-24
      • 2017-04-02
      • 2023-04-04
      • 2020-05-01
      • 2018-04-22
      • 2014-12-04
      • 1970-01-01
      • 1970-01-01
      • 2017-10-27
      相关资源
      最近更新 更多