【发布时间】:2017-12-28 20:05:57
【问题描述】:
我正在尝试为我的自定义 TF 操作定义渐变方法。我在网上找到的大多数解决方案似乎都是基于harpone 的gist。我不愿意使用这种方法,因为它使用了py_func,它不能在 GPU 上运行。我找到了另一个解决方案here,它使用了看起来更优雅的tf.identity(),我认为 将在GPU 上运行。但是,我在访问自定义渐变函数中的操作输入时遇到了一些问题。这是我的代码:
@tf.RegisterGradient('MyCustomGradient')
def _custom_gradient(op, gradients):
x = op.inputs[0]
return(x)
def my_op(w):
return tf.pow(w,3)
var_foo = tf.Variable(5, dtype=tf.float32)
bar = my_op(var_foo)
g = tf.get_default_graph()
with g.gradient_override_map({'Identity': 'MyCustomGradient'}):
bar = tf.identity(bar)
g = tf.gradients(bar, var_foo)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(g))
我期待 _custom_gradient() 将输入返回给 op(本例中为 5),但它似乎返回 op output x gradient。我的自定义 my_op 将具有不可微分的操作,例如 tf.sign ,我想根据输入定义我的自定义渐变。我究竟做错了什么?
【问题讨论】:
-
我认为发生的事情是自定义渐变附加到
identity()op 而不是我希望的my_op()函数。
标签: python tensorflow