You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/8-Reinforcement/2-Gym/solution/notebook.ipynb

520 lines
482 KiB

{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
},
"orig_nbformat": 4,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.7.0 64-bit ('3.7')"
},
"interpreter": {
"hash": "70b38d7a306a849643e446cd70466270a13445e5987dfa1344ef2b127438fa4d"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"## CartPole Skating\n",
"\n",
"> **Problem**: If Peter wants to escape from the wolf, he needs to be able to move faster than him. We will see how Peter can learn to skate, in particular, to keep balance, using Q-Learning.\n",
"\n",
"First, let's install the gym and import required libraries:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
4 years ago
"Requirement already satisfied: gym in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (0.18.3)\n",
"Requirement already satisfied: Pillow<=8.2.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gym) (7.0.0)\n",
"Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gym) (1.4.1)\n",
"Requirement already satisfied: numpy>=1.10.4 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gym) (1.19.2)\n",
4 years ago
"Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gym) (1.6.0)\n",
"Requirement already satisfied: pyglet<=1.5.15,>=1.4.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gym) (1.5.15)\n",
"\u001b[33mWARNING: You are using pip version 20.2.3; however, version 21.1.2 is available.\n",
"You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.7/bin/python3.7 -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"import sys\n",
"!{sys.executable} -m pip install gym \n",
"\n",
"import gym\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import random"
]
},
{
"source": [
"## Create a cartpole environment"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"env = gym.make(\"CartPole-v1\")\n",
"print(env.action_space)\n",
"print(env.observation_space)\n",
"print(env.action_space.sample())"
],
"cell_type": "code",
"metadata": {},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
4 years ago
"Discrete(2)\nBox(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)\n0\n"
]
}
]
},
{
"source": [
"To see how the environment works, let's run a short simulation for 100 steps."
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"env.reset()\n",
"\n",
"for i in range(100):\n",
" env.render()\n",
" env.step(env.action_space.sample())\n",
"env.close()"
],
"cell_type": "code",
"metadata": {},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.\u001b[0m\n warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n"
]
}
]
},
{
"source": [
"During simulation, we need to get observations in order to decide how to act. In fact, `step` function returns us back current observations, reward function, and the `done` flag that indicates whether it makes sense to continue the simulation or not:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"env.reset()\n",
"\n",
"done = False\n",
"while not done:\n",
" env.render()\n",
" obs, rew, done, info = env.step(env.action_space.sample())\n",
" print(f\"{obs} -> {rew}\")\n",
"env.close()"
],
"cell_type": "code",
"metadata": {},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
4 years ago
"[ 0.03044442 -0.19543914 -0.04496216 0.28125618] -> 1.0\n",
"[ 0.02653564 -0.38989186 -0.03933704 0.55942606] -> 1.0\n",
"[ 0.0187378 -0.19424049 -0.02814852 0.25461393] -> 1.0\n",
"[ 0.01485299 -0.38894946 -0.02305624 0.53828712] -> 1.0\n",
"[ 0.007074 -0.19351108 -0.0122905 0.23842953] -> 1.0\n",
"[ 0.00320378 0.00178427 -0.00752191 -0.05810469] -> 1.0\n",
"[ 0.00323946 0.19701326 -0.008684 -0.35315131] -> 1.0\n",
"[ 0.00717973 0.00201587 -0.01574703 -0.06321931] -> 1.0\n",
"[ 0.00722005 0.19736001 -0.01701141 -0.36082863] -> 1.0\n",
"[ 0.01116725 0.39271958 -0.02422798 -0.65882671] -> 1.0\n",
"[ 0.01902164 0.19794307 -0.03740452 -0.37387001] -> 1.0\n",
"[ 0.0229805 0.39357584 -0.04488192 -0.67810827] -> 1.0\n",
"[ 0.03085202 0.58929164 -0.05844408 -0.98457719] -> 1.0\n",
"[ 0.04263785 0.78514572 -0.07813563 -1.2950295 ] -> 1.0\n",
"[ 0.05834076 0.98116859 -0.10403622 -1.61111521] -> 1.0\n",
"[ 0.07796413 0.78741784 -0.13625852 -1.35259196] -> 1.0\n",
"[ 0.09371249 0.98396202 -0.16331036 -1.68461179] -> 1.0\n",
"[ 0.11339173 0.79106371 -0.1970026 -1.44691436] -> 1.0\n",
"[ 0.12921301 0.59883361 -0.22594088 -1.22169133] -> 1.0\n"
]
}
]
},
{
"source": [
"We can get min and max value of those numbers:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]\n[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]\n"
]
}
],
"source": [
"print(env.observation_space.low)\n",
"print(env.observation_space.high)"
]
},
{
"source": [
"## State Discretization"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def discretize(x):\n",
" return tuple((x/np.array([0.25, 0.25, 0.01, 0.1])).astype(np.int))"
]
},
{
"source": [
"Let's also explore other discretization method using bins:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Sample bins for interval (-5,5) with 10 bins\n [-5. -4. -3. -2. -1. 0. 1. 2. 3. 4. 5.]\n"
]
}
],
"source": [
"def create_bins(i,num):\n",
" return np.arange(num+1)*(i[1]-i[0])/num+i[0]\n",
"\n",
"print(\"Sample bins for interval (-5,5) with 10 bins\\n\",create_bins((-5,5),10))\n",
"\n",
"ints = [(-5,5),(-2,2),(-0.5,0.5),(-2,2)] # intervals of values for each parameter\n",
"nbins = [20,20,10,10] # number of bins for each parameter\n",
"bins = [create_bins(ints[i],nbins[i]) for i in range(4)]\n",
"\n",
"def discretize_bins(x):\n",
" return tuple(np.digitize(x[i],bins[i]) for i in range(4))"
]
},
{
"source": [
"Let's now run a short simulation and observe those discrete environment values."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
4 years ago
"(0, 0, -1, -3)\n(0, 0, -2, 0)\n(0, 0, -2, -3)\n(0, 1, -3, -6)\n(0, 2, -4, -9)\n(0, 3, -6, -12)\n(0, 2, -8, -9)\n(0, 3, -10, -13)\n(0, 4, -13, -16)\n(0, 4, -16, -19)\n(0, 4, -20, -17)\n(0, 4, -24, -20)\n"
]
}
],
"source": [
"env.reset()\n",
"\n",
"done = False\n",
"while not done:\n",
" #env.render()\n",
" obs, rew, done, info = env.step(env.action_space.sample())\n",
" #print(discretize_bins(obs))\n",
" print(discretize(obs))\n",
"env.close()"
]
},
{
"source": [
"## Q-Table Structure"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"Q = {}\n",
"actions = (0,1)\n",
"\n",
"def qvalues(state):\n",
" return [Q.get((state,a),0) for a in actions]"
]
},
{
"source": [
"## Let's Start Q-Learning!"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# hyperparameters\n",
"alpha = 0.3\n",
"gamma = 0.9\n",
"epsilon = 0.90"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0: 108.0, alpha=0.3, epsilon=0.9\n"
]
}
],
"source": [
"def probs(v,eps=1e-4):\n",
" v = v-v.min()+eps\n",
" v = v/v.sum()\n",
" return v\n",
"\n",
"Qmax = 0\n",
"cum_rewards = []\n",
"rewards = []\n",
"for epoch in range(100000):\n",
" obs = env.reset()\n",
" done = False\n",
" cum_reward=0\n",
" # == do the simulation ==\n",
" while not done:\n",
" s = discretize(obs)\n",
" if random.random()<epsilon:\n",
" # exploitation - chose the action according to Q-Table probabilities\n",
" v = probs(np.array(qvalues(s)))\n",
" a = random.choices(actions,weights=v)[0]\n",
" else:\n",
" # exploration - randomly chose the action\n",
" a = np.random.randint(env.action_space.n)\n",
"\n",
" obs, rew, done, info = env.step(a)\n",
" cum_reward+=rew\n",
" ns = discretize(obs)\n",
" Q[(s,a)] = (1 - alpha) * Q.get((s,a),0) + alpha * (rew + gamma * max(qvalues(ns)))\n",
" cum_rewards.append(cum_reward)\n",
" rewards.append(cum_reward)\n",
" # == Periodically print results and calculate average reward ==\n",
" if epoch%5000==0:\n",
" print(f\"{epoch}: {np.average(cum_rewards)}, alpha={alpha}, epsilon={epsilon}\")\n",
" if np.average(cum_rewards) > Qmax:\n",
" Qmax = np.average(cum_rewards)\n",
" Qbest = Q\n",
" cum_rewards=[]"
]
},
{
"source": [
"## Plotting Training Progress"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x2814bb79788>]"
]
},
"metadata": {},
"execution_count": 20
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 379.159862 248.518125\" width=\"379.159862pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 248.518125 \r\nL 379.159862 248.518125 \r\nL 379.159862 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 33.2875 224.64 \r\nL 368.0875 224.64 \r\nL 368.0875 7.2 \r\nL 33.2875 7.2 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"ma0dcf872fc\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"48.505682\" xlink:href=\"#ma0dcf872fc\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- 0 -->\r\n <defs>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(45.324432 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"109.379018\" xlink:href=\"#ma0dcf872fc\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- 20000 -->\r\n <defs>\r\n <path d=\"M 19.1875 8.296875 \r\nL 53.609375 8.296875 \r\nL 53.609375 0 \r\nL 7.328125 0 \r\nL 7.328125 8.296875 \r\nQ 12.9375 14.109375 22.625 23.890625 \r\nQ 32.328125 33.6875 34.8125 36.53125 \r\nQ 39.546875 41.84375 41.421875 45.53125 \r\nQ 43.3125 49.21875 43.3125 52.78125 \r\nQ 43.3125 58.59375 39.234375 62.25 \r\nQ 35.15625 65.921875 28.609375 65.921875 \r\nQ 23.96875 65.921875 18.8125 64.3125 \r\nQ 13.671875 62.703125 7.8125 59.421875 \r\nL 7.8125 69.390625 \r\nQ 13.765625 71.78125 18.9375 73 \r\nQ 24.125 74.21875 28.421875 74.21875 \r\nQ 39.75 74.21875 46.484375 68.546875 \r\nQ 53.21875 62.890625 53.21875 53.421875 \r\nQ 53.21875 48.921875 51.53125 44.890625 \r\nQ 49.859375 40.875 45.40625 35.40625 \r\nQ 44.1875 33.984375 37.640625 27.21875 \r\nQ 31.109375 20.453125 19.1875 8.296875 \r\nz\r\n\" id=\"DejaVuSans-50\"/>\r\n </defs>\r\n <g transform=\"translate(93.472768 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-50\"/>\r\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_3\">\r\n <g id=\"line2d_3\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"170.252354\" xlink:href=\"#ma0dcf872fc\" y=\"224.64\"/>\
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3deXxU9b3/8dcnCSTsa8CQgAEJIKIIBGSXTUWiYqu0Lq2o3MvV6nWhVlGrtbdasddq9dqfy9W2tr22WpdKXYu4W0VBRVBAQFACCEF2kCXk+/tjvkkm+yTMZCZn3s/HI4+c853vzPmenMl7vud7zpxjzjlERCS4UuLdABERiS0FvYhIwCnoRUQCTkEvIhJwCnoRkYBLi3cDADp37uxyc3Pj3QwRkSZl0aJFW5xzmXXVS4igz83NZeHChfFuhohIk2JmX0ZST0M3IiIBp6AXEQk4Bb2ISMAp6EVEAk5BLyIScBEFvZmtNbMlZvaxmS30ZR3NbJ6ZrfS/O/hyM7N7zWyVmX1iZoNjuQIiIlK7+vToxzvnjnfO5fv52cB851weMN/PA5wK5PmfmcD90WqsiIjU3+GcRz8VGOenHwVeB67z5X90oesfv2dm7c0syzm38XAa2pjWbd3Lj/+2mG7tMvjpaf3p3DqdbXsO8K/V31BwXBbOOf7+8XpO6NmJj77azsSjuzD5N2/yo3G9eXbxev484wTumb+S/lltOaJdBobxwBureePzIpbccjJmBsCLSzby+ooiphyXxYl9Kn7nYUnhDv7+8Xq6tk1n5tijyso/XredtBRjQHY7nHM8uaiQCf26cOvzy+jYqjnLNu7kX6u/4d5zB5HdvgWri3bz3UHZpKWm8PKnX3NMt7bc/sJybig4mpPueoNHpg+leZrxwdptzHlxOU9dOoIN2/dx+sBuPLbgK254Zglt0tO4cFQu//PqKm46rT+/eO4zlv9iMjf9fSl/W1RIn66tGdazI51apXPP/JX87PT+9OzcivfXbGXJ+h28tXJLg7fFJScexQNvrKZ3l9as2ry7rNwMUs0oLonfZbbbZKSxa19xxPVHHtWJpet3sHNfMcdmt2PJ+h306tyKL7bsAeDu7w/k6scXx6q53HJ6f7buPcgTH6zj6537YraccENzO/DB2m38YHgP1m/7luZpKbz86aZ6vcb5J/TgolG5TLrrzXov/4qJeaSacfcrn9da79YzB/DlN3v437fW8B9je/Hgm19UqfPdwdk8/eH6Gl9jcI/2PHXpSNZs2cOEX78RUftyOrTglVknktEsNaL6DWWRXI/ezNYA2wAHPOice8jMtjvn2ofV2eac62BmzwFznHNv+/L5wHXOuYWVXnMmoR4/PXr0GPLllxGd998ocmc/Xzad06EFb183ge8/+C4L1mzlvesnsm7bXqY98G5ZnYn9ujB/+eay+cqhFO73Fw1lfN8u7Nx3kONu+WdZ+do5BTW2Ifyx0vK1cwp4aelGLvnzh3Wuz7WT+3LxqJ70u+mlOuuWKv0HrUnz1BQOHCqJ+PVEksFvzxvMZY/V/T8Z7ofDj+QXZw5o0PLMbFHYKEuNIu3Rj3LObTCzLsA8M1te27KrKavyaeKcewh4CCA/Pz9h735SuO1bANZvD/0+eKiE3fsr9uLWbdtbYb6mkAfKeoDFhw5/lXd+G1lvcuvuA5TU8wYztYU8oJAXqcaufQfr/ZyiXftj0JKKIhqjd85t8L83A88Aw4BNZpYF4H+XdmkLge5hT88BNkSrwUGVO/t5LvnTong3Q0QCqM6gN7NWZtamdBo4GVgKzAWm+2rTgWf99FzgAn/2zXBgR1Man4+nlz79Ot5NEJEAimTopivwjD+AmAY85px7ycw+AJ4wsxnAV8A0X/8FYAqwCtgLXBT1VouIBETR7tgP3dQZ9M65L4CB1ZR/A0ysptwBl0WldQG0/+Ah5i7ewKijOjXaMuN4YoqI1OHLb/bWXekwJcRlipPJr15eQdGu/dxzzvGNtkyr7vC4iETdI2+viXcTqqVLIHhffrOHRV/WfqZJNJQeYd/5bf2PzotIYltZyxl38aQevXfif78OVD2fXUQklrY0whi9evRRYNV+dSBxOJf4bRSR2FHQi4gEnIZuoqAhBzsb80SY372zhjYZ2tQiyUo9+giMv/P1CvNNcRDknvkr490EEYkTBX0E1virC4qINEUK+nqq57XBouqv738Vv4WLSJOloI9QInzpaPbTS+LdBBFpghT0IiIBp6AH9h6o+7ru0R6yiecQkIgkFwU9sOCLrXXWKb0BSayHcL49cCi2CxCRpKOgr2TfwbqDNpad8UPq6otIlCnoK9l/sHFukVfbnkFjXPtCRJKHgr6eHv9gXUzv8bh19wHyb30lZq8vIslHQV9P9722it+/szZmr79174EqZbmzn+eQ7h4iIg2koAdcPUfdDxQf/gHTmobiP9uws9ryg4caZ0hJRIJHQd8Aq4sO/5IIroakX7l512G/tohIOAV9gvmwhrtcrdyUmHeuEZHEp6CvpL7DONG2uHBHteWn3/d2I7dERIJCQR8nlggXzxGRpKCgFxEJOAU98bnuTE0HY0VEok1BHwUahhGRRKagr8Sa5I0CRURqpqCvpCFn3SzbWP2XnEREEoGCnviM0Wu4R0Qai4JeRCTgFPTE9vryNS5TZ92ISCNR0MfJLf/4LN5NEJEkEXHQm1mqmX1kZs/5+Z5mtsDMVprZ42bW3Jen+/lV/vHc2DQ9etS7FpEgq0+P/kpgWdj8HcDdzrk8YBsww5fPALY553oDd/t6Ce3xD9bFuwn11uenL/L655vj3QwRaQIiCnozywEKgIf9vAETgCd9lUeBM/30VD+Pf3yiJfgpJis3l18Zsql07g8Ul/DCkq/j3QwRaQIi7dH/BrgWKL37RSdgu3Ou2M8XAtl+OhtYB+Af3+HrV2BmM81soZktLCoqamDzRUSkLnUGvZmdBmx2zi0KL66mqovgsfIC5x5yzuU75/IzMzMjamysxPvSxCIisZQWQZ1RwBlmNgXIANoS6uG3N7M032vPATb4+oVAd6DQzNKAdsDWqLc8SkpKHCVhd+lL7EEmEZH6q7NH75y73jmX45zLBc4BXnXOnQ+8Bpztq00HnvXTc/08/vFXXQKf1nLqPW+xfvu38W6GiEjMHM559NcBs8xsFaEx+Ed8+SNAJ18+C5h9eE2MrRWbdI9WEQm2SIZuyjjnXgde99NfAMOqqbMPmBaFtsVF4u57iIg0jL4ZKyIScAp6EZGAU9CLiARcvcbog+JAcQl9fvoiV0zoHe+miIjEXFL26PcVHwLg9++sjW9DREQaQVIGvYhIMlHQV6KzK0UkaBT0legKCCISNAp6EZGAU9CLiAScgr4SjdGLSNAo6CuZ95nu2iQiwZLUQb9rf3GVsi27D8ShJSIisZPUQS8ikgwU9JX8+p8r4t0EEZGoUtBXUqKjsSISMEkX9J9v2sW0+9+NdzNERBpN0gX9L19YptsHikhSSbqgFxFJNgp6EZGAU9CLiAScgl5EJOAU9CIiAaegFxEJuKQLet1YRESSTdIFvYhIslHQi4gEnIJeRCTgFPQiIgGnoBcRCbikC3oznXcjIsmlzqA3swwze9/MFpvZp2b2c1/e08wWmNlKM3vczJr78nQ/v8o/nhvbVRARkdpE0qPfD0xwzg0Ejgcmm9lw4A7gbudcHrANmOHrzwC2Oed6A3f7eglD/XkRSTZ1Br0L2e1nm/kfB0wAnvTljwJn+umpfh7/+ETTeImISNxENEZvZqlm9jGwGZgHrAa2O+eKfZVCINtPZwPrAPzjO4BO0Wy0iIhELqKgd84dcs4dD+QAw4Cjq6vmf1fXe69yJ1Yzm2lmC81sYVFRUaTtFRGReqrXWTfOue3A68BwoL2ZpfmHcoANfroQ6A7gH28HbK3mtR5yzuU75/IzMzMb1voG0CCSiCSbSM66yTSz9n66BTAJWAa8Bpztq00HnvXTc/08/vFXnXNVevQiItI40uquQhbwqJmlEvpgeMI595yZfQb81cxuBT4CHvH1HwH+ZGa
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"plt.plot(rewards)"
]
},
{
"source": [
"From this graph, it is not possible to tell anything, because due to the nature of stochastic training process the length of training sessions varies greatly. To make more sense of this graph, we can calculate **running average** over series of experiments, let's say 100. This can be done conveniently using `np.convolve`:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x2814cc63488>]"
]
},
"metadata": {},
"execution_count": 22
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 379.461486 248.518125\" width=\"379.461486pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 248.518125 \r\nL 379.461486 248.518125 \r\nL 379.461486 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 33.2875 224.64 \r\nL 368.0875 224.64 \r\nL 368.0875 7.2 \r\nL 33.2875 7.2 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"mb0fbef4acd\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"48.505682\" xlink:href=\"#mb0fbef4acd\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- 0 -->\r\n <defs>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(45.324432 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"109.439343\" xlink:href=\"#mb0fbef4acd\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- 20000 -->\r\n <defs>\r\n <path d=\"M 19.1875 8.296875 \r\nL 53.609375 8.296875 \r\nL 53.609375 0 \r\nL 7.328125 0 \r\nL 7.328125 8.296875 \r\nQ 12.9375 14.109375 22.625 23.890625 \r\nQ 32.328125 33.6875 34.8125 36.53125 \r\nQ 39.546875 41.84375 41.421875 45.53125 \r\nQ 43.3125 49.21875 43.3125 52.78125 \r\nQ 43.3125 58.59375 39.234375 62.25 \r\nQ 35.15625 65.921875 28.609375 65.921875 \r\nQ 23.96875 65.921875 18.8125 64.3125 \r\nQ 13.671875 62.703125 7.8125 59.421875 \r\nL 7.8125 69.390625 \r\nQ 13.765625 71.78125 18.9375 73 \r\nQ 24.125 74.21875 28.421875 74.21875 \r\nQ 39.75 74.21875 46.484375 68.546875 \r\nQ 53.21875 62.890625 53.21875 53.421875 \r\nQ 53.21875 48.921875 51.53125 44.890625 \r\nQ 49.859375 40.875 45.40625 35.40625 \r\nQ 44.1875 33.984375 37.640625 27.21875 \r\nQ 31.109375 20.453125 19.1875 8.296875 \r\nz\r\n\" id=\"DejaVuSans-50\"/>\r\n </defs>\r\n <g transform=\"translate(93.533093 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-50\"/>\r\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_3\">\r\n <g id=\"line2d_3\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"170.373004\" xlink:href=\"#mb0fbef4acd\" y=\"224.64\"/>\
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD4CAYAAAANbUbJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO2dd3gVZfbHvycdAiGUAKEZelGqkY4gICDo4rr6U3dVVKxrWdeKde2ylnXX1bWiYu8FpYmAKCol9AABAgQIBAglQALp7++PO3Mzd+70O7fk3vN5njyZeeedmXfu3HvmzHlPISEEGIZhmOgmLtwDYBiGYYIPC3uGYZgYgIU9wzBMDMDCnmEYJgZgYc8wDBMDJIR7AADQokULkZWVFe5hMAzD1CtWrVp1SAiRYaVvRAj7rKws5OTkhHsYDMMw9Qoi2mW1L5txGIZhYgAW9gzDMDEAC3uGYZgYgIU9wzBMDMDCnmEYJgZgYc8wDBMDsLBnGIaJAUyFPRGlENEKIlpHRBuJ6DGp/V0i2klEa6W/flI7EdFLRJRPROuJaECwL4JhwoUQAp/n7EFldW24h8IwhlgJqqoAMFoIUUpEiQCWEtFcads9QogvVP3PA9BV+hsE4FXpP8NEHXNz9+OeL9aj4HAZ7hnfI9zDYRhdTDV74aFUWk2U/owqnkwG8J603zIA6USUGfhQGSbyOHaqCgBw6ERlmEfCMMZYstkTUTwRrQVwEMACIcRyadNTkqnmRSJKltraAtij2L1QalMf8wYiyiGinOLi4gAugWHCB4V7AAxjEUvCXghRI4ToB6AdgIFEdAaA+wH0AHAWgGYA7pO6a33//d4EhBBvCCGyhRDZGRmW8vgwDMMwDrHljSOEKAHwE4AJQogiyVRTAeAdAAOlboUA2it2awdgnwtjZRiGYRxixRsng4jSpeUGAMYCyJPt8EREAC4EkCvtMgvAVZJXzmAAx4QQRUEZPcOEmVrpnVUYTmPFBt+v34dVu46GexiMDlY0+0wAi4loPYCV8NjsvwfwIRFtALABQAsAT0r95wDYASAfwJsA/ur6qBkmQnjn150AgFnr+OX11o/W4E+v/hbuYTA6mLpeCiHWA+iv0T5ap78AcEvgQ2OYyGfbQY+jWnmVr599eVUNVhYcwYiuPB/FRAYcQcswQeDRWRtx5YwVyNt/PNxDYRgALOyZKOPD5bvw2co95h2DjKzxnyivDtk5K6pr8I9vc1FVw9G8jD8s7Jmo4sGvc3Hvl+tt7fPBsl3YeajM1XHIE5Wh9MO/94v1mPn7Lpz97OIQnpWpL7CwZ2KSkpOVKD5RASEEHvomFxe+8mtAx5twemvNdgqhtP9l2yEAQNGxchwqrbC0z9wNRciaNtsbCcxELyzsmZhkwBMLcNZTP3pdJwMVdh2aN9Rs31tSHtBx7XCkrC5lw/8WbzfsW3yiAmUV1XhtiaffzkNluPWj1fh4xe6gjpEJHyzsmZjE6x8vAvOPv3poFgCgVVqK5vZtB04EdHynmD28znrqR5z/36UorfDMKdQKge/XF+H+rzb49T1cWoGsabOxcPMBS+cuq3A2T3Gqsiak8w3Xv5cTUw83FvZMTBNoKJRsptETrtW14Qm2+s6C3//OQ2XYXuyZqyg1mEjeVOTxKHrn1wJL5z79H/ORf9D+Q67nI/PQ9cG55h1dYsGmA5oPt2Bj1cTmNizsmZimNkDN/rf8wwCAlxZu09xeFaY895U2NeSftribjHDjPnsup1v2h+cNKNQs33EY2U/+iHm5oU8qwMKeiXpeXrQNd362VnNbgLIeW0zMNG5r9vtKTiFr2mx8u3avq8dtmZbsXe720FzNCetgpoQ4WRk6F1XA13zn5C3EKTmSl9a6wmMhO6cMC3sm6nn+h634arW5cHzw6w2ue6XY1bDNeG7+FgDA3z7Rfng5JbNJ3ZxDZXUt1u4p8a6T5ED6a/5h3dw3gc591ITY3KU83dh//YwCl11v9ZAfag0S40NyPiUs7JmoJGvabNz3hbm/vdKM8+Hy3brmGKe4bcZJSdT/yd43wXmlrMR4a6JAby5ALazJps9pqOc21Oa7/cdD4zU1e73HfPP1GnffzKzAwp6JWj7N8Y2k3V5c6tdHrZAGatZR47YMS06o0wgHPf0j8vYfR8GhMmRNm42PVuxyfNz4OGvCOU5HiKuFtcXDeQl11K/64ZSU4FwU7iguxUPfbECthZtdcPgkAGD/sdC55MqwsGdihjEvLPFrU/88F2856Oo50xsmunq8FMXr/4HjFXjz550Y9fxPAIA9R045Pq6eEAeA2RvqJhP1uqmFNdmMHS456dx8tufISaxTmJ2soH6ot0hN1u5ogZs+WIUPlu3GVhPbf/GJCnTKSAUAnKqqcXw+p7CwZ2Ia9eu8kfWhsroWk176Bb/mH7J8/GqXNdZklQYaqDeRTFqKdgLcE+VVPr7oeh9PpcpcdeSkvZq8pQ598wFgxLOLMdlmBLRbnxsAlFV4BHd1jf4xtx44gbOe+hE7ikMzN6AFC3sm5thz5KR3Wf2bH9lNPyXxvpJT2LjvuC3f7CqX7TjJicER9o00hP2O4lJMfTfHp03vYaiM3gX0Hx56lIdY01V/biWnnBeM31vieaP6YlWhbp/cvaH3vlHDwp6JOUYoEoXd+tFqn21G5gdZPOxWPCzMqDHQ9pyg9uLQe5b8aUA7W8fVemaMfmEJVhQc8WnTM/esUZlRGibZE/bqCeK1e0rwyLe5ul4+xSc8Ub2L8qxF9cpU1dTio+W7/Wz26jcTJ5yq1H9gac0XhRoW9kxMIycPk3n7151+Oej3HDmJDQ79oqtq3fbG8RX2esLQonONl2cll04zCo9qzwscV7ms2nXFVNv8L3zlV7z3+y6/ojAysqZ8rerNw4zXl2zHA19vwOc5vlp4WoPA51aM5me27GdhzzARx6SXlvqsj3h2MS54ealO7zpOf2SenznCyI7rBPVEpJ4Zx6p3jczPW61F0Cona5WotdpaIXD9eznIUb0Z6KGnWS/bcdjS/laRbeay6UXG7n06dqoKX632fWAcLtM3BVVUh35CVg0LeybieG5+HuZv3B+047dpop20TMZKgI+W5lpWWYN3fyvwaat2WbP/RFWYRe/wahu6kkADoLR4YcFWn/XiExVYsOkArn/PXPP+bt0+fLNW23//pI5ppOCw/0RnTa3Aw9/k+szJKDl4vBxfSf7tB1R+9XbnPvo+9gPu/GwdNu6re+NTH1NJhc7DbF/JqaDcDy1Y2DMRxyuLt+PG91f5tdfUCld+GOkNkxztpzy3nvfI9Ll5PutVLmv2TVTmBr0UBvM36tuyQyFb5DQSRy24VN728RpsLtLOpaPnovjYd5v82tbsPor3l+3C3z/Vji6+Q9HevplvSuqaWoH8g6XImjYby228TWzcWzfuHq0b6/ar0LiO938vwNDpi3Drx2ssny8QWNgz9YbOD8zBzR+sNu9owiYdwWIHPWHfv0O6z7rbrpcNk6xN0KrZUHgMD3/jmfB0w4PnyhnLMeCJBbrbW0spnxPsRlepsOOPLkft1kjX98rifOQfrLOV/7a9TojLkawyNULgmTmbAQCv/7xD9xzPzN2MborMnD9trYvLGNalhe5+WuUpP1qxR3MswcJU2BNRChGtIKJ1RLSRiB6T2jsS0XIi2kZEnxJRktSeLK3nS9uzgnsJTCwxL0DzTiBeF0rBeqqyRtNdsEtGI591ZWRpba0wfNW3NgZfQW1Vbl/y+m94f9kulFZUY6mNOAE9ftl2CEfKKnVt0bLg7dgiNaDznLKRIE1+rtTWCuwrOYXn5m/BJa/9ptlXbbOvqRVYmOcR3Ivy9APrXl+ywyff0dDOdQLeaJ5kh0buncY23VMDxYpmXwFgtBCiL4B+ACYQ0WAA/wTwohCiK4CjAKZK/acCOCqE6ALgRakfw0QE3wSQk0QpaEe/sASXvv67Rh/fdeXE3yuL8zHo6YW6NmUr2H1Y1dQKHC+v8nq1LNx8EFe/s9Lx+dVco3Msed7DKDLXCqcqrV+vLGxrhMDlby7z7C89kI+
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"def running_average(x,window):\n",
" return np.convolve(x,np.ones(window)/window,mode='valid')\n",
"\n",
"plt.plot(running_average(rewards,100))"
]
},
{
"source": [
"## Varying Hyperparameters and Seeing the Result in Action\n",
"\n",
"Now it would be interesting to actually see how the trained model behaves. Let's run the simulation, and we will be following the same action selection strategy as during training: sampling according to the probability distribution in Q-Table: "
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"obs = env.reset()\n",
"done = False\n",
"while not done:\n",
" s = discretize(obs)\n",
" env.render()\n",
" v = probs(np.array(qvalues(s)))\n",
" a = random.choices(actions,weights=v)[0]\n",
" obs,_,done,_ = env.step(a)\n",
"env.close()"
]
},
{
"source": [
"\n",
"## Saving result to an animated GIF\n",
"\n",
"If you want to impress your friends, you may want to send them the animated GIF picture of the balancing pole. To do this, we can invoke `env.render` to produce an image frame, and then save those to animated GIF using PIL library:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"360\n"
]
}
],
"source": [
"from PIL import Image\n",
"obs = env.reset()\n",
"done = False\n",
"i=0\n",
"ims = []\n",
"while not done:\n",
" s = discretize(obs)\n",
" img=env.render(mode='rgb_array')\n",
" ims.append(Image.fromarray(img))\n",
" v = probs(np.array([Qbest.get((s,a),0) for a in actions]))\n",
" a = random.choices(actions,weights=v)[0]\n",
" obs,_,done,_ = env.step(a)\n",
" i+=1\n",
"env.close()\n",
"ims[0].save('images/cartpole-balance.gif',save_all=True,append_images=ims[1::2],loop=0,duration=5)\n",
"print(i)"
]
}
]
}