2017-04-19 3 views
1

언어 모델링을위한 기본 lstm 코드를 실행하고 있습니다. 하지만 난하고 싶지 않아. BPTT. 나는 stateLSTMStateTuple이다, 그러나 tf.stop_gradient(state)tensorflow에서 LSTMStateTuple의 그래디언트를 중지하는 방법

with tf.variable_scope("RNN"): 
    for time_step in range(N): 
    if time_step > 0: tf.get_variable_scope().reuse_variables() 
    (cell_output, state) = cell(inputs[:, time_step, :], state) 

같은 것을하고 싶은, 그래서 시도 :

for lli in range(len(state)): 
    print(state[lli].c, state[lli].h) 
    state[lli].c = tf.stop_gradient(state[lli].c) 
    state[lli].h = tf.stop_gradient(state[lli].h) 

그러나 나는 AttributeError: can't set attribute 오류가있어 : 나는 또한 사용하려고

File "/home/liyu-iri/IRRNNL/word-rnn/ptb/models/decoupling.py", line 182, in __init__ 
state[lli].c = tf.stop_gradient(state[lli].c) 
AttributeError: can't set attribute 

tf.assign이지만 state[lli].c은 변수가 아닙니다.

그래, 어떻게 그라데이션을 멈출 수 있습니까? LSTMStateTuple? 아니면 BPTT를 중단 할 수 있습니까? 나는 단 하나 액자에서 BP를하고 싶다.

고맙습니다.

답변

0

이것은 순수한 python 질문이라고 생각합니다. LSTMStateTuple은 collections.namedtuple이며 python에서는 다른 튜플과 같이 요소를 할당 할 수 없습니다. 해결책은 완전히 새로운 것을 만드는 것입니다. stopped_state = LSTMStateTuple(tf.stop_gradient(old_tuple.c), tf.stop_gradient(old_tuple.h))에서와 같이 사용하고이 상태 (또는 그 목록)를 상태로 사용하십시오. 기존 튜플을 바꾸라고한다면, namedtuple에는 _replace 메서드가 있다고 생각합니다. old_tuple._replace(c=tf.stop_gradient(...))과 같이 here을 참조하십시오. 희망이 도움이됩니다!

+0

대단히 감사합니다! 그게 많은 도움이됩니다! –

관련 문제