From b3a639a8e8347bc29c15abe54c73ae5d890379ff Mon Sep 17 00:00:00 2001 From: benjas <909336740@qq.com> Date: Sun, 22 Aug 2021 14:02:08 +0800 Subject: [PATCH] Add. Stacking better than Blending and linear-weighted --- .../Stacking-checkpoint.ipynb | 727 ++++++++++++++---- 竞赛优胜技巧/Stacking.ipynb | 727 ++++++++++++++---- 2 files changed, 1112 insertions(+), 342 deletions(-) diff --git a/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb b/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb index ef7f238..4ae86fe 100644 --- a/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb +++ b/竞赛优胜技巧/.ipynb_checkpoints/Stacking-checkpoint.ipynb @@ -13,7 +13,7 @@ "id": "0b365a02", "metadata": {}, "source": [ - "## 先说结论,该数据集(fetch_covtype)Stacking的方法比线性加权更好\n", + "## 先说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好\n", "比赛中我们常用线性加权作为最终的融合方式,我们同样也会好奇怎样的线性加权权重更好,下面也会举例子\n", "参考:https://github.com/rushter/heamy/tree/master/examples" ] @@ -67,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "id": "69632c6a", "metadata": {}, "outputs": [ @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 2, "id": "ca421279", "metadata": {}, "outputs": [], @@ -119,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 3, "id": "9a0fabe1", "metadata": {}, "outputs": [], @@ -130,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 6, "id": "5bd75178", "metadata": {}, "outputs": [ @@ -153,14 +153,14 @@ "print(y)\n", "ord = OrdinalEncoder()\n", "y = ord.fit_transform(y.reshape(-1, 1))\n", - "y = y_enc.reshape(-1, )\n", + "y = y.reshape(-1, )\n", "print('七分类任务,处理后:',np.unique(y))\n", "print(y)" ] }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 7, "id": "23d9778c", "metadata": {}, "outputs": [ @@ -182,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 8, "id": "eac48668", "metadata": {}, "outputs": [ @@ -204,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 9, "id": "fba3f975", "metadata": {}, "outputs": [ @@ -226,7 +226,7 @@ " 0.000e+00]])" ] }, - "execution_count": 50, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 10, "id": "e8393e73", "metadata": {}, "outputs": [], @@ -293,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 11, "id": "78ab0083", "metadata": {}, "outputs": [], @@ -313,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 13, "id": "173ef0f0", "metadata": {}, "outputs": [ @@ -321,185 +321,485 @@ "name": "stdout", "output_type": "stream", "text": [ - "Best Score (log_loss): 0.18744137777851164\n", - "Best Weights: [0.36556831 0.00303401 0.63139768]\n" + "Best Score (log_loss): 0.18646435443714865\n", + "Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n", + "Wall time: 14min 19s\n" ] }, { "data": { "text/plain": [ - "array([0.36556831, 0.00303401, 0.63139768])" + "array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])" ] }, - "execution_count": 53, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "pipeline.find_weights(scorer=log_loss, ) # 输出最优权重组合" + "%%time\n", + "pipeline.find_weights(scorer=log_loss) # 输出最优权重组合" ] }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 22, "id": "80726d19", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wall time: 1h 39min 20s\n" + ] + } + ], "source": [ + "%%time\n", "# 5折训练构建5折模型特征集,这里比较耗时\n", - "stack_ds = pipeline.stack(k=5,stratify=False,seed=42,full_test=False) # full_test指明预测全部还是预测当前折的验证集" + "stack_ds = pipeline.stack(k=5,seed=42)" ] }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 26, "id": "b25bba3c", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n", - "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 \n", - "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 \n", - "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 \n", - "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 \n", - "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 \n", - "... ... ... ... ... ... \n", - "435754 0.988518 0.011477 3.190797e-09 5.645121e-08 2.940739e-09 \n", - "435755 0.969212 0.030723 2.142020e-08 1.572054e-05 4.321913e-07 \n", - "435756 0.415850 0.584142 4.283793e-08 7.367601e-08 6.148067e-07 \n", - "435757 0.602601 0.397399 6.606462e-10 1.015894e-09 7.221973e-08 \n", - "435758 0.834587 0.165411 3.267833e-09 2.057172e-08 2.078704e-08 \n", - "\n", - " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n", - "0 1.725062e-06 1.048052e-06 0.172406 0.812678 1.416886e-06 ... \n", - "1 1.435787e-09 1.603579e-10 0.008114 0.991886 0.000000e+00 ... \n", - "2 6.384080e-10 2.823794e-08 0.817627 0.182372 0.000000e+00 ... \n", - "3 2.230641e-06 6.630235e-05 0.465733 0.534184 0.000000e+00 ... \n", - "4 5.604260e-07 9.037877e-04 0.932050 0.043451 0.000000e+00 ... \n", - "... ... ... ... ... ... ... \n", - "435754 1.530261e-08 4.466830e-06 0.970593 0.029399 0.000000e+00 ... \n", - "435755 5.574208e-10 4.977021e-05 0.862591 0.136644 0.000000e+00 ... \n", - "435756 2.371389e-06 5.185283e-06 0.466886 0.533039 0.000000e+00 ... \n", - "435757 2.326313e-09 9.193871e-09 0.674250 0.325750 4.092880e-211 ... \n", - "435758 5.976972e-08 2.204258e-06 0.709320 0.290680 0.000000e+00 ... \n", - "\n", - " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n", - "0 1.486358e-02 5.093522e-05 6.805300e-08 0.06 0.92 0.0 0.0 \n", - "1 0.000000e+00 0.000000e+00 0.000000e+00 0.12 0.88 0.0 0.0 \n", - "2 4.452850e-07 7.825338e-09 1.012052e-07 0.63 0.37 0.0 0.0 \n", - "3 0.000000e+00 0.000000e+00 8.245405e-05 0.56 0.44 0.0 0.0 \n", - "4 0.000000e+00 0.000000e+00 2.449972e-02 0.95 0.04 0.0 0.0 \n", - "... ... ... ... ... ... ... ... \n", - "435754 0.000000e+00 0.000000e+00 7.871809e-06 0.97 0.03 0.0 0.0 \n", - "435755 0.000000e+00 0.000000e+00 7.647430e-04 0.93 0.06 0.0 0.0 \n", - "435756 0.000000e+00 0.000000e+00 7.493861e-05 0.45 0.55 0.0 0.0 \n", - "435757 0.000000e+00 0.000000e+00 0.000000e+00 0.52 0.48 0.0 0.0 \n", - "435758 0.000000e+00 0.000000e+00 0.000000e+00 0.87 0.13 0.0 0.0 \n", - "\n", - " rf_4 rf_5 rf_6 \n", - "0 0.02 0.0 0.00 \n", - "1 0.00 0.0 0.00 \n", - "2 0.00 0.0 0.00 \n", - "3 0.00 0.0 0.00 \n", - "4 0.00 0.0 0.01 \n", - "... ... ... ... \n", - "435754 0.00 0.0 0.00 \n", - "435755 0.00 0.0 0.01 \n", - "435756 0.00 0.0 0.00 \n", - "435757 0.00 0.0 0.00 \n", - "435758 0.00 0.0 0.00 \n", - "\n", - "[435759 rows x 21 columns]\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xgb_0xgb_1xgb_2xgb_3xgb_4xgb_5xgb_6lgb_0lgb_1lgb_2...lgb_4lgb_5lgb_6rf_0rf_1rf_2rf_3rf_4rf_5rf_6
00.1771790.8187282.185222e-079.264143e-094.090067e-031.725062e-061.048052e-060.1796250.8046840.000003...1.562435e-026.370849e-051.004259e-080.030.960.00.00.010.00.0
10.0051550.9948457.055579e-101.326343e-086.331572e-091.435787e-091.603579e-100.0081140.9918860.000000...0.000000e+000.000000e+000.000000e+000.130.870.00.00.000.00.0
20.2934920.7065083.650662e-101.017633e-098.823530e-096.384080e-102.823794e-080.8314450.1685550.000000...4.999034e-074.015190e-094.997854e-090.630.370.00.00.000.00.0
30.4781120.5218163.207779e-062.878019e-081.076500e-082.230641e-066.630235e-050.4657330.5341840.000000...0.000000e+000.000000e+008.245405e-050.550.450.00.00.000.00.0
40.9924300.0066521.233117e-051.887496e-071.569583e-065.604260e-079.037877e-040.9320500.0434510.000000...0.000000e+000.000000e+002.449972e-020.970.030.00.00.000.00.0
\n", + "

5 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n", + "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 1.725062e-06 \n", + "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 1.435787e-09 \n", + "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 6.384080e-10 \n", + "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 2.230641e-06 \n", + "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 5.604260e-07 \n", + "\n", + " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 \\\n", + "0 1.048052e-06 0.179625 0.804684 0.000003 ... 1.562435e-02 \n", + "1 1.603579e-10 0.008114 0.991886 0.000000 ... 0.000000e+00 \n", + "2 2.823794e-08 0.831445 0.168555 0.000000 ... 4.999034e-07 \n", + "3 6.630235e-05 0.465733 0.534184 0.000000 ... 0.000000e+00 \n", + "4 9.037877e-04 0.932050 0.043451 0.000000 ... 0.000000e+00 \n", + "\n", + " lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n", + "0 6.370849e-05 1.004259e-08 0.03 0.96 0.0 0.0 0.01 0.0 0.0 \n", + "1 0.000000e+00 0.000000e+00 0.13 0.87 0.0 0.0 0.00 0.0 0.0 \n", + "2 4.015190e-09 4.997854e-09 0.63 0.37 0.0 0.0 0.00 0.0 0.0 \n", + "3 0.000000e+00 8.245405e-05 0.55 0.45 0.0 0.0 0.00 0.0 0.0 \n", + "4 0.000000e+00 2.449972e-02 0.97 0.03 0.0 0.0 0.00 0.0 0.0 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "# 模型输出的训练集,7个特征对应7个标签的预测概率\n", - "print(stack_ds.X_train)" + "stack_ds.X_train.head()" ] }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 28, "id": "835205e9", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n", - "0 9.876224e-01 0.000789 2.774616e-06 4.129093e-07 1.311387e-06 \n", - "1 5.139124e-02 0.929659 1.852793e-03 1.518293e-07 1.692924e-02 \n", - "2 7.695035e-04 0.973729 6.878623e-04 1.573823e-07 2.408167e-02 \n", - "3 3.376913e-02 0.966229 2.024872e-07 7.321523e-08 1.071163e-06 \n", - "4 1.013981e-03 0.998553 3.794874e-06 8.755425e-08 4.243054e-04 \n", - "... ... ... ... ... ... \n", - "145248 9.615189e-01 0.038480 6.486028e-08 1.744931e-08 1.069370e-06 \n", - "145249 3.055384e-02 0.969440 2.475371e-07 5.530033e-08 4.299908e-06 \n", - "145250 8.224608e-06 0.058361 9.212288e-01 9.705171e-08 5.440121e-05 \n", - "145251 9.183387e-01 0.081601 5.612090e-08 1.088283e-08 5.225256e-07 \n", - "145252 9.203915e-07 0.003578 2.372825e-01 1.582836e-06 3.307252e-07 \n", - "\n", - " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n", - "0 7.851924e-09 1.158422e-02 0.962538 0.004222 9.599869e-23 ... \n", - "1 1.613036e-04 6.763449e-06 0.070947 0.882463 2.232464e-03 ... \n", - "2 7.296442e-04 1.884172e-06 0.004029 0.945838 1.014722e-02 ... \n", - "3 7.227585e-08 1.448203e-09 0.066538 0.933450 1.206630e-06 ... \n", - "4 4.374311e-06 4.566837e-09 0.001334 0.997391 1.580417e-06 ... \n", - "... ... ... ... ... ... ... \n", - "145248 5.049759e-08 4.010809e-07 0.917842 0.082153 2.154302e-17 ... \n", - "145249 1.255851e-08 1.224208e-06 0.058622 0.941370 1.332795e-12 ... \n", - "145250 2.034389e-02 3.132630e-06 0.000268 0.083680 8.789707e-01 ... \n", - "145251 2.566383e-07 5.976933e-05 0.875834 0.123030 2.631276e-12 ... \n", - "145252 7.591362e-01 6.988637e-08 0.000032 0.037757 2.462795e-01 ... \n", - "\n", - " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n", - "0 1.726260e-240 0.000000e+00 3.324009e-02 0.984 0.000 0.000 0.0 \n", - "1 4.396802e-02 3.888647e-04 1.077146e-08 0.106 0.816 0.038 0.0 \n", - "2 3.329283e-02 6.693269e-03 3.213404e-09 0.008 0.950 0.002 0.0 \n", - "3 9.908823e-06 1.371448e-07 1.280235e-09 0.078 0.922 0.000 0.0 \n", - "4 1.273086e-03 7.811306e-07 5.594472e-10 0.004 0.988 0.000 0.0 \n", - "... ... ... ... ... ... ... ... \n", - "145248 6.191352e-08 0.000000e+00 4.958509e-06 0.968 0.032 0.000 0.0 \n", - "145249 7.656931e-06 3.415083e-47 2.271880e-07 0.018 0.972 0.000 0.0 \n", - "145250 2.052535e-04 3.687570e-02 1.393421e-09 0.000 0.040 0.946 0.0 \n", - "145251 2.521124e-07 5.749375e-08 1.135236e-03 0.992 0.008 0.000 0.0 \n", - "145252 2.927400e-06 7.159244e-01 8.608624e-140 0.000 0.018 0.110 0.0 \n", - "\n", - " rf_4 rf_5 rf_6 \n", - "0 0.000 0.000 0.016 \n", - "1 0.034 0.006 0.000 \n", - "2 0.032 0.008 0.000 \n", - "3 0.000 0.000 0.000 \n", - "4 0.008 0.000 0.000 \n", - "... ... ... ... \n", - "145248 0.000 0.000 0.000 \n", - "145249 0.010 0.000 0.000 \n", - "145250 0.000 0.014 0.000 \n", - "145251 0.000 0.000 0.000 \n", - "145252 0.000 0.872 0.000 \n", - "\n", - "[145253 rows x 21 columns]\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xgb_0xgb_1xgb_2xgb_3xgb_4xgb_5xgb_6lgb_0lgb_1lgb_2...lgb_4lgb_5lgb_6rf_0rf_1rf_2rf_3rf_4rf_5rf_6
00.9914930.0002104.796058e-066.178684e-086.947614e-072.410490e-098.291459e-030.0000000.0000000.000000...0.000000e+000.0000001.000000e+000.990.000.000.00.000.000.01
10.0247310.9643726.387765e-044.205048e-081.006575e-021.879628e-044.830114e-060.0650730.8773540.001678...5.530599e-020.0005898.601484e-110.090.800.050.00.060.000.00
20.0007800.9797768.593459e-041.267791e-071.710379e-021.477527e-032.521008e-060.0051640.9338490.016355...3.657553e-020.0080570.000000e+000.010.970.000.00.010.010.00
30.0426950.9573042.283268e-084.387427e-084.175481e-074.406019e-086.909629e-100.0543920.9456080.000000...4.285638e-080.0000000.000000e+000.040.960.000.00.000.000.00
40.0004570.9993343.366338e-064.893879e-082.045808e-047.889498e-071.415576e-090.0013670.9958570.000000...2.776106e-030.0000000.000000e+000.000.990.000.00.010.000.00
\n", + "

5 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n", + "0 0.991493 0.000210 4.796058e-06 6.178684e-08 6.947614e-07 2.410490e-09 \n", + "1 0.024731 0.964372 6.387765e-04 4.205048e-08 1.006575e-02 1.879628e-04 \n", + "2 0.000780 0.979776 8.593459e-04 1.267791e-07 1.710379e-02 1.477527e-03 \n", + "3 0.042695 0.957304 2.283268e-08 4.387427e-08 4.175481e-07 4.406019e-08 \n", + "4 0.000457 0.999334 3.366338e-06 4.893879e-08 2.045808e-04 7.889498e-07 \n", + "\n", + " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 lgb_5 \\\n", + "0 8.291459e-03 0.000000 0.000000 0.000000 ... 0.000000e+00 0.000000 \n", + "1 4.830114e-06 0.065073 0.877354 0.001678 ... 5.530599e-02 0.000589 \n", + "2 2.521008e-06 0.005164 0.933849 0.016355 ... 3.657553e-02 0.008057 \n", + "3 6.909629e-10 0.054392 0.945608 0.000000 ... 4.285638e-08 0.000000 \n", + "4 1.415576e-09 0.001367 0.995857 0.000000 ... 2.776106e-03 0.000000 \n", + "\n", + " lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n", + "0 1.000000e+00 0.99 0.00 0.00 0.0 0.00 0.00 0.01 \n", + "1 8.601484e-11 0.09 0.80 0.05 0.0 0.06 0.00 0.00 \n", + "2 0.000000e+00 0.01 0.97 0.00 0.0 0.01 0.01 0.00 \n", + "3 0.000000e+00 0.04 0.96 0.00 0.0 0.00 0.00 0.00 \n", + "4 0.000000e+00 0.00 0.99 0.00 0.0 0.01 0.00 0.00 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "# 模型输出的测试集,7个特征对应7个标签的预测概率\n", - "print(stack_ds.X_test) " + "stack_ds.X_test.head()" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 30, "id": "b9db35dc", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wall time: 1min 53s\n" + ] + } + ], "source": [ + "%%time\n", "# 用lr做最后一层\n", "stacker = Classifier(dataset=stack_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n", "predict_stack = stacker.predict()" @@ -507,7 +807,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 32, "id": "a4a48219", "metadata": {}, "outputs": [ @@ -515,19 +815,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "[[9.95173402e-01 2.67623709e-03 4.23846755e-08 ... 3.15435935e-05\n", - " 5.66194220e-06 2.11044140e-03]\n", - " [2.23612439e-02 9.70927685e-01 1.23929922e-03 ... 4.49727904e-03\n", - " 8.73983383e-04 9.97020226e-05]\n", - " [6.22588197e-03 9.89402233e-01 9.81655972e-04 ... 2.83331258e-03\n", - " 5.22139184e-04 3.45071569e-05]\n", + "[[9.99137967e-01 4.76574513e-04 3.32764186e-07 ... 3.03237700e-06\n", + " 1.62161211e-06 3.79326620e-04]\n", + " [2.00732175e-02 9.71682830e-01 1.51484266e-03 ... 5.69138554e-03\n", + " 9.40725428e-04 9.60822625e-05]\n", + " [5.54556002e-03 9.91048437e-01 8.04840682e-04 ... 2.11437934e-03\n", + " 4.56463787e-04 3.00919502e-05]\n", " ...\n", - " [5.36335125e-06 2.06267200e-03 9.90604140e-01 ... 8.55252386e-04\n", - " 4.18405061e-03 1.64678945e-05]\n", - " [9.96602824e-01 2.15991442e-03 7.27481581e-08 ... 3.63552051e-05\n", - " 6.80942632e-06 1.19199377e-03]\n", - " [5.89156494e-05 1.15333400e-03 1.09178439e-02 ... 3.09244417e-04\n", - " 9.85167196e-01 2.21261408e-05]]\n" + " [4.60179790e-06 1.78298095e-03 9.91553958e-01 ... 7.26752933e-04\n", + " 3.79135124e-03 1.53584401e-05]\n", + " [9.96307096e-01 2.43558944e-03 1.01596361e-07 ... 3.94985596e-05\n", + " 7.41569805e-06 1.20819024e-03]\n", + " [5.34671504e-05 7.62534718e-04 5.58323657e-03 ... 2.11410908e-04\n", + " 9.91805379e-01 1.69502656e-05]]\n" ] } ], @@ -553,7 +853,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 33, "id": "a28806a0", "metadata": {}, "outputs": [ @@ -561,9 +861,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.9284696357390209\n", - "0.8890005714167694\n", - "0.9511404239499356\n" + "0.9254473229468583\n", + "0.8412562907478676\n", + "0.9535087055000585\n" ] } ], @@ -573,6 +873,28 @@ "print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 14:].values, axis=1),y_test)) # RF" ] }, + { + "cell_type": "code", + "execution_count": 35, + "id": "1db92fec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9537840870756542\n", + "Wall time: 50.8 s\n" + ] + } + ], + "source": [ + "%%time\n", + "# 测试单模运行结果是否一致\n", + "rf_predict = rf_model(X_train, y_train, X_test, None)\n", + "print(accuracy_score(np.argmax(rf_predict, axis=1),y_test)) # RF" + ] + }, { "cell_type": "markdown", "id": "2e9423ce", @@ -583,7 +905,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 37, "id": "d2b50ba4", "metadata": {}, "outputs": [ @@ -591,8 +913,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "主观根据结果blending: 0.9425209806337908\n", - "根据最优权重的blending: 0.9488616414118813\n" + "主观根据结果blending: 0.9504106627746071\n", + "根据最优权重的线性加权: 0.9530749795184953\n" ] } ], @@ -603,19 +925,74 @@ "rf_t = stack_ds.X_test.iloc[:, 14:].values\n", "\n", "# 根据分数好坏随机定\n", - "result = 0.3*xgb_t+0.2*lgb_t+0.5*rf_t\n", + "result = 0.2*xgb_t+0.1*lgb_t+0.7*rf_t\n", "print('主观根据结果blending:', accuracy_score(np.argmax(result, axis=1), y_test))\n", - "# 根据上面提供的最优权重 Best Weights: [0.36556831 0.00303401 0.63139768]\n", - "result = 0.36556831*xgb_t+0.00303401*lgb_t+0.63139768*rf_t\n", - "print('根据最优权重的blending:',accuracy_score(np.argmax(result, axis=1), y_test))" + "# 根据上面提供的最优权重 Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n", + "result = 2.53464919e-01*xgb_t+1.48562205e-20*lgb_t+7.46535081e-01*rf_t\n", + "print('根据最优权重的线性加权:',accuracy_score(np.argmax(result, axis=1), y_test))" + ] + }, + { + "cell_type": "markdown", + "id": "e8ce00b2", + "metadata": {}, + "source": [ + "可以观察到最优权重比我们主观选权重更优,单对比单模结果反而下降了" + ] + }, + { + "cell_type": "markdown", + "id": "d4a12d6d", + "metadata": {}, + "source": [ + "### Blending的分数" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "ffe1cf4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wall time: 14min 10s\n" + ] + } + ], + "source": [ + "%%time\n", + "blend_ds = pipeline.blend(seed=111)\n", + "blender = Classifier(dataset=blend_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n", + "predict_blend = blender.predict()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "506adffb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9546859617357301\n" + ] + } + ], + "source": [ + "print(accuracy_score(np.argmax(predict_blend, axis=1), y_test))" ] }, { "cell_type": "markdown", - "id": "dfec8968", + "id": "e84c19ae", "metadata": {}, "source": [ - "可以观察到最优权重比我们主观选权重更优" + "使用Blending的分数有所提升" ] }, { @@ -623,12 +1000,12 @@ "id": "e8daf1e3", "metadata": {}, "source": [ - "### stacking的分数" + "### Stacking的分数" ] }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 38, "id": "4930b407", "metadata": {}, "outputs": [ @@ -636,7 +1013,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.957439777491687\n" + "0.9589887988544127\n" ] } ], @@ -644,12 +1021,20 @@ "print(accuracy_score(np.argmax(predict_stack, axis=1), y_test))" ] }, + { + "cell_type": "markdown", + "id": "d5ad3c4b", + "metadata": {}, + "source": [ + "可以明显看到提升的效果" + ] + }, { "cell_type": "markdown", "id": "6311fbd7", "metadata": {}, "source": [ - "## 再说结论,该数据集(fetch_covtype)Stacking的方法更好" + "## 再说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好" ] } ], diff --git a/竞赛优胜技巧/Stacking.ipynb b/竞赛优胜技巧/Stacking.ipynb index ef7f238..4ae86fe 100644 --- a/竞赛优胜技巧/Stacking.ipynb +++ b/竞赛优胜技巧/Stacking.ipynb @@ -13,7 +13,7 @@ "id": "0b365a02", "metadata": {}, "source": [ - "## 先说结论,该数据集(fetch_covtype)Stacking的方法比线性加权更好\n", + "## 先说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好\n", "比赛中我们常用线性加权作为最终的融合方式,我们同样也会好奇怎样的线性加权权重更好,下面也会举例子\n", "参考:https://github.com/rushter/heamy/tree/master/examples" ] @@ -67,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "id": "69632c6a", "metadata": {}, "outputs": [ @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 2, "id": "ca421279", "metadata": {}, "outputs": [], @@ -119,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 3, "id": "9a0fabe1", "metadata": {}, "outputs": [], @@ -130,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 6, "id": "5bd75178", "metadata": {}, "outputs": [ @@ -153,14 +153,14 @@ "print(y)\n", "ord = OrdinalEncoder()\n", "y = ord.fit_transform(y.reshape(-1, 1))\n", - "y = y_enc.reshape(-1, )\n", + "y = y.reshape(-1, )\n", "print('七分类任务,处理后:',np.unique(y))\n", "print(y)" ] }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 7, "id": "23d9778c", "metadata": {}, "outputs": [ @@ -182,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 8, "id": "eac48668", "metadata": {}, "outputs": [ @@ -204,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 9, "id": "fba3f975", "metadata": {}, "outputs": [ @@ -226,7 +226,7 @@ " 0.000e+00]])" ] }, - "execution_count": 50, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 10, "id": "e8393e73", "metadata": {}, "outputs": [], @@ -293,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 11, "id": "78ab0083", "metadata": {}, "outputs": [], @@ -313,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 13, "id": "173ef0f0", "metadata": {}, "outputs": [ @@ -321,185 +321,485 @@ "name": "stdout", "output_type": "stream", "text": [ - "Best Score (log_loss): 0.18744137777851164\n", - "Best Weights: [0.36556831 0.00303401 0.63139768]\n" + "Best Score (log_loss): 0.18646435443714865\n", + "Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n", + "Wall time: 14min 19s\n" ] }, { "data": { "text/plain": [ - "array([0.36556831, 0.00303401, 0.63139768])" + "array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])" ] }, - "execution_count": 53, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "pipeline.find_weights(scorer=log_loss, ) # 输出最优权重组合" + "%%time\n", + "pipeline.find_weights(scorer=log_loss) # 输出最优权重组合" ] }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 22, "id": "80726d19", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wall time: 1h 39min 20s\n" + ] + } + ], "source": [ + "%%time\n", "# 5折训练构建5折模型特征集,这里比较耗时\n", - "stack_ds = pipeline.stack(k=5,stratify=False,seed=42,full_test=False) # full_test指明预测全部还是预测当前折的验证集" + "stack_ds = pipeline.stack(k=5,seed=42)" ] }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 26, "id": "b25bba3c", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n", - "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 \n", - "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 \n", - "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 \n", - "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 \n", - "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 \n", - "... ... ... ... ... ... \n", - "435754 0.988518 0.011477 3.190797e-09 5.645121e-08 2.940739e-09 \n", - "435755 0.969212 0.030723 2.142020e-08 1.572054e-05 4.321913e-07 \n", - "435756 0.415850 0.584142 4.283793e-08 7.367601e-08 6.148067e-07 \n", - "435757 0.602601 0.397399 6.606462e-10 1.015894e-09 7.221973e-08 \n", - "435758 0.834587 0.165411 3.267833e-09 2.057172e-08 2.078704e-08 \n", - "\n", - " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n", - "0 1.725062e-06 1.048052e-06 0.172406 0.812678 1.416886e-06 ... \n", - "1 1.435787e-09 1.603579e-10 0.008114 0.991886 0.000000e+00 ... \n", - "2 6.384080e-10 2.823794e-08 0.817627 0.182372 0.000000e+00 ... \n", - "3 2.230641e-06 6.630235e-05 0.465733 0.534184 0.000000e+00 ... \n", - "4 5.604260e-07 9.037877e-04 0.932050 0.043451 0.000000e+00 ... \n", - "... ... ... ... ... ... ... \n", - "435754 1.530261e-08 4.466830e-06 0.970593 0.029399 0.000000e+00 ... \n", - "435755 5.574208e-10 4.977021e-05 0.862591 0.136644 0.000000e+00 ... \n", - "435756 2.371389e-06 5.185283e-06 0.466886 0.533039 0.000000e+00 ... \n", - "435757 2.326313e-09 9.193871e-09 0.674250 0.325750 4.092880e-211 ... \n", - "435758 5.976972e-08 2.204258e-06 0.709320 0.290680 0.000000e+00 ... \n", - "\n", - " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n", - "0 1.486358e-02 5.093522e-05 6.805300e-08 0.06 0.92 0.0 0.0 \n", - "1 0.000000e+00 0.000000e+00 0.000000e+00 0.12 0.88 0.0 0.0 \n", - "2 4.452850e-07 7.825338e-09 1.012052e-07 0.63 0.37 0.0 0.0 \n", - "3 0.000000e+00 0.000000e+00 8.245405e-05 0.56 0.44 0.0 0.0 \n", - "4 0.000000e+00 0.000000e+00 2.449972e-02 0.95 0.04 0.0 0.0 \n", - "... ... ... ... ... ... ... ... \n", - "435754 0.000000e+00 0.000000e+00 7.871809e-06 0.97 0.03 0.0 0.0 \n", - "435755 0.000000e+00 0.000000e+00 7.647430e-04 0.93 0.06 0.0 0.0 \n", - "435756 0.000000e+00 0.000000e+00 7.493861e-05 0.45 0.55 0.0 0.0 \n", - "435757 0.000000e+00 0.000000e+00 0.000000e+00 0.52 0.48 0.0 0.0 \n", - "435758 0.000000e+00 0.000000e+00 0.000000e+00 0.87 0.13 0.0 0.0 \n", - "\n", - " rf_4 rf_5 rf_6 \n", - "0 0.02 0.0 0.00 \n", - "1 0.00 0.0 0.00 \n", - "2 0.00 0.0 0.00 \n", - "3 0.00 0.0 0.00 \n", - "4 0.00 0.0 0.01 \n", - "... ... ... ... \n", - "435754 0.00 0.0 0.00 \n", - "435755 0.00 0.0 0.01 \n", - "435756 0.00 0.0 0.00 \n", - "435757 0.00 0.0 0.00 \n", - "435758 0.00 0.0 0.00 \n", - "\n", - "[435759 rows x 21 columns]\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xgb_0xgb_1xgb_2xgb_3xgb_4xgb_5xgb_6lgb_0lgb_1lgb_2...lgb_4lgb_5lgb_6rf_0rf_1rf_2rf_3rf_4rf_5rf_6
00.1771790.8187282.185222e-079.264143e-094.090067e-031.725062e-061.048052e-060.1796250.8046840.000003...1.562435e-026.370849e-051.004259e-080.030.960.00.00.010.00.0
10.0051550.9948457.055579e-101.326343e-086.331572e-091.435787e-091.603579e-100.0081140.9918860.000000...0.000000e+000.000000e+000.000000e+000.130.870.00.00.000.00.0
20.2934920.7065083.650662e-101.017633e-098.823530e-096.384080e-102.823794e-080.8314450.1685550.000000...4.999034e-074.015190e-094.997854e-090.630.370.00.00.000.00.0
30.4781120.5218163.207779e-062.878019e-081.076500e-082.230641e-066.630235e-050.4657330.5341840.000000...0.000000e+000.000000e+008.245405e-050.550.450.00.00.000.00.0
40.9924300.0066521.233117e-051.887496e-071.569583e-065.604260e-079.037877e-040.9320500.0434510.000000...0.000000e+000.000000e+002.449972e-020.970.030.00.00.000.00.0
\n", + "

5 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n", + "0 0.177179 0.818728 2.185222e-07 9.264143e-09 4.090067e-03 1.725062e-06 \n", + "1 0.005155 0.994845 7.055579e-10 1.326343e-08 6.331572e-09 1.435787e-09 \n", + "2 0.293492 0.706508 3.650662e-10 1.017633e-09 8.823530e-09 6.384080e-10 \n", + "3 0.478112 0.521816 3.207779e-06 2.878019e-08 1.076500e-08 2.230641e-06 \n", + "4 0.992430 0.006652 1.233117e-05 1.887496e-07 1.569583e-06 5.604260e-07 \n", + "\n", + " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 \\\n", + "0 1.048052e-06 0.179625 0.804684 0.000003 ... 1.562435e-02 \n", + "1 1.603579e-10 0.008114 0.991886 0.000000 ... 0.000000e+00 \n", + "2 2.823794e-08 0.831445 0.168555 0.000000 ... 4.999034e-07 \n", + "3 6.630235e-05 0.465733 0.534184 0.000000 ... 0.000000e+00 \n", + "4 9.037877e-04 0.932050 0.043451 0.000000 ... 0.000000e+00 \n", + "\n", + " lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n", + "0 6.370849e-05 1.004259e-08 0.03 0.96 0.0 0.0 0.01 0.0 0.0 \n", + "1 0.000000e+00 0.000000e+00 0.13 0.87 0.0 0.0 0.00 0.0 0.0 \n", + "2 4.015190e-09 4.997854e-09 0.63 0.37 0.0 0.0 0.00 0.0 0.0 \n", + "3 0.000000e+00 8.245405e-05 0.55 0.45 0.0 0.0 0.00 0.0 0.0 \n", + "4 0.000000e+00 2.449972e-02 0.97 0.03 0.0 0.0 0.00 0.0 0.0 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "# 模型输出的训练集,7个特征对应7个标签的预测概率\n", - "print(stack_ds.X_train)" + "stack_ds.X_train.head()" ] }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 28, "id": "835205e9", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 \\\n", - "0 9.876224e-01 0.000789 2.774616e-06 4.129093e-07 1.311387e-06 \n", - "1 5.139124e-02 0.929659 1.852793e-03 1.518293e-07 1.692924e-02 \n", - "2 7.695035e-04 0.973729 6.878623e-04 1.573823e-07 2.408167e-02 \n", - "3 3.376913e-02 0.966229 2.024872e-07 7.321523e-08 1.071163e-06 \n", - "4 1.013981e-03 0.998553 3.794874e-06 8.755425e-08 4.243054e-04 \n", - "... ... ... ... ... ... \n", - "145248 9.615189e-01 0.038480 6.486028e-08 1.744931e-08 1.069370e-06 \n", - "145249 3.055384e-02 0.969440 2.475371e-07 5.530033e-08 4.299908e-06 \n", - "145250 8.224608e-06 0.058361 9.212288e-01 9.705171e-08 5.440121e-05 \n", - "145251 9.183387e-01 0.081601 5.612090e-08 1.088283e-08 5.225256e-07 \n", - "145252 9.203915e-07 0.003578 2.372825e-01 1.582836e-06 3.307252e-07 \n", - "\n", - " xgb_5 xgb_6 lgb_0 lgb_1 lgb_2 ... \\\n", - "0 7.851924e-09 1.158422e-02 0.962538 0.004222 9.599869e-23 ... \n", - "1 1.613036e-04 6.763449e-06 0.070947 0.882463 2.232464e-03 ... \n", - "2 7.296442e-04 1.884172e-06 0.004029 0.945838 1.014722e-02 ... \n", - "3 7.227585e-08 1.448203e-09 0.066538 0.933450 1.206630e-06 ... \n", - "4 4.374311e-06 4.566837e-09 0.001334 0.997391 1.580417e-06 ... \n", - "... ... ... ... ... ... ... \n", - "145248 5.049759e-08 4.010809e-07 0.917842 0.082153 2.154302e-17 ... \n", - "145249 1.255851e-08 1.224208e-06 0.058622 0.941370 1.332795e-12 ... \n", - "145250 2.034389e-02 3.132630e-06 0.000268 0.083680 8.789707e-01 ... \n", - "145251 2.566383e-07 5.976933e-05 0.875834 0.123030 2.631276e-12 ... \n", - "145252 7.591362e-01 6.988637e-08 0.000032 0.037757 2.462795e-01 ... \n", - "\n", - " lgb_4 lgb_5 lgb_6 rf_0 rf_1 rf_2 rf_3 \\\n", - "0 1.726260e-240 0.000000e+00 3.324009e-02 0.984 0.000 0.000 0.0 \n", - "1 4.396802e-02 3.888647e-04 1.077146e-08 0.106 0.816 0.038 0.0 \n", - "2 3.329283e-02 6.693269e-03 3.213404e-09 0.008 0.950 0.002 0.0 \n", - "3 9.908823e-06 1.371448e-07 1.280235e-09 0.078 0.922 0.000 0.0 \n", - "4 1.273086e-03 7.811306e-07 5.594472e-10 0.004 0.988 0.000 0.0 \n", - "... ... ... ... ... ... ... ... \n", - "145248 6.191352e-08 0.000000e+00 4.958509e-06 0.968 0.032 0.000 0.0 \n", - "145249 7.656931e-06 3.415083e-47 2.271880e-07 0.018 0.972 0.000 0.0 \n", - "145250 2.052535e-04 3.687570e-02 1.393421e-09 0.000 0.040 0.946 0.0 \n", - "145251 2.521124e-07 5.749375e-08 1.135236e-03 0.992 0.008 0.000 0.0 \n", - "145252 2.927400e-06 7.159244e-01 8.608624e-140 0.000 0.018 0.110 0.0 \n", - "\n", - " rf_4 rf_5 rf_6 \n", - "0 0.000 0.000 0.016 \n", - "1 0.034 0.006 0.000 \n", - "2 0.032 0.008 0.000 \n", - "3 0.000 0.000 0.000 \n", - "4 0.008 0.000 0.000 \n", - "... ... ... ... \n", - "145248 0.000 0.000 0.000 \n", - "145249 0.010 0.000 0.000 \n", - "145250 0.000 0.014 0.000 \n", - "145251 0.000 0.000 0.000 \n", - "145252 0.000 0.872 0.000 \n", - "\n", - "[145253 rows x 21 columns]\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xgb_0xgb_1xgb_2xgb_3xgb_4xgb_5xgb_6lgb_0lgb_1lgb_2...lgb_4lgb_5lgb_6rf_0rf_1rf_2rf_3rf_4rf_5rf_6
00.9914930.0002104.796058e-066.178684e-086.947614e-072.410490e-098.291459e-030.0000000.0000000.000000...0.000000e+000.0000001.000000e+000.990.000.000.00.000.000.01
10.0247310.9643726.387765e-044.205048e-081.006575e-021.879628e-044.830114e-060.0650730.8773540.001678...5.530599e-020.0005898.601484e-110.090.800.050.00.060.000.00
20.0007800.9797768.593459e-041.267791e-071.710379e-021.477527e-032.521008e-060.0051640.9338490.016355...3.657553e-020.0080570.000000e+000.010.970.000.00.010.010.00
30.0426950.9573042.283268e-084.387427e-084.175481e-074.406019e-086.909629e-100.0543920.9456080.000000...4.285638e-080.0000000.000000e+000.040.960.000.00.000.000.00
40.0004570.9993343.366338e-064.893879e-082.045808e-047.889498e-071.415576e-090.0013670.9958570.000000...2.776106e-030.0000000.000000e+000.000.990.000.00.010.000.00
\n", + "

5 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " xgb_0 xgb_1 xgb_2 xgb_3 xgb_4 xgb_5 \\\n", + "0 0.991493 0.000210 4.796058e-06 6.178684e-08 6.947614e-07 2.410490e-09 \n", + "1 0.024731 0.964372 6.387765e-04 4.205048e-08 1.006575e-02 1.879628e-04 \n", + "2 0.000780 0.979776 8.593459e-04 1.267791e-07 1.710379e-02 1.477527e-03 \n", + "3 0.042695 0.957304 2.283268e-08 4.387427e-08 4.175481e-07 4.406019e-08 \n", + "4 0.000457 0.999334 3.366338e-06 4.893879e-08 2.045808e-04 7.889498e-07 \n", + "\n", + " xgb_6 lgb_0 lgb_1 lgb_2 ... lgb_4 lgb_5 \\\n", + "0 8.291459e-03 0.000000 0.000000 0.000000 ... 0.000000e+00 0.000000 \n", + "1 4.830114e-06 0.065073 0.877354 0.001678 ... 5.530599e-02 0.000589 \n", + "2 2.521008e-06 0.005164 0.933849 0.016355 ... 3.657553e-02 0.008057 \n", + "3 6.909629e-10 0.054392 0.945608 0.000000 ... 4.285638e-08 0.000000 \n", + "4 1.415576e-09 0.001367 0.995857 0.000000 ... 2.776106e-03 0.000000 \n", + "\n", + " lgb_6 rf_0 rf_1 rf_2 rf_3 rf_4 rf_5 rf_6 \n", + "0 1.000000e+00 0.99 0.00 0.00 0.0 0.00 0.00 0.01 \n", + "1 8.601484e-11 0.09 0.80 0.05 0.0 0.06 0.00 0.00 \n", + "2 0.000000e+00 0.01 0.97 0.00 0.0 0.01 0.01 0.00 \n", + "3 0.000000e+00 0.04 0.96 0.00 0.0 0.00 0.00 0.00 \n", + "4 0.000000e+00 0.00 0.99 0.00 0.0 0.01 0.00 0.00 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "# 模型输出的测试集,7个特征对应7个标签的预测概率\n", - "print(stack_ds.X_test) " + "stack_ds.X_test.head()" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 30, "id": "b9db35dc", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wall time: 1min 53s\n" + ] + } + ], "source": [ + "%%time\n", "# 用lr做最后一层\n", "stacker = Classifier(dataset=stack_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n", "predict_stack = stacker.predict()" @@ -507,7 +807,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 32, "id": "a4a48219", "metadata": {}, "outputs": [ @@ -515,19 +815,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "[[9.95173402e-01 2.67623709e-03 4.23846755e-08 ... 3.15435935e-05\n", - " 5.66194220e-06 2.11044140e-03]\n", - " [2.23612439e-02 9.70927685e-01 1.23929922e-03 ... 4.49727904e-03\n", - " 8.73983383e-04 9.97020226e-05]\n", - " [6.22588197e-03 9.89402233e-01 9.81655972e-04 ... 2.83331258e-03\n", - " 5.22139184e-04 3.45071569e-05]\n", + "[[9.99137967e-01 4.76574513e-04 3.32764186e-07 ... 3.03237700e-06\n", + " 1.62161211e-06 3.79326620e-04]\n", + " [2.00732175e-02 9.71682830e-01 1.51484266e-03 ... 5.69138554e-03\n", + " 9.40725428e-04 9.60822625e-05]\n", + " [5.54556002e-03 9.91048437e-01 8.04840682e-04 ... 2.11437934e-03\n", + " 4.56463787e-04 3.00919502e-05]\n", " ...\n", - " [5.36335125e-06 2.06267200e-03 9.90604140e-01 ... 8.55252386e-04\n", - " 4.18405061e-03 1.64678945e-05]\n", - " [9.96602824e-01 2.15991442e-03 7.27481581e-08 ... 3.63552051e-05\n", - " 6.80942632e-06 1.19199377e-03]\n", - " [5.89156494e-05 1.15333400e-03 1.09178439e-02 ... 3.09244417e-04\n", - " 9.85167196e-01 2.21261408e-05]]\n" + " [4.60179790e-06 1.78298095e-03 9.91553958e-01 ... 7.26752933e-04\n", + " 3.79135124e-03 1.53584401e-05]\n", + " [9.96307096e-01 2.43558944e-03 1.01596361e-07 ... 3.94985596e-05\n", + " 7.41569805e-06 1.20819024e-03]\n", + " [5.34671504e-05 7.62534718e-04 5.58323657e-03 ... 2.11410908e-04\n", + " 9.91805379e-01 1.69502656e-05]]\n" ] } ], @@ -553,7 +853,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 33, "id": "a28806a0", "metadata": {}, "outputs": [ @@ -561,9 +861,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.9284696357390209\n", - "0.8890005714167694\n", - "0.9511404239499356\n" + "0.9254473229468583\n", + "0.8412562907478676\n", + "0.9535087055000585\n" ] } ], @@ -573,6 +873,28 @@ "print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 14:].values, axis=1),y_test)) # RF" ] }, + { + "cell_type": "code", + "execution_count": 35, + "id": "1db92fec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9537840870756542\n", + "Wall time: 50.8 s\n" + ] + } + ], + "source": [ + "%%time\n", + "# 测试单模运行结果是否一致\n", + "rf_predict = rf_model(X_train, y_train, X_test, None)\n", + "print(accuracy_score(np.argmax(rf_predict, axis=1),y_test)) # RF" + ] + }, { "cell_type": "markdown", "id": "2e9423ce", @@ -583,7 +905,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 37, "id": "d2b50ba4", "metadata": {}, "outputs": [ @@ -591,8 +913,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "主观根据结果blending: 0.9425209806337908\n", - "根据最优权重的blending: 0.9488616414118813\n" + "主观根据结果blending: 0.9504106627746071\n", + "根据最优权重的线性加权: 0.9530749795184953\n" ] } ], @@ -603,19 +925,74 @@ "rf_t = stack_ds.X_test.iloc[:, 14:].values\n", "\n", "# 根据分数好坏随机定\n", - "result = 0.3*xgb_t+0.2*lgb_t+0.5*rf_t\n", + "result = 0.2*xgb_t+0.1*lgb_t+0.7*rf_t\n", "print('主观根据结果blending:', accuracy_score(np.argmax(result, axis=1), y_test))\n", - "# 根据上面提供的最优权重 Best Weights: [0.36556831 0.00303401 0.63139768]\n", - "result = 0.36556831*xgb_t+0.00303401*lgb_t+0.63139768*rf_t\n", - "print('根据最优权重的blending:',accuracy_score(np.argmax(result, axis=1), y_test))" + "# 根据上面提供的最优权重 Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n", + "result = 2.53464919e-01*xgb_t+1.48562205e-20*lgb_t+7.46535081e-01*rf_t\n", + "print('根据最优权重的线性加权:',accuracy_score(np.argmax(result, axis=1), y_test))" + ] + }, + { + "cell_type": "markdown", + "id": "e8ce00b2", + "metadata": {}, + "source": [ + "可以观察到最优权重比我们主观选权重更优,单对比单模结果反而下降了" + ] + }, + { + "cell_type": "markdown", + "id": "d4a12d6d", + "metadata": {}, + "source": [ + "### Blending的分数" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "ffe1cf4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wall time: 14min 10s\n" + ] + } + ], + "source": [ + "%%time\n", + "blend_ds = pipeline.blend(seed=111)\n", + "blender = Classifier(dataset=blend_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n", + "predict_blend = blender.predict()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "506adffb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9546859617357301\n" + ] + } + ], + "source": [ + "print(accuracy_score(np.argmax(predict_blend, axis=1), y_test))" ] }, { "cell_type": "markdown", - "id": "dfec8968", + "id": "e84c19ae", "metadata": {}, "source": [ - "可以观察到最优权重比我们主观选权重更优" + "使用Blending的分数有所提升" ] }, { @@ -623,12 +1000,12 @@ "id": "e8daf1e3", "metadata": {}, "source": [ - "### stacking的分数" + "### Stacking的分数" ] }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 38, "id": "4930b407", "metadata": {}, "outputs": [ @@ -636,7 +1013,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.957439777491687\n" + "0.9589887988544127\n" ] } ], @@ -644,12 +1021,20 @@ "print(accuracy_score(np.argmax(predict_stack, axis=1), y_test))" ] }, + { + "cell_type": "markdown", + "id": "d5ad3c4b", + "metadata": {}, + "source": [ + "可以明显看到提升的效果" + ] + }, { "cell_type": "markdown", "id": "6311fbd7", "metadata": {}, "source": [ - "## 再说结论,该数据集(fetch_covtype)Stacking的方法更好" + "## 再说结论,该数据集(fetch_covtype)Stacking的方法相比Blending和线性加权更好" ] } ],