【问题标题】:Proper way of writing a custom layer in keras?在 keras 中编写自定义层的正确方法?
【发布时间】:2021-01-05 04:55:22
【问题描述】:

我看到至少三种在 keras 中创建自定义层的方法。

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model

def reset_seed(seed=313):
    tf.keras.backend.clear_session()
    tf.random.set_seed(seed)
    np.random.seed(313)

class Method1MLP(tf.keras.layers.Layer):

    def __init__(self, in_units, out_units, **kwargs):
        self.dense = Dense(in_units)
        self.out = Dense(out_units)
        super().__init__(**kwargs)
    
    def call(self, x):
        temp = self.dense(x)
        return self.out(temp)


class Method2MLP(tf.keras.layers.Layer):

    def __init__(self, in_units, out_units, **kwargs):
        self.dense = Dense(in_units)
        self.out = Dense(out_units)
        super().__init__(**kwargs)

    def __call__(self, x):

        temp = self.dense(x)
        return self.out(temp)


class Method3MLP(tf.keras.layers.Layer):

    def __init__(self, in_units, out_units, **kwargs): 
        self.in_units = in_units
        self.out_units = out_units
        super().__init__(**kwargs)

    def __call__(self, x):

        temp = Dense(self.in_units)(x)
        return Dense(self.out_units)(temp)

# define dummy inputs and outputs
x = np.random.random((100, 10,5))
y = np.random.random((100, 1))

现在首先构建没有自定义层的模型

reset_seed()

inp = Input(shape=(10,5))
temp = Dense(5)(inp)
out = Dense(1)(temp)
model =  Model(inputs=inp, outputs=out)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mse')
print(model.summary())

model.fit(x=x,y=y, epochs=5)

使用方法一

reset_seed()

inp = Input(shape=(10,5))
out = Method1MLP(5,1)(inp)
model =  Model(inputs=inp, outputs=out)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mse')
print(model.summary())

model.fit(x=x,y=y, epochs=5)

使用方法二

reset_seed()

inp = Input(shape=(10,5))
out = Method2MLP(5,1)(inp)
model =  Model(inputs=inp, outputs=out)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mse')
print(model.summary())

model.fit(x=x,y=y, epochs=5)

使用方法3

reset_seed()

inp = Input(shape=(10,5))
out = Method3MLP(5,1)(inp)
model =  Model(inputs=inp, outputs=out)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mse')
print(model.summary())

model.fit(x=x,y=y, epochs=5)

以上所有代码 sn-ps 给出相同的结果。 虽然官方文档推荐方法 1,但方法 2 和方法 3 的优点是它们暴露了中间输出,即自定义层内的输出。这使得在训练后很容易将这些输出作为 numpy 数组获取。我也想了解方法二和方法三的区别。在__init__方法中是否启动层无关?

方法2(当我们显式写__call__方法时)和方法1(当我们让keras Layer的__call__方法调用我们的call方法时)有什么区别吗?

【问题讨论】:

    标签: python tensorflow keras neural-network keras-layer


    【解决方案1】:

    我看不出方法 1 和 2 有什么不同。也许你忘记了什么?

    我认为方法 3 更慢,因为它每次都创建Dense 层的额外开销。

    你就不能两全其美,就像下面的 sn-p 一样吗?

    class Method4MLP(tf.keras.layers.Layer):
    
        def __init__(self, in_units, out_units, **kwargs):
            self.dense = Dense(in_units)
            self.out = Dense(out_units)
            self.in_units = in_units
            self.out_units = out_units
            super().__init__(**kwargs)
    
        def __call__(self, x):
            temp = self.dense(x)
            return self.out(temp)
    

    【讨论】:

      【解决方案2】:

      对我来说,这些都没有意义。尤其是在 __call__ dunder 方法中定义密集层的地方。但是,当类定义中有两个这样的对象时,子类化一个对象的目的是什么?看起来您可以简单地使用两层顺序模型。

      无论哪种方式,子类化 Keras 层的正确方法都在documentation 中进行了概述。

      from tensorflow import keras
      import tensorflow as tf
      
      class Linear(keras.layers.Layer):
          def __init__(self, units=32, input_dim=32):
              super(Linear, self).__init__()
              w_init = tf.random_normal_initializer()
              self.w = tf.Variable(
                  initial_value=w_init(shape=(input_dim, units), dtype="float32"),
                  trainable=True,
              )
              b_init = tf.zeros_initializer()
              self.b = tf.Variable(
                  initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
              )
      
          def call(self, inputs):
              return tf.matmul(inputs, self.w) + self.b
      

      【讨论】:

      • 示例非常简单,但适用于例如call 内部有一个 softmax,我们想查看这个 softmax 的激活,如果我们使用 call() 而不是 `__call__(),我们如何查看呢?
      猜你喜欢
      • 2018-06-03
      • 1970-01-01
      • 2018-07-21
      • 2020-10-06
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-08-02
      • 2020-01-04
      相关资源
      最近更新 更多