{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear Regression for Diabetes dataset - Lesson 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import needed libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn import datasets, linear_model, model_selection\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load the diabetes dataset, divided into `X` data and `y` features" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(442, 10)\n", "[ 0.03807591 0.05068012 0.06169621 0.02187239 -0.0442235 -0.03482076\n", " -0.04340085 -0.00259226 0.01990749 -0.01764613]\n" ] } ], "source": [ "X, y = datasets.load_diabetes(return_X_y=True)\n", "print(X.shape)\n", "print(X[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Select just one feature to target for this exercise" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(442,)\n" ] } ], "source": [ "# Selecting the 3rd feature\n", "X = X[:, 2]\n", "print(X.shape)\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(442, 1)\n", "[[ 0.06169621]\n", " [-0.05147406]\n", " [ 0.04445121]\n", " [-0.01159501]\n", " [-0.03638469]\n", " [-0.04069594]\n", " [-0.04716281]\n", " [-0.00189471]\n", " [ 0.06169621]\n", " [ 0.03906215]\n", " [-0.08380842]\n", " [ 0.01750591]\n", " [-0.02884001]\n", " [-0.00189471]\n", " [-0.02560657]\n", " [-0.01806189]\n", " [ 0.04229559]\n", " [ 0.01211685]\n", " [-0.0105172 ]\n", " [-0.01806189]\n", " [-0.05686312]\n", " [-0.02237314]\n", " [-0.00405033]\n", " [ 0.06061839]\n", " [ 0.03582872]\n", " [-0.01267283]\n", " [-0.07734155]\n", " [ 0.05954058]\n", " [-0.02129532]\n", " [-0.00620595]\n", " [ 0.04445121]\n", " [-0.06548562]\n", " [ 0.12528712]\n", " [-0.05039625]\n", " [-0.06332999]\n", " [-0.03099563]\n", " [ 0.02289497]\n", " [ 0.01103904]\n", " [ 0.07139652]\n", " [ 0.01427248]\n", " [-0.00836158]\n", " [-0.06764124]\n", " [-0.0105172 ]\n", " [-0.02345095]\n", " [ 0.06816308]\n", " [-0.03530688]\n", " [-0.01159501]\n", " [-0.0730303 ]\n", " [-0.04177375]\n", " [ 0.01427248]\n", " [-0.00728377]\n", " [ 0.0164281 ]\n", " [-0.00943939]\n", " [-0.01590626]\n", " [ 0.0250506 ]\n", " [-0.04931844]\n", " [ 0.04121778]\n", " [-0.06332999]\n", " [-0.06440781]\n", " [-0.02560657]\n", " [-0.00405033]\n", " [ 0.00457217]\n", " [-0.00728377]\n", " [-0.0374625 ]\n", " [-0.02560657]\n", " [-0.02452876]\n", " [-0.01806189]\n", " [-0.01482845]\n", " [-0.02991782]\n", " [-0.046085 ]\n", " [-0.06979687]\n", " [ 0.03367309]\n", " [-0.00405033]\n", " [-0.02021751]\n", " [ 0.00241654]\n", " [-0.03099563]\n", " [ 0.02828403]\n", " [-0.03638469]\n", " [-0.05794093]\n", " [-0.0374625 ]\n", " [ 0.01211685]\n", " [-0.02237314]\n", " [-0.03530688]\n", " [ 0.00996123]\n", " [-0.03961813]\n", " [ 0.07139652]\n", " [-0.07518593]\n", " [-0.00620595]\n", " [-0.04069594]\n", " [-0.04824063]\n", " [-0.02560657]\n", " [ 0.0519959 ]\n", " [ 0.00457217]\n", " [-0.06440781]\n", " [-0.01698407]\n", " [-0.05794093]\n", " [ 0.00996123]\n", " [ 0.08864151]\n", " [-0.00512814]\n", " [-0.06440781]\n", " [ 0.01750591]\n", " [-0.04500719]\n", " [ 0.02828403]\n", " [ 0.04121778]\n", " [ 0.06492964]\n", " [-0.03207344]\n", " [-0.07626374]\n", " [ 0.04984027]\n", " [ 0.04552903]\n", " [-0.00943939]\n", " [-0.03207344]\n", " [ 0.00457217]\n", " [ 0.02073935]\n", " [ 0.01427248]\n", " [ 0.11019775]\n", " [ 0.00133873]\n", " [ 0.05846277]\n", " [-0.02129532]\n", " [-0.0105172 ]\n", " [-0.04716281]\n", " [ 0.00457217]\n", " [ 0.01750591]\n", " [ 0.08109682]\n", " [ 0.0347509 ]\n", " [ 0.02397278]\n", " [-0.00836158]\n", " [-0.06117437]\n", " [-0.00189471]\n", " [-0.06225218]\n", " [ 0.0164281 ]\n", " [ 0.09618619]\n", " [-0.06979687]\n", " [-0.02129532]\n", " [-0.05362969]\n", " [ 0.0433734 ]\n", " [ 0.05630715]\n", " [-0.0816528 ]\n", " [ 0.04984027]\n", " [ 0.11127556]\n", " [ 0.06169621]\n", " [ 0.01427248]\n", " [ 0.04768465]\n", " [ 0.01211685]\n", " [ 0.00564998]\n", " [ 0.04660684]\n", " [ 0.12852056]\n", " [ 0.05954058]\n", " [ 0.09295276]\n", " [ 0.01535029]\n", " [-0.00512814]\n", " [ 0.0703187 ]\n", " [-0.00405033]\n", " [-0.00081689]\n", " [-0.04392938]\n", " [ 0.02073935]\n", " [ 0.06061839]\n", " [-0.0105172 ]\n", " [-0.03315126]\n", " [-0.06548562]\n", " [ 0.0433734 ]\n", " [-0.06225218]\n", " [ 0.06385183]\n", " [ 0.03043966]\n", " [ 0.07247433]\n", " [-0.0191397 ]\n", " [-0.06656343]\n", " [-0.06009656]\n", " [ 0.06924089]\n", " [ 0.05954058]\n", " [-0.02668438]\n", " [-0.02021751]\n", " [-0.046085 ]\n", " [ 0.07139652]\n", " [-0.07949718]\n", " [ 0.00996123]\n", " [-0.03854032]\n", " [ 0.01966154]\n", " [ 0.02720622]\n", " [-0.00836158]\n", " [-0.01590626]\n", " [ 0.00457217]\n", " [-0.04285156]\n", " [ 0.00564998]\n", " [-0.03530688]\n", " [ 0.02397278]\n", " [-0.01806189]\n", " [ 0.04229559]\n", " [-0.0547075 ]\n", " [-0.00297252]\n", " [-0.06656343]\n", " [-0.01267283]\n", " [-0.04177375]\n", " [-0.03099563]\n", " [-0.00512814]\n", " [-0.05901875]\n", " [ 0.0250506 ]\n", " [-0.046085 ]\n", " [ 0.00349435]\n", " [ 0.05415152]\n", " [-0.04500719]\n", " [-0.05794093]\n", " [-0.05578531]\n", " [ 0.00133873]\n", " [ 0.03043966]\n", " [ 0.00672779]\n", " [ 0.04660684]\n", " [ 0.02612841]\n", " [ 0.04552903]\n", " [ 0.04013997]\n", " [-0.01806189]\n", " [ 0.01427248]\n", " [ 0.03690653]\n", " [ 0.00349435]\n", " [-0.07087468]\n", " [-0.03315126]\n", " [ 0.09403057]\n", " [ 0.03582872]\n", " [ 0.03151747]\n", " [-0.06548562]\n", " [-0.04177375]\n", " [-0.03961813]\n", " [-0.03854032]\n", " [-0.02560657]\n", " [-0.02345095]\n", " [-0.06656343]\n", " [ 0.03259528]\n", " [-0.046085 ]\n", " [-0.02991782]\n", " [-0.01267283]\n", " [-0.01590626]\n", " [ 0.07139652]\n", " [-0.03099563]\n", " [ 0.00026092]\n", " [ 0.03690653]\n", " [ 0.03906215]\n", " [-0.01482845]\n", " [ 0.00672779]\n", " [-0.06871905]\n", " [-0.00943939]\n", " [ 0.01966154]\n", " [ 0.07462995]\n", " [-0.00836158]\n", " [-0.02345095]\n", " [-0.046085 ]\n", " [ 0.05415152]\n", " [-0.03530688]\n", " [-0.03207344]\n", " [-0.0816528 ]\n", " [ 0.04768465]\n", " [ 0.06061839]\n", " [ 0.05630715]\n", " [ 0.09834182]\n", " [ 0.05954058]\n", " [ 0.03367309]\n", " [ 0.05630715]\n", " [-0.06548562]\n", " [ 0.16085492]\n", " [-0.05578531]\n", " [-0.02452876]\n", " [-0.03638469]\n", " [-0.00836158]\n", " [-0.04177375]\n", " [ 0.12744274]\n", " [-0.07734155]\n", " [ 0.02828403]\n", " [-0.02560657]\n", " [-0.06225218]\n", " [-0.00081689]\n", " [ 0.08864151]\n", " [-0.03207344]\n", " [ 0.03043966]\n", " [ 0.00888341]\n", " [ 0.00672779]\n", " [-0.02021751]\n", " [-0.02452876]\n", " [-0.01159501]\n", " [ 0.02612841]\n", " [-0.05901875]\n", " [-0.03638469]\n", " [-0.02452876]\n", " [ 0.01858372]\n", " [-0.0902753 ]\n", " [-0.00512814]\n", " [-0.05255187]\n", " [-0.02237314]\n", " [-0.02021751]\n", " [-0.0547075 ]\n", " [-0.00620595]\n", " [-0.01698407]\n", " [ 0.05522933]\n", " [ 0.07678558]\n", " [ 0.01858372]\n", " [-0.02237314]\n", " [ 0.09295276]\n", " [-0.03099563]\n", " [ 0.03906215]\n", " [-0.06117437]\n", " [-0.00836158]\n", " [-0.0374625 ]\n", " [-0.01375064]\n", " [ 0.07355214]\n", " [-0.02452876]\n", " [ 0.03367309]\n", " [ 0.0347509 ]\n", " [-0.03854032]\n", " [-0.03961813]\n", " [-0.00189471]\n", " [-0.03099563]\n", " [-0.046085 ]\n", " [ 0.00133873]\n", " [ 0.06492964]\n", " [ 0.04013997]\n", " [-0.02345095]\n", " [ 0.05307371]\n", " [ 0.04013997]\n", " [-0.02021751]\n", " [ 0.01427248]\n", " [-0.03422907]\n", " [ 0.00672779]\n", " [ 0.00457217]\n", " [ 0.03043966]\n", " [ 0.0519959 ]\n", " [ 0.06169621]\n", " [-0.00728377]\n", " [ 0.00564998]\n", " [ 0.05415152]\n", " [-0.00836158]\n", " [ 0.114509 ]\n", " [ 0.06708527]\n", " [-0.05578531]\n", " [ 0.03043966]\n", " [-0.02560657]\n", " [ 0.10480869]\n", " [-0.00620595]\n", " [-0.04716281]\n", " [-0.04824063]\n", " [ 0.08540807]\n", " [-0.01267283]\n", " [-0.03315126]\n", " [-0.00728377]\n", " [-0.01375064]\n", " [ 0.05954058]\n", " [ 0.02181716]\n", " [ 0.01858372]\n", " [-0.01159501]\n", " [-0.00297252]\n", " [ 0.01750591]\n", " [-0.02991782]\n", " [-0.02021751]\n", " [-0.05794093]\n", " [ 0.06061839]\n", " [-0.04069594]\n", " [-0.07195249]\n", " [-0.05578531]\n", " [ 0.04552903]\n", " [-0.00943939]\n", " [-0.03315126]\n", " [ 0.04984027]\n", " [-0.08488624]\n", " [ 0.00564998]\n", " [ 0.02073935]\n", " [-0.00728377]\n", " [ 0.10480869]\n", " [-0.02452876]\n", " [-0.00620595]\n", " [-0.03854032]\n", " [ 0.13714305]\n", " [ 0.17055523]\n", " [ 0.00241654]\n", " [ 0.03798434]\n", " [-0.05794093]\n", " [-0.00943939]\n", " [-0.02345095]\n", " [-0.0105172 ]\n", " [-0.03422907]\n", " [-0.00297252]\n", " [ 0.06816308]\n", " [ 0.00996123]\n", " [ 0.00241654]\n", " [-0.03854032]\n", " [ 0.02612841]\n", " [-0.08919748]\n", " [ 0.06061839]\n", " [-0.02884001]\n", " [-0.02991782]\n", " [-0.0191397 ]\n", " [-0.04069594]\n", " [ 0.01535029]\n", " [-0.02452876]\n", " [ 0.00133873]\n", " [ 0.06924089]\n", " [-0.06979687]\n", " [-0.02991782]\n", " [-0.046085 ]\n", " [ 0.01858372]\n", " [ 0.00133873]\n", " [-0.03099563]\n", " [-0.00405033]\n", " [ 0.01535029]\n", " [ 0.02289497]\n", " [ 0.04552903]\n", " [-0.04500719]\n", " [-0.03315126]\n", " [ 0.097264 ]\n", " [ 0.05415152]\n", " [ 0.12313149]\n", " [-0.08057499]\n", " [ 0.09295276]\n", " [-0.05039625]\n", " [-0.01159501]\n", " [-0.0277622 ]\n", " [ 0.05846277]\n", " [ 0.08540807]\n", " [-0.00081689]\n", " [ 0.00672779]\n", " [ 0.00888341]\n", " [ 0.08001901]\n", " [ 0.07139652]\n", " [-0.02452876]\n", " [-0.0547075 ]\n", " [-0.03638469]\n", " [ 0.0164281 ]\n", " [ 0.07786339]\n", " [-0.03961813]\n", " [ 0.01103904]\n", " [-0.04069594]\n", " [-0.03422907]\n", " [ 0.00564998]\n", " [ 0.08864151]\n", " [-0.03315126]\n", " [-0.05686312]\n", " [-0.03099563]\n", " [ 0.05522933]\n", " [-0.06009656]\n", " [ 0.00133873]\n", " [-0.02345095]\n", " [-0.07410811]\n", " [ 0.01966154]\n", " [-0.01590626]\n", " [-0.01590626]\n", " [ 0.03906215]\n", " [-0.0730303 ]]\n" ] } ], "source": [ "#Reshaping to get a 2D array\n", "X = X.reshape(-1, 1)\n", "print(X.shape)\n", "print(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split the training and test data for both `X` and `y`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.33)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Select the model and fit it with the training data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LinearRegression()" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = linear_model.LinearRegression()\n", "model.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use test data to predict a line" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "y_pred = model.predict(X_test)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the results in a plot" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(X_test, y_test, color='black')\n", "plt.plot(X_test, y_pred, color='blue', linewidth=3)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.1" }, "metadata": { "interpreter": { "hash": "70b38d7a306a849643e446cd70466270a13445e5987dfa1344ef2b127438fa4d" } }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }