2017-02-08 2 views
0

디스크에서 텐류 흐름 모델을로드하고 값을 예측하려고합니다. 나는 튜플 (rowkey, input_vector)로 구성 RDD이pyspark를 사용하여 맵을 mapPartition으로 변환

코드

def get_value(row): 
    print("**********************************************") 
    graph = tf.Graph() 
    rowkey = row[0] 
    checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300" 
    print("Loading model................................") 
    with graph.as_default(): 
     session_conf = tf.ConfigProto(
      allow_soft_placement=allow_soft_placement, 
      log_device_placement=log_device_placement) 
     sess = tf.Session(config=session_conf) 
     with sess.as_default(): 
      # Load the saved meta graph and restore variables 
      saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 
      saver.restore(sess, checkpoint_file) 
      input_x = graph.get_operation_by_name("X_train").outputs[0] 
      dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 
      predictions = graph.get_operation_by_name("output/predictions").outputs[0] 
      batch_predictions = sess.run(predictions, {input_x: [row[1]], dropout_keep_prob: 1.0}) 
      print(batch_predictions) 
      return (rowkey, batch_predictions) 

. 로드 된 모델을 사용하여 입력의 점수/클래스를 예측하고 싶습니다.

코드 모델은 각 튜플에 대해 매번로드되고 시간이 많이 소요

result = data_rdd.map(lambda iter: get_value(iter)) 
result.foreach(print) 

문제는 내가지도를 호출 할 때마다입니다) (get_value를 호출합니다.

나는 mapPartitions를 사용하여 모델을로드 생각하고 get_value에게 함수를 호출하는 맵을 사용하고 있습니다. 코드를 parition으로 한 번만로드하고 실행 시간을 줄이는 mapPartition으로 코드를 변환하는 방법에 대한 단서가 없습니다.

미리 감사드립니다.

답변

0

아래 코드는 mapPartitions를 사용하기 때문에 큰 발전이라고 생각합니다. session_pickle =의 cPickle에, 파일 "/home/sahil/Desktop/Relation_Extraction/temp.py", 줄 465 :

코드

def predict(rows): 
    graph = tf.Graph() 
    checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300" 
    print("Loading model................................") 
    with graph.as_default(): 
     session_conf = tf.ConfigProto(
      allow_soft_placement=allow_soft_placement, 
      log_device_placement=log_device_placement) 
     sess = tf.Session(config=session_conf) 
     with sess.as_default(): 
      # Load the saved meta graph and restore variables 
      saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 
      saver.restore(sess, checkpoint_file) 
     print("**********************************************") 
     # Get the placeholders from the graph by name 
     input_x = graph.get_operation_by_name("X_train").outputs[0] 
     dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 
     # Tensors we want to evaluate 
     predictions = graph.get_operation_by_name("output/predictions").outputs[0] 

     # Generate batches for one epoch 
     for row in rows: 
      X_test = [row[1]] 
      batch_predictions = sess.run(predictions, {input_x: X_test, dropout_keep_prob: 
      yield (row[0], batch_predictions) 


result = data_rdd.mapPartitions(lambda iter: predict(iter)) 
result.foreach(print) 
1

귀하의 질문에 올바르게 답변을 드릴지 확실하지 않지만 여기에서 귀하의 코드를 최적화 할 수 있습니다.

graph = tf.Graph() 

checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300" 

with graph.as_default(): 
     session_conf = tf.ConfigProto(
      allow_soft_placement=allow_soft_placement, 
      log_device_placement=log_device_placement) 
     sess = tf.Session(config=session_conf) 

s = sess.as_default() 
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 
saver.restore(sess, checkpoint_file) 


input_x = graph.get_operation_by_name("X_train").outputs[0] 
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 
predictions = graph.get_operation_by_name("output/predictions").outputs[0] 

session_pickle = cPickle.dumps(sess) 

def get_value(key, vector, session_pickle): 
    sess = cPickle.loads(session_pickle) 
    rowkey = key 
    batch_predictions = sess.run(predictions, {input_x: [vector], dropout_keep_prob: 1.0}) 
    print(batch_predictions) 
    return (rowkey, batch_predictions 



result = data_rdd.map(lambda (key, row): get_value(key=key, vector = row , session_pickle = session_pickle)) 
result.foreach(print) 

따라서 텐 토류 흐름 세션을 직렬화 할 수 있습니다. 비록 내가 여기에 귀하의 코드를 테스트하지 않았습니다. 이것을 실행하고 코멘트를 남겨주세요.

+0

오류가 '역 추적 (가장 최근 통화 마지막) 팝업 .dumps (sess) TypeError : SwigPyObject 객체를 피클링 할 수 없습니다. ' – wadhwasahil

관련 문제