【问题标题】:Batch Normalization in tf.keras does not calculate average mean and average variancetf.keras 中的批量标准化不计算平均均值和平均方差
【发布时间】:2019-08-20 15:55:26
【问题描述】:

here 提出了一个类似的未回答问题。 我正在测试一种在 tensorflow 中使用 keras 后端的深度强化学习算法。我对 tf.keras 不是很熟悉,但是想添加批量标准化层。因此,我尝试使用tf.keras.layers.BatchNormalization(),但它不会更新平均均值和方差,因为update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 是空的。

使用常规的tf.layers.batch_normalization 似乎可以正常工作。不过因为完整的算法有些复杂,我需要想办法使用tf.keras

标准tfbatch_normed = tf.layers.batch_normalization(hidden, training=True) 更新平均值,因为update_ops 不为空:

[
    <tf.Operation 'batch_normalization/AssignMovingAvg' type=AssignSub>, 
    <tf.Operation 'batch_normalization/AssignMovingAvg_1' type=AssignSub>, 
    <tf.Operation 'batch_normalization_1/AssignMovingAvg' type=AssignSub>, 
    <tf.Operation 'batch_normalization_1/AssignMovingAvg_1' type=AssignSub>
]

不起作用的最小示例:

import tensorflow as tf
import numpy as np

tf.reset_default_graph()
graph = tf.get_default_graph()
tf.keras.backend.set_learning_phase(True)

input_shapes = [(3, )]
hidden_layer_sizes = [16, 16]

inputs = [
    tf.keras.layers.Input(shape=input_shape)
    for input_shape in input_shapes
]

concatenated = tf.keras.layers.Lambda(
    lambda x: tf.concat(x, axis=-1)
)(inputs)

out = concatenated
for units in hidden_layer_sizes:      
    hidden = tf.keras.layers.Dense(
    units, activation=None
    )(out)
    batch_normed = tf.keras.layers.BatchNormalization()(hidden, training=True)
    #batch_normed = tf.layers.batch_normalization(hidden, training=True)
    out = tf.keras.layers.Activation('relu')(batch_normed)

out = tf.keras.layers.Dense(
    units=1, activation='linear'
)(out)


data = np.random.rand(100,3)
with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(10):

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    sess.run(update_ops,  {inputs[0]: data})
    sess.run(out, {inputs[0]: data})

    variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                          scope='batch_normalization')
    bn_gamma, bn_beta, bn_moving_mean, bn_moving_variance = [], [], [], []
    for variable in variables:
        val = sess.run(variable)
        nv = np.linalg.norm(val)
        if 'gamma' in variable.name:
            bn_gamma.append(nv)
        if 'beta' in variable.name:
            bn_beta.append(nv)
        if 'moving_mean' in variable.name:
            bn_moving_mean.append(nv)
        if 'moving_variance' in variable.name:
            bn_moving_variance.append(nv)

        diagnostics = {
            'bn_Q_gamma': np.mean(bn_gamma),
            'bn_Q_beta': np.mean(bn_beta),
            'bn_Q_moving_mean': np.mean(bn_moving_mean),
            'bn_Q_moving_variance': np.mean(bn_moving_variance),
        }

    print(diagnostics)

输出如下(可以看到moving_mean和moving_variance没有变化):

{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}

虽然预期的输出类似于以下内容(使用tf.keras 注释batch_normed calculus 行并取消注释它下面的行):

{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0148749575, 'bn_Q_moving_variance': 3.966927}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.029601166, 'bn_Q_moving_variance': 3.934192}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.04418011, 'bn_Q_moving_variance': 3.9017918}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.05861327, 'bn_Q_moving_variance': 3.8697228}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0729021, 'bn_Q_moving_variance': 3.8379822}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.08704803, 'bn_Q_moving_variance': 3.8065662}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.10105251, 'bn_Q_moving_variance': 3.7754717}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.11491694, 'bn_Q_moving_variance': 3.7446957}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.12864274, 'bn_Q_moving_variance': 3.7142346}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.14223127, 'bn_Q_moving_variance': 3.6840856}

注意

即使使用tf.layers.batch_normalization,仍有一些可疑之处。 tf.control_dependencies的标准tf方法:

    with tf.control_dependencies(update_ops):
        sess.run(out, {inputs[0]: data})

我在上面的代码中放置了以下两行代码:

    sess.run(update_ops,  {inputs[0]: data})
    sess.run(out, {inputs[0]: data})

产生bn_Q_moving_mean = 0.0bn_Q_moving_variance = 4.0

【问题讨论】:

  • 那么这个问题怎么没有答案?
  • @Sharky 因为 Matias Valdenegro 给出了一个关于纯 Keras 的答案,而不是关于 Tensorflow+Keras 的答案,请参阅他的答案 cmets。鉴于 Syncopated 和我的经验,tf.keras 不会自动更新移动平均线。所以,问题依然存在:怎么做?

标签: python tensorflow keras batch-normalization


【解决方案1】:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

这个可以解决

tf.control_dependencies(update_ops)

错误问题。

如果使用

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, batch_normed.updates)

回归

tf.get_collection(tf.GraphKeys.UPDATE_OPS)

是列表中的一个列表,就像 [[something]]

并使用

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

回归

tf.get_collection(tf.GraphKeys.UPDATE_OPS)

是 [something1,something2,...]

我认为这是解决方案。

但输出不同,我不知道哪个是真的。

【讨论】:

    【解决方案2】:

    这是因为tf.keras.layers.BatchNormalization 继承自tf.keras.layers.Layer。 Keras API 将更新操作作为其拟合和评估循环的一部分。这反过来意味着没有它它不会更新tf.GraphKeys.UPDATE_OPS 集合。

    所以为了让它工作,你需要手动更新它

    hidden = tf.keras.layers.Dense(units, activation=None)(out)
    batch_normed = tf.keras.layers.BatchNormalization(trainable=True) 
    layer = batch_normed(hidden)
    

    这会创建单独的类实例

    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, batch_normed.updates)
    

    而且这个更新需要收集。也看看https://github.com/tensorflow/tensorflow/issues/25525

    【讨论】:

    • 很好,非常感谢!运行sess.run(update_ops, {inputs[0]: data}) 然后sess.run(out, {inputs[0]: data}) 对我来说效果很好。你知道为什么with tf.control_dependencies(update_ops): sess.run(out, {inputs[0]: data}) 还是不行吗?
    • 已更新答案,尽我所能(如 github 链接中所述)keras 根本没有更新图层对象内集合的功能,因此应该明确更新。估算器和低级会话可以做到这一点
    猜你喜欢
    • 2014-03-21
    • 1970-01-01
    • 2018-06-17
    • 2020-12-26
    • 1970-01-01
    • 2012-04-20
    • 2017-01-18
    相关资源
    最近更新 更多