You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/8-Reinforcement/2-Gym/notebook.ipynb

388 lines
476 KiB

{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"orig_nbformat": 4,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.7.4 64-bit ('base': conda)"
},
"interpreter": {
"hash": "86193a1ab0ba47eac1c69c1756090baa3b420b3eea7d4aafab8b85f8b312f0c5"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"## CartPole Skating\n",
"\n",
"> **Problem**: If Peter wants to escape from the wolf, he needs to be able to move faster than him. We will see how Peter can learn to skate, in particular, to keep balance, using Q-Learning.\n",
"\n",
"First, let's install the gym and import required libraries:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#code block 1"
]
},
{
"source": [
"## Create a cartpole environment"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"#code block 2"
],
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"source": [
"To see how the environment works, let's run a short simulation for 100 steps."
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"#code block 3"
],
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"source": [
"During simulation, we need to get observations in order to decide how to act. In fact, `step` function returns us back current observations, reward function, and the `done` flag that indicates whether it makes sense to continue the simulation or not:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"#code block 4"
],
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"source": [
"We can get min and max value of those numbers:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]\n[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]\n"
]
}
],
"source": [
"#code block 5"
]
},
{
"source": [
"## State Discretization"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#code block 6"
]
},
{
"source": [
"Let's also explore other discretization method using bins:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Sample bins for interval (-5,5) with 10 bins\n [-5. -4. -3. -2. -1. 0. 1. 2. 3. 4. 5.]\n"
]
}
],
"source": [
"#code block 7"
]
},
{
"source": [
"Let's now run a short simulation and observe those discrete environment values."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(0, 0, -2, -2)\n(0, 1, -2, -5)\n(0, 2, -3, -8)\n(0, 3, -5, -11)\n(0, 3, -7, -14)\n(0, 4, -10, -17)\n(0, 3, -14, -15)\n(0, 3, -17, -12)\n(0, 3, -20, -16)\n(0, 4, -23, -19)\n"
]
}
],
"source": [
"#code block 8"
]
},
{
"source": [
"## Q-Table Structure"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"#code block 9"
]
},
{
"source": [
"## Let's Start Q-Learning!"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"#code block 10"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0: 22.0, alpha=0.3, epsilon=0.9\n",
"5000: 70.1384, alpha=0.3, epsilon=0.9\n",
"10000: 121.8586, alpha=0.3, epsilon=0.9\n",
"15000: 149.6368, alpha=0.3, epsilon=0.9\n",
"20000: 168.2782, alpha=0.3, epsilon=0.9\n",
"25000: 196.7356, alpha=0.3, epsilon=0.9\n",
"30000: 220.7614, alpha=0.3, epsilon=0.9\n",
"35000: 233.2138, alpha=0.3, epsilon=0.9\n",
"40000: 248.22, alpha=0.3, epsilon=0.9\n",
"45000: 264.636, alpha=0.3, epsilon=0.9\n",
"50000: 276.926, alpha=0.3, epsilon=0.9\n",
"55000: 277.9438, alpha=0.3, epsilon=0.9\n",
"60000: 248.881, alpha=0.3, epsilon=0.9\n",
"65000: 272.529, alpha=0.3, epsilon=0.9\n",
"70000: 281.7972, alpha=0.3, epsilon=0.9\n",
"75000: 284.2844, alpha=0.3, epsilon=0.9\n",
"80000: 269.667, alpha=0.3, epsilon=0.9\n",
"85000: 273.8652, alpha=0.3, epsilon=0.9\n",
"90000: 278.2466, alpha=0.3, epsilon=0.9\n",
"95000: 269.1736, alpha=0.3, epsilon=0.9\n"
]
}
],
"source": [
"#code block 11"
]
},
{
"source": [
"## Plotting Training Progress"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x2814bb79788>]"
]
},
"metadata": {},
"execution_count": 20
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 379.159862 248.518125\" width=\"379.159862pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 248.518125 \r\nL 379.159862 248.518125 \r\nL 379.159862 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 33.2875 224.64 \r\nL 368.0875 224.64 \r\nL 368.0875 7.2 \r\nL 33.2875 7.2 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"ma0dcf872fc\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"48.505682\" xlink:href=\"#ma0dcf872fc\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- 0 -->\r\n <defs>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(45.324432 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"109.379018\" xlink:href=\"#ma0dcf872fc\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- 20000 -->\r\n <defs>\r\n <path d=\"M 19.1875 8.296875 \r\nL 53.609375 8.296875 \r\nL 53.609375 0 \r\nL 7.328125 0 \r\nL 7.328125 8.296875 \r\nQ 12.9375 14.109375 22.625 23.890625 \r\nQ 32.328125 33.6875 34.8125 36.53125 \r\nQ 39.546875 41.84375 41.421875 45.53125 \r\nQ 43.3125 49.21875 43.3125 52.78125 \r\nQ 43.3125 58.59375 39.234375 62.25 \r\nQ 35.15625 65.921875 28.609375 65.921875 \r\nQ 23.96875 65.921875 18.8125 64.3125 \r\nQ 13.671875 62.703125 7.8125 59.421875 \r\nL 7.8125 69.390625 \r\nQ 13.765625 71.78125 18.9375 73 \r\nQ 24.125 74.21875 28.421875 74.21875 \r\nQ 39.75 74.21875 46.484375 68.546875 \r\nQ 53.21875 62.890625 53.21875 53.421875 \r\nQ 53.21875 48.921875 51.53125 44.890625 \r\nQ 49.859375 40.875 45.40625 35.40625 \r\nQ 44.1875 33.984375 37.640625 27.21875 \r\nQ 31.109375 20.453125 19.1875 8.296875 \r\nz\r\n\" id=\"DejaVuSans-50\"/>\r\n </defs>\r\n <g transform=\"translate(93.472768 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-50\"/>\r\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_3\">\r\n <g id=\"line2d_3\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"170.252354\" xlink:href=\"#ma0dcf872fc\" y=\"224.64\"/>\
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3deXxU9b3/8dcnCSTsa8CQgAEJIKIIBGSXTUWiYqu0Lq2o3MvV6nWhVlGrtbdasddq9dqfy9W2tr22WpdKXYu4W0VBRVBAQFACCEF2kCXk+/tjvkkm+yTMZCZn3s/HI4+c853vzPmenMl7vud7zpxjzjlERCS4UuLdABERiS0FvYhIwCnoRUQCTkEvIhJwCnoRkYBLi3cDADp37uxyc3Pj3QwRkSZl0aJFW5xzmXXVS4igz83NZeHChfFuhohIk2JmX0ZST0M3IiIBp6AXEQk4Bb2ISMAp6EVEAk5BLyIScBEFvZmtNbMlZvaxmS30ZR3NbJ6ZrfS/O/hyM7N7zWyVmX1iZoNjuQIiIlK7+vToxzvnjnfO5fv52cB851weMN/PA5wK5PmfmcD90WqsiIjU3+GcRz8VGOenHwVeB67z5X90oesfv2dm7c0syzm38XAa2pjWbd3Lj/+2mG7tMvjpaf3p3DqdbXsO8K/V31BwXBbOOf7+8XpO6NmJj77azsSjuzD5N2/yo3G9eXbxev484wTumb+S/lltOaJdBobxwBureePzIpbccjJmBsCLSzby+ooiphyXxYl9Kn7nYUnhDv7+8Xq6tk1n5tijyso/XredtBRjQHY7nHM8uaiQCf26cOvzy+jYqjnLNu7kX6u/4d5zB5HdvgWri3bz3UHZpKWm8PKnX3NMt7bc/sJybig4mpPueoNHpg+leZrxwdptzHlxOU9dOoIN2/dx+sBuPLbgK254Zglt0tO4cFQu//PqKm46rT+/eO4zlv9iMjf9fSl/W1RIn66tGdazI51apXPP/JX87PT+9OzcivfXbGXJ+h28tXJLg7fFJScexQNvrKZ3l9as2ry7rNwMUs0oLonfZbbbZKSxa19xxPVHHtWJpet3sHNfMcdmt2PJ+h306tyKL7bsAeDu7w/k6scXx6q53HJ6f7buPcgTH6zj6537YraccENzO/DB2m38YHgP1m/7luZpKbz86aZ6vcb5J/TgolG5TLrrzXov/4qJeaSacfcrn9da79YzB/DlN3v437fW8B9je/Hgm19UqfPdwdk8/eH6Gl9jcI/2PHXpSNZs2cOEX78RUftyOrTglVknktEsNaL6DWWRXI/ezNYA2wAHPOice8jMtjvn2ofV2eac62BmzwFznHNv+/L5wHXOuYWVXnMmoR4/PXr0GPLllxGd998ocmc/Xzad06EFb183ge8/+C4L1mzlvesnsm7bXqY98G5ZnYn9ujB/+eay+cqhFO73Fw1lfN8u7Nx3kONu+WdZ+do5BTW2Ifyx0vK1cwp4aelGLvnzh3Wuz7WT+3LxqJ70u+mlOuuWKv0HrUnz1BQOHCqJ+PVEksFvzxvMZY/V/T8Z7ofDj+QXZw5o0PLMbFHYKEuNIu3Rj3LObTCzLsA8M1te27KrKavyaeKcewh4CCA/Pz9h735SuO1bANZvD/0+eKiE3fsr9uLWbdtbYb6mkAfKeoDFhw5/lXd+G1lvcuvuA5TU8wYztYU8oJAXqcaufQfr/ZyiXftj0JKKIhqjd85t8L83A88Aw4BNZpYF4H+XdmkLge5hT88BNkSrwUGVO/t5LvnTong3Q0QCqM6gN7NWZtamdBo4GVgKzAWm+2rTgWf99FzgAn/2zXBgR1Man4+nlz79Ot5NEJEAimTopivwjD+AmAY85px7ycw+AJ4wsxnAV8A0X/8FYAqwCtgLXBT1VouIBETR7tgP3dQZ9M65L4CB1ZR/A0ysptwBl0WldQG0/+Ah5i7ewKijOjXaMuN4YoqI1OHLb/bWXekwJcRlipPJr15eQdGu/dxzzvGNtkyr7vC4iETdI2+viXcTqqVLIHhffrOHRV/WfqZJNJQeYd/5bf2PzotIYltZyxl38aQevXfif78OVD2fXUQklrY0whi9evRRYNV+dSBxOJf4bRSR2FHQi4gEnIZuoqAhBzsb80SY372zhjYZ2tQiyUo9+giMv/P1CvNNcRDknvkr490EEYkTBX0E1virC4qINEUK+nqq57XBouqv738Vv4WLSJOloI9QInzpaPbTS+LdBBFpghT0IiIBp6AH9h6o+7ru0R6yiecQkIgkFwU9sOCLrXXWKb0BSayHcL49cCi2CxCRpKOgr2TfwbqDNpad8UPq6otIlCnoK9l/sHFukVfbnkFjXPtCRJKHgr6eHv9gXUzv8bh19wHyb30lZq8vIslHQV9P9722it+/szZmr79174EqZbmzn+eQ7h4iIg2koAdcPUfdDxQf/gHTmobiP9uws9ryg4caZ0hJRIJHQd8Aq4sO/5IIroakX7l512G/tohIOAV9gvmwhrtcrdyUmHeuEZHEp6CvpL7DONG2uHBHteWn3/d2I7dERIJCQR8nlggXzxGRpKCgFxEJOAU98bnuTE0HY0VEok1BHwUahhGRRKagr8Sa5I0CRURqpqCvpCFn3SzbWP2XnEREEoGCnviM0Wu4R0Qai4JeRCTgFPTE9vryNS5TZ92ISCNR0MfJLf/4LN5NEJEkEXHQm1mqmX1kZs/5+Z5mtsDMVprZ42bW3Jen+/lV/vHc2DQ9etS7FpEgq0+P/kpgWdj8HcDdzrk8YBsww5fPALY553oDd/t6Ce3xD9bFuwn11uenL/L655vj3QwRaQIiCnozywEKgIf9vAETgCd9lUeBM/30VD+Pf3yiJfgpJis3l18Zsql07g8Ul/DCkq/j3QwRaQIi7dH/BrgWKL37RSdgu3Ou2M8XAtl+OhtYB+Af3+HrV2BmM81soZktLCoqamDzRUSkLnUGvZmdBmx2zi0KL66mqovgsfIC5x5yzuU75/IzMzMjamysxPvSxCIisZQWQZ1RwBlmNgXIANoS6uG3N7M032vPATb4+oVAd6DQzNKAdsDWqLc8SkpKHCVhd+lL7EEmEZH6q7NH75y73jmX45zLBc4BXnXOnQ+8Bpztq00HnvXTc/08/vFXXQKf1nLqPW+xfvu38W6GiEjMHM559NcBs8xsFaEx+Ed8+SNAJ18+C5h9eE2MrRWbdI9WEQm2SIZuyjjnXgde99NfAMOqqbMPmBaFtsVF4u57iIg0jL4ZKyIScAp6EZGAU9CLiARcvcbog+JAcQl9fvoiV0zoHe+miIjEXFL26PcVHwLg9++sjW9DREQaQVIGvYhIMlHQV6KzK0UkaBT0legKCCISNAp6EZGAU9CLiAScgr4SjdGLSNAo6CuZ95nu2iQiwZLUQb9rf3GVsi27D8ShJSIisZPUQS8ikgwU9JX8+p8r4t0EEZGoUtBXUqKjsSISMEkX9J9v2sW0+9+NdzNERBpN0gX9L19YptsHikhSSbqgFxFJNgp6EZGAU9CLiAScgl5EJOAU9CIiAaegFxEJuKQLet1YRESSTdIFvYhIslHQi4gEnIJeRCTgFPQiIgGnoBcRCbikC3oznXcjIsmlzqA3swwze9/MFpvZp2b2c1/e08wWmNlKM3vczJr78nQ/v8o/nhvbVRARkdpE0qPfD0xwzg0Ejgcmm9lw4A7gbudcHrANmOHrzwC2Oed6A3f7eglD/XkRSTZ1Br0L2e1nm/kfB0wAnvTljwJn+umpfh7/+ETTeImISNxENEZvZqlm9jGwGZgHrAa2O+eKfZVCINtPZwPrAPzjO4BO0Wy0iIhELqKgd84dcs4dD+QAw4Cjq6vmf1fXe69yJ1Yzm2lmC81sYVFRUaTtFRGReqrXWTfOue3A68BwoL2ZpfmHcoANfroQ6A7gH28HbK3mtR5yzuU75/IzMzMb1voG0CCSiCSbSM66yTSz9n66BTAJWAa8Bpztq00HnvXTc/08/vFXnXNVevQiItI40uquQhbwqJmlEvpgeMI595yZfQb81cxuBT4CHvH1HwH+ZGa
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"plt.plot(rewards)"
]
},
{
"source": [
"From this graph, it is not possible to tell anything, because due to the nature of stochastic training process the length of training sessions varies greatly. To make more sense of this graph, we can calculate **running average** over series of experiments, let's say 100. This can be done conveniently using `np.convolve`:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x2814cc63488>]"
]
},
"metadata": {},
"execution_count": 22
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 379.461486 248.518125\" width=\"379.461486pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 248.518125 \r\nL 379.461486 248.518125 \r\nL 379.461486 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 33.2875 224.64 \r\nL 368.0875 224.64 \r\nL 368.0875 7.2 \r\nL 33.2875 7.2 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"mb0fbef4acd\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"48.505682\" xlink:href=\"#mb0fbef4acd\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- 0 -->\r\n <defs>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(45.324432 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"109.439343\" xlink:href=\"#mb0fbef4acd\" y=\"224.64\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- 20000 -->\r\n <defs>\r\n <path d=\"M 19.1875 8.296875 \r\nL 53.609375 8.296875 \r\nL 53.609375 0 \r\nL 7.328125 0 \r\nL 7.328125 8.296875 \r\nQ 12.9375 14.109375 22.625 23.890625 \r\nQ 32.328125 33.6875 34.8125 36.53125 \r\nQ 39.546875 41.84375 41.421875 45.53125 \r\nQ 43.3125 49.21875 43.3125 52.78125 \r\nQ 43.3125 58.59375 39.234375 62.25 \r\nQ 35.15625 65.921875 28.609375 65.921875 \r\nQ 23.96875 65.921875 18.8125 64.3125 \r\nQ 13.671875 62.703125 7.8125 59.421875 \r\nL 7.8125 69.390625 \r\nQ 13.765625 71.78125 18.9375 73 \r\nQ 24.125 74.21875 28.421875 74.21875 \r\nQ 39.75 74.21875 46.484375 68.546875 \r\nQ 53.21875 62.890625 53.21875 53.421875 \r\nQ 53.21875 48.921875 51.53125 44.890625 \r\nQ 49.859375 40.875 45.40625 35.40625 \r\nQ 44.1875 33.984375 37.640625 27.21875 \r\nQ 31.109375 20.453125 19.1875 8.296875 \r\nz\r\n\" id=\"DejaVuSans-50\"/>\r\n </defs>\r\n <g transform=\"translate(93.533093 239.238437)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-50\"/>\r\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_3\">\r\n <g id=\"line2d_3\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"170.373004\" xlink:href=\"#mb0fbef4acd\" y=\"224.64\"/>\
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD4CAYAAAANbUbJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO2dd3gVZfbHvycdAiGUAKEZelGqkY4gICDo4rr6U3dVVKxrWdeKde2ylnXX1bWiYu8FpYmAKCol9AABAgQIBAglQALp7++PO3Mzd+70O7fk3vN5njyZeeedmXfu3HvmzHlPISEEGIZhmOgmLtwDYBiGYYIPC3uGYZgYgIU9wzBMDMDCnmEYJgZgYc8wDBMDJIR7AADQokULkZWVFe5hMAzD1CtWrVp1SAiRYaVvRAj7rKws5OTkhHsYDMMw9Qoi2mW1L5txGIZhYgAW9gzDMDEAC3uGYZgYgIU9wzBMDMDCnmEYJgZgYc8wDBMDsLBnGIaJAUyFPRGlENEKIlpHRBuJ6DGp/V0i2klEa6W/flI7EdFLRJRPROuJaECwL4JhwoUQAp/n7EFldW24h8IwhlgJqqoAMFoIUUpEiQCWEtFcads9QogvVP3PA9BV+hsE4FXpP8NEHXNz9+OeL9aj4HAZ7hnfI9zDYRhdTDV74aFUWk2U/owqnkwG8J603zIA6USUGfhQGSbyOHaqCgBw6ERlmEfCMMZYstkTUTwRrQVwEMACIcRyadNTkqnmRSJKltraAtij2L1QalMf8wYiyiGinOLi4gAugWHCB4V7AAxjEUvCXghRI4ToB6AdgIFEdAaA+wH0AHAWgGYA7pO6a33//d4EhBBvCCGyhRDZGRmW8vgwDMMwDrHljSOEKAHwE4AJQogiyVRTAeAdAAOlboUA2it2awdgnwtjZRiGYRxixRsng4jSpeUGAMYCyJPt8EREAC4EkCvtMgvAVZJXzmAAx4QQRUEZPcOEmVrpnVUYTmPFBt+v34dVu46GexiMDlY0+0wAi4loPYCV8NjsvwfwIRFtALABQAsAT0r95wDYASAfwJsA/ur6qBkmQnjn150AgFnr+OX11o/W4E+v/hbuYTA6mLpeCiHWA+iv0T5ap78AcEvgQ2OYyGfbQY+jWnmVr599eVUNVhYcwYiuPB/FRAYcQcswQeDRWRtx5YwVyNt/PNxDYRgALOyZKOPD5bvw2co95h2DjKzxnyivDtk5K6pr8I9vc1FVw9G8jD8s7Jmo4sGvc3Hvl+tt7fPBsl3YeajM1XHIE5Wh9MO/94v1mPn7Lpz97OIQnpWpL7CwZ2KSkpOVKD5RASEEHvomFxe+8mtAx5twemvNdgqhtP9l2yEAQNGxchwqrbC0z9wNRciaNtsbCcxELyzsmZhkwBMLcNZTP3pdJwMVdh2aN9Rs31tSHtBx7XCkrC5lw/8WbzfsW3yiAmUV1XhtiaffzkNluPWj1fh4xe6gjpEJHyzsmZjE6x8vAvOPv3poFgCgVVqK5vZtB04EdHynmD28znrqR5z/36UorfDMKdQKge/XF+H+rzb49T1cWoGsabOxcPMBS+cuq3A2T3Gqsiak8w3Xv5cTUw83FvZMTBNoKJRsptETrtW14Qm2+s6C3//OQ2XYXuyZqyg1mEjeVOTxKHrn1wJL5z79H/ORf9D+Q67nI/PQ9cG55h1dYsGmA5oPt2Bj1cTmNizsmZimNkDN/rf8wwCAlxZu09xeFaY895U2NeSftribjHDjPnsup1v2h+cNKNQs33EY2U/+iHm5oU8qwMKeiXpeXrQNd362VnNbgLIeW0zMNG5r9vtKTiFr2mx8u3avq8dtmZbsXe720FzNCetgpoQ4WRk6F1XA13zn5C3EKTmSl9a6wmMhO6cMC3sm6nn+h634arW5cHzw6w2ue6XY1bDNeG7+FgDA3z7Rfng5JbNJ3ZxDZXUt1u4p8a6T5ED6a/5h3dw3gc591ITY3KU83dh//YwCl11v9ZAfag0S40NyPiUs7JmoJGvabNz3hbm/vdKM8+Hy3brmGKe4bcZJSdT/yd43wXmlrMR4a6JAby5ALazJps9pqOc21Oa7/cdD4zU1e73HfPP1GnffzKzAwp6JWj7N8Y2k3V5c6tdHrZAGatZR47YMS06o0wgHPf0j8vYfR8GhMmRNm42PVuxyfNz4OGvCOU5HiKuFtcXDeQl11K/64ZSU4FwU7iguxUPfbECthZtdcPgkAGD/sdC55MqwsGdihjEvLPFrU/88F2856Oo50xsmunq8FMXr/4HjFXjz550Y9fxPAIA9R045Pq6eEAeA2RvqJhP1uqmFNdmMHS456dx8tufISaxTmJ2soH6ot0hN1u5ogZs+WIUPlu3GVhPbf/GJCnTKSAUAnKqqcXw+p7CwZ2Ia9eu8kfWhsroWk176Bb/mH7J8/GqXNdZklQYaqDeRTFqKdgLcE+VVPr7oeh9PpcpcdeSkvZq8pQ598wFgxLOLMdlmBLRbnxsAlFV4BHd1jf4xtx44gbOe+hE7ikMzN6AFC3sm5thz5KR3Wf2bH9lNPyXxvpJT2LjvuC3f7CqX7TjJicER9o00hP2O4lJMfTfHp03vYaiM3gX0Hx56lIdY01V/biWnnBeM31vieaP6YlWhbp/cvaH3vlHDwp6JOUYoEoXd+tFqn21G5gdZPOxWPCzMqDHQ9pyg9uLQe5b8aUA7W8fVemaMfmEJVhQc8WnTM/esUZlRGibZE/bqCeK1e0rwyLe5ul4+xSc8Ub2L8qxF9cpU1dTio+W7/Wz26jcTJ5yq1H9gac0XhRoW9kxMIycPk3n7151+Oej3HDmJDQ79oqtq3fbG8RX2esLQonONl2cll04zCo9qzwscV7ms2nXFVNv8L3zlV7z3+y6/ojAysqZ8rerNw4zXl2zHA19vwOc5vlp4WoPA51aM5me27GdhzzARx6SXlvqsj3h2MS54ealO7zpOf2SenznCyI7rBPVEpJ4Zx6p3jczPW61F0Cona5WotdpaIXD9eznIUb0Z6KGnWS/bcdjS/laRbeay6UXG7n06dqoKX632fWAcLtM3BVVUh35CVg0LeybieG5+HuZv3B+047dpop20TMZKgI+W5lpWWYN3fyvwaat2WbP/RFWYRe/wahu6kkADoLR4YcFWn/XiExVYsOkArn/PXPP+bt0+fLNW23//pI5ppOCw/0RnTa3Aw9/k+szJKDl4vBxfSf7tB1R+9XbnPvo+9gPu/GwdNu6re+NTH1NJhc7DbF/JqaDcDy1Y2DMRxyuLt+PG91f5tdfUCld+GOkNkxztpzy3nvfI9Ll5PutVLmv2TVTmBr0UBvM36tuyQyFb5DQSRy24VN728RpsLtLOpaPnovjYd5v82tbsPor3l+3C3z/Vji6+Q9HevplvSuqaWoH8g6XImjYby228TWzcWzfuHq0b6/ar0LiO938vwNDpi3Drx2ssny8QWNgz9YbOD8zBzR+sNu9owiYdwWIHPWHfv0O6z7rbrpcNk6xN0KrZUHgMD3/jmfB0w4PnyhnLMeCJBbrbW0spnxPsRlepsOOPLkft1kjX98rifOQfrLOV/7a9TojLkawyNULgmTmbAQCv/7xD9xzPzN2MborMnD9trYvLGNalhe5+WuUpP1qxR3MswcJU2BNRChGtIKJ1RLSRiB6T2jsS0XIi2kZEnxJRktSeLK3nS9uzgnsJTCwxL0DzTiBeF0rBeqqyRtNdsEtGI591ZWRpba0wfNW3NgZfQW1Vbl/y+m94f9kulFZUY6mNOAE9ftl2CEfKKnVt0bLg7dgiNaDznLKRIE1+rtTWCuwrOYXn5m/BJa/9ptlXbbOvqRVYmOcR3Ivy9APrXl+ywyff0dDOdQLeaJ5kh0buncY23VMDxYpmXwFgtBCiL4B+ACYQ0WAA/wTwohCiK4CjAKZK/acCOCqE6ALgRakfw0QE3wSQk0QpaEe/sASXvv67Rh/fdeXE3yuL8zHo6YW6NmUr2H1Y1dQKHC+v8nq1LNx8EFe/s9Lx+dVco3Msed7DKDLXCqcqrV+vLGxrhMDlby7z7C89kI+
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"#code block 12"
]
},
{
"source": [
"## Varying Hyperparameters and Seeing the Result in Action\n",
"\n",
"Now it would be interesting to actually see how the trained model behaves. Let's run the simulation, and we will be following the same action selection strategy as during training: sampling according to the probability distribution in Q-Table: "
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# code block 13"
]
},
{
"source": [
"\n",
"## Saving result to an animated GIF\n",
"\n",
"If you want to impress your friends, you may want to send them the animated GIF picture of the balancing pole. To do this, we can invoke `env.render` to produce an image frame, and then save those to animated GIF using PIL library:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"360\n"
]
}
],
"source": [
"from PIL import Image\n",
"obs = env.reset()\n",
"done = False\n",
"i=0\n",
"ims = []\n",
"while not done:\n",
" s = discretize(obs)\n",
" img=env.render(mode='rgb_array')\n",
" ims.append(Image.fromarray(img))\n",
" v = probs(np.array([Qbest.get((s,a),0) for a in actions]))\n",
" a = random.choices(actions,weights=v)[0]\n",
" obs,_,done,_ = env.step(a)\n",
" i+=1\n",
"env.close()\n",
"ims[0].save('images/cartpole-balance.gif',save_all=True,append_images=ims[1::2],loop=0,duration=5)\n",
"print(i)"
]
}
]
}