2017-10-25 6 views
0

tf.scatter_add의 매우 이상한 동작을 발견했습니다. tf.Variable 안에 랩핑 된 Tensor를 작성하는 tf.while_loop을 작성했습니다.tf.scatter_add가 루프에서 오류를 발생시킵니다.

루프 밖의 변수에 뭔가를 추가하지 않으면 tensorflow가 변수가 변경 가능하지 않다는 오류 메시지를 표시합니다.

import tensorflow as tf   

m = 25 
batch_num = 32 
num_bus = 50 

C = tf.zeros((m, batch_num, num_bus, m),tf.float64) 
C = tf.Variable(C) 

c = tf.ones((batch_num, num_bus, m), tf.float64) 
#C = tf.scatter_add(C,0,c) 

k = tf.constant(1) 

stop_cond = lambda k,C: k<m 

def construct_C(k, C): 
    upd_c = c+1 
    C = tf.scatter_add(C,k,upd_c) 
    return k+1,C 

k,C = tf.while_loop(stop_cond,construct_C, (k,C)) 

sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
C1 = sess.run(C) 

이 코드는 오류가 발생합니다 : TypeError: 'ScatterAdd' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable) 여기

은 MWE이다. 그러나 C = tf.scatter_add(C,0,c)의 모든 주석을 제거하면 잘 작동합니다.

이것은 의도 된 것입니까? 내가 도대체 ​​뭘 잘못하고있는 겁니까?

+0

나는 C가 불변의 텐서로 변하는 tf.while_loop일지도 모른다고 생각하지만 너무 확신 할 수 없다. 대신에 파이썬 루프를 사용해보십시오. –

답변

1

일부 while_loop 프리미티브와 같은 사운드는 변수에 대해 알지 못합니다 (대신 ref 유형 인 Tensors에 대해 알고 있습니다). 이것은 코드의 버그처럼 보입니다. github에 문제를 제기하십시오.

관련 문제