【发布时间】:2019-02-22 04:09:47
【问题描述】:
我需要将一个方法定义为自定义渐变,如下所示:
class CustGradClass:
def __init__(self):
pass
@tf.custom_gradient
def f(self,x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
我收到以下错误:
ValueError:尝试将类型 () 不受支持的值(main.CustGradClass object at 0x12ed91710>)转换为 Tensor。
原因是自定义渐变接受一个函数 f(*x),其中 x 是一个张量序列。传递的第一个参数是对象本身,即self。
f:函数 f(*x) 返回元组 (y, grad_fn) 其中:
x 是函数的张量输入序列。 y 是将 f 中的 TensorFlow 操作应用于 x 的张量或张量输出序列。 grad_fn 是一个函数,签名为 g(*grad_ys)
如何让它发挥作用?我需要继承一些 python tensorflow 类吗?
我正在使用 tf 版本 1.12.0 和渴望模式。
【问题讨论】:
标签: python tensorflow gradient autodiff