From b3a639a8e8347bc29c15abe54c73ae5d890379ff Mon Sep 17 00:00:00 2001
From: benjas <909336740@qq.com>
Date: Sun, 22 Aug 2021 14:02:08 +0800
Subject: [PATCH] Add. Stacking better than Blending and linear-weighted
---
.../Stacking-checkpoint.ipynb | 727 ++++++++++++++----
竞赛优胜技巧/Stacking.ipynb | 727 ++++++++++++++----
2 files changed, 1112 insertions(+), 342 deletions(-)
diff --git a/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb b/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb
index ef7f238..4ae86fe 100644
--- a/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb
+++ b/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb
@@ -13,7 +13,7 @@
"id": "0b365a02",
"metadata": {},
"source": [
- "## 先说结论,该数据集(fetch_covtype)Stacking的方法比线性加权更好\n",
+ "## 先说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好\n",
"比赛中我们常用线性加权作为最终的融合方式,我们同样也会好奇怎样的线性加权权重更好,下面也会举例子\n",
"参考:https://github.com/rushter/heamy/tree/master/examples"
]
@@ -67,7 +67,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 1,
"id": "69632c6a",
"metadata": {},
"outputs": [
@@ -86,7 +86,7 @@
},
{
"cell_type": "code",
- "execution_count": 61,
+ "execution_count": 2,
"id": "ca421279",
"metadata": {},
"outputs": [],
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 3,
"id": "9a0fabe1",
"metadata": {},
"outputs": [],
@@ -130,7 +130,7 @@
},
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": 6,
"id": "5bd75178",
"metadata": {},
"outputs": [
@@ -153,14 +153,14 @@
"print(y)\n",
"ord = OrdinalEncoder()\n",
"y = ord.fit_transform(y.reshape(-1, 1))\n",
- "y = y_enc.reshape(-1, )\n",
+ "y = y.reshape(-1, )\n",
"print('七分类任务,处理后:',np.unique(y))\n",
"print(y)"
]
},
{
"cell_type": "code",
- "execution_count": 48,
+ "execution_count": 7,
"id": "23d9778c",
"metadata": {},
"outputs": [
@@ -182,7 +182,7 @@
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 8,
"id": "eac48668",
"metadata": {},
"outputs": [
@@ -204,7 +204,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": 9,
"id": "fba3f975",
"metadata": {},
"outputs": [
@@ -226,7 +226,7 @@
" 0.000e+00]])"
]
},
- "execution_count": 50,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -246,7 +246,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 10,
"id": "e8393e73",
"metadata": {},
"outputs": [],
@@ -293,7 +293,7 @@
},
{
"cell_type": "code",
- "execution_count": 52,
+ "execution_count": 11,
"id": "78ab0083",
"metadata": {},
"outputs": [],
@@ -313,7 +313,7 @@
},
{
"cell_type": "code",
- "execution_count": 53,
+ "execution_count": 13,
"id": "173ef0f0",
"metadata": {},
"outputs": [
@@ -321,185 +321,485 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Best Score (log_loss): 0.18744137777851164\n",
- "Best Weights: [0.36556831 0.00303401 0.63139768]\n"
+ "Best Score (log_loss): 0.18646435443714865\n",
+ "Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
+ "Wall time: 14min 19s\n"
]
},
{
"data": {
"text/plain": [
- "array([0.36556831, 0.00303401, 0.63139768])"
+ "array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])"
]
},
- "execution_count": 53,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "pipeline.find_weights(scorer=log_loss, ) # 输出最优权重组合"
+ "%%time\n",
+ "pipeline.find_weights(scorer=log_loss) # 输出最优权重组合"
]
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": 22,
"id": "80726d19",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wall time: 1h 39min 20s\n"
+ ]
+ }
+ ],
"source": [
+ "%%time\n",
"# 5折训练构建5折模型特征集,这里比较耗时\n",
- "stack_ds = pipeline.stack(k=5,stratify=False,seed=42,full_test=False) # full_test指明预测全部还是预测当前折的验证集"
+ "stack_ds = pipeline.stack(k=5,seed=42)"
]
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": 26,
"id": "b25bba3c",
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n",
- "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 \n",
- "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 \n",
- "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 \n",
- "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 \n",
- "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 \n",
- "... ... ... ... ... ... \n",
- "435754 0.988518 0.011477 3.190797e-09 5.645121e-08 2.940739e-09 \n",
- "435755 0.969212 0.030723 2.142020e-08 1.572054e-05 4.321913e-07 \n",
- "435756 0.415850 0.584142 4.283793e-08 7.367601e-08 6.148067e-07 \n",
- "435757 0.602601 0.397399 6.606462e-10 1.015894e-09 7.221973e-08 \n",
- "435758 0.834587 0.165411 3.267833e-09 2.057172e-08 2.078704e-08 \n",
- "\n",
- " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n",
- "0 1.725062e-06 1.048052e-06 0.172406 0.812678 1.416886e-06 ... \n",
- "1 1.435787e-09 1.603579e-10 0.008114 0.991886 0.000000e+00 ... \n",
- "2 6.384080e-10 2.823794e-08 0.817627 0.182372 0.000000e+00 ... \n",
- "3 2.230641e-06 6.630235e-05 0.465733 0.534184 0.000000e+00 ... \n",
- "4 5.604260e-07 9.037877e-04 0.932050 0.043451 0.000000e+00 ... \n",
- "... ... ... ... ... ... ... \n",
- "435754 1.530261e-08 4.466830e-06 0.970593 0.029399 0.000000e+00 ... \n",
- "435755 5.574208e-10 4.977021e-05 0.862591 0.136644 0.000000e+00 ... \n",
- "435756 2.371389e-06 5.185283e-06 0.466886 0.533039 0.000000e+00 ... \n",
- "435757 2.326313e-09 9.193871e-09 0.674250 0.325750 4.092880e-211 ... \n",
- "435758 5.976972e-08 2.204258e-06 0.709320 0.290680 0.000000e+00 ... \n",
- "\n",
- " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n",
- "0 1.486358e-02 5.093522e-05 6.805300e-08 0.06 0.92 0.0 0.0 \n",
- "1 0.000000e+00 0.000000e+00 0.000000e+00 0.12 0.88 0.0 0.0 \n",
- "2 4.452850e-07 7.825338e-09 1.012052e-07 0.63 0.37 0.0 0.0 \n",
- "3 0.000000e+00 0.000000e+00 8.245405e-05 0.56 0.44 0.0 0.0 \n",
- "4 0.000000e+00 0.000000e+00 2.449972e-02 0.95 0.04 0.0 0.0 \n",
- "... ... ... ... ... ... ... ... \n",
- "435754 0.000000e+00 0.000000e+00 7.871809e-06 0.97 0.03 0.0 0.0 \n",
- "435755 0.000000e+00 0.000000e+00 7.647430e-04 0.93 0.06 0.0 0.0 \n",
- "435756 0.000000e+00 0.000000e+00 7.493861e-05 0.45 0.55 0.0 0.0 \n",
- "435757 0.000000e+00 0.000000e+00 0.000000e+00 0.52 0.48 0.0 0.0 \n",
- "435758 0.000000e+00 0.000000e+00 0.000000e+00 0.87 0.13 0.0 0.0 \n",
- "\n",
- " rf_4 rf_5 rf_6 \n",
- "0 0.02 0.0 0.00 \n",
- "1 0.00 0.0 0.00 \n",
- "2 0.00 0.0 0.00 \n",
- "3 0.00 0.0 0.00 \n",
- "4 0.00 0.0 0.01 \n",
- "... ... ... ... \n",
- "435754 0.00 0.0 0.00 \n",
- "435755 0.00 0.0 0.01 \n",
- "435756 0.00 0.0 0.00 \n",
- "435757 0.00 0.0 0.00 \n",
- "435758 0.00 0.0 0.00 \n",
- "\n",
- "[435759 rows x 21 columns]\n"
- ]
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " xgb_0 | \n",
+ " xgb_1 | \n",
+ " xgb_2 | \n",
+ " xgb_3 | \n",
+ " xgb_4 | \n",
+ " xgb_5 | \n",
+ " xgb_6 | \n",
+ " lgb_0 | \n",
+ " lgb_1 | \n",
+ " lgb_2 | \n",
+ " ... | \n",
+ " lgb_4 | \n",
+ " lgb_5 | \n",
+ " lgb_6 | \n",
+ " rf_0 | \n",
+ " rf_1 | \n",
+ " rf_2 | \n",
+ " rf_3 | \n",
+ " rf_4 | \n",
+ " rf_5 | \n",
+ " rf_6 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.177179 | \n",
+ " 0.818728 | \n",
+ " 2.185222e-07 | \n",
+ " 9.264143e-09 | \n",
+ " 4.090067e-03 | \n",
+ " 1.725062e-06 | \n",
+ " 1.048052e-06 | \n",
+ " 0.179625 | \n",
+ " 0.804684 | \n",
+ " 0.000003 | \n",
+ " ... | \n",
+ " 1.562435e-02 | \n",
+ " 6.370849e-05 | \n",
+ " 1.004259e-08 | \n",
+ " 0.03 | \n",
+ " 0.96 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.01 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.005155 | \n",
+ " 0.994845 | \n",
+ " 7.055579e-10 | \n",
+ " 1.326343e-08 | \n",
+ " 6.331572e-09 | \n",
+ " 1.435787e-09 | \n",
+ " 1.603579e-10 | \n",
+ " 0.008114 | \n",
+ " 0.991886 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 0.13 | \n",
+ " 0.87 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.293492 | \n",
+ " 0.706508 | \n",
+ " 3.650662e-10 | \n",
+ " 1.017633e-09 | \n",
+ " 8.823530e-09 | \n",
+ " 6.384080e-10 | \n",
+ " 2.823794e-08 | \n",
+ " 0.831445 | \n",
+ " 0.168555 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 4.999034e-07 | \n",
+ " 4.015190e-09 | \n",
+ " 4.997854e-09 | \n",
+ " 0.63 | \n",
+ " 0.37 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.478112 | \n",
+ " 0.521816 | \n",
+ " 3.207779e-06 | \n",
+ " 2.878019e-08 | \n",
+ " 1.076500e-08 | \n",
+ " 2.230641e-06 | \n",
+ " 6.630235e-05 | \n",
+ " 0.465733 | \n",
+ " 0.534184 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 8.245405e-05 | \n",
+ " 0.55 | \n",
+ " 0.45 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.992430 | \n",
+ " 0.006652 | \n",
+ " 1.233117e-05 | \n",
+ " 1.887496e-07 | \n",
+ " 1.569583e-06 | \n",
+ " 5.604260e-07 | \n",
+ " 9.037877e-04 | \n",
+ " 0.932050 | \n",
+ " 0.043451 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 2.449972e-02 | \n",
+ " 0.97 | \n",
+ " 0.03 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5 rows × 21 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
+ "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 1.725062e-06 \n",
+ "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 1.435787e-09 \n",
+ "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 6.384080e-10 \n",
+ "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 2.230641e-06 \n",
+ "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 5.604260e-07 \n",
+ "\n",
+ " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 \\\n",
+ "0 1.048052e-06 0.179625 0.804684 0.000003 ... 1.562435e-02 \n",
+ "1 1.603579e-10 0.008114 0.991886 0.000000 ... 0.000000e+00 \n",
+ "2 2.823794e-08 0.831445 0.168555 0.000000 ... 4.999034e-07 \n",
+ "3 6.630235e-05 0.465733 0.534184 0.000000 ... 0.000000e+00 \n",
+ "4 9.037877e-04 0.932050 0.043451 0.000000 ... 0.000000e+00 \n",
+ "\n",
+ " lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
+ "0 6.370849e-05 1.004259e-08 0.03 0.96 0.0 0.0 0.01 0.0 0.0 \n",
+ "1 0.000000e+00 0.000000e+00 0.13 0.87 0.0 0.0 0.00 0.0 0.0 \n",
+ "2 4.015190e-09 4.997854e-09 0.63 0.37 0.0 0.0 0.00 0.0 0.0 \n",
+ "3 0.000000e+00 8.245405e-05 0.55 0.45 0.0 0.0 0.00 0.0 0.0 \n",
+ "4 0.000000e+00 2.449972e-02 0.97 0.03 0.0 0.0 0.00 0.0 0.0 \n",
+ "\n",
+ "[5 rows x 21 columns]"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
"# 模型输出的训练集,7个特征对应7个标签的预测概率\n",
- "print(stack_ds.X_train)"
+ "stack_ds.X_train.head()"
]
},
{
"cell_type": "code",
- "execution_count": 59,
+ "execution_count": 28,
"id": "835205e9",
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n",
- "0 9.876224e-01 0.000789 2.774616e-06 4.129093e-07 1.311387e-06 \n",
- "1 5.139124e-02 0.929659 1.852793e-03 1.518293e-07 1.692924e-02 \n",
- "2 7.695035e-04 0.973729 6.878623e-04 1.573823e-07 2.408167e-02 \n",
- "3 3.376913e-02 0.966229 2.024872e-07 7.321523e-08 1.071163e-06 \n",
- "4 1.013981e-03 0.998553 3.794874e-06 8.755425e-08 4.243054e-04 \n",
- "... ... ... ... ... ... \n",
- "145248 9.615189e-01 0.038480 6.486028e-08 1.744931e-08 1.069370e-06 \n",
- "145249 3.055384e-02 0.969440 2.475371e-07 5.530033e-08 4.299908e-06 \n",
- "145250 8.224608e-06 0.058361 9.212288e-01 9.705171e-08 5.440121e-05 \n",
- "145251 9.183387e-01 0.081601 5.612090e-08 1.088283e-08 5.225256e-07 \n",
- "145252 9.203915e-07 0.003578 2.372825e-01 1.582836e-06 3.307252e-07 \n",
- "\n",
- " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n",
- "0 7.851924e-09 1.158422e-02 0.962538 0.004222 9.599869e-23 ... \n",
- "1 1.613036e-04 6.763449e-06 0.070947 0.882463 2.232464e-03 ... \n",
- "2 7.296442e-04 1.884172e-06 0.004029 0.945838 1.014722e-02 ... \n",
- "3 7.227585e-08 1.448203e-09 0.066538 0.933450 1.206630e-06 ... \n",
- "4 4.374311e-06 4.566837e-09 0.001334 0.997391 1.580417e-06 ... \n",
- "... ... ... ... ... ... ... \n",
- "145248 5.049759e-08 4.010809e-07 0.917842 0.082153 2.154302e-17 ... \n",
- "145249 1.255851e-08 1.224208e-06 0.058622 0.941370 1.332795e-12 ... \n",
- "145250 2.034389e-02 3.132630e-06 0.000268 0.083680 8.789707e-01 ... \n",
- "145251 2.566383e-07 5.976933e-05 0.875834 0.123030 2.631276e-12 ... \n",
- "145252 7.591362e-01 6.988637e-08 0.000032 0.037757 2.462795e-01 ... \n",
- "\n",
- " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n",
- "0 1.726260e-240 0.000000e+00 3.324009e-02 0.984 0.000 0.000 0.0 \n",
- "1 4.396802e-02 3.888647e-04 1.077146e-08 0.106 0.816 0.038 0.0 \n",
- "2 3.329283e-02 6.693269e-03 3.213404e-09 0.008 0.950 0.002 0.0 \n",
- "3 9.908823e-06 1.371448e-07 1.280235e-09 0.078 0.922 0.000 0.0 \n",
- "4 1.273086e-03 7.811306e-07 5.594472e-10 0.004 0.988 0.000 0.0 \n",
- "... ... ... ... ... ... ... ... \n",
- "145248 6.191352e-08 0.000000e+00 4.958509e-06 0.968 0.032 0.000 0.0 \n",
- "145249 7.656931e-06 3.415083e-47 2.271880e-07 0.018 0.972 0.000 0.0 \n",
- "145250 2.052535e-04 3.687570e-02 1.393421e-09 0.000 0.040 0.946 0.0 \n",
- "145251 2.521124e-07 5.749375e-08 1.135236e-03 0.992 0.008 0.000 0.0 \n",
- "145252 2.927400e-06 7.159244e-01 8.608624e-140 0.000 0.018 0.110 0.0 \n",
- "\n",
- " rf_4 rf_5 rf_6 \n",
- "0 0.000 0.000 0.016 \n",
- "1 0.034 0.006 0.000 \n",
- "2 0.032 0.008 0.000 \n",
- "3 0.000 0.000 0.000 \n",
- "4 0.008 0.000 0.000 \n",
- "... ... ... ... \n",
- "145248 0.000 0.000 0.000 \n",
- "145249 0.010 0.000 0.000 \n",
- "145250 0.000 0.014 0.000 \n",
- "145251 0.000 0.000 0.000 \n",
- "145252 0.000 0.872 0.000 \n",
- "\n",
- "[145253 rows x 21 columns]\n"
- ]
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " xgb_0 | \n",
+ " xgb_1 | \n",
+ " xgb_2 | \n",
+ " xgb_3 | \n",
+ " xgb_4 | \n",
+ " xgb_5 | \n",
+ " xgb_6 | \n",
+ " lgb_0 | \n",
+ " lgb_1 | \n",
+ " lgb_2 | \n",
+ " ... | \n",
+ " lgb_4 | \n",
+ " lgb_5 | \n",
+ " lgb_6 | \n",
+ " rf_0 | \n",
+ " rf_1 | \n",
+ " rf_2 | \n",
+ " rf_3 | \n",
+ " rf_4 | \n",
+ " rf_5 | \n",
+ " rf_6 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.991493 | \n",
+ " 0.000210 | \n",
+ " 4.796058e-06 | \n",
+ " 6.178684e-08 | \n",
+ " 6.947614e-07 | \n",
+ " 2.410490e-09 | \n",
+ " 8.291459e-03 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000 | \n",
+ " 1.000000e+00 | \n",
+ " 0.99 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.01 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.024731 | \n",
+ " 0.964372 | \n",
+ " 6.387765e-04 | \n",
+ " 4.205048e-08 | \n",
+ " 1.006575e-02 | \n",
+ " 1.879628e-04 | \n",
+ " 4.830114e-06 | \n",
+ " 0.065073 | \n",
+ " 0.877354 | \n",
+ " 0.001678 | \n",
+ " ... | \n",
+ " 5.530599e-02 | \n",
+ " 0.000589 | \n",
+ " 8.601484e-11 | \n",
+ " 0.09 | \n",
+ " 0.80 | \n",
+ " 0.05 | \n",
+ " 0.0 | \n",
+ " 0.06 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.000780 | \n",
+ " 0.979776 | \n",
+ " 8.593459e-04 | \n",
+ " 1.267791e-07 | \n",
+ " 1.710379e-02 | \n",
+ " 1.477527e-03 | \n",
+ " 2.521008e-06 | \n",
+ " 0.005164 | \n",
+ " 0.933849 | \n",
+ " 0.016355 | \n",
+ " ... | \n",
+ " 3.657553e-02 | \n",
+ " 0.008057 | \n",
+ " 0.000000e+00 | \n",
+ " 0.01 | \n",
+ " 0.97 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.01 | \n",
+ " 0.01 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.042695 | \n",
+ " 0.957304 | \n",
+ " 2.283268e-08 | \n",
+ " 4.387427e-08 | \n",
+ " 4.175481e-07 | \n",
+ " 4.406019e-08 | \n",
+ " 6.909629e-10 | \n",
+ " 0.054392 | \n",
+ " 0.945608 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 4.285638e-08 | \n",
+ " 0.000000 | \n",
+ " 0.000000e+00 | \n",
+ " 0.04 | \n",
+ " 0.96 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.000457 | \n",
+ " 0.999334 | \n",
+ " 3.366338e-06 | \n",
+ " 4.893879e-08 | \n",
+ " 2.045808e-04 | \n",
+ " 7.889498e-07 | \n",
+ " 1.415576e-09 | \n",
+ " 0.001367 | \n",
+ " 0.995857 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 2.776106e-03 | \n",
+ " 0.000000 | \n",
+ " 0.000000e+00 | \n",
+ " 0.00 | \n",
+ " 0.99 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.01 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5 rows × 21 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
+ "0 0.991493 0.000210 4.796058e-06 6.178684e-08 6.947614e-07 2.410490e-09 \n",
+ "1 0.024731 0.964372 6.387765e-04 4.205048e-08 1.006575e-02 1.879628e-04 \n",
+ "2 0.000780 0.979776 8.593459e-04 1.267791e-07 1.710379e-02 1.477527e-03 \n",
+ "3 0.042695 0.957304 2.283268e-08 4.387427e-08 4.175481e-07 4.406019e-08 \n",
+ "4 0.000457 0.999334 3.366338e-06 4.893879e-08 2.045808e-04 7.889498e-07 \n",
+ "\n",
+ " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 lgb_5 \\\n",
+ "0 8.291459e-03 0.000000 0.000000 0.000000 ... 0.000000e+00 0.000000 \n",
+ "1 4.830114e-06 0.065073 0.877354 0.001678 ... 5.530599e-02 0.000589 \n",
+ "2 2.521008e-06 0.005164 0.933849 0.016355 ... 3.657553e-02 0.008057 \n",
+ "3 6.909629e-10 0.054392 0.945608 0.000000 ... 4.285638e-08 0.000000 \n",
+ "4 1.415576e-09 0.001367 0.995857 0.000000 ... 2.776106e-03 0.000000 \n",
+ "\n",
+ " lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
+ "0 1.000000e+00 0.99 0.00 0.00 0.0 0.00 0.00 0.01 \n",
+ "1 8.601484e-11 0.09 0.80 0.05 0.0 0.06 0.00 0.00 \n",
+ "2 0.000000e+00 0.01 0.97 0.00 0.0 0.01 0.01 0.00 \n",
+ "3 0.000000e+00 0.04 0.96 0.00 0.0 0.00 0.00 0.00 \n",
+ "4 0.000000e+00 0.00 0.99 0.00 0.0 0.01 0.00 0.00 \n",
+ "\n",
+ "[5 rows x 21 columns]"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
"# 模型输出的测试集,7个特征对应7个标签的预测概率\n",
- "print(stack_ds.X_test) "
+ "stack_ds.X_test.head()"
]
},
{
"cell_type": "code",
- "execution_count": 60,
+ "execution_count": 30,
"id": "b9db35dc",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wall time: 1min 53s\n"
+ ]
+ }
+ ],
"source": [
+ "%%time\n",
"# 用lr做最后一层\n",
"stacker = Classifier(dataset=stack_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
"predict_stack = stacker.predict()"
@@ -507,7 +807,7 @@
},
{
"cell_type": "code",
- "execution_count": 64,
+ "execution_count": 32,
"id": "a4a48219",
"metadata": {},
"outputs": [
@@ -515,19 +815,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "[[9.95173402e-01 2.67623709e-03 4.23846755e-08 ... 3.15435935e-05\n",
- " 5.66194220e-06 2.11044140e-03]\n",
- " [2.23612439e-02 9.70927685e-01 1.23929922e-03 ... 4.49727904e-03\n",
- " 8.73983383e-04 9.97020226e-05]\n",
- " [6.22588197e-03 9.89402233e-01 9.81655972e-04 ... 2.83331258e-03\n",
- " 5.22139184e-04 3.45071569e-05]\n",
+ "[[9.99137967e-01 4.76574513e-04 3.32764186e-07 ... 3.03237700e-06\n",
+ " 1.62161211e-06 3.79326620e-04]\n",
+ " [2.00732175e-02 9.71682830e-01 1.51484266e-03 ... 5.69138554e-03\n",
+ " 9.40725428e-04 9.60822625e-05]\n",
+ " [5.54556002e-03 9.91048437e-01 8.04840682e-04 ... 2.11437934e-03\n",
+ " 4.56463787e-04 3.00919502e-05]\n",
" ...\n",
- " [5.36335125e-06 2.06267200e-03 9.90604140e-01 ... 8.55252386e-04\n",
- " 4.18405061e-03 1.64678945e-05]\n",
- " [9.96602824e-01 2.15991442e-03 7.27481581e-08 ... 3.63552051e-05\n",
- " 6.80942632e-06 1.19199377e-03]\n",
- " [5.89156494e-05 1.15333400e-03 1.09178439e-02 ... 3.09244417e-04\n",
- " 9.85167196e-01 2.21261408e-05]]\n"
+ " [4.60179790e-06 1.78298095e-03 9.91553958e-01 ... 7.26752933e-04\n",
+ " 3.79135124e-03 1.53584401e-05]\n",
+ " [9.96307096e-01 2.43558944e-03 1.01596361e-07 ... 3.94985596e-05\n",
+ " 7.41569805e-06 1.20819024e-03]\n",
+ " [5.34671504e-05 7.62534718e-04 5.58323657e-03 ... 2.11410908e-04\n",
+ " 9.91805379e-01 1.69502656e-05]]\n"
]
}
],
@@ -553,7 +853,7 @@
},
{
"cell_type": "code",
- "execution_count": 65,
+ "execution_count": 33,
"id": "a28806a0",
"metadata": {},
"outputs": [
@@ -561,9 +861,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "0.9284696357390209\n",
- "0.8890005714167694\n",
- "0.9511404239499356\n"
+ "0.9254473229468583\n",
+ "0.8412562907478676\n",
+ "0.9535087055000585\n"
]
}
],
@@ -573,6 +873,28 @@
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 14:].values, axis=1),y_test)) # RF"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "1db92fec",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.9537840870756542\n",
+ "Wall time: 50.8 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# 测试单模运行结果是否一致\n",
+ "rf_predict = rf_model(X_train, y_train, X_test, None)\n",
+ "print(accuracy_score(np.argmax(rf_predict, axis=1),y_test)) # RF"
+ ]
+ },
{
"cell_type": "markdown",
"id": "2e9423ce",
@@ -583,7 +905,7 @@
},
{
"cell_type": "code",
- "execution_count": 70,
+ "execution_count": 37,
"id": "d2b50ba4",
"metadata": {},
"outputs": [
@@ -591,8 +913,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "主观根据结果blending: 0.9425209806337908\n",
- "根据最优权重的blending: 0.9488616414118813\n"
+ "主观根据结果blending: 0.9504106627746071\n",
+ "根据最优权重的线性加权: 0.9530749795184953\n"
]
}
],
@@ -603,19 +925,74 @@
"rf_t = stack_ds.X_test.iloc[:, 14:].values\n",
"\n",
"# 根据分数好坏随机定\n",
- "result = 0.3*xgb_t+0.2*lgb_t+0.5*rf_t\n",
+ "result = 0.2*xgb_t+0.1*lgb_t+0.7*rf_t\n",
"print('主观根据结果blending:', accuracy_score(np.argmax(result, axis=1), y_test))\n",
- "# 根据上面提供的最优权重 Best Weights: [0.36556831 0.00303401 0.63139768]\n",
- "result = 0.36556831*xgb_t+0.00303401*lgb_t+0.63139768*rf_t\n",
- "print('根据最优权重的blending:',accuracy_score(np.argmax(result, axis=1), y_test))"
+ "# 根据上面提供的最优权重 Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
+ "result = 2.53464919e-01*xgb_t+1.48562205e-20*lgb_t+7.46535081e-01*rf_t\n",
+ "print('根据最优权重的线性加权:',accuracy_score(np.argmax(result, axis=1), y_test))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8ce00b2",
+ "metadata": {},
+ "source": [
+ "可以观察到最优权重比我们主观选权重更优,单对比单模结果反而下降了"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4a12d6d",
+ "metadata": {},
+ "source": [
+ "### Blending的分数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "ffe1cf4d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wall time: 14min 10s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "blend_ds = pipeline.blend(seed=111)\n",
+ "blender = Classifier(dataset=blend_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
+ "predict_blend = blender.predict()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "506adffb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.9546859617357301\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(accuracy_score(np.argmax(predict_blend, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
- "id": "dfec8968",
+ "id": "e84c19ae",
"metadata": {},
"source": [
- "可以观察到最优权重比我们主观选权重更优"
+ "使用Blending的分数有所提升"
]
},
{
@@ -623,12 +1000,12 @@
"id": "e8daf1e3",
"metadata": {},
"source": [
- "### stacking的分数"
+ "### Stacking的分数"
]
},
{
"cell_type": "code",
- "execution_count": 71,
+ "execution_count": 38,
"id": "4930b407",
"metadata": {},
"outputs": [
@@ -636,7 +1013,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "0.957439777491687\n"
+ "0.9589887988544127\n"
]
}
],
@@ -644,12 +1021,20 @@
"print(accuracy_score(np.argmax(predict_stack, axis=1), y_test))"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "d5ad3c4b",
+ "metadata": {},
+ "source": [
+ "可以明显看到提升的效果"
+ ]
+ },
{
"cell_type": "markdown",
"id": "6311fbd7",
"metadata": {},
"source": [
- "## 再说结论,该数据集(fetch_covtype)Stacking的方法更好"
+ "## 再说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好"
]
}
],
diff --git a/竞赛优胜技巧/Stacking.ipynb b/竞赛优胜技巧/Stacking.ipynb
index ef7f238..4ae86fe 100644
--- a/竞赛优胜技巧/Stacking.ipynb
+++ b/竞赛优胜技巧/Stacking.ipynb
@@ -13,7 +13,7 @@
"id": "0b365a02",
"metadata": {},
"source": [
- "## 先说结论,该数据集(fetch_covtype)Stacking的方法比线性加权更好\n",
+ "## 先说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好\n",
"比赛中我们常用线性加权作为最终的融合方式,我们同样也会好奇怎样的线性加权权重更好,下面也会举例子\n",
"参考:https://github.com/rushter/heamy/tree/master/examples"
]
@@ -67,7 +67,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 1,
"id": "69632c6a",
"metadata": {},
"outputs": [
@@ -86,7 +86,7 @@
},
{
"cell_type": "code",
- "execution_count": 61,
+ "execution_count": 2,
"id": "ca421279",
"metadata": {},
"outputs": [],
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 3,
"id": "9a0fabe1",
"metadata": {},
"outputs": [],
@@ -130,7 +130,7 @@
},
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": 6,
"id": "5bd75178",
"metadata": {},
"outputs": [
@@ -153,14 +153,14 @@
"print(y)\n",
"ord = OrdinalEncoder()\n",
"y = ord.fit_transform(y.reshape(-1, 1))\n",
- "y = y_enc.reshape(-1, )\n",
+ "y = y.reshape(-1, )\n",
"print('七分类任务,处理后:',np.unique(y))\n",
"print(y)"
]
},
{
"cell_type": "code",
- "execution_count": 48,
+ "execution_count": 7,
"id": "23d9778c",
"metadata": {},
"outputs": [
@@ -182,7 +182,7 @@
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 8,
"id": "eac48668",
"metadata": {},
"outputs": [
@@ -204,7 +204,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": 9,
"id": "fba3f975",
"metadata": {},
"outputs": [
@@ -226,7 +226,7 @@
" 0.000e+00]])"
]
},
- "execution_count": 50,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -246,7 +246,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 10,
"id": "e8393e73",
"metadata": {},
"outputs": [],
@@ -293,7 +293,7 @@
},
{
"cell_type": "code",
- "execution_count": 52,
+ "execution_count": 11,
"id": "78ab0083",
"metadata": {},
"outputs": [],
@@ -313,7 +313,7 @@
},
{
"cell_type": "code",
- "execution_count": 53,
+ "execution_count": 13,
"id": "173ef0f0",
"metadata": {},
"outputs": [
@@ -321,185 +321,485 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Best Score (log_loss): 0.18744137777851164\n",
- "Best Weights: [0.36556831 0.00303401 0.63139768]\n"
+ "Best Score (log_loss): 0.18646435443714865\n",
+ "Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
+ "Wall time: 14min 19s\n"
]
},
{
"data": {
"text/plain": [
- "array([0.36556831, 0.00303401, 0.63139768])"
+ "array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])"
]
},
- "execution_count": 53,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "pipeline.find_weights(scorer=log_loss, ) # 输出最优权重组合"
+ "%%time\n",
+ "pipeline.find_weights(scorer=log_loss) # 输出最优权重组合"
]
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": 22,
"id": "80726d19",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wall time: 1h 39min 20s\n"
+ ]
+ }
+ ],
"source": [
+ "%%time\n",
"# 5折训练构建5折模型特征集,这里比较耗时\n",
- "stack_ds = pipeline.stack(k=5,stratify=False,seed=42,full_test=False) # full_test指明预测全部还是预测当前折的验证集"
+ "stack_ds = pipeline.stack(k=5,seed=42)"
]
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": 26,
"id": "b25bba3c",
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n",
- "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 \n",
- "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 \n",
- "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 \n",
- "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 \n",
- "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 \n",
- "... ... ... ... ... ... \n",
- "435754 0.988518 0.011477 3.190797e-09 5.645121e-08 2.940739e-09 \n",
- "435755 0.969212 0.030723 2.142020e-08 1.572054e-05 4.321913e-07 \n",
- "435756 0.415850 0.584142 4.283793e-08 7.367601e-08 6.148067e-07 \n",
- "435757 0.602601 0.397399 6.606462e-10 1.015894e-09 7.221973e-08 \n",
- "435758 0.834587 0.165411 3.267833e-09 2.057172e-08 2.078704e-08 \n",
- "\n",
- " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n",
- "0 1.725062e-06 1.048052e-06 0.172406 0.812678 1.416886e-06 ... \n",
- "1 1.435787e-09 1.603579e-10 0.008114 0.991886 0.000000e+00 ... \n",
- "2 6.384080e-10 2.823794e-08 0.817627 0.182372 0.000000e+00 ... \n",
- "3 2.230641e-06 6.630235e-05 0.465733 0.534184 0.000000e+00 ... \n",
- "4 5.604260e-07 9.037877e-04 0.932050 0.043451 0.000000e+00 ... \n",
- "... ... ... ... ... ... ... \n",
- "435754 1.530261e-08 4.466830e-06 0.970593 0.029399 0.000000e+00 ... \n",
- "435755 5.574208e-10 4.977021e-05 0.862591 0.136644 0.000000e+00 ... \n",
- "435756 2.371389e-06 5.185283e-06 0.466886 0.533039 0.000000e+00 ... \n",
- "435757 2.326313e-09 9.193871e-09 0.674250 0.325750 4.092880e-211 ... \n",
- "435758 5.976972e-08 2.204258e-06 0.709320 0.290680 0.000000e+00 ... \n",
- "\n",
- " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n",
- "0 1.486358e-02 5.093522e-05 6.805300e-08 0.06 0.92 0.0 0.0 \n",
- "1 0.000000e+00 0.000000e+00 0.000000e+00 0.12 0.88 0.0 0.0 \n",
- "2 4.452850e-07 7.825338e-09 1.012052e-07 0.63 0.37 0.0 0.0 \n",
- "3 0.000000e+00 0.000000e+00 8.245405e-05 0.56 0.44 0.0 0.0 \n",
- "4 0.000000e+00 0.000000e+00 2.449972e-02 0.95 0.04 0.0 0.0 \n",
- "... ... ... ... ... ... ... ... \n",
- "435754 0.000000e+00 0.000000e+00 7.871809e-06 0.97 0.03 0.0 0.0 \n",
- "435755 0.000000e+00 0.000000e+00 7.647430e-04 0.93 0.06 0.0 0.0 \n",
- "435756 0.000000e+00 0.000000e+00 7.493861e-05 0.45 0.55 0.0 0.0 \n",
- "435757 0.000000e+00 0.000000e+00 0.000000e+00 0.52 0.48 0.0 0.0 \n",
- "435758 0.000000e+00 0.000000e+00 0.000000e+00 0.87 0.13 0.0 0.0 \n",
- "\n",
- " rf_4 rf_5 rf_6 \n",
- "0 0.02 0.0 0.00 \n",
- "1 0.00 0.0 0.00 \n",
- "2 0.00 0.0 0.00 \n",
- "3 0.00 0.0 0.00 \n",
- "4 0.00 0.0 0.01 \n",
- "... ... ... ... \n",
- "435754 0.00 0.0 0.00 \n",
- "435755 0.00 0.0 0.01 \n",
- "435756 0.00 0.0 0.00 \n",
- "435757 0.00 0.0 0.00 \n",
- "435758 0.00 0.0 0.00 \n",
- "\n",
- "[435759 rows x 21 columns]\n"
- ]
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " xgb_0 | \n",
+ " xgb_1 | \n",
+ " xgb_2 | \n",
+ " xgb_3 | \n",
+ " xgb_4 | \n",
+ " xgb_5 | \n",
+ " xgb_6 | \n",
+ " lgb_0 | \n",
+ " lgb_1 | \n",
+ " lgb_2 | \n",
+ " ... | \n",
+ " lgb_4 | \n",
+ " lgb_5 | \n",
+ " lgb_6 | \n",
+ " rf_0 | \n",
+ " rf_1 | \n",
+ " rf_2 | \n",
+ " rf_3 | \n",
+ " rf_4 | \n",
+ " rf_5 | \n",
+ " rf_6 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.177179 | \n",
+ " 0.818728 | \n",
+ " 2.185222e-07 | \n",
+ " 9.264143e-09 | \n",
+ " 4.090067e-03 | \n",
+ " 1.725062e-06 | \n",
+ " 1.048052e-06 | \n",
+ " 0.179625 | \n",
+ " 0.804684 | \n",
+ " 0.000003 | \n",
+ " ... | \n",
+ " 1.562435e-02 | \n",
+ " 6.370849e-05 | \n",
+ " 1.004259e-08 | \n",
+ " 0.03 | \n",
+ " 0.96 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.01 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.005155 | \n",
+ " 0.994845 | \n",
+ " 7.055579e-10 | \n",
+ " 1.326343e-08 | \n",
+ " 6.331572e-09 | \n",
+ " 1.435787e-09 | \n",
+ " 1.603579e-10 | \n",
+ " 0.008114 | \n",
+ " 0.991886 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 0.13 | \n",
+ " 0.87 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.293492 | \n",
+ " 0.706508 | \n",
+ " 3.650662e-10 | \n",
+ " 1.017633e-09 | \n",
+ " 8.823530e-09 | \n",
+ " 6.384080e-10 | \n",
+ " 2.823794e-08 | \n",
+ " 0.831445 | \n",
+ " 0.168555 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 4.999034e-07 | \n",
+ " 4.015190e-09 | \n",
+ " 4.997854e-09 | \n",
+ " 0.63 | \n",
+ " 0.37 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.478112 | \n",
+ " 0.521816 | \n",
+ " 3.207779e-06 | \n",
+ " 2.878019e-08 | \n",
+ " 1.076500e-08 | \n",
+ " 2.230641e-06 | \n",
+ " 6.630235e-05 | \n",
+ " 0.465733 | \n",
+ " 0.534184 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 8.245405e-05 | \n",
+ " 0.55 | \n",
+ " 0.45 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.992430 | \n",
+ " 0.006652 | \n",
+ " 1.233117e-05 | \n",
+ " 1.887496e-07 | \n",
+ " 1.569583e-06 | \n",
+ " 5.604260e-07 | \n",
+ " 9.037877e-04 | \n",
+ " 0.932050 | \n",
+ " 0.043451 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000e+00 | \n",
+ " 2.449972e-02 | \n",
+ " 0.97 | \n",
+ " 0.03 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5 rows × 21 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
+ "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 1.725062e-06 \n",
+ "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 1.435787e-09 \n",
+ "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 6.384080e-10 \n",
+ "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 2.230641e-06 \n",
+ "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 5.604260e-07 \n",
+ "\n",
+ " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 \\\n",
+ "0 1.048052e-06 0.179625 0.804684 0.000003 ... 1.562435e-02 \n",
+ "1 1.603579e-10 0.008114 0.991886 0.000000 ... 0.000000e+00 \n",
+ "2 2.823794e-08 0.831445 0.168555 0.000000 ... 4.999034e-07 \n",
+ "3 6.630235e-05 0.465733 0.534184 0.000000 ... 0.000000e+00 \n",
+ "4 9.037877e-04 0.932050 0.043451 0.000000 ... 0.000000e+00 \n",
+ "\n",
+ " lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
+ "0 6.370849e-05 1.004259e-08 0.03 0.96 0.0 0.0 0.01 0.0 0.0 \n",
+ "1 0.000000e+00 0.000000e+00 0.13 0.87 0.0 0.0 0.00 0.0 0.0 \n",
+ "2 4.015190e-09 4.997854e-09 0.63 0.37 0.0 0.0 0.00 0.0 0.0 \n",
+ "3 0.000000e+00 8.245405e-05 0.55 0.45 0.0 0.0 0.00 0.0 0.0 \n",
+ "4 0.000000e+00 2.449972e-02 0.97 0.03 0.0 0.0 0.00 0.0 0.0 \n",
+ "\n",
+ "[5 rows x 21 columns]"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
"# 模型输出的训练集,7个特征对应7个标签的预测概率\n",
- "print(stack_ds.X_train)"
+ "stack_ds.X_train.head()"
]
},
{
"cell_type": "code",
- "execution_count": 59,
+ "execution_count": 28,
"id": "835205e9",
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n",
- "0 9.876224e-01 0.000789 2.774616e-06 4.129093e-07 1.311387e-06 \n",
- "1 5.139124e-02 0.929659 1.852793e-03 1.518293e-07 1.692924e-02 \n",
- "2 7.695035e-04 0.973729 6.878623e-04 1.573823e-07 2.408167e-02 \n",
- "3 3.376913e-02 0.966229 2.024872e-07 7.321523e-08 1.071163e-06 \n",
- "4 1.013981e-03 0.998553 3.794874e-06 8.755425e-08 4.243054e-04 \n",
- "... ... ... ... ... ... \n",
- "145248 9.615189e-01 0.038480 6.486028e-08 1.744931e-08 1.069370e-06 \n",
- "145249 3.055384e-02 0.969440 2.475371e-07 5.530033e-08 4.299908e-06 \n",
- "145250 8.224608e-06 0.058361 9.212288e-01 9.705171e-08 5.440121e-05 \n",
- "145251 9.183387e-01 0.081601 5.612090e-08 1.088283e-08 5.225256e-07 \n",
- "145252 9.203915e-07 0.003578 2.372825e-01 1.582836e-06 3.307252e-07 \n",
- "\n",
- " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n",
- "0 7.851924e-09 1.158422e-02 0.962538 0.004222 9.599869e-23 ... \n",
- "1 1.613036e-04 6.763449e-06 0.070947 0.882463 2.232464e-03 ... \n",
- "2 7.296442e-04 1.884172e-06 0.004029 0.945838 1.014722e-02 ... \n",
- "3 7.227585e-08 1.448203e-09 0.066538 0.933450 1.206630e-06 ... \n",
- "4 4.374311e-06 4.566837e-09 0.001334 0.997391 1.580417e-06 ... \n",
- "... ... ... ... ... ... ... \n",
- "145248 5.049759e-08 4.010809e-07 0.917842 0.082153 2.154302e-17 ... \n",
- "145249 1.255851e-08 1.224208e-06 0.058622 0.941370 1.332795e-12 ... \n",
- "145250 2.034389e-02 3.132630e-06 0.000268 0.083680 8.789707e-01 ... \n",
- "145251 2.566383e-07 5.976933e-05 0.875834 0.123030 2.631276e-12 ... \n",
- "145252 7.591362e-01 6.988637e-08 0.000032 0.037757 2.462795e-01 ... \n",
- "\n",
- " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n",
- "0 1.726260e-240 0.000000e+00 3.324009e-02 0.984 0.000 0.000 0.0 \n",
- "1 4.396802e-02 3.888647e-04 1.077146e-08 0.106 0.816 0.038 0.0 \n",
- "2 3.329283e-02 6.693269e-03 3.213404e-09 0.008 0.950 0.002 0.0 \n",
- "3 9.908823e-06 1.371448e-07 1.280235e-09 0.078 0.922 0.000 0.0 \n",
- "4 1.273086e-03 7.811306e-07 5.594472e-10 0.004 0.988 0.000 0.0 \n",
- "... ... ... ... ... ... ... ... \n",
- "145248 6.191352e-08 0.000000e+00 4.958509e-06 0.968 0.032 0.000 0.0 \n",
- "145249 7.656931e-06 3.415083e-47 2.271880e-07 0.018 0.972 0.000 0.0 \n",
- "145250 2.052535e-04 3.687570e-02 1.393421e-09 0.000 0.040 0.946 0.0 \n",
- "145251 2.521124e-07 5.749375e-08 1.135236e-03 0.992 0.008 0.000 0.0 \n",
- "145252 2.927400e-06 7.159244e-01 8.608624e-140 0.000 0.018 0.110 0.0 \n",
- "\n",
- " rf_4 rf_5 rf_6 \n",
- "0 0.000 0.000 0.016 \n",
- "1 0.034 0.006 0.000 \n",
- "2 0.032 0.008 0.000 \n",
- "3 0.000 0.000 0.000 \n",
- "4 0.008 0.000 0.000 \n",
- "... ... ... ... \n",
- "145248 0.000 0.000 0.000 \n",
- "145249 0.010 0.000 0.000 \n",
- "145250 0.000 0.014 0.000 \n",
- "145251 0.000 0.000 0.000 \n",
- "145252 0.000 0.872 0.000 \n",
- "\n",
- "[145253 rows x 21 columns]\n"
- ]
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " xgb_0 | \n",
+ " xgb_1 | \n",
+ " xgb_2 | \n",
+ " xgb_3 | \n",
+ " xgb_4 | \n",
+ " xgb_5 | \n",
+ " xgb_6 | \n",
+ " lgb_0 | \n",
+ " lgb_1 | \n",
+ " lgb_2 | \n",
+ " ... | \n",
+ " lgb_4 | \n",
+ " lgb_5 | \n",
+ " lgb_6 | \n",
+ " rf_0 | \n",
+ " rf_1 | \n",
+ " rf_2 | \n",
+ " rf_3 | \n",
+ " rf_4 | \n",
+ " rf_5 | \n",
+ " rf_6 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.991493 | \n",
+ " 0.000210 | \n",
+ " 4.796058e-06 | \n",
+ " 6.178684e-08 | \n",
+ " 6.947614e-07 | \n",
+ " 2.410490e-09 | \n",
+ " 8.291459e-03 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 0.000000e+00 | \n",
+ " 0.000000 | \n",
+ " 1.000000e+00 | \n",
+ " 0.99 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.01 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.024731 | \n",
+ " 0.964372 | \n",
+ " 6.387765e-04 | \n",
+ " 4.205048e-08 | \n",
+ " 1.006575e-02 | \n",
+ " 1.879628e-04 | \n",
+ " 4.830114e-06 | \n",
+ " 0.065073 | \n",
+ " 0.877354 | \n",
+ " 0.001678 | \n",
+ " ... | \n",
+ " 5.530599e-02 | \n",
+ " 0.000589 | \n",
+ " 8.601484e-11 | \n",
+ " 0.09 | \n",
+ " 0.80 | \n",
+ " 0.05 | \n",
+ " 0.0 | \n",
+ " 0.06 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.000780 | \n",
+ " 0.979776 | \n",
+ " 8.593459e-04 | \n",
+ " 1.267791e-07 | \n",
+ " 1.710379e-02 | \n",
+ " 1.477527e-03 | \n",
+ " 2.521008e-06 | \n",
+ " 0.005164 | \n",
+ " 0.933849 | \n",
+ " 0.016355 | \n",
+ " ... | \n",
+ " 3.657553e-02 | \n",
+ " 0.008057 | \n",
+ " 0.000000e+00 | \n",
+ " 0.01 | \n",
+ " 0.97 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.01 | \n",
+ " 0.01 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.042695 | \n",
+ " 0.957304 | \n",
+ " 2.283268e-08 | \n",
+ " 4.387427e-08 | \n",
+ " 4.175481e-07 | \n",
+ " 4.406019e-08 | \n",
+ " 6.909629e-10 | \n",
+ " 0.054392 | \n",
+ " 0.945608 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 4.285638e-08 | \n",
+ " 0.000000 | \n",
+ " 0.000000e+00 | \n",
+ " 0.04 | \n",
+ " 0.96 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.000457 | \n",
+ " 0.999334 | \n",
+ " 3.366338e-06 | \n",
+ " 4.893879e-08 | \n",
+ " 2.045808e-04 | \n",
+ " 7.889498e-07 | \n",
+ " 1.415576e-09 | \n",
+ " 0.001367 | \n",
+ " 0.995857 | \n",
+ " 0.000000 | \n",
+ " ... | \n",
+ " 2.776106e-03 | \n",
+ " 0.000000 | \n",
+ " 0.000000e+00 | \n",
+ " 0.00 | \n",
+ " 0.99 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.01 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5 rows × 21 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n",
+ "0 0.991493 0.000210 4.796058e-06 6.178684e-08 6.947614e-07 2.410490e-09 \n",
+ "1 0.024731 0.964372 6.387765e-04 4.205048e-08 1.006575e-02 1.879628e-04 \n",
+ "2 0.000780 0.979776 8.593459e-04 1.267791e-07 1.710379e-02 1.477527e-03 \n",
+ "3 0.042695 0.957304 2.283268e-08 4.387427e-08 4.175481e-07 4.406019e-08 \n",
+ "4 0.000457 0.999334 3.366338e-06 4.893879e-08 2.045808e-04 7.889498e-07 \n",
+ "\n",
+ " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 lgb_5 \\\n",
+ "0 8.291459e-03 0.000000 0.000000 0.000000 ... 0.000000e+00 0.000000 \n",
+ "1 4.830114e-06 0.065073 0.877354 0.001678 ... 5.530599e-02 0.000589 \n",
+ "2 2.521008e-06 0.005164 0.933849 0.016355 ... 3.657553e-02 0.008057 \n",
+ "3 6.909629e-10 0.054392 0.945608 0.000000 ... 4.285638e-08 0.000000 \n",
+ "4 1.415576e-09 0.001367 0.995857 0.000000 ... 2.776106e-03 0.000000 \n",
+ "\n",
+ " lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n",
+ "0 1.000000e+00 0.99 0.00 0.00 0.0 0.00 0.00 0.01 \n",
+ "1 8.601484e-11 0.09 0.80 0.05 0.0 0.06 0.00 0.00 \n",
+ "2 0.000000e+00 0.01 0.97 0.00 0.0 0.01 0.01 0.00 \n",
+ "3 0.000000e+00 0.04 0.96 0.00 0.0 0.00 0.00 0.00 \n",
+ "4 0.000000e+00 0.00 0.99 0.00 0.0 0.01 0.00 0.00 \n",
+ "\n",
+ "[5 rows x 21 columns]"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
"# 模型输出的测试集,7个特征对应7个标签的预测概率\n",
- "print(stack_ds.X_test) "
+ "stack_ds.X_test.head()"
]
},
{
"cell_type": "code",
- "execution_count": 60,
+ "execution_count": 30,
"id": "b9db35dc",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wall time: 1min 53s\n"
+ ]
+ }
+ ],
"source": [
+ "%%time\n",
"# 用lr做最后一层\n",
"stacker = Classifier(dataset=stack_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
"predict_stack = stacker.predict()"
@@ -507,7 +807,7 @@
},
{
"cell_type": "code",
- "execution_count": 64,
+ "execution_count": 32,
"id": "a4a48219",
"metadata": {},
"outputs": [
@@ -515,19 +815,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "[[9.95173402e-01 2.67623709e-03 4.23846755e-08 ... 3.15435935e-05\n",
- " 5.66194220e-06 2.11044140e-03]\n",
- " [2.23612439e-02 9.70927685e-01 1.23929922e-03 ... 4.49727904e-03\n",
- " 8.73983383e-04 9.97020226e-05]\n",
- " [6.22588197e-03 9.89402233e-01 9.81655972e-04 ... 2.83331258e-03\n",
- " 5.22139184e-04 3.45071569e-05]\n",
+ "[[9.99137967e-01 4.76574513e-04 3.32764186e-07 ... 3.03237700e-06\n",
+ " 1.62161211e-06 3.79326620e-04]\n",
+ " [2.00732175e-02 9.71682830e-01 1.51484266e-03 ... 5.69138554e-03\n",
+ " 9.40725428e-04 9.60822625e-05]\n",
+ " [5.54556002e-03 9.91048437e-01 8.04840682e-04 ... 2.11437934e-03\n",
+ " 4.56463787e-04 3.00919502e-05]\n",
" ...\n",
- " [5.36335125e-06 2.06267200e-03 9.90604140e-01 ... 8.55252386e-04\n",
- " 4.18405061e-03 1.64678945e-05]\n",
- " [9.96602824e-01 2.15991442e-03 7.27481581e-08 ... 3.63552051e-05\n",
- " 6.80942632e-06 1.19199377e-03]\n",
- " [5.89156494e-05 1.15333400e-03 1.09178439e-02 ... 3.09244417e-04\n",
- " 9.85167196e-01 2.21261408e-05]]\n"
+ " [4.60179790e-06 1.78298095e-03 9.91553958e-01 ... 7.26752933e-04\n",
+ " 3.79135124e-03 1.53584401e-05]\n",
+ " [9.96307096e-01 2.43558944e-03 1.01596361e-07 ... 3.94985596e-05\n",
+ " 7.41569805e-06 1.20819024e-03]\n",
+ " [5.34671504e-05 7.62534718e-04 5.58323657e-03 ... 2.11410908e-04\n",
+ " 9.91805379e-01 1.69502656e-05]]\n"
]
}
],
@@ -553,7 +853,7 @@
},
{
"cell_type": "code",
- "execution_count": 65,
+ "execution_count": 33,
"id": "a28806a0",
"metadata": {},
"outputs": [
@@ -561,9 +861,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "0.9284696357390209\n",
- "0.8890005714167694\n",
- "0.9511404239499356\n"
+ "0.9254473229468583\n",
+ "0.8412562907478676\n",
+ "0.9535087055000585\n"
]
}
],
@@ -573,6 +873,28 @@
"print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 14:].values, axis=1),y_test)) # RF"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "1db92fec",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.9537840870756542\n",
+ "Wall time: 50.8 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# 测试单模运行结果是否一致\n",
+ "rf_predict = rf_model(X_train, y_train, X_test, None)\n",
+ "print(accuracy_score(np.argmax(rf_predict, axis=1),y_test)) # RF"
+ ]
+ },
{
"cell_type": "markdown",
"id": "2e9423ce",
@@ -583,7 +905,7 @@
},
{
"cell_type": "code",
- "execution_count": 70,
+ "execution_count": 37,
"id": "d2b50ba4",
"metadata": {},
"outputs": [
@@ -591,8 +913,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "主观根据结果blending: 0.9425209806337908\n",
- "根据最优权重的blending: 0.9488616414118813\n"
+ "主观根据结果blending: 0.9504106627746071\n",
+ "根据最优权重的线性加权: 0.9530749795184953\n"
]
}
],
@@ -603,19 +925,74 @@
"rf_t = stack_ds.X_test.iloc[:, 14:].values\n",
"\n",
"# 根据分数好坏随机定\n",
- "result = 0.3*xgb_t+0.2*lgb_t+0.5*rf_t\n",
+ "result = 0.2*xgb_t+0.1*lgb_t+0.7*rf_t\n",
"print('主观根据结果blending:', accuracy_score(np.argmax(result, axis=1), y_test))\n",
- "# 根据上面提供的最优权重 Best Weights: [0.36556831 0.00303401 0.63139768]\n",
- "result = 0.36556831*xgb_t+0.00303401*lgb_t+0.63139768*rf_t\n",
- "print('根据最优权重的blending:',accuracy_score(np.argmax(result, axis=1), y_test))"
+ "# 根据上面提供的最优权重 Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
+ "result = 2.53464919e-01*xgb_t+1.48562205e-20*lgb_t+7.46535081e-01*rf_t\n",
+ "print('根据最优权重的线性加权:',accuracy_score(np.argmax(result, axis=1), y_test))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8ce00b2",
+ "metadata": {},
+ "source": [
+ "可以观察到最优权重比我们主观选权重更优,单对比单模结果反而下降了"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4a12d6d",
+ "metadata": {},
+ "source": [
+ "### Blending的分数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "ffe1cf4d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wall time: 14min 10s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "blend_ds = pipeline.blend(seed=111)\n",
+ "blender = Classifier(dataset=blend_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
+ "predict_blend = blender.predict()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "506adffb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.9546859617357301\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(accuracy_score(np.argmax(predict_blend, axis=1), y_test))"
]
},
{
"cell_type": "markdown",
- "id": "dfec8968",
+ "id": "e84c19ae",
"metadata": {},
"source": [
- "可以观察到最优权重比我们主观选权重更优"
+ "使用Blending的分数有所提升"
]
},
{
@@ -623,12 +1000,12 @@
"id": "e8daf1e3",
"metadata": {},
"source": [
- "### stacking的分数"
+ "### Stacking的分数"
]
},
{
"cell_type": "code",
- "execution_count": 71,
+ "execution_count": 38,
"id": "4930b407",
"metadata": {},
"outputs": [
@@ -636,7 +1013,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "0.957439777491687\n"
+ "0.9589887988544127\n"
]
}
],
@@ -644,12 +1021,20 @@
"print(accuracy_score(np.argmax(predict_stack, axis=1), y_test))"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "d5ad3c4b",
+ "metadata": {},
+ "source": [
+ "可以明显看到提升的效果"
+ ]
+ },
{
"cell_type": "markdown",
"id": "6311fbd7",
"metadata": {},
"source": [
- "## 再说结论,该数据集(fetch_covtype)Stacking的方法更好"
+ "## 再说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好"
]
}
],