R을 이용한 의사결정 트리
1.1 개념
의사결정 트리는 분류 문제이기도 하고, 예측 문제이기도 한데, 주로 사용되는 분야는 지도 분류 학습이다. 분기의 기준이
범주형이냐, 연속형이냐에 따라 결과가 분류나 회귀 모두가 가능하다. terminal node의 값이 범주형이라면 분류, 연속형이라면
예측이라는 것이다.
어찌되었든 의사결정 트리의 메인 아이디어는 데이터를 스무고개처럼 분석하여 최종적 판단에 이르는 패턴을 찾아내는 것이다.
다음의 그림은 의사결정트리를 나타내는 대표적인 예시인 타이타닉호 생존 결정 트리이다.
결정트리에서 중요한 내용으로는 다음과 같은 크게 3가지가 있다.
1. 노드 분기 방법
2. 모델 학습 과정
3. 가지치기
1.2 노드 분기 방법
의사결정트리는 한 분기때마다 변수 영역을 두개로 구분하게 된다. 컴퓨터 자료구조에서의 binary 트리와 유사한 형태이다.
각 depth에서 의사를 결정하는 기준은, 데이터 마이닝이나 통계 영역에서는 순도(homogeneity), 불순도(impurity), 등으로
평가하게 된다.
만약 분기 후에 불순도와 불확실성이 감소한다면, 정보획득이 일어난 것이며, 즉 올바른 분기라고 할 수 있는 것이다.
불순도에 대해 쉽게 설명하자면 만약 100개의 데이터 중에 남녀의 비율이 5:5 라면 굉장히 불순한(?) 데이터인 것이고,
분기의 기준이 성별일 때 분기 후에 각 node에서의 불순도는 0이 될 것이다. 그렇다면 이 분기는 올바르게 진행된 것이다.
의사결정트리에서는 불순도나 불확실성이 최대로 감소하도록 학습을 진행하게 되는데, 이 과정은 뒤에 설명할 재귀적 분기
(recursive partitioning)로 진행된다.
불순도 계산 지표는 대표적으로 엔트로피(entropy), 지니 지수(gini index), 오분류 오차(misclassfication error)이다.
지니 불순도는 어떤 집합에서 한 항목을 뽑아 무작위로 라벨을 추정할 때 틀릴 확률을 말한다. 집합에 있는 항목이 모두 같다면
지니 불순도는 최솟값(0)을 갖게 되며 이 집합은 완전히 순수하다고 할 수 있다.
엔트로피 역시 비슷한 논리의 결과로 나온 수치로, 라벨 추정과 관련된 확률이다.
지니 불순도와 엔트로피는 불순도에 관한 굉장히 비슷한 (거의 동일한)척도이다. 하지만 분류 오차의 경우는 약간 다르다.
분류 오차는 노드들의 분류 확률에 대한 변화는 덜 민감하기 때문에 가지치기를 위해서 유용한 기준이 된다.
즉, 의사결정트리를 성장시키고 학습시키는 데에는 지니, 엔트로피가 적당한 기준이고 가지치기를 위해서는 분류 오차가 유용한
기준이 된다는 것이다.
- 엔트로피 공식
- 지니 지수 공식
1.3 모델 학습 과정
결정 트리에서의 학습은, 학습에 사용되는 자료 집합을 적절하게 분기시키는 것이다. 노드를 분기함에 있어서 '순환 분할'이라
불리는 재귀적 분기(recursive partitioning) 과정이 진행된다. 분할로 인해 더 이상 나은 결과가 나오지 않을 때까지 반복된다.
이 과정은 Greedy한 알고리즘이라고 할 수 있다.
재귀적 분기 과정은 우선 데이터를 한 변수를 기준으로 정렬한 후, 가능한 모든 분기점에 대해 정보획득을 조사한다.
예를 들어 데이터의 개수가 500개라면, 1:499, 2:498, 3:497 ... 등으로 분기점을 선택한 후 각각의 정보획득을 계산.
그 다음, 다른 변수를 기준으로 정렬한 뒤 위 과정을 반복하고, 최종적으로 나온 모든 경우의 수 가운데서 가장 정보획득이 큰
변수와 지점을 선택하여 분기를 하게된다. 이 과정을 하위노드도 재귀적으로 반복한다고 하여 재귀적 분기라고 불리는 것이다.
1.4 가지치기
결정 트리에서 depth가 불필요하게 크거나(일반적인 depth는 3~4개가 적당), leaf node가 많아지는 경우 데이터가 트리에 과적합
되었을 가능성이 매우 높다. 이는 새로운 자료에 적용할 때 예측오차가 매우 클 가능성이 있다는 것을 의미한다.
따라서 가지치기(pruning)을 통해 모든 terminal node의 불순도가 0인 풀 트리(full tree) 상태를 방지해야 한다.
결정 트리에서의 분기 수가 증가하면, 모델 학습 과정에서는 오분류율이 감소할 수 있으나, 일정 수준 이상을 넘어서거나 검증
데이터를 넣기 시작하면 오분류율이 증가하게 된다. 이는 새로운 데이터에 대한 예측 성능인 일반화 능력이 떨어지는 것을
의미한다. 그래서 가지치기를 하기에 적당한 지점은, 분기 수가 증가하는 지점이면서 동시에 input 검증 데이터의 오분류율이
증가하는 시점이다.
가지치기를 하기에 적당한 지점을 찾는 것 역시 Greedy한 알고리즘을 사용한다. 잘라낼 가지들, 즉 subtree를 선택해 나가면서
이때의 adjusted error rate(aer)를 측정한다. 어떤 subtree를 잘라냈을 때, 전체 트리의 aer보다 작거나 같으면, 가지치기가
가능한 후보가 되는 것이다.
### data
data("Vowel")
set.seed(100)
ind2 = sample(1:nrow(Vowel),
nrow(Vowel)*0.7,
replace = F)
train = Vowel[ind2, -1]
test = Vowel[-ind2, -1]
### pruning
install.packages("tree")
library(tree)
library(MASS)
ir.tr = tree(Class~., train) # 가지치기 전
plot(ir.tr)
text(ir.tr, all = T)
plot(prune.misclass(ir.tr))
# 트리의 depth가 적당하면서 misclass가 낮은 지점은 10~15 사이.
ir.tr1 = prune.misclass(ir.tr, best=14) # 가지치기 후
plot(ir.tr1)
text(ir.tr1, all = T)
'Programming & Machine Learning > R X 머신러닝' 카테고리의 다른 글
R을 이용한 간단한 데이터 시각화 (0) | 2017.07.21 |
---|---|
R을 이용한 머신러닝 - 6 (랜덤 포레스트 개념과 적용) (0) | 2017.07.18 |
R을 이용한 머신러닝 - 4 (로지스틱 회귀모델을 이용한 분류 문제) (0) | 2017.07.14 |
R을 이용한 머신러닝 - 3 (분류와 클러스터링 : K-NN, K-means) (0) | 2017.07.13 |
R을 이용한 통계분석 -5 (실제 데이터를 이용한 시계열 분석) (0) | 2017.07.11 |