2017-01-12 1 views
2

이것은 사실 여기에있는 질문이 아닙니다 ... CNTK python api - continue training a model 그들은 관련되어 있지만 동일하지는 않습니다.CNTK python api - 분류 자 ​​훈련 계속

나는 1500 신기원의 모델을 훈련 받았고 평균 67 % 정도의 손실이 발생했습니다. 나는 그 다음 내가 코딩이있는, 훈련을 계속하려면 :

def Create_Trainer(train_reader, minibatch_size, epoch_size, checkpoint_path=None, distributed_after=INFINITE_SAMPLES): 
#Create Model with Params 
lr_per_minibatch = learning_rate_schedule(
    [0.01] * 10 + [0.003] * 10 + [0.001], UnitType.minibatch, epoch_size) 
momentum_time_constant = momentum_as_time_constant_schedule(
    -minibatch_size/np.log(0.9)) 
l2_reg_weight = 0.0001 
input_var = input_variable((num_channels, image_height, image_width)) 
label_var = input_variable((num_classes)) 
feature_scale = 1.0/256.0 
input_var_norm = element_times(feature_scale, input_var) 
z = create_model(input_var_norm, num_classes) 
#Create Error Functions 
if(checkpoint_path): 
    print('Loaded Checkpoint!') 
    z.load_model(checkpoint_path) 
ce = cross_entropy_with_softmax(z, label_var) 
pe = classification_error(z, label_var)  

#Create Learner  
learner = momentum_sgd(z.parameters, 
         lr=lr_per_minibatch, momentum=momentum_time_constant, 
         l2_regularization_weight=l2_reg_weight) 
if(distributed_after != INFINITE_SAMPLES): 
    learner = distributed.data_parallel_distributed_learner(
     learner = learner, 
     num_quantization_bits = 1, 
     distributed_after = distributed_after 
    ) 
input_map = { 
    input_var: train_reader.streams.features, 
    label_var: train_reader.streams.labels 
} 
return Trainer(z, ce, pe, learner), input_map 

통지 코드의 라인 : 경우 (checkpoint_path) : 약 반쯤합니다. 나는이 기능을 통해 저장 이전 훈련에서 .dnn 파일을로드

...

if current_epoch % checkpoint_frequency == 0: 
      trainer.save_checkpoint(os.path.join(checkpoint_path + "_{}.dnn".format(current_epoch))) 

이 실제로 .dnn과 .dnn.ckp 파일을 생성합니다. 분명히 load_model에서 .dnn 파일 만로드합니다.

교육을 다시 시작하고 모델을로드하면 마치 네트워크 아키텍처를로드하는 것처럼 보이지만 가중치가 아닌 것 같습니다. 이 작업을 수행하는 올바른 방법은 무엇입니까?

고맙습니다!

답변

4

대신 trainer.restore_from_checkpoint를 사용해야합니다. 트레이너와 학습자를 다시 만들어야합니다.

곧 트레이너/미니 배너/분산 상태를 돌보는 쉬운 방법으로 원활한 복원을 가능하게하는 교육 세션이 될 것입니다.

중요한 점 : python 스크립트에서 노드를 만드는 네트워크 구조와 순서는 검사 점을 만들 때와 그 시점에서 복원 할 때 동일해야합니다.

관련 문제