랜덤 포레스트
1. 개념
랜덤포레스트는 의사결정트리를 앙상블 기법으로 학습시킨 모델로, 오버피팅을 방지하기 위해 고안된 방법이다.
랜덤포레스트 역시 분류와 회귀분석 모두에 사용될 수 있다.
앙상블 학습 기법이란 쉽게 말해, 하나의 예측에 여러가지 알고리즘을 투표를 거쳐 사용하는 것이다.
즉, 랜덤포레스트에서 사용하는 앙상블은 여러 개의 서로 다른 의사결정트리를 만들고, 투표를 통해 결과를 얻어내는 것이다.
하지만 의사결정트리 개념에서 정리한 것 처럼 의사결정 트리를 생성하게 된다면,
트리들이 모두 똑같거나 비슷한 트리로 생성될 것이다. 배심원 10명을 앉혀놨는데 모두 쌍둥이인 꼴이다(..)
이때 배심원의 출신을 랜덤하게 선정하는 방법이 바로 배깅(bagging == bootstrap aggregating)이다.
배깅은 동일한 전체 데이터를 이용하여 개별 트리들을 학습하는 것이 아니라,
전체 데이터에서 훈련용 데이터를 샘플링하여 개별 트리를 구성하는 것이다.
이 과정에서는 부트스트랩이라 불리는 것을 통해 샘플링을 하게 되는데,
이때 모집단에서 표본집단 추출하듯이 추출하는 것이 아니라 원 데이터와 같은 크기의 데이터를 추출하기 위해 노력한다.
이때 샘플링의 크기를 조절해서 모델을 튜닝할 수 있는데, 원 데이터와 비슷한 크기라고 항상 좋은 것은 아닌것 같다.
(통상적으로는 80% 정도라고 한다.)
랜덤 포레스트에서 분류기 훈련용 데이터의 샘플링 문제는 쉬운 문제가 아니므로 이정도(부트스트랩 과정)에서 정리하겠다.
배깅 외에도 랜덤 포레스트에 중요한 요소가 있는데, 바로 임의 노드 최적화이다.
각 트리의 노드들은 분할 함수(split function)을 토대로, 데이터가 오른쪽으로 향할 지 왼쪽으로 향할 지를 결정한다.
Neural Net의 activate function과 비슷한 개념이라고 생각할 수 있는데,
(이런 걸 보면 데이터 분석 알고리즘은 대개 비스무레하다)
이러한 분할 함수는 수식화된 매개변수를 가진다. 매개변수를 X라고 정의한다면,
X는 feature에 대한 정보(특징 Bagging이라고도 불린다), 임의성, 함수의 기하학적 특성 등을 포함하게 된다.
이때 임의성에 대한 파라미터를 최적화시켜서 최종적으로 파라미터 X를 최적화 시키는 것이 임의 노드 최적화이다.
최적화 방법 역시 고도화된 알고리즘의 영역이므로, 일단은 이 정도의 깊이를 아는것이 중요하다 할 수 있겠다.
추가적으로, 랜덤포레스트에는 부스팅 방법이 적용되기도 하는데,
부스팅 방법이란 회귀분석의 Lasso와 비슷한 역할을 하는 오분류 벌점 부과 방법이다.
벌점 부과 방법은 회귀나 분류 문제에 있어서 오류를 줄이는 하나의 방법론으로,
잘못 회귀되거나 분류된 데이터에 높은 가중치를 부과하여 정확도를 향상시키기 위한 방법이다.
회귀 분석에서는 Gradient Descent에서 미분값의 절대값이 매우 큰 것과 비슷한 상황이라고 생각하면 편리하다.
어찌되었든, 랜덤포레스트에서도 부스팅 방법은 모델링을 통한 예측 변수에 대해 오분류된 객체에 높은 가중치를 부여하는 것.
의사결정트리 학습방법 중, 가지치기의 개념과 비슷(역시 데이터 마이닝 알고리즘은 거기서 거기로 돌고 돈다)하다.
한 가지 활용 예시로, R에서 이 개념을 사용한다.
R에서는 변수 중요도 평가라는 파라미터로 활용할 수 있는데, 높은 가중치를 부여하는 대상이 반대일 뿐, 개념은 동일하다.
randomForest라는 함수에서 importance=T 라는 파라미터를 설정하여 불순도 개선에 기여하는 변수를 선택한다.
랜덤포레스트를 정리하자면, Bagging 계열의 알고리즘으로써, 그 기반을 의사결정트리에 둔 알고리즘이라고 할 수 있겠다.
성능 향상을 위해 배깅의 방식 뿐 아니라, 임의 노드 최적화와 부스팅을 추가적으로 수행하기도 한다.
랜덤포레스트는 상당히 성능이 좋은 알고리즘에 속하지만 의사결정트리와 달리, 분석과정이 철저히 black box 형태이다.
물론 프로그래머가 사용하기에는 매우 쉬운데, 조절해줄 파라미터로
tree depth(or count), data feature, sampling % 정도가 있겠다.
2. 적용
library(randomForest)
### mnist dataset에서 테스트
url = "https://github.com/ozt-ca/tjo.hatenablog.samples/raw/master/r_samples/public_lib/jp/mnist_reproduced"
# prac_test = read.csv(paste(url, "prac_test.csv", sep = "/"))
# prac_train = read.csv(paste(url, "prac_train.csv", sep = "/"))
short_prac_train = read.csv(paste(url, "short_prac_train.csv", sep = "/"))
short_prac_test = read.csv(paste(url, "short_prac_test.csv", sep = "/"))
str(short_prac_train)
# mnist randomforest 학습
train1 = short_prac_train
test1 = short_prac_test
train1$label = factor(short_prac_train$label)
test1$label = factor(short_prac_test$label)
r2 = randomForest(label~., train1)
pred2 = predict(r2, newdata = test1)
t2 = table(test1$label, pred2)
diag(t2)
sum(diag(t2)) / sum(t2)
# 픽셀을 0과 1로만 나누어서 학습
train2 = train1
train2[, -1] = round(train2[, -1]/255)
test2 = test1
test2[, -1] = round(test2[, -1]/255)
start1=Sys.time()
r3 = randomForest(label~., train2)
interval = Sys.time()-start1
pred3 = predict(r3, newdata = test2)
t3 = table(test1$label, pred3)
diag(t3)
sum(diag(t3)) / sum(t3)
# dna 데이터 연습
data("DNA")
View(DNA)
summary(DNA)
ind = sample(1:nrow(DNA)*0.7, nrow(DNA)*0.7, replace = F)
train_dna = DNA[ind,]
test_dna = DNA[-ind,]
start2=Sys.time()
r4 = randomForest(Class~., train_dna)
interval = Sys.time() - start2
pred4 = predict(r4, newdata = test_dna)
t4 = table(test_dna$Class, pred4)
diag(t4)
sum(diag(t4)) / sum(t4)
'Programming & Machine Learning > R X 머신러닝' 카테고리의 다른 글
R을 이용한 간단한 데이터 시각화 (0) | 2017.07.21 |
---|---|
R을 이용한 머신러닝 - 5 (의사결정 트리) (0) | 2017.07.14 |
R을 이용한 머신러닝 - 4 (로지스틱 회귀모델을 이용한 분류 문제) (0) | 2017.07.14 |
R을 이용한 머신러닝 - 3 (분류와 클러스터링 : K-NN, K-means) (0) | 2017.07.13 |
R을 이용한 통계분석 -5 (실제 데이터를 이용한 시계열 분석) (0) | 2017.07.11 |