You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/2-Regression/4-Logistic/solution/notebook.ipynb

1249 lines
235 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
4 years ago
"## Logistic Regression - Lesson 4\n",
"\n",
"Load up required libraries and dataset. Convert the data to a dataframe containing a subset of the data:"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>City Name</th>\n",
" <th>Type</th>\n",
" <th>Package</th>\n",
" <th>Variety</th>\n",
" <th>Sub Variety</th>\n",
" <th>Grade</th>\n",
" <th>Date</th>\n",
" <th>Low Price</th>\n",
" <th>High Price</th>\n",
" <th>Mostly Low</th>\n",
" <th>...</th>\n",
" <th>Unit of Sale</th>\n",
" <th>Quality</th>\n",
" <th>Condition</th>\n",
" <th>Appearance</th>\n",
" <th>Storage</th>\n",
" <th>Crop</th>\n",
" <th>Repack</th>\n",
" <th>Trans Mode</th>\n",
" <th>Unnamed: 24</th>\n",
" <th>Unnamed: 25</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>BALTIMORE</td>\n",
" <td>NaN</td>\n",
" <td>24 inch bins</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4/29/17</td>\n",
" <td>270.0</td>\n",
" <td>280.0</td>\n",
" <td>270.0</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>E</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>BALTIMORE</td>\n",
" <td>NaN</td>\n",
" <td>24 inch bins</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>5/6/17</td>\n",
" <td>270.0</td>\n",
" <td>280.0</td>\n",
" <td>270.0</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>E</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>BALTIMORE</td>\n",
" <td>NaN</td>\n",
" <td>24 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>9/24/16</td>\n",
" <td>160.0</td>\n",
" <td>160.0</td>\n",
" <td>160.0</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>N</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>BALTIMORE</td>\n",
" <td>NaN</td>\n",
" <td>24 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>9/24/16</td>\n",
" <td>160.0</td>\n",
" <td>160.0</td>\n",
" <td>160.0</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>N</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>BALTIMORE</td>\n",
" <td>NaN</td>\n",
" <td>24 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>11/5/16</td>\n",
" <td>90.0</td>\n",
" <td>100.0</td>\n",
" <td>90.0</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>N</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 26 columns</p>\n",
"</div>"
],
"text/plain": [
" City Name Type Package Variety Sub Variety Grade Date \n",
"0 BALTIMORE NaN 24 inch bins NaN NaN NaN 4/29/17 \\\n",
"1 BALTIMORE NaN 24 inch bins NaN NaN NaN 5/6/17 \n",
"2 BALTIMORE NaN 24 inch bins HOWDEN TYPE NaN NaN 9/24/16 \n",
"3 BALTIMORE NaN 24 inch bins HOWDEN TYPE NaN NaN 9/24/16 \n",
"4 BALTIMORE NaN 24 inch bins HOWDEN TYPE NaN NaN 11/5/16 \n",
"\n",
" Low Price High Price Mostly Low ... Unit of Sale Quality Condition \n",
"0 270.0 280.0 270.0 ... NaN NaN NaN \\\n",
"1 270.0 280.0 270.0 ... NaN NaN NaN \n",
"2 160.0 160.0 160.0 ... NaN NaN NaN \n",
"3 160.0 160.0 160.0 ... NaN NaN NaN \n",
"4 90.0 100.0 90.0 ... NaN NaN NaN \n",
"\n",
" Appearance Storage Crop Repack Trans Mode Unnamed: 24 Unnamed: 25 \n",
"0 NaN NaN NaN E NaN NaN NaN \n",
"1 NaN NaN NaN E NaN NaN NaN \n",
"2 NaN NaN NaN N NaN NaN NaN \n",
"3 NaN NaN NaN N NaN NaN NaN \n",
"4 NaN NaN NaN N NaN NaN NaN \n",
"\n",
"[5 rows x 26 columns]"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"full_pumpkins = pd.read_csv('../../data/US-pumpkins.csv')\n",
"\n",
"full_pumpkins.head()\n"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>City Name</th>\n",
" <th>Package</th>\n",
" <th>Variety</th>\n",
" <th>Origin</th>\n",
" <th>Item Size</th>\n",
" <th>Color</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>BALTIMORE</td>\n",
" <td>24 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>DELAWARE</td>\n",
" <td>med</td>\n",
" <td>ORANGE</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>BALTIMORE</td>\n",
" <td>24 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>VIRGINIA</td>\n",
" <td>med</td>\n",
" <td>ORANGE</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>BALTIMORE</td>\n",
" <td>24 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>MARYLAND</td>\n",
" <td>lge</td>\n",
" <td>ORANGE</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>BALTIMORE</td>\n",
" <td>24 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>MARYLAND</td>\n",
" <td>lge</td>\n",
" <td>ORANGE</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>BALTIMORE</td>\n",
" <td>36 inch bins</td>\n",
" <td>HOWDEN TYPE</td>\n",
" <td>MARYLAND</td>\n",
" <td>med</td>\n",
" <td>ORANGE</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" City Name Package Variety Origin Item Size Color\n",
"2 BALTIMORE 24 inch bins HOWDEN TYPE DELAWARE med ORANGE\n",
"3 BALTIMORE 24 inch bins HOWDEN TYPE VIRGINIA med ORANGE\n",
"4 BALTIMORE 24 inch bins HOWDEN TYPE MARYLAND lge ORANGE\n",
"5 BALTIMORE 24 inch bins HOWDEN TYPE MARYLAND lge ORANGE\n",
"6 BALTIMORE 36 inch bins HOWDEN TYPE MARYLAND med ORANGE"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Select the columns we want to use\n",
"columns_to_select = ['City Name','Package','Variety', 'Origin','Item Size', 'Color']\n",
"pumpkins = full_pumpkins.loc[:, columns_to_select]\n",
"\n",
"# Drop rows with missing values\n",
"pumpkins.dropna(inplace=True)\n",
"\n",
"pumpkins.head()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Let's have a look to our data!\n",
"\n",
"By visualising it with Seaborn"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<seaborn.axisgrid.FacetGrid at 0x7f8c56d0c650>"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjQAAAHpCAYAAACVw6ZvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABU3klEQVR4nO3deVRU5f8H8PeFkQFZZXNQ2RQBwy3NNRVGMTCz3JW0JJcyjdwXLJcwBSszTcU0wKxccl9KyoVxS0VTEhXXRM1A+7qwmOz394eH+/M6A7IKV9+vc+7Jee6zfO7IkXfP3JkRRFEUQURERKRgRlVdABEREVF5MdAQERGR4jHQEBERkeIx0BAREZHiMdAQERGR4jHQEBERkeIx0BAREZHiMdDQc0sURaSnp4MfxUREpHwMNPTcysjIgLW1NTIyMqq6FCIiKicGGiIiIlI8BhoiIiJSPAYaIiIiUjwGGiIiIlI8VVUXQFTVrq5qAkszZnsipXIbdqWqS6BqgP+KExERkeIx0BAREZHiMdAQERGR4jHQEBERkeIx0BAREZHiMdAQERGR4jHQEBERkeIx0BAREZHiMdAQERGR4jHQEBERkeIx0BAREZHiMdAQERGR4jHQEBERkeIx0BAREZHiMdCUU3BwMHr27KnXrtPpIAgC7t27J7Xl5+djwYIFaNKkCUxNTVGrVi1069YNhw4dkvqcO3cOgiDgyJEjsvnatm0LU1NTZGVlSW1ZWVkwNTVFVFSUVIsgCBAEATVq1EDt2rXRtWtXREdHo6CgQDafm5ub1PfRIyIiAgCQnJwMQRDg6OiIjIwM2djmzZtj1qxZBp8PPz8/g/MWHo0bN4ZGo8HcuXP1xvbv3x9t27ZFfn4+Zs2aJY1RqVRwc3PDuHHjkJmZKavP0PH4c0dERM8+BpqnRBRFDBw4EGFhYRgzZgySkpKg0+ng7OwMPz8/bNmyBQDg7e0NjUYDnU4njc3IyMCJEyfg4OAg+2V9+PBhZGdno3PnzlJbYGAgUlJSkJycjJ07d0Kr1WLMmDF47bXXkJeXJ6spLCwMKSkpsiMkJETWJyMjA1988UWJr3PTpk3SXPHx8QCA3bt3S2379+/H8uXL8cknnyAxMVEat379euzYsQPfffcdjI2NAQA+Pj7StcybNw/Lly/HhAkTZOs9Onfh0bJlyxLXS0REzwZVVRfwvPjpp5+wYcMGbNu2DT169JDaly9fjtu3b2P48OHo2rUrzM3NodVqodPpMHXqVADAwYMH4enpiU6dOkGn08HPzw/Aw10gV1dXuLu7S/Op1WpoNBoAQN26ddGiRQu0bdsWXbp0wcqVKzF8+HCpr6WlpdS3KCEhIfjyyy8xevRoODo6PvE6bW1tpT8X7ibZ2dnJ1nn99dfx5ptvYsiQITh69Cju3buH0aNHIyIiAl5eXlI/lUoljRswYAD27NmDbdu24ZtvvpH6PD43ERE9n7hD85SsXr0anp6esjBTaMKECbh9+zZ27doFANBqtTh48KC0oxIXFwc/Pz/4+voiLi5OGhcXFwetVvvEtTt37oxmzZph06ZNpa47KCgIHh4eCAsLK/XY4ixcuBC3b9/G7NmzMWrUKDRu3Fhvd+hxZmZmyMnJKfOa2dnZSE9Plx1ERPRs4A5NBdixYwcsLCxkbfn5+bLHFy5cQKNGjQyOL2y/cOECgIeB5v79+zh27BjatWsHnU6HSZMmoUOHDhgyZAiysrIgiiLi4+NlOy7F8fb2xqlTp2RtU6ZMwccffyxr27lzJzp27Cg9LryvpkePHhg3bhwaNGhQovWexMrKCjExMXjllVdgbm6OU6dOQRCEIvv/8ccfWL16tezlNQBo3749jIzkubzwPpvHhYeH45NPPtFrd307EVZWVmW4CiIiqi4YaCqAVqtFZGSkrO3o0aMYPHiwrE0UxRLN5+HhgXr16kGn08HHxwcnT56Er68vHB0d4eLigsOHD0MURWRnZ5doh6Zw7ccDw6RJkxAcHCxrq1u3rt7YgIAAdOjQAdOnT8fq1atLtF5JdO7cGW3btkXz5s3h6uqqdz4xMREWFhbIz89HTk4OunfvjsWLF8v6rFu3rsig+LjQ0FCMHz9eepyeng5nZ+fyXQQREVULDDQVwNzcHB4eHrK2v//+W/bY09MTSUlJBscXtnt6ekptfn5+iIuLQ9OmTdGwYUPp/pXCl51EUYSHh0eJfyEnJSXJ7rUBAHt7e726ixIREYF27dph0qRJJepfUiqVCiqV4R9DLy8vbNu2DSqVCnXq1IGJiYleH2dn5xJfg1qthlqtLle9RERUPfEemqdk4MCBuHjxIrZv3653bv78+bCzs0PXrl2lNq1Wi99//x27du2SbgIGIN0YrNPpSrw7s3fvXiQmJqJPnz5lrr9169bo3bu3dKPy02BiYgIPDw+4ubkZDDNERESFuEPzlAwcOBDr16/HkCFD8Pnnn6NLly5IT0/HkiVLsG3bNqxfvx7m5uZS/8L7aKKjo7FixQqp3dfXV7pvZtSoUXrrZGdnIzU1Ffn5+bh58yZiY2MRHh6O1157DW+//basb0ZGBlJTU2VtNWvWLPJ+kjlz5sDHx6fIHZWqcPv2bb1rsLGxgampaRVVREREVYE7NE+JIAj46aefMG3aNCxYsABeXl7o2LEjrl69Cp1Op/fhfO7u7nB1dUVGRgZ8fX2ldhcXF9SpUwc5OTmynZtCsbGxcHJygpubGwIDAxEXF4dFixZh69at0ue7FJoxYwacnJxkx+TJk4u8Bk9PTwwdOlT24X5Vzd/fX+8aCj/Th4iInh+CWNI7VYmeMenp6bC2tkZaWhrf5UREpHDcoSEiIiLFY6AhIiIixWOgISIiIsVjoCEiIiLFY6AhIiIixWOgISIiIsVjoCEiIiLFY6AhIiIixWOgISIiIsVjoCEiIiLFY6AhIiIixWOgISIiIsVjoCEiIiLFY6AhIiIixWOgISIiIsVjoCEiIiLFY6AhIiIixWOgISIiIsVjoCEiIiLFY6AhIiIixWOgISIiIsVTVXUBRFXt6qomsDRjtq8obsOuVHUJRPQc4r/iREREpHgMNERERKR4DDRERESkeAw0REREpHgMNERERKR4DDRERESkeAw0REREpHgMNERERKR4DDRERESkeAw0REREpHgMNERERKR4DDRERESkeAw0REREpHgMNERERKR4DDRERESkeAw0z5ng4GAIggBBEGBiYgIPDw+EhYUhLy8PAKDT6SAIAu7duyd7bOhITU3Vm3/WrFlF9i88/P39ERAQoDd26dKlsLGxwd9//623bu3atdGnTx/89ddfUn83NzeD80dERFTOk0dERNWWqqoLoKcvMDAQMTExyM7Oxi+//ILRo0ejRo0aCA0NLXLM+fPnYWVlJWtzdHTU6zdx4kSMHDlSetyqVSu8++67GDFihNSWm5uLJk2a4JtvvsF7770HALhy5QomT56MyMhI1KtXD5cuXZLWtbS0xMWLF/Huu++iR48eOHXqFIyNjQEAYWFhsrkBwNLSspTPCBERKR0DzXNIrVZDo9EAAN5//31s3rwZ27ZtKzbQODo6wsbG5olzW1hYwMLCQnpsbGwMS0tLab1CCxcuxAcffIBXXnkFbm5uGDZsGF555RW89dZbBtd1cnLCjBkzMGjQIFy6dAleXl4AYHDuomRnZyM7O1t6nJ6eXqJxRERU/THQEMzMzHD79u2nuuaQIUOwefNmDB06FL1798bp06dx5syZYseYmZkBAHJycsq0Znh4OD755BO9dssOi2BlaQ7bBv5lmpeIiKoe76F5jomiiN27d+PXX39F586di+1br149affFwsICPj4+5V5/+fLlOH36NMaOHYvly5fDwcGhyL4pKSn44osvULduXWl3BgCmTJkiq8vCwgIHDhwwOEdoaCjS0tKk4/r16+W+BiIiqh64Q/Mc2rFjBywsLJCbm4uCggK8+eabmDVrVrFjDhw4ILs3pUaNGuWuw9HREe+99x62bNmCnj17GuxTr149iKKI//77D82aNcPGjRthYmIinZ80aRKCg4NlY+rWrWtwLrVaDbVaXe6
"text/plain": [
"<Figure size 609.375x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import seaborn as sns\n",
"# Specify colors for each values of the hue variable\n",
"palette = {\n",
" 'ORANGE': 'orange',\n",
" 'WHITE': 'wheat',\n",
"}\n",
"# Plot a bar plot to visualize how many pumpkins of each variety are orange or white\n",
"sns.catplot(\n",
" data=pumpkins, y=\"Variety\", hue=\"Color\", kind=\"count\",\n",
" palette=palette, \n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data pre-processing\n",
"\n",
"Let's encode features and labels to better plot the data and train the model"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['med', 'lge', 'sml', 'xlge', 'med-lge', 'jbo', 'exjbo'],\n",
" dtype=object)"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Let's look at the different values of the 'Item Size' column\n",
"pumpkins['Item Size'].unique()"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import OrdinalEncoder\n",
"# Encode the 'Item Size' column using ordinal encoding\n",
"item_size_categories = [['sml', 'med', 'med-lge', 'lge', 'xlge', 'jbo', 'exjbo']]\n",
"ordinal_features = ['Item Size']\n",
"ordinal_encoder = OrdinalEncoder(categories=item_size_categories)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import OneHotEncoder\n",
"# Encode all the other features using one-hot encoding\n",
"categorical_features = ['City Name', 'Package', 'Variety', 'Origin']\n",
"categorical_encoder = OneHotEncoder(sparse_output=False)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ord__Item Size</th>\n",
" <th>cat__City Name_ATLANTA</th>\n",
" <th>cat__City Name_BALTIMORE</th>\n",
" <th>cat__City Name_BOSTON</th>\n",
" <th>cat__City Name_CHICAGO</th>\n",
" <th>cat__City Name_COLUMBIA</th>\n",
" <th>cat__City Name_DALLAS</th>\n",
" <th>cat__City Name_DETROIT</th>\n",
" <th>cat__City Name_LOS ANGELES</th>\n",
" <th>cat__City Name_MIAMI</th>\n",
" <th>...</th>\n",
" <th>cat__Origin_MICHIGAN</th>\n",
" <th>cat__Origin_NEW JERSEY</th>\n",
" <th>cat__Origin_NEW YORK</th>\n",
" <th>cat__Origin_NORTH CAROLINA</th>\n",
" <th>cat__Origin_OHIO</th>\n",
" <th>cat__Origin_PENNSYLVANIA</th>\n",
" <th>cat__Origin_TENNESSEE</th>\n",
" <th>cat__Origin_TEXAS</th>\n",
" <th>cat__Origin_VERMONT</th>\n",
" <th>cat__Origin_VIRGINIA</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 48 columns</p>\n",
"</div>"
],
"text/plain": [
" ord__Item Size cat__City Name_ATLANTA cat__City Name_BALTIMORE \n",
"2 1.0 0.0 1.0 \\\n",
"3 1.0 0.0 1.0 \n",
"4 3.0 0.0 1.0 \n",
"5 3.0 0.0 1.0 \n",
"6 1.0 0.0 1.0 \n",
"\n",
" cat__City Name_BOSTON cat__City Name_CHICAGO cat__City Name_COLUMBIA \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__City Name_DALLAS cat__City Name_DETROIT cat__City Name_LOS ANGELES \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__City Name_MIAMI ... cat__Origin_MICHIGAN cat__Origin_NEW JERSEY \n",
"2 0.0 ... 0.0 0.0 \\\n",
"3 0.0 ... 0.0 0.0 \n",
"4 0.0 ... 0.0 0.0 \n",
"5 0.0 ... 0.0 0.0 \n",
"6 0.0 ... 0.0 0.0 \n",
"\n",
" cat__Origin_NEW YORK cat__Origin_NORTH CAROLINA cat__Origin_OHIO \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__Origin_PENNSYLVANIA cat__Origin_TENNESSEE cat__Origin_TEXAS \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__Origin_VERMONT cat__Origin_VIRGINIA \n",
"2 0.0 0.0 \n",
"3 0.0 1.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"\n",
"[5 rows x 48 columns]"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"ct = ColumnTransformer(transformers=[\n",
" ('ord', ordinal_encoder, ordinal_features),\n",
" ('cat', categorical_encoder, categorical_features)\n",
" ])\n",
"# Get the encoded features as a pandas DataFrame\n",
"ct.set_output(transform='pandas')\n",
"encoded_features = ct.fit_transform(pumpkins)\n",
"encoded_features.head()"
]
},
{
4 years ago
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ord__Item Size</th>\n",
" <th>cat__City Name_ATLANTA</th>\n",
" <th>cat__City Name_BALTIMORE</th>\n",
" <th>cat__City Name_BOSTON</th>\n",
" <th>cat__City Name_CHICAGO</th>\n",
" <th>cat__City Name_COLUMBIA</th>\n",
" <th>cat__City Name_DALLAS</th>\n",
" <th>cat__City Name_DETROIT</th>\n",
" <th>cat__City Name_LOS ANGELES</th>\n",
" <th>cat__City Name_MIAMI</th>\n",
" <th>...</th>\n",
" <th>cat__Origin_NEW JERSEY</th>\n",
" <th>cat__Origin_NEW YORK</th>\n",
" <th>cat__Origin_NORTH CAROLINA</th>\n",
" <th>cat__Origin_OHIO</th>\n",
" <th>cat__Origin_PENNSYLVANIA</th>\n",
" <th>cat__Origin_TENNESSEE</th>\n",
" <th>cat__Origin_TEXAS</th>\n",
" <th>cat__Origin_VERMONT</th>\n",
" <th>cat__Origin_VIRGINIA</th>\n",
" <th>Color</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 49 columns</p>\n",
"</div>"
],
"text/plain": [
" ord__Item Size cat__City Name_ATLANTA cat__City Name_BALTIMORE \n",
"2 1.0 0.0 1.0 \\\n",
"3 1.0 0.0 1.0 \n",
"4 3.0 0.0 1.0 \n",
"5 3.0 0.0 1.0 \n",
"6 1.0 0.0 1.0 \n",
"\n",
" cat__City Name_BOSTON cat__City Name_CHICAGO cat__City Name_COLUMBIA \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__City Name_DALLAS cat__City Name_DETROIT cat__City Name_LOS ANGELES \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__City Name_MIAMI ... cat__Origin_NEW JERSEY cat__Origin_NEW YORK \n",
"2 0.0 ... 0.0 0.0 \\\n",
"3 0.0 ... 0.0 0.0 \n",
"4 0.0 ... 0.0 0.0 \n",
"5 0.0 ... 0.0 0.0 \n",
"6 0.0 ... 0.0 0.0 \n",
"\n",
" cat__Origin_NORTH CAROLINA cat__Origin_OHIO cat__Origin_PENNSYLVANIA \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__Origin_TENNESSEE cat__Origin_TEXAS cat__Origin_VERMONT \n",
"2 0.0 0.0 0.0 \\\n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"\n",
" cat__Origin_VIRGINIA Color \n",
"2 0.0 0 \n",
"3 1.0 0 \n",
"4 0.0 0 \n",
"5 0.0 0 \n",
"6 0.0 0 \n",
"\n",
"[5 rows x 49 columns]"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"# Encode the 'Color' column using label encoding\n",
"label_encoder = LabelEncoder()\n",
"encoded_label = label_encoder.fit_transform(pumpkins['Color'])\n",
"encoded_pumpkins = encoded_features.assign(Color=encoded_label)\n",
"encoded_pumpkins.head()"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['ORANGE', 'WHITE']"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Let's look at the mapping between the encoded values and the original values\n",
"list(label_encoder.inverse_transform([0, 1]))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Analysing relationships between features and label"
]
},
{
"cell_type": "code",
"execution_count": 81,
4 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<seaborn.axisgrid.FacetGrid at 0x7f8c56322210>"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAYpCAYAAABBoEQQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxVdf7H8ff1IuAI4oqAIqiIuORS5lqCpqnTkE2NOo6ZYlk5uOVYaf1m1MbCpkUrNbMhW8a0Tdtm1FwAzTRFpVxJDVMTRBNZHEGB8/vDh3e8gXpZT6f7ej4e5zH3nu853/vhjMGbL9/zPTbDMAwBAAAAFlLD7AIAAACAsiLEAgAAwHIIsQAAALAcQiwAAAAshxALAAAAyyHEAgAAwHIIsQAAALAcQiwAAAAshxALAAAAyyHEAgAAwHIIsQAAALAcD7MLAAAAwK9fTEyMS8ctWbLEpeMIsQAAAKhy2dnZTu/PnTunDRs2KDo6ulz92QzDMCqjMAAAAMBVaWlp6tChg3Jzc8t1PnNiAQAAUO0qOo5KiAUAAIDlEGIBAABgOdzYBQAAgCqXlJTk9P7HH39UUVGREhMTZbPZHPsjIyNd6o8buwAAAFDl7Ha7DMNwCqw/ZxiGiouLXeqPkVgAAABUuaysrErtj5FYAAAAWA43dgEAAKDaLF++XHfddZfatm2rtm3b6q677tJ7771X5n4YiQUAAECVKy4u1pAhQ/Txxx+rVatWatOmjWw2m/bv36/U1FTdc889eu+991SjhmtjrMyJBQAAQJWbN2+ekpKS9Omnn+qOO+5wavvPf/6jkSNH6qWXXtIjjzziUn+MxAIAAKDKdejQQZMnT9aYMWNKbV+yZInmzp2rb7/91qX+CLEAAACocrVq1dKBAwcUEhJSavsPP/ygiIgInT9/3qX+uLELAAAAVc7b21vZ2dlXbc/JyVGtWrVc7o8QCwAAgCrXo0cPLViw4Krt8+fPV/fu3V3ujxu7AAAAUOX+9re/KTIyUqdPn9Zf/vIXtW3bVpK0f/9+vfDCC/r000+VmJjocn/MiQUAAEC1+Oyzz3T//ffr9OnTTvsbNmyof/7zn7rzzjtd7osQCwAAgGpz/vx5rV27Vt99950kKTw8XP379y/TfFiJEAsAAIBfiNTUVLVu3dqlY5kTCwAAAFMcPnxYCQkJji0jI0PFxcUunUuIBQAAQLU4cuSII7AmJibq+PHj8vHx0S233KLJkycrKirK5b6YTgAAAIAq17x5c/3www+qXbu2evXqpT59+igqKko333yzatQo+6qvhFgAAABUOQ8PD/n4+CgmJkb9+/fXrbfeKl9f33L3R4gFAABAlcvMzFRSUpKSkpKUmJio7777Tp07d1ZUVJT69OmjW265RT4+Pi73R4gtJ8MwlJubK19fX9lsNrPLAQAAsJSffvpJiYmJjlCbmpqqzp07a+vWrS6dz41d5ZSbmys/Pz9lZ2erTp06ZpfjVgzDUEFBgdllAKa48t+/l5cXv0SbhGsPVFyDBg3Uq1cvFRcXq7i4WNnZ2frmm29cPp8QC8spKCjQkCFDzC4DgBv74IMP5O3tbXYZgOUcO3ZMSUlJ2rhxozZu3KgffvhB3bp1U9++ffXOO++oe/fuLvdFiAUAAECVa9GihU6cOKFu3bopKipKr732mnr06CFPT89y9UeIhaU9M2i/PO2uLYoM/BoUFNr05Oq2kqSnB+6Tlwe3NVSXC0U19MSqNmaXAVjW0aNHVbNmTRmGIcMwHNMIyosQC0vztBfzQxxuy8vD4N9/teIXZqAifvzxRyUmJiohIUHvvfeenn76aXl6eqpr167q06ePIiMj1bNnT3l5ebnUHyG2gljcAQAA/Npdvqm0InPBGzdurGHDhmnYsGGSLoXay0/vevvtt/X3v/9dnp6eOn/+vEv9EWIriLvkAQDAr93lm6o/++yzSuuzSZMmuvfee3XvvfdKujTdYP369S6fX/ZnfOnSnWVjxoxRUFCQPD09FRISokmTJumnn35yHBMVFSWbzSabzSZvb2+Fh4crLi6u1JHLLVu2yG6364477ijRduTIEdlsNvn7+ys3N9eprVOnTpo5c6bTvkOHDmnMmDFq1qyZvLy81KRJE912221aunSpCgsLHcddru3n2/Lly8tzSQAAAFABzZo1U0xMjMvHl3kk9vvvv1ePHj0UHh6uZcuWqXnz5tq7d68effRRrVq1Slu3blX9+vUlSWPHjtVTTz2lgoICbdiwQQ8++KDq1q2rcePGOfUZHx+vCRMmKD4+XidOnFBQUFCJz83NzdXzzz+vWbNmXbW2bdu2qV+/fmrXrp0WLFigiIgISVJycrIWLFig9u3bq2PHjo7jlyxZooEDBzr1Ubdu3bJeEgAAAFyHKwHVMAy9+eabLvVX5hAbGxsrT09PffHFF6pVq5akS8m5c+fOatmypZ588km9+uqrkqTf/OY3CggIcBQ+f/58rV271inE5uXl6b333lNycrIyMjL05ptv6oknnijxuRMmTNCLL76o2NhY+fv7l/pFjx49WuHh4dq8ebNq1PjfIHOrVq00fPjwEqPAdevWddRXXvn5+crPz69QHyibK683U5IBVJcrv9/wfR/upjL+zb/99tsaOHDgVW/cKigo0KpVq6omxJ45c0Zr1qzR008/7QiwlwUEBGjEiBF67733tHDhQqc2wzD05Zdf6sCBA2rVqpVT2/vvv6+IiAi1bt1a9957ryZPnqzp06eXeBLK8OHDtXbtWj311FOaP39+idpSUlK0f/9+LVu2zCnAXqkiT1cpKChwmv+ak5MjSXrwwQdVs2bNcveLirlQZJN3TZIsgKp3oeh/P0NGjhxpYiWAdb3xxhtq3LhxqW2nTp0q0+BimebEHjx4UIZhqE2b0tfJa9OmjbKysnTq1ClJ0sKFC+Xj4yMvLy/17t1bxcXFmjhxotM58fHxjgm9AwcOVHZ2tpKSkkr0bbPZNGfOHC1evFiHDx8u0f7dd99Jklq3bu3Yl5mZKR8fH8f283A9fPhwp3YfHx8dPXq01K8tLi5Ofn5+ji04OPhqlwkAAAA/4+HhoaKioqu2FxYWym63u95feYpwdVmpESNG6Mknn1RWVpZmzJihnj17qmfPno721NRUbdu2TStXrrxUjIeHhg0bpvj4eEVFRZXob8CAAbrlllv017/+Ve++++51P79BgwZKSUmRdOlGswsXLji1z507V/369XPaV9p8XEmaPn26pkyZ4nifk5Oj4OBgLV68uNTpDag6+fn5jlEQTzujsACqx5Xfb9555x0eOwu3cuXP3vKqV6+eTp48edWsdfLkScd9Va4oU4gNCwuTzWbT/v379fvf/75E+/79+1WvXj01atRIkuTn56ewsDBJl6YNhIWFqXv37o7gGB8fr8LCQqcvxjAMeXl5af78+fLz8yvxGXPmzFGPHj306KOPOu2/PE0hNTVVnTt3liTZ7XbH53t4lPxSAwICHO3X4+XlVeocDm9vb76RmagCM0QAoEyu/H7D936g7Dp27KhVq1Y5ctrPrV69Wh06dHC5vzJNJ2jQoIH69++vhQsXlliINiMjQ0uXLtWwYcNKnXvq4+OjSZMmaerUqTIMQ4WFhXr77bf1wgsvKCUlxbF98803CgoK0rJly0qtoWvXrrr77rs1bdo0p/2dO3dWRESEnn/++Qo9wgwAAACVb8SIEZozZ442bNhQoi0hIUHPPPOMhg8f7nJ/ZZ5OMH/+fPXs2VMDBgzQ7NmznZbYatKkiZ5++umrnvvQQw/p73//uz766CN5eHgoKytL999/f4kR13vuuUfx8fF6+OGHS+3n6aefVrt27ZxGV202m5YsWaL+/furV69emj59utq0aaOLFy9q48aNOnXqVIl5FmfPnlVGRobTPl9fX9WuXbuslwUAAADXcN9992nFihXq16+fbrjhBrVp00Y2m00HDhzQN998o9/+9rcaPXq0y/2V+WEHrVq1UnJyslq0aKGhQ4eqZcuWevDBB9WnTx9t2bLlmnMZ6tevr/vuu08zZ85UfHy8+vXrV+qUgXvuuUfJycn69tt
"text/plain": [
"<Figure size 720x1620 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
4 years ago
"source": [
"palette = {\n",
" 'ORANGE': 'orange',\n",
" 'WHITE': 'wheat',\n",
"}\n",
"# We need the encoded Item Size column to use it as the x-axis values in the plot\n",
"pumpkins['Item Size'] = encoded_pumpkins['ord__Item Size']\n",
"\n",
"g = sns.catplot(\n",
" data=pumpkins,\n",
" x=\"Item Size\", y=\"Color\", row='Variety',\n",
" kind=\"box\", orient=\"h\",\n",
" sharex=False, margin_titles=True,\n",
" height=1.8, aspect=4, palette=palette,\n",
")\n",
"# Defining axis labels \n",
"g.set(xlabel=\"Item Size\", ylabel=\"\").set(xlim=(0,6))\n",
"g.set_titles(row_template=\"{row_name}\")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now focus on a specific relationship: Item Size and Color!"
4 years ago
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(action='ignore', category=UserWarning, module='seaborn')"
]
},
{
"cell_type": "code",
"execution_count": 37,
4 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: xlabel='Color', ylabel='ord__Item Size'>"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB9+0lEQVR4nO3deXQc1Z33/3dV9aatZcnaF1tesTEYL3gLdmzAxDAOkEDCPgES8jwTiJMZMkxgzu8MhFmAMUlIgkOWYSDJQIAwLE54MAEvbLHBbGFzAjZeZFuLV+1qqZffH1dSd6m7Zcu2UBs+r3N0wP3tunVv1e263666V7JisVgMERERkQxkD3cFRERERNJRoiIiIiIZS4mKiIiIZCwlKiIiIpKxlKiIiIhIxlKiIiIiIhlLiYqIiIhkLM9wV+BoRKNRdu/eTV5eHpZlDXd1RERE5DDEYjFaWlqoqKjAtge+Z3JcJyq7d++murp6uKshIiIiR6C2tpaqqqoB33NcJyp5eXmAaWgwGBzm2oiIiMjhaG5uprq6um8cH8hxnaj0Pu4JBoNKVERERI4zhzNtQ5NpRUREJGMpUREREZGMpURFREREMpYSFREREclYSlREREQkYylRERERkYylREVEREQylhIVERERyVhKVERERCRjKVERERGRjDXsv0J/165dfPe73+Xpp5+mvb2d8ePHc99993HqqacOd9U+2SIhqP1fOPg25E2E0ZeAJzseb3wR6laBNwg1l0N2wh+Nat0G238L4Tao/DwUzY3HultNrHULFMyA6i+C7TWxWBR2r4I9L0KgDMZcAf6R8W0Pvge1j5r/r/4SjJgSj4X2wbYHoKMOihdAxdlg9eTZ0W6ofRwOvAG542D0peDNjW+7dwPs+gN4ckwstyYea99pyu1uhvKzoWRBPBZuh+0PQ8tfYcRUqL4QHH9PW2JQ/xw0rAF/sWlLoGSwZ0FERA7BisViseHa+YEDB5g+fTqnn3463/jGNyguLubDDz9k3LhxjBs37pDbNzc3k5+fT1NTk/7Wz2B07oXVi6DpvfhrOTWweB1kj4JXroGP/jses31w2sNQ/QXY9hCs/1uIhePxCdfBrLuhZTM8twg6dsVjBdPhzNXg5MAL55vkp5c3HxY9DcXzYNOd8OYN7npO+0848QaTaKw9G7qb4rHyJfDZlRBph9VnmiSlV1YlnLkWghPgtWXwwd3xmOWBeb+Gmkth55Pw0kUQ7YrHx14Nc+6F9h2mLW3b4rH8KaZcXwG89GXY+UQ85smBhX+A0kUpD7mIiMQNZvwe1kTlxhtv5OWXX+bFF188ou2VqByhjd+ED1ckvz7qyzDmKnh+aXLMPxKWboKV4yDckhw/cx1sWg67n0qOTb7BJEKvXZccy58CC38Pvx9v7rgksmw4dzM8f647qep16gqTSGxanhyrWGr2u3pRcsyTB+dtgacmmzs1/S18CrbeDzt+lxybcB2MnAUbrkqO5Y6Fcz+M3+kREZGUBjN+D+sVdeXKlZx66ql8+ctfpqSkhOnTp/PLX/4y7ftDoRDNzc2uHzkCiXcCXK8/CbueTB0L7YMP70mdpIB5jFT3dPr97UxTbtN7sOW/k5MUMK9tvjd1knKocuuehh3/mzoWbjFtSZWk9JW7Mk3s8fT7bP0IDr6TOiYiIkdkWBOVjz76iHvuuYcJEybwzDPP8I1vfINvfetb/OpXv0r5/ttuu438/Py+n+rq6o+5xp8QvfMs+rP95iftdtkDxAJgedPH0u0TwMlKH/McYp/pyrW84AkMsO0hyrV9g4/BwMdPREQGbVgTlWg0yowZM/iP//gPpk+fzv/5P/+Hr3/96/zsZz9L+f6bbrqJpqamvp/a2tqPucafEDWXp3n9MvOTSk4NTPymmQSbxDJljroo9bajL0u/z+IFMP7rJgHoz/abWPGC5Nihyh11EdRcYerWX6DMtCWnJvW2NQOUW3N5+ljBDMiflDomIiJHZFgTlfLyck488UTXa5MnT2bHjh0p3+/3+wkGg64fOQJT/hkqPu9+rXgBTLvDrOCZdkd8pQ6YgX3+I+YOxYJHwV8Uj9k+mHkXFJwCM38II+e4y62+ECZdD6MvhonLcCUOeRNg3v0QKIbPPGAmpPby5JjXAsXmPXkTEwq1TKJRc4kpu/pL7n2OnGPqUjAVZv7IfQfEXwTzf2faMv937sTL8sC023uOwe3JCVLFUnPsqs6Fyf/knouSU2PqKyIix9SwTqa97LLLqK2tdU2m/Yd/+AdeeeUV/vSnPx1ye02mPUr73+hZnnyCWXmTqKPOLL/1BqH8HHASBvtIJ+x+2ixPLv9c8rLcxpegdbO5w1Aw1R1r2QJ7XoKscihb7B7su5tNuQAV55h994pFTX066qB4PuT1WxV24O348uSSfglGZyPU/dEkPxXnuO/eRLrMSqTuJlOfrHL3tnvWx5cnF85wx1q3QePzpv1lZ4E97Kv9RUSOC8fNqp+NGzfymc98hu9973tcdNFFvPrqq3z961/nF7/4BZdfnub2egIlKiIiIsef42bVz6xZs3j88cf57W9/y0knncS//uu/ctdddx1WkiIiIiKffMN6R+Vo6Y6KiIjI8ee4uaMiIiIiMhAlKiIiIpKxlKiIiIhIxlKiIiIiIhlLiYqIiIhkLCUqIiIikrGUqIiIiEjGUqIiIiIiGUuJioiIiGQsJSoiIiKSsZSoiIiISMZSoiIiIiIZS4mKiIiIZCwlKiIiIpKxlKiIiIhIxlKiIiIiIhlLiYqIiIhkLCUqIiIikrGUqIiIiEjGUqIiIiIiGUuJioiIiGQsJSoiIiKSsZSoiIiISMZSoiIiIiIZS4mKiIiIZCwlKiIiIpKxlKiIiIhIxlKiIiIiIhlLiYqIiIhkLCUqIiIikrGUqIiIiEjGUqIiIiIiGUuJioiIiGQsJSoiIiKSsZSoiIiISMZSoiIiIiIZS4mKiIiIZCwlKiIiIpKxlKiIiIhIxlKiIiIiIhlLiYqIiIhkLCUqIiIikrGUqIiIiEjGUqIiIiIiGUuJioiIiGQsz3Du/JZbbuF73/ue67UTTjiBv/zlL8NUo34e9ANd8X9f3AHbH4Gm9yB/Coy+CJxAPN6wFuqeBX8h1FwOWeXxWMsW2P4QRDqh+gtQODMe626GbQ9C2zYYORsqzwO759TEorDrKdj7MmRVwZjLwVcQ3/bAn6H2MbAcGH0xBE+IxzobYev/QGgPlJ4OZWeBZZlYJAS1/wsH34a8iTD6EvBkx7dtfBHqVoE3aNqSXRWPtW6D7b+FcBtUfh6K5ia0pdXEWrdAwQyo/iLY3nhbdq+CPS9CoAzGXAH+kfFtD74HtY+a/6/+EoyYEo+F9sG2B6CjDooXQMXZYPXk2dFuqH0cDrwBueNg9KXgzY1vu3cD7PoDeHJMLLcmHmvfacrtbobys6FkQTwWboftD0PLX2HEVKi+EBx/T1tiUP8cNKwBf7FpS6Akvm3zX822sQhUXwAFp8RjXQdg6wPQsROKPmOOYV9bwrBrJex7FXJqoOYycw567X8dap8w9Rh9KeSNi8c66kxbQvuh/CxzzntFOo9d3930X7D9Z/H4uO/AnDv5pIt07CXSsRfL8eLJqcTyxI9ftLuNSNtuYrEonuxSbP+Ivlgs2k24dTexcDu2fwROdilWz/mOxWJEOhqJdu7H8gRMuY4vXm6omXB7HRYWTk4Fti/er2OREOHWXcQiIZxAEXZWEVbP5zsWixBpqyfa1YztzcHJqcSynXhbOvcRad+DZXtwciuxPVkJbWk3bYmGcbJLcQLx600sGibSuotouB3bl4+TU+ZqS7SjkUjnfizHjye3yt2WrhYibXUAODnl2L68hLZ09bSlEydQiJ1VktCWaE9bmrA92Ti5lVh2fOiKdB4g0t5g2pJTge2NX8ei4Q4irbt62lKME4hfb2LRCJG23US7W7F9wZ62OPG2dPaebz+e3Eqs3s8+EO1qNceIGJ7scmx//DMai3QRbttFLNyJ7S/oOd8JbWlvIBo6iOXJxpNbgdV7fQSioYOE2xuwsHFyK7C9OfFyw52m3Eg3TlYRTlZ
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Suppressing warning message claiming that a portion of points cannot be placed into the plot due to the high number of data points\n",
"import warnings\n",
"warnings.filterwarnings(action='ignore', category=UserWarning, module='seaborn')\n",
"\n",
"palette = {\n",
" 0: 'orange',\n",
" 1: 'wheat'\n",
"}\n",
"sns.swarmplot(x=\"Color\", y=\"ord__Item Size\", hue=\"Color\", data=encoded_pumpkins, palette=palette)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"**Watch out**: Ignoring warnings is NOT a best practice and should be avoid, whenever possible. Warnings often contain useful messages that let us improve our code and solve an issue.\n",
"The reason why we are ignoring this specific warning is to guarantee the readability of the plot. Plotting all the data points with a reduced marker size, while keeping consistency with the palette color, generates an unclear visualization."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Build your model"
4 years ago
]
},
{
"cell_type": "code",
"execution_count": 74,
4 years ago
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"# X is the encoded features\n",
"X = encoded_pumpkins[encoded_pumpkins.columns.difference(['Color'])]\n",
"# y is the encoded label\n",
"y = encoded_pumpkins['Color']\n",
4 years ago
"\n",
"# Split the data into training and test sets\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
4 years ago
{
4 years ago
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.94 0.98 0.96 166\n",
" 1 0.85 0.67 0.75 33\n",
"\n",
" accuracy 0.92 199\n",
" macro avg 0.89 0.82 0.85 199\n",
"weighted avg 0.92 0.92 0.92 199\n",
"\n",
"Predicted labels: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0\n",
" 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 1 0 1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 1 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0\n",
" 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n",
" 0 0 0 1 0 0 0 0 0 0 0 0 1 1]\n",
"F1-score: 0.7457627118644068\n"
]
4 years ago
}
],
4 years ago
"source": [
"from sklearn.metrics import f1_score, classification_report \n",
4 years ago
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"# Train a logistic regression model on the pumpkin dataset\n",
4 years ago
"model = LogisticRegression()\n",
"model.fit(X_train, y_train)\n",
"predictions = model.predict(X_test)\n",
4 years ago
"\n",
"# Evaluate the model and print the results\n",
4 years ago
"print(classification_report(y_test, predictions))\n",
"print('Predicted labels: ', predictions)\n",
"print('F1-score: ', f1_score(y_test, predictions))"
4 years ago
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[162, 4],\n",
" [ 11, 22]])"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"confusion_matrix(y_test, predictions)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhgAAAIjCAYAAABBOWJ+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABgUElEQVR4nO3dd1gUV8MF8LOUZelqsItBjTV2jcYKKgoWFDWKJUKIvUc0tqjYoibWxKDGFqwRNGqMjQiC3WhE7CX2BiixgHR27/eHL/tJBGVxl9lyfs/DE3aY2T07ETjcuTMjE0IIEBEREWmRmdQBiIiIyPiwYBAREZHWsWAQERGR1rFgEBERkdaxYBAREZHWsWAQERGR1rFgEBERkdaxYBAREZHWsWAQERGR1rFgEBERkdaxYBCZgODgYMhkMvWHhYUFypYtiy+++AIPHz7MdRshBDZs2ICWLVuiSJEisLGxQa1atTBz5kwkJyfn+Vo7duxA+/bt4eTkBLlcjjJlyqBnz544ePBgvrKmpaVh8eLFaNy4MRwdHaFQKFClShWMGDEC169fL9D7J6LCJ+O9SIiMX3BwMPz9/TFz5kxUqFABaWlpOHnyJIKDg+Hi4oKLFy9CoVCo11cqlejTpw9CQ0PRokULdOvWDTY2Njhy5Ag2b96MGjVqIDw8HCVLllRvI4TAl19+ieDgYNSrVw+fffYZSpUqhdjYWOzYsQNnzpzBsWPH0LRp0zxzJiQkwNPTE2fOnEGnTp3g7u4OOzs7XLt2DVu2bEFcXBwyMjJ0uq+ISEsEERm9X375RQAQp0+fzrF8woQJAoAICQnJsXzOnDkCgBg3btwbz7Vr1y5hZmYmPD09cyyfP3++ACC++uoroVKp3thu/fr14q+//nprzo4dOwozMzOxbdu2N76WlpYmxo4d+9bt8yszM1Okp6dr5bmIKHcsGEQmIK+CsXv3bgFAzJkzR70sJSVFFC1aVFSpUkVkZmbm+nz+/v4CgDhx4oR6m2LFiolq1aqJrKysAmU8efKkACAGDhyYr/VdXV2Fq6vrG8v9/PzEhx9+qH58+/ZtAUDMnz9fLF68WFSsWFGYmZmJkydPCnNzczF9+vQ3nuPq1asCgFi6dKl62bNnz8To0aNFuXLlhFwuF5UqVRLz5s0TSqVS4/dKZAo4B4PIhN25cwcAULRoUfWyo0eP4tmzZ+jTpw8sLCxy3c7X1xcAsHv3bvU2T58+RZ8+fWBubl6gLLt27QIA9OvXr0Dbv8svv/yCpUuXYtCgQVi4cCFKly4NV1dXhIaGvrFuSEgIzM3N0aNHDwBASkoKXF1dsXHjRvj6+uLHH39Es2bNMGnSJAQEBOgkL5Ghy/2nBxEZpRcvXiAhIQFpaWn466+/MGPGDFhZWaFTp07qdS5fvgwAqFOnTp7Pk/21K1eu5PhvrVq1CpxNG8/xNg8ePMCNGzdQvHhx9TIfHx8MHjwYFy9eRM2aNdXLQ0JC4Orqqp5jsmjRIty8eRNnz55F5cqVAQCDBw9GmTJlMH/+fIwdOxbOzs46yU1kqDiCQWRC3N3dUbx4cTg7O+Ozzz6Dra0tdu3ahXLlyqnXSUpKAgDY29vn+TzZX0tMTMzx37dt8y7aeI636d69e45yAQDdunWDhYUFQkJC1MsuXryIy5cvw8fHR71s69ataNGiBYoWLYqEhAT1h7u7O5RKJQ4fPqyTzESGjCMYRCYkKCgIVapUwYsXL7B27VocPnwYVlZWOdbJ/gWfXTRy898S4uDg8M5t3uX15yhSpEiBnycvFSpUeGOZk5MT2rRpg9DQUMyaNQvAq9ELCwsLdOvWTb3eP//8g/Pnz79RULI9fvxY63mJDB0LBpEJadSoERo2bAgA8Pb2RvPmzdGnTx9cu3YNdnZ2AIDq1asDAM6fPw9vb+9cn+f8+fMAgBo1agAAqlWrBgC4cOFCntu8y+vP0aJFi3euL5PJIHI5y16pVOa6vrW1da7Le/XqBX9/f8TExKBu3boIDQ1FmzZt4OTkpF5HpVKhbdu2GD9+fK7PUaVKlXfmJTI1PERCZKLMzc0xd+5cPHr0CD/99JN6efPmzVGkSBFs3rw5z1/W69evBwD13I3mzZujaNGi+PXXX/Pc5l28vLwAABs3bszX+kWLFsXz58/fWH737l2NXtfb2xtyuRwhISGIiYnB9evX0atXrxzrVKpUCS9fvoS7u3uuH+XLl9foNYlMAQsGkQlzc3NDo0aNsGTJEqSlpQEAbGxsMG7cOFy7dg3ffPPNG9vs2bMHwcHB8PDwwKeffqreZsKECbhy5QomTJiQ68jCxo0bcerUqTyzNGnSBJ6enli9ejV27tz5xtczMjIwbtw49eNKlSrh6tWrePLkiXrZuXPncOzYsXy/fwAoUqQIPDw8EBoaii1btkAul78xCtOzZ0+cOHECYWFhb2z//PlzZGVlafSaRKaAV/IkMgHZV/I8ffq0+hBJtm3btqFHjx5Yvnw5hgwZAuDVYQYfHx/89ttvaNmyJbp37w5ra2scPXoUGzduRPXq1REREZHjSp4qlQpffPEFNmzYgPr166uv5BkXF4edO3fi1KlTOH78OJo0aZJnzidPnqBdu3Y4d+4cvLy80KZNG9ja2uKff/7Bli1bEBsbi/T0dACvzjqpWbMm6tSpg/79++Px48dYsWIFSpYsicTERPUpuHfu3EGFChUwf/78HAXldZs2bcLnn38Oe3t7uLm5qU+ZzZaSkoIWLVrg/Pnz+OKLL9CgQQMkJyfjwoUL2LZtG+7cuZPjkAoRgVfyJDIFeV1oSwghlEqlqFSpkqhUqVKOi2QplUrxyy+/iGbNmgkHBwehUCjExx9/LGbMmCFevnyZ52tt27ZNtGvXThQrVkxYWFiI0qVLCx8fHxEVFZWvrCkpKWLBggXik08+EXZ2dkIul4vKlSuLkSNHihs3buRYd+PGjaJixYpCLpeLunXrirCwsLdeaCsviYmJwtraWgAQGzduzHWdpKQkMWnSJPHRRx8JuVwunJycRNOmTcWCBQtERkZGvt4bkSnhCAYRERFpHedgEBERkdaxYBAREZHWsWAQERGR1rFgEBERkdaxYBAREZHWsWAQERGR1pncvUhUKhUePXoEe3t7yGQyqeMQEREZDCEEkpKSUKZMGZiZvX2MwuQKxqNHj+Ds7Cx1DCIiIoN1//59lCtX7q3rmFzByL699P3799W3hyYiIqJ3S0xMhLOzs/p36duYXMHIPizi4ODAgkFERFQA+ZliwEmeREREpHUsGERERKR1LBhERESkdSwYREREpHUsGERERKR1LBhERESkdSwYREREpHUsGERERKR1LBhERESkdSwYREREpHUsGERERKR1LBhERESkdSwYREREpHUsGERERKR1khaMw4cPw8vLC2XKlIFMJsPOnTvfuU1UVBTq168PKysrfPTRRwgODtZ5TiIiItKMpAUjOTkZderUQVBQUL7Wv337Njp27IhWrVohJiYGX331FQYMGICwsDAdJyUiIiJNWEj54u3bt0f79u3zvf6KFStQoUIFLFy4EABQvXp1HD16FIsXL4aHh4euYhoVIQSi7z3Dk6R0qaMQEVEhaFLJCY7WloX+upIWDE2dOHEC7u7uOZZ5eHjgq6++ynOb9PR0pKf//y/TxMREXcUzCCdvPUXvVSeljkFERIVk76gWLBjvEhcXh5IlS+ZYVrJkSSQmJiI1NRXW1tZvbDN37lzMmDGjsCLqvQsPnwMAnOys4PKBjbRhiIhI52zk5pK8rkEVjIKYNGkSAgIC1I8TExPh7OwsYSJp3XuaAgDo9YkzxnlUlTgNERFp05kzZ/DTTz9h5cqVsLQs/FGL1xlUwShVqhTi4+NzLIuPj4eDg0OuoxcAYGVlBSsrq8KIZxDuPU0FAJQvxtELIiJjcvr0abRr1w7Pnz9H+fLlJR+9N6jrYDRp0gQRERE5lh04cABNmjSRKJHhuf+/EQxnFgwiIqNx6tQptG3bFs+fP0ezZs0wbtw4qSNJWzBevnyJmJgYxMTEAHh1GmpMTAzu3bsH4NXhDV9fX/X6Q4YMwa1btzB+/Hh
"text/plain": [
"<Figure size 600x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
4 years ago
"from sklearn.metrics import roc_curve, roc_auc_score\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
4 years ago
"\n",
"y_scores = model.predict_proba(X_test)\n",
"# calculate ROC curve\n",
"fpr, tpr, thresholds = roc_curve(y_test, y_scores[:,1])\n",
"\n",
"# plot ROC curve\n",
"fig = plt.figure(figsize=(6, 6))\n",
"# Plot the diagonal 50% line\n",
"plt.plot([0, 1], [0, 1], 'k--')\n",
"# Plot the FPR and TPR achieved by our model\n",
"plt.plot(fpr, tpr)\n",
"plt.xlabel('False Positive Rate')\n",
"plt.ylabel('True Positive Rate')\n",
"plt.title('ROC Curve')\n",
"plt.show()"
4 years ago
]
},
{
"cell_type": "code",
"execution_count": 78,
4 years ago
"metadata": {},
"outputs": [
{
4 years ago
"name": "stdout",
"output_type": "stream",
4 years ago
"text": [
"0.9749908725812341\n"
4 years ago
]
}
],
"source": [
"# Calculate AUC score\n",
"auc = roc_auc_score(y_test,y_scores[:,1])\n",
"print(auc)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"metadata": {
"interpreter": {
"hash": "70b38d7a306a849643e446cd70466270a13445e5987dfa1344ef2b127438fa4d"
}
},
"orig_nbformat": 2,
"vscode": {
"interpreter": {
"hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}