You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/4-Classification/1-Introduction/solution/data-prep-visual.ipynb

1521 lines
102 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"\r\n",
"import pandas as pd\r\n",
"import matplotlib.pyplot as plt\r\n",
"import matplotlib as mpl\r\n",
"import numpy as np\r\n",
"from imblearn.over_sampling import SMOTE"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('.data/asian_indian_recipes.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Unnamed: 0</th>\n",
" <th>cuisine</th>\n",
" <th>almond</th>\n",
" <th>angelica</th>\n",
" <th>anise</th>\n",
" <th>anise_seed</th>\n",
" <th>apple</th>\n",
" <th>apple_brandy</th>\n",
" <th>apricot</th>\n",
" <th>armagnac</th>\n",
" <th>...</th>\n",
" <th>whiskey</th>\n",
" <th>white_bread</th>\n",
" <th>white_wine</th>\n",
" <th>whole_grain_wheat_flour</th>\n",
" <th>wine</th>\n",
" <th>wood</th>\n",
" <th>yam</th>\n",
" <th>yeast</th>\n",
" <th>yogurt</th>\n",
" <th>zucchini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>65</td>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>66</td>\n",
" <td>indian</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>67</td>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>68</td>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>69</td>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 385 columns</p>\n",
"</div>"
],
"text/plain": [
" Unnamed: 0 cuisine almond angelica anise anise_seed apple \\\n",
"0 65 indian 0 0 0 0 0 \n",
"1 66 indian 1 0 0 0 0 \n",
"2 67 indian 0 0 0 0 0 \n",
"3 68 indian 0 0 0 0 0 \n",
"4 69 indian 0 0 0 0 0 \n",
"\n",
" apple_brandy apricot armagnac ... whiskey white_bread white_wine \\\n",
"0 0 0 0 ... 0 0 0 \n",
"1 0 0 0 ... 0 0 0 \n",
"2 0 0 0 ... 0 0 0 \n",
"3 0 0 0 ... 0 0 0 \n",
"4 0 0 0 ... 0 0 0 \n",
"\n",
" whole_grain_wheat_flour wine wood yam yeast yogurt zucchini \n",
"0 0 0 0 0 0 0 0 \n",
"1 0 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 1 0 \n",
"\n",
"[5 rows x 385 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 2448 entries, 0 to 2447\n",
"Columns: 385 entries, Unnamed: 0 to zucchini\n",
"dtypes: int64(384), object(1)\n",
"memory usage: 7.2+ MB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"korean 799\n",
"indian 598\n",
"chinese 442\n",
"japanese 320\n",
"thai 289\n",
"Name: cuisine, dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.cuisine.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#df.keys().values"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAD4CAYAAAAtrdtxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR0ElEQVR4nO3de5CddX3H8fenAYIRSERSuqKyYCMWQQEDFeMFlXop1EvLdKS1QscaW2sVnVHjSL10qhMvtVSpTuOlOmqpBVEpdIrUVqx4gQ0EEm5SNQp4AazGOzrw7R/nSTmGTcj+ds+eJ+z7NbNznv09v+c5n7PZzWef5znnbKoKSZJa/Mq4A0iSdl2WiCSpmSUiSWpmiUiSmlkikqRmu407wHzab7/9anJyctwxJGmXsn79+tuqavl06xZUiUxOTjI1NTXuGJK0S0ny9e2t83SWJKmZJSJJamaJSJKaWSKSpGaWiCSpmSUiSWpmiUiSmlkikqRmlogkqdmCesX6xpu3MLnmgnHH0BzZvPaEcUeQFjyPRCRJzSwRSVIzS0SS1MwSkSQ1s0QkSc0sEUlSM0tEktSsFyWSZFmSF3XLxyU5f4bb/1WS40eTTpK0Pb0oEWAZ8KLWjavqtVX1H3MXR5K0M/pSImuBhyTZALwV2CvJOUmuS/KRJAFI8toklyXZlGTd0PgHkpw0vviStDD1pUTWAF+pqiOAVwBHAqcBhwIHA6u6eWdW1dFVdRhwH+DEe9pxktVJppJM3fGTLaPILkkLVl9KZFuXVtVNVXUnsAGY7MafmORLSTYCTwIefk87qqp1VbWyqlYuWrJ0ZIElaSHq6xsw3j60fAewW5I9gXcBK6vqxiSvB/YcRzhJ0kBfjkR+COx9D3O2FsZtSfYCvAYiSWPWiyORqvpukkuSbAJ+CnxnmjnfT/IeYCOwGbhsflNKkrbVixIBqKo/2M74i4eWTwdOn2bOqaNLJknanr6czpIk7YIsEUlSM0tEktTMEpEkNbNEJEnNevPsrPlw+AFLmVp7wrhjSNK9hkcikqRmlogkqZklIklqZolIkppZIpKkZpaIJKmZJSJJamaJSJKaWSKSpGaWiCSpmSUiSWpmiUiSmlkikqRmlogkqZklIklqZolIkppZIpKkZpaIJKmZJSJJamaJSJKaWSKSpGa7jTvAfNp48xYm11ww7hgak81rTxh3BOlexyMRSVIzS0SS1MwSkSQ1s0QkSc0sEUlSM0tEktRsp0okyedHHUSStOvZqRKpqseMOogkadezs0ciP0qyV5JPJ7k8ycYkz+zWTSa5LskHk1yV5JwkS7p1r01yWZJNSdYlSTf+mSRvTnJpki8neVw3vijJW7ttrkrywm58Islnk2zo9rV1/lOSfKHLdHaSvUbxRZIkTW8m10R+Bjy7qo4Cngj8zdZSAA4B1lXVI4AfAC/qxs+sqqOr6jDgPsCJQ/vbraqOAU4DXteNPR/YUlVHA0cDL0hyEPAHwIVVdQTwSGBDkv2A04Hju0xTwMtn8HgkSbM0k7c9CfCmJI8H7gQOAPbv1t1YVZd0yx8GXgK8DXhiklcCS4B9gauBf+3mndvdrgcmu+WnAI9IclL3+VJgBXAZ8P4kuwOfqKoNSZ4AHApc0nXZHsAX7hY6WQ2sBli0z/IZPFxJ0j2ZSYn8IbAceFRV/SLJZmDPbl1tM7eS7Am8C1hZVTcmef3QfIDbu9s7hnIE+IuqunDbO+/K6wTgQ0neCnwPuKiqTt5R6KpaB6wDWDyxYtuckqRZmMnprKXALV2BPBE4cGjdg5Mc2y2fDHyOuwrjtu5axUncswuBP+uOOEjy0CT3TXJgd9/vAd4HHAV8EViV5Ne7uUuSPHQGj0eSNEs7eyRSwEeAf00yBWwArhtafy1wSpJ/AG4A3l1VP0nyHmAjsJnBKal78l4Gp7Yu76633Ao8CzgOeEWSXwA/Ap5XVbcmORU4K8nibvvTgS/v5GOSJM1SqnZ8hifJ/YHLq+rA7ayfBM7vLp732uKJFTVxyhnjjqEx8a3gpTZJ1lfVyunW7fB0VpIHMLhY/bZRBJMk7dp2eDqrqr4J7PA6Q1VtBnp/FCJJmnu+d5YkqZklIklqZolIkprN5MWGu7zDD1jKlM/QkaQ545GIJKmZJSJJamaJSJKaWSKSpGaWiCSpmSUiSWpmiUiSmlkikqRmlogkqZklIklqZolIkppZIpKkZpaIJKmZJSJJamaJSJKaWSKSpGaWiCSpmSUiSWpmiUiSmlkikqRmlogkqdlu4w4wnzbevIXJNReMO4bUbPPaE8YdQfolHolIkppZIpKkZpaIJKmZJSJJamaJSJKaWSKSpGaWiCSp2ZyWSJIPJDlpmvEHJDlnLu9LkjR+8/Jiw6r6JnC3cpEk7dpmdSSS5HlJrkpyZZIPdcOPT/L5JF/delSSZDLJpm751CTnJvn3JDckecvQ/p6S5AtJLk9ydpK9uvG1Sa7p7utt3djyJB9Lcln3sWo2j0WSNHPNRyJJHg68BlhVVbcl2Rd4OzABPBZ4GHAeMN1prCOAI4HbgeuTvBP4KXA6cHxV/TjJq4CXJzkTeDbwsKqqJMu6ffwd8LdV9bkkDwYuBH5jmpyrgdUAi/ZZ3vpwJUnTmM3prCcB51TVbQBV9b9JAD5RVXcC1yTZfzvbfrqqtgAkuQY4EFgGHApc0u1nD+ALwA+AnwHvTXIBcH63j+OBQ7u5APsk2buqfjh8R1W1DlgHsHhiRc3i8UqStjGbEgkw3X/Kt28zZzrDc+7ocgS4qKpOvtsdJccATwaeA7yYQYH9CnBsVf105tElSXNhNtdEPg38fpL7A3Sns2bji8CqJL/e7W9Jkod210WWVtW/AacxOBUG8CkGhUI3/wgkSfOq+Uikqq5O8kbg4iR3AFfMJkhV3ZrkVOCsJIu74dOBHwKfTLIng6OVl3XrXgL8fZKrGDyOzwJ/OpsMkqSZSdXCuUyweGJFTZxyxrhjSM38eyIahyTrq2rldOt8xbokqZklIklqZolIkppZIpKkZpaIJKnZvLwBY18cfsBSpnx2iyTNGY9EJEnNLBFJUjNLRJLUzBKRJDWzRCRJzSwRSVIzS0SS1MwSkSQ1s0QkSc0sEUlSM0tEktTMEpEkNbNEJEnNLBFJUjNLRJLUzBKRJDWzRCRJzSwRSVIzS0SS1MwSkSQ1s0QkSc12G3eA+bTx5i1Mrrlg3DEkzdDmtSeMO4K2wyMRSVIzS0SS1MwSkSQ1s0QkSc0sEUlSM0tEktRsZCWS5PMznH9ckvO75WckWTOaZJKkuTKy14lU1WNmse15wHlzGEeSNAKjPBL5UXd7XJLPJDknyXVJPpIk3bqndWOfA353aNtTk5zZLf9Oki8luSLJfyTZvxt/fZL3d/v+apKXjOqxSJKmN1/XRI4ETgMOBQ4GViXZE3gP8DvA44Bf2862nwMeXVVHAv8MvHJo3cOApwLHAK9LsvtI0kuSpjVfb3tyaVXdBJBkAzAJ/Aj4WlXd0I1/GFg9zbYPBD6aZALYA/ja0LoLqup24PYktwD7AzcNb5xk9db9Ltpn+Rw+JEnSfB2J3D60fAd3lVftxLbvBM6sqsOBFwJ77sR+/19VrauqlVW1ctGSpTNLLUnaoXE+xfc64KAkD+k+P3k785YCN3fLp4w8lSRpp42tRKrqZwxOM13QXVj/+namvh44O8l/A7fNUzxJ0k5I1c6cUbp3WDyxoiZOOWPcMSTNkG8FP15J1lfVyunW+Yp1SVIzS0SS1MwSkSQ1s0QkSc0sEUlSs/l6xXovHH7AUqZ8lockzRmPRCRJzSwRSVIzS0SS1MwSkSQ1s0QkSc0sEUlSM0tEktTMEpEkNbNEJEnNLBFJUjNLRJLUzBKRJDWzRCRJzSwRSVIzS0SS1MwSkSQ1s0QkSc0sEUlSM0tEktTMEpEkNbNEJEnNdht3gPm08eYtTK65YNwxJGlebV57wsj27ZGIJKmZJSJJamaJSJKaWSKSpGaWiCSpmSUiSWpmiUiSms1piSSZTLJpLvcpSeqvXhyJJFlQL3qUpHuLkZVIkoOTXJHk6CRfTHJVko8nuV+3/jNJ3pTkYuClSR6V5OIk65NcmGSim/eCJJcluTLJx5Is6cY/kOQdST6f5KtJThrVY5EkTW8kJZLkEOBjwB8D7wNeVVWPADYCrxuauqyqngC8A3gncFJVPQp4P/DGbs65VXV0VT0SuBZ4/tD2E8BjgROBtdvJsjrJVJKpO36yZc4eoyRpNO+dtRz4JPB7wE0MiuLibt0HgbOH5n60uz0EOAy4KAnAIuBb3brDkvw1sAzYC7hwaPtPVNWdwDVJ9p8uTFWtA9YBLJ5YUbN6ZJKkXzKKEtkC3Ais4q6S2J4fd7cBrq6qY6eZ8wHgWVV1ZZJTgeOG1t0+tJyWsJKkdqM4nfVz4FnA84ATgO8leVy37o+Ai6fZ5npgeZJjAZLsnuTh3bq9gW8l2R34wxHklSQ1Gsmzoqrqx0lOBC4CzgXe2l0Q/yqD6yTbzv95d2H8HUmWdrnOAK4G/hL4EvB1BtdU9h5FZknSzKVq4VwmWDyxoiZOOWPcMSRpXs3274kkWV9VK6db14vXiUiSdk2WiCSpmSUiSWpmiUiSmlkikqRmC+qNDw8/YClTs3yWgiTpLh6JSJKaWSKSpGaWiCSpmSUiSWpmiUiSmlkikqRmlogkqZklIklqZolIkppZIpKkZgvqj1Il+SGDP8XbV/sBt407xA6Yb3bMNzvmm53Z5DuwqpZPt2JBvXcWcP32/jpXHySZMl87882O+WZnoebzdJYkqZklIklqttBKZN24A9wD882O+WbHfLOzIPMtqAvrkqS5tdCORCRJc8gSkSQ1WzAlkuRpSa5P8j9J1owpw/uT3JJk09DYvkkuSnJDd3u/oXWv7vJen+SpI872oCT/leTaJFcneWnP8u2Z5NIkV3b53tCnfEP3uSjJFUnO71u+JJuTbEyyIclUD/MtS3JOkuu678Nj+5IvySHd123rxw+SnNaXfN39vaz72diU5KzuZ2b0+arqXv8BLAK+AhwM7AFcCRw6hhyPB44CNg2NvQVY0y2vAd7cLR/a5VwMHNTlXzTCbBPAUd3y3sCXuwx9yRdgr255d+BLwKP7km8o58uBfwLO79O/b3efm4H9thnrU74PAn/SLe8BLOtTvqGci4BvAwf2JR9wAPA14D7d5/8CnDof+Ub+Be/DB3AscOHQ568GXj2mLJP8colcD0x0yxMMXhB5t4zAhcCx85jzk8Bv9TEfsAS4HPjNPuUDHgh8GngSd5VIn/Jt5u4l0ot8wD7df4LpY75tMj0FuKRP+RiUyI3AvgxeRH5+l3Pk+RbK6aytX+CtburG+mD/qvoWQHf7q9342DInmQSOZPDbfm/ydaeKNgC3ABdVVa/yAWcArwTuHBrrU74CPpVkfZLVPct3MHAr8I/d6cD3Jrlvj/INew5wVrfci3xVdTPwNuAbwLeALVX1qfnIt1BKJNOM9f25zWPJnGQv4GPAaVX1gx1NnWZspPmq6o6qOoLBb/zHJDlsB9PnNV+SE4Fbqmr9zm4yzdio/31XVdVRwNOBP0/y+B3Mne98uzE41fvuqjoS+DGD0y/bM66fjz2AZwBn39PUacZG+f13P+CZDE5NPQC4b5Ln7miTacaa8i2UErkJeNDQ5w8EvjmmLNv6TpIJgO72lm583jMn2Z1BgXykqs7tW76tqur7wGeAp/Uo3yrgGUk2A/8MPCnJh3uUj6r6Znd7C/Bx4Jge5bsJuKk7ugQ4h0Gp9CXfVk8HLq+q73Sf9yXf8cDXqurWqvoFcC7wmPnIt1BK5DJgRZKDut8kngOcN+ZMW50HnNItn8LgWsTW8eckWZzkIGAFcOmoQiQJ8D7g2qp6ew/zLU+yrFu+D4Mfmuv6kq+qXl1VD6yqSQbfX/9ZVc/tS74k902y99ZlBufLN/UlX1V9G7gxySHd0JOBa/qSb8jJ3HUqa2uOPuT7BvDoJEu6n+UnA9fOS775uBDVhw/gtxk84+grwGvGlOEsBucrf8HgN4HnA/dncDH2hu5236H5r+nyXg88fcTZHsvgcPYqYEP38ds9yvcI4Iou3ybgtd14L/Jtk/U47rqw3ot8DK45XNl9XL31Z6Av+br7OwKY6v6NPwHcr2f5lgDfBZYOjfUp3xsY/GK1CfgQg2dejTyfb3siSWq2UE5nSZJGwBKRJDWzRCRJzSwRSVIzS0SS1MwSkSQ1s0QkSc3+DwlMP+/hPKDCAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"#display classes in bar graph\r\n",
"df.cuisine.value_counts().plot.barh()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"thai df: (289, 385)\n",
"japanese df: (320, 385)\n",
"chinese df: (442, 385)\n",
"indian df: (598, 385)\n",
"korean df: (799, 385)\n"
]
}
],
"source": [
"# ingrediant counts by class count\r\n",
"# filter to thai food, display ingredients graph\r\n",
"\r\n",
"thai_df = df[(df.cuisine == \"thai\")]\r\n",
"japanese_df = df[(df.cuisine == \"japanese\")]\r\n",
"chinese_df = df[(df.cuisine == \"chinese\")]\r\n",
"indian_df = df[(df.cuisine == \"indian\")]\r\n",
"korean_df = df[(df.cuisine == \"korean\")]\r\n",
"\r\n",
"print(f'thai df: {thai_df.shape}')\r\n",
"print(f'japanese df: {japanese_df.shape}')\r\n",
"print(f'chinese df: {chinese_df.shape}')\r\n",
"print(f'indian df: {indian_df.shape}')\r\n",
"print(f'korean df: {korean_df.shape}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What are the top ingredients by class"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def create_ingredient_df(df):\r\n",
" #transpose df, drop cuisine and unnamed rows, sum the row to get total for ingredient and add value header to new df\r\n",
" ingredient_df = df.T.drop(['cuisine','Unnamed: 0']).sum(axis=1).to_frame('value')\r\n",
" # drop ingredients that have a 0 sum\r\n",
" ingredient_df = ingredient_df[(ingredient_df.T != 0).any()]\r\n",
" # sort df\r\n",
" ingredient_df = ingredient_df.sort_values(by='value', ascending=False, inplace=False)\r\n",
" return ingredient_df\r\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"thai_ingredient_df = create_ingredient_df(thai_df)\r\n",
"thai_ingredient_df.head(10).plot.barh()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"japanese_ingredient_df = create_ingredient_df(japanese_df)\r\n",
"japanese_ingredient_df.head(10).plot.barh()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"chinese_ingredient_df = create_ingredient_df(chinese_df)\r\n",
"chinese_ingredient_df.head(10).plot.barh()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"indian_ingredient_df = create_ingredient_df(indian_df)\r\n",
"indian_ingredient_df.head(10).plot.barh()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"korean_ingredient_df = create_ingredient_df(korean_df)\r\n",
"korean_ingredient_df.head(10).plot.barh()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# TODO add categorical labels to food items - calculated columns to improve accuracy"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>almond</th>\n",
" <th>angelica</th>\n",
" <th>anise</th>\n",
" <th>anise_seed</th>\n",
" <th>apple</th>\n",
" <th>apple_brandy</th>\n",
" <th>apricot</th>\n",
" <th>armagnac</th>\n",
" <th>artemisia</th>\n",
" <th>artichoke</th>\n",
" <th>...</th>\n",
" <th>whiskey</th>\n",
" <th>white_bread</th>\n",
" <th>white_wine</th>\n",
" <th>whole_grain_wheat_flour</th>\n",
" <th>wine</th>\n",
" <th>wood</th>\n",
" <th>yam</th>\n",
" <th>yeast</th>\n",
" <th>yogurt</th>\n",
" <th>zucchini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 380 columns</p>\n",
"</div>"
],
"text/plain": [
" almond angelica anise anise_seed apple apple_brandy apricot \\\n",
"0 0 0 0 0 0 0 0 \n",
"1 1 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 0 0 \n",
"\n",
" armagnac artemisia artichoke ... whiskey white_bread white_wine \\\n",
"0 0 0 0 ... 0 0 0 \n",
"1 0 0 0 ... 0 0 0 \n",
"2 0 0 0 ... 0 0 0 \n",
"3 0 0 0 ... 0 0 0 \n",
"4 0 0 0 ... 0 0 0 \n",
"\n",
" whole_grain_wheat_flour wine wood yam yeast yogurt zucchini \n",
"0 0 0 0 0 0 0 0 \n",
"1 0 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 1 0 \n",
"\n",
"[5 rows x 380 columns]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# set x and y features\r\n",
"# dropping common ingredients to improve accuracy\r\n",
"#feature_df= df.drop(['cuisine','Unnamed: 0'], axis=1)\r\n",
"feature_df= df.drop(['cuisine','Unnamed: 0','rice','garlic','ginger'], axis=1)\r\n",
"labels_df = df.cuisine #.unique()\r\n",
"feature_df.head()\r\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# balance data with SMOTE oversamplling to the highest class. Read more here: https://imbalanced-learn.org/dev/references/generated/imblearn.over_sampling.SMOTE.html\r\n",
"oversample = SMOTE()\r\n",
"transformed_feature_df, transformed_label_df = oversample.fit_resample(feature_df, labels_df)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>almond</th>\n",
" <th>angelica</th>\n",
" <th>anise</th>\n",
" <th>anise_seed</th>\n",
" <th>apple</th>\n",
" <th>apple_brandy</th>\n",
" <th>apricot</th>\n",
" <th>armagnac</th>\n",
" <th>artemisia</th>\n",
" <th>artichoke</th>\n",
" <th>...</th>\n",
" <th>whiskey</th>\n",
" <th>white_bread</th>\n",
" <th>white_wine</th>\n",
" <th>whole_grain_wheat_flour</th>\n",
" <th>wine</th>\n",
" <th>wood</th>\n",
" <th>yam</th>\n",
" <th>yeast</th>\n",
" <th>yogurt</th>\n",
" <th>zucchini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 380 columns</p>\n",
"</div>"
],
"text/plain": [
" almond angelica anise anise_seed apple apple_brandy apricot \\\n",
"0 0 0 0 0 0 0 0 \n",
"1 1 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 0 0 \n",
"\n",
" armagnac artemisia artichoke ... whiskey white_bread white_wine \\\n",
"0 0 0 0 ... 0 0 0 \n",
"1 0 0 0 ... 0 0 0 \n",
"2 0 0 0 ... 0 0 0 \n",
"3 0 0 0 ... 0 0 0 \n",
"4 0 0 0 ... 0 0 0 \n",
"\n",
" whole_grain_wheat_flour wine wood yam yeast yogurt zucchini \n",
"0 0 0 0 0 0 0 0 \n",
"1 0 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 1 0 \n",
"\n",
"[5 rows x 380 columns]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed_feature_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"new label count: chinese 799\n",
"korean 799\n",
"indian 799\n",
"thai 799\n",
"japanese 799\n",
"Name: cuisine, dtype: int64\n",
"old label count: korean 799\n",
"indian 598\n",
"chinese 442\n",
"japanese 320\n",
"thai 289\n",
"Name: cuisine, dtype: int64\n"
]
}
],
"source": [
"print(f'new label count: {transformed_label_df.value_counts()}')\r\n",
"print(f'old label count: {df.cuisine.value_counts()}')"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cuisine</th>\n",
" <th>almond</th>\n",
" <th>angelica</th>\n",
" <th>anise</th>\n",
" <th>anise_seed</th>\n",
" <th>apple</th>\n",
" <th>apple_brandy</th>\n",
" <th>apricot</th>\n",
" <th>armagnac</th>\n",
" <th>artemisia</th>\n",
" <th>...</th>\n",
" <th>whiskey</th>\n",
" <th>white_bread</th>\n",
" <th>white_wine</th>\n",
" <th>whole_grain_wheat_flour</th>\n",
" <th>wine</th>\n",
" <th>wood</th>\n",
" <th>yam</th>\n",
" <th>yeast</th>\n",
" <th>yogurt</th>\n",
" <th>zucchini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>indian</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>indian</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 381 columns</p>\n",
"</div>"
],
"text/plain": [
" cuisine almond angelica anise anise_seed apple apple_brandy apricot \\\n",
"0 indian 0 0 0 0 0 0 0 \n",
"1 indian 1 0 0 0 0 0 0 \n",
"2 indian 0 0 0 0 0 0 0 \n",
"3 indian 0 0 0 0 0 0 0 \n",
"4 indian 0 0 0 0 0 0 0 \n",
"\n",
" armagnac artemisia ... whiskey white_bread white_wine \\\n",
"0 0 0 ... 0 0 0 \n",
"1 0 0 ... 0 0 0 \n",
"2 0 0 ... 0 0 0 \n",
"3 0 0 ... 0 0 0 \n",
"4 0 0 ... 0 0 0 \n",
"\n",
" whole_grain_wheat_flour wine wood yam yeast yogurt zucchini \n",
"0 0 0 0 0 0 0 0 \n",
"1 0 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 1 0 \n",
"\n",
"[5 rows x 381 columns]"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# export transformed data to new csv for classification\r\n",
"transformed_df = pd.concat([transformed_label_df,transformed_feature_df],axis=1, join='outer')\r\n",
"transformed_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"transformed_feature_df.to_csv(\".data/features_dataset.csv\")\r\n",
"transformed_label_df.to_csv(\".data/labels_dataset.csv\")\r\n",
"transformed_df.to_csv(\".data/processed.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Build classification model"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\r\n",
"from sklearn.model_selection import train_test_split, cross_val_score\r\n",
"from sklearn.metrics import accuracy_score,precision_score,confusion_matrix,classification_report, precision_recall_curve\r\n",
"from sklearn.svm import SVC"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(transformed_feature_df, transformed_label_df, test_size=0.3)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy is 0.7973311092577148\n"
]
}
],
"source": [
"lr = LogisticRegression(multi_class='ovr',solver='lbfgs')\r\n",
"model = lr.fit(X_train, y_train)\r\n",
"\r\n",
"accuracy = model.score(X_test, y_test)\r\n",
"print (\"Accuracy is {}\".format(accuracy))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ingredients: Index(['bean', 'carrot', 'cayenne', 'pea', 'sake', 'scallion', 'sesame_oil',\n",
" 'shrimp', 'starch', 'vegetable_oil', 'vinegar'],\n",
" dtype='object')\n",
"cusine: chinese\n"
]
}
],
"source": [
"# test an item\r\n",
"print(f'ingredients: {X_test.iloc[20][X_test.iloc[20]!=0].keys()}')\r\n",
"print(f'cusine: {y_test.iloc[20]}')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>chinese</th>\n",
" <td>0.848156</td>\n",
" </tr>\n",
" <tr>\n",
" <th>japanese</th>\n",
" <td>0.110072</td>\n",
" </tr>\n",
" <tr>\n",
" <th>korean</th>\n",
" <td>0.033688</td>\n",
" </tr>\n",
" <tr>\n",
" <th>thai</th>\n",
" <td>0.005013</td>\n",
" </tr>\n",
" <tr>\n",
" <th>indian</th>\n",
" <td>0.003072</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0\n",
"chinese 0.848156\n",
"japanese 0.110072\n",
"korean 0.033688\n",
"thai 0.005013\n",
"indian 0.003072"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#rehsape to 2d array and transpose\r\n",
"test= X_test.iloc[20].values.reshape(-1, 1).T\r\n",
"# predict with score\r\n",
"proba = model.predict_proba(test)\r\n",
"classes = model.classes_\r\n",
"# create df with classes and scores\r\n",
"resultdf = pd.DataFrame(data=proba, columns=classes)\r\n",
"\r\n",
"# create df to show results\r\n",
"topPrediction = resultdf.T.sort_values(by=[0], ascending = [False])\r\n",
"topPrediction.head()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"y_pred = model.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" chinese 0.78 0.75 0.76 243\n",
" indian 0.91 0.92 0.91 233\n",
" japanese 0.67 0.77 0.71 244\n",
" korean 0.84 0.78 0.81 241\n",
" thai 0.82 0.77 0.80 238\n",
"\n",
" accuracy 0.80 1199\n",
" macro avg 0.80 0.80 0.80 1199\n",
"weighted avg 0.80 0.80 0.80 1199\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test,y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Try different classifiers"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"\r\n",
"C = 10\r\n",
"# Create different classifiers.\r\n",
"classifiers = {\r\n",
" 'L1 logistic': LogisticRegression(C=C, penalty='l1',\r\n",
" solver='saga',\r\n",
" multi_class='multinomial',\r\n",
" max_iter=10000),\r\n",
" 'L2 logistic (Multinomial)': LogisticRegression(C=C, penalty='l2',\r\n",
" solver='saga',\r\n",
" multi_class='multinomial',\r\n",
" max_iter=10000),\r\n",
" 'L2 logistic (OvR)': LogisticRegression(C=C, penalty='l2',\r\n",
" solver='saga',\r\n",
" multi_class='ovr',\r\n",
" max_iter=10000),\r\n",
" 'Linear SVC': SVC(kernel='linear', C=C, probability=True,\r\n",
" random_state=0)\r\n",
"}\r\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy (train) for L1 logistic: 80.7% \n",
"Accuracy (train) for L2 logistic (Multinomial): 80.9% \n",
"Accuracy (train) for L2 logistic (OvR): 80.7% \n",
"Accuracy (train) for Linear SVC: 79.9% \n"
]
}
],
"source": [
"n_classifiers = len(classifiers)\r\n",
"\r\n",
"for index, (name, classifier) in enumerate(classifiers.items()):\r\n",
" classifier.fit(X_train, y_train)\r\n",
"\r\n",
" y_pred = classifier.predict(X_test)\r\n",
" accuracy = accuracy_score(y_test, y_pred)\r\n",
" print(\"Accuracy (train) for %s: %0.1f%% \" % (name, accuracy * 100))\r\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "dd61f40108e2a19f4ef0d3ebbc6b6eea57ab3c4bc13b15fe6f390d3d86442534"
},
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('onnxwine': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}