【问题标题】:PyTorch flatten doesn't maintain batch sizePyTorch flatten 不保持批量大小
【发布时间】:2020-02-07 14:42:58
【问题描述】:

在 Keras 中,使用Flatten() 层会保留批量大小。例如,如果 Flatten 的输入形状是 (32, 100, 100),在 Keras 中,Flatten 的输出是 (32, 10000),但在 PyTorch 中是 320000。为什么会这样?

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    正如 OP 在他们的回答中已经指出的那样,张量操作不会默认考虑批量维度。您可以使用torch.flatten()Tensor.flatten()start_dim=1 在批量维度之后开始展平操作。

    或者,从 PyTorch 1.2.0 开始,您可以在模型中定义一个 nn.Flatten() 层,默认为 start_dim=1

    【讨论】:

    • 我在训练 GNN 时尝试了这两种方法,但我似乎遗漏了一些东西。在我的DataLoader 对象中,batch_size 与输入的第一维混合(在我的 GNN 中,从DataLoader 对象解压缩的输入大小为[batch_size*node_num, attribute_num]),所以如果我使用torch.flatten(),则样本是混合的一起,该网络将只有 1 个输出,而我期望 #batch_size 输出。如果我使用nn.Flatten(),似乎什么也没有发生,这一层的输出仍然是[batch_size*node_num, attribute_num]。我该如何处理?
    【解决方案2】:

    是的,正如this thread 中提到的,PyTorch 操作如 Flatten、view、reshape。

    一般来说,当使用像Conv2d 这样的模块时,你不需要担心批量大小。 PyTorch 负责处理它。但是当直接处理张量时,你需要注意批量大小。

    在 Keras 中,Flatten() 是一个层。但在 PyTorch 中,flatten() 是对张量的操作。因此,批量大小需要手动处理。

    【讨论】:

      猜你喜欢
      • 2021-07-31
      • 2019-02-03
      • 2020-11-08
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-02-24
      • 2020-05-20
      • 2020-09-21
      相关资源
      最近更新 更多