그래프 및 변수 값을 포함하여 Tensorflow (0.12.0) 모델을 저장 한 다음 나중에로드하고 실행하려고합니다. 나는 이것에 대한 문서 및 기타 게시물을 읽었지만 작동하도록 기본을 얻을 수 없습니다. 나는 this page in the Tensorflow docs에서 기술을 사용하고있다. 코드 :Tensorflow 모델 저장 및로드
저장 간단한 모델 :
myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0 = tf.train.Saver()
saver0.save(sess, './myModel.ckpt')
saver0.export_meta_graph('./myModel.meta')
나중에 부하 모델을 실행합니다
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, './myModel.meta')
print sess.run(myVar)
질문 1 : 구원의 코드가 보인다 작동 할 수 있지만, 로딩 코드는이를 생산 오류 :
W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
해결 방법.
질문 2 : 나는 TF의 문서에서 패턴을 따르지이 줄을 포함 ...
tf.add_to_collection('modelVariables', myVar)
...하지만 왜 그 라인이 필요하다?
expert_meta_graph
전체 그래프를 기본적으로 내보내지 않습니까? 그렇지 않다면 저장하기 전에 그래프의 모든 변수를 콜렉션에 추가해야합니까? 또는 복원 후에 액세스 할 변수를 콜렉션에 추가하기 만합니까?
---------------------- 업데이트 1 월 12 일 2017 ------------------ -----------
아래의 Kashyap의 제안을 바탕으로 부분 성공했지만 여전히 신비가 존재합니다. 아래 코드는 tf.add_to_collection
및 tf.get_collection
이 포함 된 행을 포함하는 경우에만 이지만을 처리합니다. 이 줄이 없으면 '로드'모드에서 마지막 줄에 오류가 발생합니다. NameError: name 'myVar' is not defined
. 기본적으로 Saver.save
은 그래프의 모든 변수를 저장하고 복원하므로 컬렉션에 사용될 변수의 이름을 지정해야하는 이유는 무엇입니까? 이것은 Tensorflow의 변수 이름을 파이썬 이름으로 매핑하는 것과 관련이 있다고 가정합니다. 그러나 여기 게임의 규칙은 무엇입니까? 이것이 필요한 변수는 무엇입니까?
mode = 'load' # or 'save'
if mode == 'save':
myVar = tf.Variable(7.1)
init_op = tf.global_variables_initializer()
saver0 = tf.train.Saver()
tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0.save(sess, './myModel')
if mode == 'load':
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, tf.train.latest_checkpoint('./'))
myVar = tf.get_collection('myVar')[0] ### WHY NECESSARY?
print sess.run(myVar)
[Tensorflow : 이전에 저장된 모델 (Python)을 복원하는 방법] (http://stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-save-model-python) – Kashyap