Skip to content

Instantly share code, notes, and snippets.

@ishuca
Created August 27, 2017 11:19
Show Gist options
  • Select an option

  • Save ishuca/cb762106cbc48ebed5a8f9ed8404f69c to your computer and use it in GitHub Desktop.

Select an option

Save ishuca/cb762106cbc48ebed5a8f9ed8404f69c to your computer and use it in GitHub Desktop.
Why we use Stacking with Nested CV
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Nested CV(Cross Validation) Stacking"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 이 튜토리얼은 아래 링크 20 슬라이드부터 23 슬라이드의 있는 \n",
"### 2 stage 이상의 Stacking 에서의 Information leakage (정보 노출) 에 대해 논의하고 있습니다.\n",
"https://www.slideshare.net/odsc/owen-zhangopen-sourcetoolsanddscompetitions1"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"예제 데이터로는 iris 데이터를 사용하겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"tqdm 은 Progress bar를 볼 수 있는 유용한 라이브러리 입니다."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"데이터를 불러옵니다."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['target', 'DESCR', 'feature_names', 'target_names', 'data'])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = load_iris()\n",
"data.keys()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"타겟과 특성을 나눠줍니다."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"target = data['target']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"features = data['data']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"iris 데이터는 3가지 붓꽃의 분류를 위한 데이터로써, Trainset 과 Testset 을 나눌 때 3 클래스의 비율이 비슷하도록 층화추출 하도록 하겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.model_selection import StratifiedKFold, train_test_split, cross_val_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"재현성을 위해 random_state를 고정시켰습니다.\n",
"\n",
"iris 데이터는 총 150 개 관찰치로써, 이 중 20개의 관찰치를 학습에 사용하고, 130개 데이터에 테스트 해봄으로써\n",
"\n",
"일반화 성능이 높은 Stacking의 효과를 좀 더 명확히 보여줍니다."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"Xtrain, Xtest, ytrain, ytest = train_test_split(features, target, stratify=target, random_state=1234, test_size = 130)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1234)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"일반적인 CV 는 StratifiedKFold 를 이용해 split 한 후 여기서 나오는 각각의 Fold Index 를 이용해\n",
"\n",
"Trainset 으로 학습시킨 후, Testset 에 적용을 하게 됩니다.\n",
"\n",
"다만 여기서는 아래와 같은 방법을 취하기 위해 각 폴드를 따로 따로 기억하겠습니다."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-20-638.jpg?cb=1433444974\">\n",
"### Stacking은 위와 같이 Base Model 들의 Prediction 을 Meta Feature 로 생각하고, 그 예측값들 사이의 최적 결합을 학습하는 또다른 모델 (Stacker) 를 두는 것입니다.\n",
"\n",
"<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-21-638.jpg?cb=1433444974\">\n",
"### Overfitting 을 막기 위해서는 Out of sample 을 사용해 학습한 Prediction 을 사용해야 하는데 일반적으로 CV는 이를 만족합니다.\n",
"\n",
"<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-22-638.jpg?cb=1433444974\">\n",
"### 하지만 최근 추세는 2 stage 이상의 모델들이 점점 많아지고 있습니다.\n",
"\n",
"<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-23-638.jpg?cb=1433444974\">\n",
"\n",
"### 2 stage 이상 Stacking에서 위와 같이 일반적인 CV로 Stacking 구현을 하게 되면 아래와 같은 Information Leakage 가 발생하게 됩니다.\n",
"\n",
"<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-24-638.jpg?cb=1433444974\">\n",
"\n",
"### 4번째 fold를 예측하기 위해 1,2,3 folds 가 사용되었는데 이 1번째 fold 를 예측하기 위해 2,3,4 folds 를 사용해 학습했던 것이 2 stage 에서 4 fold 를 사용할 때 간접적인 정보 누출 (indirected information leakage)가 발생한다고 합니다.\n",
"\n",
"### 그래서 예측 값을 만들기 위해 사용하는 fold에는 다음 예측 fold가 사용되지 않도록 학습하는 Nested CV 가 필요하게 됩니다."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[array([ 0, 2, 4, 6, 8, 14]),\n",
" array([ 3, 10, 11, 13, 17]),\n",
" array([ 7, 12, 19]),\n",
" array([ 5, 16, 18]),\n",
" array([ 1, 9, 15])]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"folds = [x[1] for x in list(skf.split(Xtrain,ytrain))]\n",
"folds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Stacking 에 사용할 Base Learner 로는 Randomforest, ExtraTrees, GradientBoosting 을 사용합니다"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"clfs = [RandomForestClassifier(n_estimators=100, n_jobs=-1,\n",
"criterion='gini'),\n",
" RandomForestClassifier(n_estimators=100, n_jobs=-1,\n",
"criterion='entropy'),\n",
" ExtraTreesClassifier(n_estimators=100, n_jobs=-1,\n",
"criterion='gini'),\n",
" ExtraTreesClassifier(n_estimators=100, n_jobs=-1,\n",
"criterion='entropy'),\n",
" GradientBoostingClassifier(learning_rate=0.05,\n",
"subsample=0.5, max_depth=6, n_estimators=50)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"아래의 주석처리 된 코드는 Nested CV를 하지 않는 일반적인 Stacking 구현입니다."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.22it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.21it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.14it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.09it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.11it/s]\n",
"1it [00:01, 1.62s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.09it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.20it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.22it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.31it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.36it/s]\n",
"2it [00:03, 1.58s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.57it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.58it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.58it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.56it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.55it/s]\n",
"3it [00:04, 1.53s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.36it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.41it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.47it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.50it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.53it/s]\n",
"4it [00:05, 1.50s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 15.14it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:00<00:00, 15.07it/s]\n",
"5it [00:06, 1.15s/it]\n"
]
}
],
"source": [
"# not nested\n",
"\n",
"class_num = len(np.unique(ytrain))\n",
"\n",
"dataset_blend_train_not_nested = np.zeros((Xtrain.shape[0], len(clfs)*class_num))\n",
"dataset_blend_test = np.zeros((Xtest.shape[0], len(clfs)*class_num))\n",
"\n",
"for k, clf in tqdm(enumerate(clfs)):\n",
" for i in tqdm(range(0,len(folds))):\n",
" target_fold = folds[i]\n",
" inner_folds = folds[0:i]+folds[i+1:]\n",
" clf.fit(Xtrain[np.concatenate(inner_folds).ravel()],ytrain[np.concatenate(inner_folds).ravel()])\n",
" pred = clf.predict_proba(Xtrain[target_fold])\n",
" dataset_blend_train_not_nested[target_fold, k*class_num:(k*class_num+class_num)] = pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"아래는 Nested CV를 적용한 Stacking 구현입니다."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.54it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.51it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.46it/s]\n",
"\n",
"\n",
" 20%|████████████████▊ | 1/5 [00:01<00:04, 1.16s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.47it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.49it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.48it/s]\n",
"\n",
"\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.16s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.47it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.50it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.52it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.52it/s]\n",
"\n",
"\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.15s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.47it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.49it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.49it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.52it/s]\n",
"\n",
"\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.15s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.62it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.61it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.59it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.58it/s]\n",
"\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.14s/it]\n",
"1it [00:05, 5.71s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.53it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.53it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.53it/s]\n",
"\n",
"\n",
" 20%|████████████████▊ | 1/5 [00:01<00:04, 1.14s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.55it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.56it/s]\n",
"\n",
"\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.13s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.48it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.50it/s]\n",
"\n",
"\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.14s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.41it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.40it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.45it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.41it/s]\n",
"\n",
"\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.15s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.21it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.28it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.27it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.34it/s]\n",
"\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.16s/it]\n",
"2it [00:11, 5.74s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.57it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.57it/s]\n",
"\n",
"\n",
" 20%|████████████████▊ | 1/5 [00:01<00:04, 1.12s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.12it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.20it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.32it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.27it/s]\n",
"\n",
"\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.15s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.29it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.34it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.40it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.25it/s]\n",
"\n",
"\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.17s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.13it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.16it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.17it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.07it/s]\n",
"\n",
"\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.21s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.33it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.35it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.29it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.22it/s]\n",
"\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:06<00:00, 1.22s/it]\n",
"3it [00:17, 5.85s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.35it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.34it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.31it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.30it/s]\n",
"\n",
"\n",
" 20%|████████████████▊ | 1/5 [00:01<00:04, 1.22s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.32it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.35it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.39it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.35it/s]\n",
"\n",
"\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.21s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.63it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.46it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.50it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.53it/s]\n",
"\n",
"\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.19s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.25it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.27it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.21it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.22it/s]\n",
"\n",
"\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.21s/it]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 3.27it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.37it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.39it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.38it/s]\n",
"\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.20s/it]\n",
"4it [00:23, 5.89s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 11.69it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 11.67it/s]\n",
"\n",
"\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 2.88it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 9.34it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 9.31it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 9.27it/s]\n",
"\n",
"\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:01, 2.67it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 25%|█████████████████████ | 1/4 [00:00<00:00, 9.99it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 9.65it/s]\n",
"\n",
" 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 9.52it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 9.68it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:01<00:00, 2.59it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 12.04it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 11.70it/s]\n",
"\n",
"\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 2.66it/s]\n",
"\n",
" 0%| | 0/4 [00:00<?, ?it/s]\n",
"\n",
" 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 13.50it/s]\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 13.42it/s]\n",
"\n",
"\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 2.82it/s]\n",
"5it [00:25, 4.68s/it]\n",
"0it [00:00, ?it/s]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.30it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.35it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.36it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.34it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.30it/s]\n",
"1it [00:01, 1.51s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.19it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.16it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.20it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.24it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.33it/s]\n",
"2it [00:03, 1.51s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.23it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.19it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.22it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.32it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.39it/s]\n",
"3it [00:04, 1.51s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 20%|████████████████▊ | 1/5 [00:00<00:01, 3.45it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.46it/s]\n",
" 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.51it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.54it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.54it/s]\n",
"4it [00:05, 1.48s/it]\n",
" 0%| | 0/5 [00:00<?, ?it/s]\n",
" 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 12.65it/s]\n",
" 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:00<00:00, 12.58it/s]\n",
"5it [00:06, 1.16s/it]\n"
]
}
],
"source": [
"# nested\n",
"class_num = len(np.unique(ytrain))\n",
"\n",
"dataset_blend_train_nested = np.zeros((Xtrain.shape[0], len(clfs)*class_num))\n",
"\n",
"dataset_inner_train = np.zeros((Xtrain.shape[0], len(clfs)*class_num))\n",
"dataset_inner_test = np.zeros((Xtest.shape[0], len(clfs)*class_num))\n",
"\n",
"for k, clf in tqdm(enumerate(clfs)):\n",
" for i in tqdm(range(0,len(folds))):\n",
" # outer CV 부분\n",
" target_fold = folds[i]\n",
" inner_folds = folds[0:i]+folds[i+1:]\n",
" for j in tqdm(range(0,len(inner_folds))):\n",
" # inner CV 부분\n",
" inner_target_fold = inner_folds[j]\n",
" inner_train_fold = np.concatenate(inner_folds[0:j]+inner_folds[j+1:]).ravel()\n",
" clf.fit(Xtrain[inner_train_fold],ytrain[inner_train_fold])\n",
" inner_pred = clf.predict_proba(Xtrain[inner_target_fold])\n",
" dataset_inner_train[inner_target_fold,k*class_num:(k*class_num+class_num)] = inner_pred\n",
"\n",
"for k, clf in tqdm(enumerate(clfs)):\n",
" for i in tqdm(range(0,len(folds))):\n",
" target_fold = folds[i]\n",
" inner_folds = folds[0:i]+folds[i+1:]\n",
" # 학습에 사용하는 feature 가 innerCV 에서 만들어진 prediction 을 사용해 학습이 됨\n",
" clf.fit(dataset_inner_train[np.concatenate(inner_folds).ravel()],ytrain[np.concatenate(inner_folds).ravel()])\n",
" pred = clf.predict_proba(dataset_inner_train[target_fold])\n",
" dataset_blend_train_nested[target_fold, k*class_num:(k*class_num+class_num)] = pred"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"Testset 에 대해 Prediction을 만드는 과정입니다."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"for k, clf in enumerate(clfs):\n",
" clf.fit(Xtrain, ytrain)\n",
" pred = clf.predict_proba(Xtest)\n",
" dataset_blend_test[:, k*class_num:(k*class_num+class_num)] = pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"stacker로는 LogisticRegression 을 사용합니다."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegressionCV\n",
"from pandas_confusion import ConfusionMatrix\n",
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nested CV 를 하지 않은 Stacking의 결과입니다."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted 0 1 2 __all__\n",
"Actual \n",
"0 43 0 0 43\n",
"1 0 40 3 43\n",
"2 0 3 41 44\n",
"__all__ 43 43 44 130\n",
"\n",
"정확도 : 0.953846153846\n"
]
}
],
"source": [
"stacker = LogisticRegressionCV(refit=False)\n",
"stacker.fit(dataset_blend_train_not_nested,ytrain)\n",
"print(ConfusionMatrix(ytest,stacker.predict(dataset_blend_test)))\n",
"print()\n",
"print(\"정확도 : \", accuracy_score(ytest,stacker.predict(dataset_blend_test)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nested CV 를 한 Stacking의 결과입니다."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted 0 1 2 __all__\n",
"Actual \n",
"0 43 0 0 43\n",
"1 0 41 2 43\n",
"2 0 3 41 44\n",
"__all__ 43 44 43 130\n",
"\n",
"정확도 : 0.961538461538\n"
]
}
],
"source": [
"stacker = LogisticRegressionCV(refit=False)\n",
"stacker.fit(dataset_blend_train_nested,ytrain)\n",
"print(ConfusionMatrix(ytest,stacker.predict(dataset_blend_test)))\n",
"print()\n",
"print(\"정확도 : \", accuracy_score(ytest,stacker.predict(dataset_blend_test)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Stacking을 하지 않은 LogisticRegression 의 결과입니다."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted 0 1 2 __all__\n",
"Actual \n",
"0 43 0 0 43\n",
"1 2 37 4 43\n",
"2 0 0 44 44\n",
"__all__ 45 37 48 130\n",
"\n",
"정확도 : 0.953846153846\n"
]
}
],
"source": [
"lr = LogisticRegressionCV(refit=False)\n",
"lr.fit(Xtrain,ytrain)\n",
"accuracy_score(ytest,lr.predict(Xtest))\n",
"print(ConfusionMatrix(ytest,lr.predict(Xtest)))\n",
"print()\n",
"print(\"정확도 : \", accuracy_score(ytest,lr.predict(Xtest)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Conclusion\n",
"\n",
"Testset 을 Holdout Set 으로 두었을 때의 정확도를 비교해보면,\n",
"Nested CV 가 좀 더 일반화를 잘 하고 있는 것을 알 수 있습니다.\n",
"\n",
"이 얘기는 Nested CV 하지 않은 Stacking이 좀 더 Trainset 에 Overfitting 한다고 볼 수 있습니다.\n",
"\n",
"예시를 위해 1 Stage 만을 해봤습니다."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment