2017-12-09 5 views
1

Tensorflow에서 필자의 모델은 사전 훈련 된 모델을 기반으로하고 있으며 몇 가지 변수를 추가하고 사전 훈련 된 모델에서 일부 모델을 제거했습니다. 검사 점 파일에서 변수를 복원 할 때 제외해야하는 그래프에 추가 한 모든 변수를 명시 적으로 지정해야합니다. 예를 들어, 내가했다Tensorflow에서 검사 점의 변수 만 복원하는 방법?

exclude = # explicitly list all variables to exclude 
variables_to_restore = slim.get_variables_to_restore(exclude=exclude) 
saver = tf.train.Saver(variables_to_restore) 

더 쉬운 방법이 있습니까? 즉, 변수가 검사 점에 있지 않으면 복원을 시도하지 마십시오.

답변

1

먼저 수행 할 수있는 작업은 검사 점에서와 동일한 모델을 사용하고 두 번째로 검사 점 값을 동일한 모델로 복원하는 것입니다. 동일한 모델에 대한 변수를 복원 한 후 새 도면층을 추가하거나 기존 도면층을 삭제하거나 도면층의 두께를 변경할 수 있습니다.

그러나주의해야 할 중요한 사항이 있습니다. 새 레이어를 추가 한 후에는 새 레이어를 초기화해야합니다. tf.global_variables_initializer()을 사용하면 다시로드 된 레이어의 값이 손실됩니다. 따라서 초기화되지 않은 가중치 만 초기화해야합니다.이 경우 다음 함수를 사용할 수 있습니다.

def initialize_uninitialized(sess): 
    global_vars = tf.global_variables() 
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars]) 
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] 

    # for i in not_initialized_vars: # only for testing 
    # print(i.name) 

    if len(not_initialized_vars): 
     sess.run(tf.variables_initializer(not_initialized_vars)) 
+0

OK 들으! 그러나 처음에'sess.run (tf.global_variables_initializer())'를 실행하고 그래프를 복원하면 어떨까요? 이 방법은 내가 처음에는 모든 것을 초기화한다고 가정하지만 복원은 검사 점에있는 초기화 된 일부 변수를 덮어 씁니다. 잘못 됐나? – user131379

+0

정확하고 정확하게해야 할 일입니다. – eneski

0

는 먼저 (그래프도 의미) 유용하며, 다음의 체크 포인트가 아니라 모두에서 두 개의 교차 공동 세트를 추가 모든 변수를 찾아야한다.

variables_can_be_restored = list(set(tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)).intersection(tf.train.list_variables(checkpoint_dir))) 

다음이 같은 보호기를 정의한 후 복원 :

temp_saver = tf.train.Saver(variables_can_be_restored) 
ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir, lastest_filename) 
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 
temp_saver.restore(sess, ckpt_state.model_checkpoint_path) 
관련 문제