{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Permutation Importance" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DateTeamOpponentGoal ScoredBall Possession %AttemptsOn-TargetOff-TargetBlockedCorners...Yellow CardYellow & RedRedMan of the Match1st GoalRoundPSOGoals in PSOOwn goalsOwn goal Time
014-06-2018RussiaSaudi Arabia540137336...000Yes12.0Group StageNo0NaNNaN
114-06-2018Saudi ArabiaRussia06060332...000NoNaNGroup StageNo0NaNNaN
215-06-2018EgyptUruguay04383320...200NoNaNGroup StageNo0NaNNaN
315-06-2018UruguayEgypt157144645...000Yes89.0Group StageNo0NaNNaN
415-06-2018MoroccoIran064133645...100NoNaNGroup StageNo01.090.0
\n", "

5 rows × 27 columns

\n", "
" ], "text/plain": [ " Date Team Opponent Goal Scored Ball Possession % \\\n", "0 14-06-2018 Russia Saudi Arabia 5 40 \n", "1 14-06-2018 Saudi Arabia Russia 0 60 \n", "2 15-06-2018 Egypt Uruguay 0 43 \n", "3 15-06-2018 Uruguay Egypt 1 57 \n", "4 15-06-2018 Morocco Iran 0 64 \n", "\n", " Attempts On-Target Off-Target Blocked Corners ... Yellow Card \\\n", "0 13 7 3 3 6 ... 0 \n", "1 6 0 3 3 2 ... 0 \n", "2 8 3 3 2 0 ... 2 \n", "3 14 4 6 4 5 ... 0 \n", "4 13 3 6 4 5 ... 1 \n", "\n", " Yellow & Red Red Man of the Match 1st Goal Round PSO \\\n", "0 0 0 Yes 12.0 Group Stage No \n", "1 0 0 No NaN Group Stage No \n", "2 0 0 No NaN Group Stage No \n", "3 0 0 Yes 89.0 Group Stage No \n", "4 0 0 No NaN Group Stage No \n", "\n", " Goals in PSO Own goals Own goal Time \n", "0 0 NaN NaN \n", "1 0 NaN NaN \n", "2 0 NaN NaN \n", "3 0 NaN NaN \n", "4 0 1.0 90.0 \n", "\n", "[5 rows x 27 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv('data/FIFA 2018 Statistics.csv') # 足球赛事数据集\n", "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "怎么知道哪些特征影响多大\n", "\n", "破坏特征:比如把Goal Scored上下的顺序打乱,其它不变,然后比较破坏前和后的区别\n", "\n", "如果前 ≈ 后,说明这个特征对结果没影响\n", "如果后 < 前,说明这个特征对结果影响起正作用\n", "如果后 > 前,说明这个特征对结果影响起负作用(一般不会出现)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 True\n", "1 False\n", "2 False\n", "3 True\n", "4 False\n", "Name: Man of the Match, dtype: bool" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = (data['Man of the Match'] == 'Yes') # 转换标签\n", "y[:5]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Goal ScoredBall Possession %AttemptsOn-TargetOff-TargetBlockedCornersOffsidesFree KicksSavesPass Accuracy %PassesDistance Covered (Kms)Fouls CommittedYellow CardYellow & RedRedGoals in PSO
0540137336311078306118220000
106060332125286511105100000
20438332017378395112122000
315714464511338658911160000
4064133645014286433101221000
\n", "
" ], "text/plain": [ " Goal Scored Ball Possession % Attempts On-Target Off-Target Blocked \\\n", "0 5 40 13 7 3 3 \n", "1 0 60 6 0 3 3 \n", "2 0 43 8 3 3 2 \n", "3 1 57 14 4 6 4 \n", "4 0 64 13 3 6 4 \n", "\n", " Corners Offsides Free Kicks Saves Pass Accuracy % Passes \\\n", "0 6 3 11 0 78 306 \n", "1 2 1 25 2 86 511 \n", "2 0 1 7 3 78 395 \n", "3 5 1 13 3 86 589 \n", "4 5 0 14 2 86 433 \n", "\n", " Distance Covered (Kms) Fouls Committed Yellow Card Yellow & Red Red \\\n", "0 118 22 0 0 0 \n", "1 105 10 0 0 0 \n", "2 112 12 2 0 0 \n", "3 111 6 0 0 0 \n", "4 101 22 1 0 0 \n", "\n", " Goals in PSO \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_names = [i for i in data.columns if data[i].dtype in [np.int64]]\n", "X = data[feature_names]\n", "X.head()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\Anaconda3\\lib\\site-packages\\sklearn\\ensemble\\forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n", " \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n" ] } ], "source": [ "train_X, val_X,train_y,val_y = train_test_split(X, y,random_state=1)\n", "my_model = RandomForestClassifier(random_state=0).fit(train_X,train_y)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
WeightFeature
\n", " 0.0750\n", " \n", " ± 0.1159\n", " \n", " \n", " Goal Scored\n", "
\n", " 0.0625\n", " \n", " ± 0.0791\n", " \n", " \n", " Corners\n", "
\n", " 0.0437\n", " \n", " ± 0.0500\n", " \n", " \n", " Distance Covered (Kms)\n", "
\n", " 0.0375\n", " \n", " ± 0.0729\n", " \n", " \n", " On-Target\n", "
\n", " 0.0375\n", " \n", " ± 0.0468\n", " \n", " \n", " Free Kicks\n", "
\n", " 0.0187\n", " \n", " ± 0.0306\n", " \n", " \n", " Blocked\n", "
\n", " 0.0125\n", " \n", " ± 0.0750\n", " \n", " \n", " Pass Accuracy %\n", "
\n", " 0.0125\n", " \n", " ± 0.0500\n", " \n", " \n", " Yellow Card\n", "
\n", " 0.0063\n", " \n", " ± 0.0468\n", " \n", " \n", " Saves\n", "
\n", " 0.0063\n", " \n", " ± 0.0250\n", " \n", " \n", " Offsides\n", "
\n", " 0.0063\n", " \n", " ± 0.1741\n", " \n", " \n", " Off-Target\n", "
\n", " 0.0000\n", " \n", " ± 0.1046\n", " \n", " \n", " Passes\n", "
\n", " 0\n", " \n", " ± 0.0000\n", " \n", " \n", " Red\n", "
\n", " 0\n", " \n", " ± 0.0000\n", " \n", " \n", " Yellow & Red\n", "
\n", " 0\n", " \n", " ± 0.0000\n", " \n", " \n", " Goals in PSO\n", "
\n", " -0.0312\n", " \n", " ± 0.0884\n", " \n", " \n", " Fouls Committed\n", "
\n", " -0.0375\n", " \n", " ± 0.0919\n", " \n", " \n", " Attempts\n", "
\n", " -0.0500\n", " \n", " ± 0.0500\n", " \n", " \n", " Ball Possession %\n", "
\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import eli5 # pip install eli5\n", "from eli5.sklearn import PermutationImportance\n", "\n", "perm = PermutationImportance(my_model,random_state=1).fit(val_X, val_y)\n", "\n", "eli5.show_weights(perm, feature_names=val_X.columns.tolist())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Partial Dependence Plots\n", "\n", "特征重要性展示了每个特征发挥的作用情况,partial dependence plots 可以展示一个特征怎么影响预测结果\n", "\n", "模型建立完成后进行使用,概述如下:\n", "