업데이트 : 다음과 같은 방법으로 신뢰 점수를 얻으려고했지만 예외가 발생했습니다.Spark MLLib Logistic Regression에서 신뢰 점수를 얻는 방법
double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0/(1.0 + Math.exp(-point));
내가 얻을 예외 :
Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)
당신이 도와 주실 수 있습니까 아래의 코드를 사용할 수 있습니까? 가중치 벡터가 데이터 벡터보다 많은 요소 (198)를 가지고있는 것처럼 보입니다 (저는 18 개의 피쳐를 생성하고 있습니다). 그들은 dot()
기능에서 동일한 길이 여야합니다.
Java에 프로그램을 구현하여 Spark MLLib (1.5.0)에서 사용할 수있는 로지스틱 회귀 알고리즘을 사용하여 기존 데이터 집합에서 훈련하고 새로운 데이터 집합을 예측하려고합니다. 내 열차와 예상 프로그램은 다음과 같으며 여러 가지 구현을 사용하고 있습니다. 문제는 내가 model.predict(vector)
(예측 프로그램에서 lrmodel.predict()에 주목할 때)입니다. 예측 된 레이블을 얻었습니다. 그러나 신뢰 점수가 필요한 경우에는 어떻게해야합니까? 어떻게해야합니까? 나는 API를 거쳐 신뢰 점수를주는 특정 API를 찾을 수 없었다. 누구든지 나를 도울 수 있습니까?
기차 프로그램은
public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;
try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Train").setMaster(
sparkMaster);
jsc = new JavaSparkContext(sparkConf);
...
JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();
final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;
public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});
...
JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;
public Object call(final Object arg0)
throws Exception {
...
StringBuilder featureSet =
new StringBuilder();
...
featureSet.append(i - 2);
featureSet.append(COLON);
featureSet.append(strVal);
featureSet.append(SPACE);
}
return featureSet.toString().trim();
}
});
List<String> featureList = featureRDD.collect();
String featureOutput = StringUtils.join(featureList, NEW_LINE);
String filePath = basePath + "lr.arff";
FileUtils.writeStringToFile(new File(filePath), featureOutput,
"UTF-8");
JavaRDD<LabeledPoint> trainingData =
MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();
final LogisticRegressionModel model =
new LogisticRegressionWithLBFGS().setNumClasses(18).run(
trainingData.rdd());
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(model);
oos.flush();
oos.close();
FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
baos.toByteArray());
baos.close();
} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}
이
public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;
try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
jsc = new JavaSparkContext(sparkConf);
ObjectInputStream objectInputStream =
new ObjectInputStream(new FileInputStream(basePath
+ "lr.model"));
LogisticRegressionModel lrmodel =
(LogisticRegressionModel) objectInputStream.readObject();
objectInputStream.close();
...
JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();
final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;
public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});
...
final Broadcast<LogisticRegressionModel> broadcastModel =
jsc.broadcast(lrmodel);
JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;
public Object call(final Object arg0)
throws Exception {
...
LogisticRegressionModel lrModel =
broadcastModel.value();
String row = ((String) arg0);
String[] featureSetArray =
row.split(CSV_SPLITTER);
...
final Vector vector =
Vectors.dense(doubleArr);
double score = lrModel.predict(vector);
...
return csvString;
}
});
String outputContent =
featureRDD
.reduce(new org.apache.spark.api.java.function.Function2() {
private static final long serialVersionUID =
1212970144641935082L;
public Object call(Object arg0, Object arg1)
throws Exception {
...
}
});
...
FileUtils.writeStringToFile(new File(basePath
+ "predicted-sales-data.csv"), sb.toString());
} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}
}
}
예를 들어 주시겠습니까? spark docs에서 LogisticRegression.java를 통해 그 메소드를 찾을 수 없었습니다. – ArinCool
** raw2probabilityInPlace ** 및 ** raw2prediction ** 기능을 찾을 수 없습니다. 좀 도와 줄 수있어? – ArinCool
org.apache.spark.ml.classificationLogisticRegressionModel 클래스에 있습니다. 더 간단하다면 다른 이름으로 복사본을 만들고이 함수를 공개 할 수 있습니다. –