【问题标题】:How to implement Batchnorm2d in Pytorch myself?如何自己在 Pytorch 中实现 Batchnorm2d?
【发布时间】:2021-01-29 12:14:00
【问题描述】:

我正在尝试通过以下方式实现 Batchnorm2d() 层:

class BatchNorm2d(nn.Module):

    def __init__(self, num_features):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.eps = 1e-5
        self.momentum = 0.1
        self.first_run = True

    def forward(self, input):
        # input: [batch_size, num_feature_map, height, width]
        device = input.device
        if self.training:
            mean = torch.mean(input, dim=0, keepdim=True).to(device)  # [1, num_feature, height, width]
            var = torch.var(input, dim=0, unbiased=False, keepdim=True).to(device)  # [1, num_feature, height, width]
            if self.first_run:
                self.weight = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.bias = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.register_buffer('running_mean', torch.zeros(input.shape).to(input.device))
                self.register_buffer('running_var', torch.ones(input.shape).to(input.device))
                self.first_run = False
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            bn_init = (input - mean) / torch.sqrt(var + self.eps)
        else:
            bn_init = (input - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        return self.weight * bn_init + self.bias

但经过训练和测试后,我发现使用我的层的结果与使用nn.Batchnorm2d() 的结果无法比拟。肯定有问题,我猜这个问题与forward()中的初始化参数有关?我这样做是因为我不知道如何知道__init__() 中输入的形状,也许有更好的方法。我不知道如何解决它,请帮助。谢谢!!

【问题讨论】:

标签: python pytorch batch-normalization


【解决方案1】:

HERE得到答案!\
所以weight(bias)的形状是(1, num_features, 1, 1),而不是(1, num_features, width, height)。

【讨论】:

    【解决方案2】:

    如果有人偶然发现这个, 您实际上不必像上面那样在模型中设置“设备”。在模型之外,您可以这样做

    device = torch.device('cuda:0') model = model.to(设备)

    不确定这是否比手动设置模块内的权重和偏差设备更好,但我认为肯定更标准

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-08-30
      • 2023-03-14
      • 2021-04-09
      • 2020-11-09
      • 1970-01-01
      • 1970-01-01
      • 2021-06-03
      • 2020-12-14
      相关资源
      最近更新 更多