2017-10-15 1 views
1

값 비싼 연산에 조건부가있는 경우 지연된 동작 (예 : 선택된 분기 만 평가할 수 있음)이 필요할 수 있습니다.Tensorflow에서 게으른 조건문을 수행하는 경우

다음 작품, 그리고 게으른 :

>>> a. tf.zeros(0) 
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.argmax(a)).eval() 
-1 

argmax 평가되지 않기 때문에이 오류가 발생할 것이기 때문에 당신은, 그것은 게으른 것을 볼 수 있습니다. argmax를 취한 텐서가 비어 있기 때문입니다.

>>> am = tf.argmax(a) 
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(am, 1)).eval() 
... Reduction axis 0 is empty in shape [0] 

tf.add 작동에 의한되지 않은 : 당신은 람다 밖으로 argmax를 이동하는 경우, 그것은 바로이 오류를 얻을 수 있습니다. 인라인으로 이동하면 다시 작동합니다.

>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(tf.argmax(a), 1)).eval() 
-1 

그런 다음 질문은 클리너 방식으로 게으른 조건문을 수행하는 방법입니까?

답변

1

조건부 함수가 길어지면 위의 방법이 약간 엉망이됩니다. 당신이 할 수있는 것은 조건부 밖에서 람다 식을 정의하는 것입니다. 다음은 파이썬 대화 형 REPL에서 작동하지 않습니다. 결과는 ValueError: Operation 'cond_14/Merge' has been marked as not fetchable.입니다.

코드를 파이썬 파일에 넣고 정상적인 방법으로 실행하면이 작동합니다.

import tensorflow as tf 

sess = tf.InteractiveSession() 

a = tf.zeros(0) 
fn = lambda: tf.argmax(a) 

res = tf.cond(
    tf.equal(tf.size(a), tf.constant(0)), 
    lambda: tf.constant(-1, dtype=tf.int64), 
    fn 
    ).eval() 
print(res) 

res2 = tf.cond(
    tf.equal(tf.size(a), tf.constant(0)), 
    lambda: tf.constant(-1, dtype=tf.int64), 
    lambda: tf.add(fn(), tf.constant(1, dtype=tf.int64)) 
    ).eval() 
print(res2) 
# Output: 
# -1 
# -1 
관련 문제