2017-11-23 4 views
0

tensorflow를 사용하여 예측할 때 예측과 관련된 클래스의 이름을 어떻게 얻을 수 있습니까? 현재는 일련의 확률 만 반환합니다. 이것은 이미지를 예측하는 데 사용하는 코드입니다.예측 후 tensorflow 클래스 이름을 표시하는 방법

class Prediction: 

def __init__(self, filename, filepath, image_size = 128, number_channels = 3): 

    self.x_batch = [] 
    self.images = [] 
    self.image_size = image_size 
    self.number_channels = number_channels 

    self.image = cv2.imread(filename) 

    self.modelpath = filepath 
    self.modelfilepath = filepath + '/train-model.meta' 

    self.sess = tf.Session() 
    self.graph = None 
    self.y_pred = None 


def resize_image(self): 
    self.image = cv2.resize(self.image, (self.image_size, self.image_size), cv2.INTER_LINEAR) 
    self.images.append(self.image) 
    self.images = np.array(self.images, dtype=np.uint8) 
    self.images = self.images.astype('float32') 
    self.images = np.multiply(self.images, 1.0/255.0) 
    self.x_batch = self.images.reshape(1, self.image_size, self.image_size, self.number_channels) 


def restore_model(self): 

    saver = tf.train.import_meta_graph(self.modelfilepath) 
    saver.restore(self.sess, tf.train.latest_checkpoint(self.modelpath)) 

    self.graph = tf.get_default_graph() 

    self.y_pred = self.graph.get_tensor_by_name("y_pred:0") 


def predict_image(self): 
    x = self.graph.get_tensor_by_name("x:0") 
    y_true = self.graph.get_tensor_by_name("y_true:0") 
    y_test_images = np.zeros((1, 2)) 

    feed_dict_testing = {x: self.x_batch, y_true: y_test_images} 
    result = self.sess.run(self.y_pred, feed_dict=feed_dict_testing) 
    return result 

감사합니다.

답변

0

지상 진실 값에 대한 정확도를 측정하는 방법을 확인하려면 학습 코드를 보는 것이 도움이됩니다. 이것은 재교육 처음에 tensorflow examples에서 바로 제공

 predictions = self.sess.run(self.y_pred, feed_dict=feed_dict_testing) 

     # Format predicted classes for display 
     # use np.squeeze to convert the tensor to a 1-d vector of probability values 
     predictions = np.squeeze(predictions) 

     top_k = predictions.argsort()[-5:][::-1] # Getting the indicies of the top 5 predictions 

     # read the class labels in from the label file 
     f = open(labelPath, 'rb') 
     lines = f.readlines() 
     labels = [str(w).replace("\n", "") for w in lines] 
     print("") 
     print ("Image Classification Probabilities") 
     # Output the class probabilites in descending order 
     for node_id in top_k: 
      human_string = filter_delimiters(labels[node_id]) 
      score = predictions[node_id] 
      print('{0:s} (score = {1:.5f})'.format(human_string, score)) 

- 즉,이처럼 사용할 수있는 레이블 파일이 필요했다. 희망이 도움이

관련 문제