2017-11-15 2 views
0

의 문자열 출력을 구문 분석하는 방법 :여기에 코드를 사용하여 모델을 만든 tensorflow 모델

public static void main(String[] args) { 
    Session session = SavedModelBundle.load("/Users/gagandeep.malhotra/Documents/SampleTF_projects/tf_iris_model/1510707746/", "serve").session(); 

    Tensor x = 
     Tensor.create(
      new long[] {2, 4}, 
      FloatBuffer.wrap(
       new float[] { 
        6.4f, 3.2f, 4.5f, 1.5f, 
        5.8f, 3.1f, 5.0f, 1.7f 
       })); 

    final String xName = "Placeholder:0"; 
    final String scoresName = "dnn/head/predictions/probabilities:0"; 

    List<Tensor<?>> outputs = session.runner() 
     .feed(xName, x) 
     .fetch(scoresName) 
     .run(); 

    // Outer dimension is batch size; inner dimension is number of classes 
    float[][] scores = new float[2][3]; 

    outputs.get(0).copyTo(scores); 
    System.out.println(Arrays.deepToString(scores)); 
    } 

: https://gist.github.com/gaganmalhotra/1424bd3d0617e784976b29d5846b16b1

이 코드 아래 사용하여 수행 할 수 있습니다 자바의 probabilites의 예측을 얻으려면

final String xName = "Placeholder:0"; 
final String className = "dnn/head/predictions/str_classes:0"; 

List<Tensor<?>> outputs = session.runner() 
    .feed(xName, x) 
    .fetch(className) 
    .run(); 

// Outer dimension is batch size; inner dimension is number of classes 
String[][] classes = new String[2][1]; 

outputs.get(0).copyTo(classes); 
System.out.println(Arrays.deepToString(classes)); 

내가 에로와 끝까지 : 우리는 아래의 코드에 대한 예측 클래스 (문자열 형식)을 통해 복사 할 그러나 경우 R이 같은

Exception in thread "main" java.lang.IllegalArgumentException: cannot copy Tensor with 2 dimensions into an object with 1 
    at org.tensorflow.Tensor.throwExceptionIfTypeIsIncompatible(Tensor.java:739) 
    at org.tensorflow.Tensor.copyTo(Tensor.java:450) 
    at deeplearning.IrisTFLoad.main(IrisTFLoad.java:71) 

그러나 차원 출력 텐서와 동일하다 : 도형 [2와 STRING 텐서 1]

PS : 서명 정의는 다음과 같이 찾을 수있다 -

The given SavedModel SignatureDef contains the following input(s): 
    inputs['x'] tensor_info: 
     dtype: DT_FLOAT 
     shape: (-1, 4) 
     name: Placeholder:0 
    The given SavedModel SignatureDef contains the following output(s): 
    outputs['class_ids'] tensor_info: 
     dtype: DT_INT64 
     shape: (-1, 1) 
     name: dnn/head/predictions/ExpandDims:0 
    outputs['classes'] tensor_info: 
     dtype: DT_STRING 
     shape: (-1, 1) 
     name: dnn/head/predictions/str_classes:0 
    outputs['logits'] tensor_info: 
     dtype: DT_FLOAT 
     shape: (-1, 3) 
     name: dnn/head/logits:0 
    outputs['probabilities'] tensor_info: 
     dtype: DT_FLOAT 
     shape: (-1, 3) 
     name: dnn/head/predictions/probabilities:0 
    Method name is: tensorflow/serving/predict 

가지 시도 :

텐서 텐서 = (텐서) outputs.get (0); 바이트 [] [] [] 결과 = tensor.copyTo (새 바이트 [2] [1] []);

하지만 아래와 같이 오류 아웃 :

Exception in thread "main" java.lang.IllegalStateException: invalid DataType(7) 
    at org.tensorflow.Tensor.readNDArray(Native Method) 
    at org.tensorflow.Tensor.copyTo(Tensor.java:451) 
    at deeplearning.IrisTFLoad.main(IrisTFLoad.java:74) 

답변

1

DT_STRING 입력 TensorFlow 텐서 포함 arbitrary byte sequences 요소가 아닌 자바 String의 (문자 순서) 등. 자바에게 String 객체를 얻을하려는 경우

byte[][][] classes = new byte[2][1][]; 
outputs.get(0).copyTo(classes); 

것은, 당신이 모델을 인코딩하는의 클래스를 생성 무엇을 알아야하고 있습니다 :

따라서, 당신이 원하는 것은이 같은 것입니다 (UTF-8 인코딩 가정) :

String[][] classesStrings = new String[2][1]; 
for (int i = 0; i < classes.length; ++i) { 
    for (int j = 0; j < classes[i].length; ++j) { 
    classesString[i][j] = new String(classes[i][j], UTF_8); 
    } 
} 

희망이 있습니다. unittest 유익한 정보를 찾을 수도 있습니다.

+0

감사합니다. @ash –

관련 문제