2017-05-23 6 views
1

나는 우리가이 작업을 수행 할 수있는 tutorial에서 볼 : TensorFlow에서 protos의 값에 액세스하는 방법?

for node in tf.get_default_graph().as_graph_def().node: print node

임의의 네트워크에서 수행

, 우리는 많은 키 값 쌍을 얻을. 예 :

name: "conv2d_2/convolution" 
op: "Conv2D" 
input: "max_pooling2d/MaxPool" 
input: "conv2d_1/kernel/read" 
device: "/device:GPU:0" 
attr { 
    key: "T" 
    value { 
    type: DT_FLOAT 
    } 
} 
attr { 
    key: "data_format" 
    value { 
    s: "NHWC" 
    } 
} 
attr { 
    key: "padding" 
    value { 
    s: "SAME" 
    } 
} 
attr { 
    key: "strides" 
    value { 
    list { 
     i: 1 
     i: 1 
     i: 1 
     i: 1 
    } 
    } 
} 
attr { 
    key: "use_cudnn_on_gpu" 
    value { 
    b: true 
    } 
} 

이 값들을 모두 파이썬리스트에 어떻게 넣을 수 있습니까? 구체적으로, 우리는 "strides"속성을 얻고 거기에있는 모든 1을 [1, 1, 1, 1]로 어떻게 변환 할 수 있습니까?

답변

1

TLDR : 아래의 코드를 사용할 수있는 것입니다 :

for n in tf.get_default_graph().as_graph_def().node: if 'strides' in n.attr.keys(): print n.name, [int(a) for a in n.attr['strides'].list.i] if 'shape' in n.attr.keys(): print n.name, [int(a.size) for a in n.attr['shape'].shape.dim]

이 일에 트릭 protobufs이 무엇인지 이해하는 것입니다. 위에서 언급 한 tutorial을 살펴 보겠습니다. 모든

첫째는 성명있다 :

for node in graph_def.node

각 노드가 tensorflow/코어/프레임 워크/node_def.proto에서 정의하는 NodeDef 개체입니다. 이들은 각각 빌딩 블록의 TensorFlow 그래프이며 각각은 입력 연결과 함께 단일 작업을 정의합니다. 다음은 NodeDef의 구성원과 그 의미입니다.

참고는 node_def.proto에서 다음

  • 그것은 attr_value.proto를 가져옵니다.
  • name, op, input, device, attr과 같은 속성이 있습니다. 특히 입력 앞에 repeated이라는 용어가 있습니다. 우리는 지금 이것을 무시할 수 있습니다.

이 정확히 파이썬 클래스처럼 작동하고 우리가 이렇게 등 node.name, node.op, node.input, node.device, node.attr를 호출 할 수 있습니다 우리가 액세스하고자하는 어떤

이제 node.attr의 내용이됩니다. 튜토리얼을 다시 한번 참조하면 다음과 같이 지정됩니다.

노드의 모든 속성을 포함하는 키/값 저장소입니다. 이 은 컨볼 루션 필터의 크기와 같이 런타임에서 변경되지 않는 노드의 영구 속성이거나 constant ops의 값입니다. 문자열, 정수, 텐서 값 배열에 이르기까지 많은 다른 유형의 속성 값이있을 수 있으므로 에있는 데이터 구조를 정의하는 별도의 protobuf 파일이 tenorflow/core/framework/attr_value.proto에 있습니다. .

각 특성에는 고유 한 이름 문자열이 있으며 작업이 정의 될 때 예상 특성 인 이 나열됩니다.노드에있는 속성이 이 아니지만 작업 정의에 나열된 기본값이 있으면 그래프를 만들 때 기본값이 사용됩니다.

파이썬에서 node.name, node.op, 등을 호출하여이 모든 멤버에 액세스 할 수 있습니다. GraphDef에 저장된 노드 목록은 전체 모델 아키텍처의 정의입니다.

이 키 - 값 저장소이므로 n.attr.keys()을 호출하여이 속성의 키 목록을 볼 수 있습니다. 그러한 키가 사용 가능하다면 스트라이드에 액세스하려면 아마도 n.attr['strides']으로 전화하십시오. 우리는이를 인쇄하려고 할 때, 우리는 다음을 얻을 :

list { 
    i: 1 
    i: 2 
    i: 2 
    i: 1 
} 

을 우리가 list(n.attr['strides']) 또는 이런 종류의 뭔가를 시도 할 수 있기 때문에 혼란 스러울 시작 곳이다. attr_value.proto를 보면, 우리는 무슨 일이 일어나고 있는지 이해할 수 있습니다. oneof value이고이 경우는 ListValue list이므로 n.attr['strides'].list으로 전화 할 수 있습니다. 우리는이를 인쇄 할 경우, 우리는 다음을 얻을 :

i: 1 
i: 1 
i: 1 
i: 1 

우리는 다음이 작업을 수행하려고 할 수 있습니다 [a for a in n.attr['strides'].list] 또는 [a.i for a in n.attr['strides'].list]. 그러나 아무것도 작동하지 않습니다. 이것은 repeated이 이해해야 할 중요한 용어입니다. 기본적으로 int64 목록이 있다는 것을 의미하므로 i 속성을 사용하여 액세스해야합니다. 그러면 [int(a) for a in n.attr['strides'].list.i]을 사용하면 우리가 사용할 수있는 Python 목록 인 우리가 원하는 것을 얻을 수 있습니다 :

[1, 1, 1, 1] 
관련 문제