You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1438 lines
195 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 信用卡欺诈检测\n",
"基于信用卡交易记录数据,建立分类模型来预测哪些交易记录是异常的,哪些是正常的。\n",
"\n",
"我整理好的数据地址https://pan.baidu.com/s/18vPGelYCXGqp5OCWZWz36A 提取码de0f\n",
"\n",
"kaggle数据地址https://www.kaggle.com/mlg-ulb/creditcardfraud#creditcard.csv\n",
"\n",
"kesci数据地址https://www.kesci.com/mw/dataset/5b56a592fc7e9000103c0442"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 任务目的:\n",
"完成数据集中正常交易数据和异常交易数据的分类,并对测试数据进行预测 0/1进行分类。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 任务流程:\n",
"* 加载数据,观测问题\n",
"* 针对问题给出解决方案\n",
"* 数据集划分\n",
"* 评估方法对比\n",
"* 逻辑回归模型\n",
"* 建模结果分析\n",
"* 方案效果对比\n",
"\n",
"### 主要解决问题:\n",
" (1) 在此项目中,我们首先对数据进行观测,发现了其中样本不均衡的问题,其实我们做任务工作之前都一定要先进行数据检查,看看数据有什么问题,针对这些问题来选择解决方案。\n",
" (2) 这里我们提出了两种方法,下采样和过采样,两条路线来进行对比实验,任何时间问题来了之后,我们都不会一条路走到黑,没有对比就没有优化,通常会得到一个基础模型,然后对各种方法进行对比,找到最合适的,然后在任务开始之前,一定得多想多准备,得到的结果才有可选择的余地。\n",
" (3) 在建模之前,需要对数据进行各种预处理操作,比如数据标准化,缺失值填充等,这些都是必要操作,由于数据本身已经给定了特征,此处我们还没有提到特征工程这个概念,后续实战中我们会逐步引入,其实数据预处理的工作是整个任务中最为重要也是最优难度的一个阶段,数据决定上限,模型逼近这个上限。\n",
" (4) 先选好评估方法再进行建模。建模的目的是为了得到结果但是我们不可能一次就得到最好的结果肯定要尝试很多次所以一定要有一个合适的评估方法比如通用的AUC、ROC、召回率、精确率等也可以根据实际问题自己指定评估指标。\n",
" (5) 选择合适的算法,这里我们使用的逻辑回归,逻辑回归现在使用的很少,但在金融领域还是一个非常具有代表的算法,其简单并具有可推导及解释性,深受金融行业的爱戴。\n",
" (6) 模型调参也是非常重要的不用的调参会导致不同的结果后续实战中我们也会有更多的调参细节对于调参可以参考工具包的API文档了解每个参数的意义再来选择合适的参数值。\n",
" (7) 得到结果一定是和实际任务结合在一起,有时候线下(开发)时效果不错,但是上线后效果差距很大,所以测试环境也是必不可少的。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# 导入工具包\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# 把图轻松的镶嵌到这个notebook中\n",
"%matplotlib inline\n",
"\n",
"import warnings # 忽略普通警告,不打印太多东西\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Time</th>\n",
" <th>V1</th>\n",
" <th>V2</th>\n",
" <th>V3</th>\n",
" <th>V4</th>\n",
" <th>V5</th>\n",
" <th>V6</th>\n",
" <th>V7</th>\n",
" <th>V8</th>\n",
" <th>V9</th>\n",
" <th>...</th>\n",
" <th>V21</th>\n",
" <th>V22</th>\n",
" <th>V23</th>\n",
" <th>V24</th>\n",
" <th>V25</th>\n",
" <th>V26</th>\n",
" <th>V27</th>\n",
" <th>V28</th>\n",
" <th>Amount</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.0</td>\n",
" <td>-1.359807</td>\n",
" <td>-0.072781</td>\n",
" <td>2.536347</td>\n",
" <td>1.378155</td>\n",
" <td>-0.338321</td>\n",
" <td>0.462388</td>\n",
" <td>0.239599</td>\n",
" <td>0.098698</td>\n",
" <td>0.363787</td>\n",
" <td>...</td>\n",
" <td>-0.018307</td>\n",
" <td>0.277838</td>\n",
" <td>-0.110474</td>\n",
" <td>0.066928</td>\n",
" <td>0.128539</td>\n",
" <td>-0.189115</td>\n",
" <td>0.133558</td>\n",
" <td>-0.021053</td>\n",
" <td>149.62</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.0</td>\n",
" <td>1.191857</td>\n",
" <td>0.266151</td>\n",
" <td>0.166480</td>\n",
" <td>0.448154</td>\n",
" <td>0.060018</td>\n",
" <td>-0.082361</td>\n",
" <td>-0.078803</td>\n",
" <td>0.085102</td>\n",
" <td>-0.255425</td>\n",
" <td>...</td>\n",
" <td>-0.225775</td>\n",
" <td>-0.638672</td>\n",
" <td>0.101288</td>\n",
" <td>-0.339846</td>\n",
" <td>0.167170</td>\n",
" <td>0.125895</td>\n",
" <td>-0.008983</td>\n",
" <td>0.014724</td>\n",
" <td>2.69</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>-1.358354</td>\n",
" <td>-1.340163</td>\n",
" <td>1.773209</td>\n",
" <td>0.379780</td>\n",
" <td>-0.503198</td>\n",
" <td>1.800499</td>\n",
" <td>0.791461</td>\n",
" <td>0.247676</td>\n",
" <td>-1.514654</td>\n",
" <td>...</td>\n",
" <td>0.247998</td>\n",
" <td>0.771679</td>\n",
" <td>0.909412</td>\n",
" <td>-0.689281</td>\n",
" <td>-0.327642</td>\n",
" <td>-0.139097</td>\n",
" <td>-0.055353</td>\n",
" <td>-0.059752</td>\n",
" <td>378.66</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.0</td>\n",
" <td>-0.966272</td>\n",
" <td>-0.185226</td>\n",
" <td>1.792993</td>\n",
" <td>-0.863291</td>\n",
" <td>-0.010309</td>\n",
" <td>1.247203</td>\n",
" <td>0.237609</td>\n",
" <td>0.377436</td>\n",
" <td>-1.387024</td>\n",
" <td>...</td>\n",
" <td>-0.108300</td>\n",
" <td>0.005274</td>\n",
" <td>-0.190321</td>\n",
" <td>-1.175575</td>\n",
" <td>0.647376</td>\n",
" <td>-0.221929</td>\n",
" <td>0.062723</td>\n",
" <td>0.061458</td>\n",
" <td>123.50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2.0</td>\n",
" <td>-1.158233</td>\n",
" <td>0.877737</td>\n",
" <td>1.548718</td>\n",
" <td>0.403034</td>\n",
" <td>-0.407193</td>\n",
" <td>0.095921</td>\n",
" <td>0.592941</td>\n",
" <td>-0.270533</td>\n",
" <td>0.817739</td>\n",
" <td>...</td>\n",
" <td>-0.009431</td>\n",
" <td>0.798278</td>\n",
" <td>-0.137458</td>\n",
" <td>0.141267</td>\n",
" <td>-0.206010</td>\n",
" <td>0.502292</td>\n",
" <td>0.219422</td>\n",
" <td>0.215153</td>\n",
" <td>69.99</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 31 columns</p>\n",
"</div>"
],
"text/plain": [
" Time V1 V2 V3 V4 V5 V6 V7 \\\n",
"0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 \n",
"1 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 \n",
"2 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 \n",
"3 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 \n",
"4 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 \n",
"\n",
" V8 V9 ... V21 V22 V23 V24 V25 \\\n",
"0 0.098698 0.363787 ... -0.018307 0.277838 -0.110474 0.066928 0.128539 \n",
"1 0.085102 -0.255425 ... -0.225775 -0.638672 0.101288 -0.339846 0.167170 \n",
"2 0.247676 -1.514654 ... 0.247998 0.771679 0.909412 -0.689281 -0.327642 \n",
"3 0.377436 -1.387024 ... -0.108300 0.005274 -0.190321 -1.175575 0.647376 \n",
"4 -0.270533 0.817739 ... -0.009431 0.798278 -0.137458 0.141267 -0.206010 \n",
"\n",
" V26 V27 V28 Amount Class \n",
"0 -0.189115 0.133558 -0.021053 149.62 0 \n",
"1 0.125895 -0.008983 0.014724 2.69 0 \n",
"2 -0.139097 -0.055353 -0.059752 378.66 0 \n",
"3 -0.221929 0.062723 0.061458 123.50 0 \n",
"4 0.502292 0.219422 0.215153 69.99 0 \n",
"\n",
"[5 rows x 31 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 读取数据\n",
"data = pd.read_csv(\"data/creditcard.csv\")\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 数据情况说明:\n",
"数据集包含由欧洲人于2013年9月使用信用卡进行交易的数据。此数据集显示两天内发生的交易其中284807笔交易中有492笔被盗刷。数据集非常不平衡正例被盗刷占所有交易的0.172这是因为由于保密问题我们无法提供有关数据的原始功能和更多背景信息。特征V1V2... V28是使用PCA获得的主要组件没有用PCA转换的唯一特征是“Class”和“Amount”。特征'Time'包含数据集中每个刷卡时间和第一次刷卡时间之间经过的秒数。特征'Class'是响应变量如果发生被盗刷则取值1否则为0。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 284315\n",
"1 492\n",
"Name: Class, dtype: int64\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# 数据标签分布\n",
"count_classes = pd.value_counts(data['Class'], sort=True).sort_index() # 统计里面不同分类的量\n",
"count_classes.plot(kind='bar') # 使用直方图\n",
"plt.title(\"Fraund class histogram\")\n",
"plt.xlabel(\"Class\")\n",
"plt.ylabel(\"Frequency\")\n",
"print(count_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"明显的正负样本数量有差异正样本为1只有492个负样本有28万个这种情况下如果直接给模型学习那么模型很容易知道只要它把样本预测为负那么准确率就在99.99%以上。\n",
"\n",
"我们不能让模型学到这种歪门技巧。\n",
"\n",
"有两种方案解决:\n",
"* 1和0一样多也就是1也有28万个左右。上采样\n",
"* 0和1一样少也就是28万里只取492个。下采样\n",
"\n",
"两个方案的比较:\n",
"* 第一种需要造一些数据,那么数据就是假的,假的会影响模型在预测真实数据时,结果自然会下降。\n",
"* 第二种方式则会减少真实数据,使得模型可学的数据变少,能力也会减弱。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 数据标准化处理\n",
"\n",
"上面Amount列的值还是原值相比其它列的值过大会导致模型结果出现偏差认为Amount列是非常重要的具体可参考前面讲过的回归分析章节需要对其标准化大的值在区间内依然是大的小的值在区间内依然是小的可以理解为一种缩放。\n",
"\n",
"对逻辑回归来说,所有的训练数据都需要进行标准化。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>V1</th>\n",
" <th>V2</th>\n",
" <th>V3</th>\n",
" <th>V4</th>\n",
" <th>V5</th>\n",
" <th>V6</th>\n",
" <th>V7</th>\n",
" <th>V8</th>\n",
" <th>V9</th>\n",
" <th>V10</th>\n",
" <th>...</th>\n",
" <th>V21</th>\n",
" <th>V22</th>\n",
" <th>V23</th>\n",
" <th>V24</th>\n",
" <th>V25</th>\n",
" <th>V26</th>\n",
" <th>V27</th>\n",
" <th>V28</th>\n",
" <th>Class</th>\n",
" <th>normAmount</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-1.359807</td>\n",
" <td>-0.072781</td>\n",
" <td>2.536347</td>\n",
" <td>1.378155</td>\n",
" <td>-0.338321</td>\n",
" <td>0.462388</td>\n",
" <td>0.239599</td>\n",
" <td>0.098698</td>\n",
" <td>0.363787</td>\n",
" <td>0.090794</td>\n",
" <td>...</td>\n",
" <td>-0.018307</td>\n",
" <td>0.277838</td>\n",
" <td>-0.110474</td>\n",
" <td>0.066928</td>\n",
" <td>0.128539</td>\n",
" <td>-0.189115</td>\n",
" <td>0.133558</td>\n",
" <td>-0.021053</td>\n",
" <td>0</td>\n",
" <td>0.244964</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.191857</td>\n",
" <td>0.266151</td>\n",
" <td>0.166480</td>\n",
" <td>0.448154</td>\n",
" <td>0.060018</td>\n",
" <td>-0.082361</td>\n",
" <td>-0.078803</td>\n",
" <td>0.085102</td>\n",
" <td>-0.255425</td>\n",
" <td>-0.166974</td>\n",
" <td>...</td>\n",
" <td>-0.225775</td>\n",
" <td>-0.638672</td>\n",
" <td>0.101288</td>\n",
" <td>-0.339846</td>\n",
" <td>0.167170</td>\n",
" <td>0.125895</td>\n",
" <td>-0.008983</td>\n",
" <td>0.014724</td>\n",
" <td>0</td>\n",
" <td>-0.342475</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-1.358354</td>\n",
" <td>-1.340163</td>\n",
" <td>1.773209</td>\n",
" <td>0.379780</td>\n",
" <td>-0.503198</td>\n",
" <td>1.800499</td>\n",
" <td>0.791461</td>\n",
" <td>0.247676</td>\n",
" <td>-1.514654</td>\n",
" <td>0.207643</td>\n",
" <td>...</td>\n",
" <td>0.247998</td>\n",
" <td>0.771679</td>\n",
" <td>0.909412</td>\n",
" <td>-0.689281</td>\n",
" <td>-0.327642</td>\n",
" <td>-0.139097</td>\n",
" <td>-0.055353</td>\n",
" <td>-0.059752</td>\n",
" <td>0</td>\n",
" <td>1.160686</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>-0.966272</td>\n",
" <td>-0.185226</td>\n",
" <td>1.792993</td>\n",
" <td>-0.863291</td>\n",
" <td>-0.010309</td>\n",
" <td>1.247203</td>\n",
" <td>0.237609</td>\n",
" <td>0.377436</td>\n",
" <td>-1.387024</td>\n",
" <td>-0.054952</td>\n",
" <td>...</td>\n",
" <td>-0.108300</td>\n",
" <td>0.005274</td>\n",
" <td>-0.190321</td>\n",
" <td>-1.175575</td>\n",
" <td>0.647376</td>\n",
" <td>-0.221929</td>\n",
" <td>0.062723</td>\n",
" <td>0.061458</td>\n",
" <td>0</td>\n",
" <td>0.140534</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>-1.158233</td>\n",
" <td>0.877737</td>\n",
" <td>1.548718</td>\n",
" <td>0.403034</td>\n",
" <td>-0.407193</td>\n",
" <td>0.095921</td>\n",
" <td>0.592941</td>\n",
" <td>-0.270533</td>\n",
" <td>0.817739</td>\n",
" <td>0.753074</td>\n",
" <td>...</td>\n",
" <td>-0.009431</td>\n",
" <td>0.798278</td>\n",
" <td>-0.137458</td>\n",
" <td>0.141267</td>\n",
" <td>-0.206010</td>\n",
" <td>0.502292</td>\n",
" <td>0.219422</td>\n",
" <td>0.215153</td>\n",
" <td>0</td>\n",
" <td>-0.073403</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 30 columns</p>\n",
"</div>"
],
"text/plain": [
" V1 V2 V3 V4 V5 V6 V7 \\\n",
"0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 \n",
"1 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 \n",
"2 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 \n",
"3 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 \n",
"4 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 \n",
"\n",
" V8 V9 V10 ... V21 V22 V23 V24 \\\n",
"0 0.098698 0.363787 0.090794 ... -0.018307 0.277838 -0.110474 0.066928 \n",
"1 0.085102 -0.255425 -0.166974 ... -0.225775 -0.638672 0.101288 -0.339846 \n",
"2 0.247676 -1.514654 0.207643 ... 0.247998 0.771679 0.909412 -0.689281 \n",
"3 0.377436 -1.387024 -0.054952 ... -0.108300 0.005274 -0.190321 -1.175575 \n",
"4 -0.270533 0.817739 0.753074 ... -0.009431 0.798278 -0.137458 0.141267 \n",
"\n",
" V25 V26 V27 V28 Class normAmount \n",
"0 0.128539 -0.189115 0.133558 -0.021053 0 0.244964 \n",
"1 0.167170 0.125895 -0.008983 0.014724 0 -0.342475 \n",
"2 -0.327642 -0.139097 -0.055353 -0.059752 0 1.160686 \n",
"3 0.647376 -0.221929 0.062723 0.061458 0 0.140534 \n",
"4 -0.206010 0.502292 0.219422 0.215153 0 -0.073403 \n",
"\n",
"[5 rows x 30 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# X = (x-μ)/σ,使得新的X数据集方差为1均值为0\n",
"# fit_transform(data['Amount']) 意思是找出data['Amount']的μ和σ并应用在data['Amount']上。\n",
"data['normAmount'] = StandardScaler().fit_transform(data['Amount'].values.reshape(-1,1))\n",
"data = data.drop(['Time', 'Amount'], axis=1) # Time这里用不上也去掉\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 下采样方案"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"正常样本所占整体比例: 0.5\n",
"异常样本所占整体比例: 0.5\n",
"下采样策略总体样本量: 984\n"
]
}
],
"source": [
"X = data.loc[:,data.columns != \"Class\"] # 特征\n",
"y = data.loc[:,data.columns == \"Class\"] # 标签\n",
" \n",
"# 得到正样本(异常样本)的索引\n",
"number_records_fraud=len(data[data.Class==1])\n",
"fraud_indices=np.array(data[data.Class==1].index)\n",
"\n",
"# 得到负样本(正常样本)的索引\n",
"normal_indices=data[data.Class==0].index\n",
" \n",
"# 从正常样本中随机采样指定个数的样本,并取索引\n",
"random_normal_indices=np.random.choice(normal_indices, number_records_fraud,replace=False) \n",
"random_normal_indices=np.array(random_normal_indices)\n",
"\n",
"# 有了正常样本和异常样本的索引\n",
"under_sample_indices=np.concatenate([fraud_indices,random_normal_indices])\n",
"\n",
"# 根据索引得到下采样的所有样本点\n",
"under_sample_data=data.iloc[under_sample_indices,:]\n",
"\n",
"X_under_sample=under_sample_data.iloc[:,under_sample_data.columns != \"Class\"]\n",
"Y_under_sample=under_sample_data.iloc[:,under_sample_data.columns == \"Class\"]\n",
"\n",
"print(\"正常样本所占整体比例:\", len(under_sample_data[under_sample_data.Class == 0])/len(under_sample_data))\n",
"print(\"异常样本所占整体比例:\", len(under_sample_data[under_sample_data.Class == 1])/len(under_sample_data))\n",
"print(\"下采样策略总体样本量:\", len(under_sample_data))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 交叉验证"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"原始训练集包含样本量: 199364\n",
"原始测试集包含样本量: 85443\n",
"原始样本总数: 284807\n",
"\n",
"\n",
"下采样训练集包含样本数量: 688\n",
"下采样测试集包含样本数量: 296\n",
"下采样样本总数: 984\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split # 切分数据集\n",
"\n",
"# 将数据切割成训练集0.7 和测试集 0.3\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 0)\n",
"\n",
"print(\"原始训练集包含样本量:\", len(X_train))\n",
"print(\"原始测试集包含样本量:\", len(X_test))\n",
"print(\"原始样本总数:\", len(X_train)+len(X_test))\n",
"\n",
"# 下采样数据集进行划分\n",
"X_train_undersample, X_test_undersample, y_train_undersample, y_test_undersample = train_test_split(X_under_sample\n",
" ,Y_under_sample\n",
" ,test_size = 0.3\n",
" ,random_state = 0)\n",
" \n",
"print(\"\\n\")\n",
"print(\"下采样训练集包含样本数量: \", len(X_train_undersample))\n",
"print(\"下采样测试集包含样本数量: \", len(X_test_undersample))\n",
"print(\"下采样样本总数: \", len(X_train_undersample)+len(X_test_undersample))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 模型的评估方法——召回率\n",
"由于目前正负样本极度不平衡如果用准确率那么和上面说的一样模型把全部评定为正常样本准确率就达到99.99%。这里用召回率,即异常样本找到多少个。\n",
"\n",
"Recall = TP/(TP+FN)\n",
"\n",
"* TP即 True Positive =正确地判断成正例\n",
"* TN即 True negative=正确地判断成负例\n",
"* FP即False Positive =错误地判断成正例\n",
"* FN 即False negative =错误地判断成负例"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# 在sklearn库的线性模块中调取“逻辑回归”\n",
"from sklearn.linear_model import LogisticRegression \n",
"\n",
"#交叉验证模块中选用“K折交叉验证”\n",
"#cross_val_score函数返回的是一个使用交叉验证以后的评分标准。\n",
"from sklearn.model_selection import KFold, cross_val_score \n",
"\n",
"# 混淆矩阵、召回率\n",
"# sklearn中的classification_report函数用于显示主要分类指标的文本报告在报告中显示每个类的精确度召回率F1值等信息。\n",
"from sklearn.metrics import confusion_matrix,recall_score,classification_report \n",
"\n",
"# cross_val_predict 和 cross_val_score的使用方法是一样的但是它返回的是一个使用交叉验证以后的输出值而不是评分标准。\n",
"from sklearn.model_selection import cross_val_predict"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 正则化惩罚——提高模型泛化能力\n",
"模型过拟合通常出现在数据量少的同时特征又多也就是当下的情况异常样本非常少特征维度有28个。"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# 编写Kflod函数——printing_Kfold_scores实际中我们可以直接调用\n",
"def printing_Kfold_scores(x_train_data,y_train_data):\n",
" fold = KFold(5,shuffle=False) #shuffle=False是指数据集不用洗牌\n",
" \n",
" # 定义不同力度的正则化惩罚力度,值越大惩罚力度越小\n",
" c_param_range = [0.01,0.1,1,10,100]\n",
" # 展示结果用的表格\n",
" results_table = pd.DataFrame(index = range(len(c_param_range),2), columns = ['C_parameter','Mean recall score'])\n",
" results_table['C_parameter'] = c_param_range\n",
" \n",
" # k-fold 表示K折的交叉验证这里会得到两个索引集合: 训练集 = indices[0], 验证集 = indices[1]\n",
" j = 0\n",
" # 循环遍历不同的参数这里的c_param_rang是5个——5折交叉验证\n",
" for c_param in c_param_range:\n",
" print('-------------------------------------------')\n",
" print('正则化惩罚力度: ', c_param)\n",
" print('-------------------------------------------')\n",
" \n",
" # 计算每一次迭代后的召回率一次5次\n",
" recall_accs = []\n",
" \n",
" # 一步步分解来执行交叉验证\n",
" for iteration, indices in enumerate(fold.split(x_train_data)):\n",
" \n",
" # 选择算法模型+给定参数\n",
" lr = LogisticRegression(C = c_param, penalty = 'l1') #L1正则化防止过拟合通过k折交叉验证寻找最佳的参数C。 \n",
"\n",
" # 训练模型。注意索引不要给错了训练的时候一定传入的是训练集所以X和Y的索引都是0\n",
" lr.fit(x_train_data.iloc[indices[0],:],y_train_data.iloc[indices[0],:].values.ravel())\n",
"\n",
" # 使用验证集预测模型结果这里用的就是验证集索引为1\n",
" y_pred_undersample = lr.predict(x_train_data.iloc[indices[1],:].values)\n",
"\n",
" # 评估模型。有了预测结果之后就可以来进行评估了这里recall_score需要传入预测值和真实值。\n",
" recall_acc = recall_score(y_train_data.iloc[indices[1],:].values,y_pred_undersample)\n",
" # 保存每一步的结果,以便后续计算平均值。\n",
" recall_accs.append(recall_acc)\n",
" print('Iteration ', iteration,': 召回率 = ', recall_acc)\n",
"\n",
" # 当执行完所有的交叉验证后,计算平均结果\n",
" results_table.loc[j,'Mean recall score'] = np.mean(recall_accs)\n",
" j += 1 # 在这儿的意思是 num = num + 1\n",
" print('')\n",
" print('平均召回率 ', np.mean(recall_accs))\n",
" print('')\n",
"\n",
" # 找到最好的参数哪一个Recall高自然就是最好的了。\n",
" best_c = results_table.loc[results_table['Mean recall score'].astype('float32').idxmax()]['C_parameter']\n",
"\n",
" # 打印最好的结果\n",
" print('***********************************')\n",
" print('效果最好的模型所选参数 = ', best_c)\n",
" print('***********************************')\n",
"\n",
" return best_c"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"交叉验证与不同参数的结果"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-------------------------------------------\n",
"正则化惩罚力度: 0.01\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 0.9315068493150684\n",
"Iteration 1 : 召回率 = 0.9178082191780822\n",
"Iteration 2 : 召回率 = 1.0\n",
"Iteration 3 : 召回率 = 0.972972972972973\n",
"Iteration 4 : 召回率 = 0.9545454545454546\n",
"\n",
"平均召回率 0.9553666992023157\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 0.1\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 0.8493150684931506\n",
"Iteration 1 : 召回率 = 0.863013698630137\n",
"Iteration 2 : 召回率 = 0.9322033898305084\n",
"Iteration 3 : 召回率 = 0.9459459459459459\n",
"Iteration 4 : 召回率 = 0.9090909090909091\n",
"\n",
"平均召回率 0.8999138023981302\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 1\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 0.863013698630137\n",
"Iteration 1 : 召回率 = 0.8904109589041096\n",
"Iteration 2 : 召回率 = 0.9661016949152542\n",
"Iteration 3 : 召回率 = 0.9459459459459459\n",
"Iteration 4 : 召回率 = 0.8939393939393939\n",
"\n",
"平均召回率 0.9118823384669682\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 10\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 0.863013698630137\n",
"Iteration 1 : 召回率 = 0.8767123287671232\n",
"Iteration 2 : 召回率 = 0.9830508474576272\n",
"Iteration 3 : 召回率 = 0.9459459459459459\n",
"Iteration 4 : 召回率 = 0.9090909090909091\n",
"\n",
"平均召回率 0.9155627459783485\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 100\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 0.8767123287671232\n",
"Iteration 1 : 召回率 = 0.8767123287671232\n",
"Iteration 2 : 召回率 = 0.9830508474576272\n",
"Iteration 3 : 召回率 = 0.9459459459459459\n",
"Iteration 4 : 召回率 = 0.9393939393939394\n",
"\n",
"平均召回率 0.9243630780663519\n",
"\n",
"***********************************\n",
"效果最好的模型所选参数 = 0.01\n",
"***********************************\n"
]
}
],
"source": [
"best_c = printing_Kfold_scores(X_train_undersample,y_train_undersample)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# 混淆矩阵\n",
"def plot_confusion_matrix(cm, classes,\n",
" title='Confusion matrix',\n",
" cmap=plt.cm.Blues):\n",
" \"\"\"\n",
" 绘制混淆矩阵\n",
" \"\"\"\n",
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
" plt.title(title)\n",
" plt.colorbar()\n",
" tick_marks = np.arange(len(classes))\n",
" plt.xticks(tick_marks, classes, rotation=0)\n",
" plt.yticks(tick_marks, classes)\n",
"\n",
" thresh = cm.max() / 2.\n",
" for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
" plt.text(j, i, cm[i, j],\n",
" horizontalalignment=\"center\",\n",
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
"\n",
" plt.tight_layout()\n",
" plt.ylabel('True label')\n",
" plt.xlabel('Predicted label')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"召回率: 0.9387755102040817\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import itertools\n",
"# 选择最优正则化参数\n",
"lr = LogisticRegression(C = best_c, penalty = 'l1')\n",
"# 训练模型\n",
"lr.fit(X_train_undersample,y_train_undersample.values.ravel())\n",
"# 测试模型\n",
"y_pred_undersample = lr.predict(X_test_undersample.values)\n",
"# 计算所需值\n",
"cnf_matrix = confusion_matrix(y_test_undersample,y_pred_undersample)\n",
"np.set_printoptions(precision=2)\n",
" \n",
"print(\"召回率: \", cnf_matrix[1,1]/(cnf_matrix[1,0]+cnf_matrix[1,1]))\n",
"# 绘制\n",
"class_names = [0,1]\n",
"plt.figure()\n",
"plot_confusion_matrix(cnf_matrix\n",
" , classes=class_names\n",
" , title='Confusion matrix')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"右上角19表示原本正常的被判定为异常。右下角表示原本异常的被判定为异常的。看似结果不错。\n",
"\n",
"但这里还不是我们的原始需求我们的原始需求是在28万多个中找到492个异常的。而目前是11的比例。"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"召回率: 0.9183673469387755\n",
"精确率: 0.011378961564396493\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"lr = LogisticRegression(C = best_c, penalty = 'l1')\n",
"lr.fit(X_train_undersample,y_train_undersample.values.ravel())\n",
"# 代码和上面大致相同,唯一不同的,是这里我们使用的是真实比例\n",
"y_pred = lr.predict(X_test.values)\n",
" \n",
"cnf_matrix = confusion_matrix(y_test,y_pred)\n",
"np.set_printoptions(precision=2)\n",
" \n",
"print(\"召回率: \", cnf_matrix[1,1]/(cnf_matrix[1,0]+cnf_matrix[1,1]))\n",
"print(\"精确率: \", cnf_matrix[1,1]/(cnf_matrix[0,1]+cnf_matrix[1,1]))\n",
" \n",
"class_names = [0,1]\n",
"plt.figure()\n",
"plot_confusion_matrix(cnf_matrix\n",
" , classes=class_names\n",
" , title='Confusion matrix')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"第一眼看到召回率92.8%貌似不错但是右上角9433表示有这么多正常的人被预测为异常误判了这么多人。也就是精确率低的可怕"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 阈值对结果的影响\n",
"模型会给出对每个样本的预测概率默认是将0.5以上认为是被预测为异常的那么如果我们将0.6以上才认定为异常呢?\n",
"\n",
"模型给的概率越高表示越肯定该样本是异常,那么我们可以算下不同阈值的概率,召回率和精确率分别是多少"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"指定的阈值为: 0.1 时,测试集的召回率: 1.00 精确率: 0.497\n",
"指定的阈值为: 0.2 时,测试集的召回率: 1.00 精确率: 0.497\n",
"指定的阈值为: 0.3 时,测试集的召回率: 1.00 精确率: 0.500\n",
"指定的阈值为: 0.4 时,测试集的召回率: 0.97 精确率: 0.637\n",
"指定的阈值为: 0.5 时,测试集的召回率: 0.94 精确率: 0.896\n",
"指定的阈值为: 0.6 时,测试集的召回率: 0.89 精确率: 0.978\n",
"指定的阈值为: 0.7 时,测试集的召回率: 0.82 精确率: 0.984\n",
"指定的阈值为: 0.8 时,测试集的召回率: 0.77 精确率: 0.991\n",
"指定的阈值为: 0.9 时,测试集的召回率: 0.60 精确率: 1.000\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x576 with 18 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# 还是那个模型和正则参数\n",
"lr = LogisticRegression(C = best_c, penalty='l1')\n",
"\n",
"# 训练模型依然是下采样的数据集\n",
"lr.fit(X_train_undersample, y_train_undersample.values.ravel())\n",
"\n",
"# 得到预测的概率\n",
"y_pred_undersample_proba = np.array(lr.predict_proba(X_test_undersample.values))\n",
"\n",
"# 指定不同的阈值\n",
"thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n",
"\n",
"plt.figure(figsize=(8,8))\n",
"\n",
"j = 1\n",
"\n",
"# 用混淆矩阵来展示图形化结果\n",
"for i in thresholds:\n",
" y_test_predictions_high_recall = y_pred_undersample_proba[:, 1] > i\n",
" \n",
" plt.subplot(3, 3, j)\n",
" j += 1\n",
" \n",
" cnf_matrix = confusion_matrix(y_test_undersample,y_test_predictions_high_recall)\n",
" np.set_printoptions(precision=2)\n",
"\n",
" print(\"指定的阈值为: \",i,\"时,测试集的召回率:\", '{0:.2f}'.format(cnf_matrix[1,1]/(cnf_matrix[1,0]+cnf_matrix[1,1]))\n",
" ,\"精确率:\", '{0:.3f}'.format(cnf_matrix[1,1]/(cnf_matrix[0,1]+cnf_matrix[1,1])))\n",
" \n",
" class_names = [0,1]\n",
" plot_confusion_matrix(cnf_matrix\n",
" , classes=class_names\n",
" , title='Confusion matrix')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"随着阈值的上升,召回率越来越低,但精确率越来越高,在不可改变的情况下,我们会选择合适的阈值进行业务使用。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SMOTE过采样方案\n",
"SMOTE方法是通过近邻构造数据其中距离公式是欧式距离计算在位置相近的点构造新数据。"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from imblearn.over_sampling import SMOTE\n",
"from sklearn.metrics import confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"credit_cards = data.copy()\n",
"columns = credit_cards.columns\n",
"features_columns = columns.delete(len(columns)-1)\n",
"\n",
"features = credit_cards[features_columns]\n",
"labels = credit_cards['Class']"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"features_train, features_test, labels_train, labels_test = train_test_split(features, labels, \n",
" test_size = 0.3, \n",
" random_state = 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"基于SMOTE算法来进行样本生成这样正负样本比例就是一致了"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"#初始化SMOTE 模型\n",
"oversampler=SMOTE(random_state=0)\n",
"\n",
"#使用SMOTE模型创造新的数据集\n",
"os_features,os_labels=oversampler.fit_sample(features_train,labels_train)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=====正负样本数据量=======\n",
"0 199019\n",
"1 199019\n",
"dtype: int64\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Frequency')"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"os_count_classes = pd.value_counts(os_labels, sort = True).sort_index()\n",
"print(\"=====正负样本数据量=======\")\n",
"print(os_count_classes)\n",
"\n",
"os_count_classes.plot(kind = \"bar\")\n",
"\n",
"plt.title(\"Fraud class histogram\")\n",
"plt.xlabel(\"Class\")\n",
"plt.ylabel(\"Frequency\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"此时数据量已完全一致,我们跑下模型"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-------------------------------------------\n",
"正则化惩罚力度: 0.01\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 1.0\n",
"Iteration 1 : 召回率 = 1.0\n",
"Iteration 2 : 召回率 = 1.0\n",
"Iteration 3 : 召回率 = 1.0\n",
"Iteration 4 : 召回率 = 1.0\n",
"\n",
"平均召回率 1.0\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 0.1\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 1.0\n",
"Iteration 1 : 召回率 = 1.0\n",
"Iteration 2 : 召回率 = 1.0\n",
"Iteration 3 : 召回率 = 1.0\n",
"Iteration 4 : 召回率 = 1.0\n",
"\n",
"平均召回率 1.0\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 1\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 1.0\n",
"Iteration 1 : 召回率 = 1.0\n",
"Iteration 2 : 召回率 = 1.0\n",
"Iteration 3 : 召回率 = 1.0\n",
"Iteration 4 : 召回率 = 1.0\n",
"\n",
"平均召回率 1.0\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 10\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 1.0\n",
"Iteration 1 : 召回率 = 1.0\n",
"Iteration 2 : 召回率 = 1.0\n",
"Iteration 3 : 召回率 = 1.0\n",
"Iteration 4 : 召回率 = 1.0\n",
"\n",
"平均召回率 1.0\n",
"\n",
"-------------------------------------------\n",
"正则化惩罚力度: 100\n",
"-------------------------------------------\n",
"Iteration 0 : 召回率 = 1.0\n",
"Iteration 1 : 召回率 = 1.0\n",
"Iteration 2 : 召回率 = 1.0\n",
"Iteration 3 : 召回率 = 1.0\n",
"Iteration 4 : 召回率 = 1.0\n",
"\n",
"平均召回率 1.0\n",
"\n",
"***********************************\n",
"效果最好的模型所选参数 = 0.01\n",
"***********************************\n"
]
}
],
"source": [
"os_features = pd.DataFrame(os_features)\n",
"os_labels = pd.DataFrame(os_labels)\n",
"best_c = printing_Kfold_scores(os_features,os_labels)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"召回率: 1.0\n",
"精确率: 0.9735099337748344\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"lr = LogisticRegression(C = best_c, penalty = 'l1')\n",
"lr.fit(os_features,os_labels.values.ravel())\n",
"# 代码和上面大致相同,唯一不同的,是这里我们使用的是真实比例\n",
"y_pred = lr.predict(features_test.values)\n",
" \n",
"cnf_matrix = confusion_matrix(labels_test,y_pred)\n",
"np.set_printoptions(precision=2)\n",
" \n",
"print(\"召回率: \", cnf_matrix[1,1]/(cnf_matrix[1,0]+cnf_matrix[1,1]))\n",
"print(\"精确率: \", cnf_matrix[1,1]/(cnf_matrix[0,1]+cnf_matrix[1,1]))\n",
" \n",
"class_names = [0,1]\n",
"plt.figure()\n",
"plot_confusion_matrix(cnf_matrix\n",
" , classes=class_names\n",
" , title='Confusion matrix')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"结果可以说是非常好,我们再试下不同阈值的结果"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"指定的阈值为: 0.1 时,测试集的召回率: 1.00 精确率: 0.936\n",
"指定的阈值为: 0.2 时,测试集的召回率: 1.00 精确率: 0.955\n",
"指定的阈值为: 0.3 时,测试集的召回率: 1.00 精确率: 0.961\n",
"指定的阈值为: 0.4 时,测试集的召回率: 1.00 精确率: 0.967\n",
"指定的阈值为: 0.5 时,测试集的召回率: 1.00 精确率: 0.987\n",
"指定的阈值为: 0.6 时,测试集的召回率: 1.00 精确率: 0.993\n",
"指定的阈值为: 0.7 时,测试集的召回率: 1.00 精确率: 1.000\n",
"指定的阈值为: 0.8 时,测试集的召回率: 1.00 精确率: 1.000\n",
"指定的阈值为: 0.9 时,测试集的召回率: 1.00 精确率: 1.000\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x576 with 18 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# 试下不同阈值的结果\n",
"lr = LogisticRegression(C = best_c, penalty='l1')\n",
"\n",
"lr.fit(os_features, os_labels.values.ravel())\n",
"y_pred_over_proba = np.array(lr.predict_proba(features_test.values))\n",
"\n",
"\n",
"thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n",
"plt.figure(figsize=(8,8))\n",
"j = 1\n",
"\n",
"# 用混淆矩阵来展示图形化结果\n",
"for i in thresholds:\n",
" y_test_predictions_high_recall = y_pred_over_proba[:, 1] > i\n",
" \n",
" plt.subplot(3, 3, j)\n",
" j += 1\n",
" \n",
" cnf_matrix = confusion_matrix(labels_test,y_test_predictions_high_recall)\n",
" np.set_printoptions(precision=2)\n",
"\n",
" print(\"指定的阈值为: \",i,\"时,测试集的召回率:\", '{0:.2f}'.format(cnf_matrix[1,1]/(cnf_matrix[1,0]+cnf_matrix[1,1]))\n",
" ,\"精确率:\", '{0:.3f}'.format(cnf_matrix[1,1]/(cnf_matrix[0,1]+cnf_matrix[1,1])))\n",
" \n",
" class_names = [0,1]\n",
" plot_confusion_matrix(cnf_matrix\n",
" , classes=class_names\n",
" , title='Confusion matrix')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"此时当阈值越高,召回率没影响而精确率提升,我们可以选择更高的阈值。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 项目总结\n",
" (1) 在项目中,我们发现了样本不均衡问题,并指定了相应的解决方案。\n",
" (2) 这里提出的两种方法,下采样和过采样,通过实验比对,找到合适的那个。\n",
" (3) 在建模时,我们使用各种预处理方法,如这里的数据标准化,甚至还有缺失值填充等。\n",
" (4) 选择合适的评估方法,再进行建模。建模得到我们预期的结果,并根据评估方法来评价模型。\n",
" (5) 这里选择的算法是逻辑回归,逻辑回归也是简单模型的代表,在工业界也常被使用。\n",
" (6) 模型训练后进行适当的调优也是非常必要的,如前面的下采样时。最终选择最优的参数。\n",
" (7) 得到的结果和实际任务结合,并确保效果偏差小。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}