{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Build Classification Model" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\r\n", "from sklearn.model_selection import train_test_split, cross_val_score\r\n", "from sklearn.metrics import accuracy_score,precision_score,confusion_matrix,classification_report, precision_recall_curve\r\n", "from sklearn.svm import SVC\r\n", "import pandas as pd\r\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
almondangelicaaniseanise_seedappleapple_brandyapricotarmagnacartemisiaartichoke...whiskeywhite_breadwhite_winewhole_grain_wheat_flourwinewoodyamyeastyogurtzucchini
00000000000...0000000000
11000000000...0000000000
20000000000...0000000000
30000000000...0000000000
40000000000...0000000010
\n", "

5 rows × 380 columns

\n", "
" ], "text/plain": [ " almond angelica anise anise_seed apple apple_brandy apricot \\\n", "0 0 0 0 0 0 0 0 \n", "1 1 0 0 0 0 0 0 \n", "2 0 0 0 0 0 0 0 \n", "3 0 0 0 0 0 0 0 \n", "4 0 0 0 0 0 0 0 \n", "\n", " armagnac artemisia artichoke ... whiskey white_bread white_wine \\\n", "0 0 0 0 ... 0 0 0 \n", "1 0 0 0 ... 0 0 0 \n", "2 0 0 0 ... 0 0 0 \n", "3 0 0 0 ... 0 0 0 \n", "4 0 0 0 ... 0 0 0 \n", "\n", " whole_grain_wheat_flour wine wood yam yeast yogurt zucchini \n", "0 0 0 0 0 0 0 0 \n", "1 0 0 0 0 0 0 0 \n", "2 0 0 0 0 0 0 0 \n", "3 0 0 0 0 0 0 0 \n", "4 0 0 0 0 0 1 0 \n", "\n", "[5 rows x 380 columns]" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed_feature_df = pd.read_csv(\".data/features_dataset.csv\")\r\n", "transformed_feature_df= transformed_feature_df.drop(['Unnamed: 0'], axis=1)\r\n", "transformed_feature_df.head()" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cuisine
0indian
1indian
2indian
3indian
4indian
\n", "
" ], "text/plain": [ " cuisine\n", "0 indian\n", "1 indian\n", "2 indian\n", "3 indian\n", "4 indian" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed_label_df = pd.read_csv(\".data/labels_dataset.csv\")\r\n", "transformed_label_df= transformed_label_df.drop(['Unnamed: 0'], axis=1)\r\n", "transformed_label_df.head()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(transformed_feature_df, transformed_label_df, test_size=0.3)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy is 0.8031693077564637\n" ] } ], "source": [ "lr = LogisticRegression(multi_class='ovr',solver='lbfgs')\r\n", "model = lr.fit(X_train, np.ravel(y_train))\r\n", "\r\n", "accuracy = model.score(X_test, y_test)\r\n", "print (\"Accuracy is {}\".format(accuracy))" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ingredients: Index(['corn'], dtype='object')\n", "cusine: cuisine thai\n", "Name: 3816, dtype: object\n" ] } ], "source": [ "# test an item\r\n", "print(f'ingredients: {X_test.iloc[20][X_test.iloc[20]!=0].keys()}')\r\n", "print(f'cusine: {y_test.iloc[20]}')" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0
thai0.475724
chinese0.201912
japanese0.152046
korean0.110980
indian0.059338
\n", "
" ], "text/plain": [ " 0\n", "thai 0.475724\n", "chinese 0.201912\n", "japanese 0.152046\n", "korean 0.110980\n", "indian 0.059338" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#rehsape to 2d array and transpose\r\n", "test= X_test.iloc[20].values.reshape(-1, 1).T\r\n", "# predict with score\r\n", "proba = model.predict_proba(test)\r\n", "classes = model.classes_\r\n", "# create df with classes and scores\r\n", "resultdf = pd.DataFrame(data=proba, columns=classes)\r\n", "\r\n", "# create df to show results\r\n", "topPrediction = resultdf.T.sort_values(by=[0], ascending = [False])\r\n", "topPrediction.head()" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " chinese 0.77 0.70 0.74 239\n", " indian 0.88 0.88 0.88 240\n", " japanese 0.76 0.79 0.77 227\n", " korean 0.86 0.78 0.82 240\n", " thai 0.75 0.86 0.80 253\n", "\n", " accuracy 0.80 1199\n", " macro avg 0.81 0.80 0.80 1199\n", "weighted avg 0.81 0.80 0.80 1199\n", "\n" ] } ], "source": [ "y_pred = model.predict(X_test)\r\n", "print(classification_report(y_test,y_pred))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Try different classifiers" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "\r\n", "C = 10\r\n", "# Create different classifiers.\r\n", "classifiers = {\r\n", " 'L1 logistic': LogisticRegression(C=C, penalty='l1',\r\n", " solver='saga',\r\n", " multi_class='multinomial',\r\n", " max_iter=10000),\r\n", " 'L2 logistic (Multinomial)': LogisticRegression(C=C, penalty='l2',\r\n", " solver='saga',\r\n", " multi_class='multinomial',\r\n", " max_iter=10000),\r\n", " 'L2 logistic (OvR)': LogisticRegression(C=C, penalty='l2',\r\n", " solver='saga',\r\n", " multi_class='ovr',\r\n", " max_iter=10000),\r\n", " 'Linear SVC': SVC(kernel='linear', C=C, probability=True,\r\n", " random_state=0)\r\n", "}\r\n" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy (train) for L1 logistic: 79.9% \n", "Accuracy (train) for L2 logistic (Multinomial): 79.7% \n", "Accuracy (train) for L2 logistic (OvR): 79.8% \n", "Accuracy (train) for Linear SVC: 77.9% \n" ] } ], "source": [ "n_classifiers = len(classifiers)\r\n", "\r\n", "for index, (name, classifier) in enumerate(classifiers.items()):\r\n", " classifier.fit(X_train, np.ravel(y_train))\r\n", "\r\n", " y_pred = classifier.predict(X_test)\r\n", " accuracy = accuracy_score(y_test, y_pred)\r\n", " print(\"Accuracy (train) for %s: %0.1f%% \" % (name, accuracy * 100))\r\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "dd61f40108e2a19f4ef0d3ebbc6b6eea57ab3c4bc13b15fe6f390d3d86442534" }, "kernelspec": { "display_name": "Python 3.8.5 64-bit ('onnxwine': conda)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }