tf.case
(https://www.tensorflow.org/api_docs/python/tf/case)을 사용하여 Tensor를 조건부로 업데이트하려고합니다. 표시된 바와 같이 global_step == 2
일 때 learning_rate
을 0.01
으로 업데이트하고 global_step == 4
일 때 0.001
으로 업데이트하려고합니다.Tensorflow : tf.case가 왜 나에게 잘못된 결과를주는 이유는 무엇입니까?
그러나 global_step == 2
일 때 이미 learning_rate = 0.001
이 표시됩니다. 추가 검사시 global_step == 2
(0.01
대신 0.001
이 표시됨)을 입력하면 tf.case
이 잘못된 결과를 표시하는 것 같습니다. 0.01
의 술어가 True로 평가되고 0.001
의 술어가 False로 평가되는 경우에도 이러한 상황이 발생합니다.
제가 잘못했거나 버그입니까?
TF 버전 : 1.0.0
코드 :
import tensorflow as tf
global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')
# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
cases = []
case_tensors = []
for step, new_rate in [(2, 0.01), (4, 0.001)]:
pred = tf.equal(global_step, step)
fn_tensor = tf.constant(new_rate, dtype=tf.float32)
cases.append((pred, lambda: fn_tensor))
case_tensors.append((pred, fn_tensor))
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)
print tf.__version__
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in xrange(6):
print sess.run([global_step, case_tensors, update, updated_learning_rate])
sess.run(train_op)
결과 : 이것은에 대답했다
1.0.0
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]