2016-12-08 3 views
3

나는 tensorflow 도구를 사용하여 신경망을 작성했습니다. 모든 것이 작동하고 이제는 단일 예측 방법을 만들기 위해 신경망의 최종 가중치를 내보내려고합니다. 어떻게해야합니까?tensorflow를 사용한 신경망의 가중치 내보내기

+0

https://nathanbrixius.wordpress.com/2016/05/24/checkpointing-and-reusing-tensorflow-models/ – martianwars

답변

2

tf.train.Saver 클래스를 사용하여 교육이 끝나면 모델을 저장해야합니다.

개체를 초기화하는 동안 저장하려는 모든 변수 목록을 전달해야합니다. 가장 중요한 부분은 다른 계산 그래프에서 이러한 저장된 변수를 사용할 수 있다는 것입니다!

은 물론

saver.save(sess, 'filename'); 

,

# Assume you want to save 2 variables `v1` and `v2` 
saver = tf.train.Saver([v1, v2]) 

tf.Session 객체를 사용하여 변수를 저장, 사용하여 Saver 객체를 생성하면 global_step 같은 추가 정보를 추가 할 수 있습니다.

나중에 restore() 기능을 사용하여 변수를 복원 할 수 있습니다. 복원 된 변수는 자동으로이 값으로 초기화됩니다.

+1

매개 변수의 원시 데이터를 가져올 수 있습니까? 다른 플랫폼에서 tensorflow-trained-model을 실행하고 싶습니다. 어떻게해야합니까? –

+2

'sess.run (weights)'을 사용하여 가중치의 최종 값을 얻고이를 numpy 배열 (예 : – martianwars

+0

)로 내보낼 수 있습니다. 또 다른 문제점 : 네트에서'tf.nn.rnn_cell.LSTMCell'을 사용했는데 어떻게 'LSTMCell' 객체의 가중치/바이어스에 접근 할 수 있습니까? –

0

위의 대답은 세션 스냅 샷을 저장/복원하는 표준 방법입니다. 그러나 네트워크를 다른 바이너리 파일로 내보내 다른 텐서 흐름 도구와 함께 사용하려면 몇 가지 단계를 더 수행해야합니다.

먼저, freeze the graph. TF는 해당 도구를 제공합니다. 나는이처럼 사용

#!/bin/bash -x 

# The script combines graph definition and trained weights into 
# a single binary protobuf with constant holders for the weights. 
# The resulting graph is suitable for the processing with other tools. 


TF_HOME=~/tensorflow/ 

if [ $# -lt 4 ]; then 
    echo "Usage: $0 graph_def snapshot output_nodes output.pb" 
    exit 0 
fi 

proto=$1 
snapshot=$2 
out_nodes=$3 
out=$4 

$TF_HOME/bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=$proto \ 
    --input_checkpoint=$snapshot \ 
    --output_graph=$out \ 
    --output_node_names=$out_nodes 

당신이 optimize it for inference을 할 수있는, 또는 any other tool를 사용, 그 일을 가졌어요.

관련 문제