2017-04-07 6 views
1

기계 학습 및 Tensorflow의 초보자이며 예제 튜토리얼 소스 코드를 사용하여 모델이 교육 받고 정확도가 인쇄되지만 모델을 내보내는 소스 코드는 포함되어 있지 않습니다. 변수 및 가져 오기 새 이미지를 예측합니다.복원 된 교육 모델의 Tensorflow 보고서 오류

그래서 모델을 내보내고 테스트 데이터 세트를 사용하여 예측할 새 파이썬 스크립트를 만들 소스 코드를 수정했습니다. 여기

이 교육 모델을 내보낼 수있는 소스 코드 : 새로운 파이썬 스크립트에서

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) 
print("run here3") 
# Create the model 
x = tf.placeholder(tf.float32, [None, 784], name="x") 
W = tf.Variable(tf.zeros([784, 10]), name="W") 
b = tf.Variable(tf.zeros([10])) 
y = tf.matmul(x, W) + b 
saver = tf.train.Saver() 
sess = tf.InteractiveSession() 
... ignore the source code for the cost function definition and train the model 
#after the model get trained, save the variables and y 
tf.add_to_collection('W', W) 
tf.add_to_collection('b', b) 
tf.add_to_collection('y', y) 

saver.save(sess, 'result') 

을, 나는 Y 기능을

sess = tf.Session() 
saver = tf.train.import_meta_graph('result.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
W = tf.get_collection('W')[0] 
b = tf.get_collection('b')[0] 
y = tf.get_collection('y')[0] 


mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True) 
img = mnist.test.images[0] 
x = tf.placeholder(tf.float32, [None, 784]) 
sess.run(y, feed_dict={x: mnist.test.images}) 

모든 작품을 모델을 복원하고 다시 실행하려고 올바르게, 나는 그들을 인쇄 할 경우 W 및 b 값을 얻을 수 있지만 마지막 문 (실행 y 함수)을 실행하는 동안 오류가 발생합니다.

Caused by op u'x', defined at: 
File "predict.py", line 58, in <module> 
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run 
_sys.exit(main(_sys.argv[:1] + flags_passthrough)) 
File "predict.py", line 25, in main 
saver = tf.train.import_meta_graph('result.meta') 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site- packages/tensorflow/python/training/saver.py", line 1566, in import_meta_graph 
**kwargs) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 498, in import_scoped_meta_graph 
producer_op_list=producer_op_list) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 288, in import_graph_def 
op_def=op_def) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2327, in create_op 
original_op=self._default_original_op, op_def=op_def) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1226, in __init__ 
self._traceback = _extract_stack() 

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'x' with dtype float 
[[Node: x = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

그것의 나는 X를 정의하고, Y 기능을 실행하는 동안 작동하지 않는 이유, 나도 몰라 동일한 방법을 사용하여 x를 공급하는 동일한 문을 사용 할 이상한 원인?

+0

왜 mnist = input_data.read_data_sets ('/ tmp/tensorflow/mnist/input_data', one_hot = True)와 함께 'mnist = input_data.read_data_sets (FLAGS.data_dir, one_hot = True) ? – tagoma

+0

아, 모델을 복원하기 위해 만든 새 스크립트의 경우 하드 코드 된 데이터 폴더를 사용하여 단순화하고 FLAGS.data_dir은/tmp/tensorflow/mnist/input_data와 동일합니다. – mailme365

답변

1

문제는 새로운 자리 표시 자입니다 :

x = tf.placeholder(tf.float32, [None, 784]) 

같은 이름의 자리를 만드는 것은 충분하지 않습니다. 실제로 모델을 만들 때 사용한 것과 동일한 자리 표시자가 필요합니다.

tf.add_to_collection('x', x) 

를 새 파일에로드 : 따라서 당신은 또한 첫 번째 파일의 컬렉션에 X를 추가해야합니다

x = tf.get_collection('x')[0] 

을 대신 새로운 하나를 만드는.

관련 문제