【问题标题】:Conditional Batch Normalization in KerasKeras 中的条件批量标准化
【发布时间】:2019-01-09 00:15:46
【问题描述】:

我正在尝试在 Keras 中实现条件批量标准化。我假设我必须创建一个自定义层,因此,我扩展了 Keras 团队的 Normalization 源代码。

想法: 我将有 3 个条件,因此,我需要初始化 3 个不同的 beta 和 gamma 参数。然后,我只是在需要的地方合并了条件语句。请注意,我的条件在每次迭代后都会随机变化,并尝试根据 3 个全局 Keras 变量 c1、c2 和 c3 设置条件。

这是我目前拥有的代码。由于条件语句,它给了我错误。任何想法如何在 Keras 中改进或实施条件批量标准化:

更新:

from keras import regularizers, initializers, constraints
from keras.legacy import interfaces
import keras.backend as K
from keras.layers import Layer, Input, InputSpec
from keras.models import Model
import tensorflow as tf

global c1, c2, c3
c1 = K.variable([0])
c2 = K.variable([0])
c3 = K.variable([0])

class ConditionalBatchNormalization(Layer):
"""Conditional Batch normalization layer.
"""

@interfaces.legacy_batchnorm_support
def __init__(self, 
             axis=-1,
             momentum=0.99,
             epsilon=1e-3,
             center=True,
             scale=True,
             beta_initializer='zeros',
             gamma_initializer='ones',
             moving_mean_initializer='zeros',
             moving_variance_initializer='ones',
             beta_regularizer=None,
             gamma_regularizer=None,
             beta_constraint=None,
             gamma_constraint=None,
             **kwargs):
    super(ConditionalBatchNormalization, self).__init__(**kwargs)
    self.axis = axis
    self.momentum = momentum
    self.epsilon = epsilon
    self.center = center
    self.scale = scale
    self.beta_initializer = initializers.get(beta_initializer)
    self.gamma_initializer = initializers.get(gamma_initializer)
    self.moving_mean_initializer = initializers.get(moving_mean_initializer)
    self.moving_variance_initializer = (
        initializers.get(moving_variance_initializer))
    self.beta_regularizer = regularizers.get(beta_regularizer)
    self.gamma_regularizer = regularizers.get(gamma_regularizer)
    self.beta_constraint = constraints.get(beta_constraint)
    self.gamma_constraint = constraints.get(gamma_constraint)


def build(self, input_shape):

    dim = input_shape[0][self.axis]
    if dim is None:
        raise ValueError('Axis ' + str(self.axis) + ' of '
                         'input tensor should have a defined dimension '
                         'but the layer received an input with shape ' +
                         str(input_shape[0]) + '.')

    shape = (dim,)

    if self.scale:
        self.gamma1 = self.add_weight(shape=shape,
                                     name='gamma',
                                     initializer=self.gamma_initializer,
                                     regularizer=self.gamma_regularizer,
                                     constraint=self.gamma_constraint)
        self.gamma2 = self.add_weight(shape=shape,
                                     name='gamma',
                                     initializer=self.gamma_initializer,
                                     regularizer=self.gamma_regularizer,
                                     constraint=self.gamma_constraint)
        self.gamma3 = self.add_weight(shape=shape,
                                     name='gamma',
                                     initializer=self.gamma_initializer,
                                     regularizer=self.gamma_regularizer,
                                     constraint=self.gamma_constraint)
    else:
        self.gamma1 = None
        self.gamma2 = None
        self.gamma3 = None

    if self.center:
        self.beta1 = self.add_weight(shape=shape,
                                    name='beta',
                                    initializer=self.beta_initializer,
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint)

        self.beta2 = self.add_weight(shape=shape,
                                    name='beta',
                                    initializer=self.beta_initializer,
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint)

        self.beta3 = self.add_weight(shape=shape,
                                    name='beta',
                                    initializer=self.beta_initializer,
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint)
    else:
        self.beta1 = None
        self.beta2 = None
        self.beta3 = None

    self.moving_mean = self.add_weight(
        shape=shape,
        name='moving_mean',
        initializer=self.moving_mean_initializer,
        trainable=False)
    self.moving_variance = self.add_weight(
        shape=shape,
        name='moving_variance',
        initializer=self.moving_variance_initializer,
        trainable=False)

    super(ConditionalBatchNormalization, self).build(input_shape) 

def call(self, inputs, training=None):

    input_shape = K.int_shape(inputs[0])
    c1 = inputs[1][0]
    c2 = inputs[2][0]

    # Prepare broadcasting shape.
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis]

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    def normalize_inference():
        if needs_broadcasting:
            # In this case we must explicitly broadcast all parameters.
            broadcast_moving_mean = K.reshape(self.moving_mean,
                                              broadcast_shape)
            broadcast_moving_variance = K.reshape(self.moving_variance,
                                                  broadcast_shape)
            if self.center:
                broadcast_beta = \
                    tf.case({
                                c1: lambda: K.reshape(self.beta1,
                                                      broadcast_shape),
                                c2: lambda: K.reshape(self.beta2,
                                                      broadcast_shape)
                            },
                                default=lambda: K.reshape(self.beta3,
                                                          broadcast_shape)
                            )

            else:
                broadcast_beta = None

            if self.scale:

                broadcast_gamma = \
                    tf.case({
                                c1: lambda: K.reshape(self.gamma1,
                                                      broadcast_shape),
                                c2: lambda: K.reshape(self.gamma2,
                                                      broadcast_shape)
                            },
                                default=lambda: K.reshape(self.gamma3,
                                                          broadcast_shape)
                            )

            else:
                broadcast_gamma = None

            return K.batch_normalization(
                inputs[0],
                broadcast_moving_mean,
                broadcast_moving_variance,
                broadcast_beta,
                broadcast_gamma,
                axis=self.axis,
                epsilon=self.epsilon)
        else:
            out = \
            tf.case({
                    c1: lambda: K.batch_normalization(
                                        inputs[0],
                                        self.moving_mean,
                                        self.moving_variance,
                                        self.beta1,
                                        self.gamma1,
                                        axis=self.axis,
                                        epsilon=self.epsilon),
                    c2: lambda: K.batch_normalization(
                                        inputs[0],
                                        self.moving_mean,
                                        self.moving_variance,
                                        self.beta2,
                                        self.gamma2,
                                        axis=self.axis,
                                        epsilon=self.epsilon)
                },
                    default=lambda: K.batch_normalization(
                                        inputs[0],
                                        self.moving_mean,
                                        self.moving_variance,
                                        self.beta3,
                                        self.gamma3,
                                        axis=self.axis,
                                        epsilon=self.epsilon)
                        )

            return out

    # If the learning phase is *static* and set to inference:
    if training in {0, False}:
        return normalize_inference()


    # If the learning is either dynamic, or set to training:
    normed_training, mean, variance = \
        tf.case({
                    c1: lambda: K.normalize_batch_in_training(
                            inputs[0], self.gamma1, self.beta1, reduction_axes,
                            epsilon=self.epsilon),
                    c2: lambda: K.normalize_batch_in_training(
                            inputs[0], self.gamma2, self.beta2, reduction_axes,
                            epsilon=self.epsilon)
                },
                    default=lambda: K.normalize_batch_in_training(
                            inputs[0], self.gamma3, self.beta3, reduction_axes,
                            epsilon=self.epsilon)
                )

    print(normed_training)

    if K.backend() != 'cntk':
        sample_size = K.prod([K.shape(inputs[0])[axis]
                              for axis in reduction_axes])
        sample_size = K.cast(sample_size, dtype=K.dtype(inputs[0]))
        if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
            sample_size = K.cast(sample_size, dtype='float32')

        # sample variance - unbiased estimator of population variance
        variance *= sample_size / (sample_size - (1.0 + self.epsilon))

    self.add_update([K.moving_average_update(self.moving_mean,
                                             mean,
                                             self.momentum),
                     K.moving_average_update(self.moving_variance,
                                             variance,
                                             self.momentum)],
                    inputs[0])

    # Pick the normalized form corresponding to the training phase.

    return K.in_train_phase(normed_training,
                            normalize_inference,
                            training=training)

def get_config(self):
    config = {
        'axis': self.axis,
        'momentum': self.momentum,
        'epsilon': self.epsilon,
        'center': self.center,
        'scale': self.scale,
        'beta_initializer': initializers.serialize(self.beta_initializer),
        'gamma_initializer': initializers.serialize(self.gamma_initializer),
        'moving_mean_initializer':
            initializers.serialize(self.moving_mean_initializer),
        'moving_variance_initializer':
            initializers.serialize(self.moving_variance_initializer),
        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
        'beta_constraint': constraints.serialize(self.beta_constraint),
        'gamma_constraint': constraints.serialize(self.gamma_constraint)
    }
    base_config = super(ConditionalBatchNormalization, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):

    return input_shape[0]


if __name__ == '__main__':
    x = Input((10,))
    c1 = Input(batch_shape=(1,), dtype=tf.bool)
    c2 = Input(batch_shape=(1,), dtype=tf.bool)
    h = ConditionalBatchNormalization()([x, c1, c2])
    model = Model([x, c1, c2], h)
    model.compile(optimizer=Adam(1e-4), loss='mse')

    c1 = K.constant([False]*100, dtype=tf.bool)
    c2 = K.constant([True]*100, dtype=tf.bool)

    X = np.random.rand(100, 10)
    Y = np.random.rand(100, 10)


    model.train_on_batch(x=[X, c1, c2], y=Y)

    c1 = K.constant([False]*100, dtype=tf.bool)
    c2 = K.constant([True]*100, dtype=tf.bool)


    model.train_on_batch(x=[X, c1, c2], y=Y)

`

【问题讨论】:

    标签: python tensorflow machine-learning keras keras-layer


    【解决方案1】:

    我会使用tf.case 来表达您的条件语句:

    normed_training, mean, variance = \
                tf.case({
                    c1: lambda: K.normalize_batch_in_training(
                        inputs, self.gamma1, self.beta1, reduction_axes,
                        epsilon=self.epsilon),
                    c2: lambda: K.normalize_batch_in_training(
                        inputs, self.gamma2, self.beta2, reduction_axes,
                        epsilon=self.epsilon)
                },
                    default=lambda: K.normalize_batch_in_training(
                        inputs, self.gamma3, self.beta3, reduction_axes,
                        epsilon=self.epsilon)
                )
    

    还要注意tf.case 要求条件c1c2 的类型为tf.Tensor,所以我将它们定义如下:

    c1 = K.constant(False, dtype=tf.bool)
    c2 = K.constant(False, dtype=tf.bool)
    

    【讨论】:

    • 谢谢,这似乎工作。现在,我必须在每次迭代后更改训练期间的条件。我打算通过更改 c1、c2 和 c3 值来做到这一点。但是,如果我们将它们定义为常量,我们不能在训练期间改变它吗?我们如何处理这个问题?重新初始化会起作用并相应地更改图形计算吗? @rvinas
    • 我刚刚测试过,似乎它总是遵循初始条件。例如,如果我们将 c1 设置为 True,其余设置为 False,则只有 gamma1 和 beta1 在训练期间得到更新。是否有解决方法让后端遵循 c1、c2 和 c3 的新值?
    • 我会将条件 c1c2 定义为占位符(使用 Keras Input 层),以便您可以在训练期间选择和输入它们的值
    • 谢谢,我以某种方式实现了自定义层并为该层做了 2 个额外的输入。但是,该层未按预期工作。该层有 8 个权重数组(beta1、beta2、beta2、gamma1、gamma2、gamma3、moving_average、moving_variance)。根据我尝试的上述主要功能,moving_average 和moving_variance 没有被更新。此外,beta 和 gamma 权重没有正确更新,即当条件 c1=True 时,gamma2 和 gamma3 的权重仍在更改。我不确定还有什么导致问题的原因。任何的想法? @rvinas
    • 这似乎是一个不同的问题。你能再发一个问题吗?另外,请确保您的错误是可重现的(您的代码在我的机器上不起作用!)并且最小。最美好的祝愿!
    猜你喜欢
    • 2017-10-21
    • 2018-07-28
    • 2019-10-30
    • 2017-11-24
    • 2020-01-15
    • 2018-04-09
    • 2019-05-07
    • 2019-05-23
    • 1970-01-01
    相关资源
    最近更新 更多