2016-09-05 2 views
0

업데이트 : 다음과 같은 방법으로 신뢰 점수를 얻으려고했지만 예외가 발생했습니다.Spark MLLib Logistic Regression에서 신뢰 점수를 얻는 방법

double point = BLAS.dot(logisticregressionmodel.weights(), datavector); 
double confScore = 1.0/(1.0 + Math.exp(-point)); 

내가 얻을 예외 :

Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18 
    at scala.Predef$.require(Predef.scala:233) 
    at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99) 
    at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala) 

당신이 도와 주실 수 있습니까 아래의 코드를 사용할 수 있습니까? 가중치 벡터가 데이터 벡터보다 많은 요소 (198)를 가지고있는 것처럼 보입니다 (저는 18 개의 피쳐를 생성하고 있습니다). 그들은 dot() 기능에서 동일한 길이 여야합니다.

Java에 프로그램을 구현하여 Spark MLLib (1.5.0)에서 사용할 수있는 로지스틱 회귀 알고리즘을 사용하여 기존 데이터 집합에서 훈련하고 새로운 데이터 집합을 예측하려고합니다. 내 열차와 예상 프로그램은 다음과 같으며 여러 가지 구현을 사용하고 있습니다. 문제는 내가 model.predict(vector) (예측 프로그램에서 lrmodel.predict()에 주목할 때)입니다. 예측 된 레이블을 얻었습니다. 그러나 신뢰 점수가 필요한 경우에는 어떻게해야합니까? 어떻게해야합니까? 나는 API를 거쳐 신뢰 점수를주는 특정 API를 찾을 수 없었다. 누구든지 나를 도울 수 있습니까?

기차 프로그램은

public static void main(final String[] args) throws Exception { 
     JavaSparkContext jsc = null; 
     int salesIndex = 1; 

     try { 
      ... 
     SparkConf sparkConf = 
        new SparkConf().setAppName("Hackathon Train").setMaster(
          sparkMaster); 
      jsc = new JavaSparkContext(sparkConf); 
      ... 

      JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache(); 

      final String firstRdd = trainRDD.first().trim(); 
      JavaRDD<String> tempRddFilter = 
        trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() { 
         private static final long serialVersionUID = 
           11111111111111111L; 

         public Boolean call(final String arg0) { 
          return !arg0.trim().equalsIgnoreCase(firstRdd); 
         } 
        }); 

      ... 
      JavaRDD<String> featureRDD = 
        tempRddFilter 
          .map(new org.apache.spark.api.java.function.Function() { 
           private static final long serialVersionUID = 
             6948900080648474074L; 

           public Object call(final Object arg0) 
             throws Exception { 
            ... 
            StringBuilder featureSet = 
              new StringBuilder(); 
            ... 
             featureSet.append(i - 2); 
             featureSet.append(COLON); 
             featureSet.append(strVal); 
             featureSet.append(SPACE); 
            } 

            return featureSet.toString().trim(); 
           } 
          }); 

      List<String> featureList = featureRDD.collect(); 
      String featureOutput = StringUtils.join(featureList, NEW_LINE); 
      String filePath = basePath + "lr.arff"; 
      FileUtils.writeStringToFile(new File(filePath), featureOutput, 
        "UTF-8"); 

      JavaRDD<LabeledPoint> trainingData = 
        MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache(); 

      final LogisticRegressionModel model = 
        new LogisticRegressionWithLBFGS().setNumClasses(18).run(
          trainingData.rdd()); 
      ByteArrayOutputStream baos = new ByteArrayOutputStream(); 
      ObjectOutputStream oos = new ObjectOutputStream(baos); 
      oos.writeObject(model); 
      oos.flush(); 
      oos.close(); 
      FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"), 
        baos.toByteArray()); 
      baos.close(); 

     } catch (Exception e) { 
      e.printStackTrace(); 
     } finally { 
      if (jsc != null) { 
       jsc.close(); 
      } 
     } 

public static void main(final String[] args) throws Exception { 
     JavaSparkContext jsc = null; 
     int salesIndex = 1; 
     try { 
      ... 
     SparkConf sparkConf = 
        new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster); 
      jsc = new JavaSparkContext(sparkConf); 

      ObjectInputStream objectInputStream = 
        new ObjectInputStream(new FileInputStream(basePath 
          + "lr.model")); 
      LogisticRegressionModel lrmodel = 
        (LogisticRegressionModel) objectInputStream.readObject(); 
      objectInputStream.close(); 

      ... 

      JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache(); 

      final String firstRdd = trainRDD.first().trim(); 
      JavaRDD<String> tempRddFilter = 
        trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() { 
         private static final long serialVersionUID = 
           11111111111111111L; 

         public Boolean call(final String arg0) { 
          return !arg0.trim().equalsIgnoreCase(firstRdd); 
         } 
        }); 

      ... 
      final Broadcast<LogisticRegressionModel> broadcastModel = 
        jsc.broadcast(lrmodel); 

      JavaRDD<String> featureRDD = 
        tempRddFilter 
          .map(new org.apache.spark.api.java.function.Function() { 
           private static final long serialVersionUID = 
             6948900080648474074L; 

           public Object call(final Object arg0) 
             throws Exception { 
            ... 
            LogisticRegressionModel lrModel = 
              broadcastModel.value(); 
            String row = ((String) arg0); 
            String[] featureSetArray = 
              row.split(CSV_SPLITTER); 
            ... 
            final Vector vector = 
              Vectors.dense(doubleArr); 
            double score = lrModel.predict(vector); 
            ... 
            return csvString; 
           } 
          }); 

      String outputContent = 
        featureRDD 
          .reduce(new org.apache.spark.api.java.function.Function2() { 

           private static final long serialVersionUID = 
             1212970144641935082L; 

           public Object call(Object arg0, Object arg1) 
             throws Exception { 
            ... 
           } 

          }); 
      ... 
      FileUtils.writeStringToFile(new File(basePath 
        + "predicted-sales-data.csv"), sb.toString()); 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } finally { 
      if (jsc != null) { 
       jsc.close(); 
      } 
     } 
    } 
} 

답변

0

많은 시도 끝에 마침내 신뢰 점수를 생성하는 사용자 지정 함수를 작성했습니다. 완벽하지는 않지만 지금 당장 나를 위해 일합니다!

private static double getConfidenceScore(
      final LogisticRegressionModel lrModel, final Vector vector) { 
     /* Approach to get confidence scores starts */ 
     Vector weights = lrModel.weights(); 
     int numClasses = lrModel.numClasses(); 
     int dataWithBiasSize = weights.size()/(numClasses - 1); 
     boolean withBias = (vector.size() + 1) == dataWithBiasSize; 
     double maxMargin = 0.0; 
     double margin = 0.0; 
     for (int j = 0; j < (numClasses - 1); j++) { 
      margin = 0.0; 
      for (int k = 0; k < vector.size(); k++) { 
       double value = vector.toArray()[k]; 
       if (value != 0.0) { 
        margin += value 
          * weights.toArray()[(j * dataWithBiasSize) + k]; 
       } 
      } 
      if (withBias) { 
       margin += weights.toArray()[(j * dataWithBiasSize) 
         + vector.size()]; 
      } 
      if (margin > maxMargin) { 
       maxMargin = margin; 
      } 
     } 
     double conf = 1.0/(1.0 + Math.exp(-maxMargin)); 
     DecimalFormat twoDForm = new DecimalFormat("#.##"); 
     double confidenceScore = Double.valueOf(twoDForm.format(conf * 100)); 
     /* Approach to get confidence scores ends */ 
     return confidenceScore; 
    } 
0

가 실제로는하지 않습니다 (기차 프로그램에서 생성 된 lr.model를 사용하여) 프로그램을 예측 (A .MODEL 파일을 생성합니다) 가능할 것 같다. 소스 코드를 살펴보면이 확률을 반환하기 위해 소스 코드를 확장 할 수 있습니다.

if (numClasses == 2) { 
    val margin = dot(weightMatrix, dataMatrix) + intercept 
    val score = 1.0/(1.0 + math.exp(-margin)) 
    threshold match { 
    case Some(t) => if (score > t) 1.0 else 0.0 
    case None => score 
    } 

https://github.com/apache/spark/blob/branch-1.5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

나는 그것이 해결 방법을 찾기 시작하는 데 도움이 될 수 있기를 바랍니다.

+0

예를 들어 주시겠습니까? spark docs에서 LogisticRegression.java를 통해 그 메소드를 찾을 수 없었습니다. – ArinCool

+0

** raw2probabilityInPlace ** 및 ** raw2prediction ** 기능을 찾을 수 없습니다. 좀 도와 줄 수있어? – ArinCool

+0

org.apache.spark.ml.classificationLogisticRegressionModel 클래스에 있습니다. 더 간단하다면 다른 이름으로 복사본을 만들고이 함수를 공개 할 수 있습니다. –

관련 문제