【问题标题】:Running TensorFlow with XLA tf.function throws error使用 XLA tf.function 运行 TensorFlow 会引发错误
【发布时间】:2020-12-21 09:59:20
【问题描述】:
当我试图编译这个code 时,得到以下错误。
File "xla_test.py", line 25, in <module>
@tf.function(jit_compile=True)
TypeError: function() got an unexpected keyword argument 'jit_compile'
【问题讨论】:
标签:
tensorflow2.0
tensorflow-xla
【解决方案1】:
无需切换到 tf-nightly,只需使用:
@tf.function(experimental_compile=True)
来自tensorflow docs:
experimental_compile 如果为 True,则函数始终由 XLA 编译。 XLA 在某些情况下可能更高效(例如 TPU、XLA_GPU、密集张量计算)。
在我的情况下,没有该参数的 MCMC 采样:~1 分 37 秒,experimental_compile=True:~6 秒。
从源代码构建的 TensorFlow(r2.4 分支)。
【解决方案2】:
安装 tf-nightly 解决了这个问题。
pip install tf-nightly