【发布时间】:2020-04-19 02:12:54
【问题描述】:
我正在尝试在 pytorch 中进行批处理。在我下面的代码中,您可能会将x 视为批量大小为 2 的批次(每个样本是一个 10d 向量)。我使用x_sep 表示x 中的第一个样本。
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.fc1 = nn.Linear(10,10)
def forward(self, x):
x = self.fc1(x)
return x
f = net()
x = torch.randn(2,10)
print(f(x[0])==f(x)[0])
理想情况下,f(x[0])==f(x)[0] 应该给出一个包含所有真实条目的张量。但是我电脑上的输出是
tensor([False, False, True, True, False, False, False, False, True, False])
为什么会这样?是计算错误吗?还是与pytorch中如何实现批处理有关?
更新:我稍微简化了代码。问题还是一样。
我的推理:
我相信f(x)[0]==f(x[0]) 应该有它的所有条目True,因为矩阵乘法是这样说的。让我们将x 视为一个2x10 矩阵,并将线性变换f() 视为由矩阵B 表示(暂时忽略偏差)。然后用我们的符号f(x)=xB。矩阵乘法告诉我们xB等于先将两行分别乘以右边的B,然后再将两行放回一起。翻译回代码,是f(x[0])==f(x)[0]和f(x[1])==f(x)[1]。
即使我们考虑偏差,每一行都应该有相同的偏差,并且相等性应该仍然成立。
另请注意,此处未进行任何培训。因此,如何初始化权重并不重要。
【问题讨论】:
-
你的最后一层返回接收 10 个特征向量并返回一个 10 个特征向量。那么问题是什么以及您在哪里使用批处理?
-
@Green 假设我将样本定义为 10 个分量的向量。那么你可以把这里的
x看作是一批2个样本,x_sep是x中的第一个样本。将线性变换应用于x,你得到y,一批大小为2。y[0]不应该等于f(x_sep)==y_sep吗?但我的结果告诉我没有,为什么? -
@Green 或者松散地说,为什么
f(x[0])==f(x)[0]不成立? -
我详细回答了你。希望你现在清楚了。
标签: python numpy pytorch batch-processing