【问题标题】:Gradient accumulation in tensorflow 2.x / kerastensorflow 2.x / keras 中的梯度累积
【发布时间】:2021-06-08 16:47:42
【问题描述】:

我正在尝试在 TF2.x 上实现梯度累积。我发现的所有实现都适用于 TF1.x 或旧的 keras 接口。我不认为那里有实现(尽管我很高兴被证明是错误的)。

这是我正在使用的:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt


class SimpleTrainStepModel(Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None


        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x, training = True)  # Forward pass
            loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y, y_pred)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {m.name: m.result() for m in self.metrics}


class GradAccumModel(Model):
    def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
        self.train_function = None
        if batch_size % grad_accum != 0:
            raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
        self.grad_accum = grad_accum
        self.batch_size = batch_size
        return super(GradAccumModel, self).fit(*args,
                                               batch_size = self.batch_size,
                                               #validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,
                                               **kwargs)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None

        step = self.batch_size // self.grad_accum

        # def _slice_nested(obj, i, j):
        #     if type(obj) is list:
        #         return [o[i:j] for o in obj]
        #     else:
        #         return obj[i:j]

        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x[:step], training = True)  # Forward pass
            loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y[:step], y_pred)

        i = tf.constant(step)
        # tf.print('TF - HERE!')
        def cond(i, *args):
            return i < self.batch_size
        def body(i, grad):
            # tf.print('\tTF - HERE!')
            with tf.GradientTape() as tape:
                y_pred = self(x[i:i + step], training = True) # Forward pass
                loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
            _grad = tape.gradient(loss, self.trainable_variables)

            for g,_g in zip(grad, _grad):
                g += _g

            self.compiled_metrics.update_state(y[i:i + step], y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)


        # for g in gradients:        # I tried with and without division co calculate the mean
        #     g *= 1/self.grad_accum #


        # Update weights
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # Update metrics (includes the metric that tracks the loss)

        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}


if __name__ == '__main__':
    (x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()

    for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
                                            [{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
                                            ['blue', 'green', 'yellow', 'red'])):

        for _ in tqdm(range(10)):
            # tf.random.set_seed(0)
            x = Input((28, 28))
            y = x
            y = Flatten()(y)
            y = Dense(128, activation = 'sigmoid')(y)
            y = Dense(10, activation = 'softmax')(y)

            model = MODEL(x, y)
            model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
                          optimizer = tf.keras.optimizers.Adam(1e-4),
                          metrics = ['acc'])

            hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
            plt.plot(hist.history['val_acc'], color = colour, alpha = .25)

    plt.title('')
    plt.xscale('symlog')
    plt.yscale('logit')
    plt.show()

我已经能够验证它确实节省了 gpu 内存。但是最终的结果和普通的Model.fit不一样。

如您所见,前三个Model.fits 聚类良好,并给出相同的结果。但是当while 循环开始发挥作用时,训练就完全不同了。

有人知道为什么会这样吗?

【问题讨论】:

    标签: python keras tensorflow2.0 gradienttape


    【解决方案1】:

    经过多次尝试,我找到了解决方案,似乎主要问题是渐变的复合分配,它不像我预期的那样工作。对于任何可能感兴趣的人,这是我的最终解决方案。它包括用于分布式、混合精度训练和嵌套输入/输出的额外内容。

    from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
    from tensorflow.python.distribute import parameter_server_strategy
    from tensorflow.python.distribute import distribution_strategy_context as ds_context
    from tensorflow.python.util import nest
    from tensorflow.keras.models import Model as _Model
    
    
    class Model(_Model):
        def fit(self, *args, batch_size: int = 32, grad_accum_steps: int = 1, **kwargs):
            """
            Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.
    
            Parameters
            ----------
            batch_size : int
                same as in Model.fit
            grad_accum_steps : int
                Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
            """
            if grad_accum_steps == 1:
                super().fit(*args, batch_size = batch_size, **kwargs)
    
            self.train_function = None
            num_workers = ds_context.get_strategy().num_replicas_in_sync
            if batch_size % (grad_accum_steps * num_workers) != 0:
                raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}), and the number of replicas ({num_workers}), dummy!')
    
            self._grad_accum_ = grad_accum_steps
            self._batch_size_ = batch_size
            self._num_workers_ = num_workers
            train_step_backup = self.train_step
            self.train_step = self._train_step_
            out = super(self).fit(*args,
                                  batch_size = self._batch_size_, # TODO maybe consider validation batch size
                                  **kwargs)
    
            del self._grad_accum_
            del self._batch_size_
            del self._num_workers_
            self.train_step = train_step_backup
            return out
    
        def _train_step_(self, data):
            """
            Custom training step taking into account gradient accumulation for low memory training
            """
    
            if len(data) == 3:
                x, y, sample_weight = data
            else:
                (x, y), sample_weight = data, None
    
    
            def slice_map(struct, start, stop): # dealing with nasty nested structures
                if struct is None:
                    return None # special case for sample_weight
    
                return nest.map_structure(lambda x: x[start:stop], struct)
    
    
    
            # ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
            step = self._batch_size_ // self._num_workers_ // self._grad_accum_
            x_ = slice_map(x, 0, step)
            y_ = slice_map(y, 0, step)
            w_ = slice_map(sample_weight, 0, step)
    
            with tf.GradientTape() as tape:
    
                y_pred = self(x_, training = True)  # Forward pass
                loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
                if isinstance(self.optimizer, lso.LossScaleOptimizer):
                    loss = self.optimizer.get_scaled_loss(loss)
    
            gradients = tape.gradient(loss, self.trainable_variables)
            gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
            self.compiled_metrics.update_state(y_, y_pred)
    
            i = tf.constant(step)
            def cond(i, *args):
                return i < self._batch_size_
    
            def body(i, grad):
                x_ = slice_map(x, i, i + step)
                y_ = slice_map(y, i, i + step)
                w_ = slice_map(sample_weight, i, i + step)
    
                with tf.GradientTape() as tape:
                    y_pred = self(x_, training = True) # Forward pass
                    loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
                    if isinstance(self.optimizer, lso.LossScaleOptimizer):
                        loss = self.optimizer.get_scaled_loss(loss)
    
                _grad = tape.gradient(loss, self.trainable_variables)
                _grad = [_g * (1./self._grad_accum_) for _g in _grad]
    
                grad = [g + _g for g,_g in zip(grad, _grad)]
    
                self.compiled_metrics.update_state(y_, y_pred)
                return [i + step, grad]
    
            i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
            # --------------------------------------------------------------------------------------------------------------
    
    
    
            # ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
            aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended))
    
            if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling, due to the extra accumulation steps
                gradients = self.optimizer._aggregate_gradients(zip(gradients, self.trainable_variables))
    
            if isinstance(self.optimizer, lso.LossScaleOptimizer):
                gradients = self.optimizer.get_unscaled_gradients(gradients)
    
            gradients = self.optimizer._clip_gradients(gradients)
            if self.trainable_variables:
                if aggregate_grads_outside_optimizer:
                    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables), experimental_aggregate_gradients = False)
                else:
                    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            # --------------------------------------------------------------------------------------------------------------
    
    
            return {m.name: m.result() for m in self.metrics}
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2022-01-13
      • 2019-04-19
      • 1970-01-01
      • 2021-01-04
      • 1970-01-01
      • 1970-01-01
      • 2020-09-15
      相关资源
      最近更新 更多