You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/2-Regression/1-Tools/solution/notebook.ipynb

190 lines
44 KiB

4 years ago
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"## Linear Regression Solution"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"Import needed libraries"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from sklearn import datasets, linear_model, model_selection\n"
]
},
{
"source": [
"Load the diabetes dataset, divided into `X` data and `y` features"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(442, 10)\n[ 0.03807591 0.05068012 0.06169621 0.02187235 -0.0442235 -0.03482076\n -0.04340085 -0.00259226 0.01990842 -0.01764613]\n"
]
}
],
"source": [
"X, y = datasets.load_diabetes(return_X_y=True)\n",
"print(X.shape)\n",
"print(X[0])"
]
},
{
"source": [
"Select just one feature to target for this exercise"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"X = X[:, np.newaxis, 2]\n"
]
},
{
"source": [
"Split the training and test data for both `X` and `y`"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.33)\n"
]
},
{
"source": [
"Select the model and fit it with the training data"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LinearRegression()"
]
},
"metadata": {},
"execution_count": 6
}
],
"source": [
"model = linear_model.LinearRegression()\n",
"model.fit(X_train, y_train)"
]
},
{
"source": [
"Use test data to predict a line"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"y_pred = model.predict(X_test)\n"
]
},
{
"source": [
"Display the results in a plot"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 375.2875 248.518125\" width=\"375.2875pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 248.518125 \nL 375.2875 248.518125 \nL 375.2875 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 33.2875 224.64 \nL 368.0875 224.64 \nL 368.0875 7.2 \nL 33.2875 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path d=\"M 0 3 \nC 0.795609 3 1.55874 2.683901 2.12132 2.12132 \nC 2.683901 1.55874 3 0.795609 3 0 \nC 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132 \nC 1.55874 -2.683901 0.795609 -3 0 -3 \nC -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132 \nC -2.683901 -1.55874 -3 -0.795609 -3 0 \nC -3 0.795609 -2.683901 1.55874 -2.12132 2.12132 \nC -1.55874 2.683901 -0.795609 3 0 3 \nz\n\" id=\"m87062d904c\" style=\"stroke:#000000;\"/>\n </defs>\n <g clip-path=\"url(#pe67797de08)\">\n <use style=\"stroke:#000000;\" x=\"165.453286\" xlink:href=\"#m87062d904c\" y=\"93.400322\"/>\n <use style=\"stroke:#000000;\" x=\"148.960675\" xlink:href=\"#m87062d904c\" y=\"167.214822\"/>\n <use style=\"stroke:#000000;\" x=\"348.371333\" xlink:href=\"#m87062d904c\" y=\"17.083636\"/>\n <use style=\"stroke:#000000;\" x=\"157.956645\" xlink:href=\"#m87062d904c\" y=\"111.541174\"/>\n <use style=\"stroke:#000000;\" x=\"139.964706\" xlink:href=\"#m87062d904c\" y=\"177.849114\"/>\n <use style=\"stroke:#000000;\" x=\"138.465377\" xlink:href=\"#m87062d904c\" y=\"184.10458\"/>\n <use style=\"stroke:#000000;\" x=\"81.490903\" xlink:href=\"#m87062d904c\" y=\"182.22794\"/>\n <use style=\"stroke:#000000;\" x=\"277.902905\" xlink:href=\"#m87062d904c\" y=\"58.369712\"/>\n <use style=\"stroke:#000000;\" x=\"205.935149\" xlink:href=\"#m87062d904c\" y=\"85.268216\"/>\n <use style=\"stroke:#000000;\" x=\"120.473438\" xlink:href=\"#m87062d904c\" y=\"205.373165\"/>\n <use style=\"stroke:#000000;\" x=\"171.450599\" xlink:href=\"#m87062d904c\" y=\"114.668907\"/>\n <use style=\"stroke:#000000;\" x=\"202.936492\" xlink:href=\"#m87062d904c\" y=\"54.616433\"/>\n <use style=\"stroke:#000000;\" x=\"171.450599\" xlink:href=\"#m87062d904c\" y=\"162.210449\"/>\n <use style=\"stroke:#000000;\" x=\"258.411638\" xlink:href=\"#m87062d904c\" y=\"78.387204\"/>\n <use style=\"stroke:#000000;\" x=\"193.940523\" xlink:href=\"#m87062d904c\" y=\"92.774776\"/>\n <use style=\"stroke:#000000;\" x=\"193.940523\" xlink:href=\"#m87062d904c\" y=\"141.567411\"/>\n <use style=\"stroke:#000000;\" x=\"183.445225\" xlink:href=\"#m87062d904c\" y=\"148.448423\"/>\n <use style=\"stroke:#000000;\" x=\"159.455973\" xlink:href=\"#m87062d904c\" y=\"169.717008\"/>\n <use style=\"stroke:#000000;\" x=\"238.920371\" xlink:href=\"#m87062d904c\" y=\"121.549919\"/>\n <use style=\"stroke:#000000;\" x=\"130.968736\" xlink:href=\"#m87062d904c\" y=\"195.989965\"/>\n <use style=\"stroke:#000000;\" x=\"133.967393\" xlink:href=\"#m87062d904c\" y=\"175.346928\"/>\n <use style=\"stroke:#000000;\" x=\"166.952614\" xlink:href=\"#m87062d904c\" y=\"157.831623\"/>\n <use style=\"stroke:#000000;\" x=\"163.953958\" xlink:href=\"#m87062d904c\" y=\"197.866605\"/>\n <use style=\"stroke:#000000;\" x=\"351.36999\" xlink:href=\"#m87062d904c\" y=\"37.726674\"/>\n <use style=\"stroke:#000000;\" x=\"141.464034\" xlink:href=\"#m87062d904c\" y=\"140.316318\"/>\n <use style=\"stroke:#000000;\" x=\"121.972766\" xlink:href=\"#m87062d904c\" y=\"159.708262\"/>\n <use style=\"stroke:#000000;\" x=\"177.447912\" xlink:href=\"#m870
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2dfZBc1Xmnn1ejGeIx5kMtoYiP6cG7tjfISWysxVBOiOOEJahSsWNwyutxNgWsx6FCLWbDpiRPOZBkZ20QzrpcdgXGBVXgGXBRwSkLr2OiCFKp2IAjtpCAJBIfnhmwiVAkWINlsNCc/aPvSD0zfb+678e5t39P1a3pvn3vOeee6f7d977nPe8x5xxCCCHqxaqyGyCEECJ7JO5CCFFDJO5CCFFDJO5CCFFDJO5CCFFDVpfdAIC1a9e60dHRspshhBCV4tFHH/0359y6Tp/FiruZ/Qzw98AJwfF/6Zy73sxuAD4BHAgO/bRz7lvBOVuBK4GjwH9zzt0fVcfo6Ci7du1KeDlCCCEAzGwu7LMklvvrwAecc6+a2SDwD2b218Fn/9s5d/Oyys4BPgpsBE4H/tbM3u6cO9pd84UQQqQl1ufuWrwavB0MtqiZTx8Evuace905933gaeC8nlsqhBAiMYkGVM1swMweA14EdjjnHgk+utrM9pjZ7WZ2arDvDOC5ttOfD/YJIYQoiETi7pw76px7F3AmcJ6ZvRP4C+DfAe8CXgA+HxxunYpYvsPMxs1sl5ntOnDgQIdThBBCdEuqUEjn3MvA3wG/4ZzbH4j+AvAVjrtengfOajvtTOCHHcqacs5tcs5tWreu42CvEEKILokVdzNbZ2anBK/fBPw68C9mtqHtsN8Gnghebwc+amYnmNnZwNuA72XbbCGEqA4zMzOMjo6yatUqRkdHmZmZyb3OJNEyG4A7zGyA1s3gHufcN83sq2b2Lloul1ngkwDOuSfN7B7gn4A3gD9QpIwQol+ZmZlhfHycw4cPAzA3N8f4+DgAY2NjudVrPqT83bRpk1OcuxCijoyOjjI3tzIcvdlsMjs721PZZvaoc25Tp8+UfkAIIXJkfn4+1f6skLgLIUSOjIyMpNqfFRJ3IYTIkcnJSYaHh5fsGx4eZnJyMtd6Je5CCJEjY2NjTE1N0Ww2MTOazSZTU1O5DqaCBlSFEKKyaEBVCCH6DIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EELUEIm7EH1AGWt4inKRuAtRcxbX8Jybm8M5d2wNTwl8ORR1o1XKXyFqTp5reIp0LF8sG1oLd3Sb3z0q5a/EXYias2rVKjr9zs2MhYWFElrUv2R9o1U+dyH6mLLW8BQrKXKxbIm7EBnh66Dl5s2bU+0X+VHkjTZW3M3sZ8zse2a228yeNLM/CfavMbMdZvZU8PfUtnO2mtnTZrbXzC7OvNVCeIbPg5bf+ta3Uu0X+VHoYtnOucgNMODE4PUg8AhwPnATsCXYvwW4MXh9DrAbOAE4G3gGGIiq4z3veY8Toso0m00HrNiazWbZTXNm1rFtZlZ20/qS6elp12w2nZm5ZrPppqenuy4L2OVCdDXWcg/KeDV4OxhsDvggcEew/w7gQ8HrDwJfc8697pz7PvA0cF7Ke44QlaJIX2pa5HP3i7GxMWZnZ1lYWGB2drarKJkkJPK5m9mAmT0GvAjscM49Aqx3zr0AEPw9LTj8DOC5ttOfD/YtL3PczHaZ2a4DBw70cg1ClI7PApqXK8DXMQbRIpG4O+eOOufeBZwJnGdm74w43DoV0aHMKefcJufcpnXr1iVrrRCeUqgvNSVjY2NMTU3RbDYxM5rNZtdx1Yv4PMYgWqSOczez64EfA58A3u+ce8HMNgB/55x7h5ltBXDOfTY4/n7gBufcQ2FlKs5d1IGZmRkmJiaYn59nZGSEycnJ3B65y0YTo/ygp0lMZrYOOOKce9nM3gT8DXAj8CvAQefc58xsC7DGOfdHZrYRuIuWn/10YCfwNufc0bA6JO5CVAtNjPKDKHFfneD8DcAdZjZAy41zj3Pum2b2EHCPmV0JzAMfAXDOPWlm9wD/BLwB/EGUsAshqseaNWs4ePDgiv0+jDGIFrHi7pzbA7y7w/6DwK+FnDMJlO9sFEJkzszMDK+88sqK/YODg16MMYgWmqEqhEjFxMQEP/3pT1fsP+mkk2o7xlBFJO5CiFSExe4fOnSo4JaIKCTuwhsUN10NfI7pF8eRuAsv8Dlu2vebTtHt8zmmX7QRlpegyE25ZYSvuVmmp6fd8PDwkjYNDw/3lA8kS8pqX5b5UUT3EJFbRot1CC/wNW7a98k6vrdP5IsW6xDe46sf1+eEYOB/+0R5SNyFF/jqx/X1prOIL+1b9PubGatXr8bMCh2fmJmZYe3atZgZZsbatWu9GxspnDB/TZGbfO7COT/9uD753Dv1jw/t69SGItsyPT3thoaGVtQ9ODjoxXcoT4jwuZcu7E7iLjzHh5tOlIiX3b6wwfDFLe9B8aj6yx6Qz5socdeAqhAVwOeB07DB8EXyHhSPqj+sbh8yeL7xBtx5J7z1rfD+93dXhgZUhag4Pg+cxvn38/b/R5Xf6bOy51Ts2QMnnwyDg3DllfCrvwpf/3r29UjchagAvgycdqLTYPgiRQyKT05OMjQ0tGJ/WCKziYkJDh8+vGTf4cOHmZiYyK2NCwvwx38MZvCLvwg/+tHSz3NxXIT5a4rc5HMXvlG2H7tTe8oeOI1isb8ANzAwcMzfXVT7pqenXaPRONY3jUYjtO4iFwzft8+5M89sjW6GbSed5NzCQnflowFVIZLjq5D6dsMpm277I+/Z0AsLzm3bFi3o4NxFFzl36FBvdUnchUiBr6kQxHF6uQHndfOen3du48Z4UZ+Z6amaJUjchUhBkY/tojt6vQFn+RR0663xgn7++c796792XUUoUeKuAVVRKYqYCenz4KVo0Wv00NjYGLOzsywsLDA7O5s6DPLFF+GCC1oDpJ/8ZPhxt9zSkveHHoL165d+lns2zzDVL3KT5S6SUNRMSF997uI4ZbnO7ror3ko/5xzn5uaiy8nqO4bcMqIOFDkTUoOXflPkDfill1qDn3Givm1b8qiXrG5OUeKuGaqiMpQ9E1L4Rd6zTO+7D37rt6KPOess2LkT3va2dGVnleJaM1RFT/iyElHZMyGFX/TqN+/Ej38Ml13W8qVHCftnPgNHj8L8fHphh4LGdcJM+sUNOAt4EPhn4EngmmD/DcAPgMeCbXPbOVuBp4G9wMVxdcgt4y8++Z/Lzj4o6ssDD8S7XU4+2bndu7OpzwufO7ABODd4/RZgH3BOIO7XdTj+HGA3cAJwNvAMMBBVh8TdX4oauErq4y57JqSoD6+95tzll8eL+jXXOHfkSPb1ZzGuEyXuqX3uZvYN4EvA+4BXnXM3L/t8a/BE8Nng/f3ADc65h8LKlM/dX4pY/m4xkVN7vo/h4WGmpqYKz9Qn6s8jj7TCGKOkz6wVvvje9xbXrm7IzOduZqPAu4FHgl1Xm9keM7vdzE4N9p0BPNd22vPBvuVljZvZLjPbdeDAgTTNEAVShG+wjEROvuHLuEZdeeMN+NSnWqJ9/vnhwn755fCTn7QSffku7LGEmfTLN+BE4FHgw8H79cAArRvEJHB7sP/LwMfbzrsNuDSqbLll/KUIn3u/zwj1aVyjbuze3UrMFed62bmz7JZ2B73OUDWzQeBeYMY59/XgprDfOXfUObcAfAU4Lzj8eVqDsIucCfww3S1H+MLY2BhTU1M0m03MjGazmbm7pN9nhJb55FLHJ4a49LqLXHYZvPJKS94/8IFi21gIYarvjlveBtwJfGHZ/g1tr68Fvha83sjSAdVn0YCqiKDfLdeynlzq1u9J0uuCc9u3l93S7KDHaJlfCv7xe2gLewS+Cjwe7N++TOwnaEXJ7AUuiatD4i76eUZoWVPp65D9Mml63Ysvbs00rRtR4q4ZqkKUTFnRQkVEQuXFc8/BJZfAk09GHzczAx/
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"plt.scatter(X_test, y_test, color='black')\n",
"plt.plot(X_test, y_pred, color='blue', linewidth=3)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}