【问题标题】:Defining custom gradient as a class method in Tensorflow将自定义渐变定义为 Tensorflow 中的类方法
【发布时间】:2019-02-22 04:09:47
【问题描述】:

我需要将一个方法定义为自定义渐变,如下所示:

class CustGradClass:

    def __init__(self):
        pass

    @tf.custom_gradient
    def f(self,x):
      fx = x
      def grad(dy):
        return dy * 1
      return fx, grad

我收到以下错误:

ValueError:尝试将类型 () 不受支持的值(ma​​in.CustGradClass object at 0x12ed91710>)转换为 Tensor。

原因是自定义渐变接受一个函数 f(*x),其中 x 是一个张量序列。传递的第一个参数是对象本身,即self

来自documentation

f:函数 f(*x) 返回元组 (y, grad_fn) 其中:
x 是函数的张量输入序列。 y 是将 f 中的 TensorFlow 操作应用于 x 的张量或张量输出序列。 grad_fn 是一个函数,签名为 g(*grad_ys)

如何让它发挥作用?我需要继承一些 python tensorflow 类吗?

我正在使用 tf 版本 1.12.0 和渴望模式。

【问题讨论】:

    标签: python tensorflow gradient autodiff


    【解决方案1】:

    这是一种可能的简单解决方法:

    import tensorflow as tf
    
    class CustGradClass:
    
        def __init__(self):
            self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))
    
        @staticmethod
        def _f(self, x):
            fx = x * 1
            def grad(dy):
                return dy * 1
            return fx, grad
    
    with tf.Graph().as_default(), tf.Session() as sess:
        x = tf.constant(1.0)
        c = CustGradClass()
        y = c.f(x)
        print(tf.gradients(y, x))
        # [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]
    

    编辑:

    如果你想在不同的类上多次这样做,或者只是想要一个更可重用的解决方案,你可以使用像这样的一些装饰器,例如:

    import functools
    import tensorflow as tf
    
    def tf_custom_gradient_method(f):
        @functools.wraps(f)
        def wrapped(self, *args, **kwargs):
            if not hasattr(self, '_tf_custom_gradient_wrappers'):
                self._tf_custom_gradient_wrappers = {}
            if f not in self._tf_custom_gradient_wrappers:
                self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
            return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
        return wrapped
    

    那么你可以这样做:

    class CustGradClass:
    
        def __init__(self):
            pass
    
        @tf_custom_gradient_method
        def f(self, x):
            fx = x * 1
            def grad(dy):
                return dy * 1
            return fx, grad
    
        @tf_custom_gradient_method
        def f2(self, x):
            fx = x * 2
            def grad(dy):
                return dy * 2
            return fx, grad
    

    【讨论】:

    • 很好的解决方法!一个建议:如果你要用lambda 包装该方法,则无需将该方法声明为静态。就说:self.f = tf.custom_gradient(lambda x: self._f(x))
    • 另一个注意事项:TF 2.1+ 支持将tf.custom_gradients 与开箱即用的类方法一起使用。 (请参阅here。)但是,TF 2.0 不支持它。
    【解决方案2】:

    在您的示例中,您没有使用任何成员变量,因此您可以将该方法设为静态方法。如果使用成员变量,则从成员函数调用静态方法并将成员变量作为参数传递。

    class CustGradClass:
    
      def __init__(self):
        self.some_var = ...
    
      @staticmethod
      @tf.custom_gradient
      def _f(x):
        fx = x
        def grad(dy):
          return dy * 1
    
        return fx, grad
    
      def f(self):
        return CustGradClass._f(self.some_var)
    

    【讨论】:

    • 我假设 OP 正在寻找一种解决方案,他们可以在方法中使用成员变量。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-02-10
    • 1970-01-01
    • 2019-12-14
    • 1970-01-01
    • 2021-03-25
    • 1970-01-01
    • 2022-01-18
    相关资源
    最近更新 更多