You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1063 lines
32 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "720f2b62",
"metadata": {},
"source": [
"# Stacking"
]
},
{
"cell_type": "markdown",
"id": "0b365a02",
"metadata": {},
"source": [
"## 先说结论该数据集fetch_covtypeStacking的方法相比Blending和线性加权更好\n",
"比赛中我们常用线性加权作为最终的融合方式,我们同样也会好奇怎样的线性加权权重更好,下面也会举例子\n",
"参考https://github.com/rushter/heamy/tree/master/examples"
]
},
{
"cell_type": "markdown",
"id": "cc8fecb1",
"metadata": {},
"source": [
"通过对训练集进行五折验证,将验证结果作为第二层的训练和测试集合\n",
"<img src=\"assets/stacking.jpg\" width=\"50%\">"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "18a12000",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting heamy\n",
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/20/32/2f3e1efa38a8e34f790d90b6d49ef06ab812181ae896c50e89b8750fa5a0/heamy-0.0.7.tar.gz (30 kB)\n",
"Requirement already satisfied: scikit-learn>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (0.24.1)\n",
"Requirement already satisfied: pandas>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.2.4)\n",
"Requirement already satisfied: six>=1.10.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.15.0)\n",
"Requirement already satisfied: scipy>=0.16.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.6.2)\n",
"Requirement already satisfied: numpy>=1.7.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.19.5)\n",
"Requirement already satisfied: pytz>=2017.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2021.1)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2.8.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (2.1.0)\n",
"Requirement already satisfied: joblib>=0.11 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (1.0.1)\n",
"Building wheels for collected packages: heamy\n",
" Building wheel for heamy (setup.py): started\n",
" Building wheel for heamy (setup.py): finished with status 'done'\n",
" Created wheel for heamy: filename=heamy-0.0.7-py2.py3-none-any.whl size=15353 sha256=e3ba65b34e2bdee3b90b45b637e28836afdbdb0c9547f76b36fe10d17f8aba8f\n",
" Stored in directory: c:\\users\\administrator\\appdata\\local\\pip\\cache\\wheels\\6e\\f1\\7d\\048e558da94f495a0ed0d9c09d312e73eb176a092e36774ec2\n",
"Successfully built heamy\n",
"Installing collected packages: heamy\n",
"Successfully installed heamy-0.0.7\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"pip install heamy # 安装相关包"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "69632c6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]\n"
]
}
],
"source": [
"import sys\n",
"print(sys.version) # 版本信息"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ca421279",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import time\n",
"\n",
"from heamy.dataset import Dataset\n",
"from heamy.estimator import Classifier \n",
"from heamy.pipeline import ModelsPipeline\n",
"# 导入相关模型没有的pip install xxx 即可\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"import xgboost as xgb \n",
"import lightgbm as lgb \n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"from sklearn.metrics import log_loss"
]
},
{
"cell_type": "markdown",
"id": "2592fbbd",
"metadata": {},
"source": [
"## 准备数据集"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9a0fabe1",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_covtype\n",
"data = fetch_covtype()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5bd75178",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"七分类任务,处理前: [1 2 3 4 5 6 7]\n",
"[5 5 2 ... 3 3 3]\n",
"七分类任务,处理后: [0. 1. 2. 3. 4. 5. 6.]\n",
"[4. 4. 1. ... 2. 2. 2.]\n"
]
}
],
"source": [
"# 预处理\n",
"X, y = data['data'], data['target']\n",
"# 由于模型标签需要从0开始所以数字需要全部减1\n",
"print('七分类任务,处理前:',np.unique(y))\n",
"print(y)\n",
"ord = OrdinalEncoder()\n",
"y = ord.fit_transform(y.reshape(-1, 1))\n",
"y = y.reshape(-1, )\n",
"print('七分类任务,处理后:',np.unique(y))\n",
"print(y)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "23d9778c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(435759, 54)\n",
"(145253, 54)\n"
]
}
],
"source": [
"# 切分训练和测试集\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=42)\n",
"print(X_train.shape)\n",
"print(X_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eac48668",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset(5c3ccfb5c81451d098565ef5e7e36ac5)\n"
]
}
],
"source": [
"# 创建数据集\n",
"'''use_cache : bool, default True\n",
" If use_cache=True then preprocessing step will be cached until function codeis changed.'''\n",
"dataset = Dataset(X_train=X_train, y_train=y_train, X_test=X_test, y_test=None,use_cache=True) # 注意这里的y_test=None即不存在数据泄露\n",
"print(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fba3f975",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[2.833e+03, 2.580e+02, 2.600e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.008e+03, 4.500e+01, 2.000e+00, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [2.949e+03, 0.000e+00, 1.100e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" ...,\n",
" [3.153e+03, 2.870e+02, 1.700e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.065e+03, 3.480e+02, 2.100e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.021e+03, 2.600e+01, 1.600e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 处理后的数据集\n",
"dataset.X_train"
]
},
{
"cell_type": "markdown",
"id": "d4517ea1",
"metadata": {},
"source": [
"## 定义模型"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e8393e73",
"metadata": {},
"outputs": [],
"source": [
"def xgb_model(X_train, y_train, X_test, y_test):\n",
" \"\"\"参数必须为X_train,y_train,X_test,y_test\"\"\"\n",
" # 可以内置参数\n",
" params = {'objective': 'multi:softprob',\n",
" \"eval_metric\": 'mlogloss',\n",
" \"verbosity\": 0,\n",
" 'num_class': 7,\n",
" 'nthread': -1}\n",
" dtrain = xgb.DMatrix(X_train, y_train)\n",
" dtest = xgb.DMatrix(X_test)\n",
" model = xgb.train(params, dtrain, num_boost_round=300)\n",
" predict = model.predict(dtest)\n",
" return predict # 返回值必须为X_test的预测\n",
"\n",
"\n",
"def lgb_model(X_train, y_train, X_test, y_test,**parameters):\n",
" # 也可以开放参数接口\n",
" if parameters is None:\n",
" parameters = {}\n",
" lgb_train = lgb.Dataset(X_train, y_train)\n",
" model = lgb.train(params=parameters, train_set=lgb_train,num_boost_round=300)\n",
" predict = model.predict(X_test)\n",
" return predict\n",
"\n",
"\n",
"def rf_model(X_train, y_train, X_test, y_test):\n",
" params = {\"n_estimators\": 100, \"n_jobs\": -1}\n",
" model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
" predict = model.predict_proba(X_test)\n",
" return predict"
]
},
{
"cell_type": "markdown",
"id": "0715cf6e",
"metadata": {},
"source": [
"## 构建和训练模型"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "78ab0083",
"metadata": {},
"outputs": [],
"source": [
"params = {\"objective\": \"multiclass\",\n",
" \"num_class\": 7,\n",
" \"n_jobs\": -1,\n",
" \"verbose\": -4, \n",
" \"metric\": (\"multi_logloss\",)}\n",
"\n",
"model_xgb = Classifier(dataset=dataset, estimator=xgb_model, name='xgb',use_cache=False)\n",
"model_lgb = Classifier(dataset=dataset, estimator=lgb_model, name='lgb',parameters=params,use_cache=False)\n",
"model_rf = Classifier(dataset=dataset, estimator=rf_model,name='rf',use_cache=False)\n",
"\n",
"pipeline = ModelsPipeline(model_xgb, model_lgb, model_rf)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "173ef0f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best Score (log_loss): 0.18646435443714865\n",
"Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
"Wall time: 14min 19s\n"
]
},
{
"data": {
"text/plain": [
"array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"pipeline.find_weights(scorer=log_loss) # 输出最优权重组合"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "80726d19",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 1h 39min 20s\n"
]
}
],
"source": [
"%%time\n",
"# 5折训练构建5折模型特征集这里比较耗时\n",
"stack_ds = pipeline.stack(k=5,seed=42)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "b25bba3c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>xgb_0</th>\n",
" <th>xgb_1</th>\n",
" <th>xgb_2</th>\n",
" <th>xgb_3</th>\n",
" <th>xgb_4</th>\n",
" <th>xgb_5</th>\n",
" <th>xgb_6</th>\n",
" <th>lgb_0</th>\n",
" <th>lgb_1</th>\n",
" <th>lgb_2</th>\n",
" <th>...</th>\n",
" <th>lgb_4</th>\n",
" <th>lgb_5</th>\n",
" <th>lgb_6</th>\n",
" <th>rf_0</th>\n",
" <th>rf_1</th>\n",
" <th>rf_2</th>\n",
" <th>rf_3</th>\n",
" <th>rf_4</th>\n",
" <th>rf_5</th>\n",
" <th>rf_6</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.177179</td>\n",
" <td>0.818728</td>\n",
" <td>2.185222e-07</td>\n",
" <td>9.264143e-09</td>\n",
" <td>4.090067e-03</td>\n",
" <td>1.725062e-06</td>\n",
" <td>1.048052e-06</td>\n",
" <td>0.179625</td>\n",
" <td>0.804684</td>\n",
" <td>0.000003</td>\n",
" <td>...</td>\n",
" <td>1.562435e-02</td>\n",
" <td>6.370849e-05</td>\n",
" <td>1.004259e-08</td>\n",
" <td>0.03</td>\n",
" <td>0.96</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.01</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.005155</td>\n",
" <td>0.994845</td>\n",
" <td>7.055579e-10</td>\n",
" <td>1.326343e-08</td>\n",
" <td>6.331572e-09</td>\n",
" <td>1.435787e-09</td>\n",
" <td>1.603579e-10</td>\n",
" <td>0.008114</td>\n",
" <td>0.991886</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.13</td>\n",
" <td>0.87</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.293492</td>\n",
" <td>0.706508</td>\n",
" <td>3.650662e-10</td>\n",
" <td>1.017633e-09</td>\n",
" <td>8.823530e-09</td>\n",
" <td>6.384080e-10</td>\n",
" <td>2.823794e-08</td>\n",
" <td>0.831445</td>\n",
" <td>0.168555</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>4.999034e-07</td>\n",
" <td>4.015190e-09</td>\n",
" <td>4.997854e-09</td>\n",
" <td>0.63</td>\n",
" <td>0.37</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.478112</td>\n",
" <td>0.521816</td>\n",
" <td>3.207779e-06</td>\n",
" <td>2.878019e-08</td>\n",
" <td>1.076500e-08</td>\n",
" <td>2.230641e-06</td>\n",
" <td>6.630235e-05</td>\n",
" <td>0.465733</td>\n",
" <td>0.534184</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>8.245405e-05</td>\n",
" <td>0.55</td>\n",
" <td>0.45</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.992430</td>\n",
" <td>0.006652</td>\n",
" <td>1.233117e-05</td>\n",
" <td>1.887496e-07</td>\n",
" <td>1.569583e-06</td>\n",
" <td>5.604260e-07</td>\n",
" <td>9.037877e-04</td>\n",
" <td>0.932050</td>\n",
" <td>0.043451</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>2.449972e-02</td>\n",
" <td>0.97</td>\n",
" <td>0.03</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
"0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 1.725062e-06 \n",
"1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 1.435787e-09 \n",
"2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 6.384080e-10 \n",
"3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 2.230641e-06 \n",
"4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 5.604260e-07 \n",
"\n",
" xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 \\\n",
"0 1.048052e-06 0.179625 0.804684 0.000003 ... 1.562435e-02 \n",
"1 1.603579e-10 0.008114 0.991886 0.000000 ... 0.000000e+00 \n",
"2 2.823794e-08 0.831445 0.168555 0.000000 ... 4.999034e-07 \n",
"3 6.630235e-05 0.465733 0.534184 0.000000 ... 0.000000e+00 \n",
"4 9.037877e-04 0.932050 0.043451 0.000000 ... 0.000000e+00 \n",
"\n",
" lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
"0 6.370849e-05 1.004259e-08 0.03 0.96 0.0 0.0 0.01 0.0 0.0 \n",
"1 0.000000e+00 0.000000e+00 0.13 0.87 0.0 0.0 0.00 0.0 0.0 \n",
"2 4.015190e-09 4.997854e-09 0.63 0.37 0.0 0.0 0.00 0.0 0.0 \n",
"3 0.000000e+00 8.245405e-05 0.55 0.45 0.0 0.0 0.00 0.0 0.0 \n",
"4 0.000000e+00 2.449972e-02 0.97 0.03 0.0 0.0 0.00 0.0 0.0 \n",
"\n",
"[5 rows x 21 columns]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 模型输出的训练集7个特征对应7个标签的预测概率\n",
"stack_ds.X_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "835205e9",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>xgb_0</th>\n",
" <th>xgb_1</th>\n",
" <th>xgb_2</th>\n",
" <th>xgb_3</th>\n",
" <th>xgb_4</th>\n",
" <th>xgb_5</th>\n",
" <th>xgb_6</th>\n",
" <th>lgb_0</th>\n",
" <th>lgb_1</th>\n",
" <th>lgb_2</th>\n",
" <th>...</th>\n",
" <th>lgb_4</th>\n",
" <th>lgb_5</th>\n",
" <th>lgb_6</th>\n",
" <th>rf_0</th>\n",
" <th>rf_1</th>\n",
" <th>rf_2</th>\n",
" <th>rf_3</th>\n",
" <th>rf_4</th>\n",
" <th>rf_5</th>\n",
" <th>rf_6</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.991493</td>\n",
" <td>0.000210</td>\n",
" <td>4.796058e-06</td>\n",
" <td>6.178684e-08</td>\n",
" <td>6.947614e-07</td>\n",
" <td>2.410490e-09</td>\n",
" <td>8.291459e-03</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000</td>\n",
" <td>1.000000e+00</td>\n",
" <td>0.99</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" <td>0.01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.024731</td>\n",
" <td>0.964372</td>\n",
" <td>6.387765e-04</td>\n",
" <td>4.205048e-08</td>\n",
" <td>1.006575e-02</td>\n",
" <td>1.879628e-04</td>\n",
" <td>4.830114e-06</td>\n",
" <td>0.065073</td>\n",
" <td>0.877354</td>\n",
" <td>0.001678</td>\n",
" <td>...</td>\n",
" <td>5.530599e-02</td>\n",
" <td>0.000589</td>\n",
" <td>8.601484e-11</td>\n",
" <td>0.09</td>\n",
" <td>0.80</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>0.06</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.000780</td>\n",
" <td>0.979776</td>\n",
" <td>8.593459e-04</td>\n",
" <td>1.267791e-07</td>\n",
" <td>1.710379e-02</td>\n",
" <td>1.477527e-03</td>\n",
" <td>2.521008e-06</td>\n",
" <td>0.005164</td>\n",
" <td>0.933849</td>\n",
" <td>0.016355</td>\n",
" <td>...</td>\n",
" <td>3.657553e-02</td>\n",
" <td>0.008057</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.01</td>\n",
" <td>0.97</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.042695</td>\n",
" <td>0.957304</td>\n",
" <td>2.283268e-08</td>\n",
" <td>4.387427e-08</td>\n",
" <td>4.175481e-07</td>\n",
" <td>4.406019e-08</td>\n",
" <td>6.909629e-10</td>\n",
" <td>0.054392</td>\n",
" <td>0.945608</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>4.285638e-08</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.04</td>\n",
" <td>0.96</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.000457</td>\n",
" <td>0.999334</td>\n",
" <td>3.366338e-06</td>\n",
" <td>4.893879e-08</td>\n",
" <td>2.045808e-04</td>\n",
" <td>7.889498e-07</td>\n",
" <td>1.415576e-09</td>\n",
" <td>0.001367</td>\n",
" <td>0.995857</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>2.776106e-03</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.00</td>\n",
" <td>0.99</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.01</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
"0 0.991493 0.000210 4.796058e-06 6.178684e-08 6.947614e-07 2.410490e-09 \n",
"1 0.024731 0.964372 6.387765e-04 4.205048e-08 1.006575e-02 1.879628e-04 \n",
"2 0.000780 0.979776 8.593459e-04 1.267791e-07 1.710379e-02 1.477527e-03 \n",
"3 0.042695 0.957304 2.283268e-08 4.387427e-08 4.175481e-07 4.406019e-08 \n",
"4 0.000457 0.999334 3.366338e-06 4.893879e-08 2.045808e-04 7.889498e-07 \n",
"\n",
" xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 lgb_5 \\\n",
"0 8.291459e-03 0.000000 0.000000 0.000000 ... 0.000000e+00 0.000000 \n",
"1 4.830114e-06 0.065073 0.877354 0.001678 ... 5.530599e-02 0.000589 \n",
"2 2.521008e-06 0.005164 0.933849 0.016355 ... 3.657553e-02 0.008057 \n",
"3 6.909629e-10 0.054392 0.945608 0.000000 ... 4.285638e-08 0.000000 \n",
"4 1.415576e-09 0.001367 0.995857 0.000000 ... 2.776106e-03 0.000000 \n",
"\n",
" lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
"0 1.000000e+00 0.99 0.00 0.00 0.0 0.00 0.00 0.01 \n",
"1 8.601484e-11 0.09 0.80 0.05 0.0 0.06 0.00 0.00 \n",
"2 0.000000e+00 0.01 0.97 0.00 0.0 0.01 0.01 0.00 \n",
"3 0.000000e+00 0.04 0.96 0.00 0.0 0.00 0.00 0.00 \n",
"4 0.000000e+00 0.00 0.99 0.00 0.0 0.01 0.00 0.00 \n",
"\n",
"[5 rows x 21 columns]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 模型输出的测试集7个特征对应7个标签的预测概率\n",
"stack_ds.X_test.head()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b9db35dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 1min 53s\n"
]
}
],
"source": [
"%%time\n",
"# 用lr做最后一层\n",
"stacker = Classifier(dataset=stack_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
"predict_stack = stacker.predict()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "a4a48219",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[9.99137967e-01 4.76574513e-04 3.32764186e-07 ... 3.03237700e-06\n",
" 1.62161211e-06 3.79326620e-04]\n",
" [2.00732175e-02 9.71682830e-01 1.51484266e-03 ... 5.69138554e-03\n",
" 9.40725428e-04 9.60822625e-05]\n",
" [5.54556002e-03 9.91048437e-01 8.04840682e-04 ... 2.11437934e-03\n",
" 4.56463787e-04 3.00919502e-05]\n",
" ...\n",
" [4.60179790e-06 1.78298095e-03 9.91553958e-01 ... 7.26752933e-04\n",
" 3.79135124e-03 1.53584401e-05]\n",
" [9.96307096e-01 2.43558944e-03 1.01596361e-07 ... 3.94985596e-05\n",
" 7.41569805e-06 1.20819024e-03]\n",
" [5.34671504e-05 7.62534718e-04 5.58323657e-03 ... 2.11410908e-04\n",
" 9.91805379e-01 1.69502656e-05]]\n"
]
}
],
"source": [
"print(predict_stack) # stacking后的结果"
]
},
{
"cell_type": "markdown",
"id": "1372d4f8",
"metadata": {},
"source": [
"## 验证结果"
]
},
{
"cell_type": "markdown",
"id": "52ef71d4",
"metadata": {},
"source": [
"### 单模分数"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "a28806a0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9254473229468583\n",
"0.8412562907478676\n",
"0.9535087055000585\n"
]
}
],
"source": [
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, :7].values, axis=1),y_test)) # XGB\n",
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 7:14].values, axis=1),y_test)) # LGB\n",
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 14:].values, axis=1),y_test)) # RF"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "1db92fec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9537840870756542\n",
"Wall time: 50.8 s\n"
]
}
],
"source": [
"%%time\n",
"# 测试单模运行结果是否一致\n",
"rf_predict = rf_model(X_train, y_train, X_test, None)\n",
"print(accuracy_score(np.argmax(rf_predict, axis=1),y_test)) # RF"
]
},
{
"cell_type": "markdown",
"id": "2e9423ce",
"metadata": {},
"source": [
"### 线性加权分数"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "d2b50ba4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"主观根据结果blending 0.9504106627746071\n",
"根据最优权重的线性加权: 0.9530749795184953\n"
]
}
],
"source": [
"# blending的分数\n",
"xgb_t = stack_ds.X_test.iloc[:, :7].values\n",
"lgb_t = stack_ds.X_test.iloc[:, 7:14].values\n",
"rf_t = stack_ds.X_test.iloc[:, 14:].values\n",
"\n",
"# 根据分数好坏随机定\n",
"result = 0.2*xgb_t+0.1*lgb_t+0.7*rf_t\n",
"print('主观根据结果blending', accuracy_score(np.argmax(result, axis=1), y_test))\n",
"# 根据上面提供的最优权重 Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
"result = 2.53464919e-01*xgb_t+1.48562205e-20*lgb_t+7.46535081e-01*rf_t\n",
"print('根据最优权重的线性加权:',accuracy_score(np.argmax(result, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
"id": "e8ce00b2",
"metadata": {},
"source": [
"可以观察到最优权重比我们主观选权重更优,单对比单模结果反而下降了"
]
},
{
"cell_type": "markdown",
"id": "d4a12d6d",
"metadata": {},
"source": [
"### Blending的分数"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ffe1cf4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 14min 10s\n"
]
}
],
"source": [
"%%time\n",
"blend_ds = pipeline.blend(seed=111)\n",
"blender = Classifier(dataset=blend_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
"predict_blend = blender.predict()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "506adffb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9546859617357301\n"
]
}
],
"source": [
"print(accuracy_score(np.argmax(predict_blend, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
"id": "e84c19ae",
"metadata": {},
"source": [
"使用Blending的分数有所提升"
]
},
{
"cell_type": "markdown",
"id": "e8daf1e3",
"metadata": {},
"source": [
"### Stacking的分数"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "4930b407",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9589887988544127\n"
]
}
],
"source": [
"print(accuracy_score(np.argmax(predict_stack, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
"id": "d5ad3c4b",
"metadata": {},
"source": [
"可以明显看到提升的效果"
]
},
{
"cell_type": "markdown",
"id": "6311fbd7",
"metadata": {},
"source": [
"## 再说结论该数据集fetch_covtypeStacking的方法相比Blending和线性加权更好"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}