2017-04-20 4 views
3

자바에서 훈련 된 모델 (Tensorflow, Python)을 가져오고 사용하려고했습니다.Java로 Tensorflow 모델 가져 오기

모델을 Python으로 저장할 수 있었지만 Java에서 동일한 모델을 사용하여 예측을 시도 할 때 문제가 발생했습니다.

Here을 사용하면 모델 초기화, 교육, 저장을위한 파이썬 코드를 볼 수 있습니다.

Here을 사용하면 입력 값을 가져오고 예측할 수있는 Java 코드를 볼 수 있습니다. 문제가 파이썬 코드에 어딘가에, Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7 [[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]] at org.tensorflow.Session.run(Native Method) at org.tensorflow.Session.access$100(Session.java:48) at org.tensorflow.Session$Runner.runHelper(Session.java:285) at org.tensorflow.Session$Runner.run(Session.java:235) at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)

나는 생각하지만, 나는 그것을 찾을 수 없습니다 :

내가 오류 메시지입니다.

도움을 주시면 감사하겠습니다.

피터

감사
+0

제가 사용 [이 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java) inspiration as – szi

답변

5

자바 importGraphDef() 함수는 컴퓨터 g raph (파이썬 코드에서 tf.train.write_graph로 작성), 테스트 된 변수 값 (체크 포인트에 저장 됨)을로드하지 않으므로 초기화되지 않은 변수에 대해 불평하는 오류가 발생합니다.

반면에 TensorFlow SavedModel format에는 모델 (그래프, 검사 점 상태, 기타 메타 데이터)에 대한 모든 정보가 포함되어 있으며 훈련 된 변수 값으로 초기화 된 세션을 만들려면 SavedModelBundle.load을 사용하려는 Java에서 사용하십시오.

def save_model(session, input_tensor, output_tensor): 
    signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)}, 
    outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)}, 
) 
    b = saved_model_builder.SavedModelBuilder('/tmp/model') 
    b.add_meta_graph_and_variables(session, 
           [tf.saved_model.tag_constants.SERVING], 
           signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}) 
    b.save() 
:

파이썬에서이 형식의 모델을 내보내려면, 당신은 관련된 질문 귀하의 경우에는 Deploy retrained inception SavedModel to google cloud ml engine

를 살펴 할 수 있습니다, 이것은 파이썬에서 다음과 같이 뭔가에 달할한다

그리고 호출 그 사용하여 모델을로드

save_model(session, x, yhat)를 통해 그리고 자바 :

try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) { 
    // b.session().run(...) 
} 

희망이 도움이됩니다.

+0

경고 : 이것은 Java에서 작동하지만 TF는 현재 Android에서 SavedModel을로드하는 것을 지원하지 않습니다. 이것을 힘든 방법으로 찾아 냈습니다. :/ – Keilaron

+0

대신 # 12750 또는 # 13079 문제를 보거나 https://www.tensorflow.org/mobile/prepare_models – Keilaron

1

확실히이 실패합니다 파이썬 모델 :

sess.run(init) #<---this will fail 
save_model(sess) 
error = tf.reduce_mean(tf.square(prediction - y)) 

#accuracy = tf.reduce_mean(tf.cast(error, 'float')) 
print('Error:', error) 

init 모델에 정의되지 않은 - 당신이에 달성 원하는 것을 확실 해요 이 장소,하지만 당신에게 출발점을 제공해야합니다

1

Fwiw, Deeplearning4j를 사용하면 TensorFlow에서 Keras 1.0으로 교육 한 모델을 가져올 수 있습니다 (Keras 2.0 지원이 진행 중입니다).

https://deeplearning4j.org/model-import-keras

우리는 또한 텐서를 처리 할 때 Py4j 것보다 더 효율적 포인터 대신 복제 데이터를 사용 NumPy와 배열 및 Pyjnius 래퍼가있다 예민라는 라이브러리를 만들었습니다.

https://deeplearning4j.org/jumpy

+0

을 참조하십시오. 고마워요! 저는 Tensorflow에 GAN을 구축하려고했는데, DL4J가 GAN을 일반적으로 지원하지 않는다는 것을 알고있는 한, Tensorflow와 함께 훈련 된 GAN을 사용할 수있는 해결 방법을 찾고 JVM을 실행했습니다. – szi

관련 문제