You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/8-Reinforcement/1-QLearning/solution/notebook.ipynb

569 lines
431 KiB

{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
4 years ago
"version": "3.7.0"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
4 years ago
"display_name": "Python 3.7.0 64-bit ('3.7')"
},
"interpreter": {
4 years ago
"hash": "70b38d7a306a849643e446cd70466270a13445e5987dfa1344ef2b127438fa4d"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"# Peter and the Wolf: Reinforcement Learning Primer\n",
"\n",
"In this tutorial, we will learn how to apply Reinforcement learning to a problem of path finding. The setting is inspired by [Peter and the Wolf](https://en.wikipedia.org/wiki/Peter_and_the_Wolf) musical fairy tale by Russian composer [Sergei Prokofiev](https://en.wikipedia.org/wiki/Sergei_Prokofiev). It is a story about young pioneer Peter, who bravely goes out of his house to the forest clearing to chase the wolf. We will train machine learning algorithms that will help Peter to explore the surroinding area and build an optimal navigation map.\n",
"\n",
"First, let's import a bunch of userful libraries:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import random\n",
"import math"
]
},
{
"source": [
"## Overview of Reinforcement Learning\n",
"\n",
"**Reinforcement Learning** (RL) is a learning technique that allows us to learn an optimal behaviour of an **agent** in some **environment** by running many experiments. An agent in this environment should have some **goal**, defined by a **reward function**.\n",
"\n",
"## The Environment\n",
"\n",
"For simplicity, let's consider Peter's world to be a square board of size `width` x `height`. Each cell in this board can either be:\n",
"* **ground**, on which Peter and other creatures can walk\n",
"* **water**, on which you obviously cannot walk\n",
"* **a tree** or **grass** - a place where you cat take some rest\n",
"* **an apple**, which represents something Peter would be glad to find in order to feed himself\n",
"* **a wolf**, which is dangerous and should be avoided\n",
"\n",
"To work with the environment, we will define a class called `Board`. In order not to clutter this notebook too much, we have moved all code to work with the board into separate `rlboard` module, which we will now import. You may look inside this module to get more details about the internals of the implementation."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from rlboard import *"
]
},
{
"source": [
"Let's now create a random board and see how it looks:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 792x432 with 1 Axes>",
4 years ago
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"360.646265pt\" version=\"1.1\" viewBox=\"0 0 366.6475 360.646265\" width=\"366.6475pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 360.646265 \nL 366.6475 360.646265 \nL 366.6475 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 33.2875 336.76814 \nL 359.4475 336.76814 \nL 359.4475 10.60814 \nL 33.2875 10.60814 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g clip-path=\"url(#pa5c3528530)\">\n <image height=\"327\" id=\"image7eb26753ab\" transform=\"scale(1 -1)translate(0 -327)\" width=\"327\" x=\"33.2875\" xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAUcAAAFHCAYAAAAySY5rAAAABHNCSVQICAgIfAhkiAAAIABJREFUeJzs3XmcZFV99/HPOefeW7dubb3NzjLAAINsgiLiggoaJC4RohKVaAI+uAOCrzzRJ+YRCYlbQmJIRBJNAopoRBCNEpAdBgUU2ZdBZoAZmJneu7r2e895/rhdlxmpxmdgXk/1PPN7Q72qp6u6+szpqW+f7Z6jTjzxRLdixQoWqlarxezsLMPDw/0uyrymp6fxPI9CodDvosxr8+bNbN48gnOm30WZ18DARvbcc+H+W0yShPHxcRYvXtzvosyrVquRJAnlcrnfRZnX+Pg4xWKRXC7X76LMa+PGjXDGGWe4hWzt2rXuoosu6ncxntcVV1zh1qxZ0+9iPK9zzz3X5XITDtwCvVn37nd/vN/V9LzGx8fdeeed1+9iPK9bb73VXXnllf0uxvO68MIL3dq1a/tdjOd1+umnO93vhBZCiIXI63cBhNgZWWsBUErhnPud9845Oo0GSaeTvYbxfaxNsInNPucFAZ12e5vvFYQhQRjO+/rdcuxsuvXy23+XXp/b+mvme2xH2Po1JRyFeAEmJiZIkoRisUi1WmV4eJjp6WmGh4eZnJykVCqxZcsWBgcHmZ2tcu9P/ovbL/kPHrvrDiAdSNj7yCPZ8uSTTGzcSGIdiYOXvP4N/Pq664mtw6II8nle90fv5g8+/glKixYxsmgR9XqdcrnM+Pg4Q0NDKKV2ynBsNps0Gg1yuRytVgutNcYY2u02URTRaDQIgoBOp4NSCs/zaLVaRFFEvV4nCAKSJMH3fZxzAHieR7PZJJ/P02w2CYIAay3OOTzPo9FoEEVR9phzjiRJMMaQz+e3KZ+EoxDb6bffWFu/2WZnZ/F9n2azSalUwlrLHZdezI//6hyctYSm2zqCjXfdQeIgNJpEORLnWHvzjeQ0GKWIHYRRnl9d+zMmNj7NB7/8ZVrlMlprZmdnyeVy2X0Yhn2ule0XxzG+72OtxfO8LOCDIMge64YapK1G3/eJ45ggCAAwxmSteOccnU4nC1itNXEcZ9+v0+nged42j3VboEmSPKd8MuYoxHYyxpAkCUmSoLUmSRI8z8Nam73Zfd+n3W5zyzcu5Ka//wqhhrynyXua0DOEvib0NHmjCD1FOPf5nKfJeYacUeS0IzCKyuAg9eoMXz/jDOzcGzqXy5EkCUEQZOGxs+m2BJVSNJvNLKyazSbGGJrNJgDtdjsLue5jjUYDay3tdnuboQXf9+l0OtsE7Nah232s+zMzxuCcw5jnruKQcBRiO2mdvm263dluSHbfZM45tNasu+N27rr4X/E6TfK+zm6Rl97ynib0NXnPpI8ZyBs1F6BpSHamJ5l9ZgNRPo9zlv/5pt/LymCtzcqyM+rW2dZ/h+544tb1CM+ONVpr2bJlCz/96U/ZY4892GuvvVixYgW77bYb3/jGN9i0aRNAFnZaa7TWKKW2+Rl1n9N9rNewxM5bs0L0SavVwvd9fN8nSRLy+Xw2PjY7O0sQBFQnJxi7/Wbs6DPkPZUGotk2GLufz3uavNGEviFndNqSNJqcMQRa0Z6ZZstjaxlevISoVOTxX/+aarVKLpejVqvR2WqSZ2fSbDYpFArEcUyhUMi6ysViMRuWSJKEMAzJ5XI453jooYc47LDD+JM/+RNmZmaoVqvMzMwwMzPDJz/5Sfbff39uv/12qtUqQRBQr9ezFmmtViMMQ6rVKp7nUa/XabVaQNo6/W0SjkJsp3w+T71ep9FoYIyhXq9nb+hKpUKz2UTFbe655F/TVqEx5D1D5Kf3YRaOJm0lGjXXUky71oFJP84ZRTB3m9m0kcfvuoPlK1fys3/7NyqlIo1Gg4GBgZ1yvBEgiiJmZmbwfZ+ZmZmsGz01NUUURUxNTWGMoVar0Ww2ue222/jABz7A1NTUvK/Zbrf54Ac/yI033kir1aJUKmVd63K5TKPRYHBwkE6nQ7FYzCZhetWhhKMQ26lWq1EqlSgWi8RxTLFYZHp6miiKGB8fJ4oifvqxPyXnp13j/Fwo5j2V3gdp9/rZQEwDMm8g9Byhp9OANOnnfaMJPMPkxidZe9stNOs1rrnsMkqlEmNjY9Tr9X5XyQtSrVYZHByk3W4zODiYBdXIyAjVapWRkZHsap8oirjzzjt56qmnfufrbtq0iRtuuIEwDJmamqLVamGtZWJigkKhwNjYGEEQMD09Ta1WA+hZhzvnSK4QfWSMIY7jbKwqjmM8zyNJEnK5HFOP/4b6k+vIewajFFortAOlNA6wgLUOqx0WRxw7Yg2x1Rhl8QCNQykLiZ77CkAbmjNT2Habdffcw8vf9HvkCwV83+9fZbwIpVKJiYkJyuUyU1NTBEGA7/tMTEwwNDTE+Pg4pVKJarWK1pooirKvXbRoEbvvvvs245IbNmxgy5YtOOeypTyVSoVOp0OSJAwODjI5OcnQ0BAzMzOUy+VsUue3l/GAhKMQ280YQ6fT2WaSoDtBopTiri//JbpVI/B0Go4K9Fbj/daB1WCdInHQUQ7POTqJRScGRYJDgVI4p8GBdRZnoFWdYctjj/DkQw9y7PtOZunee2dr/HY2rVaLMAyJ45h8Pp9NiuTz+Wyto7WWXC7H6Ogot9xyCwCFQoELLriA448/fpvF4D/+8Y/5+Mc/zsTEBL/85S955plnsnFM51w2Llyv1zHG0Gq1srqL4/g513pLt1qI7dRqtcjn8+Tzeay12yw4fuK2G2lufDKdYPE0odHZ7HPoeYSeN9fFNs8+bhQ5rQi0zsYYAwO+Tj/2jcLTCl8rPK0ZW7+OvQ86kB//8wXMzMz0nEzYGWz9CyVJkiyotl4i1f3ls27dOn74wx8C6fjgG9/4RoCs5aiU4m1ve1u2Qc11113H/fffv81sdHdN5NYz4c93hZGEoxDbKYoiqtVqNutZrVazwX67fi129GlCTxPlAgr5gLyniTwzN1Ot5sYe0
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAFpCAYAAAC8p8I3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nOzdeZxcVZ3//9fn1l7V3dV7J2QjIexBwhaIC6MgyKACg47iyogzqD9QZ8YZdUZnXJDBr8vgMF8V40hEXFBHWYavy2AGR1lEQCEkbAkkgSSdpbu6u/a6yzm/P+p209F09k5VJZ8nj3pU1b23qj65Tb9zcu45p8Rai1JKqdbhNLoApZRSe0eDWymlWowGt1JKtRgNbqWUajEa3Eop1WI0uJVSqsVMW3CLyAUi8rSIrBWRj07X5yil1OFGpmMct4hEgGeA84CNwEPAW6y1TxzwD1NKqcPMdLW4lwBrrbXPWWtd4Fbg4mn6LKWUOqxMV3DPAl6Y9HxjuE0ppdR+ijbqg0XkSuBKgFgsdtpLXvKS/Xq/kZERPM+b/P709fXt13uOq9VqFAoFent7D8j7TYfR0VFisRiZTKbRpUxpcHCQ/v5+IpFIo0uZ0vPPP8/cuXMbXcaUfN9n+/btzJw5s9GlTKlYLOL7Pp2dnY0uZUrbt2+no6ODRCLR6FKmtHr1aiqViux0p7X2gN+ApcDPJz3/B+Afpjq+v7/f7o9bbrnF9vT0WGDiFo1G7T/90z/t1/uOW7NmjV22bNkBea/pctttt9n777+/0WXs0jXXXGNzuVyjy5iSMcZeffXVjS5jl4aHh+21117b6DJ26d5777W33357o8vYpRtvvNGuWbOm0WXsUpiLO83M6WpxPwQcLSLzgU3AZcBbD/SH+L7P97//fT7wgQ8wMjLyR/u+8IUvAPCRj3yEdDqNyM7/8lJKqVYyLX3c1lofuBr4OfAk8ANr7eoD/TmbNm3iHe94xx+F9rhKpcJnPvMZfvGLXxzoj1ZKqYaZtj5ua+1PgJ9M1/sDbN68GcdxCIJgymNEhK1btxIEAdFow7r0lVLqgGnpmZMPPvjgLkMbwBjD73//+x0uXCqlVCtr6eB+wxvesNsRCo7jcOGFF5JMJg9SVUopNb1aOrhjsRiLFy/e5THz58+np6fnIFWklFLTr6WDu6+vj6uuumqXx1x44YWceuqpOqJEKXXIaOngdhyHiy++mDvvvJNjjz12h32ZTIbvfe97fPjDH27qQfZKKbW3Wjq4oT5q5Ic//CHPPPPMDttLpRKf//znxycATdwrpVSra+ngXrt2LR/60Ie45ZZbdhrMv/vd77jiiit48MEHMcY0oEKllDrwWjK4jTE8/fTTfPCDH2T58uW7PPYXv/gF733ve/nNb36z26GDSinVCloquK21VKtVPvvZz/Lyl7+cn//853v0uscee4yLLrqIt7/97eTz+clrqiilVMtpqamErutyww038LGPfWyvX5vL5bj11ltJpVJ87nOf0yGCSqmW1VLB/bnPfY5PfOIT+/Uey5cvJxaL8eUvf1mnwCulWlJLdJVYa7nuuuu47rrrDkgXx/Lly/mLv/gLvWCplGpJTR/cruvyb//2b3zyk5+kUqnssO/kk0/eoynvxx9//A6ta8/zuPXWW7nyyivJ5/PTUrdSSk2Xpg5uay1f/vKX+fCHP4zrujvsO+ecc/jhD3+4R8H9pS99ife9730sWbJkYnsQBHzzm9/kox/9KMVicVrqV0qp6dDUwf2///u/fPzjH99hZb+5c+dy3XXX8fWvf51sNrvH73X99ddz/fXXc9ZZZ01Mfw+CgBtvvJFbb71VR5kopVpG0wa3MYbvfve7VKvViW19fX3ccMMN/O3f/i3z58/fq/dzHIclS5bwxS9+kcnfb2mtZfny5drfrZRqGU0b3CLCW97yFhYtWgTAMcccw3/8x3/w+te/nng8vteLRokI0WiUpUuXsnz5cpYsWYKIMHfuXK688kocp2lPhVJK7aBpx8OJCK985StZtmwZd9xxB5deeimnn376Hx23t10cIsIpp5zCHXfcwb//+7+zdOlSzj//fF09UCnVMpo2uMctWbKEM844Y6fBWi6X8X1/l68PgoByuYy1dof3GBgY4JprrtHAVkq1nKbuHxARRATHcXYasDNnzuTTn/70Lt/jLW95Cy9/+ct3+t7j76vhrZRqJU0d3LsTiUTo6ura5TFtbW0kk0kNZ6XUIaOlg1sppQ5HGtxKKdViNLiVUqrFaHArpVSL0eBWSqkWo8GtlFItRoNbKaVajAa3Ukq1mJYObmvtbqe8G2N05T+l1CFlv4JbRNaLyOMi8qiIPBxu6xaRu0VkTXi/66mN+yEajTJv3ryJWZHJZJJFixbtMJuyv7+f9vb26SpBKaUOugOxyNSrrLVDk55/FFhhrf2siHw0fP6RA/A5O3XkkUdy+eWXU6vVWLBgAZ/5zGe46aabWLFiBZFIhDPOOGO6PloppRpiOlYHvBh4Zfj4ZuCXTFNwiwiLFy9m+fLlO2y/4ooruOKKK6bjI5VSquH2t4/bAv8tIo+IyJXhtgFr7WD4eAswsJ+foZRSapL9bXG/3Fq7SUT6gbtF5KnJO621VkR2+k0HYdBfCfUV/NasWbOfpUyfjRs3Mjo62tQ1Dg0NYYxp6hpLpRLr1q1jaGho9wc3iOu6TX0O8/k8pVKpqWvcsmVL0/++jI6O8sILLzT1d83ualDFfgW3tXZTeL9NRG4DlgBbRWSmtXZQRGYC26Z47TJgGUBPT4/95S9/uT+lTKvR0VE2btxIM9f47LPPkk6nGR4ebnQpUxoaGuL+++8nkUg0upQpFYvFpv45V6tVHtj+AHf88o5GlzKl9GCacyvnNvVork2bNvHII4+wdu3aRpcypV2eP2vtPt2ADNA+6fH9wAXA54GPhts/Cnxud+/V399vm9maNWvssmXLGl3GLt122232/vvvb3QZu3TNNdfYXC7X6DKmZIyxV199daPL2KXh4WF72rWnWZr4vxn3zrC33357o0/VLt144412zZo1jS5jl8Jc3Glm7k+LewC4LRyKFwW+a639mYg8BPxARN4NbADetB+foZRS6g/sc3Bba58DTt7J9mHg3P0pSiml1NRaeuakUkodjjS4lVKqxWhwK6VUi9HgVkqpFqPBrZRSLUaDWymlWowGt1JKtRgNbqWUajEa3Eop1WI0uJVSqsVocCulVIvR4FZKqRajwa2UUi1Gg1sppVqMBrdSSrUYDW6llGoxGtxKKdViNLiVUqrFaHArpVSL0eBWSqkWo8GtlFItRoNbKaVajAa3Ukq1GA1upZRqMRrcSinVYjS4lVKqxWhwK6VUi9HgVkqpFqPBrZRSLUaDWymlWsxug1tEbhKRbSKyatK2bhG5W0TWhPdd4XYRkRtEZK2IrBSRU6ezeKWUOhztSYv7m8AFf7Dto8AKa+3RwIrwOcCfAkeHtyuBrx6YMpVSrUREGl3CIW23wW2t/RWQ+4PNFwM3h49vBi6ZtP1btu43QKeIzDxQxSqlWoO1ttElHNL2tY97wFo7GD7eAgyEj2cBL0w6bmO4TSml1AGy3xcnbf2v1r3+61VErhSRh0Xk4Uqlsr9lKKXUYWNfg3vreBdIeL8t3L4JmDPpuNnhtj9irV1mrT3dWnt6KpXaxzKUUurwE93H190JXA58Nry/Y9L2q0XkVuBMYGxSl8qUgiDg9ttv38dSpt/Q0BDPPvtsU9e4atUqNmzYwNatWxtdypS2bNnCz372M5r5L+p8Pt/UP+dyuUxmMMOC2xc0upQpta9vZ1VpVVP3cz/33HNEo1FWrVq1+4MbJAiCKfftNrhF5HvAK4FeEdkIfIJ6YP9ARN4NbADeFB7+E+BCYC1QBt61JwW6rvC+9w3s/sAGSacNl1+eZmCgeWvcsGEDN96YZXS0eWtcuDDBJZf0kclkGl3KlKLRaFP/nIvFImckzuCzA59tdClTemrkKQpOoanPYzqd5l+6/4XyQLnRpUzJFXfKfbsNbmvtW6bYde5OjrXAVXtc2cTrHLZ
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"width, height = 8,8\n",
"m = Board(width,height)\n",
"m.randomize(seed=13)\n",
"m.plot()"
]
},
{
"source": [
"## Actions and Policy\n",
"\n",
"In our example, Peter's goal would be to find an apple, while avoiding the wolf and other obstacles. To do this, he can essentially walk around until he finds and apple. Therefore, at any position he can chose between one of the following actions: up, down, left and right. We will define those actions as a dictionary, and map them to pairs of corresponding coordinate changes. For example, moving right (`R`) would correspond to a pair `(1,0)`."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"actions = { \"U\" : (0,-1), \"D\" : (0,1), \"L\" : (-1,0), \"R\" : (1,0) }\n",
"action_idx = { a : i for i,a in enumerate(actions.keys()) }"
]
},
{
"source": [
"The strategy of our agent (Peter) is defined by a so-called **policy**. Let's consider the simplest policy called **random walk**.\n",
"\n",
"## Random walk\n",
"\n",
"Let's first solve our problem by implementing a random walk strategy."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"18"
]
},
"metadata": {},
"execution_count": 5
}
],
"source": [
"def random_policy(m):\n",
" return random.choice(list(actions))\n",
"\n",
"def walk(m,policy,start_position=None):\n",
" n = 0 # number of steps\n",
" # set initial position\n",
" if start_position:\n",
" m.human = start_position \n",
" else:\n",
" m.random_start()\n",
" while True:\n",
" if m.at() == Board.Cell.apple:\n",
" return n # success!\n",
" if m.at() in [Board.Cell.wolf, Board.Cell.water]:\n",
" return -1 # eaten by wolf or drowned\n",
" while True:\n",
" a = actions[policy(m)]\n",
" new_pos = m.move_pos(m.human,a)\n",
" if m.is_valid(new_pos) and m.at(new_pos)!=Board.Cell.water:\n",
" m.move(a) # do the actual move\n",
" break\n",
" n+=1\n",
"\n",
"walk(m,random_policy)"
]
},
{
"source": [
"Let's run random walk experiment several times and see the average number of steps taken:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Average path length = 32.87096774193548, eaten by wolf: 7 times\n"
]
}
],
"source": [
"def print_statistics(policy):\n",
" s,w,n = 0,0,0\n",
" for _ in range(100):\n",
" z = walk(m,policy)\n",
" if z<0:\n",
" w+=1\n",
" else:\n",
" s += z\n",
" n += 1\n",
" print(f\"Average path length = {s/n}, eaten by wolf: {w} times\")\n",
"\n",
"print_statistics(random_policy)"
]
},
{
"source": [
"## Reward Function\n",
"\n",
"To make our policy more intelligent, we need to understand which moves are \"better\" than others.\n",
"\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"move_reward = -0.1\n",
"goal_reward = 10\n",
"end_reward = -10\n",
"\n",
"def reward(m,pos=None):\n",
" pos = pos or m.human\n",
" if not m.is_valid(pos):\n",
" return end_reward\n",
" x = m.at(pos)\n",
" if x==Board.Cell.water or x == Board.Cell.wolf:\n",
" return end_reward\n",
" if x==Board.Cell.apple:\n",
" return goal_reward\n",
" return move_reward"
]
},
{
"source": [
"## Q-Learning\n",
"\n",
"Build a Q-Table, or multi-dimensional array. Since our board has dimentions `width` x `height`, we can represent Q-Table by a numpy array with shape `width` x `height` x `len(actions)`:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"Q = np.ones((width,height,len(actions)),dtype=np.float)*1.0/len(actions)"
]
},
{
"source": [
"Pass the Q-Table to the plot function in order to visualize the table on the board:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 792x432 with 1 Axes>",
4 years ago
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"360.646265pt\" version=\"1.1\" viewBox=\"0 0 366.6475 360.646265\" width=\"366.6475pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 360.646265 \nL 366.6475 360.646265 \nL 366.6475 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 33.2875 336.76814 \nL 359.4475 336.76814 \nL 359.4475 10.60814 \nL 33.2875 10.60814 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g clip-path=\"url(#p267d475764)\">\n <image height=\"327\" id=\"image00b8757863\" transform=\"scale(1 -1)translate(0 -327)\" width=\"327\" x=\"33.2875\" xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAUcAAAFHCAYAAAAySY5rAAAABHNCSVQICAgIfAhkiAAAIABJREFUeJzs3XmcFNW99/FPrV1dvc3KjuyCyqaCxpi4oAZxB0VjUNwAY9wFjXG5NxGJRuNOcokGjbhE44LRJBgVAUGIglEQZBAEVJBl1p7el6p6/ujpckYb86Dt7Z7reSfzGpzuqf5NMf3l1DmnzpEmTJjg9OzZk3KVSqWIRqNUV1eXupQ9CofDqKqKz+crdSl7tGvXLnbtqsFxlFKXskcVFdvp06d8fxcty6KxsZEuXbqUupQ9isViWJZFMBgsdSl71NjYiN/vx+PxlLqUPdq+fTtceeWVTjnbuHGj8+CDD5a6jK80f/58Z/ny5aUu4yvNnDnT8XiaHHDK9MN2zjzzslKfpq/U2NjozJo1q9RlfKVly5Y5L7zwQqnL+Epz5sxxNm7cWOoyvtIVV1zhyKVOaEEQhHKklroAQeiMbNsGQJIkHMf5j58dxyGTSGBlMu4xFE3Dti1sy3a/puo6mXS6w2vphoFuGHs8fr6OziZ/Xr74sxT6Wvvv2dNjxdD+mCIcBeFraGpqwrIs/H4/kUiE6upqwuEw1dXVNDc3EwgE2L17N5WVlUSjEdb84++seOxRNq16G8h1JPQ/9FB2f/IJTdu3Y9kOlgP7H3U07y18naztYCOhe70c+eMzOfWyywnU1lJTW0s8HicYDNLY2EhVVRWSJHXKcEwmkyQSCTweD6lUClmWURSFdDqNaZokEgl0XSeTySBJEqqqkkqlME2TeDyOrutYloWmaTiOA4CqqiSTSbxeL8lkEl3XsW0bx3FQVZVEIoFpmu5jjuNgWRaKouD1ejvUJ8JREPbSF99Y7d9s0WgUTdNIJpMEAgFs2+btJ+fxt1t/hWPbGEq+dQTbV72N5YChyFiSg+U4bHxjMR4ZFEki64Bhevn3q6/RtP0zptx5J6lgEFmWiUajeDwe97NhGCU+K3svm82iaRq2baOqqhvwuq67j+VDDXKtRk3TyGaz6LoOgKIobivecRwymYwbsLIsk81m3dfLZDKoqtrhsXwL1LKsL9Un+hwFYS8pioJlWViWhSzLWJaFqqrYtu2+2TVNI51Os3TuHJbc+1sMGbyqjFeVMVQFQ5MxVBmvImGoEkbb1z2qjEdV8CgSHtlBVyRClZXEI6384corsdve0B6PB8uy0HXdDY/OJt8SlCSJZDLphlUymURRFJLJJADpdNoNufxjiUQC27ZJp9MduhY0TSOTyXQI2Pahm38s/3emKAqO46AoX57FIcJREPaSLOfeNvnL2XxI5t9kjuMgyzJb3l7Bqnl/RM0k8Wqy+2GquQ+vKmNoMl5VyT2mgFeR2gI0F5KZcDPRHdswvV4cx+bnx/3IrcG2bbeWzih/ztr/DPn+xPbnET7va7Rtm927d7NgwQL22Wcf+vXrR8+ePenVqxdz585l586dAG7YybKMLMtIktTh7yj/nPxjhbolOu+ZFYQSSaVSaJqGpmlYloXX63X7x6LRKLquE2luomHFG9j1O/CqUi4QlY7BmP+6V5XxKjKGpuBR5FxLUpHxKAq6LJFuDbN700aqu3TFDPjZ/N57RCIRPB4PsViMTLtBns4kmUzi8/nIZrP4fD73Utnv97vdEpZlYRgGHo8Hx3FYv349Bx54IOeffz6tra1EIhFaW1tpbW3l6quvZvDgwaxYsYJIJIKu68TjcbdFGovFMAyDSCSCqqrE43FSqRSQa51+kQhHQdhLXq+XeDxOIpFAURTi8bj7hg6FQiSTSaRsmtWP/THXKlQUvKqCqeU+G244KrlWoiK1tRRzl9a6kvuzR5HQ2z5ad25n86q36dG3L6898gihgJ9EIkFFRUWn7G8EME2T1tZWNE2jtbXVvYxuaWnBNE1aWlpQFIVYLEYymeTNN9/kvPPOo6WlZY/HTKfTTJkyhcWLF5NKpQgEAu6ldTAYJJFIUFlZSSaTwe/3u4Mwhc6hCEdB2EuxWIxAIIDf7yebzeL3+wmHw5imSWNjI6ZpsuDSC/BouUtjb1soelUp91nPXV5/Hoi5gPQqYKgOhirnAlLJfV1TZHRVoXn7J2x8cynJeIxXnnqKQCBAQ0MD8Xi81Kfka4lEIlRWVpJOp6msrHSDqqamhkgkQk1NjXu3j2marFy5kk8//fQ/Hnfnzp0sWrQIwzBoaWkhlUph2zZNTU34fD4aGhrQdZ1wOEwsFgMoeA47Z0+uIJSQoihks1m3ryqbzaKqKpZl4fF4aNn8EfFPtuBVFRRJQpYlZAckScYBbMC2HWzZwcYhm3XIypC1ZRTJRgVkHCTJBktu+w5AVki2tmCn02xZvZpRx/0Ir8+HpmmlOxnfQCAQoKmpiWAwSEtLC7quo2kaTU1NVFVV0djYSCAQIBKJIMsypmm631tbW0vv3r079Etu27aN3bt34ziOO5UnFAqRyWSwLIvKykqam5upqqqitbWVYDDoDup8cRoPiHAUhL2mKAqZTKbDIEF+gESSJFbd+V/IqRi6KufCUQK5XX+/7YAtg+1IWA5kJAfVcchYNrKlIGHhIIEk4TgyOGA7No4CqUgruzdt4JP1H3DMpHPo1r+/O8evs0mlUhiGQTabxev1uoMiXq/Xneto2zYej4f6+nqWLl0KgM/nY/bs2YwbN67DZPC//e1vXHbZZTQ1NfHOO++wY8cOtx/TcRy3Xzgej6MoCqlUyj132Wz2S/d6i8tqQdhLqVQKr9eL1+vFtu0OE44/fnMxye2f5AZYVBlDkd3RZ0NVMVS17RJb+fxxRcIjS+iy7PYx6gpocu7PmiKhyhKaLKHKMg1bt9B/6AH87fezaW1tLTiY0Bm0/wfFsiw3qNpPkcr/47Nlyxb++te/Arn+wWOPPRbAbTlKksTJJ5/sLlCzcOFC1q5d22E0Oj8nsv1I+FfdYSTCURD2kmmaRCIRd9QzEom4nf321o3Y9Z9hqDKmR8fn1
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAFpCAYAAAC8p8I3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nOzdeXhU5f3+8fczk8m+BwIYBGQRZIkoiKIsIm6tuwVEZSkiqC1SXHAB259VYhUVqrUVUURA3FGwFLQKrsAXxJZNBQIkICEEkpBl9plznt8fmaRESEBJODPh8/LKlcmZ5dwkmduT5yyP0lojhBAictisDiCEEOLnkeIWQogII8UthBARRopbCCEijBS3EEJEGCluIYSIMI1W3EqpK5VS25RSO5RSDzXWeoQQ4lSjGuM4bqWUHdgOXAbsBb4BbtZaf9/gKxNCiFNMY21x9wF2aK13aa39wFvAdY20LiGEOKU0VnFnAT8e9vXe0DIhhBAnKMqqFSulxgPjARwOR6/s7GyrohyTz+ejsrKSZs2aWR2lTmVlZTgcDhISEqyOUqfCwkIyMzOx2+1WR6nTnj17aNOmjdUx6hQMBjl48CCtWrWyOkqdnE4nwWCQ1NRUq6PU6eDBgyQnJxMTE2N1lDp99913eDweddQ7tdYN/gH0BT4+7OuHgYfrenxmZqYOZ7m5uXr27NlWx6jXBx98oFevXm11jHo9/vjjurS01OoYdTJNU0+YMMHqGPUqKSnROTk5Vseo19dff60XL15sdYx6zZo1S+fm5lodo16hXjxqZzbWUMk3QCel1BlKqWhgOPBhI61LCCFOKY0yVKK1DiqlJgAfA3bgVa31d42xLiGEONU02hi31noZsKyxXl8IIU5VcuakEEJEGCluIYSIMFLcQggRYaS4hRAiwkhxCyFEhJHiFkKICCPFLYQQEUaKWwghIowUtxBCRBgpbiGEiDBS3EIIEWGkuIUQIsI02eLOy8urvhZ4WDIMgz179lgdo14ul4uDBw9aHaNeBw8exOVyWR2jXnv27MEwDKtj1ElrTV5entUx6uX3+9m3b5/VMepVVlZGWVnZSVmXZTPgNJZt27bx1VdfsXnzZnr06EHPnj3p3bu31bFq+fzzz9m2bRvbtm2ja9euXHnllbRu3drqWLW8/fbbFBQUUFFRQevWrRk5cmRYzRbi8/lYsGABe/fuJSkpiaysLIYPH251rFoKCgpYvnw533//PZ07d+bMM89k0KBBVseqZf369WzYsIFNmzaRnZ1Nv3796NKli9Wxalm6dCn5+fkUFBTQoUMHhgwZElaz62itmTdvHoWFhQC0atWK0aNHo9TRJ69pCE1ui/urr75i7dq1/OlPf2LXrl188MEHVkc6wpw5c3C73TzwwAMsX76cLVu2WB3pCI899hjdunXjhhtu4K9//Stut9vqSLV4PB5mzpzJ9ddfT/fu3XnsscesjnSE7777jmXLljF58mS8Xi+vvPKK1ZGOsHjxYnbt2sWjjz7KunXr+Oqrr6yOdISnn36ajIwMxo4dy4IFC9i/f7/VkWoxTZNp06YxYMAA+vfvz7Rp0zBNs1HX2aSKOy8vj82bNxMTE8O1117LQw89RGxsLKtWrbI6Wo1FixbRt29fPvroIx555BH+/ve/s3DhQsrLy62OVmPKlCn85S9/4aGHHmL16tUsXryY22+/3epYtdx+++1MmTKFMWPGkJqayt///ncefvhhq2PVqKioYP78+XTv3p2rrrqKkSNH0q9fP9577z2ro9VYvXo10dHR7Nq1i9GjR/Pkk0+yZcsWdu3aZXW0Gs8//zx33XUXzz33HPPnz2fBggVMmTIlrIaexo8fzzPPPMOECRMoLS3lvffeY/z48Y26ziY1VNKuXTt69OjBzp07Wb58OfPnz8fr9XLhhRdaHa3GjTfeyOjRo7ntttu4/PLLeeCBB7jllltITk62OlqNadOm0bNnT+bNm0daWho33XQTy5aF15wYL7/8Mpdffjnvv/8+FRUV3H777WzcuNHqWDWSkpIYNWoUb7/9Np9++ikrV65k1apVLFiwwOpoNfr27cuyZcu48MIL+e1vf8tTTz1F9+7dOeOMM6yOVuPuu+9m0KBBPPnkk3Tu3Jnx48fz1FNPYbOFzzbnSy+9RHZ2Nu+//z5Q9R7ftGlTo66zSRW3UoqePXuSl5fHc889h9fr5eKLL27UsaafSynFNddcw8aNG9m2bRuZmZl06NAhrDLabDZGjRrFBx98gM1m49JLLyU+Pt7qWLXExsZy+eWXM2/ePEzTZNSoUWH1ZlZK0b59e1q2bMkLL7yA3+/n2muvDaufs1KKgQMH8vnnn/Pcc88B0LNnz7DLeOutt/Lpp5/yxRdfkJ2dTfPmzcMqo81mY9iwYbz55psADBs2rNF/F5tUcQP07t2b3r17s2rVKi688MKw+gFXGzp0KDfccAPffvst559/vtVxjur++++npKSE0tJSOnXqZHWcI8TFxZGTk0Nubi5paWk0a9bM6khH6NixIzk5Oaxbt45zzjkHh8NhdaQjXHbZZVx66aWsXr2aiy66yOo4RzVu3DhcLhc7d+4kOzvb6jhHsNlsPProo+zduxfgpBxo0OSKu1q4/hJWi4qKCtvSrpaRkUFGRobVMeoVjv9T+ak+ffpYHaFeSqmwf78kJCSEZWkf7mQeGRY+f1sKIYQ4LlLcQggRYaS4hRAiwkhxCyFEhJHiFkKICCPFLYQQEUaKWwghIowUtxBCRBgpbiGEiDBS3EIIEWFO6JR3pVQ+UAkYQFBr3VsplQ68DbQD8oFhWutDJxZTCCFEtYbY4h6kte6pta6eZuYhYIXWuhOwIvS1EEKIBtIYQyXXAfNCt+cB1zfCOoQQ4pR1osWtgX8rpb5VSlVP+dBCa10Yur0faHGC6xBCCHGYE72saz+tdYFSKhP4RCm19fA7tdZaKXXUqdZDRT8eIDExkdzc3BOM0nj27t1LWVlZWGcsLi7GNM2wzuhyucjLy6O4uNjqKHXy+/1h/T2sqKjA5XKFdcb9+/eH/fulrKyMH3/8Ea2PWk9hob55K0+ouLXWBaHPB5RSHwB9gCKlVCutdaFSqhVwoI7nzgZmA2RkZOjPP//8RKI0qrKyMvbu3Us4Z9y5cyfx8fGUlJRYHaVOxcXFrF69Oqxmi/8pp9MZ1j9nr9fLmoNrWPL5Equj1Cm+MJ7BnsGNPmHuiSgoKODbb79lx44dVkepU73fP631L/oAEoCkw26vBq4EngYeCi1/CJh+rNfKzMzU4Sw3N1fPnj3b6hj1+uCDD/Tq1autjlGvxx9/XJeWllodo06maeoJEyZYHaNeJSUluldOL00Y/9fy65Z68eLFVn+r6jVr1iydm5trdYx6hXrxqJ15IlvcLYAPQlODRQFvaK0/Ukp9A7yjlBoL7AaGncA6hBBC/MQvLm6t9S7g7KMsLwEGn0goIYQQdZMzJ4UQIsJIcQshRISR4hZCiAgjxS2EEBFGilsIISKMFLcQQkQYKW4hhIgwUtxCCBFhpLiFECLCSHELIUSEkeIWQogII8UthBARRopbCCEijBS3EEJEmCZb3IsWLQrraYn8fj9Lly61Oka9CgoKWL9+vdUx6vXtt99SUFBgdYx6LV26FL/fb3WMiFZeXh7WMxMBbNu2jW3btp2UdZ3onJNh5/PPP2fOnDn07duX0aNHc8011zB06FCrY9Uya9Ys1q5dS8+ePRk5ciSTJ08mOzvb6li13HvvvRiGQVpaGs8//zwvvvgiCQkJVseq4Xa7ufPOO2nfvj1lZWUopZg5c6bVsWrZvHkz06dPp3fv3txxxx306dOHu+66y+pYEScnJ4c9e/bQvn175syZw1NPPcVpp51mdawapmkyfvx4mjVrhtaakpISZs+ejc3WeNvFTWqL2zAMtm3bRteuXbnpppu4+uqr2bhxI8Fg0OpoNfx+P2vXrmXo0KGMGDGCrKws8vPzw+qvA4/Hw9dff81dd93FnXfeWTP5azipnux1zJgx3HvvvaxatQqPx2N1rBpaa3bv3k2zZs24+ea
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"m.plot(Q)"
]
},
{
"source": [
"## Essence of Q-Learning: Bellman Equation and Learning Algorithm\n",
"\n",
"Write a pseudo-code for our leaning algorithm:\n",
"\n",
"* Initialize Q-Table Q with equal numbers for all states and actions\n",
"* Set learning rate $\\alpha\\leftarrow 1$\n",
"* Repeat simulation many times\n",
" 1. Start at random position\n",
" 1. Repeat\n",
" 1. Select an action $a$ at state $s$\n",
" 2. Exectute action by moving to a new state $s'$\n",
" 3. If we encounter end-of-game condition, or total reward is too small - exit simulation \n",
" 4. Compute reward $r$ at the new state\n",
" 5. Update Q-Function according to Bellman equation: $Q(s,a)\\leftarrow (1-\\alpha)Q(s,a)+\\alpha(r+\\gamma\\max_{a'}Q(s',a'))$\n",
" 6. $s\\leftarrow s'$\n",
" 7. Update total reward and decrease $\\alpha$.\n",
"\n",
"## Exploit vs. Explore\n",
"\n",
"The best approach is to balance between exploration and exploitation. As we learn more about our environment, we would be more likely to follow the optimal route, however, choosing the unexplored path once in a while.\n",
"\n",
"## Python Implementation\n",
"\n",
"Now we are ready to implement the learning algorithm. Before that, we also need some function that will convert arbitrary numbers in the Q-Table into a vector of probabilities for corresponding actions:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def probs(v,eps=1e-4):\n",
" v = v-v.min()+eps\n",
" v = v/v.sum()\n",
" return v"
]
},
{
"source": [
"We add a small amount of `eps` to the original vector in order to avoid division by 0 in the initial case, when all components of the vector are identical.\n",
"\n",
"The actual learning algorithm we will run for 5000 experiments, also called **epochs**: "
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
""
]
}
],
"source": [
"\n",
"from IPython.display import clear_output\n",
"\n",
"lpath = []\n",
"\n",
"for epoch in range(10000):\n",
" clear_output(wait=True)\n",
" print(f\"Epoch = {epoch}\",end='')\n",
"\n",
" # Pick initial point\n",
" m.random_start()\n",
" \n",
" # Start travelling\n",
" n=0\n",
" cum_reward = 0\n",
" while True:\n",
" x,y = m.human\n",
" v = probs(Q[x,y])\n",
" a = random.choices(list(actions),weights=v)[0]\n",
" dpos = actions[a]\n",
" m.move(dpos,check_correctness=False) # we allow player to move outside the board, which terminates episode\n",
" r = reward(m)\n",
" cum_reward += r\n",
" if r==end_reward or cum_reward < -1000:\n",
" print(f\" {n} steps\",end='\\r')\n",
" lpath.append(n)\n",
" break\n",
" alpha = np.exp(-n / 3000)\n",
" gamma = 0.5\n",
" ai = action_idx[a]\n",
" Q[x,y,ai] = (1 - alpha) * Q[x,y,ai] + alpha * (r + gamma * Q[x+dpos[0], y+dpos[1]].max())\n",
" n+=1"
]
},
{
"source": [
"After executing this algorithm, the Q-Table should be updated with values that define the attractiveness of different actions at each step. Visualize the table here:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 792x432 with 1 Axes>",
4 years ago
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"360.646265pt\" version=\"1.1\" viewBox=\"0 0 366.6475 360.646265\" width=\"366.6475pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 360.646265 \nL 366.6475 360.646265 \nL 366.6475 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 33.2875 336.76814 \nL 359.4475 336.76814 \nL 359.4475 10.60814 \nL 33.2875 10.60814 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g clip-path=\"url(#p31952a86ca)\">\n <image height=\"327\" id=\"image3fd90cd4cb\" transform=\"scale(1 -1)translate(0 -327)\" width=\"327\" x=\"33.2875\" xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAUcAAAFHCAYAAAAySY5rAAAABHNCSVQICAgIfAhkiAAAIABJREFUeJzsnXmcFNXVv59au7p6mZV9X0VBBRSXuEQxCsYIShQiGlQwIiRxAaNmfUXUV1GDJC7EXzRG0YgbIokxbhAFNAgqyqZssm+z9/TeVff3R0+XjIJ5kR56Ru6j/emhu+bW6Zrub597z7nnKCNGjBAdOnSguZJMJqmvr6esrKzQpuyX2tpadF0nEAgU2pT9smvXLnbtKkcIrdCm7Jfi4m106XLw70XHcdi0aRPdu3dn/fr1dO3albq6OlzXPaj3keM4VFZW0rp164O2samIRqM4jkM4HM7bmJ999hm9e/fO23iVlZUEg0F8Pl/exsw327Ztg+uuu040Z9auXSseeeSRQpvxtcyZM0csXry40GZ8LVOnThU+X5UA0Uxvrhg58md5ea3RaFTcdNNNYsmSJWLw4MFi8+bN4o477hDvvPPOQY1bWVkp7rjjjrzY2FQsXLhQvPTSS3kbz3Ec0b9//7yNJ4QQM2fOFGvXrs3rmPnm2muvFWqhFVoiyTe2bXPWWWcxduxYVqxYwU9+8hNWrlzJqaeeWmjTJC0IvdAGSCRNwWmnncY//vEPLrnkEu677z7atGmT1/Fd1wVAURSEEP/1XghBOh7HSae9MTTDwHUdXMf1HtNNk3Qq1ehcpmVhWtZ+x8/Z0dLIXZcvv5Z9Pbb37+zvuXyw95hSHCXfSvx+P507d8ayLDp27EhRUVFex6+qqsJxHILBIJFIhLKyMmpraykrK6O6uppQKMTu3bspKSmhvj7Cx6/8g3ef/Cvrli4BsgsJ3U88kd2bN1O1bRuOK3AEHHXGmXz05ltkXIGLgun3890fjWT4z35OqFUrylu1IhaLEQ6HqayspLS0FEVRWqQ4JhIJ4vE4Pp+PZDKJqqpomkYqlcK2beLxOKZpkk6nURQFXddJJpPYtk0sFsM0TRzHwTAMhBAA6LpOIpHA7/eTSCQwTRPXdRFCoOs68Xgc27a954QQOI6Dpmn4/f5G9klxlEgOkC9/sPb+sNXX12MYBolEglAohOu6LHn6Cf5++xSE62JpOe8Iti1dgiPA0lQcReAIwdq3F+BTQVMUMgIs288Hr79B1bbtXHXPPSTDYVRVpb6+Hp/P591bllXgq3LgZDIZDMPAdV10XfcE3jRN77mcqEHWazQMg0wmg2maAGia5nnxQgjS6bQnsKqqkslkvPOl02l0XW/0XM4DdRznK/bJNUeJ5ADRNA3HcXAcB1VVcRwHXddxXdf7sBuGQSqV4p1HZ/Lv++/FUsGvq/h1FUvXsAwVS1fxawqWrmA1PO7TVXy6hk9T8KkCU1MoKikhFqnjT9ddh9vwgfb5fDiOg2manni0NHKeoKIoJBIJT6wSiQSappFIJABIpVKeyOWei8fjuK5LKpVqtLRgGAbpdLqRwO4turnncn8zTdMQQqBpX83ikOIokRwgqpr92OSmszmRzH3IhBCoqsrGJe+y9Ik/o6cT+A3Vu9l69ubXVSxDxa9r2ec08GtKg4BmRTJdW039jq3Yfj9CuNx89jmeDa7rera0RHLXbO/XkFtP3Ps6whdrja7rsnv3bv75z3/SuXNnunXrRocOHejYsSOPPvooO3fuBPDETlVVVFVFUZRGf6PcMbnn9rUs0XKvrERSIJLJJIZhYBgGjuPg9/u99bH6+npM0yRSXUXFu2/j7tmBX1eygqg1Fsbc435dxa+pWIaGT1OznqSm4tM0TFUhVVfL7nVrKWvdBjsUZMNHHxGJRPD5fESjUdJ7BXlaEolEgkAgQCaTIRAIeFPlYDDoLUs4joNlWfh8PoQQrF69mgEDBnDFFVdQV1dHJBKhrq6Ouro6brjhBo444gjeffddIpEIpmkSi8U8jzQajWJZFpFIBF3XicViJJNJIOudfhkpjhLJAeL3+4nFYsTjcTRNIxaLeR/ooqIiEokESibF8if/nPUKNQ2/rmEb2XvLE0ct6yVqSoOnmJ1am1r2Z5+mYDbc6nZuY8PSJbTv2pU3/vIXikJB4vE4xcXFLXK9EbIpV3V1dRiGQV1dnTeNrqmpwbZtampq0DSNaDRKIpFg0aJFXH755dTU1Ox3zFQqxVVXXcWCBQtIJpOEQiFvah0Oh4nH45SUlJBOpwkGg14QZl/XUIpjM8VxnBbrETQnclHMfBKNRgmFQgSDQTKZDMFgkNraWmzbprKyEtu2+edPr8RnZKfG/gZR9OtK9t7MTq+/EMSsQPo1sHSBpatZgdSyjxuaiqlrVG/bzNpF75CIRXntmWcIhUJUVFQQi8Xy/hoPBZFIhJKSElKpFCUlJZ5QlZeXE4lEKC8v93b72LbN+++/z5YtW/7ruDt37mT+/PlYlkVNTQ3JZBLXdamqqiIQCFBRUYFpmtTW1hKNRgH2eQ1b5krut5zly5ezYsUK1q1bx7nnnsvAgQNb7KJ7oWmKFBdN08hkMt5aVSaTQdd1HMfB5/NRs2E9sc0b8esamqKgqgqqAEVREYALuK7AVQUugkxGkFEh46poiosOqAgUxQVHbfgNQNVI1NXgplJsXL6c488+B38ggGEYeX+Nh4JQKERVVRXhcJiamhpM08QwDKqqqigtLaWyspJQKEQkEkFVVWzb9n63VatWdOrUqdG65NatW9m9ezdCCC+Vp6ioiHQ6jeM4lJSUUF1dTWlpKXV1dYTDYS+o8+U0HpCeY7Nk1KhR1NTUcNJJJzFy5EgikUihTZLsRS5Y4LquFyTIBUgURWHpPb9DTUa96LRfU/Ab2ai0X1cagi4NgRdNw2dko9c+XcGnaRgamKqCT1MwVBVDVTFVBUNTSEbq2L3uU5bNe5mqHTu8HL6WSDKZxLIsMpkMfr/fcwD8fr+X6+i6bvYLp6aGd955B4BAIMADDzzAggULWLBgAfPnz2f+/PlMnz6dkpISAJYtW8aOHTtIJBI4joMQwlsXj
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAFpCAYAAAC8p8I3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nOzdeXwURd7H8U/NmUzug4Rb7kMRQW4RFRFFlwVdXFRUdlFRQTxQ8FgUFVRQ1xMeZBE8F0VXF1R0PUCUS+VQhKDcoCQkQO7MPdNdzx+ZzMJKAMlMepLUm1deyXSa7u8M5Jea6qpqIaVEURRFqTtMRgdQFEVRfh9VuBVFUeoYVbgVRVHqGFW4FUVR6hhVuBVFUeoYVbgVRVHqmKgVbiHEECHEdiHELiHE/dE6j6IoSkMjojGOWwhhBnYAg4FcYD1wjZTyp4ifTFEUpYGJVou7N7BLSrlHSukHFgHDo3QuRVGUBiVahbsZsP+Ix7mhbYqiKEoNWYw6sRDiZuBmAKvV2qNr165GRTkhn89HRUUFmZmZRkepVmlpKVarlYSEBKOjVCs/P5+srCzMZrPRUar166+/0rJlS6NjVCsYDHL48GGaNGlidJRqOZ1OgsEgqampRkep1uHDh0lOTsZutxsdpVpbt27F4/GIY35TShnxD6Af8NkRjx8AHqhu/6ysLBnLdu7cKefNm2d0jONavHixXLt2rdExjmv69OmyuLjY6BjV0nVdTpgwwegYx1VUVCQff/xxo2Mc1+rVq+WSJUuMjnFcc+fOlTt37jQ6xnGF6uIxa2a0ukrWA+2FEK2FEDbgauDDKJ1LURSlQYlKV4mUMiiEmAB8BpiBV6SUW6NxLkVRlIYman3cUspPgE+idXxFUZSGSs2cVJQo0HWd9evX8/rrr6PretW1HkWJCFW4FSUKunbtyrx58yguLqZJkyZUVFQYHekomqZx4MABo2Mop8iw4YCKUl99/fXXDB48mMmTJ1NYWIjP5+PTTz9l5MiRRkcDKvPt2rWLnJwcunfvzsUXX0zjxo2NjqX8DnWuxb1nzx7effddo2MoSrWCwSAWi4WcnBw++OADrFYrgUDA6Fhhs2fPxul0cvfdd/POO++wfft2oyMpv1OdanEPGzaMuLg4+vTpw5lnnsmKFStielKM0jANGjSICRMmsGTJElwuF4FAgN27dxsdC4BFixYxePBgrrjiCqZMmcK8efOYPHkyPXv2jOnJW0eSUuL1elm4cCGtWrWif//+xMfHGx2rVtWZFveePXuIi4tj7ty5jBgxgksuuYScnByjYynKMeXk5DBr1ixuueUW8vPzSUpKMjoSACNHjmTZsmUsX74cv9/P1KlTGTt2LA6Hw+hoJ83j8dC8eXN8Ph+fffYZnTt3NjpSraszhXvjxo307t2b8vJyXn31VRo3bsyqVavU1XolJpnNZkwmEyaTCYvFghDHnrlc20wmE0OHDuX777/HarWyfft2WrZsGTP5Tsbbb7/No48+Sq9evbj11lu58cYbWbx4sdGxalWd6Sr585//zJlnnklBQQFNmjRh0qRJ5Ofn16n/cErD0q1bN9asWcOWLVs488wzjY4TNnr0aPx+P0888QSnnXYabdu2NTrS79K8eXNeeuklCgoKmDBhAnl5eVx88cVGx6pVdabFDbBixQqGDh2Ky+WioKCAW2+9ldLSUqNjKcoxZWVloes6hYWFRkf5DZvNxvnnn4+UEr/fb3Sc32XAgAF8+umnBINBpk6dyjfffEOfPn2MjlWr6lThzszM5IILLuChhx4iOzub1157jYceeohdu3YZHU1R6pyBAwfi8/n45ptvjI7yu/Xq1YtevXpx0003sWnTJqPj1Lo6VbirVHWPpKamcuedd/Lmm2+yb98+Y0MpilIr5syZw6233sqIESPo3bt3g+wurZOF+0jt2rVjzJgxTJ06FbfbbXQc5RQsXbqUAQMGcPHFFxMIBNA0DU3T1IXnWmAymerUlPwZM2bQoUMHRo0aZXQUQ9X5wg3QqlUr5s6dy+jRo8nPz4/qufbu3RvT/8mDwSD79+8/8Y4xZMiQIXz22We89dZb9OvXj169etG7d282bNjAtm3b2LZtGx6Px+iY9dKDDz7IjBkzcDqdRkc5obKyMg4cOECXLl0aZCv7SHVmVMmJOBwOZs2axaxZsxgzZgzt27eP+DmWL1/OBx98QNeuXenVqxdnnXVWxM9RU2+99RabN2+mU6dODB06tE5MZbZYLFgsFhwOBxs2bAAqJ1lMnDgxXFDOPPPM8ASRgQMH1rmRELHKZrPF1KzO6pSVlTFnzhxGjhxJmzZtDMuxZcsWkpKSaNWq1Unt/9lnn7F//35atWrFRRddFLEc9aZwAzRp0oQxY8awYMECJk6cSHZ2dkSPbzKZeOSRR3jsscc4fPhwjQq3x+Nh3LhxEWm9t23blqpbv6WmpnLXXXdxyy230KlTpxoV7p9++oknn3yyxvmgslX9ewgheP7558OPly5dSlFREQD/+Mc/OHjwIABXXHEFl19+eUQyKrHr4MGDbNu2jQceeMCQ8wcCAW655RaaNm2K1+vF6/Uya9aso1r+q1atYv78+Uf9vbPOOouMjIyI366vXhVugPbt2zNx4kTGjRvHm2++GdFpvAMHDuTiiy9m9+7deDyeGq2ZEh8fz2uvvRaRt3xxcXFs3LgRqBwqdeGFF3LgwAFuuOGGGj3/Ll26MHXq1BrnA0hLS6vRtO+hQ4eGvy4oKAivtvevf/2LRx99FICLLroonNdsNh93NmBVK9NqtZ5yppNR9Uu0Z8+eMTN78n+98cYbXHPNNSxdutToKMek6zp33nknr732mmEZKioq2LFjB4899hgul4sBAwawbNmyo6baDxw4kAcffPCov9e0adOoLCVQ7wo3QHZ2Nm+++SYTJkxg6tSptG7dOiLH/emnn3j//feZM2cOmqZx//331+h4QoiI9dVVFe49e/bwxRdfcMcddzBu3Dj69+9fo+OaTLF3GaRx48bhdxIPPPBA+N9h+fLl/PGPfwQqf2BuvfVWoLI49+3bFyEEUkrWrl3Lhg0b0HWdXr160b9//6j1mWZkZFBSUoKu61E5fiQ0a9aMvLw8o2NUa/PmzWRnZ0f8HfTvcd999/H888+zevVq5syZg9PpZPTo0cycOfOo/Wrr56VeFm6AhIQEpk6dyuuvv871118fkT7RTZs28f777+P1ern44otjsqgtW7YMn89H69atadmyZUxmjKQjf/kNHjyYwYMHA5CXl8ecOXPC+3z55ZdAZd/57NmzmT17NmazmVGjRrF3796YvvN8Q7ZixQpWrFjBP/7xD0NzzJo1i06dOjF//nwmTJjAvffey5NPPmnYRdJ6W7gBWrduzfXXX8/TTz/Nk08+SUpKSo2ON2rUKFavXh3VFlpNTZw4kc2bN9OzZ0+joxiqWbNmPP744wD4/X6WL18OwL333ktFRQULFiwIv/2+9957eeaZZ4yMayghBFOmTOHvf/87kyZNMjpOmJSSzz//nEsvvRS73W5oFqvVyty5c1m3bh3Jycm8/PLLhuap14UbKi/cPfXUU4wePZoFCxaQkZFRo+Ode+65EUoWHTabrcEX7f9ls9m49NJLAejbty/9+/fnxRdfJCkpiT//+c988MEHBic0lhCCbt268eGHHxodJUzTND766COaNm1Kv379jI6D2WxmyJAhnH/++ZhMJsN/kdT7wg2QnJzMK6+8wvTp0xk/fnxUhgoqdUNKSgrdu3dn3rx5CCFo0aIFqampRsdS/kdJSQnz5s3jk09i637jsbLud4Mo3ADp6encdtttvPHGG9x00020bNnS6EiKAUwmEwsXLuSnn35CSskZZ5xhdKSYkJmZSceOHfnmm29iooU7e/ZsbrnlFqNjxKz6feXqf7Rr146xY8dy//33q+nxDdzpp59ea0XbbDajaVqtnOtUpaam0qJFi5i4OUkwGOTDDz9k2LBhRkeJWQ2qcAO0aNGCBQsWMGbMGHJzc42OozQAS5cu5ZJLLjE6Rp3x4IMP8tFHH8XsAIB
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"m.plot(Q)"
]
},
{
"source": [
"## Checking the Policy\n",
"\n",
"Since Q-Table lists the \"attractiveness\" of each action at each state, it is quite easy to use it to define the efficient navigation in our world. In the simplest case, we can just select the action corresponding to the highest Q-Table value:"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"2"
]
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"def qpolicy_strict(m):\n",
" x,y = m.human\n",
" v = probs(Q[x,y])\n",
" a = list(actions)[np.argmax(v)]\n",
" return a\n",
"\n",
"walk(m,qpolicy_strict)"
]
},
{
"source": [
"If you try the code above several times, you may notice that sometimes it just \"hangs\", and you need to press the STOP button in the notebook to interrupt it. \n",
"\n",
"> **Task 1:** Modify the `walk` function to limit the maximum length of path by a certain number of steps (say, 100), and watch the code above return this value from time to time.\n",
"\n",
"> **Task 2:** Modify the `walk` function so that it does not go back to the places where is has already been previously. This will prevent `walk` from looping, however, the agent can still end up being \"trapped\" in a location from which it is unable to escape. "
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Average path length = 3.45, eaten by wolf: 0 times\n"
]
}
],
"source": [
"\n",
"def qpolicy(m):\n",
" x,y = m.human\n",
" v = probs(Q[x,y])\n",
" a = random.choices(list(actions),weights=v)[0]\n",
" return a\n",
"\n",
"print_statistics(qpolicy)"
]
},
{
"source": [
"## Investigating the Learning Process"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
4 years ago
"[<matplotlib.lines.Line2D at 0x7fbab852fd68>]"
]
},
"metadata": {},
"execution_count": 15
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
4 years ago
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 382.368508 248.518125\" width=\"382.368508pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M -0 248.518125 \nL 382.368508 248.518125 \nL 382.368508 0 \nL -0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 39.65 224.64 \nL 374.45 224.64 \nL 374.45 7.2 \nL 39.65 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"m2c059e3741\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n </defs>\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"54.868182\" xlink:href=\"#m2c059e3741\" y=\"224.64\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <defs>\n <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n </defs>\n <g transform=\"translate(51.686932 239.238437)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-48\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"115.746997\" xlink:href=\"#m2c059e3741\" y=\"224.64\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 2000 -->\n <defs>\n <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n </defs>\n <g transform=\"translate(103.021997 239.238437)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-50\"/>\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_3\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"176.625812\" xlink:href=\"#m2c059e3741\" y=\"224.64\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 4000 -->\n <defs>\n <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3de3xcdZ3/8deHVlBxf1ykPxYBN3XF9Yd3rAKL666i3HdxFRV+/rCrII+foj9FH6th2ZW7UAEREAqFVoHl0gqVFgK9t/RCL6T3W9qm1zSkbdqkSZu0uX5/f8x3kkkykzmTzMyZzHk/H488cuZ7zsx8z5yZz/me7+2Ycw4REYmGY8LOgIiI5I+CvohIhCjoi4hEiIK+iEiEKOiLiETI8LAz0J9TTjnFlZSUhJ0NEZEhZfny5fudcyOSrSvooF9SUkJ5eXnY2RARGVLMbGeqdareERGJEAV9EZEIUdAXEYkQBX0RkQhR0BcRiRAFfRGRCFHQFxGJEAV9EQlswZZadh5oCjsbMggFPThLRArLteOXAbDj3stDzokMlEr6IiIRoqAvIhIhCvoiIhGioC8iEiEK+iIiEaKgLyISIYGCvpndZGbrzWydmb1gZu82s5FmttTMKs1sopkd67c9zj+u9OtLEl7nZp++ycwuzs0uiYhIKmmDvpmdDvw/YJRz7uPAMOBqYAzwoHPuw0A9cJ1/ynVAvU9/0G+HmZ3tn/cx4BLgMTMblt3dERGR/gSt3hkOvMfMhgPvBWqALwMv+fVPA1/zy1f6x/j1F5qZ+fQXnXMtzrntQCXw+cHvgoiIBJU26DvnqoH7gV3Egn0DsBw46Jxr95vtBk73y6cDVf657X779yemJ3lOFzO7wczKzay8trZ2IPskIiIpBKneOYlYKX0k8AHgeGLVMznhnBvnnBvlnBs1YkTS+/qKiMgABane+Qqw3TlX65xrAyYDFwAn+uoegDOAar9cDZwJ4NefABxITE/yHBERyYMgQX8XcJ6ZvdfXzV8IbADmAlf5bUYDU/zyVP8Yv36Oc8759Kt9756RwFnAsuzshoiIBJF2lk3n3FIzewlYAbQDK4FxQBnwopnd5dPG+6eMB541s0qgjliPHZxz681sErETRjtwo3OuI8v7IyIi/Qg0tbJz7lbg1l7J20jS+8Y5dxT4ZorXuRu4O8M8iohIlmhErohIhCjoi4hEiIK+iEiEKOiLiESIgr6ISIQo6IuIRIiCvohIhCjoi8iAHG3r4NDRtrCzIRlS0BeRAbnwgTf5xG0zws6GZEhBX0QGpPrgkbCzIAOgoC8iEiEK+iIiEaKgLyISIQr6IiIRoqAvEhGfu3sW4+ZvDTsbEjIFfZGIqD3Uwm9erwg7GxIyBX0RkQgp2qD/y5dW8/Ly3WFnQ0SkoBRt0J9Uvptf/Hl12NkQESkoRRv0RUSkLwV9EZEIUdAXkUHp7HTc8eoGquqaw86KBKCgLyKDsqGmkQmLtnPj8yvCzooEoKAvIoPiXOx/Z3xBCpqCvohIhCjoi4hEiIK+iEiEKOiLiESIgr6IdGlqaef6p9+mpkG3QixWCvoi0qVsbQ2zNu7jgRmbw86K5IiCvohIhCjoi4hEiIK+iGSFxmYNDQr6IjIoZmHnQDKhoC8iEiEK+iIRN2vDXta/0xB2NiRPFPRFIu76Z8q5/OGFg36d9e80crilPQs5klxS0BeRrHl9bU3YWZA0AgV9MzvRzF4yswoz22hm55vZyWY208y2+P8n+W3NzB42s0ozW2Nm5yS8zmi//RYzG52rnRKR/FGvnaElaEn/IWCac+6jwKeAjUApMNs5dxYw2z8GuBQ4y//dAIwFMLOTgVuBc4HPA7fGTxQiIpIfaYO+mZ0AfBEYD+Cca3XOHQSuBJ72mz0NfM0vXwk842KWACea2WnAxcBM51ydc64emAlcktW9ERGRfgUp6Y8EaoE/mtlKM3vKzI4HTnXOxSvw9gCn+uXTgaqE5+/2aanSezCzG8ys3MzKa2trM9sbERHpV5CgPxw4BxjrnPsM0ER3VQ4AzjkHZKVmzzk3zjk3yjk3asSIEdl4SRER8YIE/d3AbufcUv/4JWIngb2+2gb/f59fXw2cmfD8M3xaqnQREcmTtEHfObcHqDKzv/NJFwIbgKlAvAfOaGCKX54KfNf34jkPaPDVQNOBi8zsJN+Ae5FPE5EhTNMwDC3DA273E+A5MzsW2AZ8j9gJY5KZXQfsBL7lt30duAyoBJr9tjjn6szsTuBtv90dzrm6rOyFiIgEEijoO+dWAaOSrLowybYOuDHF60wAJmSSQREZvLqm1vy8kfrsFzyNyBWJgB88Ux52FqRAKOiLRMDu+uawsyAFQkFfRLJqxa56Gprbws6GpKCgLyJZ9fXH3uI745eEnQ1JQUFfJAKM/ParXFfdmNf3k+AU9EVEIkRBX0QkQhT0RSJAo2YlTkFfRCRCFPRFJJCfvbgy7TZOQ3ILnoK+SAQMtHZnwZZafvJCLNi/suqd7GVIQqOgLyIpXTt+Ga+uVrAvJgr6IhFgaskVT0FfRCRCFPRFJGuc2nELnoK+iAxK4xFNrjaUKOiLSMZmrN/TtVw6eW2IOZFMKeiLSMZueHZ51/Kuuu65+lW7U/gU9EWkj5b2TqrqdOOVYhT0xugiEiGvrn5H/fOLlEr6IpI1ib13Nu05FF5GJCUFfZEICGNs1hWPLMj/m0paCvoikhNtHWrWLUQK+iIiEaKgLxIBmnpH4hT0RUQiREFfRLJGN1EpfAr6IhFgA76NSozTTGpFQ0FfRCRCFPRFRCJEQV8kAtR7R+IU9EUka1T1X/gU9EVEIkRBX0QkQhT0RUQiREFfRNIKWlevKv3Cp6AvEgF567yjltyCp6AvIt1SxOw9jUfzmw/JmcBB38yGmdlKM3vNPx5pZkvNrNLMJprZsT79OP+40q8vSXiNm336JjO7ONs7IyK58ff3zgk7C5IlmZT0fwpsTHg8BnjQOfdhoB64zqdfB9T79Af9dpjZ2cDVwMeAS4DHzGzY4LIvIkFY0NFZGsRV9AIFfTM7A7gceMo/NuDLwEt+k6eBr/nlK/1j/PoL/fZXAi8651qcc9uBSuDz2dgJEREJJmhJ//fAL4FO//j9wEHnXLt/vBs43S+fDlQB+PUNfvuu9CTP6WJmN5hZuZmV19bWZrArIiKSTtqgb2ZXAPucc8vzkB+cc+Occ6Occ6NGjBgx6Nc72NyahVyJDG35qrWZWF6VfiMJVZCS/gXAv5jZDuBFYtU6DwEnmtlwv80ZQLVfrgbOBPDrTwAOJKYneU7OrKw6mOu3EBFvXXVj2FmQNNIGfefczc65M5xzJcQaYuc4574DzAWu8puNBqb45an+MX79HBe7A8NU4Grfu2ckcBawLGt7IiIiaQ1Pv0lKvwJeNLO7gJXAeJ8+HnjWzCqBOmInCpxz681sErABaAdudM51DOL9RSQo9coRL6Og75ybB8zzy9tI0vvGOXcU+GaK598N3J1pJkUk90pKy/i///i3YWdDckwjckWky9pqtYEVOwV9kQhQ7Y7EKeiLiESIgr6ISIQo6ItIl/2HNJix2Cnoi0iXTXsPhZ0FybGiD/pqwBLJYJZNKXpFH/R1Hx8RkW6DGZErIkPYC8t2cYwuACJHQV8kApLF9psnr817PiR8RV+9IyIi3RT0RYpce0cn2/c3hZ0NKRAK+iJF7rfTN9HeqS4NEqOgL1Lkxs3f1ietpuFICDmRQqCgLxIx9U2tnH/PnLCzISFR0BeJmMajbWFnQUJU9EFf3ZBFRLoVfdAXEZFuRR/01WdBRKRb0Qd9ERHppqAvEjGmlq5IU9AXiZjaw0fDzoKESEFfJGK+MXZx2FmQECnoi4hEiIK+iEiEKOiLiERI0Qd99VMQEelW9EFfRES6KeiLiESIgr6ISIQUfdDX3Dsi4SkpLWPmhr1hZ0MSFH3QF5Fw3Tx5TdhZkAQK+iJSMK54ZAGvrKwOOxtFTUFfRArGuupGfjZxVdjZKGoK+iIiEaKgLyI5piGShaTog76+biJhUx+6QlL0QV+kmIydt5X
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"plt.plot(lpath)"
]
},
{
"source": [
"What we see here is that at first the average path length increased. This is probably due to the fact that when we know nothing about the environment - we are likely to get trapped into bad states, water or wolf. As we learn more and start using this knowledge, we can explore the environment for longer, but we still do not know well where apples are.\n",
"\n",
"Once we learn enough, it becomes easier for the agent to achieve the goal, and the path length starts to decrease. However, we are still open to exploration, so we often diverge away from the best path, and explore new options, making the path longer than optimal.\n",
"\n",
"What we also observe on this graph, is that at some point the length increased abruptly. This indicates stochastic nature of the process, and that we can at some point \"sploil\" the Q-Table coefficients, by overwriting them with new values. This ideally should be minimized by decreasing learning rate (i.e. towards the end of training we only adjust Q-Table values by a small value).\n",
"\n",
"Overall, it is important to remember that the success and quality of the learning process significantly depends on parameters, such as leaning rate, learning rate decay and discount factor. Those are often called **hyperparameters**, to distinguish them from **parameters** which we optimize during training (eg. Q-Table coefficients). The process of finding best hyperparameter values is called **hyperparameter optimization**, and it deserves a separate topic."
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"## Exercise\n",
"#### A More Realistic Peter and the Wolf World\n",
"\n",
"In our situation, Peter was able to move around almost without getting tired or hungry. In a more realistic world, he has to sit down and rest from time to time, and also to feed himself. Let's make our world more realistic by implementing the following rules:\n",
"\n",
"1. By moving from one place to another, Peter loses **energy** and gains some **fatigue**.\n",
"2. Peter can gain more energy by eating apples.\n",
"3. Peter can get rid of fatigue by resting under the tree or on the grass (i.e. walking into a board location with a tree or grass - green field)\n",
"4. Peter needs to find and kill the wolf\n",
"5. In order to kill the wolf, Peter needs to have certain levels of energy and fatigue, otherwise he loses the battle.\n",
"\n",
"Modify the reward function above according to the rules of the game, run the reinforcement learning algorithm to learn the best strategy for winning the game, and compare the results of random walk with your algorithm in terms of number of games won and lost.\n",
"\n",
"\n",
"> **Note**: You may need to adjust hyperparameters to make it work, especially the number of epochs. Because the success of the game (fighting the wolf) is a rare event, you can expect much longer training time.\n",
"\n"
],
"cell_type": "markdown",
"metadata": {}
}
]
}