【问题标题】:Backpropagating through multiple forward passes通过多个前向传播反向传播
【发布时间】:2020-11-23 03:02:15
【问题描述】:

在通常的反向传播中,我们向前传播一次,计算梯度,然后将它们应用于更新权重。但是假设我们希望前向传播两次,并通过两者进行反向传播,然后才应用渐变(先跳过)。

假设如下:

x = tf.Variable([2.])
w = tf.Variable([4.])

with tf.GradientTape(persistent=True) as tape:
    w.assign(w * x)
    y = w * w  # w^2 * x
print(tape.gradient(y, x))  # >>None

docs 来看,tf.Variable 是一个有状态 对象,它阻止渐变,权重为tf.Variables。

示例是可区分的硬注意力(与 RL 相对),或者只是在随后的前向传递中在层之间传递隐藏状态,如下图所示。 TF 和 Keras 都没有对有状态梯度的 API 级别的支持,包括 RNNs,它只保留一个有状态的状态张量;梯度不会流过一批。

如何做到这一点?

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0 tf.keras


    【解决方案1】:

    我们需要精心申请tf.while_loop;来自help(TensorArray)

    此类旨在与动态迭代原语一起使用,例如while_loopmap_fn。它通过特殊的“流”控制流依赖支持梯度反向传播。

    因此,我们试图编写一个循环,以便我们要反向传播的所有输出都写入TensorArray。完成此操作的代码及其高级描述如下。底部是一个验证示例。


    说明

    • 代码借鉴自 K.rnn,为了简单和相关而重写
    • 为了更好地理解,我建议检查K.rnnSimpleRNNCell.callRNN.call
    • model_rnn 为案例 3 进行了一些不必要的检查;将链接更清洁的版本
    • 思路如下:我们从下到上遍历网络,然后从左到右,将整个前向传递写入一个TensorArray下单tf.while_loop;这可确保 TF 在整个过程中缓存张量操作以进行反向传播。

    from tensorflow.python.util import nest
    from tensorflow.python.ops import array_ops, tensor_array_ops
    from tensorflow.python.framework import ops
    
    
    def model_rnn(model, inputs, states=None, swap_batch_timestep=True):
        def step_function(inputs, states):
            out = model([inputs, *states], training=True)
            output, new_states = (out if isinstance(out, (tuple, list)) else
                                  (out, states))
            return output, new_states
    
        def _swap_batch_timestep(input_t):
            # (samples, timesteps, channels) -> (timesteps, samples, channels)
            # iterating dim0 to feed (samples, channels) slices expected by RNN
            axes = list(range(len(input_t.shape)))
            axes[0], axes[1] = 1, 0
            return array_ops.transpose(input_t, axes)
    
        if swap_batch_timestep:
            inputs = nest.map_structure(_swap_batch_timestep, inputs)
    
        if states is None:
            states = (tf.zeros(model.inputs[0].shape, dtype='float32'),)
        initial_states = states
        input_ta, output_ta, time, time_steps_t = _process_args(model, inputs)
    
        def _step(time, output_ta_t, *states):
            current_input = input_ta.read(time)
            output, new_states = step_function(current_input, tuple(states))
    
            flat_state = nest.flatten(states)
            flat_new_state = nest.flatten(new_states)
            for state, new_state in zip(flat_state, flat_new_state):
                if isinstance(new_state, ops.Tensor):
                    new_state.set_shape(state.shape)
    
            output_ta_t = output_ta_t.write(time, output)
            new_states = nest.pack_sequence_as(initial_states, flat_new_state)
            return (time + 1, output_ta_t) + tuple(new_states)
    
        final_outputs = tf.while_loop(
            body=_step,
            loop_vars=(time, output_ta) + tuple(initial_states),
            cond=lambda time, *_: tf.math.less(time, time_steps_t))
    
        new_states = final_outputs[2:]
        output_ta = final_outputs[1]
        outputs = output_ta.stack()
        return outputs, new_states
    
    
    def _process_args(model, inputs):
        time_steps_t = tf.constant(inputs.shape[0], dtype='int32')
    
        # assume single-input network (excluding states)
        input_ta = tensor_array_ops.TensorArray(
            dtype=inputs.dtype,
            size=time_steps_t,
            tensor_array_name='input_ta_0').unstack(inputs)
    
        # assume single-input network (excluding states)
        # if having states, infer info from non-state nodes
        output_ta = tensor_array_ops.TensorArray(
            dtype=model.outputs[0].dtype,
            size=time_steps_t,
            element_shape=model.outputs[0].shape,
            tensor_array_name='output_ta_0')
    
        time = tf.constant(0, dtype='int32', name='time')
        return input_ta, output_ta, time, time_steps_t
    

    示例和验证

    案例设计:我们两次输入相同的输入,这可以进行某些有状态与无状态的比较;结果也适用于不同的输入。

    • 案例0:对照;其他情况必须与此相符。
    • 案例1:失败;梯度不匹配,即使输出和损失匹配。馈送减半的序列时,反向传播失败。
    • 案例 2:梯度匹配案例 1。看起来我们只使用了一个 tf.while_loop,但 SimpleRNN 在 3 个时间步中使用了它自己的一个,并写入被丢弃的 TensorArray;这不行。一种解决方法是自己实现 SimpleRNN 逻辑。
    • 案例 3:完美匹配。

    请注意,没有状态 RNN 单元之类的东西;有状态在RNN 基类中实现,我们在model_rnn 中重新创建了它。这同样是处理任何其他层的方式 - 每次前向传递一次馈送一个步进切片。

    import random
    import numpy as np
    import tensorflow as tf
    
    from tensorflow.keras.layers import Input, SimpleRNN, SimpleRNNCell
    from tensorflow.keras.models import Model
    
    def reset_seeds():
        random.seed(0)
        np.random.seed(1)
        tf.compat.v1.set_random_seed(2)  # graph-level seed
        tf.random.set_seed(3)  # global seed
    
    def print_report(case, model, outs, loss, tape, idx=1):
        print("\nCASE #%s" % case)
        print("LOSS", loss)
        print("GRADS:\n", tape.gradient(loss, model.layers[idx].weights[0]))
        print("OUTS:\n", outs)
    
    
    #%%# Make data ###############################################################
    reset_seeds()
    x0 = y0 = tf.constant(np.random.randn(2, 3, 4))
    x0_2 = y0_2 = tf.concat([x0, x0], axis=1)
    x00  = y00  = tf.stack([x0, x0], axis=0)
    
    #%%# Case 0: Complete forward pass; control case #############################
    reset_seeds()
    ipt = Input(batch_shape=(2, 6, 4))
    out = SimpleRNN(4, return_sequences=True)(ipt)
    model0 = Model(ipt, out)
    model0.compile('sgd', 'mse')
    #%%#############################################################
    with tf.GradientTape(persistent=True) as tape:
        outs = model0(x0_2, training=True)
        loss = model0.compiled_loss(y0_2, outs)
    print_report(0, model0, outs, loss, tape)
    
    #%%# Case 1: Two passes, stateful RNN, direct feeding ########################
    reset_seeds()
    ipt = Input(batch_shape=(2, 3, 4))
    out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
    model1 = Model(ipt, out)
    model1.compile('sgd', 'mse')
    #%%#############################################################
    with tf.GradientTape(persistent=True) as tape:
        outs0 = model1(x0, training=True)
        tape.watch(outs0)  # cannot even diff otherwise
        outs1 = model1(x0, training=True)
        tape.watch(outs1)
        outs = tf.concat([outs0, outs1], axis=1)
        tape.watch(outs)
        loss = model1.compiled_loss(y0_2, outs)
    print_report(1, model1, outs, loss, tape)
    
    #%%# Case 2: Two passes, stateful RNN, model_rnn #############################
    reset_seeds()
    ipt = Input(batch_shape=(2, 3, 4))
    out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
    model2 = Model(ipt, out)
    model2.compile('sgd', 'mse')
    #%%#############################################################
    with tf.GradientTape(persistent=True) as tape:
        outs, _ = model_rnn(model2, x00, swap_batch_timestep=False)
        outs = tf.concat(list(outs), axis=1)
        loss = model2.compiled_loss(y0_2, outs)
    print_report(2, model2, outs, loss, tape)
    
    #%%# Case 3: Single pass, stateless RNN, model_rnn ###########################
    reset_seeds()
    ipt  = Input(batch_shape=(2, 4))
    sipt = Input(batch_shape=(2, 4))
    out, state = SimpleRNNCell(4)(ipt, sipt)
    model3 = Model([ipt, sipt], [out, state])
    model3.compile('sgd', 'mse')
    #%%#############################################################
    with tf.GradientTape(persistent=True) as tape:
        outs, _ = model_rnn(model3, x0_2)
        outs = tf.transpose(outs, (1, 0, 2))
        loss = model3.compiled_loss(y0_2, outs)
    print_report(3, model3, outs, loss, tape, idx=2)
    

    垂直流:我们已经验证了水平的,timewise-反向传播;垂直呢?

    为此,我们实现了一个stacked stateful RNN;结果如下。我机器上的所有输出,here

    我们特此验证了 verticalhorizo​​ntal 有状态反向传播。这可用于实现具有正确反向传播的任意复杂的前向传播逻辑。应用示例here

    #%%# Case 4: Complete forward pass; control case ############################
    reset_seeds()
    ipt = Input(batch_shape=(2, 6, 4))
    x   = SimpleRNN(4, return_sequences=True)(ipt)
    out = SimpleRNN(4, return_sequences=True)(x)
    model4 = Model(ipt, out)
    model4.compile('sgd', 'mse')
    #%%
    with tf.GradientTape(persistent=True) as tape:
        outs = model4(x0_2, training=True)
        loss = model4.compiled_loss(y0_2, outs)
    print("=" * 80)
    print_report(4, model4, outs, loss, tape, idx=1)
    print_report(4, model4, outs, loss, tape, idx=2)
    
    #%%# Case 5: Two passes, stateless RNN; model_rnn ############################
    reset_seeds()
    ipt = Input(batch_shape=(2, 6, 4))
    out = SimpleRNN(4, return_sequences=True)(ipt)
    model5a = Model(ipt, out)
    model5a.compile('sgd', 'mse')
    
    ipt  = Input(batch_shape=(2, 4))
    sipt = Input(batch_shape=(2, 4))
    out, state = SimpleRNNCell(4)(ipt, sipt)
    model5b = Model([ipt, sipt], [out, state])
    model5b.compile('sgd', 'mse')
    #%%
    with tf.GradientTape(persistent=True) as tape:
        outs = model5a(x0_2, training=True)
        outs, _ = model_rnn(model5b, outs)
        outs = tf.transpose(outs, (1, 0, 2))
        loss = model5a.compiled_loss(y0_2, outs)
    print_report(5, model5a, outs, loss, tape)
    print_report(5, model5b, outs, loss, tape, idx=2)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2012-02-03
      • 2019-07-24
      • 2017-01-05
      • 2020-09-10
      • 2019-03-18
      • 1970-01-01
      • 2019-07-04
      • 1970-01-01
      相关资源
      最近更新 更多