나는 link을 따라 마스크라는 사용자 정의 op를 작성했습니다. tensorflow 연산의 본체는 실제로 꽤 많이 인용 된 링크를 따라ValueError : op 이름에 대해 생성 된 그라디언트 1의 수 : "마스크/마스크"
def tf_mask(x, labels, epoch_, name=None): # add "labels" to the input
with ops.name_scope(name, "Mask", [x, labels, epoch_]) as name:
z = py_func(np_mask,
[x, labels, epoch_], # add "labels, epoch_" to the input list
[tf.float32],
name=name,
grad=our_grad)
z = z[0]
z.set_shape(x.get_shape())
return z
입니다. 그러나, 나는이 오류로 실행 : I 그라데이션을 계산하는 our_grad
함수를 정의하는 방법을
ValueError: Num gradients 1 generated for op name: "mask/Mask"
op: "PyFunc"
input: "conv2/Relu"
input: "Placeholder_2"
input: "Placeholder_3"
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "_gradient_op_type"
value {
s: "PyFuncGrad302636"
}
}
attr {
key: "token"
value {
s: "pyfunc_0"
}
}
do not match num inputs 3
이 필요한 경우,이입니다.
def our_grad(cus_op, grad):
"""Compute gradients of our custom operation.
Args:
param cus_op: our custom op tf_mask
param grad: the previous gradients before the operation
Returns:
gradient that can be sent down to next layer in back propagation
it's an n-tuple, where n is the number of arguments of the operation
"""
x = cus_op.inputs[0]
labels = cus_op.inputs[1]
epoch_ = cus_op.inputs[2]
n_gr1 = tf_d_mask(x)
n_gr2 = tf_gradient2(x, labels, epoch_)
return tf.multiply(grad, n_gr1) + n_gr2
그리고 py_func
기능 (인용 링크와 같은)은
def py_func(func, inp, tout, stateful=True, name=None, grad=None):
"""
I omitted the introduction to parameters that are not of interest
:param func: a numpy function
:param inp: input tensors
:param grad: a tensorflow function to get the gradients (used in bprop, should be able to receive previous
gradients and send gradients down.)
:return: a tensorflow op with a registered bprop method
"""
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1000000))
tf.RegisterGradient(rnd_name)(grad)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, tout, stateful=stateful, name=name)
은 정말 지역 사회의 도움이 필요합니다!
감사합니다.