【发布时间】:2017-11-19 21:54:59
【问题描述】:
我按照link 创建了一个名为 mask 的自定义操作。 tensorflow op的主体是
def tf_mask(x, labels, epoch_, name=None): # add "labels" to the input
with ops.name_scope(name, "Mask", [x, labels, epoch_]) as name:
z = py_func(np_mask,
[x, labels, epoch_], # add "labels, epoch_" to the input list
[tf.float32],
name=name,
grad=our_grad)
z = z[0]
z.set_shape(x.get_shape())
return z
实际上几乎遵循引用的链接。但是,我遇到了这个错误:
ValueError: Num gradients 1 generated for op name: "mask/Mask"
op: "PyFunc"
input: "conv2/Relu"
input: "Placeholder_2"
input: "Placeholder_3"
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "_gradient_op_type"
value {
s: "PyFuncGrad302636"
}
}
attr {
key: "token"
value {
s: "pyfunc_0"
}
}
do not match num inputs 3
如果需要,这就是我定义our_grad 函数来计算梯度的方式。
def our_grad(cus_op, grad):
"""Compute gradients of our custom operation.
Args:
param cus_op: our custom op tf_mask
param grad: the previous gradients before the operation
Returns:
gradient that can be sent down to next layer in back propagation
it's an n-tuple, where n is the number of arguments of the operation
"""
x = cus_op.inputs[0]
labels = cus_op.inputs[1]
epoch_ = cus_op.inputs[2]
n_gr1 = tf_d_mask(x)
n_gr2 = tf_gradient2(x, labels, epoch_)
return tf.multiply(grad, n_gr1) + n_gr2
还有py_func函数(和引用的链接一样)
def py_func(func, inp, tout, stateful=True, name=None, grad=None):
"""
I omitted the introduction to parameters that are not of interest
:param func: a numpy function
:param inp: input tensors
:param grad: a tensorflow function to get the gradients (used in bprop, should be able to receive previous
gradients and send gradients down.)
:return: a tensorflow op with a registered bprop method
"""
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1000000))
tf.RegisterGradient(rnd_name)(grad)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, tout, stateful=stateful, name=name)
真的需要社区的帮助!
谢谢!
【问题讨论】:
标签: python tensorflow