2017-10-22 1 views
1

나는 들어오는 그라데이션을 두 배로하는 커스텀 그라디언트 계산 기능을 가지고 있습니다. 내가 할TensorFlow 커스텀 그라디언트

import tensorflow as tf 

@tf.RegisterGradient("CustomSquare") 
def _custom_square_grad(op, grad): 
    return grad*2.0 

c = tf.constant(3.) 

s1 = tf.square(c) 
grad1 = tf.gradients(s1, c)[0] 

g = tf.get_default_graph() 
with g.gradient_override_map({"Square": "CustomSquare"}): 
    s2 = tf.square(c) 
    grad2 = tf.gradients(s2, c)[0] 

with tf.Session() as sess: 
    print(sess.run([c, s1, grad1])) 
    print(sess.run([c, s2, grad2])) 

결과는 놀라운 일이다 : 나는 두 번째 결과를 기다리고 있었다

[3.0, 9.0, 6.0] 
[3.0, 9.0, 2.0] 

[3.0, 9.0, 12.0] 될 수 있습니다. 내가 뭘 놓치고 있니?

감사합니다.

답변

1

즉, _custom_square_grad의 올바른 버전은 다음과 같아야합니다 코드를 이해하기 위해

@tf.RegisterGradient("CustomSquare")            
def _custom_square_grad(op, grad):            
    x = op.inputs[0]                
    return 2.0 * (grad * 2.0 * x) 

, 당신은 어떻게 gradient 작품을 알 필요가있다. tf.RegisterGradient을 정의하면 출력에서 ​​입력으로 그라디언트를 BACK-PROPAGATE합니다. tf.squre의 경우, 기본 그라데이션 기능은 다음과 같이이다 : 당신이 당신의 정의 그라데이션 기능에 그라데이션을 두 배로 원하기 때문에

# Given y = tf.square(x) => y' = 2x 
grad_x = grad_y * 2.0 * x 

, 당신은 단순히 grad_x = 2.0 * (grad_y * 2.0 * x)으로 변경할 수 있습니다.