【发布时间】:2020-02-07 14:42:58
【问题描述】:
在 Keras 中,使用Flatten() 层会保留批量大小。例如,如果 Flatten 的输入形状是 (32, 100, 100),在 Keras 中,Flatten 的输出是 (32, 10000),但在 PyTorch 中是 320000。为什么会这样?
【问题讨论】:
在 Keras 中,使用Flatten() 层会保留批量大小。例如,如果 Flatten 的输入形状是 (32, 100, 100),在 Keras 中,Flatten 的输出是 (32, 10000),但在 PyTorch 中是 320000。为什么会这样?
【问题讨论】:
正如 OP 在他们的回答中已经指出的那样,张量操作不会默认考虑批量维度。您可以使用torch.flatten() 或Tensor.flatten() 和start_dim=1 在批量维度之后开始展平操作。
或者,从 PyTorch 1.2.0 开始,您可以在模型中定义一个 nn.Flatten() 层,默认为 start_dim=1。
【讨论】:
DataLoader 对象中,batch_size 与输入的第一维混合(在我的 GNN 中,从DataLoader 对象解压缩的输入大小为[batch_size*node_num, attribute_num]),所以如果我使用torch.flatten(),则样本是混合的一起,该网络将只有 1 个输出,而我期望 #batch_size 输出。如果我使用nn.Flatten(),似乎什么也没有发生,这一层的输出仍然是[batch_size*node_num, attribute_num]。我该如何处理?
是的,正如this thread 中提到的,PyTorch 操作如 Flatten、view、reshape。
一般来说,当使用像Conv2d 这样的模块时,你不需要担心批量大小。 PyTorch 负责处理它。但是当直接处理张量时,你需要注意批量大小。
在 Keras 中,Flatten() 是一个层。但在 PyTorch 中,flatten() 是对张量的操作。因此,批量大小需要手动处理。
【讨论】: