2017-10-31 4 views
1

새로운 TF Estimator API으로 전송 학습/마지막 계층 재교육을 사용하는 방법을 파악할 수 없었습니다.TensorFlow Estimators로 학습/재교육

Estimatordocumentation에 정의 된대로 네트워크 아키텍처와 교육 및 평가 작업을 포함하는 model_fn이 필요합니다. CNN 아키텍처를 사용하는 model_fn의 예는 here입니다.

예를 들어 입력 아키텍처의 마지막 계층을 다시 테스트하려면이 전체 모델을이 model_fn에 지정해야하는지 여부를 잘 모르는 경우 사전 훈련 된 가중치를로드하거나 '전통적인'방식 (예 : here)에서 수행 된 것처럼 저장된 그래프를 사용하는 방법입니다.

이것은 issue으로 제기되었지만 아직 열려 있으며 답변이 분명하지 않습니다.

답변

2

모델 정의 중에 메타 문을로드하고 SessionRunHook을 사용하여 ckpt 파일에서 가중치를로드 할 수 있습니다.

def model(features, labels, mode, params): 
    # Create the graph here 

    return tf.estimator.EstimatorSpec(mode, 
      predictions, 
      loss, 
      train_op, 
      training_hooks=[RestoreHook()]) 

SessionRunHook은 다음과 같습니다

class RestoreHook(tf.train.SessionRunHook): 

    def after_create_session(self, session, coord=None): 
     if session.run(tf.train.get_or_create_global_step()) == 0: 
      # load weights here 

이 방법, 가중치는 첫 번째 단계에서로드 및 모델 체크 포인트에서 훈련 기간 동안 저장됩니다.