2016-11-23 1 views
4

libtensorflow.so 대상을 빌드하여 C API를 구축했습니다. 미리 훈련 된 모델을로드하고 추론을 실행하여 예측을하고 싶습니다. 나는 'c_api.h'헤더 파일 (적절한 파일에 'libtensorflow.so'파일을 복사하는 것과 함께)을 포함 시켜서이 작업을 수행 할 수 있다고 들었지만, 웹상에서 그 파일에 대한 예제를 찾지 못했습니다. 내가 찾을 수있는 것은 Bazel 빌드 시스템을 사용하는 예제이며 다른 빌드 시스템을 사용하고 TensorFlow를 라이브러리로 사용하고 싶습니다. 누군가는 a) 메타 그래프 파일을 가져 오는 방법에 대한 예제를 통해 나를 도울 수 있습니까? b) protobuf 그래프 파일과 검사 점 파일을 사용하여 예측을 수행합니까? 아래의 파이썬 파일에 해당하는 C++ 코드이며 g ++로 빌드 되었습니까?TensorFlow 숙련 된 모델 및 C API로 예측하기

#!/usr/bin/env python 

import tensorflow as tf 
import numpy as np 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('./metagraph.meta') 
    saver.restore(sess, './checkpoint.ckpt') 
    x = tf.get_collection("x")[0] 
    yhat = tf.get_collection("yhat")[0] 
    print sess.run(yhat, feed_dict={x : np.array([[2, 3], [4, 5]])}) 

감사합니다.

추신 :

#!/usr/bin/env python 

import tensorflow as tf 
import numpy as np 

x = tf.placeholder(tf.float32, shape=[None, 2], name='x') 
tf.add_to_collection("x", x) 
y = tf.placeholder(tf.float32, shape=[None, 1], name='y') 
w = tf.Variable(np.array([[10.0], [100.0]]), dtype=tf.float32, name='w') 
b = tf.Variable(0.0, dtype=tf.float32, name='b') 
yhat = tf.add(tf.matmul(x, w), b) 
tf.add_to_collection("yhat", yhat) 
mse_loss = tf.sqrt(tf.reduce_mean(tf.square(tf.sub(y, yhat)))) 
step_size = tf.constant(0.01) 
optimizer = tf.train.GradientDescentOptimizer(step_size) 
init_op = tf.initialize_all_variables() 
train_op = optimizer.minimize(mse_loss) 
saver = tf.train.Saver() 
with tf.Session() as sess: 
    sess.run(init_op) 
    for i in xrange(10000): 
     train_x = np.random.random([100, 2]) * 10 
     train_y = np.dot(train_x, np.array([[100.0], [10.0]])) + 1.0 
     sess.run(train_op, feed_dict={x : train_x, y : train_y}) 
    print sess.run(w) 
    print sess.run(b) 
    saver.save(sess, './checkpoint.ckpt') 
    saver.export_meta_graph('./metagraph.meta') 
    tf.train.write_graph(sess.graph_def, './', 'graph') 
+0

안녕하세요, TensorFlow C++ API 문서를 확인하셨습니까? 이 페이지는 세션을 실행하는 데 필요한 것을 제공해야합니다 : https://www.tensorflow.org/versions/r0.11/api_docs/cc/ClassSession.html 그리고 이것은 그래프를 읽는 데 도움이됩니다 : https : //www.tensorflow.org/versions/r0.11/api_docs/cc/index.html – Neal

+0

도와 주셔서 감사합니다. 그 링크를 보았지만 예제를 찾고 있었지만 찾을 수 없었습니다. –

답변

1

나는 이클립스를 사용하고/usr/지방에 내 프로젝트 파일 및 libtensorflow.so c_api.h 추가/: 완성도를 위해서 나는 파일을 작성하려면 다음을했던 한 큰 상자. 그런 다음 libtensorflow 공유 객체에 대한 참조를 GCC C++ 링커의 라이브러리에 추가하여 최종적으로 간단한 프로그램을 만들었습니다.

#include <iostream> 
#include "c_api.h" 

using namespace std; 

int main() { 
    cout << TF_Version(); 
    return 0; 
} 

그러면 원하는 Tensorflow 기능을 컴파일하고 사용할 수있게되었습니다.