【问题标题】:How do you apply layer normalization in an RNN using tf.keras?如何使用 tf.keras 在 RNN 中应用层规范化?
【发布时间】:2019-08-20 13:51:00
【问题描述】:

我想将layer normalization 应用于使用 tf.keras 的循环神经网络。在 TensorFlow 2.0 中,tf.layers.experimental 中有一个 LayerNormalization 类,但不清楚如何在每个时间步(因为它被设计为用过的)。我应该创建一个自定义单元格,还是有更简单的方法?

例如,在每个时间步应用 dropout 就像在创建 LSTM 层时设置 recurrent_dropout 参数一样简单,但没有 recurrent_layer_normalization 参数。

【问题讨论】:

    标签: python tf.keras tensorflow2.0


    【解决方案1】:

    您可以通过从SimpleRNNCell 类继承来创建自定义单元格,如下所示:

    import numpy as np
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.activations import get as get_activation
    from tensorflow.keras.layers import SimpleRNNCell, RNN, Layer
    from tensorflow.keras.layers.experimental import LayerNormalization
    
    class SimpleRNNCellWithLayerNorm(SimpleRNNCell):
        def __init__(self, units, **kwargs):
            self.activation = get_activation(kwargs.get("activation", "tanh"))
            kwargs["activation"] = None
            super().__init__(units, **kwargs)
            self.layer_norm = LayerNormalization()
        def call(self, inputs, states):
            outputs, new_states = super().call(inputs, states)
            norm_out = self.activation(self.layer_norm(outputs))
            return norm_out, [norm_out]
    

    此实现在没有任何activation 的情况下运行常规SimpleRNN 单元一步,然后将层范数应用于结果输出,然后应用activation。然后你可以这样使用它:

    model = Sequential([
        RNN(SimpleRNNCellWithLayerNorm(20), return_sequences=True,
            input_shape=[None, 20]),
        RNN(SimpleRNNCellWithLayerNorm(5)),
    ])
    
    model.compile(loss="mse", optimizer="sgd")
    X_train = np.random.randn(100, 50, 20)
    Y_train = np.random.randn(100, 5)
    history = model.fit(X_train, Y_train, epochs=2)
    

    对于 GRU 和 LSTM 单元,人们通常在门上应用层范数(在输入和状态的线性组合之后,在 sigmoid 激活之前),所以实现起来有点棘手。或者,您可以通过在应用activationrecurrent_activation 之前应用层规范来获得良好的结果,这将更容易实现。

    【讨论】:

      【解决方案2】:

      在 tensorflow 插件中,有一个开箱即用的预构建 LayerNormLSTMCell

      请参阅this doc 了解更多详情。您可能必须先安装tensorflow-addons,然后才能导入此单元格。

      pip install tensorflow-addons
      

      【讨论】:

        猜你喜欢
        • 2019-12-13
        • 2020-06-12
        • 2021-09-17
        • 2011-05-09
        • 2011-01-08
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2016-06-20
        相关资源
        最近更新 更多