2011-09-12 6 views
4

나는 회귀 분석을위한 간단한 그래디언트 부스트 알고리즘을 구현하려고합니다. 이것은 지금까지 생각해 낸 것입니다 만, 예상했던 것처럼 오류가 정체되지 않습니다. 어떤 제안?간단한 그래디언트 부스팅 알고리즘

data("OrchardSprays") 
niter <- 10 
learn <- 0.05 
y  <- OrchardSprays$decrease 
yhat <- rep(0,nrow(OrchardSprays)) 
weight <- rep(1,nrow(OrchardSprays)) 
loss <- function(y,yhat) (y - yhat)^2 

for (i in seq(niter)) 
{ 
    model <- lm(decrease~.,weights=weight,data=OrchardSprays) 
    yhat <- yhat + weight * (predict(model) - yhat)/i 
    error <- mean(loss(y,yhat)) 
    weight <- weight + learn * (loss(y,yhat) - error)/error 
    cat(i,"error:",error,"\n") 
} 

출력 :

1 error: 319.5881 
2 error: 318.6175 
3 error: 317.9368 
4 error: 317.6112 
5 error: 317.6369 
6 error: 317.9772 
7 error: 318.5833 
8 error: 319.4047 
9 error: 320.3939 
10 error: 321.5086 

답변

2

나는 나이에 체중 최적화를 작성하는 데하지 인정 것이다, 그래서 기지에서 할 수있다. 모든 반복마다 yhat 벡터를 녹음하여 시작합니다. 값이 흔들리거나 0으로 사라지는 지 확인하십시오. (도움이나 아픔이 i으로 나누어 져 있는지 확실하지 않으므로)
마찬가지로, lm()의 각 반복에서 R^2 값을 살펴보십시오. 이들이 1에 매우 가깝다면, 현재 규정 된 민감도 한계 인 lm()을 사용했을 것입니다.

구현중인 방정식과 코드를 비교할 수 있도록 알고리즘 소스를 제공하면 도움이 될 것입니다.

업데이트 : wikipedia를 살펴보면 "여러 오픈 소스 R 패키지를 사용할 수 있습니다 : gbm, [6] mboost, gbev." 소스 코드를 포함하여 패키지가 사용자의 요구를 충족시키는 지 확인하기 위해이 패키지를 연구 할 것을 강력히 권장합니다.

+0

알고리즘은 회귀 분석을위한 알고리즘을 강화 friedmans 그라데이션, 그 공식적 재귀 피팅 잔류으로 정의하지만 난 경우 가중치를 조정하여 구현해야한다. 나는 내가 무엇을 가지고 있는지에 대해 부스트 된 나무에 사용되는 적응과 가깝다는 것을 확신한다. – darckeen

+0

@ darckeen, 그냥 nitpick되고있다. 그라디언트 증폭 알고리즘이 될 수 없습니다. 다시 가중치를 적용한 선형 함수는 여전히 선형이기 때문에. 귀하의 치료는 훨씬 더 부스트처럼됩니다. –

2

각 단계에서 데이터를 무작위로 샘플링 해 보았습니까? 그렇다면 예제의 절반 만 현재 학습자에게 보여 줬습니까? 매번 전체 샘플을 사용하면 지나치게 좋지 않은 결과를 얻을 수 있다고 생각합니다. 또한 선형 모델을 높이는 것이 (분산이 낮은) 도움이된다고 확신하지 못합니다.

+0

난 그냥 간단하게 유지하고 서브 샘플링을 사용하고 싶지 않았지만 당신이 언급 한 overfitting 문제는 문제가된다. 알 고가 제대로 구현 되었다면 결국 데이터를 거의 완벽하게 채워야하지만 그렇지 않다. – darckeen

4

이것이 도움이되는지 확실하지 않지만 시작 가중치를 낮추고 반복 횟수를 늘리면 오류가 0에 훨씬 가깝습니다. 그러나 여전히 고원 상태는 아닙니다 (반복 103에서 오류가 다시 시작됩니다). 또한 lm 함수로 생성 된 missing or negative weights not allowed 오류를 보완하기 위해 weight <- ifelse(weight < 0.0, 0.0, weight) 문을 추가했습니다.

data("OrchardSprays") 
niter <- 105 
learn <- 0.05 
y  <- OrchardSprays$decrease 
yhat <- rep(0.0,nrow(OrchardSprays)) 
weight <- rep(0.2,nrow(OrchardSprays)) 
loss <- function(y,yhat) (y - yhat)^2 

error <- mean(loss(y,yhat)) 
cat("initial error:",error,"\n") 

for (i in seq(niter)) 
{ 
    model <- lm(decrease~.,weights=weight,data=OrchardSprays) 
    yhat <- yhat + weight * (predict(model) - yhat)/i 
    error <- mean(loss(y,yhat)) 
    weight <- weight + learn * (loss(y,yhat) - error)/error 
    weight <- ifelse(weight < 0.0, 0.0, weight) 
    cat(i,"error:",error,"\n") 
} 

OUPUT :

initial error: 3308.922 

1 error: 2232.762 
2 error: 1707.971 
3 error: 1360.834 
4 error: 1110.503 
5 error: 921.2804 
6 error: 776.4314 
7 error: 663.5947 
8 error: 574.2603 
9 error: 502.2455 
10 error: 443.2639 
11 error: 394.2983 
12 error: 353.1736 
13 error: 318.2869 
14 error: 288.4326 
15 error: 262.6827 
16 error: 240.3086 
17 error: 220.7289 
18 error: 203.4741 
19 error: 188.1632 
20 error: 174.4876 
21 error: 162.1971 
22 error: 151.0889 
23 error: 140.9982 
24 error: 131.7907 
25 error: 123.3567 
26 error: 115.6054 
27 error: 108.4606 
28 error: 101.8571 
29 error: 95.73825 
30 error: 90.05343 
31 error: 84.75755 
32 error: 79.81715 
33 error: 75.19618 
34 error: 70.86006 
35 error: 66.77859 
36 error: 62.92584 
37 error: 59.28014 
38 error: 55.8239 
39 error: 52.54784 
40 error: 49.44272 
41 error: 46.49915 
42 error: 43.71022 
43 error: 41.07119 
44 error: 38.57908 
45 error: 36.23237 
46 error: 34.03907 
47 error: 32.00558 
48 error: 30.12923 
49 error: 28.39891 
50 error: 26.80582 
51 error: 25.33449 
52 error: 23.97077 
53 error: 22.70327 
54 error: 21.52714 
55 error: 20.43589 
56 error: 19.42552 
57 error: 18.48629 
58 error: 17.60916 
59 error: 16.78986 
60 error: 16.02315 
61 error: 15.30303 
62 error: 14.62663 
63 error: 13.99066 
64 error: 13.39205 
65 error: 12.82941 
66 error: 12.30349 
67 error: 11.811 
68 error: 11.34883 
69 error: 10.91418 
70 error: 10.50448 
71 error: 10.11723 
72 error: 9.751116 
73 error: 9.405197 
74 error: 9.076175 
75 error: 8.761231 
76 error: 8.458107 
77 error: 8.165144 
78 error: 7.884295 
79 error: 7.615498 
80 error: 7.356618 
81 error: 7.106186 
82 error: 6.86324 
83 error: 6.627176 
84 error: 6.39777 
85 error: 6.17544 
86 error: 5.961616 
87 error: 5.756781 
88 error: 5.561157 
89 error: 5.375131 
90 error: 5.19945 
91 error: 5.034539 
92 error: 4.880956 
93 error: 4.739453 
94 error: 4.610629 
95 error: 4.495216 
96 error: 4.393571 
97 error: 4.306144 
98 error: 4.233587 
99 error: 4.176799 
100 error: 4.136802 
101 error: 4.114575 
102 error: 4.111308 
103 error: 4.1278 
104 error: 4.164539 
105 error: 4.221389 
관련 문제