5

내 텐션 흐름 버전은 0.11입니다. 나는 훈련 후에 그래프를 저장하거나 tensorflow가로드 할 수있는 다른 것을 저장하려고합니다.tensorflow에서 훈련 후 모델을 사용하는 방법 (그래프 저장 /로드)

I/난 이미이 게시물 읽어 내보내기를 사용하고 MetaGraph

가져 오기 : Tensorflow: how to save/restore a model?

Save.py 파일 :

X = tf.placeholder("float", [None, 28, 28, 1], name='X') 
Y = tf.placeholder("float", [None, 10], name='Y') 

tf.train.Saver() 
with tf.Session() as sess: 
    ...run something ... 
    final_tensor = tf.nn.softmax(py_x, name='final_result') 
    tf.add_to_collection("final_tensor", final_tensor) 

    predict_op = tf.argmax(py_x, 1) 
    tf.add_to_collection("predict_op", predict_op) 

saver.save(sess, 'my_project') 

가 그럼 난 실행 load.py :

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('my_project.meta') 
    new_saver.restore(sess, 'my_project') 
    predict_op = tf.get_collection("predict_op")[0] 
    for i in range(2): 
     test_indices = np.arange(len(teX)) # Get A Test Batch 
     np.random.shuffle(test_indices) 
     test_indices = test_indices[0:test_size] 

     print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
         sess.run(predict_op, feed_dict={"X:0": teX[test_indices], 
                 "p_keep_conv:0": 1.0, 
                 "p_keep_hidden:0": 1.0}))) 
,

하지만 오류 반환

Traceback (most recent call last): 
    File "load_05_convolution.py", line 62, in <module> 
    "p_keep_hidden:0": 1.0}))) 
    File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run 
    run_metadata_ptr) 
    File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 894, in _run 
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) 
ValueError: Cannot feed value of shape (256, 784) for Tensor u'X:0', which has shape '(?, 28, 28, 1)' 

난 정말 왜 몰라? 이 tf.add_to_collection 만 한 곳에서만 홀더를 포함하고 있기 때문에

Traceback (most recent call last): 
    File "load_05_convolution.py", line 46, in <module> 
    final_tensor = tf.get_collection("final_result")[0] 
IndexError: list index out of range 

이다 : 나는 final_tensor = tf.get_collection("final_result")[0]

그것은 또 다른 오류를 반환 추가하는 경우

? tf.train.write_graph

를 사용

II

는/I이 성공적으로 파일 'train.pb'

을 만들어 tf.train.write_graph(graph, 'folder', 'train.pb')

save.py의 끝에 다음 행을 추가 내 load.py :

with tf.gfile.FastGFile('folder/train.pb', 'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    _ = tf.import_graph_def(graph_def, name='') 

with tf.Session() as sess: 
    predict_op = sess.graph.get_tensor_by_name('predict_op:0') 
    for i in range(2): 
     test_indices = np.arange(len(teX)) # Get A Test Batch 
     np.random.shuffle(test_indices) 
     test_indices = test_indices[0:test_size] 

     print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
         sess.run(predict_op, feed_dict={"X:0": teX[test_indices], 
                 "p_keep_conv:0": 1.0, 
                 "p_keep_hidden:0": 1.0}))) 

그런 다음 오류를 반환 :

Traceback (most recent call last): 
    File "load_05_convolution.py", line 22, in <module> 
    graph_def.ParseFromString(f.read()) 
    File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString 
    self.MergeFromString(serialized) 
    File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1085, in MergeFromString 
    raise message_mod.DecodeError('Unexpected end-group tag.') 
google.protobuf.message.DecodeError: Unexpected end-group tag. 

표준 방법, 코드 또는 튜토리얼을 공유하여 모델을 저장 /로드 하시겠습니까? 나는 정말로 혼란 스럽다. 합니다 (MetaGraph 사용)

+0

인가를? load.py의'new_saver.restore (sess, 'my_projec')' 경로를 올바르게 점검하십시오. –

+0

죄송합니다. 타이핑 할 때의 실수. 로드 중.py it 'tich_chap'하지만 'project'로 변경하면 이해하기 쉽습니다. –

+0

@AayushKumarSingha, 아이디어가 있습니까 –

답변

2

첫 번째 솔루션은 거의 작동하지만, 당신이 의 배치와 4-D 텐서로 MNIST 훈련 예제의 배치를 예상하는 tf.placeholder() MNIST 훈련 예를 평평하게 공급되기 때문에 오류가 발생한다 모양 batch_sizexheight (= 28) x width (= 28) x channels (= 1). 이를 해결하는 가장 쉬운 방법은 입력 데이터의 모양을 바꾸는 것입니다. 대신이 문 :

print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
       sess.run(predict_op, feed_dict={ 
        "X:0": teX[test_indices], 
        "p_keep_conv:0": 1.0, 
        "p_keep_hidden:0": 1.0}))) 

... 대신에, 적절하게 입력 데이터를 고쳐 다음과 같은 성명, 시도 : 그것은 오타

print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
       sess.run(predict_op, feed_dict={ 
        "X:0": teX[test_indices].reshape(-1, 28, 28, 1), 
        "p_keep_conv:0": 1.0, 
        "p_keep_hidden:0": 1.0}))) 
+0

정말 작동하지 않습니다. –

+0

@ZHANGJuenjie 더 구체적으로 할 수 있습니까? 같은 코드를 실행하고 오류가 발생하려고합니까? 그렇다면 어느 것입니까? – mrry

관련 문제