그래디언트 하강을 사용한 선형 회귀를 구현하기 위해 다음 Java 프로그램을 작성했습니다. 코드가 실행되지만 결과가 정확하지 않습니다. y의 예측 된 값은 y의 실제 값에 가깝지 않습니다. 예를 들어, x = 75 일 때 예상 y = 208이지만 출력은 y = 193.784입니다.그래디언트 디센트가있는 선형 회귀
class LinReg {
double theta0, theta1;
void buildModel(double[] x, double[] y) {
double x_avg, y_avg, x_sum = 0.0, y_sum = 0.0;
double xy_sum = 0.0, xx_sum = 0.0;
int n = x.length, i;
for(i = 0; i < n; i++) {
x_sum += x[i];
y_sum += y[i];
}
x_avg = x_sum/n;
y_avg = y_sum/n;
for(i = 0; i < n; i++) {
xx_sum += (x[i] - x_avg) * (x[i] - x_avg);
xy_sum += (x[i] - x_avg) * (y[i] - y_avg);
}
theta1 = xy_sum/xx_sum;
theta0 = y_avg - (theta1 * x_avg);
System.out.println(theta0);
System.out.println(theta1);
gradientDescent(x, y, 0.1, 1500);
}
void gradientDescent(double x[], double y[], double alpha, int maxIter) {
double oldtheta0, oldtheta1;
oldtheta0 = 0.0;
oldtheta1 = 0.0;
int n = x.length;
for(int i = 0; i < maxIter; i++) {
if(hasConverged(oldtheta0, theta0) && hasConverged(oldtheta1, theta1))
break;
oldtheta0 = theta0;
oldtheta1 = theta1;
theta0 = oldtheta0 - (alpha * (summ0(x, y, oldtheta0, oldtheta1)/(double)n));
theta1 = oldtheta1 - (alpha * (summ1(x, y, oldtheta0, oldtheta1)/(double)n));
System.out.println(theta0);
System.out.println(theta1);
}
}
double summ0(double x[], double y[], double theta0, double theta1) {
double sum = 0.0;
int n = x.length, i;
for(i = 0; i < n; i++) {
sum += (hypothesis(theta0, theta1, x[i]) - y[i]);
}
return sum;
}
double summ1(double x[], double y[], double theta0, double theta1) {
double sum = 0.0;
int n = x.length, i;
for(i = 0; i < n; i++) {
sum += (((hypothesis(theta0, theta1, x[i]) - y[i]))*x[i]);
}
return sum;
}
boolean hasConverged(double oldTheta, double newTheta) {
return ((newTheta - oldTheta) < (double)0);
}
double predict(double x) {
return hypothesis(theta0, theta1, x);
}
double hypothesis(double thta0, double thta1, double x) {
return (thta0 + thta1 * x);
}
}
public class LinearRegression {
public static void main(String[] args) {
//Height data
double x[] = {63.0, 64.0, 66.0, 69.0, 69.0, 71.0, 71.0, 72.0, 73.0, 75.0};
//Weight data
double y[] = {127.0, 121.0, 142.0, 157.0, 162.0, 156.0, 169.0, 165.0, 181.0, 208.0};
LinReg model = new LinReg();
model.buildModel(x, y);
System.out.println("----------------------");
System.out.println(model.theta0);
System.out.println(model.theta1);
System.out.println(model.predict(75.0));
}
}
은 네는 회귀 및 그라데이션 하강의 개념과 혼동했다. 답변 해주셔서 감사합니다. – noober
@noober 도와 드리겠습니다. ML 및 회귀 분석을위한 훌륭한 Java 라이브러리를 원한다면 Weka를 통해 Weka를 추천합니다. –