【问题标题】:Group Normalization and Weight Standardization in KerasKeras 中的组归一化和权重标准化
【发布时间】:2021-05-24 02:24:45
【问题描述】:

按照原始论文https://arxiv.org/pdf/1903.10520v1.pdf,我正在使用 keras 在 resnet 50 上的 tensorflow 中实现权重标准化和组标准化。

虽然权重标准化适用于所有卷积层,但在某些 Conv2d 层之后似乎存在组标准化问题。在这些情况下,不会减少损失或提高准确性。我在 CIFAR10 上使用不同批量大小 (16-512) 的 model.fit 训练返回的模型。标记了导致问题的组规范的位置。 谁能告诉我为什么这可能是一个问题?

!pip install -q -U tensorflow-addons

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization, Dense, Add, Concatenate
from tensorflow.keras.layers import ZeroPadding2D, Input, MaxPooling2D, AveragePooling2D, Flatten
from tensorflow.keras import activations
from tensorflow.keras import Model
from tensorflow_addons.layers import GroupNormalization

def ws_reg(kernel):
    kernel_mean = tf.math.reduce_mean(kernel, axis=[0, 1, 2], keepdims=True, name='kernel_mean')
    kernel = kernel - kernel_mean
    kernel_std = tf.keras.backend.std(kernel, axis=[0, 1, 2], keepdims=True)
    kernel = kernel / (kernel_std + 1e-5)
    #return kernel
    
def res_identity(x, filters): 
  #renet block where dimension doesnot change.
  #The skip connection is just simple identity conncection
  #we will have 3 blocks and then input will be added

  x_skip = x # this will be used for addition with the residual block 
  f1, f2 = filters

  #first block 
  x = Conv2D(f1, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  #second block # bottleneck (but size kept same with padding)
  x = Conv2D(f1, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  # third block activation used after adding the input
  x = Conv2D(f2, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x) ###ISSUE
  #x = BatchNormalization()(x)

  # add the input 
  x = Add()([x, x_skip])
  x = Activation(activations.relu)(x)

  return x


def res_conv(x, s, filters):
  '''
  here the input size changes''' 
  x_skip = x
  f1, f2 = filters

  # first block
  x = Conv2D(f1, kernel_size=(1, 1), strides=(s, s), padding='valid', kernel_regularizer=ws_reg)(x)
  # when s = 2 then it is like downsizing the feature map
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  # second block
  x = Conv2D(f1, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  #third block
  x = Conv2D(f2, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x) ###ISSUE
  #x = BatchNormalization()(x)

  # shortcut 
  x_skip = Conv2D(f2, kernel_size=(1, 1), strides=(s, s), padding='valid', kernel_regularizer=ws_reg)(x_skip)
  x_skip = GroupNormalization(groups=16, axis=-1)(x_skip) ###ISSUE
  #x_skip = BatchNormalization()(x_skip)

  # add 
  x = Add()([x, x_skip])
  x = Activation(activations.relu)(x)

  return x

  
def resnet50(train_im):

  input_im = Input(shape=(train_im[0], train_im[1], train_im[2])) # cifar 10 images size
  x = ZeroPadding2D(padding=(3, 3))(input_im)

  # 1st stage
  # here we perform maxpooling, see the figure above

  x = Conv2D(64, kernel_size=(7, 7), strides=(2, 2), kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x) ###ISSUE
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)
  x = MaxPooling2D((3, 3), strides=(2, 2))(x)

  #2nd stage 
  # from here on only conv block and identity block, no pooling

  x = res_conv(x, s=1, filters=(64, 256))
  x = res_identity(x, filters=(64, 256))
  x = res_identity(x, filters=(64, 256))

  # 3rd stage

  x = res_conv(x, s=2, filters=(128, 512))
  x = res_identity(x, filters=(128, 512))
  x = res_identity(x, filters=(128, 512))
  x = res_identity(x, filters=(128, 512))

  # 4th stage

  x = res_conv(x, s=2, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))

  # 5th stage

  x = res_conv(x, s=2, filters=(512, 2048))
  x = res_identity(x, filters=(512, 2048))
  x = res_identity(x, filters=(512, 2048))

  # ends with average pooling and dense connection

  x = AveragePooling2D((2, 2), padding='same')(x)

  x = Flatten()(x)
  x = Dense(10, activation='softmax', kernel_initializer='he_normal')(x) #multi-class

  # define the model 

  model = Model(inputs=input_im, outputs=x, name='Resnet50')

  return model

【问题讨论】:

    标签: python python-3.x tensorflow keras normalization


    【解决方案1】:

    在我看来,将ws_reg 函数传递给Conv2D 中的kernel_regularizer 是进行权重标准化的不正确方式。

    根据 TF 和 Keras docs 的输出 kernel_regularizer 添加到 loss,不应用于内核。正确的方式可能是这样的:

    class WSConv2D(tf.keras.layers.Conv2D):
        def __init__(self, *args, **kwargs):
            super(WSConv2D, self).__init__(kernel_initializer="he_normal", *args, **kwargs)
    
        def standardize_weight(self, weight, eps):
    
            mean = tf.math.reduce_mean(weight, axis=(0, 1, 2), keepdims=True)
            var = tf.math.reduce_variance(weight, axis=(0, 1, 2), keepdims=True)
            fan_in = np.prod(weight.shape[:-1])
            gain = self.add_weight(
                name="gain",
                shape=(weight.shape[-1],),
                initializer="ones",
                trainable=True,
                dtype=self.dtype,
            )
            scale = (
                tf.math.rsqrt(
                    tf.math.maximum(var * fan_in, tf.convert_to_tensor(eps, dtype=self.dtype))
                )
                * gain
            )
            return weight * scale - (mean * scale)
    
        def call(self, inputs, eps=1e-4):
            self.kernel.assign(self.standardize_weight(self.kernel, eps))
            return super().call(inputs)
    

    【讨论】:

      猜你喜欢
      • 2020-01-15
      • 1970-01-01
      • 1970-01-01
      • 2017-07-31
      • 1970-01-01
      • 2015-11-13
      • 2017-10-21
      • 1970-01-01
      • 2018-05-22
      相关资源
      最近更新 更多