【问题标题】:Batch Normalization - Tensorflow批量标准化 - Tensorflow
【发布时间】:2017-12-07 23:44:19
【问题描述】:

我查看了一些 BN 示例,但仍然有些困惑。所以我目前正在使用这里调用该函数的函数;

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.batch_norm.md

from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
import tensorflow as tf

def bn(x,is_training,name):
    bn_train = batch_norm(x, decay=0.9, center=True, scale=True,
    updates_collections=None,
    is_training=True,
    reuse=None, 
    trainable=True,
    scope=name)
    bn_inference = batch_norm(x, decay=1.00, center=True, scale=True,
    updates_collections=None,
    is_training=False,
    reuse=True, 
    trainable=False,
    scope=name)
    z = tf.cond(is_training, lambda: bn_train, lambda: bn_inference)
    return z

以下部分是一个玩具运行,我只是检查该函数是否重用了在两个特征的训练步骤中计算的均值和方差。在测试模式下运行这部分代码,即is_training=False,在训练步骤中计算的运行均值/方差正在发生变化,当我们打印出我调用bnParams得到的BN变量时可以看到这一点

if __name__ == "__main__":
    print("Example")

    import os
    import numpy as np
    import scipy.stats as stats
    np.set_printoptions(suppress=True,linewidth=200,precision=3)
    np.random.seed(1006)
    import pdb
    path = "batchNorm/"
    if not os.path.exists(path):
        os.mkdir(path)
    savePath = path + "bn.model"

    nFeats = 2
    X = tf.placeholder(tf.float32,[None,nFeats])
    is_training = tf.placeholder(tf.bool,name="is_training")
    Y = bn(X,is_training=is_training,name="bn")
    mvn = stats.multivariate_normal([0,100])
    bs = 4
    load = 0
    train = 1
    saver = tf.train.Saver()
    def bnCheck(batch,mu,std):
        # Checking calculation
        return (x - mu)/(std + 0.001)
    with tf.Session() as sess:
        if load == 1:
            saver.restore(sess,savePath)
        else:
            tf.global_variables_initializer().run()
        #### TRAINING #####
        if train == 1:
            for i in xrange(100):
                x = mvn.rvs(bs)
                y = Y.eval(feed_dict={X:x, is_training.name: True})

        def bnParams():
            beta, gamma, mean, var = [v.eval() for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="bn")]
            return beta, gamma, mean, var

        beta, gamma, mean, var = bnParams()
        #### TESTING #####
        for i in xrange(10):
            x = mvn.rvs(1).reshape(1,-1)
            check = bnCheck(x,mean,np.sqrt(var))
            y = Y.eval(feed_dict={X:x, is_training.name: False})
            print("x = {0}, y = {1}, check = {2}".format(x,y,check))
            beta, gamma, mean, var = bnParams()
            print("BN Params: Beta {0} Gamma {1} mean {2} var{3} \n".format(beta,gamma,mean,var))

        saver.save(sess,savePath)

测试循环的前三个迭代如下所示;

x = [[  -1.782  100.941]], y = [[-1.843  1.388]], check = [[-1.842  1.387]]
BN Params: Beta [ 0.  0.] Gamma [ 1.  1.] mean [ -0.2   99.93] var[ 0.818  0.589] 

x = [[  -1.245  101.126]], y = [[-1.156  1.557]], check = [[-1.155  1.557]]
BN Params: Beta [ 0.  0.] Gamma [ 1.  1.] mean [  -0.304  100.05 ] var[ 0.736  0.53 ] 

x = [[ -0.107  99.349]], y = [[ 0.23  -0.961]], check = [[ 0.23 -0.96]]
BN Params: Beta [ 0.  0.] Gamma [ 1.  1.] mean [ -0.285  99.98 ] var[ 0.662  0.477] 

我没有做 BP,所以 beta 和 gamma 不会改变。然而,我的跑步方式/差异正在改变。我哪里错了?

编辑: 最好知道为什么这些变量需要/不需要在测试和训练之间改变;

updates_collections, reuse, trainable

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    你的 bn 函数是错误的。改用这个:

    def bn(x,is_training,name):
        return batch_norm(x, decay=0.9, center=True, scale=True,
        updates_collections=None,
        is_training=is_training,
        reuse=None,
        trainable=True,
        scope=name)
    

    is_training 是 bool 0-D 张量,指示是否更新运行均值等。然后只需更改张量 is_training 即可表明您是处于训练阶段还是测试阶段。

    编辑: 张量流中的许多操作都接受张量,而不是恒定的真/假数字参数。

    【讨论】:

    • 那和我发的有什么区别?
    • Yours 在计算图上为初学者创建了不必要的节点。老实说,我没有测试过你的代码,但是我看到很多行不正确(例如导入应该在文件的开头,而不是在 if_main 下缩进等)
    • 说实话……表面上看起来不错。但正如我所说,每 5 行代码就有一些奇怪的程序员选择(例如使用 0/1 表示 is_traning,然后在其他地方使用 bool,在不寻常(非 PEP8)地方导入等)这些事情使代码更难阅读和推理关于。为什么使用 is_training.name 而不是 is_training?
    • 另外,您如何确定您正在编写正确的变量?你从不检查他们的名字,get_collection 我的图片已经设置了抽象,没有顺序。
    • 我并不是真的在寻找关于漂亮编程的课程,我只是想要一些关于批量标准化的有用建议。还是谢谢。
    【解决方案2】:

    当您使用 slim.batch_norm 时,请务必使用 slim.learning.create_train_op 而不是 tf.train.GradientDecentOptimizer(lr).minimize(loss) 或其他优化器。试试看它是否有效!

    【讨论】:

      猜你喜欢
      • 2017-03-03
      • 2018-06-05
      • 2018-04-09
      • 2016-03-03
      • 1970-01-01
      • 2017-07-03
      • 2019-10-30
      • 2016-03-01
      • 2017-10-20
      相关资源
      最近更新 更多