2017-12-21 4 views
0

글로벌 단계를 포함하여 여러 변수가 포함 된 모델이 있습니다. MonitoredSession을 사용하여 100 단계마다 검사 점 및 요약을 성공적으로 저장할 수있었습니다. Session이 다중 패스 (this 문서를 기반으로 함)로 실행될 때 MonitoredSession이 모든 변수를 자동으로 복원 할 것으로 예상했지만, 이는 발생하지 않습니다. 교육 세션을 다시 실행 한 후 글로벌 단계를 살펴 본다면, 0에서 다시 시작한다는 것을 알았습니다. 이것은 실제 모델이없는 간단한 코드입니다. 더 많은 코드가이 코드를 처음 실행하면TensorFlow : MonitoredSession에서 모델 복원

train_graph = tf.Graph() 
with train_graph.as_default(): 
    # I create some datasets using the Dataset API 
    # ... 

    global_step = tf.train.create_global_step() 

    # Create all the other variables and the model here 
    # ... 

    saver_hook = tf.train.CheckpointSaverHook(
     checkpoint_dir='checkpoint/', 
     save_secs=None, 
     save_steps=100, 
     saver=tf.train.Saver(), 
     checkpoint_basename='model.ckpt', 
     scaffold=None) 
    summary_hook = tf.train.SummarySaverHook(
     save_steps=100, 
     save_secs=None, 
     output_dir='summaries/', 
     summary_writer=None, 
     scaffold=None, 
     summary_op=train_step_summary) 
    num_steps_hook = tf.train.StopAtStepHook(num_steps=500) # Just for testing 


    with tf.train.MonitoredSession(
     hooks=[saver_hook, summary_hook, num_steps_hook]) as sess: 
    while not sess.should_stop(): 
     step = sess.run(global_step) 
     if (step % 100 == 0): 
     print(step) 
     sess.run(optimizer) 

, 나는이 시점에서 검사 점 폴더 모든 백에 대한 체크 포인트를 가지고

0 
100 
200 
300 
400 

이 출력을 얻을이 문제를 해결하기 위해 필요한 경우 알려줘 500 단계로 올라갑니다. 프로그램을 다시 실행하면 카운터 시작이 500으로 증가하고 900까지 증가 할 것으로 예상되지만, 그 대신에 같은 점을 얻고 모든 체크 포인트를 덮어 씁니다. 어떤 아이디어?

답변

0

알았어, 알아 냈어. 사실 아주 간단했습니다. 첫째, MonitoredSession() 대신 MonitoredTraningSession()을 사용하는 것이 더 쉽습니다. 이 랩퍼 세션은 'checkpoint_dir'인수로 사용됩니다. saver_hook이 복원 작업을 처리 할 것이라고 생각했지만 그렇지 않습니다. 내 문제를 해결하기 위해 난 그냥과 같이 세션을 정의하는 라인을 변경했다 :

with tf.train.MonitoredTrainingSession(hooks=[saver_hook, summary_hook], checkpoint_dir='checkpoint'): 

또한 직접 MonitoredSession 수행 할 수 있습니다,하지만 당신은 대신 session_creator을 설정해야합니다.