Created
August 27, 2017 11:19
-
-
Save ishuca/cb762106cbc48ebed5a8f9ed8404f69c to your computer and use it in GitHub Desktop.
Why we use Stacking with Nested CV
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Nested CV(Cross Validation) Stacking" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### 이 튜토리얼은 아래 링크 20 슬라이드부터 23 슬라이드의 있는 \n", | |
| "### 2 stage 이상의 Stacking 에서의 Information leakage (정보 노출) 에 대해 논의하고 있습니다.\n", | |
| "https://www.slideshare.net/odsc/owen-zhangopen-sourcetoolsanddscompetitions1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "import numpy as np\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import seaborn as sns\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "예제 데이터로는 iris 데이터를 사용하겠습니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.datasets import load_iris" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "tqdm 은 Progress bar를 볼 수 있는 유용한 라이브러리 입니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from tqdm import tqdm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "데이터를 불러옵니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "dict_keys(['target', 'DESCR', 'feature_names', 'target_names', 'data'])" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "data = load_iris()\n", | |
| "data.keys()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "타겟과 특성을 나눠줍니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "target = data['target']" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "features = data['data']" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "iris 데이터는 3가지 붓꽃의 분류를 위한 데이터로써, Trainset 과 Testset 을 나눌 때 3 클래스의 비율이 비슷하도록 층화추출 하도록 하겠습니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.model_selection import StratifiedKFold, train_test_split, cross_val_score" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "재현성을 위해 random_state를 고정시켰습니다.\n", | |
| "\n", | |
| "iris 데이터는 총 150 개 관찰치로써, 이 중 20개의 관찰치를 학습에 사용하고, 130개 데이터에 테스트 해봄으로써\n", | |
| "\n", | |
| "일반화 성능이 높은 Stacking의 효과를 좀 더 명확히 보여줍니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "Xtrain, Xtest, ytrain, ytest = train_test_split(features, target, stratify=target, random_state=1234, test_size = 130)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1234)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "일반적인 CV 는 StratifiedKFold 를 이용해 split 한 후 여기서 나오는 각각의 Fold Index 를 이용해\n", | |
| "\n", | |
| "Trainset 으로 학습시킨 후, Testset 에 적용을 하게 됩니다.\n", | |
| "\n", | |
| "다만 여기서는 아래와 같은 방법을 취하기 위해 각 폴드를 따로 따로 기억하겠습니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-20-638.jpg?cb=1433444974\">\n", | |
| "### Stacking은 위와 같이 Base Model 들의 Prediction 을 Meta Feature 로 생각하고, 그 예측값들 사이의 최적 결합을 학습하는 또다른 모델 (Stacker) 를 두는 것입니다.\n", | |
| "\n", | |
| "<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-21-638.jpg?cb=1433444974\">\n", | |
| "### Overfitting 을 막기 위해서는 Out of sample 을 사용해 학습한 Prediction 을 사용해야 하는데 일반적으로 CV는 이를 만족합니다.\n", | |
| "\n", | |
| "<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-22-638.jpg?cb=1433444974\">\n", | |
| "### 하지만 최근 추세는 2 stage 이상의 모델들이 점점 많아지고 있습니다.\n", | |
| "\n", | |
| "<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-23-638.jpg?cb=1433444974\">\n", | |
| "\n", | |
| "### 2 stage 이상 Stacking에서 위와 같이 일반적인 CV로 Stacking 구현을 하게 되면 아래와 같은 Information Leakage 가 발생하게 됩니다.\n", | |
| "\n", | |
| "<img src=\"https://image.slidesharecdn.com/owenzhangopensourcetoolsanddscompetitions1-150604182708-lva1-app6891/95/open-source-tools-data-science-competitions-24-638.jpg?cb=1433444974\">\n", | |
| "\n", | |
| "### 4번째 fold를 예측하기 위해 1,2,3 folds 가 사용되었는데 이 1번째 fold 를 예측하기 위해 2,3,4 folds 를 사용해 학습했던 것이 2 stage 에서 4 fold 를 사용할 때 간접적인 정보 누출 (indirected information leakage)가 발생한다고 합니다.\n", | |
| "\n", | |
| "### 그래서 예측 값을 만들기 위해 사용하는 fold에는 다음 예측 fold가 사용되지 않도록 학습하는 Nested CV 가 필요하게 됩니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[array([ 0, 2, 4, 6, 8, 14]),\n", | |
| " array([ 3, 10, 11, 13, 17]),\n", | |
| " array([ 7, 12, 19]),\n", | |
| " array([ 5, 16, 18]),\n", | |
| " array([ 1, 9, 15])]" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "folds = [x[1] for x in list(skf.split(Xtrain,ytrain))]\n", | |
| "folds" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Stacking 에 사용할 Base Learner 로는 Randomforest, ExtraTrees, GradientBoosting 을 사용합니다" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "clfs = [RandomForestClassifier(n_estimators=100, n_jobs=-1,\n", | |
| "criterion='gini'),\n", | |
| " RandomForestClassifier(n_estimators=100, n_jobs=-1,\n", | |
| "criterion='entropy'),\n", | |
| " ExtraTreesClassifier(n_estimators=100, n_jobs=-1,\n", | |
| "criterion='gini'),\n", | |
| " ExtraTreesClassifier(n_estimators=100, n_jobs=-1,\n", | |
| "criterion='entropy'),\n", | |
| " GradientBoostingClassifier(learning_rate=0.05,\n", | |
| "subsample=0.5, max_depth=6, n_estimators=50)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "아래의 주석처리 된 코드는 Nested CV를 하지 않는 일반적인 Stacking 구현입니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "0it [00:00, ?it/s]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.22it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.21it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.14it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.09it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.11it/s]\n", | |
| "1it [00:01, 1.62s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.09it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.20it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.22it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.31it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.36it/s]\n", | |
| "2it [00:03, 1.58s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.57it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.58it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.58it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.56it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.55it/s]\n", | |
| "3it [00:04, 1.53s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.36it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.41it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.47it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.50it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.53it/s]\n", | |
| "4it [00:05, 1.50s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 15.14it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:00<00:00, 15.07it/s]\n", | |
| "5it [00:06, 1.15s/it]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# not nested\n", | |
| "\n", | |
| "class_num = len(np.unique(ytrain))\n", | |
| "\n", | |
| "dataset_blend_train_not_nested = np.zeros((Xtrain.shape[0], len(clfs)*class_num))\n", | |
| "dataset_blend_test = np.zeros((Xtest.shape[0], len(clfs)*class_num))\n", | |
| "\n", | |
| "for k, clf in tqdm(enumerate(clfs)):\n", | |
| " for i in tqdm(range(0,len(folds))):\n", | |
| " target_fold = folds[i]\n", | |
| " inner_folds = folds[0:i]+folds[i+1:]\n", | |
| " clf.fit(Xtrain[np.concatenate(inner_folds).ravel()],ytrain[np.concatenate(inner_folds).ravel()])\n", | |
| " pred = clf.predict_proba(Xtrain[target_fold])\n", | |
| " dataset_blend_train_not_nested[target_fold, k*class_num:(k*class_num+class_num)] = pred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "아래는 Nested CV를 적용한 Stacking 구현입니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "0it [00:00, ?it/s]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.54it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.51it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.46it/s]\n", | |
| "\n", | |
| "\n", | |
| " 20%|████████████████▊ | 1/5 [00:01<00:04, 1.16s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.47it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.49it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.48it/s]\n", | |
| "\n", | |
| "\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.16s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.47it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.50it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.52it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.52it/s]\n", | |
| "\n", | |
| "\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.15s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.47it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.49it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.49it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.52it/s]\n", | |
| "\n", | |
| "\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.15s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.62it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.61it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.59it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.58it/s]\n", | |
| "\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.14s/it]\n", | |
| "1it [00:05, 5.71s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.53it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.53it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.53it/s]\n", | |
| "\n", | |
| "\n", | |
| " 20%|████████████████▊ | 1/5 [00:01<00:04, 1.14s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.55it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.56it/s]\n", | |
| "\n", | |
| "\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.13s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.48it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.50it/s]\n", | |
| "\n", | |
| "\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.14s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.41it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.40it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.45it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.41it/s]\n", | |
| "\n", | |
| "\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.15s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.21it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.28it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.27it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.34it/s]\n", | |
| "\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.16s/it]\n", | |
| "2it [00:11, 5.74s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.51it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.53it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.57it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.57it/s]\n", | |
| "\n", | |
| "\n", | |
| " 20%|████████████████▊ | 1/5 [00:01<00:04, 1.12s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.12it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.20it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.32it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.27it/s]\n", | |
| "\n", | |
| "\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.15s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.29it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.34it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.40it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.25it/s]\n", | |
| "\n", | |
| "\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.17s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.13it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.16it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.17it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.07it/s]\n", | |
| "\n", | |
| "\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.21s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.33it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.35it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.29it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.22it/s]\n", | |
| "\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:06<00:00, 1.22s/it]\n", | |
| "3it [00:17, 5.85s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.35it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.34it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.31it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.30it/s]\n", | |
| "\n", | |
| "\n", | |
| " 20%|████████████████▊ | 1/5 [00:01<00:04, 1.22s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.32it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.35it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.39it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.35it/s]\n", | |
| "\n", | |
| "\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:02<00:03, 1.21s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.63it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.46it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.50it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.53it/s]\n", | |
| "\n", | |
| "\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:03<00:02, 1.19s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.25it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.27it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.21it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.22it/s]\n", | |
| "\n", | |
| "\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:04<00:01, 1.21s/it]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 3.27it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 3.37it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 3.39it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.38it/s]\n", | |
| "\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.20s/it]\n", | |
| "4it [00:23, 5.89s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 11.69it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 11.67it/s]\n", | |
| "\n", | |
| "\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 2.88it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 9.34it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 9.31it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 9.27it/s]\n", | |
| "\n", | |
| "\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:01, 2.67it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 25%|█████████████████████ | 1/4 [00:00<00:00, 9.99it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 9.65it/s]\n", | |
| "\n", | |
| " 75%|███████████████████████████████████████████████████████████████ | 3/4 [00:00<00:00, 9.52it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 9.68it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:01<00:00, 2.59it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 12.04it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 11.70it/s]\n", | |
| "\n", | |
| "\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 2.66it/s]\n", | |
| "\n", | |
| " 0%| | 0/4 [00:00<?, ?it/s]\n", | |
| "\n", | |
| " 50%|██████████████████████████████████████████ | 2/4 [00:00<00:00, 13.50it/s]\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 13.42it/s]\n", | |
| "\n", | |
| "\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 2.82it/s]\n", | |
| "5it [00:25, 4.68s/it]\n", | |
| "0it [00:00, ?it/s]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.30it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.35it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.36it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.34it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.30it/s]\n", | |
| "1it [00:01, 1.51s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.19it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.16it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.20it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.24it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.33it/s]\n", | |
| "2it [00:03, 1.51s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.23it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.19it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.22it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.32it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.39it/s]\n", | |
| "3it [00:04, 1.51s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 20%|████████████████▊ | 1/5 [00:00<00:01, 3.45it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 3.46it/s]\n", | |
| " 60%|██████████████████████████████████████████████████▍ | 3/5 [00:00<00:00, 3.51it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:01<00:00, 3.54it/s]\n", | |
| "100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.54it/s]\n", | |
| "4it [00:05, 1.48s/it]\n", | |
| " 0%| | 0/5 [00:00<?, ?it/s]\n", | |
| " 40%|█████████████████████████████████▌ | 2/5 [00:00<00:00, 12.65it/s]\n", | |
| " 80%|███████████████████████████████████████████████████████████████████▏ | 4/5 [00:00<00:00, 12.58it/s]\n", | |
| "5it [00:06, 1.16s/it]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# nested\n", | |
| "class_num = len(np.unique(ytrain))\n", | |
| "\n", | |
| "dataset_blend_train_nested = np.zeros((Xtrain.shape[0], len(clfs)*class_num))\n", | |
| "\n", | |
| "dataset_inner_train = np.zeros((Xtrain.shape[0], len(clfs)*class_num))\n", | |
| "dataset_inner_test = np.zeros((Xtest.shape[0], len(clfs)*class_num))\n", | |
| "\n", | |
| "for k, clf in tqdm(enumerate(clfs)):\n", | |
| " for i in tqdm(range(0,len(folds))):\n", | |
| " # outer CV 부분\n", | |
| " target_fold = folds[i]\n", | |
| " inner_folds = folds[0:i]+folds[i+1:]\n", | |
| " for j in tqdm(range(0,len(inner_folds))):\n", | |
| " # inner CV 부분\n", | |
| " inner_target_fold = inner_folds[j]\n", | |
| " inner_train_fold = np.concatenate(inner_folds[0:j]+inner_folds[j+1:]).ravel()\n", | |
| " clf.fit(Xtrain[inner_train_fold],ytrain[inner_train_fold])\n", | |
| " inner_pred = clf.predict_proba(Xtrain[inner_target_fold])\n", | |
| " dataset_inner_train[inner_target_fold,k*class_num:(k*class_num+class_num)] = inner_pred\n", | |
| "\n", | |
| "for k, clf in tqdm(enumerate(clfs)):\n", | |
| " for i in tqdm(range(0,len(folds))):\n", | |
| " target_fold = folds[i]\n", | |
| " inner_folds = folds[0:i]+folds[i+1:]\n", | |
| " # 학습에 사용하는 feature 가 innerCV 에서 만들어진 prediction 을 사용해 학습이 됨\n", | |
| " clf.fit(dataset_inner_train[np.concatenate(inner_folds).ravel()],ytrain[np.concatenate(inner_folds).ravel()])\n", | |
| " pred = clf.predict_proba(dataset_inner_train[target_fold])\n", | |
| " dataset_blend_train_nested[target_fold, k*class_num:(k*class_num+class_num)] = pred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "source": [ | |
| "Testset 에 대해 Prediction을 만드는 과정입니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "for k, clf in enumerate(clfs):\n", | |
| " clf.fit(Xtrain, ytrain)\n", | |
| " pred = clf.predict_proba(Xtest)\n", | |
| " dataset_blend_test[:, k*class_num:(k*class_num+class_num)] = pred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "stacker로는 LogisticRegression 을 사용합니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.linear_model import LogisticRegressionCV\n", | |
| "from pandas_confusion import ConfusionMatrix\n", | |
| "from sklearn.metrics import accuracy_score" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Nested CV 를 하지 않은 Stacking의 결과입니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted 0 1 2 __all__\n", | |
| "Actual \n", | |
| "0 43 0 0 43\n", | |
| "1 0 40 3 43\n", | |
| "2 0 3 41 44\n", | |
| "__all__ 43 43 44 130\n", | |
| "\n", | |
| "정확도 : 0.953846153846\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "stacker = LogisticRegressionCV(refit=False)\n", | |
| "stacker.fit(dataset_blend_train_not_nested,ytrain)\n", | |
| "print(ConfusionMatrix(ytest,stacker.predict(dataset_blend_test)))\n", | |
| "print()\n", | |
| "print(\"정확도 : \", accuracy_score(ytest,stacker.predict(dataset_blend_test)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Nested CV 를 한 Stacking의 결과입니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted 0 1 2 __all__\n", | |
| "Actual \n", | |
| "0 43 0 0 43\n", | |
| "1 0 41 2 43\n", | |
| "2 0 3 41 44\n", | |
| "__all__ 43 44 43 130\n", | |
| "\n", | |
| "정확도 : 0.961538461538\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "stacker = LogisticRegressionCV(refit=False)\n", | |
| "stacker.fit(dataset_blend_train_nested,ytrain)\n", | |
| "print(ConfusionMatrix(ytest,stacker.predict(dataset_blend_test)))\n", | |
| "print()\n", | |
| "print(\"정확도 : \", accuracy_score(ytest,stacker.predict(dataset_blend_test)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Stacking을 하지 않은 LogisticRegression 의 결과입니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted 0 1 2 __all__\n", | |
| "Actual \n", | |
| "0 43 0 0 43\n", | |
| "1 2 37 4 43\n", | |
| "2 0 0 44 44\n", | |
| "__all__ 45 37 48 130\n", | |
| "\n", | |
| "정확도 : 0.953846153846\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "lr = LogisticRegressionCV(refit=False)\n", | |
| "lr.fit(Xtrain,ytrain)\n", | |
| "accuracy_score(ytest,lr.predict(Xtest))\n", | |
| "print(ConfusionMatrix(ytest,lr.predict(Xtest)))\n", | |
| "print()\n", | |
| "print(\"정확도 : \", accuracy_score(ytest,lr.predict(Xtest)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Conclusion\n", | |
| "\n", | |
| "Testset 을 Holdout Set 으로 두었을 때의 정확도를 비교해보면,\n", | |
| "Nested CV 가 좀 더 일반화를 잘 하고 있는 것을 알 수 있습니다.\n", | |
| "\n", | |
| "이 얘기는 Nested CV 하지 않은 Stacking이 좀 더 Trainset 에 Overfitting 한다고 볼 수 있습니다.\n", | |
| "\n", | |
| "예시를 위해 1 Stage 만을 해봤습니다." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "anaconda-cloud": {}, | |
| "kernelspec": { | |
| "display_name": "Python [conda root]", | |
| "language": "python", | |
| "name": "conda-root-py" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 1 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment