Skip to content

Instantly share code, notes, and snippets.

@brusangues
Created December 15, 2025 01:48
Show Gist options
  • Select an option

  • Save brusangues/3533b3bc3883bfeeca2c2012e5a47b80 to your computer and use it in GitHub Desktop.

Select an option

Save brusangues/3533b3bc3883bfeeca2c2012e5a47b80 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 50,
"id": "36ebe4c8",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_breast_cancer\n",
"\n",
"data = load_breast_cancer(as_frame=True)\n",
"data.keys()\n",
"df = data.frame"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "0e5ec84c",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_covtype\n",
"\n",
"data = fetch_covtype(as_frame=True)\n",
"data.keys()\n",
"df = data.frame"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "b5c86837",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['Elevation', 'Aspect', 'Slope', 'Horizontal_Distance_To_Hydrology',\n",
" 'Vertical_Distance_To_Hydrology', 'Horizontal_Distance_To_Roadways',\n",
" 'Hillshade_9am', 'Hillshade_Noon', 'Hillshade_3pm',\n",
" 'Horizontal_Distance_To_Fire_Points', 'Wilderness_Area_0',\n",
" 'Wilderness_Area_1', 'Wilderness_Area_2', 'Wilderness_Area_3',\n",
" 'Soil_Type_0', 'Soil_Type_1', 'Soil_Type_2', 'Soil_Type_3',\n",
" 'Soil_Type_4', 'Soil_Type_5', 'Soil_Type_6', 'Soil_Type_7',\n",
" 'Soil_Type_8', 'Soil_Type_9', 'Soil_Type_10', 'Soil_Type_11',\n",
" 'Soil_Type_12', 'Soil_Type_13', 'Soil_Type_14', 'Soil_Type_15',\n",
" 'Soil_Type_16', 'Soil_Type_17', 'Soil_Type_18', 'Soil_Type_19',\n",
" 'Soil_Type_20', 'Soil_Type_21', 'Soil_Type_22', 'Soil_Type_23',\n",
" 'Soil_Type_24', 'Soil_Type_25', 'Soil_Type_26', 'Soil_Type_27',\n",
" 'Soil_Type_28', 'Soil_Type_29', 'Soil_Type_30', 'Soil_Type_31',\n",
" 'Soil_Type_32', 'Soil_Type_33', 'Soil_Type_34', 'Soil_Type_35',\n",
" 'Soil_Type_36', 'Soil_Type_37', 'Soil_Type_38', 'Soil_Type_39',\n",
" 'Cover_Type'],\n",
" dtype='object')"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.columns"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "cce83c33",
"metadata": {},
"outputs": [],
"source": [
"df.rename(columns={\"Cover_Type\": \"target\"}, inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "37120478",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(581012, 55)"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "d3b943be",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"target\n",
"2 283301\n",
"1 211840\n",
"3 35754\n",
"7 20510\n",
"6 17367\n",
"5 9493\n",
"4 2747\n",
"Name: count, dtype: int64"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.target.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 202,
"id": "8e749d46",
"metadata": {},
"outputs": [],
"source": [
"# Opcional - transformando target em classificação binária\n",
"df.target = df.target.apply(lambda x: int(x >= 4))"
]
},
{
"cell_type": "code",
"execution_count": 209,
"id": "7a486a9e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"target\n",
"0 530895\n",
"1 50117\n",
"Name: count, dtype: int64"
]
},
"execution_count": 209,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.target.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 203,
"id": "d048c16f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['elevation', 'aspect', 'slope', 'horizontal_distance_to_hydrology',\n",
" 'vertical_distance_to_hydrology', 'horizontal_distance_to_roadways',\n",
" 'hillshade_9am', 'hillshade_noon', 'hillshade_3pm',\n",
" 'horizontal_distance_to_fire_points', 'wilderness_area_0',\n",
" 'wilderness_area_1', 'wilderness_area_2', 'wilderness_area_3',\n",
" 'soil_type_0', 'soil_type_1', 'soil_type_2', 'soil_type_3',\n",
" 'soil_type_4', 'soil_type_5', 'soil_type_6', 'soil_type_7',\n",
" 'soil_type_8', 'soil_type_9', 'soil_type_10', 'soil_type_11',\n",
" 'soil_type_12', 'soil_type_13', 'soil_type_14', 'soil_type_15',\n",
" 'soil_type_16', 'soil_type_17', 'soil_type_18', 'soil_type_19',\n",
" 'soil_type_20', 'soil_type_21', 'soil_type_22', 'soil_type_23',\n",
" 'soil_type_24', 'soil_type_25', 'soil_type_26', 'soil_type_27',\n",
" 'soil_type_28', 'soil_type_29', 'soil_type_30', 'soil_type_31',\n",
" 'soil_type_32', 'soil_type_33', 'soil_type_34', 'soil_type_35',\n",
" 'soil_type_36', 'soil_type_37', 'soil_type_38', 'soil_type_39',\n",
" 'target'],\n",
" dtype='object')"
]
},
"execution_count": 203,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.columns = df.columns.str.lower().str.replace(\" \",\"_\")\n",
"df.columns"
]
},
{
"cell_type": "code",
"execution_count": 204,
"id": "1952689b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"df_train.shape=(4183, 55)\n",
"df_eval.shape=(465, 55)\n",
"df_test.shape=(1162, 55)\n"
]
}
],
"source": [
"# Separação dos conjuntos\n",
"\n",
"from sklearn.model_selection import train_test_split as tts\n",
"\n",
"df_sample = df.sample(frac=0.01, random_state=42)\n",
"df_train_eval, df_test = tts(df_sample, test_size=0.2, random_state=42, stratify=df_sample.target)\n",
"df_train, df_eval = tts(df_train_eval, test_size=0.1, random_state=42, stratify=df_train_eval.target)\n",
"\n",
"print(f\"{df_train.shape=}\\n{df_eval.shape=}\\n{df_test.shape=}\")"
]
},
{
"cell_type": "code",
"execution_count": 212,
"id": "7db81aeb",
"metadata": {},
"outputs": [],
"source": [
"# Definindo parâmetros do catboost para usar no resto do pipeline\n",
"params = {\n",
" \"iterations\": 2000,\n",
" \"loss_function\": \"Logloss\", # \"MultiClass\",\n",
" \"eval_metric\": \"F1\", # \"TotalF1\",\n",
" \"random_seed\": 100,\n",
" \"verbose\": 100,\n",
" \"early_stopping_rounds\": 100,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e515191b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "083cecc8fce2400d89bdb652f6ae43ca",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Learning rate set to 0.033379\n",
"0:\tlearn: 0.0976864\ttest: 0.0000000\tbest: 0.0000000 (0)\ttotal: 8.43ms\tremaining: 16.9s\n",
"100:\tlearn: 0.4000000\ttest: 0.3773585\tbest: 0.4150943 (91)\ttotal: 941ms\tremaining: 17.7s\n",
"200:\tlearn: 0.6171429\ttest: 0.4406780\tbest: 0.4482759 (153)\ttotal: 1.75s\tremaining: 15.6s\n",
"300:\tlearn: 0.7033748\ttest: 0.5312500\tbest: 0.5396825 (264)\ttotal: 2.56s\tremaining: 14.5s\n",
"400:\tlearn: 0.7805695\ttest: 0.5454545\tbest: 0.5757576 (318)\ttotal: 3.37s\tremaining: 13.4s\n",
"Stopped by overfitting detector (100 iterations wait)\n",
"\n",
"bestTest = 0.5757575758\n",
"bestIteration = 318\n",
"\n",
"Shrink model to first 319 iterations.\n"
]
},
{
"data": {
"text/plain": [
"<catboost.core.CatBoostClassifier at 0x1a4daf69290>"
]
},
"execution_count": 213,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Treinando modelo inicial com todas as features\n",
"from catboost import CatBoostClassifier, Pool, EShapCalcType, EFeaturesSelectionAlgorithm\n",
"\n",
"pool_train = Pool(data=df_train.drop(columns=[\"target\"]), label=df_train[\"target\"])\n",
"pool_eval = Pool(data=df_eval.drop(columns=[\"target\"]), label=df_eval[\"target\"])\n",
"pool_test = Pool(data=df_test.drop(columns=[\"target\"]), label=df_test[\"target\"])\n",
"\n",
"model = CatBoostClassifier(**params)\n",
"\n",
"model.fit(\n",
" pool_train,\n",
" eval_set=pool_eval,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 214,
"id": "d760142e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selecting 1 features out of 54 using 55 steps\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a77eea297c3474a8489436b08b56f48",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"line": {
"color": "rgb(51,160,44)"
},
"mode": "lines+markers",
"name": "",
"text": [
"",
"vertical_distance_to_hydrology",
"soil_type_39",
"soil_type_13",
"soil_type_0",
"soil_type_31",
"slope",
"soil_type_2",
"soil_type_15",
"aspect",
"soil_type_12",
"wilderness_area_1",
"soil_type_36",
"soil_type_27",
"soil_type_5",
"soil_type_33",
"soil_type_34",
"soil_type_17",
"hillshade_noon",
"soil_type_26",
"soil_type_6",
"soil_type_10",
"soil_type_19",
"soil_type_8",
"soil_type_7",
"soil_type_14",
"soil_type_20",
"soil_type_24",
"soil_type_18",
"soil_type_16",
"soil_type_25",
"soil_type_35",
"soil_type_22",
"soil_type_28",
"soil_type_23",
"soil_type_21",
"soil_type_1",
"soil_type_11",
"soil_type_4",
"soil_type_32",
"soil_type_30",
"hillshade_9am",
"wilderness_area_2",
"soil_type_3",
"soil_type_29",
"soil_type_37",
"horizontal_distance_to_hydrology",
"soil_type_9",
"horizontal_distance_to_fire_points",
"wilderness_area_3",
"wilderness_area_0",
"soil_type_38",
"horizontal_distance_to_roadways",
"hillshade_3pm"
],
"type": "scatter",
"x": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53
],
"y": [
0.1477282443515885,
0.14360728665267491,
0.1431814290361284,
0.14298460583655714,
0.14282510476834823,
0.138720506542803,
0.1379514449745828,
0.1373899846007138,
0.13718637249828747,
0.1396625950309679,
0.14072820600529518,
0.14097406332773885,
0.13595778276128634,
0.1346676493523341,
0.13361166554323825,
0.13187604508174725,
0.13120878267479325,
0.13104644007439964,
0.1517044460271607,
0.15171273628228008,
0.1517137533959223,
0.13722506984594268,
0.13210261129011047,
0.1443178728641258,
0.14431787286941414,
0.1458053618311109,
0.14729285079280766,
0.14878033975450441,
0.14132961172824576,
0.14132961172824576,
0.14027958275842997,
0.14027958275843014,
0.14037541750914248,
0.13926717055366758,
0.13797897441609677,
0.14212675332306965,
0.14041561086454277,
0.13757352885677757,
0.13510393460777329,
0.14409481128293467,
0.14084513882448946,
0.12712209783860534,
0.1378335582278886,
0.13478887732444192,
0.13403321273113225,
0.14991679743231506,
0.1350558793736135,
0.14042553951270548,
0.15684548881410224,
0.15089630427895054,
0.15390386308252524,
0.15754301616669614,
0.1922590970216916,
0.194306053544715
]
},
{
"marker": {
"size": 10,
"symbol": "square"
},
"mode": "markers",
"name": "",
"text": [
"",
"soil_type_0",
"soil_type_15",
"wilderness_area_1",
"soil_type_5",
"soil_type_17",
"soil_type_6",
"soil_type_19",
"soil_type_7",
"soil_type_24",
"soil_type_16",
"soil_type_25",
"soil_type_22",
"soil_type_23",
"soil_type_21",
"soil_type_11",
"soil_type_4",
"soil_type_32",
"soil_type_30",
"hillshade_9am",
"wilderness_area_2",
"soil_type_3",
"soil_type_29",
"soil_type_29",
"soil_type_37",
"horizontal_distance_to_hydrology",
"horizontal_distance_to_hydrology",
"soil_type_9",
"soil_type_9",
"horizontal_distance_to_fire_points",
"horizontal_distance_to_fire_points",
"wilderness_area_3",
"wilderness_area_3",
"wilderness_area_3",
"wilderness_area_0",
"wilderness_area_0",
"wilderness_area_0",
"soil_type_38",
"soil_type_38",
"soil_type_38",
"soil_type_38",
"horizontal_distance_to_roadways",
"horizontal_distance_to_roadways",
"horizontal_distance_to_roadways",
"horizontal_distance_to_roadways",
"horizontal_distance_to_roadways",
"horizontal_distance_to_roadways",
"horizontal_distance_to_roadways",
"hillshade_3pm",
"hillshade_3pm",
"hillshade_3pm",
"hillshade_3pm",
"hillshade_3pm",
"hillshade_3pm"
],
"type": "scatter",
"x": [
0,
4,
8,
11,
14,
17,
20,
22,
24,
27,
29,
30,
32,
34,
35,
37,
38,
39,
40,
41,
42,
43,
44,
44,
45,
46,
46,
47,
47,
48,
48,
49,
49,
49,
50,
50,
50,
51,
51,
51,
51,
52,
52,
52,
52,
52,
52,
52,
53,
53,
53,
53,
53,
53
],
"y": [
0.1477282443515885,
0.14282510476834823,
0.13718637249828747,
0.14097406332773885,
0.13361166554323825,
0.13104644007439964,
0.1517137533959223,
0.13210261129011047,
0.14431787286941414,
0.14878033975450441,
0.14132961172824576,
0.14027958275842997,
0.14037541750914248,
0.13797897441609677,
0.14212675332306965,
0.13757352885677757,
0.13510393460777329,
0.14409481128293467,
0.14084513882448946,
0.12712209783860534,
0.1378335582278886,
0.13478887732444192,
0.13403321273113225,
0.13403321273113225,
0.14991679743231506,
0.1350558793736135,
0.1350558793736135,
0.14042553951270548,
0.14042553951270548,
0.15684548881410224,
0.15684548881410224,
0.15089630427895054,
0.15089630427895054,
0.15089630427895054,
0.15390386308252524,
0.15390386308252524,
0.15390386308252524,
0.15754301616669614,
0.15754301616669614,
0.15754301616669614,
0.15754301616669614,
0.1922590970216916,
0.1922590970216916,
0.1922590970216916,
0.1922590970216916,
0.1922590970216916,
0.1922590970216916,
0.1922590970216916,
0.194306053544715,
0.194306053544715,
0.194306053544715,
0.194306053544715,
0.194306053544715,
0.194306053544715
]
},
{
"mode": "text",
"name": "",
"text": [
"",
"4",
"53",
"27",
"14",
"45",
"2",
"16",
"29",
"1",
"26",
"11",
"50",
"41",
"19",
"47",
"48",
"31",
"7",
"40",
"20",
"24",
"33",
"22",
"21",
"28",
"34",
"38",
"32",
"30",
"39",
"49",
"36",
"42",
"37",
"35",
"15",
"25",
"18",
"46",
"44",
"6",
"12",
"17",
"43",
"51",
"3",
"23",
"9",
"13",
"10",
"52",
"5",
"8"
],
"textfont": {
"color": "rgb(51,160,44)",
"family": "sans serif",
"size": 18
},
"textposition": "bottom center",
"type": "scatter",
"visible": false,
"x": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53
],
"y": [
0.1477282443515885,
0.14360728665267491,
0.1431814290361284,
0.14298460583655714,
0.14282510476834823,
0.138720506542803,
0.1379514449745828,
0.1373899846007138,
0.13718637249828747,
0.1396625950309679,
0.14072820600529518,
0.14097406332773885,
0.13595778276128634,
0.1346676493523341,
0.13361166554323825,
0.13187604508174725,
0.13120878267479325,
0.13104644007439964,
0.1517044460271607,
0.15171273628228008,
0.1517137533959223,
0.13722506984594268,
0.13210261129011047,
0.1443178728641258,
0.14431787286941414,
0.1458053618311109,
0.14729285079280766,
0.14878033975450441,
0.14132961172824576,
0.14132961172824576,
0.14027958275842997,
0.14027958275843014,
0.14037541750914248,
0.13926717055366758,
0.13797897441609677,
0.14212675332306965,
0.14041561086454277,
0.13757352885677757,
0.13510393460777329,
0.14409481128293467,
0.14084513882448946,
0.12712209783860534,
0.1378335582278886,
0.13478887732444192,
0.13403321273113225,
0.14991679743231506,
0.1350558793736135,
0.14042553951270548,
0.15684548881410224,
0.15089630427895054,
0.15390386308252524,
0.15754301616669614,
0.1922590970216916,
0.194306053544715
]
},
{
"mode": "text",
"name": "",
"text": [
"",
"vertical_distance_to_hydrology",
"soil_type_39",
"soil_type_13",
"soil_type_0",
"soil_type_31",
"slope",
"soil_type_2",
"soil_type_15",
"aspect",
"soil_type_12",
"wilderness_area_1",
"soil_type_36",
"soil_type_27",
"soil_type_5",
"soil_type_33",
"soil_type_34",
"soil_type_17",
"hillshade_noon",
"soil_type_26",
"soil_type_6",
"soil_type_10",
"soil_type_19",
"soil_type_8",
"soil_type_7",
"soil_type_14",
"soil_type_20",
"soil_type_24",
"soil_type_18",
"soil_type_16",
"soil_type_25",
"soil_type_35",
"soil_type_22",
"soil_type_28",
"soil_type_23",
"soil_type_21",
"soil_type_1",
"soil_type_11",
"soil_type_4",
"soil_type_32",
"soil_type_30",
"hillshade_9am",
"wilderness_area_2",
"soil_type_3",
"soil_type_29",
"soil_type_37",
"horizontal_distance_to_hydrology",
"soil_type_9",
"horizontal_distance_to_fire_points",
"wilderness_area_3",
"wilderness_area_0",
"soil_type_38",
"horizontal_distance_to_roadways",
"hillshade_3pm"
],
"textfont": {
"color": "rgb(51,160,44)",
"family": "sans serif",
"size": 18
},
"textposition": "bottom center",
"type": "scatter",
"visible": false,
"x": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53
],
"y": [
0.1477282443515885,
0.14360728665267491,
0.1431814290361284,
0.14298460583655714,
0.14282510476834823,
0.138720506542803,
0.1379514449745828,
0.1373899846007138,
0.13718637249828747,
0.1396625950309679,
0.14072820600529518,
0.14097406332773885,
0.13595778276128634,
0.1346676493523341,
0.13361166554323825,
0.13187604508174725,
0.13120878267479325,
0.13104644007439964,
0.1517044460271607,
0.15171273628228008,
0.1517137533959223,
0.13722506984594268,
0.13210261129011047,
0.1443178728641258,
0.14431787286941414,
0.1458053618311109,
0.14729285079280766,
0.14878033975450441,
0.14132961172824576,
0.14132961172824576,
0.14027958275842997,
0.14027958275843014,
0.14037541750914248,
0.13926717055366758,
0.13797897441609677,
0.14212675332306965,
0.14041561086454277,
0.13757352885677757,
0.13510393460777329,
0.14409481128293467,
0.14084513882448946,
0.12712209783860534,
0.1378335582278886,
0.13478887732444192,
0.13403321273113225,
0.14991679743231506,
0.1350558793736135,
0.14042553951270548,
0.15684548881410224,
0.15089630427895054,
0.15390386308252524,
0.15754301616669614,
0.1922590970216916,
0.194306053544715
]
}
],
"layout": {
"showlegend": false,
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"title": {
"text": "Loss by eliminated features"
},
"updatemenus": [
{
"active": 0,
"buttons": [
{
"args": [
{
"visible": [
true,
true,
false,
false
]
}
],
"label": "Hide features",
"method": "update"
},
{
"args": [
{
"visible": [
true,
true,
true,
false
]
}
],
"label": "Show indices",
"method": "update"
},
{
"args": [
{
"visible": [
true,
true,
false,
true
]
}
],
"label": "Show names",
"method": "update"
}
],
"pad": {
"r": 10,
"t": 10
},
"showactive": true,
"x": -0.25,
"xanchor": "left",
"y": 1.03,
"yanchor": "top"
}
],
"xaxis": {
"gridcolor": "rgb(255,255,255)",
"showgrid": true,
"showline": false,
"showticklabels": true,
"tickcolor": "rgb(127,127,127)",
"ticks": "outside",
"title": {
"text": "number of removed features"
},
"zeroline": false
},
"yaxis": {
"gridcolor": "rgb(255,255,255)",
"showgrid": true,
"showline": false,
"showticklabels": true,
"tickcolor": "rgb(127,127,127)",
"tickfont": {
"color": "rgb(51,160,44)"
},
"ticks": "outside",
"title": {
"font": {
"color": "rgb(51,160,44)"
},
"text": "loss value"
},
"zeroline": false
}
}
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Feature selection do catboost\n",
"model = CatBoostClassifier(**params)\n",
"\n",
"num_features_to_select=1\n",
"steps=len(df_train.columns)\n",
"features_for_select=list(df_train.columns.drop(\"target\"))\n",
"\n",
"print(f\"Selecting {num_features_to_select} features out of {len(features_for_select)} using {steps} steps\")\n",
"\n",
"# Criando objeto para capturar os logs do feature selection\n",
"from io import StringIO\n",
"text_results = StringIO()\n",
"\n",
"feature_selection_results = model.select_features(\n",
" pool_train,\n",
" eval_set=pool_eval,\n",
" features_for_select=features_for_select,\n",
" num_features_to_select=num_features_to_select,\n",
" steps=steps,\n",
" plot=True,\n",
" algorithm=EFeaturesSelectionAlgorithm.RecursiveByShapValues,\n",
" shap_calc_type=EShapCalcType.Regular,\n",
" train_final_model=True,\n",
" log_cout=text_results,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 215,
"id": "7bef49e1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"metric_ids=[0, 4, 8, 11, 14, 17, 20, 22, 24, 27, 29, 30, 32, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 44, 45, 46, 46, 47, 47, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 51, 52, 52, 52, 52, 52, 52, 52, 53, 53, 53, 53, 53, 53]\n",
"metric_values=[0.5757575758, 0.5396825397, 0.5846153846, 0.5625, 0.6176470588, 0.676056338, 0.5172413793, 0.6060606061, 0.5901639344, 0.5517241379, 0.5625, 0.5714285714, 0.59375, 0.5454545455, 0.59375, 0.5757575758, 0.6268656716, 0.5714285714, 0.5846153846, 0.676056338, 0.6567164179, 0.6764705882, 0.6956521739, 0.6956521739, 0.5714285714, 0.6470588235, 0.6470588235, 0.6376811594, 0.6376811594, 0.5573770492, 0.5573770492, 0.5573770492, 0.5573770492, 0.5573770492, 0.6031746032, 0.6031746032, 0.6031746032, 0.5901639344, 0.5901639344, 0.5901639344, 0.5901639344, 0.2857142857, 0.2857142857, 0.2857142857, 0.2857142857, 0.2857142857, 0.2857142857, 0.2857142857, 0.2909090909, 0.2909090909, 0.2909090909, 0.2909090909, 0.2909090909, 0.2909090909]\n",
"metric_best_iters=[318, 284, 391, 245, 474, 546, 157, 410, 196, 154, 235, 290, 233, 285, 229, 269, 385, 195, 356, 485, 297, 568, 382, 382, 174, 340, 340, 342, 342, 216, 216, 292, 292, 292, 329, 329, 329, 318, 318, 318, 318, 78, 78, 78, 78, 78, 78, 78, 182, 182, 182, 182, 182, 182]\n",
"Highest metric is best.\n",
"best_metric_id=44 best_metric=0.6956521739\n"
]
}
],
"source": [
"# Tratamento dos logs para extrair os valores das melhores métricas\n",
"import re\n",
"import numpy as np\n",
"\n",
"num_results = re.findall(\"Feature #([0-9]+) eliminated|bestTest = ([0-9\\.]+)|bestIteration = ([0-9]+)\", text_results.getvalue())\n",
"\n",
"metric_ids = []\n",
"metric_values = []\n",
"metric_best_iters = []\n",
"id = 0\n",
"for feature_eliminated, best_test, best_iteration in num_results:\n",
" if feature_eliminated:\n",
" id += 1\n",
" if best_test:\n",
" metric_ids.append(id)\n",
" metric_values.append(float(best_test))\n",
" if best_iteration:\n",
" metric_best_iters.append(int(best_iteration))\n",
"print(f\"{metric_ids=}\\n{metric_values=}\\n{metric_best_iters=}\")\n",
"\n",
"metric_ids = np.array(metric_ids)\n",
"metric_values = np.array(metric_values)\n",
"\n",
"METRIC_LESSER_IS_BETTER = False\n",
"if METRIC_LESSER_IS_BETTER:\n",
" print(\"Lowest metric is best.\")\n",
" best_metric_id = int(metric_ids[metric_values.argmin()])\n",
" best_metric = float(metric_values.min())\n",
"else:\n",
" print(\"Highest metric is best.\")\n",
" best_metric_id = int(metric_ids[metric_values.argmax()])\n",
" best_metric = float(metric_values.max())\n",
"print(f\"{best_metric_id=} {best_metric=}\")"
]
},
{
"cell_type": "code",
"execution_count": 216,
"id": "1328700b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"len(losses)=54 losses=[0.1477282443515885, 0.14360728665267491, 0.1431814290361284, 0.14298460583655714, 0.14282510476834823, 0.138720506542803, 0.1379514449745828, 0.1373899846007138, 0.13718637249828747, 0.1396625950309679, 0.14072820600529518, 0.14097406332773885, 0.13595778276128634, 0.1346676493523341, 0.13361166554323825, 0.13187604508174725, 0.13120878267479325, 0.13104644007439964, 0.1517044460271607, 0.15171273628228008, 0.1517137533959223, 0.13722506984594268, 0.13210261129011047, 0.1443178728641258, 0.14431787286941414, 0.1458053618311109, 0.14729285079280766, 0.14878033975450441, 0.14132961172824576, 0.14132961172824576, 0.14027958275842997, 0.14027958275843014, 0.14037541750914248, 0.13926717055366758, 0.13797897441609677, 0.14212675332306965, 0.14041561086454277, 0.13757352885677757, 0.13510393460777329, 0.14409481128293467, 0.14084513882448946, 0.12712209783860534, 0.1378335582278886, 0.13478887732444192, 0.13403321273113225, 0.14991679743231506, 0.1350558793736135, 0.14042553951270548, 0.15684548881410224, 0.15089630427895054, 0.15390386308252524, 0.15754301616669614, 0.1922590970216916, 0.194306053544715]\n",
"Lowest loss is best.\n",
"best_loss_id=41 best_loss=0.12712209783860534\n"
]
}
],
"source": [
"# Extraindo melhor loss\n",
"LOSS_LESSER_IS_BETTER = True\n",
"losses = feature_selection_results[\"loss_graph\"][\"loss_values\"]\n",
"print(f\"{len(losses)=} {losses=}\")\n",
"losses = np.array(losses)\n",
"if LOSS_LESSER_IS_BETTER:\n",
" print(\"Lowest loss is best.\")\n",
" best_loss_id = int(losses.argmin())\n",
"else:\n",
" print(\"Highest loss is best.\")\n",
" best_loss_id = int(losses.argmax())\n",
"best_loss = float(losses[best_loss_id])\n",
"print(f\"{best_loss_id=} {best_loss=}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 217,
"id": "ca6525b7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"len(features_ordered)=54 features_ordered=['elevation', 'hillshade_3pm', 'horizontal_distance_to_roadways', 'soil_type_38', 'wilderness_area_0', 'wilderness_area_3', 'horizontal_distance_to_fire_points', 'soil_type_9', 'horizontal_distance_to_hydrology', 'soil_type_37', 'soil_type_29', 'soil_type_3', 'wilderness_area_2', 'hillshade_9am', 'soil_type_30', 'soil_type_32', 'soil_type_4', 'soil_type_11', 'soil_type_1', 'soil_type_21', 'soil_type_23', 'soil_type_28', 'soil_type_22', 'soil_type_35', 'soil_type_25', 'soil_type_16', 'soil_type_18', 'soil_type_24', 'soil_type_20', 'soil_type_14', 'soil_type_7', 'soil_type_8', 'soil_type_19', 'soil_type_10', 'soil_type_6', 'soil_type_26', 'hillshade_noon', 'soil_type_17', 'soil_type_34', 'soil_type_33', 'soil_type_5', 'soil_type_27', 'soil_type_36', 'wilderness_area_1', 'soil_type_12', 'aspect', 'soil_type_15', 'soil_type_2', 'slope', 'soil_type_31', 'soil_type_0', 'soil_type_13', 'soil_type_39', 'vertical_distance_to_hydrology']\n",
"Choosing best number of features based on best metric.\n",
"best_n_features=10\n",
"features_best=['elevation', 'hillshade_3pm', 'horizontal_distance_to_roadways', 'soil_type_38', 'wilderness_area_0', 'wilderness_area_3', 'horizontal_distance_to_fire_points', 'soil_type_9', 'horizontal_distance_to_hydrology', 'soil_type_37']\n"
]
}
],
"source": [
"# Montando lista de features ordenada pela feature selection e pegando melhor subconjunto de features\n",
"features_ordered = (\n",
" feature_selection_results['selected_features_names'] + \n",
" feature_selection_results['eliminated_features_names'][::-1]\n",
")\n",
"print(f\"{len(features_ordered)=} {features_ordered=}\")\n",
"\n",
"BEST_N_FEATURES_BY_METRIC = True\n",
"if BEST_N_FEATURES_BY_METRIC:\n",
" print(\"Choosing best number of features based on best metric.\")\n",
" best_n_features = len(features_ordered) - best_metric_id\n",
"else:\n",
" print(\"Choosing best number of features based on best loss.\")\n",
" best_n_features = len(features_ordered) - best_loss_id\n",
"print(f\"{best_n_features=}\")\n",
"features_best = features_ordered[:best_n_features]\n",
"print(f\"{features_best=}\")"
]
},
{
"cell_type": "code",
"execution_count": 219,
"id": "c152cd9f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot resumo da seleção de features\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"fig, ax1 = plt.subplots()\n",
"losses = np.array(losses)\n",
"ax1.plot(losses, color='orange', label='Loss Values', marker='.')\n",
"ax1.scatter(metric_ids, losses[metric_ids], color='orange', marker=\"o\")\n",
"ax1.scatter(losses.argmin(), losses.min(), color=\"red\", label=f\"Min {losses.min():0.2f}\", marker=\"v\")\n",
"ax1.scatter(losses.argmax(), losses.max(), color=\"blue\", label=f\"Max {losses.max():0.2f}\", marker=\"^\")\n",
"ax1.grid()\n",
"ax1.set_xlabel(\"Num Features Removed\")\n",
"ax1.set_ylabel(\"Loss Values\", color='orange')\n",
"\n",
"ax2 = ax1.twinx()\n",
"metric_ids = np.array(metric_ids)\n",
"metric_values = np.array(metric_values)\n",
"ax2.plot(metric_ids, metric_values, color='green', label='Metric Values', marker='.')\n",
"ax2.scatter(metric_ids[metric_values.argmin()], metric_values.min(), color=\"red\", label=f\"Min {metric_values.min():0.2f}\", marker=\"v\")\n",
"ax2.scatter(metric_ids[metric_values.argmax()], metric_values.max(), color=\"blue\", label=f\"Max {metric_values.max():0.2f}\", marker=\"^\")\n",
"ax2.set_ylabel(\"Metric Values\", color='green')\n",
"\n",
"fig.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 220,
"id": "2c91f39d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a7ec5a1d71624cc7b02dfb1037f8705f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Learning rate set to 0.033379\n",
"0:\tlearn: 0.1695761\ttest: 0.1363636\tbest: 0.1363636 (0)\ttotal: 7.46ms\tremaining: 14.9s\n",
"100:\tlearn: 0.4505263\ttest: 0.4074074\tbest: 0.4074074 (92)\ttotal: 691ms\tremaining: 13s\n",
"200:\tlearn: 0.6068702\ttest: 0.5483871\tbest: 0.5483871 (186)\ttotal: 1.4s\tremaining: 12.5s\n",
"300:\tlearn: 0.6881720\ttest: 0.6363636\tbest: 0.6363636 (266)\ttotal: 2.08s\tremaining: 11.8s\n",
"400:\tlearn: 0.7547170\ttest: 0.6567164\tbest: 0.6764706 (330)\ttotal: 2.94s\tremaining: 11.7s\n",
"Stopped by overfitting detector (100 iterations wait)\n",
"\n",
"bestTest = 0.6764705882\n",
"bestIteration = 330\n",
"\n",
"Shrink model to first 331 iterations.\n"
]
},
{
"data": {
"text/plain": [
"<catboost.core.CatBoostClassifier at 0x1a4dabb9b90>"
]
},
"execution_count": 220,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Treinando modelo com as melhores features\n",
"pool_train = Pool(data=df_train[features_best], label=df_train[\"target\"])\n",
"pool_eval = Pool(data=df_eval[features_best], label=df_eval[\"target\"])\n",
"pool_test = Pool(data=df_test[features_best], label=df_test[\"target\"])\n",
"\n",
"model_final = CatBoostClassifier(**params)\n",
"\n",
"model_final.fit(\n",
" pool_train,\n",
" eval_set=pool_eval,\n",
" plot=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 222,
"id": "84d0a81e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.95 0.99 0.97 1062\n",
" 1 0.77 0.41 0.54 100\n",
"\n",
" accuracy 0.94 1162\n",
" macro avg 0.86 0.70 0.75 1162\n",
"weighted avg 0.93 0.94 0.93 1162\n",
"\n",
"AUC: 0.9431073446327685\n"
]
}
],
"source": [
"# Classification report\n",
"import numpy as np\n",
"from sklearn.metrics import classification_report, roc_auc_score\n",
"\n",
"y_pred_proba = model_final.predict_proba(pool_test)\n",
"y_pred = model_final.predict(pool_test)\n",
"\n",
"print(classification_report(df_test[\"target\"], y_pred, zero_division=0))\n",
"print(f\"AUC: {roc_auc_score(df_test['target'], y_pred_proba[:, 1], multi_class='ovr')}\")"
]
},
{
"cell_type": "code",
"execution_count": 224,
"id": "9256078e",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# ROC AUC\n",
"from sklearn.metrics import roc_curve\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fpr, tpr, thresholds = roc_curve(df_test[\"target\"], y_pred_proba[:, 1])\n",
"plt.plot(fpr, tpr)\n",
"plt.plot([0, 1], [0, 1], linestyle='--', color='gray')\n",
"plt.xlabel(\"False Positive Rate\")\n",
"plt.ylabel(\"True Positive Rate\")\n",
"plt.grid()"
]
},
{
"cell_type": "code",
"execution_count": 225,
"id": "9d0f7bd1",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# calibration curve\n",
"from sklearn.calibration import calibration_curve\n",
"prob_true, prob_pred = calibration_curve(df_test[\"target\"], y_pred_proba[:, 1], n_bins=20)\n",
"plt.plot(prob_pred, prob_true, marker='o')\n",
"plt.plot([0, 1], [0, 1], linestyle='--', color='gray')\n",
"plt.grid()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cat",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment