{
"cells": [
{
"cell_type": "markdown",
"id": "720f2b62",
"metadata": {},
"source": [
"# Stacking"
]
},
{
"cell_type": "markdown",
"id": "0b365a02",
"metadata": {},
"source": [
"## 先说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好\n",
"比赛中我们常用线性加权作为最终的融合方式,我们同样也会好奇怎样的线性加权权重更好,下面也会举例子\n",
"参考:https://github.com/rushter/heamy/tree/master/examples"
]
},
{
"cell_type": "markdown",
"id": "cc8fecb1",
"metadata": {},
"source": [
"通过对训练集进行五折验证,将验证结果作为第二层的训练和测试集合\n",
""
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "18a12000",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Collecting heamy\n",
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/20/32/2f3e1efa38a8e34f790d90b6d49ef06ab812181ae896c50e89b8750fa5a0/heamy-0.0.7.tar.gz (30 kB)\n",
"Requirement already satisfied: scikit-learn>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (0.24.1)\n",
"Requirement already satisfied: pandas>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.2.4)\n",
"Requirement already satisfied: six>=1.10.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.15.0)\n",
"Requirement already satisfied: scipy>=0.16.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.6.2)\n",
"Requirement already satisfied: numpy>=1.7.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.19.5)\n",
"Requirement already satisfied: pytz>=2017.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2021.1)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2.8.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (2.1.0)\n",
"Requirement already satisfied: joblib>=0.11 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (1.0.1)\n",
"Building wheels for collected packages: heamy\n",
" Building wheel for heamy (setup.py): started\n",
" Building wheel for heamy (setup.py): finished with status 'done'\n",
" Created wheel for heamy: filename=heamy-0.0.7-py2.py3-none-any.whl size=15353 sha256=e3ba65b34e2bdee3b90b45b637e28836afdbdb0c9547f76b36fe10d17f8aba8f\n",
" Stored in directory: c:\\users\\administrator\\appdata\\local\\pip\\cache\\wheels\\6e\\f1\\7d\\048e558da94f495a0ed0d9c09d312e73eb176a092e36774ec2\n",
"Successfully built heamy\n",
"Installing collected packages: heamy\n",
"Successfully installed heamy-0.0.7\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"pip install heamy # 安装相关包"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "69632c6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]\n"
]
}
],
"source": [
"import sys\n",
"print(sys.version) # 版本信息"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ca421279",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import time\n",
"\n",
"from heamy.dataset import Dataset\n",
"from heamy.estimator import Classifier \n",
"from heamy.pipeline import ModelsPipeline\n",
"# 导入相关模型,没有的pip install xxx 即可\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"import xgboost as xgb \n",
"import lightgbm as lgb \n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"from sklearn.metrics import log_loss"
]
},
{
"cell_type": "markdown",
"id": "2592fbbd",
"metadata": {},
"source": [
"## 准备数据集"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9a0fabe1",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_covtype\n",
"data = fetch_covtype()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5bd75178",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"七分类任务,处理前: [1 2 3 4 5 6 7]\n",
"[5 5 2 ... 3 3 3]\n",
"七分类任务,处理后: [0. 1. 2. 3. 4. 5. 6.]\n",
"[4. 4. 1. ... 2. 2. 2.]\n"
]
}
],
"source": [
"# 预处理\n",
"X, y = data['data'], data['target']\n",
"# 由于模型标签需要从0开始,所以数字需要全部减1\n",
"print('七分类任务,处理前:',np.unique(y))\n",
"print(y)\n",
"ord = OrdinalEncoder()\n",
"y = ord.fit_transform(y.reshape(-1, 1))\n",
"y = y.reshape(-1, )\n",
"print('七分类任务,处理后:',np.unique(y))\n",
"print(y)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "23d9778c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(435759, 54)\n",
"(145253, 54)\n"
]
}
],
"source": [
"# 切分训练和测试集\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=42)\n",
"print(X_train.shape)\n",
"print(X_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eac48668",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset(5c3ccfb5c81451d098565ef5e7e36ac5)\n"
]
}
],
"source": [
"# 创建数据集\n",
"'''use_cache : bool, default True\n",
" If use_cache=True then preprocessing step will be cached until function codeis changed.'''\n",
"dataset = Dataset(X_train=X_train, y_train=y_train, X_test=X_test, y_test=None,use_cache=True) # 注意这里的y_test=None,即不存在数据泄露\n",
"print(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fba3f975",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[2.833e+03, 2.580e+02, 2.600e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.008e+03, 4.500e+01, 2.000e+00, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [2.949e+03, 0.000e+00, 1.100e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" ...,\n",
" [3.153e+03, 2.870e+02, 1.700e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.065e+03, 3.480e+02, 2.100e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00],\n",
" [3.021e+03, 2.600e+01, 1.600e+01, ..., 0.000e+00, 0.000e+00,\n",
" 0.000e+00]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 处理后的数据集\n",
"dataset.X_train"
]
},
{
"cell_type": "markdown",
"id": "d4517ea1",
"metadata": {},
"source": [
"## 定义模型"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e8393e73",
"metadata": {},
"outputs": [],
"source": [
"def xgb_model(X_train, y_train, X_test, y_test):\n",
" \"\"\"参数必须为X_train,y_train,X_test,y_test\"\"\"\n",
" # 可以内置参数\n",
" params = {'objective': 'multi:softprob',\n",
" \"eval_metric\": 'mlogloss',\n",
" \"verbosity\": 0,\n",
" 'num_class': 7,\n",
" 'nthread': -1}\n",
" dtrain = xgb.DMatrix(X_train, y_train)\n",
" dtest = xgb.DMatrix(X_test)\n",
" model = xgb.train(params, dtrain, num_boost_round=300)\n",
" predict = model.predict(dtest)\n",
" return predict # 返回值必须为X_test的预测\n",
"\n",
"\n",
"def lgb_model(X_train, y_train, X_test, y_test,**parameters):\n",
" # 也可以开放参数接口\n",
" if parameters is None:\n",
" parameters = {}\n",
" lgb_train = lgb.Dataset(X_train, y_train)\n",
" model = lgb.train(params=parameters, train_set=lgb_train,num_boost_round=300)\n",
" predict = model.predict(X_test)\n",
" return predict\n",
"\n",
"\n",
"def rf_model(X_train, y_train, X_test, y_test):\n",
" params = {\"n_estimators\": 100, \"n_jobs\": -1}\n",
" model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
" predict = model.predict_proba(X_test)\n",
" return predict"
]
},
{
"cell_type": "markdown",
"id": "0715cf6e",
"metadata": {},
"source": [
"## 构建和训练模型"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "78ab0083",
"metadata": {},
"outputs": [],
"source": [
"params = {\"objective\": \"multiclass\",\n",
" \"num_class\": 7,\n",
" \"n_jobs\": -1,\n",
" \"verbose\": -4, \n",
" \"metric\": (\"multi_logloss\",)}\n",
"\n",
"model_xgb = Classifier(dataset=dataset, estimator=xgb_model, name='xgb',use_cache=False)\n",
"model_lgb = Classifier(dataset=dataset, estimator=lgb_model, name='lgb',parameters=params,use_cache=False)\n",
"model_rf = Classifier(dataset=dataset, estimator=rf_model,name='rf',use_cache=False)\n",
"\n",
"pipeline = ModelsPipeline(model_xgb, model_lgb, model_rf)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "173ef0f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best Score (log_loss): 0.18646435443714865\n",
"Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
"Wall time: 14min 19s\n"
]
},
{
"data": {
"text/plain": [
"array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"pipeline.find_weights(scorer=log_loss) # 输出最优权重组合"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "80726d19",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 1h 39min 20s\n"
]
}
],
"source": [
"%%time\n",
"# 5折训练构建5折模型特征集,这里比较耗时\n",
"stack_ds = pipeline.stack(k=5,seed=42)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "b25bba3c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | xgb_0 | \n", "xgb_1 | \n", "xgb_2 | \n", "xgb_3 | \n", "xgb_4 | \n", "xgb_5 | \n", "xgb_6 | \n", "lgb_0 | \n", "lgb_1 | \n", "lgb_2 | \n", "... | \n", "lgb_4 | \n", "lgb_5 | \n", "lgb_6 | \n", "rf_0 | \n", "rf_1 | \n", "rf_2 | \n", "rf_3 | \n", "rf_4 | \n", "rf_5 | \n", "rf_6 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.177179 | \n", "0.818728 | \n", "2.185222e-07 | \n", "9.264143e-09 | \n", "4.090067e-03 | \n", "1.725062e-06 | \n", "1.048052e-06 | \n", "0.179625 | \n", "0.804684 | \n", "0.000003 | \n", "... | \n", "1.562435e-02 | \n", "6.370849e-05 | \n", "1.004259e-08 | \n", "0.03 | \n", "0.96 | \n", "0.0 | \n", "0.0 | \n", "0.01 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "0.005155 | \n", "0.994845 | \n", "7.055579e-10 | \n", "1.326343e-08 | \n", "6.331572e-09 | \n", "1.435787e-09 | \n", "1.603579e-10 | \n", "0.008114 | \n", "0.991886 | \n", "0.000000 | \n", "... | \n", "0.000000e+00 | \n", "0.000000e+00 | \n", "0.000000e+00 | \n", "0.13 | \n", "0.87 | \n", "0.0 | \n", "0.0 | \n", "0.00 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "0.293492 | \n", "0.706508 | \n", "3.650662e-10 | \n", "1.017633e-09 | \n", "8.823530e-09 | \n", "6.384080e-10 | \n", "2.823794e-08 | \n", "0.831445 | \n", "0.168555 | \n", "0.000000 | \n", "... | \n", "4.999034e-07 | \n", "4.015190e-09 | \n", "4.997854e-09 | \n", "0.63 | \n", "0.37 | \n", "0.0 | \n", "0.0 | \n", "0.00 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "0.478112 | \n", "0.521816 | \n", "3.207779e-06 | \n", "2.878019e-08 | \n", "1.076500e-08 | \n", "2.230641e-06 | \n", "6.630235e-05 | \n", "0.465733 | \n", "0.534184 | \n", "0.000000 | \n", "... | \n", "0.000000e+00 | \n", "0.000000e+00 | \n", "8.245405e-05 | \n", "0.55 | \n", "0.45 | \n", "0.0 | \n", "0.0 | \n", "0.00 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "0.992430 | \n", "0.006652 | \n", "1.233117e-05 | \n", "1.887496e-07 | \n", "1.569583e-06 | \n", "5.604260e-07 | \n", "9.037877e-04 | \n", "0.932050 | \n", "0.043451 | \n", "0.000000 | \n", "... | \n", "0.000000e+00 | \n", "0.000000e+00 | \n", "2.449972e-02 | \n", "0.97 | \n", "0.03 | \n", "0.0 | \n", "0.0 | \n", "0.00 | \n", "0.0 | \n", "0.0 | \n", "
5 rows × 21 columns
\n", "\n", " | xgb_0 | \n", "xgb_1 | \n", "xgb_2 | \n", "xgb_3 | \n", "xgb_4 | \n", "xgb_5 | \n", "xgb_6 | \n", "lgb_0 | \n", "lgb_1 | \n", "lgb_2 | \n", "... | \n", "lgb_4 | \n", "lgb_5 | \n", "lgb_6 | \n", "rf_0 | \n", "rf_1 | \n", "rf_2 | \n", "rf_3 | \n", "rf_4 | \n", "rf_5 | \n", "rf_6 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.991493 | \n", "0.000210 | \n", "4.796058e-06 | \n", "6.178684e-08 | \n", "6.947614e-07 | \n", "2.410490e-09 | \n", "8.291459e-03 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "... | \n", "0.000000e+00 | \n", "0.000000 | \n", "1.000000e+00 | \n", "0.99 | \n", "0.00 | \n", "0.00 | \n", "0.0 | \n", "0.00 | \n", "0.00 | \n", "0.01 | \n", "
1 | \n", "0.024731 | \n", "0.964372 | \n", "6.387765e-04 | \n", "4.205048e-08 | \n", "1.006575e-02 | \n", "1.879628e-04 | \n", "4.830114e-06 | \n", "0.065073 | \n", "0.877354 | \n", "0.001678 | \n", "... | \n", "5.530599e-02 | \n", "0.000589 | \n", "8.601484e-11 | \n", "0.09 | \n", "0.80 | \n", "0.05 | \n", "0.0 | \n", "0.06 | \n", "0.00 | \n", "0.00 | \n", "
2 | \n", "0.000780 | \n", "0.979776 | \n", "8.593459e-04 | \n", "1.267791e-07 | \n", "1.710379e-02 | \n", "1.477527e-03 | \n", "2.521008e-06 | \n", "0.005164 | \n", "0.933849 | \n", "0.016355 | \n", "... | \n", "3.657553e-02 | \n", "0.008057 | \n", "0.000000e+00 | \n", "0.01 | \n", "0.97 | \n", "0.00 | \n", "0.0 | \n", "0.01 | \n", "0.01 | \n", "0.00 | \n", "
3 | \n", "0.042695 | \n", "0.957304 | \n", "2.283268e-08 | \n", "4.387427e-08 | \n", "4.175481e-07 | \n", "4.406019e-08 | \n", "6.909629e-10 | \n", "0.054392 | \n", "0.945608 | \n", "0.000000 | \n", "... | \n", "4.285638e-08 | \n", "0.000000 | \n", "0.000000e+00 | \n", "0.04 | \n", "0.96 | \n", "0.00 | \n", "0.0 | \n", "0.00 | \n", "0.00 | \n", "0.00 | \n", "
4 | \n", "0.000457 | \n", "0.999334 | \n", "3.366338e-06 | \n", "4.893879e-08 | \n", "2.045808e-04 | \n", "7.889498e-07 | \n", "1.415576e-09 | \n", "0.001367 | \n", "0.995857 | \n", "0.000000 | \n", "... | \n", "2.776106e-03 | \n", "0.000000 | \n", "0.000000e+00 | \n", "0.00 | \n", "0.99 | \n", "0.00 | \n", "0.0 | \n", "0.01 | \n", "0.00 | \n", "0.00 | \n", "
5 rows × 21 columns
\n", "