【发布时间】:2018-01-30 21:14:51
【问题描述】:
最近我尝试使用 float16 在 TF 中训练 CNN。令我惊讶的是,即使 TF 声称支持它一段时间,它也以各种方式被破坏。例如,无论网络如何,float16 优化都会在第二步导致 NaN 损失。
import tensorflow as tf
import numpy as np
slim = tf.contrib.slim
dtype = tf.float16
shape = (4, 16, 16, 3)
inpt = tf.placeholder(dtype, shape, name='input')
net = slim.conv2d(inpt, 16, [3, 3], scope='conv',
weights_initializer=tf.zeros_initializer(),
# normalizer_fn=slim.batch_norm
)
loss = tf.reduce_mean(net)
opt = tf.train.AdamOptimizer(1e-3)
train_op = slim.learning.create_train_op(loss, opt)
val = np.zeros(shape)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(2):
print(sess.run(train_op, feed_dict={inpt: val}))
据我了解,这显然是一个错误:我在零输入上应用零卷积,我应该得到不会改变零损失的零梯度。就是不能分道扬镳。如果 dtype 是 float32 它可以工作。在 CPU 和 GPU 版本上都会发生 NaN 丢失。
但是,我在 GH 问题中被解雇了,一个随机的家伙关闭了这个问题,说这是预期的行为:https://github.com/tensorflow/tensorflow/issues/7226
如果您取消注释带有 BN 的行,它将在图形构建时中断,因为 BN 假定移动平均线(以及 beta、gamma)始终为 float32 并且没有正确转换它们。这个问题也被关闭了,显然被忽略了:https://github.com/tensorflow/tensorflow/issues/7164
我觉得我正在与 ISP 的一线 IT 支持人员交谈。
当这样一个简单的“网络”严重失败时,谁能解释我应该如何使用 float16 进行训练?现在报告错误的推荐方式是什么?
【问题讨论】:
标签: tensorflow