글로벌 단계를 포함하여 여러 변수가 포함 된 모델이 있습니다. MonitoredSession을 사용하여 100 단계마다 검사 점 및 요약을 성공적으로 저장할 수있었습니다. Session이 다중 패스 (this 문서를 기반으로 함)로 실행될 때 MonitoredSession이 모든 변수를 자동으로 복원 할 것으로 예상했지만, 이는 발생하지 않습니다. 교육 세션을 다시 실행 한 후 글로벌 단계를 살펴 본다면, 0에서 다시 시작한다는 것을 알았습니다. 이것은 실제 모델이없는 간단한 코드입니다. 더 많은 코드가이 코드를 처음 실행하면TensorFlow : MonitoredSession에서 모델 복원
train_graph = tf.Graph()
with train_graph.as_default():
# I create some datasets using the Dataset API
# ...
global_step = tf.train.create_global_step()
# Create all the other variables and the model here
# ...
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir='checkpoint/',
save_secs=None,
save_steps=100,
saver=tf.train.Saver(),
checkpoint_basename='model.ckpt',
scaffold=None)
summary_hook = tf.train.SummarySaverHook(
save_steps=100,
save_secs=None,
output_dir='summaries/',
summary_writer=None,
scaffold=None,
summary_op=train_step_summary)
num_steps_hook = tf.train.StopAtStepHook(num_steps=500) # Just for testing
with tf.train.MonitoredSession(
hooks=[saver_hook, summary_hook, num_steps_hook]) as sess:
while not sess.should_stop():
step = sess.run(global_step)
if (step % 100 == 0):
print(step)
sess.run(optimizer)
, 나는이 시점에서 검사 점 폴더 모든 백에 대한 체크 포인트를 가지고
0
100
200
300
400
이 출력을 얻을이 문제를 해결하기 위해 필요한 경우 알려줘 500 단계로 올라갑니다. 프로그램을 다시 실행하면 카운터 시작이 500으로 증가하고 900까지 증가 할 것으로 예상되지만, 그 대신에 같은 점을 얻고 모든 체크 포인트를 덮어 씁니다. 어떤 아이디어?