这就是 PyTorch 广播。
It 与 NumPy 广播非常相似,如果您使用该库的话。
这是将标量添加到 2D 张量 m 的示例。
m = torch.rand(3,3)
print(m)
s=1
print(m+s)
# tensor([[0.2616, 0.4726, 0.1077],
# [0.0097, 0.1070, 0.7539],
# [0.9406, 0.1967, 0.1249]])
# tensor([[1.2616, 1.4726, 1.1077],
# [1.0097, 1.1070, 1.7539],
# [1.9406, 1.1967, 1.1249]])
这是另一个添加一维张量和二维张量的示例。
v = torch.rand(3)
print(v)
print(m+v)
# tensor([0.2346, 0.9966, 0.0266])
# tensor([[0.4962, 1.4691, 0.1343],
# [0.2442, 1.1035, 0.7805],
# [1.1752, 1.1932, 0.1514]])
我重写了你的例子:
def activation(x):
return (1/(1+torch.exp(-x)))
images = torch.randn(3,28,28)
inputs = images.view(images.shape[0], -1)
print("INPUTS:", inputs.shape)
W1 = torch.randn(784, 256)
print("W1:", w1.shape)
B1 = torch.randn(256)
print("B1:", b1.shape)
h = activation(torch.mm(inputs,W1) + B1)
出来
INPUTS: torch.Size([3, 784])
W1: torch.Size([784, 256])
B1: torch.Size([256])
解释一下:
INPUTS: of size [3, 784] @ W1: of size [784, 256] 将创建大小为 [3, 256] 的张量
然后加法:
After mm: [3, 256] + B1: [256] 已完成,因为B1 将根据广播的形式形成[3, 256]。