You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1063 lines
32 KiB

{
"cells": [
{
"cell_type": "markdown",
"id": "720f2b62",
"metadata": {},
"source": [
"# Stacking"
]
},
{
"cell_type": "markdown",
"id": "0b365a02",
"metadata": {},
"source": [
"## 先说结论该数据集fetch_covtypeStacking的方法相比Blending和线性加权更好\n",
"比赛中我们常用线性加权作为最终的融合方式,我们同样也会好奇怎样的线性加权权重更好,下面也会举例子\n",
"参考https://github.com/rushter/heamy/tree/master/examples"
]
},
{
"cell_type": "markdown",
"id": "cc8fecb1",
"metadata": {},
"source": [
"通过对训练集进行五折验证,将验证结果作为第二层的训练和测试集合\n",
"<img src=\"assets/stacking.jpg\" width=\"50%\">"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "18a12000",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting heamy\n",
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/20/32/2f3e1efa38a8e34f790d90b6d49ef06ab812181ae896c50e89b8750fa5a0/heamy-0.0.7.tar.gz (30 kB)\n",
"Requirement already satisfied: scikit-learn>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (0.24.1)\n",
"Requirement already satisfied: pandas>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.2.4)\n",
"Requirement already satisfied: six>=1.10.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.15.0)\n",
"Requirement already satisfied: scipy>=0.16.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.6.2)\n",
"Requirement already satisfied: numpy>=1.7.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.19.5)\n",
"Requirement already satisfied: pytz>=2017.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2021.1)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2.8.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (2.1.0)\n",
"Requirement already satisfied: joblib>=0.11 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (1.0.1)\n",
"Building wheels for collected packages: heamy\n",
" Building wheel for heamy (setup.py): started\n",
" Building wheel for heamy (setup.py): finished with status 'done'\n",
" Created wheel for heamy: filename=heamy-0.0.7-py2.py3-none-any.whl size=15353 sha256=e3ba65b34e2bdee3b90b45b637e28836afdbdb0c9547f76b36fe10d17f8aba8f\n",
" Stored in directory: c:\\users\\administrator\\appdata\\local\\pip\\cache\\wheels\\6e\\f1\\7d\\048e558da94f495a0ed0d9c09d312e73eb176a092e36774ec2\n",
"Successfully built heamy\n",
"Installing collected packages: heamy\n",
"Successfully installed heamy-0.0.7\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"pip install heamy # 安装相关包"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "69632c6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]\n"
]
}
],
"source": [
"import sys\n",
"print(sys.version) # 版本信息"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ca421279",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import time\n",
"\n",
"from heamy.dataset import Dataset\n",
"from heamy.estimator import Classifier \n",
"from heamy.pipeline import ModelsPipeline\n",
"# 导入相关模型没有的pip install xxx 即可\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"import xgboost as xgb \n",
"import lightgbm as lgb \n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"from sklearn.metrics import log_loss"
]
},
{
"cell_type": "markdown",
"id": "2592fbbd",
"metadata": {},
"source": [
"## 准备数据集"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9a0fabe1",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_covtype\n",
"data = fetch_covtype()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5bd75178",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"七分类任务,处理前: [1 2 3 4 5 6 7]\n",
"[5 5 2 ... 3 3 3]\n",
"七分类任务,处理后: [0. 1. 2. 3. 4. 5. 6.]\n",
"[4. 4. 1. ... 2. 2. 2.]\n"
]
}
],
"source": [
"# 预处理\n",
"X, y = data['data'], data['target']\n",
"# 由于模型标签需要从0开始所以数字需要全部减1\n",
"print('七分类任务,处理前:',np.unique(y))\n",
"print(y)\n",
"ord = OrdinalEncoder()\n",
"y = ord.fit_transform(y.reshape(-1, 1))\n",
"y = y.reshape(-1, )\n",
"print('七分类任务,处理后:',np.unique(y))\n",
"print(y)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "23d9778c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(435759, 54)\n",
"(145253, 54)\n"
]
}
],
"source": [
"# 切分训练和测试集\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=42)\n",
"print(X_train.shape)\n",
"print(X_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eac48668",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset(5c3ccfb5c81451d098565ef5e7e36ac5)\n"
]
}
],
"source": [
"# 创建数据集\n",
"'''use_cache : bool, default True\n",
" If use_cache=True then preprocessing step will be cached until function codeis changed.'''\n",
"dataset = Dataset(X_train=X_train, y_train=y_train, X_test=X_test, y_test=None,use_cache=True) # 注意这里的y_test=None即不存在数据泄露\n",
"print(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fba3f975",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[2.833e+03, 2.580e+02, 2.600e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.008e+03, 4.500e+01, 2.000e+00, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [2.949e+03, 0.000e+00, 1.100e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" ...,\n",
" [3.153e+03, 2.870e+02, 1.700e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.065e+03, 3.480e+02, 2.100e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.021e+03, 2.600e+01, 1.600e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 处理后的数据集\n",
"dataset.X_train"
]
},
{
"cell_type": "markdown",
"id": "d4517ea1",
"metadata": {},
"source": [
"## 定义模型"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e8393e73",
"metadata": {},
"outputs": [],
"source": [
"def xgb_model(X_train, y_train, X_test, y_test):\n",
" \"\"\"参数必须为X_train,y_train,X_test,y_test\"\"\"\n",
" # 可以内置参数\n",
" params = {'objective': 'multi:softprob',\n",
" \"eval_metric\": 'mlogloss',\n",
" \"verbosity\": 0,\n",
" 'num_class': 7,\n",
" 'nthread': -1}\n",
" dtrain = xgb.DMatrix(X_train, y_train)\n",
" dtest = xgb.DMatrix(X_test)\n",
" model = xgb.train(params, dtrain, num_boost_round=300)\n",
" predict = model.predict(dtest)\n",
" return predict # 返回值必须为X_test的预测\n",
"\n",
"\n",
"def lgb_model(X_train, y_train, X_test, y_test,**parameters):\n",
" # 也可以开放参数接口\n",
" if parameters is None:\n",
" parameters = {}\n",
" lgb_train = lgb.Dataset(X_train, y_train)\n",
" model = lgb.train(params=parameters, train_set=lgb_train,num_boost_round=300)\n",
" predict = model.predict(X_test)\n",
" return predict\n",
"\n",
"\n",
"def rf_model(X_train, y_train, X_test, y_test):\n",
" params = {\"n_estimators\": 100, \"n_jobs\": -1}\n",
" model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
" predict = model.predict_proba(X_test)\n",
" return predict"
]
},
{
"cell_type": "markdown",
"id": "0715cf6e",
"metadata": {},
"source": [
"## 构建和训练模型"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "78ab0083",
"metadata": {},
"outputs": [],
"source": [
"params = {\"objective\": \"multiclass\",\n",
" \"num_class\": 7,\n",
" \"n_jobs\": -1,\n",
" \"verbose\": -4, \n",
" \"metric\": (\"multi_logloss\",)}\n",
"\n",
"model_xgb = Classifier(dataset=dataset, estimator=xgb_model, name='xgb',use_cache=False)\n",
"model_lgb = Classifier(dataset=dataset, estimator=lgb_model, name='lgb',parameters=params,use_cache=False)\n",
"model_rf = Classifier(dataset=dataset, estimator=rf_model,name='rf',use_cache=False)\n",
"\n",
"pipeline = ModelsPipeline(model_xgb, model_lgb, model_rf)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "173ef0f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best Score (log_loss): 0.18646435443714865\n",
"Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
"Wall time: 14min 19s\n"
]
},
{
"data": {
"text/plain": [
"array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"pipeline.find_weights(scorer=log_loss) # 输出最优权重组合"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "80726d19",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 1h 39min 20s\n"
]
}
],
"source": [
"%%time\n",
"# 5折训练构建5折模型特征集这里比较耗时\n",
"stack_ds = pipeline.stack(k=5,seed=42)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "b25bba3c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>xgb_0</th>\n",
" <th>xgb_1</th>\n",
" <th>xgb_2</th>\n",
" <th>xgb_3</th>\n",
" <th>xgb_4</th>\n",
" <th>xgb_5</th>\n",
" <th>xgb_6</th>\n",
" <th>lgb_0</th>\n",
" <th>lgb_1</th>\n",
" <th>lgb_2</th>\n",
" <th>...</th>\n",
" <th>lgb_4</th>\n",
" <th>lgb_5</th>\n",
" <th>lgb_6</th>\n",
" <th>rf_0</th>\n",
" <th>rf_1</th>\n",
" <th>rf_2</th>\n",
" <th>rf_3</th>\n",
" <th>rf_4</th>\n",
" <th>rf_5</th>\n",
" <th>rf_6</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.177179</td>\n",
" <td>0.818728</td>\n",
" <td>2.185222e-07</td>\n",
" <td>9.264143e-09</td>\n",
" <td>4.090067e-03</td>\n",
" <td>1.725062e-06</td>\n",
" <td>1.048052e-06</td>\n",
" <td>0.179625</td>\n",
" <td>0.804684</td>\n",
" <td>0.000003</td>\n",
" <td>...</td>\n",
" <td>1.562435e-02</td>\n",
" <td>6.370849e-05</td>\n",
" <td>1.004259e-08</td>\n",
" <td>0.03</td>\n",
" <td>0.96</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.01</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.005155</td>\n",
" <td>0.994845</td>\n",
" <td>7.055579e-10</td>\n",
" <td>1.326343e-08</td>\n",
" <td>6.331572e-09</td>\n",
" <td>1.435787e-09</td>\n",
" <td>1.603579e-10</td>\n",
" <td>0.008114</td>\n",
" <td>0.991886</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.13</td>\n",
" <td>0.87</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.293492</td>\n",
" <td>0.706508</td>\n",
" <td>3.650662e-10</td>\n",
" <td>1.017633e-09</td>\n",
" <td>8.823530e-09</td>\n",
" <td>6.384080e-10</td>\n",
" <td>2.823794e-08</td>\n",
" <td>0.831445</td>\n",
" <td>0.168555</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>4.999034e-07</td>\n",
" <td>4.015190e-09</td>\n",
" <td>4.997854e-09</td>\n",
" <td>0.63</td>\n",
" <td>0.37</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.478112</td>\n",
" <td>0.521816</td>\n",
" <td>3.207779e-06</td>\n",
" <td>2.878019e-08</td>\n",
" <td>1.076500e-08</td>\n",
" <td>2.230641e-06</td>\n",
" <td>6.630235e-05</td>\n",
" <td>0.465733</td>\n",
" <td>0.534184</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>8.245405e-05</td>\n",
" <td>0.55</td>\n",
" <td>0.45</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.992430</td>\n",
" <td>0.006652</td>\n",
" <td>1.233117e-05</td>\n",
" <td>1.887496e-07</td>\n",
" <td>1.569583e-06</td>\n",
" <td>5.604260e-07</td>\n",
" <td>9.037877e-04</td>\n",
" <td>0.932050</td>\n",
" <td>0.043451</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" <td>2.449972e-02</td>\n",
" <td>0.97</td>\n",
" <td>0.03</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
"0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 1.725062e-06 \n",
"1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 1.435787e-09 \n",
"2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 6.384080e-10 \n",
"3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 2.230641e-06 \n",
"4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 5.604260e-07 \n",
"\n",
" xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 \\\n",
"0 1.048052e-06 0.179625 0.804684 0.000003 ... 1.562435e-02 \n",
"1 1.603579e-10 0.008114 0.991886 0.000000 ... 0.000000e+00 \n",
"2 2.823794e-08 0.831445 0.168555 0.000000 ... 4.999034e-07 \n",
"3 6.630235e-05 0.465733 0.534184 0.000000 ... 0.000000e+00 \n",
"4 9.037877e-04 0.932050 0.043451 0.000000 ... 0.000000e+00 \n",
"\n",
" lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
"0 6.370849e-05 1.004259e-08 0.03 0.96 0.0 0.0 0.01 0.0 0.0 \n",
"1 0.000000e+00 0.000000e+00 0.13 0.87 0.0 0.0 0.00 0.0 0.0 \n",
"2 4.015190e-09 4.997854e-09 0.63 0.37 0.0 0.0 0.00 0.0 0.0 \n",
"3 0.000000e+00 8.245405e-05 0.55 0.45 0.0 0.0 0.00 0.0 0.0 \n",
"4 0.000000e+00 2.449972e-02 0.97 0.03 0.0 0.0 0.00 0.0 0.0 \n",
"\n",
"[5 rows x 21 columns]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 模型输出的训练集7个特征对应7个标签的预测概率\n",
"stack_ds.X_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "835205e9",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>xgb_0</th>\n",
" <th>xgb_1</th>\n",
" <th>xgb_2</th>\n",
" <th>xgb_3</th>\n",
" <th>xgb_4</th>\n",
" <th>xgb_5</th>\n",
" <th>xgb_6</th>\n",
" <th>lgb_0</th>\n",
" <th>lgb_1</th>\n",
" <th>lgb_2</th>\n",
" <th>...</th>\n",
" <th>lgb_4</th>\n",
" <th>lgb_5</th>\n",
" <th>lgb_6</th>\n",
" <th>rf_0</th>\n",
" <th>rf_1</th>\n",
" <th>rf_2</th>\n",
" <th>rf_3</th>\n",
" <th>rf_4</th>\n",
" <th>rf_5</th>\n",
" <th>rf_6</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.991493</td>\n",
" <td>0.000210</td>\n",
" <td>4.796058e-06</td>\n",
" <td>6.178684e-08</td>\n",
" <td>6.947614e-07</td>\n",
" <td>2.410490e-09</td>\n",
" <td>8.291459e-03</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000</td>\n",
" <td>1.000000e+00</td>\n",
" <td>0.99</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" <td>0.01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.024731</td>\n",
" <td>0.964372</td>\n",
" <td>6.387765e-04</td>\n",
" <td>4.205048e-08</td>\n",
" <td>1.006575e-02</td>\n",
" <td>1.879628e-04</td>\n",
" <td>4.830114e-06</td>\n",
" <td>0.065073</td>\n",
" <td>0.877354</td>\n",
" <td>0.001678</td>\n",
" <td>...</td>\n",
" <td>5.530599e-02</td>\n",
" <td>0.000589</td>\n",
" <td>8.601484e-11</td>\n",
" <td>0.09</td>\n",
" <td>0.80</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>0.06</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.000780</td>\n",
" <td>0.979776</td>\n",
" <td>8.593459e-04</td>\n",
" <td>1.267791e-07</td>\n",
" <td>1.710379e-02</td>\n",
" <td>1.477527e-03</td>\n",
" <td>2.521008e-06</td>\n",
" <td>0.005164</td>\n",
" <td>0.933849</td>\n",
" <td>0.016355</td>\n",
" <td>...</td>\n",
" <td>3.657553e-02</td>\n",
" <td>0.008057</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.01</td>\n",
" <td>0.97</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.042695</td>\n",
" <td>0.957304</td>\n",
" <td>2.283268e-08</td>\n",
" <td>4.387427e-08</td>\n",
" <td>4.175481e-07</td>\n",
" <td>4.406019e-08</td>\n",
" <td>6.909629e-10</td>\n",
" <td>0.054392</td>\n",
" <td>0.945608</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>4.285638e-08</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.04</td>\n",
" <td>0.96</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.000457</td>\n",
" <td>0.999334</td>\n",
" <td>3.366338e-06</td>\n",
" <td>4.893879e-08</td>\n",
" <td>2.045808e-04</td>\n",
" <td>7.889498e-07</td>\n",
" <td>1.415576e-09</td>\n",
" <td>0.001367</td>\n",
" <td>0.995857</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>2.776106e-03</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.00</td>\n",
" <td>0.99</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>0.01</td>\n",
" <td>0.00</td>\n",
" <td>0.00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
"0 0.991493 0.000210 4.796058e-06 6.178684e-08 6.947614e-07 2.410490e-09 \n",
"1 0.024731 0.964372 6.387765e-04 4.205048e-08 1.006575e-02 1.879628e-04 \n",
"2 0.000780 0.979776 8.593459e-04 1.267791e-07 1.710379e-02 1.477527e-03 \n",
"3 0.042695 0.957304 2.283268e-08 4.387427e-08 4.175481e-07 4.406019e-08 \n",
"4 0.000457 0.999334 3.366338e-06 4.893879e-08 2.045808e-04 7.889498e-07 \n",
"\n",
" xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 lgb_5 \\\n",
"0 8.291459e-03 0.000000 0.000000 0.000000 ... 0.000000e+00 0.000000 \n",
"1 4.830114e-06 0.065073 0.877354 0.001678 ... 5.530599e-02 0.000589 \n",
"2 2.521008e-06 0.005164 0.933849 0.016355 ... 3.657553e-02 0.008057 \n",
"3 6.909629e-10 0.054392 0.945608 0.000000 ... 4.285638e-08 0.000000 \n",
"4 1.415576e-09 0.001367 0.995857 0.000000 ... 2.776106e-03 0.000000 \n",
"\n",
" lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
"0 1.000000e+00 0.99 0.00 0.00 0.0 0.00 0.00 0.01 \n",
"1 8.601484e-11 0.09 0.80 0.05 0.0 0.06 0.00 0.00 \n",
"2 0.000000e+00 0.01 0.97 0.00 0.0 0.01 0.01 0.00 \n",
"3 0.000000e+00 0.04 0.96 0.00 0.0 0.00 0.00 0.00 \n",
"4 0.000000e+00 0.00 0.99 0.00 0.0 0.01 0.00 0.00 \n",
"\n",
"[5 rows x 21 columns]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 模型输出的测试集7个特征对应7个标签的预测概率\n",
"stack_ds.X_test.head()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b9db35dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 1min 53s\n"
]
}
],
"source": [
"%%time\n",
"# 用lr做最后一层\n",
"stacker = Classifier(dataset=stack_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
"predict_stack = stacker.predict()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "a4a48219",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[9.99137967e-01 4.76574513e-04 3.32764186e-07 ... 3.03237700e-06\n",
" 1.62161211e-06 3.79326620e-04]\n",
" [2.00732175e-02 9.71682830e-01 1.51484266e-03 ... 5.69138554e-03\n",
" 9.40725428e-04 9.60822625e-05]\n",
" [5.54556002e-03 9.91048437e-01 8.04840682e-04 ... 2.11437934e-03\n",
" 4.56463787e-04 3.00919502e-05]\n",
" ...\n",
" [4.60179790e-06 1.78298095e-03 9.91553958e-01 ... 7.26752933e-04\n",
" 3.79135124e-03 1.53584401e-05]\n",
" [9.96307096e-01 2.43558944e-03 1.01596361e-07 ... 3.94985596e-05\n",
" 7.41569805e-06 1.20819024e-03]\n",
" [5.34671504e-05 7.62534718e-04 5.58323657e-03 ... 2.11410908e-04\n",
" 9.91805379e-01 1.69502656e-05]]\n"
]
}
],
"source": [
"print(predict_stack) # stacking后的结果"
]
},
{
"cell_type": "markdown",
"id": "1372d4f8",
"metadata": {},
"source": [
"## 验证结果"
]
},
{
"cell_type": "markdown",
"id": "52ef71d4",
"metadata": {},
"source": [
"### 单模分数"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "a28806a0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9254473229468583\n",
"0.8412562907478676\n",
"0.9535087055000585\n"
]
}
],
"source": [
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, :7].values, axis=1),y_test)) # XGB\n",
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 7:14].values, axis=1),y_test)) # LGB\n",
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 14:].values, axis=1),y_test)) # RF"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "1db92fec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9537840870756542\n",
"Wall time: 50.8 s\n"
]
}
],
"source": [
"%%time\n",
"# 测试单模运行结果是否一致\n",
"rf_predict = rf_model(X_train, y_train, X_test, None)\n",
"print(accuracy_score(np.argmax(rf_predict, axis=1),y_test)) # RF"
]
},
{
"cell_type": "markdown",
"id": "2e9423ce",
"metadata": {},
"source": [
"### 线性加权分数"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "d2b50ba4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"主观根据结果blending 0.9504106627746071\n",
"根据最优权重的线性加权: 0.9530749795184953\n"
]
}
],
"source": [
"# blending的分数\n",
"xgb_t = stack_ds.X_test.iloc[:, :7].values\n",
"lgb_t = stack_ds.X_test.iloc[:, 7:14].values\n",
"rf_t = stack_ds.X_test.iloc[:, 14:].values\n",
"\n",
"# 根据分数好坏随机定\n",
"result = 0.2*xgb_t+0.1*lgb_t+0.7*rf_t\n",
"print('主观根据结果blending', accuracy_score(np.argmax(result, axis=1), y_test))\n",
"# 根据上面提供的最优权重 Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
"result = 2.53464919e-01*xgb_t+1.48562205e-20*lgb_t+7.46535081e-01*rf_t\n",
"print('根据最优权重的线性加权:',accuracy_score(np.argmax(result, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
"id": "e8ce00b2",
"metadata": {},
"source": [
"可以观察到最优权重比我们主观选权重更优,单对比单模结果反而下降了"
]
},
{
"cell_type": "markdown",
"id": "d4a12d6d",
"metadata": {},
"source": [
"### Blending的分数"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ffe1cf4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 14min 10s\n"
]
}
],
"source": [
"%%time\n",
"blend_ds = pipeline.blend(seed=111)\n",
"blender = Classifier(dataset=blend_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
"predict_blend = blender.predict()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "506adffb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9546859617357301\n"
]
}
],
"source": [
"print(accuracy_score(np.argmax(predict_blend, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
"id": "e84c19ae",
"metadata": {},
"source": [
"使用Blending的分数有所提升"
]
},
{
"cell_type": "markdown",
"id": "e8daf1e3",
"metadata": {},
"source": [
"### Stacking的分数"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "4930b407",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9589887988544127\n"
]
}
],
"source": [
"print(accuracy_score(np.argmax(predict_stack, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
"id": "d5ad3c4b",
"metadata": {},
"source": [
"可以明显看到提升的效果"
]
},
{
"cell_type": "markdown",
"id": "6311fbd7",
"metadata": {},
"source": [
"## 再说结论该数据集fetch_covtypeStacking的方法相比Blending和线性加权更好"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}