【问题标题】:Using tf.keras in TF 2.0, how can I define a custom layer that depends on the learning phase?在 TF 2.0 中使用 tf.keras,如何定义依赖于学习阶段的自定义层?
【发布时间】:2019-02-14 13:34:21
【问题描述】:

我想使用 tf.keras 构建一个自定义层。为简单起见,假设它应该在训练期间返回输入*2,在测试期间返回输入*3。这样做的正确方法是什么?

我试过这种方法:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training:
            return inputs*2
        else:
            return inputs*3

然后我可以像这样使用这个类:

>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)

效果很好!但是,当我在模型中使用这个类并调用它的fit() 方法时,似乎training 没有设置为True。我尝试在call()方法的开头添加如下代码,但是training总是等于0。

if training is None:
    training = K.learning_phase()

我错过了什么?

编辑

我找到了解决方案(请参阅我的答案),但我仍在寻找使用 @tf.function 的更好解决方案(我更喜欢签名而不是 smart_cond() 业务)。不幸的是,看起来K.learning_phase()@tf.function 不匹配(我的猜测是当call() 函数被跟踪时,学习阶段会被硬编码到图中:因为这发生在调用@ 987654334@方法,学习阶段始终为0)。这可能是一个错误,或者在使用 @tf.function 时可能有另一种方法可以进入学习阶段。

【问题讨论】:

    标签: tensorflow keras tf.keras


    【解决方案1】:

    François Chollet 确认使用@tf.function 时的正确解决方案是:

    class CustomLayer(Layer):
        @tf.function
        def call(self, inputs, training=None):
            if training is None:
                training = K.learning_phase()
            if training:
                return inputs * 2
            else:
                return inputs * 3
    

    目前有一个错误(截至 2019 年 2 月 15 日)使training 始终等于0,但很快就会修复。

    【讨论】:

      【解决方案2】:

      下面的代码没有使用@tf.function,所以看起来不太好看(因为它没有使用签名),但它工作正常:

      from tensorflow.python.keras.utils.tf_utils import smart_cond
      
      class CustomLayer(Layer):
          def call(self, inputs, training=None):
              if training is None:
                  training = K.learning_phase()
              return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2019-04-27
        • 1970-01-01
        • 1970-01-01
        • 2020-02-27
        • 1970-01-01
        • 1970-01-01
        • 2020-09-11
        相关资源
        最近更新 更多