2016-10-20 5 views
3

Keras 및 TensorFlow에서 비동기 버전의 actor-critic을 구현하려고합니다. Keras를 네트워크 계층을 구축하기위한 프런트 엔드로 사용하고 있습니다 (매개 변수를 직접 tensorflow로 업데이트 중입니다). global_model과 하나의 주요 tensorflow 세션이 있습니다. 하지만 각 스레드 내에서 local_modelglobal_model에서 매개 변수를 복사 만드는 중입니다. 내 코드는 내가 작업이 동일한 그래프에 있어야 말하는 tf.assign 작업에 tensorflow 오류 다음 Keras 및 Tensorflow의 모델을 다중 스레드 설정으로 복제

UserWarning: The default TensorFlow graph is not the graph associated with the TensorFlow session currently registered with Keras, and as such Keras was not able to automatically initialize a variable. You should consider registering the proper session with Keras via K.set_session(sess)

Keras

에서 사용자 경고를이

def main(args): 
    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True) 
    sess = tf.Session(config=config) 
    K.set_session(sess) # K is keras backend 
    global_model = ConvNetA3C(84,84,4,num_actions=3) 

    threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)] 

    for t in threads: 
     t.start() 

def a3c_thread(i, sess, global_model): 
    K.set_session(sess) # registering a session for each thread (don't know if it matters) 
    local_model = ConvNetA3C(84,84,4,num_actions=3) 
    sync = local_model.get_from(global_model) # I get the error here 

    #in the get_from function I do tf.assign(dest.params[i], src.params[i]) 

같은 것을 보인다.

ValueError: Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref, device=/device:CPU:0) must be from the same graph as Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref)

정확히 무엇이 잘못 될지 확신하지 못합니다. tf.get_default_graph() is sess.graphFalse을 반환하기 때문에 오류가 Keras에서 오는

감사

답변

5

. TF 문서에서 나는 tf.get_default_graph()이 현재 스레드에 대한 기본 그래프를 반환하고 있음을 알았습니다. 새 스레드를 시작하고 그래프를 만드는 순간, 해당 스레드와 관련된 별도의 그래프로 작성됩니다. 다음을 수행하여이 문제를 해결할 수 있습니다.

with sess.graph.as_default(): 
    local_model = ConvNetA3C(84,84,4,3)