2017-04-14 4 views
1

Tensorflow 모델을 복원하려고합니다. 선형 회귀 네트워크입니다. 내 예측이 좋지 않아 뭔가 잘못하고 있다고 확신합니다. 내가 훈련 할 때, 나는 시험 세트를 가지고있다. 내 테스트 세트 예측은 훌륭해 보이지만 같은 모델을 복원하려고하면 예측이 좋지 않습니다. 여기Tensorflow 복원 모델

with tf.Session() as sess: 
    saver = tf.train.Saver() 
    init = tf.global_variables_initializer() 
    sess.run(init) 
    training_data, ground_truth = d.get_training_data() 
    testing_data, testing_ground_truth = d.get_testing_data() 

    for iteration in range(config["training_iterations"]): 
     start_pos = np.random.randint(len(training_data) - config["batch_size"]) 
     batch_x = training_data[start_pos:start_pos+config["batch_size"],:,:] 
     batch_y = ground_truth[start_pos:start_pos+config["batch_size"]] 
     sess.run(optimizer, feed_dict={x: batch_x, y: batch_y}) 
     train_acc, train_loss = sess.run([accuracy, cost], feed_dict={x: batch_x, y: batch_y}) 

     sess.run(optimizer, feed_dict={x: testing_data, y: testing_ground_truth}) 
     test_acc, test_loss = sess.run([accuracy, cost], feed_dict={x: testing_data, y: testing_ground_truth}) 
     samples = sess.run(pred, feed_dict={x: testing_data}) 
     # print samples 
     data.compute_acc(samples, testing_ground_truth) 

     print("Training\tAcc: {}\tLoss: {}".format(train_acc, train_loss)) 
     print("Testing\t\tAcc: {}\tLoss: {}".format(test_acc, test_loss)) 
     print("Iteration: {}".format(iteration)) 

     if iteration % config["save_step"] == 0: 
      saver.save(sess, config["save_model_path"]+str(iteration)+".ckpt") 

내 테스트 세트에서 몇 가지 예입니다 : 여기

내가 모델을 저장하는 방법입니다. 여기 그런

My prediction: -12.705 Actual : -10.0 
My prediction: 0.000 Actual : 8.0 
My prediction: -14.313 Actual : -23.0 
My prediction: 17.879 Actual : 13.0 
My prediction: 17.452 Actual : 24.0 
My prediction: 22.886 Actual : 29.0 
Custom accuracy: 5.0159861487 
Training Acc: 5.63836860657 Loss: 25.6545143127 
Testing  Acc: 4.238052845 Loss: 22.2736053467 
Iteration: 6297 

을 내가 모델 복원 방법은 다음과 같습니다 : 당신은 My predictionActual에 상대적으로 가까운 알 수 있습니다 예측 모양을의 여기

with tf.Session() as sess: 
    saver = tf.train.Saver() 
    saver.restore(sess, config["retore_model_path"]+"3000.ckpt") 

    init = tf.global_variables_initializer() 
    sess.run(init) 

    pred = sess.run(pred, feed_dict={x: predict_data})[0] 
    print("Prediction: {:.3f}\tGround truth: {:.3f}".format(pred, ground_truth)) 

을하지만.

Python 2.7.6 (default, Oct 26 2016, 20:30:19) 
[GCC 4.8.4] on linux2 
Type "help", "copyright", "credits" or "license" for more information. 
>>> import tensorflow as tf 
>>> print(tf.__version__) 
0.12.0-rc1 

확실하지 않음이 도움이된다면,하지만 난을 배치하려고 : (예 나는 업데이트해야합니다) 다음

Prediction: 0.355  Ground truth: -22.000 
Prediction: -0.035  Ground truth: 3.000 
Prediction: -1.005  Ground truth: -3.000 
Prediction: -0.184  Ground truth: 1.000 
Prediction: 1.300  Ground truth: 5.000 
Prediction: 0.133  Ground truth: -5.000 

내 tensorflow 버전입니다 : 당신은 Prediction은 항상 옳다 주위 0 인 것을 알 수 있습니다 saver.restore() sess.run(init) 다음에 전화하면 예상되는 결과가 모두 같습니다. 나는 sess.run(init)이 변수를 초기화하기 때문이라고 생각합니다. 주문과 같이

변경 :

sess.run(init) 
saver.restore(sess, config["retore_model_path"]+"6000.ckpt") 

그러나 예측은 다음과 같이 : 당신은 검사 점에서 복원 할 때

Prediction: -15.840  Ground truth: 2.000 
Prediction: -15.840  Ground truth: -7.000 
Prediction: -0.000  Ground truth: 12.000 
Prediction: -15.840  Ground truth: -9.000 
Prediction: -15.175  Ground truth: -27.000 

답변

2

, 당신은 당신의 변수를 초기화하지 않습니다. 당신이 질문 끝에 말했듯이.

init = tf.global_variables_initializer() 
sess.run(init) 

방금 ​​복원 한 변수를 덮어 씁니다. 죄송합니다. :)

댓글 두 라인을 밖으로 나와 당신이 갈 수있을 거라 생각합니다.