You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/4-Classification/1-Introduction/solution/intro-classification.ipynb

563 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Build Classification Model"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\r\n",
"from sklearn.model_selection import train_test_split, cross_val_score\r\n",
"from sklearn.metrics import accuracy_score,precision_score,confusion_matrix,classification_report, precision_recall_curve\r\n",
"from sklearn.svm import SVC\r\n",
"import pandas as pd\r\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>almond</th>\n",
" <th>angelica</th>\n",
" <th>anise</th>\n",
" <th>anise_seed</th>\n",
" <th>apple</th>\n",
" <th>apple_brandy</th>\n",
" <th>apricot</th>\n",
" <th>armagnac</th>\n",
" <th>artemisia</th>\n",
" <th>artichoke</th>\n",
" <th>...</th>\n",
" <th>whiskey</th>\n",
" <th>white_bread</th>\n",
" <th>white_wine</th>\n",
" <th>whole_grain_wheat_flour</th>\n",
" <th>wine</th>\n",
" <th>wood</th>\n",
" <th>yam</th>\n",
" <th>yeast</th>\n",
" <th>yogurt</th>\n",
" <th>zucchini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 380 columns</p>\n",
"</div>"
],
"text/plain": [
" almond angelica anise anise_seed apple apple_brandy apricot \\\n",
"0 0 0 0 0 0 0 0 \n",
"1 1 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 0 0 \n",
"\n",
" armagnac artemisia artichoke ... whiskey white_bread white_wine \\\n",
"0 0 0 0 ... 0 0 0 \n",
"1 0 0 0 ... 0 0 0 \n",
"2 0 0 0 ... 0 0 0 \n",
"3 0 0 0 ... 0 0 0 \n",
"4 0 0 0 ... 0 0 0 \n",
"\n",
" whole_grain_wheat_flour wine wood yam yeast yogurt zucchini \n",
"0 0 0 0 0 0 0 0 \n",
"1 0 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 1 0 \n",
"\n",
"[5 rows x 380 columns]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed_feature_df = pd.read_csv(\".data/features_dataset.csv\")\r\n",
"transformed_feature_df= transformed_feature_df.drop(['Unnamed: 0'], axis=1)\r\n",
"transformed_feature_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>cuisine</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>indian</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>indian</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>indian</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>indian</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>indian</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" cuisine\n",
"0 indian\n",
"1 indian\n",
"2 indian\n",
"3 indian\n",
"4 indian"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed_label_df = pd.read_csv(\".data/labels_dataset.csv\")\r\n",
"transformed_label_df= transformed_label_df.drop(['Unnamed: 0'], axis=1)\r\n",
"transformed_label_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(transformed_feature_df, transformed_label_df, test_size=0.3)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy is 0.8031693077564637\n"
]
}
],
"source": [
"lr = LogisticRegression(multi_class='ovr',solver='lbfgs')\r\n",
"model = lr.fit(X_train, np.ravel(y_train))\r\n",
"\r\n",
"accuracy = model.score(X_test, y_test)\r\n",
"print (\"Accuracy is {}\".format(accuracy))"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ingredients: Index(['corn'], dtype='object')\n",
"cusine: cuisine thai\n",
"Name: 3816, dtype: object\n"
]
}
],
"source": [
"# test an item\r\n",
"print(f'ingredients: {X_test.iloc[20][X_test.iloc[20]!=0].keys()}')\r\n",
"print(f'cusine: {y_test.iloc[20]}')"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>thai</th>\n",
" <td>0.475724</td>\n",
" </tr>\n",
" <tr>\n",
" <th>chinese</th>\n",
" <td>0.201912</td>\n",
" </tr>\n",
" <tr>\n",
" <th>japanese</th>\n",
" <td>0.152046</td>\n",
" </tr>\n",
" <tr>\n",
" <th>korean</th>\n",
" <td>0.110980</td>\n",
" </tr>\n",
" <tr>\n",
" <th>indian</th>\n",
" <td>0.059338</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0\n",
"thai 0.475724\n",
"chinese 0.201912\n",
"japanese 0.152046\n",
"korean 0.110980\n",
"indian 0.059338"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#rehsape to 2d array and transpose\r\n",
"test= X_test.iloc[20].values.reshape(-1, 1).T\r\n",
"# predict with score\r\n",
"proba = model.predict_proba(test)\r\n",
"classes = model.classes_\r\n",
"# create df with classes and scores\r\n",
"resultdf = pd.DataFrame(data=proba, columns=classes)\r\n",
"\r\n",
"# create df to show results\r\n",
"topPrediction = resultdf.T.sort_values(by=[0], ascending = [False])\r\n",
"topPrediction.head()"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" chinese 0.77 0.70 0.74 239\n",
" indian 0.88 0.88 0.88 240\n",
" japanese 0.76 0.79 0.77 227\n",
" korean 0.86 0.78 0.82 240\n",
" thai 0.75 0.86 0.80 253\n",
"\n",
" accuracy 0.80 1199\n",
" macro avg 0.81 0.80 0.80 1199\n",
"weighted avg 0.81 0.80 0.80 1199\n",
"\n"
]
}
],
"source": [
"y_pred = model.predict(X_test)\r\n",
"print(classification_report(y_test,y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Try different classifiers"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"\r\n",
"C = 10\r\n",
"# Create different classifiers.\r\n",
"classifiers = {\r\n",
" 'L1 logistic': LogisticRegression(C=C, penalty='l1',\r\n",
" solver='saga',\r\n",
" multi_class='multinomial',\r\n",
" max_iter=10000),\r\n",
" 'L2 logistic (Multinomial)': LogisticRegression(C=C, penalty='l2',\r\n",
" solver='saga',\r\n",
" multi_class='multinomial',\r\n",
" max_iter=10000),\r\n",
" 'L2 logistic (OvR)': LogisticRegression(C=C, penalty='l2',\r\n",
" solver='saga',\r\n",
" multi_class='ovr',\r\n",
" max_iter=10000),\r\n",
" 'Linear SVC': SVC(kernel='linear', C=C, probability=True,\r\n",
" random_state=0)\r\n",
"}\r\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy (train) for L1 logistic: 79.9% \n",
"Accuracy (train) for L2 logistic (Multinomial): 79.7% \n",
"Accuracy (train) for L2 logistic (OvR): 79.8% \n",
"Accuracy (train) for Linear SVC: 77.9% \n"
]
}
],
"source": [
"n_classifiers = len(classifiers)\r\n",
"\r\n",
"for index, (name, classifier) in enumerate(classifiers.items()):\r\n",
" classifier.fit(X_train, np.ravel(y_train))\r\n",
"\r\n",
" y_pred = classifier.predict(X_test)\r\n",
" accuracy = accuracy_score(y_test, y_pred)\r\n",
" print(\"Accuracy (train) for %s: %0.1f%% \" % (name, accuracy * 100))\r\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "dd61f40108e2a19f4ef0d3ebbc6b6eea57ab3c4bc13b15fe6f390d3d86442534"
},
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('onnxwine': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}