【问题标题】:L2 regularization keep increasing during trainingL2 正则化在训练期间不断增加
【发布时间】:2018-09-20 14:24:49
【问题描述】:

我正在 TensorFlow 上微调 InceptionResnetV2。训练时,正则化损失保持线性增加,甚至远大于训练后期的交叉熵损失。我已经检查了训练过程,并确保我正在优化交叉熵损失和 L2 损失的组合。

有没有人稍微解释一下这个奇怪的事情?任何反馈表示赞赏。

这是代码和一些 TensorBoard 图。

import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
import os
import time
from preprocessing import aug_parallel_v2
import numpy as np

slim = tf.contrib.slim

# total training data number
sample_num = 625020

data_path = 'iNaturalist_train.tfrecords'

# State where your log file is at. If it doesn't exist, create it.
log_dir = './log_v5'
# tensorboard visualization path
filewriter_path = './filewriter_v5_Logits'

# State where your checkpoint file is
checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'
checkpoint_save_addr = './log_v5/fine-tuning_v5.ckpt'
# State the image size you're resizing your images to. We will use the default inception size of 299.
image_size = 299

# State the number of classes to predict:
num_classes = 8142

# ================= TRAINING INFORMATION ==================
# State the number of epochs to train
num_epochs = 5

# State your batch size
batch_size = 60

# Learning rate information and configuration
initial_learning_rate = 0.0005
learning_rate_decay_factor = 0.8
num_epochs_before_decay = 2

# put weight on different classes inversely proportional
# to total number of their image samples
label_count = np.loadtxt('label_count.txt', dtype=int)
inverse = lambda t: 1 / t
vfunc = np.vectorize(inverse)
multiplier = vfunc(label_count)
multiplier /= np.mean(multiplier)

def run():

    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    feature = {'train/height': tf.FixedLenFeature([], tf.int64),
               'train/width': tf.FixedLenFeature([], tf.int64),
               'train/image': tf.FixedLenFeature([], tf.string),
               'train/label': tf.FixedLenFeature([], tf.int64),
               'train/sup_label': tf.FixedLenFeature([], tf.int64),
               'train/aug_level': tf.FixedLenFeature([], tf.int64)}

    # create a list of file names
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=None)
    print(filename_queue)

    reader = tf.TFRecordReader()
    _, tfrecord_serialized = reader.read(filename_queue)

    features = tf.parse_single_example(tfrecord_serialized, features=feature)

    # Convert the image data from string back to the numbers
    height = tf.cast(features['train/height'], tf.int64)
    width = tf.cast(features['train/width'], tf.int64)

    # change this line for your TFrecord version
    tf_image = tf.image.decode_jpeg(features['train/image'])

    tf_label = tf.cast(features['train/label'], tf.int32)
    aug_level = tf.cast(features['train/aug_level'], tf.int32)
    # tf_sup_label = tf.cast(features['train/sup_label'], tf.int64)

    tf_image = tf.reshape(tf_image, tf.stack([height, width, 3]))
    tf_label = tf.reshape(tf_label, [1])
    aug_level = tf.reshape(aug_level, [1])

    resized_image = tf.image.resize_images(images=tf_image, size=tf.constant([400, 400]), method=2)
    resized_image = tf.cast(resized_image, tf.uint8)
    tf_images, tf_labels, tf_aug = tf.train.shuffle_batch([resized_image, tf_label, aug_level], batch_size=batch_size,
                                                      capacity=2048, num_threads=16, allow_smaller_final_batch=False,
                                                      min_after_dequeue=256)


    tf.logging.set_verbosity(tf.logging.INFO)  # Set the verbosity to INFO level

    IMAGE_HEIGHT = 299
    IMAGE_WIDTH = 299

    images = tf.placeholder(dtype=tf.float32, shape=[None, 299, 299, 3])
    labels = tf.placeholder(dtype=tf.int32, shape=[None, 1])
    weighted_level = tf.placeholder(dtype=tf.float32, shape=[None, 1])

    # Know the number steps to take before decaying the learning rate and batches per epoch
    num_batches_per_epoch = int(sample_num / batch_size)
    num_steps_per_epoch = num_batches_per_epoch  # Because one step is one batch processed
    decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)

    # Create the model inference
    with slim.arg_scope(inception_resnet_v2_arg_scope()):
        logits, end_points = inception_resnet_v2(images, num_classes=num_classes, is_training=True)

    # Define the scopes that you want to exclude for restoration
    exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
    variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

    print("label test")
    print(labels)
    print(logits)

    # Perform one-hot-encoding of the labels (Try one-hot-encoding within the load_batch function!)
    one_hot_labels = tf.squeeze(tf.one_hot(labels, num_classes), [1])

    print(one_hot_labels)
    print(logits)

    weighted_onehot = tf.multiply(one_hot_labels, weighted_level)

    # Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
    digits_loss = tf.losses.softmax_cross_entropy(onehot_labels=weighted_onehot, logits=logits)

    reg_loss = tf.losses.get_regularization_loss()

    total_loss = digits_loss + reg_loss

    # Define your exponentially decaying learning rate
    lr = tf.train.exponential_decay(
        learning_rate=initial_learning_rate,
        global_step=global_step,
        decay_steps=decay_steps,
        decay_rate=learning_rate_decay_factor,
        staircase=True)

    # train_vars = []
    # Now we can define the optimizer that takes on the learning rate
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          "InceptionResnetV2/Logits")

    # RMSProp or Adam

    optimizer = tf.train.AdamOptimizer(learning_rate=lr)

    # Create the train_op.
    train_op = slim.learning.create_train_op(total_loss, optimizer, variables_to_train=train_vars)

    predictions = tf.argmax(end_points['Predictions'], 1)
    probabilities = end_points['Predictions']
    accuracy, accuracy_update = tf.metrics.accuracy(predictions, labels)
    metrics_op = tf.group(accuracy_update, probabilities)

    tf.summary.scalar('losses/Reg_Loss', reg_loss)
    tf.summary.scalar('losses/Digit_Loss', digits_loss)
    tf.summary.scalar('losses/Total_Loss', total_loss)
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('learning_rate', lr)
    writer = tf.summary.FileWriter(filewriter_path)
    writer.add_graph(tf.get_default_graph())

    my_summary_op = tf.summary.merge_all()

    def train_step(sess, train_op, global_step, imgs, lbls, weight):
        '''
        Simply runs a session for the three arguments provided and gives a logging on the time elapsed
        for each global step
        '''
        # Check the time for each sess run
        start_time = time.time()

        total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op],
                                                    feed_dict={images: imgs, labels: lbls, weighted_level: weight})

        time_elapsed = time.time() - start_time

        # Run the logging to print some results
        logging.info('global step %s: digit_loss: %.4f (%.2f sec/step)',
                     global_step_count, total_loss, time_elapsed)

        return total_loss, global_step_count

    saver_pretrain = tf.train.Saver(variables_to_restore)
    saver_train = tf.train.Saver(train_vars)

    with tf.Session() as sess:

        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)

        # Create a coordinator and run all QueueRunner objects
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        saver_pretrain.restore(sess, checkpoint_file)

        start_time = time.time()

        for step in range(int(num_steps_per_epoch * num_epochs)):

            imgs, lbls, augs = sess.run([tf_images, tf_labels, tf_aug])

            imgs, lbls = aug_parallel_v2(imgs, lbls, augs)

            imgs = imgs[:, 50:349, 50:349, :]

            imgs = 2*(imgs.astype(np.float32)) - 1

            lbls = lbls.astype(np.int32)

            weight = multiplier[lbls]

            weight = np.array(weight).reshape((batch_size, 1))

            # print(imgs[0, 0:10, 0:10, 0:2])

            if step % num_batches_per_epoch == 0:
                logging.info('Epoch %s/%s', step / num_batches_per_epoch + 1, num_epochs)

                learning_rate_value, accuracy_value = sess.run([lr, accuracy],
                                                feed_dict={images: imgs, labels: lbls, weighted_level: weight})

                logging.info('Current Learning Rate: %s', learning_rate_value)
                logging.info('Current Streaming Accuracy: %s', accuracy_value)

                # optionally, print your logits and predictions for a sanity check that things are going fine.
                logits_value, probabilities_value, predictions_value, labels_value = sess.run(
                    [logits, probabilities, predictions, labels],
                    feed_dict={images: imgs, labels: lbls, weighted_level: weight})

                print('logits: \n', logits_value)

                print('Probabilities: \n', probabilities_value)

                print('predictions: \n', predictions_value)

                print('Labels:\n:', labels_value)

            # Log the summaries every 10 step.
            if step % 20 == 0:

                loss, global_step_count = train_step(sess, train_op, global_step, imgs, lbls, weight)

                summaries = sess.run(my_summary_op, feed_dict={images: imgs, labels: lbls, weighted_level: weight})

                writer.add_summary(summaries, global_step_count)
                # sess.summary_computed(sess, summaries)

            # If not, simply run the training step

            else:
                loss, _ = train_step(sess, train_op, global_step, imgs, lbls, weight)

            if step % 2000 == 0:

                logging.info('Saving model to disk now.')
                saver_train.save(sess, checkpoint_save_addr, global_step=global_step)

            print('one batch time: ', time.time() - start_time)

            start_time = time.time()

        # We log the final training loss and accuracy
        logging.info('Final Loss: %s', loss)
        logging.info('Final Accuracy: %s', sess.run(accuracy))

        # Once all the training has been done, save the log files and checkpoint model
        logging.info('Finished training! Saving model to disk now.')
        saver_train.save(sess, checkpoint_save_addr, global_step=global_step)

        # Stop the threads
        coord.request_stop()

        # Wait for threads to stop
        coord.join(threads)
        sess.close()

if __name__ == '__main__':
    run()

我是新来的,没有足够的声誉来发布图片。 这是准确度图和损失图的两个链接。您可以很容易地看出正则化损失处于主导地位。

【问题讨论】:

    标签: tensorflow deep-learning


    【解决方案1】:

    这是一个很难回答的问题。不过我可以指点一下。

    一般来说,当您尝试最小化digits_loss,即使您的模型适合您的数据时,您会慢慢改变层中的权重。为了应对潜在的过拟合,L2 正则化损失(所有权重的平方和,代码中的reg_loss)通常被添加到整体损失(代码中的total_loss)中。这两种力通常相互对抗如果平衡正确,你就可以训练出一个好的模型。

    在您的情况下,您正在使用一个为 1,001 个类开发的网络 (resnet_v2),并尝试预测 8,142 个类。这本身没有问题,但你正在破坏平衡。所以我相信你需要将 resnet v2 的默认权重衰减 0.00004 覆盖到更高的值,在这一行中(注意小数点中只有 3 个零,增加 10 倍):

    with slim.arg_scope( inception_resnet_v2_arg_scope( weight_decay = 0.0004 ) ):
    

    更高的weight_decay 参数将迫使 L2 损失更快地减少。问题是这个数字只是一个猜测,我不知道理想值是多少。您需要尝试多个值并找出答案。

    【讨论】:

    • 我试过你的方法来使用更大的 weight_decay。 reg_loss 上升非常快,逐渐缓慢下降。但是重量衰减可能太大了。我的模型的性能(准确性)保持在非常低的水平。你知道训练中 reg loss 总是上升的原因吗?
    • 所以基本上它必须随着你的模型在初始化后的前几百步中开始成形而开始上升。一些权重变大,一些衰减,但现在你有一个非常任意的权重分布,而不是原来的正态分布,这是应该的。所以最初增加是很正常的。不幸的是,您必须尝试找出哪个weight_decay 值可以为您提供最佳结果。这称为超参数探索。如果您有更多硬件可供使用,您可以同时运行多个具有不同值的副本。
    • 我知道如果我们从头开始训练,L2 损失会增加。但我正在微调 InceptionResnetV2,它的变量已经很合理了。我认为 L2 损失不应该增加几千次迭代。而最奇怪的是,digit_loss 一开始也是在增加的。也许我的代码中存在错误...
    • 您正在从头开始训练最后一层。或者你的学习率太高了。但是,是的,您也可能有其他性质的错误。欢迎来到深度学习... :)
    • 哈哈哈谢谢你,先生!我尝试了几个重量衰减值,也尝试将所有变量一起微调,但无法摆脱这个问题。 digit_loss先增大后逐渐减小;而 reg_loss 会随着时间的推移而减少然后线性增加。线性是如此奇怪和烦人。它完全阻止了 digit_loss 的优化。对于超过 20, 000 次迭代,reg_loss 是 digit_loss 的三倍......
    猜你喜欢
    • 2021-04-06
    • 2018-11-04
    • 1970-01-01
    • 2018-06-18
    • 2017-07-30
    • 1970-01-01
    • 1970-01-01
    • 2013-04-04
    • 1970-01-01
    相关资源
    最近更新 更多