【发布时间】:2019-02-14 13:34:21
【问题描述】:
我想使用 tf.keras 构建一个自定义层。为简单起见,假设它应该在训练期间返回输入*2,在测试期间返回输入*3。这样做的正确方法是什么?
我试过这种方法:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training:
return inputs*2
else:
return inputs*3
然后我可以像这样使用这个类:
>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)
效果很好!但是,当我在模型中使用这个类并调用它的fit() 方法时,似乎training 没有设置为True。我尝试在call()方法的开头添加如下代码,但是training总是等于0。
if training is None:
training = K.learning_phase()
我错过了什么?
编辑
我找到了解决方案(请参阅我的答案),但我仍在寻找使用 @tf.function 的更好解决方案(我更喜欢签名而不是 smart_cond() 业务)。不幸的是,看起来K.learning_phase() 与@tf.function 不匹配(我的猜测是当call() 函数被跟踪时,学习阶段会被硬编码到图中:因为这发生在调用@ 987654334@方法,学习阶段始终为0)。这可能是一个错误,或者在使用 @tf.function 时可能有另一种方法可以进入学习阶段。
【问题讨论】:
标签: tensorflow keras tf.keras