2016-07-24 2 views
7

나는 사람이 읽을 수있는 형식으로 이진 와이어 형식을 변환하는 방법을 가지고 있지만 난 그냥 이것에 대한 파일 이름을 입력해야이protobuf 그래프를 이진선 형식으로 변환하는 방법은 무엇입니까?

import tensorflow as tf 
from tensorflow.python.platform import gfile 

def converter(filename): 
    with gfile.FastGFile(filename,'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    tf.import_graph_def(graph_def, name='') 
    tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pb', as_text=True) 
    return 

의 역을 할 수 있으며, 그것은 작동합니다. 그러나 반대의 일을 내가

File "pb_to_pbtxt.py", line 16, in <module> 
    converter('protobuf.pb') # here you can write the name of the file to be converted 
    File "pb_to_pbtxt.py", line 11, in converter 
    graph_def.ParseFromString(f.read()) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py", line 185, in ParseFromString 
    self.MergeFromString(serialized) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1008, in MergeFromString 
    if self._InternalParse(serialized, 0, length) != length: 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1034, in InternalParse 
    new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 868, in SkipField 
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 
    File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 838, in _RaiseInvalidWireType 
    raise _DecodeError('Tag had invalid wire type.') 

답변

4

당신은 google.protobuf.text_format 모듈을 사용하여 역 변환을 수행 할 수 있습니다 얻을 : 당신은 tf.Graph.as_graph_def()을 사용할 수 있습니다

import tensorflow as tf 
from google.protobuf import text_format 

def convert_pbtxt_to_graphdef(filename): 
    """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file. 

    Args: 
    filename: The name of a file containing a GraphDef pbtxt (text-formatted 
     `tf.GraphDef` protocol buffer data). 

    Returns: 
    A `tf.GraphDef` protocol buffer. 
    """ 
    with tf.gfile.FastGFile(filename, 'r') as f: 
    graph_def = tf.GraphDef() 

    file_content = f.read() 

    # Merges the human-readable string in `file_content` into `graph_def`. 
    text_format.Merge(file_content, graph_def) 
    return graph_def 
2

을 다음 Protobuf의 SerializeToString()과 같이 :

proto_graph = # obtained by calling tf.Graph.as_graph_def() 

with open("my_graph.bin", "wb") as f: 
    f.write(proto_graph.SerializeToString()) 

파일을 쓰고 싶은데 상관 없으면 인코딩에 대해 당신은 또한 tf.train.write_graph()

v = tf.Variable(0, name='my_variable') 
sess = tf.Session() 
tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt') 

주를 사용할 수 있습니다은 이전 버전에 대해 확실하지 TF 0.10에서 테스트.

관련 문제