【问题标题】:Batch Normalization in tensorflow张量流中的批量标准化
【发布时间】:2016-05-02 09:08:20
【问题描述】:

我注意到 tensorflow 的 api 中已经有批量标准化函数。但我不明白的一件事是如何改变训练和测试之间的程序?

批量标准化在测试期间的作用与在训练期间的作用不同。具体来说,在训练期间使用固定的均值和方差。

在某处有一些很好的示例代码吗?我看到了一些,但是对于范围变量,它变得令人困惑

【问题讨论】:

  • 考虑使用来自高级 api 的预定义层,例如 tf.contrib.layers

标签: tensorflow recurrent-neural-network


【解决方案1】:

你说得对,tf.nn.batch_normalization 只提供了实现批量标准化的基本功能。您必须添加额外的逻辑来跟踪训练期间的移动均值和方差,并在推理过程中使用经过训练的均值和方差。您可以查看此example 以获得非常通用的实现,但这里有一个不使用gamma 的快速版本:

  beta = tf.Variable(tf.zeros(shape), name='beta')
  moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
                                 trainable=False)
  moving_variance = tf.Variable(tf.ones(shape),
                                     name='moving_variance',
                                     trainable=False)
  control_inputs = []
  if is_training:
    mean, variance = tf.nn.moments(image, [0, 1, 2])
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, self.decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, self.decay)
    control_inputs = [update_moving_mean, update_moving_variance]
  else:
    mean = moving_mean
    variance = moving_variance
  with tf.control_dependencies(control_inputs):
    return tf.nn.batch_normalization(
        image, mean=mean, variance=variance, offset=beta,
        scale=None, variance_epsilon=0.001)

【讨论】:

  • 非常感谢。另一个快速的问题。带有伽玛的版本真的更复杂吗?似乎您只需要为它初始化另一个 tf.Variable 吗?其余的代码应该是一样的吧?
  • 是的,您可以按照我提供的链接中更通用的实现添加gamma
猜你喜欢
  • 2020-04-03
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2017-04-14
  • 1970-01-01
  • 2018-04-09
  • 2020-11-30
相关资源
最近更新 更多