2016-09-26 3 views
3

내 트리에 정보를 추가하고 싶습니다. 나는 나무를 실행할 수 있습니다트리에 정보 추가하기 - Rpart

library(rpart) 
library(rpart.plot) 
set.seed(1) 
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T), 
       var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T)) 

: 이제 나는이 같은 데이터베이스가 예를 들어 가정 해 봅시다

mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0) 
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"]) 
prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8)) 

결과는 다음과 같습니다 enter image description here

을 그리고 그것은 나를 위해 확인하지만 각 잎의 평균 노출량을 알고 싶다고합시다.

node.fun1 <- function(x, labs, digits, varlen) 
{ 
    paste("Weight \n",x$frame$wt) 
} 

prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8),node.fun = node.fun1) 

enter image description here

그러나이 프레임에서 계산 된 결과를 것 경우에만 작동합니다

나는 내가하는 기능, 예를 들어, PRP 각 잎의 무게를 몇 가지 정보를 추가 할 수 있습니다 알고 rpart 함수.

내 질문 :

어떻게 평균 노출, 또는 사용자 정의 지표를 계산하고 표 frame에 추가 할 다른 기능처럼, 플롯에 사용자 지정 정보를 추가 할 수 있습니까?

+1

당신은 rpart 부분을 잊고, 또는 사용자가 만든 그러나'네 말이 맞아 – rawr

+0

mytree'! 편집 됨 : – Arault

답변

1

이것은 정말 좋은 일이며, 이것이 옵션인지 몰랐습니다.

모든 작업이 각 노드에서 사용 된 원본 데이터의 부분 집합을 얻는 것 같습니다. 이것은 터미널 노드에서 쉽지만, 잎뿐만 아니라 모든 노드에서 사용 된 데이터 행을 식별하는 직접적인 방법을 찾지 못했습니다. 누군가 쉬운 방법을 알고 있다면, 나는 그것을 듣고 싶습니다.

enter image description here

library('rpart.plot') 
set.seed(1) 
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T), 
       var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T)) 
mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0) 
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"]) 

rpart.plot(pfit) 
x 소요 새로운 기능, 피팅 rpart (나는 다른 인수로 보지 못했지만, 네트 도움이 될한다)의 결과를 정의합니다.

x$frame의 모든 행에 대해 요약 통계를 계산하는 데 필요한 데이터를 가져와야합니다. 불행하게도 x$where은 각 관찰이 속한 터미널 노드 만 알려줍니다. 그래서 각 노드 번호를, 우리는 기본 데이터를 얻고, 당신이 그것으로 원하는 건 뭐든지 할 subset.rpart를 사용

f <- function(x, labs, digits, varlen) { 
    nodes <- as.integer(rownames(x$frame)) 
    z <- sapply(nodes, function(y) { 
    data <- subset.rpart(x, y) 
    c(mean = mean(data$expo), nrow(data), nrow(data)/length(x$where) * 100) 
    }) 
    sprintf('Mean expo: %.2f\nn=%.0f (%.0f%%)', z[1, ], z[2, ], z[3, ]) 
} 

prp(pfit, type=1, extra=100, fallen.leaves=FALSE, 
    shadow.col="darkgray", box.col=rgb(0.8,0.9,0.8), 
    node.fun = f) 
노드 수를 소요의 하위 집합을 반환 작업이 subset.rpart에 의해 수행되었다

enter image description here

data이 노드에서 사용되었습니다.

subset.rpart <- function(tree, node = 1L) { 
    ## returns subset of tree$call$data used on any node 
    data <- eval(tree$call$data, parent.frame(1L)) 
    wh <- sapply(as.integer(rownames(tree$frame)), parent) 
    wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)])) 
    data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ] 
} 

parent <- function(x) { 
    ## returns vector of parent nodes 
    if (x[1] != 1) 
    c(Recall(if (x %% 2 == 0L) x/2 else (x - 1)/2), x) else x 
} 

테스트

## tests 
dim(subset.rpart(pfit, 1)) == dim(mydb) 
# [1] TRUE TRUE 

## terminal nodes 
nodes <- as.integer(rownames(pfit$frame[pfit$frame$var %in% '<leaf>', ])) 
sum(sapply(nodes, function(x) nrow(subset.rpart(pfit, x)))) == nrow(mydb) 
# [1] TRUE 
+0

감사합니다. subset.rpart & parent 코드를 작성 했습니까? 그건 내가 뭔가 찾고 있었어 – Arault

+0

예, 불행히도 나는 [이 질문/답변] (http : // stackoverflow.com/questions/36748531/getting-observations-in-a-rparts-node-i-e-cart)가 너무 늦었습니다. 당신은 당신이 더 좋아하는 대안을 찾을 수 있습니다. – rawr

+0

실제로 나는 당신의 버전을 선호합니다. 당신의 코드는 정말로 멋지고 읽기 쉽습니다. 당신의 모든 일에 다시 감사합니다. – Arault