【发布时间】:2021-06-20 15:12:40
【问题描述】:
为什么 tensorflow 会在 tf.keras.Model 的 predict_step 函数内禁用 Eager Execution?也许我错了,但这里有一个例子:
from __future__ import annotations
from functools import wraps
import tensorflow as tf
def print_execution(func):
@wraps(func)
def wrapper(self: SimpleModel, data):
print(tf.executing_eagerly()) # Prints False
return func(self, data)
return wrapper
class SimpleModel(tf.keras.Model):
def __init__(self):
super().__init__()
def call(self, inputs, training=None, mask=None):
return inputs
@print_execution
def predict_step(self, data):
return super().predict_step(data)
if __name__ == "__main__":
x = tf.random.uniform((2, 2))
print(tf.executing_eagerly()) # Prints True
model = SimpleModel()
pred = model.predict(x)
这是预期的行为吗?有没有办法强制predict_step 以急切模式运行?
【问题讨论】:
标签: tensorflow machine-learning keras tensorflow2.0 tf.keras