2016-09-06 4 views
0

저는 강화 학습을 위해 노력하고 있으며 학습 속도를 높이기 위해 학습 중에 sess.run()을 통해 피드하는 데이터의 양을 줄이고 싶습니다.TensorFlow : 그래프 내 LSTM 상태 저장/업데이트

CurrentStateOption = tf.Variable(0, trainable=False, name='SavedState') 
with tf.name_scope("LSTMLayer") as scope: 
     initializer = tf.random_uniform_initializer(-.1, .1) 
     lstm_cell_L1 = tf.nn.rnn_cell.LSTMCell(self.input_sizes, forget_bias=1.0, initializer=initializer, state_is_tuple=True) 
     self.cell_L1 = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_L1] *self.NumberLSTMLayers, state_is_tuple=True) 
     self.state = self.cell_L1.zero_state(1,tf.float64) 

     self.SavedState = self.cell_L1.zero_state(1,tf.float64) #tf.Variable(state, trainable=False, name='SavedState') 

     #SaveCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(1)), self.SaveState, self.SameState) 
     #RestoreCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(-1)), self.RestoreState, self.SameState) 
     #ZeroCond = tf.cond(tf.less(CurrentStateOption,tf.constant(-1)), self.ZeroState, self.SameState) 

     self.state = tf.case({tf.equal(CurrentStateOption,tf.constant(1)): self.SaveState, tf.equal(CurrentStateOption,tf.constant(-1)): self.RestoreState, 
      tf.less(CurrentStateOption,tf.constant(-1)): self.ZeroState}, default=self.SameState, exclusive=True) 

     RunConditions = tf.group([SaveCond, RestoreCond, ZeroCond]) 

     self.Xinputs = [tf.concat(1,[Xinputs])] 

     outputs, stateFINAL_L1 = rnn.rnn(self.cell_L1,self.Xinputs, initial_state=self.state, dtype=tf.float32) 
: I가 LSTM 내로 기대 적절한 Q 값을 찾기 위해 재설정 할 필요성 찾던

, 난 tf.case()와 함께 이와 같은 용액을 제작

def RestoreState(self): 
    #self.state = self.state.assign(self.SavedState) 
    self.state = self.SavedState 
    return self.state 
def ZeroState(self): 
    self.state = self.cell_L1.zero_state(1,tf.float64) 
    return self.state 
def SaveState(self): 
    #self.SavedState = self.SavedState.assign(self.state) 
    self.SavedState = self.state 
    return self.SavedState 
def SameState(self): 
    return self.state 

이것은 내가 승 할 무엇 LSTM 그래프를 지시하기 위해 INT를 공급할 수 지금처럼 개념에서 잘 작동하는 것 같다 그 상태. "1"을 통과하면 실행 전에 상태를 저장하고 "-1"을 전달하면 마지막 저장된 상태를 복원합니다. "< -1"을 전달하면 상태가 0이됩니다. "0"이면 마지막 실행 (추론)에서 LSTM에있는 것을 사용합니다. 몇 가지 접근법을 시도해 보았습니다. 더 간단한 tf.cond() 접근법을 포함 시켰습니다.

내가 생각하는 문제는 텐서가 필요한 tf.case() 연산에 기인하지만 LSTM 상태는 튜플입니다 (비 튜플은 가치 하락 될 것입니다). 이것은 그래프 변수에 값을 tf.assign()하려고 할 때 분명 해졌다.

내 최종 목표는 그래프 내에 "상태"를 남기고 INT를 전달하여 상태로 수행 할 작업을 지시합니다. 미래에 나는 다양한 look-back을위한 여러 "store"위치를 갖고 싶습니다.

tf.case() 유형을 처리하는 방법을 튜플 대 텐서로 처리 하시겠습니까?

답변

0

튜플이 단지 파이썬 튜플이기 때문에 상태 튜플에 요소 당 하나의 tf.case()가 있어야한다고 생각합니다.

관련 문제