2016-07-13 4 views
0

아래에는 훈련 된 TensorFlow 모델을 배포하는 데 사용하는 몇 가지 코드가 있습니다. 기본적으로 .pb 파일에서 모델을로드하고 모델의 첫 번째 레이어와 마지막 레이어를 가져 와서 이미지를 평가합니다. 이것은 잘 작동하지만, 흐릿 해짐, 미백 및 회전과 같은 다양한 이미지 크기와 왜곡이있는 많은 모델을 배포하려고합니다.TensorFlow : 그래프 정의에 왜곡 저장

내 질문은 : 이미지 왜곡 시퀀스 .pb 파일 내부에 저장할 수 있습니까? 그렇다면 어떻게?

목표는 배포 스크립트의 코드 양을 최소화하는 것입니다.

import base64 
import math 
import os 
import tensorflow as tf 

def get_graph(): 
    if os.path.isfile('./graph.pb'): 
     graph_def = tf.GraphDef() 
     with open('./graph.pb', 'rb') as graph_file: 
      graph_def.ParseFromString(graph_file.read()) 
    else: 
     raise Exception('Graph file \'./graph.pb\' does not exist') 
    return graph_def 

def init(event): 
    graph_def = get_graph() 

    with tf.Session() as session: 
     session.graph.as_default() 
     tf.import_graph_def(graph_def, name = '') 

     stringified = base64.b64decode(event['image'].split(',')[1]) 
     decoded = tf.image.decode_jpeg(stringified, channels = 3) 
     decoded.set_shape([event['height'], event['width'], 3]) 
     image = tf.cast(decoded, tf.float32) 

     evaluation = image.eval(sessions = sess) 

     input_tensor = sess.graph.get_tensor_by_name('input_placeholder:0') 
     output_tensor = sess.graph.get_tensor_by_name('softmax_linear/softmax_linear:0') 

     feed_dict = { input_tensor: evaluation } 
     result = sess.run([output_tensor], feed_dict = feed_dict) 
     return result 

답변

0

대답은 원본 graph.pb이 생성 된 방법에 따라 다릅니다.

원래 graph.pb을 생성 한 스크립트를 수정 한 경우 해당 스크립트에 모양을 변경하고 왜곡 등의 작업을 추가하고 graph.pb을 재생성하면됩니다. 이전 input_placeholder 작업을 제거하고 사전 처리 작업의 출력에 input_placeholder이 제공 한 작업의 입력에 연결해야합니다. 그러면 새 자리 표시자는 stringified을 입력으로 사용합니다.

원래 graph.pb을 생성 한 스크립트를 수정할 수없는 경우 전처리 하위 그래프를 자신의 .pb에 저장하여 배포 스크립트의 코드 양을 줄일 수 있습니다. 당신은 다음 전처리 서브 그래프에 원시 입력을 공급하고, 원래의 그래프에 그 (AN eval 호출을 통해 여전히 획득)의 출력을 공급할 수

raw_input = tf.placeholder(tf.string) 
decoded = tf.image.decode_jpeg(raw_input, channels = 3) 
decoded.set_shape([event['height'], event['width'], 3]) 
image = tf.cast(decoded, tf.float32) 
with open('preprocess.pb', 'w') as f: 
    f.write(tf.get_default_graph().as_graph_def()) 

:처럼 뭔가.

관련 문제