Upload ddpg_pendulum.ipynb
Browse files- ddpg_pendulum.ipynb +1077 -0
ddpg_pendulum.ipynb
ADDED
|
@@ -0,0 +1,1077 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "W5ut4-Uo_wFL"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"# Deep Deterministic Policy Gradient (DDPG)\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Author:** [amifunny](https://github.com/amifunny)<br>\n",
|
| 12 |
+
"**Date created:** 2020/06/04<br>\n",
|
| 13 |
+
"**Last modified:** 2020/09/21<br>\n",
|
| 14 |
+
"**Description:** Implementing DDPG algorithm on the Inverted Pendulum Problem."
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "markdown",
|
| 19 |
+
"metadata": {
|
| 20 |
+
"id": "1eX-gAYp_wFP"
|
| 21 |
+
},
|
| 22 |
+
"source": [
|
| 23 |
+
"## Introduction\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"**Deep Deterministic Policy Gradient (DDPG)** is a model-free off-policy algorithm for\n",
|
| 26 |
+
"learning continous actions.\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"It combines ideas from DPG (Deterministic Policy Gradient) and DQN (Deep Q-Network).\n",
|
| 29 |
+
"It uses Experience Replay and slow-learning target networks from DQN, and it is based on\n",
|
| 30 |
+
"DPG,\n",
|
| 31 |
+
"which can operate over continuous action spaces.\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"This tutorial closely follow this paper -\n",
|
| 34 |
+
"[Continuous control with deep reinforcement learning](https://arxiv.org/pdf/1509.02971.pdf)\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"## Problem\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"We are trying to solve the classic **Inverted Pendulum** control problem.\n",
|
| 39 |
+
"In this setting, we can take only two actions: swing left or swing right.\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"What make this problem challenging for Q-Learning Algorithms is that actions\n",
|
| 42 |
+
"are **continuous** instead of being **discrete**. That is, instead of using two\n",
|
| 43 |
+
"discrete actions like `-1` or `+1`, we have to select from infinite actions\n",
|
| 44 |
+
"ranging from `-2` to `+2`.\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"## Quick theory\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"Just like the Actor-Critic method, we have two networks:\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"1. Actor - It proposes an action given a state.\n",
|
| 51 |
+
"2. Critic - It predicts if the action is good (positive value) or bad (negative value)\n",
|
| 52 |
+
"given a state and an action.\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"DDPG uses two more techniques not present in the original DQN:\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"**First, it uses two Target networks.**\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"**Why?** Because it add stability to training. In short, we are learning from estimated\n",
|
| 59 |
+
"targets and Target networks are updated slowly, hence keeping our estimated targets\n",
|
| 60 |
+
"stable.\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"Conceptually, this is like saying, \"I have an idea of how to play this well,\n",
|
| 63 |
+
"I'm going to try it out for a bit until I find something better\",\n",
|
| 64 |
+
"as opposed to saying \"I'm going to re-learn how to play this entire game after every\n",
|
| 65 |
+
"move\".\n",
|
| 66 |
+
"See this [StackOverflow answer](https://stackoverflow.com/a/54238556/13475679).\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"**Second, it uses Experience Replay.**\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"We store list of tuples `(state, action, reward, next_state)`, and instead of\n",
|
| 71 |
+
"learning only from recent experience, we learn from sampling all of our experience\n",
|
| 72 |
+
"accumulated so far.\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"Now, let's see how is it implemented."
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"execution_count": 1,
|
| 80 |
+
"metadata": {
|
| 81 |
+
"id": "EhtEA5C1_wFR"
|
| 82 |
+
},
|
| 83 |
+
"outputs": [],
|
| 84 |
+
"source": [
|
| 85 |
+
"import gym\n",
|
| 86 |
+
"import tensorflow as tf\n",
|
| 87 |
+
"from tensorflow.keras import layers\n",
|
| 88 |
+
"import numpy as np\n",
|
| 89 |
+
"import matplotlib.pyplot as plt"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "markdown",
|
| 94 |
+
"metadata": {
|
| 95 |
+
"id": "vvhqTnJ8_wFT"
|
| 96 |
+
},
|
| 97 |
+
"source": [
|
| 98 |
+
"We use [OpenAIGym](http://gym.openai.com/docs) to create the environment.\n",
|
| 99 |
+
"We will use the `upper_bound` parameter to scale our actions later."
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"cell_type": "code",
|
| 104 |
+
"execution_count": 2,
|
| 105 |
+
"metadata": {
|
| 106 |
+
"id": "6limWVE-_wFU",
|
| 107 |
+
"outputId": "8d672186-664b-40c5-ce82-e3450bf42221",
|
| 108 |
+
"colab": {
|
| 109 |
+
"base_uri": "https://localhost:8080/"
|
| 110 |
+
}
|
| 111 |
+
},
|
| 112 |
+
"outputs": [
|
| 113 |
+
{
|
| 114 |
+
"output_type": "stream",
|
| 115 |
+
"name": "stdout",
|
| 116 |
+
"text": [
|
| 117 |
+
"Size of State Space -> 3\n",
|
| 118 |
+
"Size of Action Space -> 1\n",
|
| 119 |
+
"Max Value of Action -> 2.0\n",
|
| 120 |
+
"Min Value of Action -> -2.0\n"
|
| 121 |
+
]
|
| 122 |
+
}
|
| 123 |
+
],
|
| 124 |
+
"source": [
|
| 125 |
+
"problem = \"Pendulum-v0\"\n",
|
| 126 |
+
"env = gym.make(problem)\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"num_states = env.observation_space.shape[0]\n",
|
| 129 |
+
"print(\"Size of State Space -> {}\".format(num_states))\n",
|
| 130 |
+
"num_actions = env.action_space.shape[0]\n",
|
| 131 |
+
"print(\"Size of Action Space -> {}\".format(num_actions))\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"upper_bound = env.action_space.high[0]\n",
|
| 134 |
+
"lower_bound = env.action_space.low[0]\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"print(\"Max Value of Action -> {}\".format(upper_bound))\n",
|
| 137 |
+
"print(\"Min Value of Action -> {}\".format(lower_bound))"
|
| 138 |
+
]
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"cell_type": "markdown",
|
| 142 |
+
"metadata": {
|
| 143 |
+
"id": "SxQKZi35_wFU"
|
| 144 |
+
},
|
| 145 |
+
"source": [
|
| 146 |
+
"To implement better exploration by the Actor network, we use noisy perturbations,\n",
|
| 147 |
+
"specifically\n",
|
| 148 |
+
"an **Ornstein-Uhlenbeck process** for generating noise, as described in the paper.\n",
|
| 149 |
+
"It samples noise from a correlated normal distribution."
|
| 150 |
+
]
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"cell_type": "code",
|
| 154 |
+
"execution_count": 3,
|
| 155 |
+
"metadata": {
|
| 156 |
+
"id": "0u9tVI2J_wFV"
|
| 157 |
+
},
|
| 158 |
+
"outputs": [],
|
| 159 |
+
"source": [
|
| 160 |
+
"\n",
|
| 161 |
+
"class OUActionNoise:\n",
|
| 162 |
+
" def __init__(self, mean, std_deviation, theta=0.15, dt=1e-2, x_initial=None):\n",
|
| 163 |
+
" self.theta = theta\n",
|
| 164 |
+
" self.mean = mean\n",
|
| 165 |
+
" self.std_dev = std_deviation\n",
|
| 166 |
+
" self.dt = dt\n",
|
| 167 |
+
" self.x_initial = x_initial\n",
|
| 168 |
+
" self.reset()\n",
|
| 169 |
+
"\n",
|
| 170 |
+
" def __call__(self):\n",
|
| 171 |
+
" # Formula taken from https://www.wikipedia.org/wiki/Ornstein-Uhlenbeck_process.\n",
|
| 172 |
+
" x = (\n",
|
| 173 |
+
" self.x_prev\n",
|
| 174 |
+
" + self.theta * (self.mean - self.x_prev) * self.dt\n",
|
| 175 |
+
" + self.std_dev * np.sqrt(self.dt) * np.random.normal(size=self.mean.shape)\n",
|
| 176 |
+
" )\n",
|
| 177 |
+
" # Store x into x_prev\n",
|
| 178 |
+
" # Makes next noise dependent on current one\n",
|
| 179 |
+
" self.x_prev = x\n",
|
| 180 |
+
" return x\n",
|
| 181 |
+
"\n",
|
| 182 |
+
" def reset(self):\n",
|
| 183 |
+
" if self.x_initial is not None:\n",
|
| 184 |
+
" self.x_prev = self.x_initial\n",
|
| 185 |
+
" else:\n",
|
| 186 |
+
" self.x_prev = np.zeros_like(self.mean)\n"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "markdown",
|
| 191 |
+
"metadata": {
|
| 192 |
+
"id": "aiaIXtYc_wFW"
|
| 193 |
+
},
|
| 194 |
+
"source": [
|
| 195 |
+
"The `Buffer` class implements Experience Replay.\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"---\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"---\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"**Critic loss** - Mean Squared Error of `y - Q(s, a)`\n",
|
| 203 |
+
"where `y` is the expected return as seen by the Target network,\n",
|
| 204 |
+
"and `Q(s, a)` is action value predicted by the Critic network. `y` is a moving target\n",
|
| 205 |
+
"that the critic model tries to achieve; we make this target\n",
|
| 206 |
+
"stable by updating the Target model slowly.\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"**Actor loss** - This is computed using the mean of the value given by the Critic network\n",
|
| 209 |
+
"for the actions taken by the Actor network. We seek to maximize this quantity.\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"Hence we update the Actor network so that it produces actions that get\n",
|
| 212 |
+
"the maximum predicted value as seen by the Critic, for a given state."
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "code",
|
| 217 |
+
"execution_count": 4,
|
| 218 |
+
"metadata": {
|
| 219 |
+
"id": "HmrqnrR3_wFX"
|
| 220 |
+
},
|
| 221 |
+
"outputs": [],
|
| 222 |
+
"source": [
|
| 223 |
+
"\n",
|
| 224 |
+
"class Buffer:\n",
|
| 225 |
+
" def __init__(self, buffer_capacity=100000, batch_size=64):\n",
|
| 226 |
+
" # Number of \"experiences\" to store at max\n",
|
| 227 |
+
" self.buffer_capacity = buffer_capacity\n",
|
| 228 |
+
" # Num of tuples to train on.\n",
|
| 229 |
+
" self.batch_size = batch_size\n",
|
| 230 |
+
"\n",
|
| 231 |
+
" # Its tells us num of times record() was called.\n",
|
| 232 |
+
" self.buffer_counter = 0\n",
|
| 233 |
+
"\n",
|
| 234 |
+
" # Instead of list of tuples as the exp.replay concept go\n",
|
| 235 |
+
" # We use different np.arrays for each tuple element\n",
|
| 236 |
+
" self.state_buffer = np.zeros((self.buffer_capacity, num_states))\n",
|
| 237 |
+
" self.action_buffer = np.zeros((self.buffer_capacity, num_actions))\n",
|
| 238 |
+
" self.reward_buffer = np.zeros((self.buffer_capacity, 1))\n",
|
| 239 |
+
" self.next_state_buffer = np.zeros((self.buffer_capacity, num_states))\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" # Takes (s,a,r,s') obervation tuple as input\n",
|
| 242 |
+
" def record(self, obs_tuple):\n",
|
| 243 |
+
" # Set index to zero if buffer_capacity is exceeded,\n",
|
| 244 |
+
" # replacing old records\n",
|
| 245 |
+
" index = self.buffer_counter % self.buffer_capacity\n",
|
| 246 |
+
"\n",
|
| 247 |
+
" self.state_buffer[index] = obs_tuple[0]\n",
|
| 248 |
+
" self.action_buffer[index] = obs_tuple[1]\n",
|
| 249 |
+
" self.reward_buffer[index] = obs_tuple[2]\n",
|
| 250 |
+
" self.next_state_buffer[index] = obs_tuple[3]\n",
|
| 251 |
+
"\n",
|
| 252 |
+
" self.buffer_counter += 1\n",
|
| 253 |
+
"\n",
|
| 254 |
+
" # Eager execution is turned on by default in TensorFlow 2. Decorating with tf.function allows\n",
|
| 255 |
+
" # TensorFlow to build a static graph out of the logic and computations in our function.\n",
|
| 256 |
+
" # This provides a large speed up for blocks of code that contain many small TensorFlow operations such as this one.\n",
|
| 257 |
+
" @tf.function\n",
|
| 258 |
+
" def update(\n",
|
| 259 |
+
" self, state_batch, action_batch, reward_batch, next_state_batch,\n",
|
| 260 |
+
" ):\n",
|
| 261 |
+
" # Training and updating Actor & Critic networks.\n",
|
| 262 |
+
" # See Pseudo Code.\n",
|
| 263 |
+
" with tf.GradientTape() as tape:\n",
|
| 264 |
+
" target_actions = target_actor(next_state_batch, training=True)\n",
|
| 265 |
+
" y = reward_batch + gamma * target_critic(\n",
|
| 266 |
+
" [next_state_batch, target_actions], training=True\n",
|
| 267 |
+
" )\n",
|
| 268 |
+
" critic_value = critic_model([state_batch, action_batch], training=True)\n",
|
| 269 |
+
" critic_loss = tf.math.reduce_mean(tf.math.square(y - critic_value))\n",
|
| 270 |
+
"\n",
|
| 271 |
+
" critic_grad = tape.gradient(critic_loss, critic_model.trainable_variables)\n",
|
| 272 |
+
" critic_optimizer.apply_gradients(\n",
|
| 273 |
+
" zip(critic_grad, critic_model.trainable_variables)\n",
|
| 274 |
+
" )\n",
|
| 275 |
+
"\n",
|
| 276 |
+
" with tf.GradientTape() as tape:\n",
|
| 277 |
+
" actions = actor_model(state_batch, training=True)\n",
|
| 278 |
+
" critic_value = critic_model([state_batch, actions], training=True)\n",
|
| 279 |
+
" # Used `-value` as we want to maximize the value given\n",
|
| 280 |
+
" # by the critic for our actions\n",
|
| 281 |
+
" actor_loss = -tf.math.reduce_mean(critic_value)\n",
|
| 282 |
+
"\n",
|
| 283 |
+
" actor_grad = tape.gradient(actor_loss, actor_model.trainable_variables)\n",
|
| 284 |
+
" actor_optimizer.apply_gradients(\n",
|
| 285 |
+
" zip(actor_grad, actor_model.trainable_variables)\n",
|
| 286 |
+
" )\n",
|
| 287 |
+
"\n",
|
| 288 |
+
" # We compute the loss and update parameters\n",
|
| 289 |
+
" def learn(self):\n",
|
| 290 |
+
" # Get sampling range\n",
|
| 291 |
+
" record_range = min(self.buffer_counter, self.buffer_capacity)\n",
|
| 292 |
+
" # Randomly sample indices\n",
|
| 293 |
+
" batch_indices = np.random.choice(record_range, self.batch_size)\n",
|
| 294 |
+
"\n",
|
| 295 |
+
" # Convert to tensors\n",
|
| 296 |
+
" state_batch = tf.convert_to_tensor(self.state_buffer[batch_indices])\n",
|
| 297 |
+
" action_batch = tf.convert_to_tensor(self.action_buffer[batch_indices])\n",
|
| 298 |
+
" reward_batch = tf.convert_to_tensor(self.reward_buffer[batch_indices])\n",
|
| 299 |
+
" reward_batch = tf.cast(reward_batch, dtype=tf.float32)\n",
|
| 300 |
+
" next_state_batch = tf.convert_to_tensor(self.next_state_buffer[batch_indices])\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" self.update(state_batch, action_batch, reward_batch, next_state_batch)\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"# This update target parameters slowly\n",
|
| 306 |
+
"# Based on rate `tau`, which is much less than one.\n",
|
| 307 |
+
"@tf.function\n",
|
| 308 |
+
"def update_target(target_weights, weights, tau):\n",
|
| 309 |
+
" for (a, b) in zip(target_weights, weights):\n",
|
| 310 |
+
" a.assign(b * tau + a * (1 - tau))\n"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"cell_type": "markdown",
|
| 315 |
+
"metadata": {
|
| 316 |
+
"id": "yuatLEJ3_wFY"
|
| 317 |
+
},
|
| 318 |
+
"source": [
|
| 319 |
+
"Here we define the Actor and Critic networks. These are basic Dense models\n",
|
| 320 |
+
"with `ReLU` activation.\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"Note: We need the initialization for last layer of the Actor to be between\n",
|
| 323 |
+
"`-0.003` and `0.003` as this prevents us from getting `1` or `-1` output values in\n",
|
| 324 |
+
"the initial stages, which would squash our gradients to zero,\n",
|
| 325 |
+
"as we use the `tanh` activation."
|
| 326 |
+
]
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"cell_type": "code",
|
| 330 |
+
"execution_count": 5,
|
| 331 |
+
"metadata": {
|
| 332 |
+
"id": "OCCV2VAQ_wFY"
|
| 333 |
+
},
|
| 334 |
+
"outputs": [],
|
| 335 |
+
"source": [
|
| 336 |
+
"\n",
|
| 337 |
+
"def get_actor():\n",
|
| 338 |
+
" # Initialize weights between -3e-3 and 3-e3\n",
|
| 339 |
+
" last_init = tf.random_uniform_initializer(minval=-0.003, maxval=0.003)\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" inputs = layers.Input(shape=(num_states,))\n",
|
| 342 |
+
" out = layers.Dense(256, activation=\"relu\")(inputs)\n",
|
| 343 |
+
" out = layers.Dense(256, activation=\"relu\")(out)\n",
|
| 344 |
+
" outputs = layers.Dense(1, activation=\"tanh\", kernel_initializer=last_init)(out)\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" # Our upper bound is 2.0 for Pendulum.\n",
|
| 347 |
+
" outputs = outputs * upper_bound\n",
|
| 348 |
+
" model = tf.keras.Model(inputs, outputs)\n",
|
| 349 |
+
" return model\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"def get_critic():\n",
|
| 353 |
+
" # State as input\n",
|
| 354 |
+
" state_input = layers.Input(shape=(num_states))\n",
|
| 355 |
+
" state_out = layers.Dense(16, activation=\"relu\")(state_input)\n",
|
| 356 |
+
" state_out = layers.Dense(32, activation=\"relu\")(state_out)\n",
|
| 357 |
+
"\n",
|
| 358 |
+
" # Action as input\n",
|
| 359 |
+
" action_input = layers.Input(shape=(num_actions))\n",
|
| 360 |
+
" action_out = layers.Dense(32, activation=\"relu\")(action_input)\n",
|
| 361 |
+
"\n",
|
| 362 |
+
" # Both are passed through seperate layer before concatenating\n",
|
| 363 |
+
" concat = layers.Concatenate()([state_out, action_out])\n",
|
| 364 |
+
"\n",
|
| 365 |
+
" out = layers.Dense(256, activation=\"relu\")(concat)\n",
|
| 366 |
+
" out = layers.Dense(256, activation=\"relu\")(out)\n",
|
| 367 |
+
" outputs = layers.Dense(1)(out)\n",
|
| 368 |
+
"\n",
|
| 369 |
+
" # Outputs single value for give state-action\n",
|
| 370 |
+
" model = tf.keras.Model([state_input, action_input], outputs)\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" return model\n"
|
| 373 |
+
]
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
"cell_type": "markdown",
|
| 377 |
+
"metadata": {
|
| 378 |
+
"id": "gkg29m65_wFZ"
|
| 379 |
+
},
|
| 380 |
+
"source": [
|
| 381 |
+
"`policy()` returns an action sampled from our Actor network plus some noise for\n",
|
| 382 |
+
"exploration."
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "code",
|
| 387 |
+
"execution_count": 6,
|
| 388 |
+
"metadata": {
|
| 389 |
+
"id": "KmHbyy8l_wFZ"
|
| 390 |
+
},
|
| 391 |
+
"outputs": [],
|
| 392 |
+
"source": [
|
| 393 |
+
"\n",
|
| 394 |
+
"def policy(state, noise_object):\n",
|
| 395 |
+
" sampled_actions = tf.squeeze(actor_model(state))\n",
|
| 396 |
+
" noise = noise_object()\n",
|
| 397 |
+
" # Adding noise to action\n",
|
| 398 |
+
" sampled_actions = sampled_actions.numpy() + noise\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" # We make sure action is within bounds\n",
|
| 401 |
+
" legal_action = np.clip(sampled_actions, lower_bound, upper_bound)\n",
|
| 402 |
+
"\n",
|
| 403 |
+
" return [np.squeeze(legal_action)]\n"
|
| 404 |
+
]
|
| 405 |
+
},
|
| 406 |
+
{
|
| 407 |
+
"cell_type": "markdown",
|
| 408 |
+
"metadata": {
|
| 409 |
+
"id": "r2EUVRZA_wFa"
|
| 410 |
+
},
|
| 411 |
+
"source": [
|
| 412 |
+
"## Training hyperparameters"
|
| 413 |
+
]
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"cell_type": "code",
|
| 417 |
+
"execution_count": 7,
|
| 418 |
+
"metadata": {
|
| 419 |
+
"id": "8FELtxWr_wFa"
|
| 420 |
+
},
|
| 421 |
+
"outputs": [],
|
| 422 |
+
"source": [
|
| 423 |
+
"std_dev = 0.2\n",
|
| 424 |
+
"ou_noise = OUActionNoise(mean=np.zeros(1), std_deviation=float(std_dev) * np.ones(1))\n",
|
| 425 |
+
"\n",
|
| 426 |
+
"actor_model = get_actor()\n",
|
| 427 |
+
"critic_model = get_critic()\n",
|
| 428 |
+
"\n",
|
| 429 |
+
"target_actor = get_actor()\n",
|
| 430 |
+
"target_critic = get_critic()\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"# Making the weights equal initially\n",
|
| 433 |
+
"target_actor.set_weights(actor_model.get_weights())\n",
|
| 434 |
+
"target_critic.set_weights(critic_model.get_weights())\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"# Learning rate for actor-critic models\n",
|
| 437 |
+
"critic_lr = 0.002\n",
|
| 438 |
+
"actor_lr = 0.001\n",
|
| 439 |
+
"\n",
|
| 440 |
+
"critic_optimizer = tf.keras.optimizers.Adam(critic_lr)\n",
|
| 441 |
+
"actor_optimizer = tf.keras.optimizers.Adam(actor_lr)\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"total_episodes = 100\n",
|
| 444 |
+
"# Discount factor for future rewards\n",
|
| 445 |
+
"gamma = 0.99\n",
|
| 446 |
+
"# Used to update target networks\n",
|
| 447 |
+
"tau = 0.005\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"buffer = Buffer(50000, 64)"
|
| 450 |
+
]
|
| 451 |
+
},
|
| 452 |
+
{
|
| 453 |
+
"cell_type": "markdown",
|
| 454 |
+
"metadata": {
|
| 455 |
+
"id": "4RDsrs-U_wFa"
|
| 456 |
+
},
|
| 457 |
+
"source": [
|
| 458 |
+
"Now we implement our main training loop, and iterate over episodes.\n",
|
| 459 |
+
"We sample actions using `policy()` and train with `learn()` at each time step,\n",
|
| 460 |
+
"along with updating the Target networks at a rate `tau`."
|
| 461 |
+
]
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"cell_type": "code",
|
| 465 |
+
"execution_count": 8,
|
| 466 |
+
"metadata": {
|
| 467 |
+
"id": "ytHAvwkZ_wFb",
|
| 468 |
+
"outputId": "3bb12bba-bd57-4e17-d456-f3812dbab54c",
|
| 469 |
+
"colab": {
|
| 470 |
+
"base_uri": "https://localhost:8080/",
|
| 471 |
+
"height": 1000
|
| 472 |
+
}
|
| 473 |
+
},
|
| 474 |
+
"outputs": [
|
| 475 |
+
{
|
| 476 |
+
"output_type": "stream",
|
| 477 |
+
"name": "stdout",
|
| 478 |
+
"text": [
|
| 479 |
+
"Episode * 0 * Avg Reward is ==> -1562.2726156328647\n",
|
| 480 |
+
"Episode * 1 * Avg Reward is ==> -1483.9644029814049\n",
|
| 481 |
+
"Episode * 2 * Avg Reward is ==> -1505.4752484518995\n",
|
| 482 |
+
"Episode * 3 * Avg Reward is ==> -1502.139312848798\n",
|
| 483 |
+
"Episode * 4 * Avg Reward is ==> -1499.098591143217\n",
|
| 484 |
+
"Episode * 5 * Avg Reward is ==> -1498.9744436801384\n",
|
| 485 |
+
"Episode * 6 * Avg Reward is ==> -1513.1644515271164\n",
|
| 486 |
+
"Episode * 7 * Avg Reward is ==> -1468.1716524312988\n",
|
| 487 |
+
"Episode * 8 * Avg Reward is ==> -1427.4514428693592\n",
|
| 488 |
+
"Episode * 9 * Avg Reward is ==> -1386.0628554198065\n",
|
| 489 |
+
"Episode * 10 * Avg Reward is ==> -1330.6460640495982\n",
|
| 490 |
+
"Episode * 11 * Avg Reward is ==> -1296.9259472439123\n",
|
| 491 |
+
"Episode * 12 * Avg Reward is ==> -1266.648683865007\n",
|
| 492 |
+
"Episode * 13 * Avg Reward is ==> -1217.2723090173156\n",
|
| 493 |
+
"Episode * 14 * Avg Reward is ==> -1163.0702894847986\n",
|
| 494 |
+
"Episode * 15 * Avg Reward is ==> -1105.9657062118963\n",
|
| 495 |
+
"Episode * 16 * Avg Reward is ==> -1056.4251588131688\n",
|
| 496 |
+
"Episode * 17 * Avg Reward is ==> -1004.7175789706645\n",
|
| 497 |
+
"Episode * 18 * Avg Reward is ==> -958.4439292235802\n",
|
| 498 |
+
"Episode * 19 * Avg Reward is ==> -916.8559819842148\n",
|
| 499 |
+
"Episode * 20 * Avg Reward is ==> -879.2971938851208\n",
|
| 500 |
+
"Episode * 21 * Avg Reward is ==> -839.4309276948444\n",
|
| 501 |
+
"Episode * 22 * Avg Reward is ==> -813.1273589718702\n",
|
| 502 |
+
"Episode * 23 * Avg Reward is ==> -784.5041398737862\n",
|
| 503 |
+
"Episode * 24 * Avg Reward is ==> -765.0508430770639\n",
|
| 504 |
+
"Episode * 25 * Avg Reward is ==> -740.464676744745\n",
|
| 505 |
+
"Episode * 26 * Avg Reward is ==> -721.947957211692\n",
|
| 506 |
+
"Episode * 27 * Avg Reward is ==> -705.225509729946\n",
|
| 507 |
+
"Episode * 28 * Avg Reward is ==> -685.144228863127\n",
|
| 508 |
+
"Episode * 29 * Avg Reward is ==> -670.6879188788478\n",
|
| 509 |
+
"Episode * 30 * Avg Reward is ==> -653.0154864082411\n",
|
| 510 |
+
"Episode * 31 * Avg Reward is ==> -643.4128610660125\n",
|
| 511 |
+
"Episode * 32 * Avg Reward is ==> -635.5798183939222\n",
|
| 512 |
+
"Episode * 33 * Avg Reward is ==> -623.9639787229108\n",
|
| 513 |
+
"Episode * 34 * Avg Reward is ==> -616.205090622738\n",
|
| 514 |
+
"Episode * 35 * Avg Reward is ==> -606.1140412258295\n",
|
| 515 |
+
"Episode * 36 * Avg Reward is ==> -603.6670876160974\n",
|
| 516 |
+
"Episode * 37 * Avg Reward is ==> -600.9921602699909\n",
|
| 517 |
+
"Episode * 38 * Avg Reward is ==> -591.8512444239832\n",
|
| 518 |
+
"Episode * 39 * Avg Reward is ==> -580.1600306375576\n",
|
| 519 |
+
"Episode * 40 * Avg Reward is ==> -553.821931002297\n",
|
| 520 |
+
"Episode * 41 * Avg Reward is ==> -521.7188034600143\n",
|
| 521 |
+
"Episode * 42 * Avg Reward is ==> -486.36319375176225\n",
|
| 522 |
+
"Episode * 43 * Avg Reward is ==> -453.6710442310697\n",
|
| 523 |
+
"Episode * 44 * Avg Reward is ==> -425.8450281985057\n",
|
| 524 |
+
"Episode * 45 * Avg Reward is ==> -400.74408779723456\n",
|
| 525 |
+
"Episode * 46 * Avg Reward is ==> -366.61738270546164\n",
|
| 526 |
+
"Episode * 47 * Avg Reward is ==> -345.13626004307355\n",
|
| 527 |
+
"Episode * 48 * Avg Reward is ==> -323.61757746366766\n",
|
| 528 |
+
"Episode * 49 * Avg Reward is ==> -301.23857698979566\n",
|
| 529 |
+
"Episode * 50 * Avg Reward is ==> -284.8999331286917\n",
|
| 530 |
+
"Episode * 51 * Avg Reward is ==> -264.84457621322116\n",
|
| 531 |
+
"Episode * 52 * Avg Reward is ==> -248.26764695916563\n",
|
| 532 |
+
"Episode * 53 * Avg Reward is ==> -237.25723863370771\n",
|
| 533 |
+
"Episode * 54 * Avg Reward is ==> -230.53260988021324\n",
|
| 534 |
+
"Episode * 55 * Avg Reward is ==> -236.8247039675385\n",
|
| 535 |
+
"Episode * 56 * Avg Reward is ==> -242.88089725188564\n",
|
| 536 |
+
"Episode * 57 * Avg Reward is ==> -249.99625421933737\n",
|
| 537 |
+
"Episode * 58 * Avg Reward is ==> -256.24104876179\n",
|
| 538 |
+
"Episode * 59 * Avg Reward is ==> -259.44539205532294\n",
|
| 539 |
+
"Episode * 60 * Avg Reward is ==> -259.771282922727\n",
|
| 540 |
+
"Episode * 61 * Avg Reward is ==> -266.78264262398795\n",
|
| 541 |
+
"Episode * 62 * Avg Reward is ==> -264.49970490719676\n",
|
| 542 |
+
"Episode * 63 * Avg Reward is ==> -264.7907401035075\n",
|
| 543 |
+
"Episode * 64 * Avg Reward is ==> -263.65884770574297\n",
|
| 544 |
+
"Episode * 65 * Avg Reward is ==> -263.8187150138804\n",
|
| 545 |
+
"Episode * 66 * Avg Reward is ==> -263.88096070288253\n",
|
| 546 |
+
"Episode * 67 * Avg Reward is ==> -263.68977140982696\n",
|
| 547 |
+
"Episode * 68 * Avg Reward is ==> -272.91279743733804\n",
|
| 548 |
+
"Episode * 69 * Avg Reward is ==> -272.777443352942\n",
|
| 549 |
+
"Episode * 70 * Avg Reward is ==> -283.87400325047287\n",
|
| 550 |
+
"Episode * 71 * Avg Reward is ==> -278.45777238816385\n",
|
| 551 |
+
"Episode * 72 * Avg Reward is ==> -272.09964609736335\n",
|
| 552 |
+
"Episode * 73 * Avg Reward is ==> -269.45733302243724\n",
|
| 553 |
+
"Episode * 74 * Avg Reward is ==> -263.91679852075515\n",
|
| 554 |
+
"Episode * 75 * Avg Reward is ==> -264.0434345954452\n",
|
| 555 |
+
"Episode * 76 * Avg Reward is ==> -260.102765623681\n",
|
| 556 |
+
"Episode * 77 * Avg Reward is ==> -253.51808301808424\n",
|
| 557 |
+
"Episode * 78 * Avg Reward is ==> -250.83738958549662\n",
|
| 558 |
+
"Episode * 79 * Avg Reward is ==> -254.1812329126542\n",
|
| 559 |
+
"Episode * 80 * Avg Reward is ==> -250.125238569467\n",
|
| 560 |
+
"Episode * 81 * Avg Reward is ==> -250.27037014579363\n",
|
| 561 |
+
"Episode * 82 * Avg Reward is ==> -250.1389180516676\n",
|
| 562 |
+
"Episode * 83 * Avg Reward is ==> -245.8142236436616\n",
|
| 563 |
+
"Episode * 84 * Avg Reward is ==> -245.1797777642314\n",
|
| 564 |
+
"Episode * 85 * Avg Reward is ==> -236.02398746977263\n",
|
| 565 |
+
"Episode * 86 * Avg Reward is ==> -239.18889403843315\n",
|
| 566 |
+
"Episode * 87 * Avg Reward is ==> -247.41187644346664\n",
|
| 567 |
+
"Episode * 88 * Avg Reward is ==> -247.82499330593242\n",
|
| 568 |
+
"Episode * 89 * Avg Reward is ==> -250.9072749126738\n",
|
| 569 |
+
"Episode * 90 * Avg Reward is ==> -263.1470922715929\n",
|
| 570 |
+
"Episode * 91 * Avg Reward is ==> -278.58573644976707\n",
|
| 571 |
+
"Episode * 92 * Avg Reward is ==> -280.6476742795351\n",
|
| 572 |
+
"Episode * 93 * Avg Reward is ==> -280.70748492063154\n",
|
| 573 |
+
"Episode * 94 * Avg Reward is ==> -280.565226725522\n",
|
| 574 |
+
"Episode * 95 * Avg Reward is ==> -268.37926234836465\n",
|
| 575 |
+
"Episode * 96 * Avg Reward is ==> -256.03979865280746\n",
|
| 576 |
+
"Episode * 97 * Avg Reward is ==> -255.01505149822543\n",
|
| 577 |
+
"Episode * 98 * Avg Reward is ==> -245.98157845584518\n",
|
| 578 |
+
"Episode * 99 * Avg Reward is ==> -245.54148137920984\n"
|
| 579 |
+
]
|
| 580 |
+
},
|
| 581 |
+
{
|
| 582 |
+
"output_type": "display_data",
|
| 583 |
+
"data": {
|
| 584 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAEGCAYAAACgt3iRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3gc1fX/8feR5N57k3uvuAjTuwHTYorpxIQklAA/0kiAkATS+CakQiAkDpgaYgjVYMBgmik2Ri64GwtXucpNtmWr7vn9saOgGEleS9qdlfR5Pc8+2rkzu3uGMXv2lrnX3B0REZHqSAk7ABERqf2UTEREpNqUTEREpNqUTEREpNqUTEREpNrSwg4gLO3bt/devXqFHYaISK0yb9687e7e4eDyeptMevXqRWZmZthhiIjUKma2rrxyNXOJiEi1JV0yMbPfm9kKM1tkZi+aWesy++4wsywzW2lmZ5YpHx+UZZnZ7eFELiJSfyVdMgHeAoa5+wjgc+AOADMbAlwGDAXGA38zs1QzSwUeBM4ChgCXB8eKiEiCJF0ycfc33b042JwDpAfPJwBT3b3A3dcAWcDY4JHl7qvdvRCYGhwrIiIJknTJ5CDfBF4PnncDNpTZlx2UVVT+FWZ2nZllmllmTk5OHMIVEamfQhnNZWYzgc7l7LrT3V8OjrkTKAb+VVOf6+6TgckAGRkZmuFSRKSGhJJM3H1cZfvN7BvAucBp/uW0xhuB7mUOSw/KqKRcREQSIOnuMzGz8cCPgZPcfX+ZXdOAp83sT0BXoD8wFzCgv5n1JppELgOuSGzUIuHJ3V/EOyu3kl8U4eSBHejSqknYIUk9lHTJBHgAaAS8ZWYAc9z9BndfambPAsuINn/d5O4lAGZ2MzADSAWmuPvScEIXSYxIxHl18Waem5fNx1nbKY582Wo7rFtLThvUiVMHdWR4t1akpFiIkdZt+UUlTFu4iQ+zttO1dRP6dmjGwM4tGN6tFcH3V71h9XVxrIyMDNcd8FLbuDvvrczhd2+sYMWWvXRv24Szh3Vh/LDONG+Uxszl25i5fCvz1+/CHdo3b8Spgzpw7oiuHNu3HWmpyT7mJnltyc1nzfY8CopLKCiOsGzTHp6as44deYV0aNGI3fsLKSqJfp/2ateUS4/swcQx6XRo0SjUuAuLIyzbvIeCohKKI05hSYSjerelacOq1SXMbJ67Z3ylXMlEpHbYsa+A7z2zkA9Wbadnu6b88IyBnDu8S7k1j515hbz/+TbeWZHDeyu2sbegmLbNGjJ+WGdOH9KJY/q0o3GD1BDOIvH25hexdvt+Nuzaz8ZdB2jaKJX+HVvQr2Nz8gqKWbIxlyWbctmZV0SKgRkYRsQdB3L2FvDZht1s21vwlfc+dVBHvn18b47p246SiLNh1wHmr9vFM5kbmLtmJwApBmmpKTRKSyG9TVP6tG9Gnw7NuGh0Or3aN4vbee/YV8DTn6zniTnryDko9pk/OIl+HZtX6X2VTA6iZCK1yeLsXK5/MpMdeYXccdYgrjy6Jw1irGXkF5Xw/uc5vPLZJt5evo0DRSU0bZjKMX3aMbBzC/p0aE7Pdk1plJZCaorRtGEavdo1rdXNNJtzDzBz+TZmLNnCnNU7/qcZsDypKUabpg0BJ+LRGmBqigFGqyZpHJHemhHprejfqQWNG6TSKC2F9s0b0blV4wrfM2vbPmYs3cL+wmKKS5z8ohI27DrAmu15rN+5n1Qzrj2xNzed0q/KtYRSRSURnpy9jrdXbGVffjF5hSWs37mfwuIIJw7owCUZ6bRp2pC0FKNBWgpDurSs8o8JJZODKJlIbfHSgo3c9vwi2jVryD++nsHw9FZVfq/8ohJmr97B28u3MvuLHazbsb/cL9r0Nk04Z0QXzhjSmYg7m3Pz2bYnn5JI9Es2NcXIL4qwN7+IfQXFRNxJMSMlSEClv+4dJxKJfkH3bt+MiRnptGzcoMrxl9qbX8Qnq3cyb/0uiksipJhRHHGytu1j6aY9bN8X/SXep30zzhzWmSPSW9O9bRPS2zRlb34Rq7bt44tt+2jSMJVhXVsxsHOLhNbUtu3J57evr+CFBRvp0qox91wwnFMGdazSe83+Ygd3TVvC51v3MbhLSzq2aETzRml0adWYy8Z2p1/HFjUau5LJQZRMpDZ4YvZafv7yUo7q3Za/XTmads1rtv29qCTChp372bDrAEXFEYojzs68Qt5ctoUPV22P6Rd980ZppKVEm4VKItGmIRycaFJJMcMMdu8vonmjNC7J6M74YZ1JsegxnVo0pke7phV+xtY9+cz+YkcQ536ytu1jUXYuxREnLcVomJZCxB3D6NW+GUO7tmRIl5ac0L89/To2T+oaVubandz54hJWbt3L9Sf24dYzB8Zc49y0+wD3vLacVxdtJr1NE35+7hBOH9Ip7uerZHIQJRNJdg9/sJpfT1/O6UM68cAVo2iUltg+jl15hcxevYNmwa/cTi0bk5ZilLhTUuI0aRht7on1y2txdi6PfLiaVxdt/p8kZQYXjOzG908fQPe2TSmJOEs35fL+yhzeWr6VRdm5/z22Y4tG9GjblLG923J8//aM7tGm1vf95BeV8KtXl/GvT9Yzukdrrj+pL51aNqZji0Y0K23+MoIE7RRHnGc+3cAD72QRcef6k/py48l9E/bfQcnkIEomkqzcnQfeyeKPb33OOSO68JdLR8b8a7U22JKbz4ote/5bY/lw1XYe/XgtOIzt3ZZF2bvZkx+dnm9Uj9aMG9yJUwZ2pE+HZrU+cVTm1UWbuOP5xewtKD70wcD4oZ2585zBdG9bca0uHpRMDqJkIslof2Extz+/mGmfbeLCUd24d+KIejGcd9PuA/xl5ufMX7+bMT3acGy/dhzbt33ow2oTbW9+Eet27Gfrnny27ikgv6gEJ/oDw8yio82AwV1aclSfdqHEqGRyECUTSTZrt+dxw1PzWLl1L7eeMZDvnNRXNxxK0qkomSTjHfAi9c6yTXu4/J9zMIPHrhnLSQO+ssS2SFJTMhEJ2drteUyaMpemDVOZet3R9GwXvxvZROKl7jfGiiSxrXvyueqRTyiJRHjyW2OVSKTWUjIRCcmuvEImPTKXXXmFPHbN2Bq/uUwkkdTMJRKCPflFTJoylzU78njsG0dyRPfWYYckUi2qmYgkWF5BMdc8+ikrtuzhH1eN4dh+7cMOSaTaVDMRSaD8ohK+/XgmCzfs5sErRlV5PiaRZKOaiUiCFBSXcN2T85izZgd/uuQIxg/rEnZIIjVGyUQkAYpKItz89AJmfZ7D7y4cwYSR3cIOSaRGKZmIxFlJxPneMwt5a9lWfjlhKJcc2T3skERqnJKJSJxN+XAN0xdt5idnD2LSMb3CDkckLpRMROJoc250AsPTBnXkuhP7hh2OSNwkbTIxsx+amZtZ+2DbzOx+M8sys0VmNrrMsVeb2argcXV4UYv8r1+/upziiHP314aGHYpIXCXl0GAz6w6cAawvU3wW0D94HAU8BBxlZm2Bu4AMogu3zTOzae6+K7FRi/yvWZ/nMH3xZn4YLPokUpcla83kz8CPiSaHUhOAJzxqDtDazLoAZwJvufvOIIG8BYxPeMQiZRQUl3DXtKX0bt+M607qE3Y4InGXdMnEzCYAG939s4N2dQM2lNnODsoqKi/vva8zs0wzy8zJyanBqEX+12+mL2fN9jx+8bWhCV9uVyQMoTRzmdlMoHM5u+4EfkK0iavGuftkYDJEF8eKx2eIPDVnHU/MXsd1J/bhRK1LIvVEKMnE3ceVV25mw4HewGdmBpAOzDezscBGoOwA/fSgbCNw8kHl79V40CIxmP3FDu6etpRTBnbgtvGDwg5HJGGSqpnL3Re7e0d37+XuvYg2WY129y3ANGBSMKrraCDX3TcDM4AzzKyNmbUhWquZEdY5SP21Yed+vvOvefRq34z7Lh9FqpbclXokKUdzVeA14GwgC9gPXAPg7jvN7FfAp8Fxv3T3neGEKPVVQXEJN/5rPpGI8/CkDFo2bhB2SCIJldTJJKidlD534KYKjpsCTElQWCJf8dvXV7B4Yy6Tvz6GXu21WqLUP0nVzCVSG81YuoVHP1rLNcf14oyh5Y0rEan7lExEqiF7135+9J/PGJHeijvOGhx2OCKhUTIRqSJ357bnFxFx+Ovlo2iYpv+dpP7Sv36RKnph/kY+ytrB7WcNomc79ZNI/aZkIlIFO/YV8OvpyxjTsw1XjO0RdjgioVMyEamC30xfzr6CYv7vwuGk6H4SESUTkcP1waocXliwke+c1JcBnVqEHY5IUlAyETkMhcUR7no5Ohvwjaf0CzsckaShZCJyGB79aA2rt+fx8/OG0LiBZgMWKaVkIhKjbXvyuf/tVZw2qCOnDOwYdjgiSUXJRCRGv31jBUUlzs/OHRJ2KCJJR8lEJAbz1u3ihfkb+dYJvTX3lkg5lExEDqGoJMLPXlpCp5aNuFmd7iLlSupZg0WSweRZq1m2eQ9/v2oMzRrpfxmR8qhmIlKJL3L2cd/bqzhrWGfGD9OMwCIVUTIRqUAk4tzx/GIap6XwiwlDww5HJKkpmYhU4F9z1zN37U5+es4QOrZoHHY4IkmtwgZgM/sr4BXtd/db4hKRSBJYsjGXX7+6jBP6t+fijPSwwxFJepXVTDKBeUBjYDSwKniMBBrGPzSRcOzeX8gNT82jTdOG/PnSkZhpIkeRQ6mwZuLujwOY2XeA4929ONj+O/BBYsITSaySiHPL1IVs21PAM9cfTfvmjcIOSaRWiKXPpA3Qssx286BMpE7JLyrhl68sZdbnOdz1tSGM6qF/5iKxiiWZ/BZYYGaPmdnjwHzgnngGZWb/z8xWmNlSM7u3TPkdZpZlZivN7Mwy5eODsiwzuz2esUnd4+68vHAjp/3xfR6fvY6rj+mpBa9EDlOld2CZWQqwEjgqeADc5u5b4hWQmZ0CTACOcPcCM+sYlA8BLgOGAl2BmWY2IHjZg8DpQDbwqZlNc/dl8YpR6oaikgivLd7Mwx+sYfHGXIZ0acnvLx7BsX3bhx2aSK1TaTJx94iZPejuo4CXExTTd4DfuntBEMO2oHwCMDUoX2NmWcDYYF+Wu68GMLOpwbFKJvJfkYizbW8BG3btZ8PO/XyRs48X5m9kc24+vds3496JI7hodDqpWjVRpEpimRvibTO7CHjB3SscKlyDBgAnmNlvgHzgVnf/FOgGzClzXHZQBrDhoPKjKIeZXQdcB9Cjh5ox6jJ3572VOTz28VpWb9/Hltx8ikr+95/vsX3b8evzh3HKwI5aelekmmJJJtcDPwCKzSwfMMDdvWXlL6uYmc0Eypub4s4gprbA0cCRwLNm1qeqn1WWu08GJgNkZGQkIjFKghWXRJi5fBsPvpvF4o25dG3VmCN7t6Vr6yZ0bdWY9LZN6d6mKeltmmhxK5EadMhk4u41vsi1u4+raF8wFLm0FjTXzCJAe2Aj0L3MoelBGZWUSz2xautenpuXzYsLNrJtbwE92zXl3otGcP6objRM00QPIvEW0xSoZtYG6E/0BkYA3H1WnGJ6CTgFeDfoYG8IbAemAU+b2Z+IdsD3B+YSrSn1N7PeRJPIZcAVcYpNksyGnfv57RsrmL5oM6kpxikDOzJxTDrjBnckLVVJRCRRDplMzOzbwHeJ/uJfSLT5aTZwapximgJMMbMlQCFwdVBLWWpmzxLtWC8GbnL3kiDGm4EZQCowxd2Xxik2SRJbcvOZ8tEaHvtoLSkpcMup/fj6Mb3o0EI3GYqEwQ7Vp25mi4n2Xcxx95FmNgi4x90vTESA8ZKRkeGZmZlhhyGHobA4wpvLtvCfzGw+WJWDAxeNTufWMwbSuZUmYhRJBDOb5+4ZB5fH0syV7+75ZoaZNXL3FWY2MA4xipRrX0ExU+eu55EP17A5N58urRpz48n9mDgmXUvoiiSJWJJJtpm1JtqX8ZaZ7QLWxTcskajn5mXzy1eWsie/mKP7tOWeC4Zz4oAOuh9EJMnEMprrguDp3Wb2LtAKeCOuUUm9F4k4f3hzJX977wuO6t2WO84ezMjurcMOS0QqEEsH/K+AWcDH7v5+/EOS+i6/qIQf/uczpi/azOVju/PLCcNooJFZIkktlmau1cDlwP1mtpfo9POz3D1R06tIPbJ88x6+/8xCVm7dy0/OHsS1J/TReiIitUAszVyPAo+aWWfgEuBWolOS1PjNjFJ/lUSchz9YzR/f/JxWTRvw6DeO5OSBHcMOS0RiFEsz18PAEGAr0VrJRKLT0IvUiPyiEm5+egEzl29l/NDO3HPhcNo202KeIrVJLM1c7YjeDLgb2AlsL111UaS6DhSWcN2TmXywajt3nTeEbxzbS81aIrVQzKO5zGwwcCbRaU5S3T093sFJ3ZZXUMy3H89kzpod3HvRCC45svuhXyQiSSmWZq5zgROAE4HWwDtoDXipps25B7jhyXks2bSHP18ykvNHdTv0i0QkacXSzDWeaPK4z903xTkeqQcy1+7khqfmc6CwmL9fNYbTh3QKOyQRqaZYmrluNrOeRDvhN5lZEyDN3ffGPTqpc575dD0/fWkJ3Vo34d/XHkX/ThoUKFIXxNLMdS3RocBtgb5EZw/+O3BafEOTusTd+fPMVdz/9ipOHNCBv142ilZNG4QdlojUkFiauW4iutb6JwDuvsrMdAOAxKy4JMKdLy7hmcwNXJKRzj0XDNdaIyJ1TCzJpMDdC0uHa5pZGqAlbyUmJRHnxn/N581lW7nl1H58//QBGvorUgfFkkzeN7OfAE3M7HTgRuCV+IYldcVvpi/nzWVb+fm5Q/jm8b3DDkdE4iSWtobbgRxgMXA98Jq73xnXqKROeGrOOqZ8tIZrjuulRCJSxx0ymbh7xN3/6e4Xu/tEYJ2ZvZWA2KQW+3DVdu6atpRTBnbgp+cMCTscEYmzCpOJmZ1qZp+b2T4ze8rMhptZJvB/wEOJC1Fqm09W7+CGp+bRr0Nz7r98lBayEqkHKquZ/JHokOB2wHPAbOAxdx/j7i8kIjipfd5dsY1JU+bSqWUjHvvmkbRorOG/IvVBZR3w7u7vBc9fMrON7v5AAmKSWmr6os18d+oCBnVpwePXjKVd80ZhhyQiCVJZzaS1mV1Y+gDSDtqOCzMbaWZzzGyhmWWa2dig3MzsfjPLMrNFZja6zGuuNrNVwePqeMUmFZu/fhe3TF3AqB6tefrao5VIROqZymom7wPnldmeVWbbgXg1dd0L/MLdXzezs4Ptk4GzgP7B4yii/TZHmVlb4C4gI4hrnplNc/ddcYpPDrInv4jvTl1A55aNefjqI2mppi2ReqfCZOLu1yQykLIfDbQMnrcCSieXnAA84e4OzDGz1mbWhWiiecvddwIEI83GA/9OaNT1lLvzs5eWsGl3Ps9efzStmiiRiNRHsdy0mGjfA2aY2R+INsMdG5R3AzaUOS47KKuo/CvM7Dqigwro0aNHzUZdT70wfyMvL9zED04fwJiebcMOR0RCEkoyMbOZQOdydt1JdALJ77v782Z2CfAIMK4mPtfdJwOTATIyMjQlTDVlbdvLz19ewtjebbnplH5hhyMiIQolmbh7hcnBzJ4Avhts/gd4OHi+ESi7FF96ULaRaFNX2fL3aihUqcDe/CKue3IejRukct9lI3UviUg9d8g74M3sJjNrXWa7jZndGMeYNgEnBc9PBVYFz6cBk4JRXUcDue6+GZgBnBHE1QY4IyiTOHF3bv3PZ6zbsZ8HrhhNl1ZNwg5JREIWS83kWnd/sHTD3XcFa5z8LU4xXQvcF8xOnE/QxwG8BpwNZAH7gWuCeHaa2a+AT4PjflnaGS/x8dD7XzBj6VZ+es5gjunbLuxwRCQJxJJMUs3MglFUmFkq0DBeAbn7h8CYcsqd6Noq5b1mCjAlXjHJl5Zt2sMfZqzk3BFd+JYmbxSRQCzJ5A3gGTP7R7B9fVAm9dAf31xJ80Zp/OaC4VqXRET+K5ZkchvRBPKdYPstvuwUl3pk/vpdvL1iGz86c6DuJxGR/3HIZOLuEaJ3m2um4HruDzNW0r55Q75xbK+wQxGRJFNhMjGzZ939EjNbTDnL9Lr7iLhGJknl46ztfPzFDn527hCaNUrGe11FJEyVfSuU3utxbiICkeTl7vz+zZV0adWYK4/SzAEi8lWVzc21Ofi7LnHhSDKauXwbC9bv5p4LhtO4QWrY4YhIEqqsmWsv5TRvlXL3lhXtk7qjqCTC/72+nD4dmnFxRnrY4YhIkqqsZtICILghcDPwJGDAlUCXhEQnoZv66QZW5+Txz0kZNEg95IQJIlJPxfLt8DV3/5u773X3Pe7+ENHp4KWO21dQzH0zP2ds77aMG9wx7HBEJInFkkzyzOxKM0s1sxQzuxLIi3dgEr5/vP8F2/cVcufZg3WDoohUKpZkcgVwCbAV2AZcHJRJHbZ1Tz7//GA15x3RlSO6tz70C0SkXovlpsW1qFmr3vnbu1kUlzg/OmNg2KGISC0QyxT06Wb2opltCx7Pm5mG9dRhW3Lz+ffcDUwck06Pdk3DDkdEaoFYmrkeJbqWSNfg8UpQJnXU39//goi7Vk8UkZjFkkw6uPuj7l4cPB4DOsQ5LgnJ1j35PD13PRPHpNO9rWolIhKbWJLJDjO7KhjNlWpmVwE74h2YhOOh974gElGtREQOTyzJ5JtER3NtIXrz4kSCVQ6lbimtlVw0WrUSETk8sYzmWgd8LQGxSMimfLiG4pKIaiUictgqm5vrx+5+r5n9lfKnoL8lrpFJQu0rKObpues5a3gXjeASkcNWWc1kefA3MxGBSLie/XQDe/OLufaEPmGHIiK1UIV9Ju7+SvD38dIH0ckeXwyeV5mZXWxmS80sYmYZB+27w8yyzGylmZ1Zpnx8UJZlZreXKe9tZp8E5c+YWcPqxFYflUScKR+tIaNnG0bqbncRqYJYblp82sxamlkzYAmwzMx+VM3PXQJcCMw66LOGAJcBQ4HxwN9KR5EBDwJnAUOAy4NjAX4H/Nnd+wG7gG9VM7Z6Z8bSLWTvOsC3VSsRkSqKZTTXEHffA5wPvA70Br5enQ919+XuvrKcXROAqe5e4O5rgCxgbPDIcvfV7l4ITAUmWHT2wVOB54LXPx7EKYfhnx+spme7ppw+pFPYoYhILRVLMmlgZg2IfklPc/ciKlk0q5q6ARvKbGcHZRWVtwN2u3vxQeXlMrPrzCzTzDJzcnJqNPDaat66XSxYv5tvHteb1BTNDCwiVRNLMvkHsBZoBswys57AnkO9yMxmmtmSch6hTRrp7pPdPcPdMzp00E38AE/OXkuLRmlMHKPp1kSk6mK5z+R+4P4yRevM7JQYXjeuCvFsBLqX2U4PyqigfAfQ2szSgtpJ2ePlEHblFfLaki1cdmR3mjU65D8FEZEKxdIB387M7jez+WY2z8zuA1rFKZ5pwGVm1sjMegP9gbnAp0D/YORWQ6Kd9NPc3YF3id6VD3A18HKcYqtznp+fTWFxhCuO6hF2KCJSy8XSzDUVyAEuIvqlnQM8U50PNbMLzCwbOAaYbmYzANx9KfAssAx4A7jJ3UuCWsfNwAyi9788GxwLcBvwAzPLItqH8kh1Yqsv3J2n565nVI/WDOrcMuxwRKSWs+iP+0oOMFvi7sMOKlvs7sPjGlmcZWRkeGZm/b0f85PVO7h08hx+P3EEF2d0P/QLREQAM5vn7hkHl8dSM3nTzC4L1n9PMbNLiNYQpBZ7eu56WjRO49wRXcMORUTqgFiSybXA00BB8JgKXG9me83skKO6JPnsyivk9cVbuHBUN5o0TA07HBGpA2IZzdUiEYFI4jw/P5vCkgiXq+NdRGpIhTWTYBGs0ufHHbTv5ngGJfETiThPzllHRs826ngXkRpTWTPXD8o8/+tB+74Zh1gkAWatymHdjv18/ZieYYciInVIZcnEKnhe3rbUEk/OXkf75o04a1iXsEMRkTqksmTiFTwvb1tqgQ079/POym1cPrY7DdNiGXshIhKbyjrgB5nZIqK1kL7Bc4JtzVVeCz31yTpSzHTHu4jUuMqSyeCERSFxl19UwrOfbuD0wZ3o0qpJ2OGISB1TYTJx93WJDETi69VFm9m1v4hJ6ngXkThQw3k98WzmBvq0b8YxfduFHYqI1EFKJvXAhp37mbtmJxeO7kZ0cUoRkZqlZFIPvLwwusTLhJEVLkIpIlItVUomZnZ3DcchceLuvLBgI2N7t6V726ZhhyMidVRVaybzajQKiZtF2bmszsnjglGqlYhI/FQpmbj7KzUdiMTHiws20jAthbOH6453EYmfQ84abGb3l1OcC2S6u5bITWJFJRFe+WwT4wZ3pFWTBmGHIyJ1WCw1k8bASGBV8BgBpAPfMrO/xDE2qaYPVuWwI6+QC0alhx2KiNRxh6yZEE0ex7l7CYCZPQR8ABwPLI5jbFJNz8/fSJumDThpQIewQxGROi6WmkkboHmZ7WZA2yC5FMQlKqm23fsLeWvpViaM7KZJHUUk7mKpmdwLLDSz94hO8ngicI+ZNQNmxjE2qYaXF26isCTCJRndww5FROqBQ/5kdfdHgGOBl4AXgePd/WF3z3P3H1XlQ83sYjNbamYRM8soU366mc0zs8XB31PL7BsTlGeZ2f0W3MptZm3N7C0zWxX8bVOVmOqaZzM3MKxbS4Z01WqKIhJ/h0wmZvYKcDIw091fdvdNNfC5S4ALgVkHlW8HznP34cDVwJNl9j0EXAv0Dx7jg/LbgbfdvT/wdrBdry3ZmMvSTXtUKxGRhImlMf0PwAnAMjN7zswmmlnj6nyouy9395XllC8ok6yWAk3MrJGZdQFauvscd3fgCeD84LgJwOPB88fLlNdbz83LpmFqCl87omvYoYhIPRFLM9f77n4j0QWx/gFcAmyLd2DARcB8dy8AugHZZfZlB2UAndx9c/B8C9Cpojc0s+vMLNPMMnNycuIRc+jyi0p4ccFGzhjaidZNG4YdjojUE7F0wGNmTYDzgEuB0XxZE6jsNTOBzuXsuvNQNzua2VDgd8AZscRXyt3dzCpcUtjdJwOTATIyMurk0sMzl28l90CRmrhEJKFiuQP+WWAs8AbwAPC+u0cO9Tp3H1eVgMwsnWhH/yR3/yIo3rihxVYAAA7vSURBVEj0RslS6UEZwFYz6+Lum4PmsETUmpLWM59uoGurxhzXr33YoYhIPRJLn8kjQF93v8Hd3wWONbMH4xGMmbUGpgO3u/tHpeVBM9YeMzs6GMU1CSit3Uwj2llP8LfeTvGyOmcfH6zazqVH9iA1ReuWiEjixNJnMgMYYWb3mtla4FfAiup8qJldYGbZwDHAdDObEey6GegH/NzMFgaPjsG+G4GHgSzgC+D1oPy3wOlmtgoYF2zXS0/NWU+DVOPyo9TEJSKJVWEzl5kNAC4PHtuBZwBz91Oq+6Hu/iLRpqyDy38N/LqC12QCw8op3wGcVt2Yarv9hcX8Z94Gxg/rQscW1RpsJyJy2CrrM1lBdA6uc909C8DMvp+QqOSwvbRgE3vzi5l0TM+wQxGReqiyZq4Lgc3Au2b2TzM7jeh0KpJk3J0nZq9lcJeWZPTUBAAikngVJhN3f8ndLwMGAe8C3wM6mtlDZnZYQ3Ylvj5du4sVW/Yy6ZieBLPMiIgkVCwd8Hnu/rS7n0d0SO4C4La4RyYxe+zjNbRonMaEkbrjXUTCcVhzk7v7Lnef7O71vsM7Wby9fCuvLd7CNcf2omnDmO5BFRGpcVroohbbmVfIbc8vZlDnFtx0ar+wwxGRekw/ZWspd+dnLy0h90AhT3xzLI3SUsMOSUTqMdVMaqlpn21i+uLNfG/cAK1ZIiKhUzKphXbsK+CuaUsZ2b0115/YJ+xwRESUTGqj/3t9Bfvyi7l34gjSUnUJRSR8+iaqZeau2clz87L59gl9GNCpRdjhiIgASia1SlFJhJ++tJhurZtwy2kavSUiyUOjuWqRRz5cw+db9/HPSRm6p0REkopqJrVEXkExf317FeMGd+T0IRWuTCwiEgolk1ritcWbySss4Tsn9w07FBGRr1AyqSWem5dN7/bNGN1DswKLSPJRMqkFNuzczydrdnLR6G6aFVhEkpKSSS3w/PxszOCC0elhhyIiUi4lkyQXiTjPz8/muL7t6da6SdjhiIiUS8kkyc1du5MNOw8wcYxqJSKSvJRMktxz87Jp3iiNM4d2DjsUEZEKhZJMzOxiM1tqZhEzyyhnfw8z22dmt5YpG29mK80sy8xuL1Pe28w+CcqfMbOGiTqPeMs9UMTrizdzzvAuNGmoKeZFJHmFVTNZAlwIzKpg/5+A10s3zCwVeBA4CxgCXG5mQ4LdvwP+7O79gF3At+IVdKJN+XANeYUlTDq2Z9ihiIhUKpRk4u7L3X1lefvM7HxgDbC0TPFYIMvdV7t7ITAVmGDRcbKnAs8Fxz0OnB+/yBMn90ARUz5aw5lDOzG0a6uwwxERqVRS9ZmYWXPgNuAXB+3qBmwos50dlLUDdrt78UHlFb3/dWaWaWaZOTk5NRd4HEz5cA1784u55bT+YYciInJIcUsmZjbTzJaU85hQycvuJtpktS8eMbn7ZHfPcPeMDh06xOMjaoRqJSJS28Rt6ll3H1eFlx0FTDSze4HWQMTM8oF5QPcyx6UDG4EdQGszSwtqJ6XltZpqJSJS2yTVPObufkLpczO7G9jn7g+YWRrQ38x6E00WlwFXuLub2bvARKL9KFcDLyc+8pqTu1+1EhGpfcIaGnyBmWUDxwDTzWxGZccHtY6bgRnAcuBZdy/toL8N+IGZZRHtQ3kkfpHH3yMfrmZvfjHfPW1A2KGIiMQslJqJu78IvHiIY+4+aPs14LVyjltNdLRXrbcrr5ApH63l7OGdGdK1ZdjhiIjELKlGc9V3//xgNXmFxXxvnGolIlK7KJkkiR37Cnjs47WcN6IrAzq1CDscEZHDomSSJP4xazX5RSUawSUitZKSSRLI2VvAE7PXcv7IbvTr2DzscEREDpuSSRJ4/OO1FBRHuPnUfmGHIiJSJUom1bBh537mrtlZrffIKyjmidlrOXNIZ/p0UK1ERGonJZNq+OlLS7h08mxemJ9d5feY+ukG9uQXc/1JfWowMhGRxEqqO+Brk4LiEuau2UmD1BRu/c9nNExL4dwRXf/nGHcnZ28Bm3LzKYk44DRrlMagztF7SIpKIjzywWrG9m7LqB5tQjgLEZGaoWRSRQvW7+ZAUQn3XTaSp+as43tTF7Jp9wHyiyJkbdvHFzn7WLs9j7zCkq+8dsLIrvzya8N4Z+VWNuXm8+sLhoVwBiIiNUfJpIo+ytpOisEpgzpy6qCOXPXIXO55bQUA6W2a0LdDc47s1ZY+HZrRtVUT0lINM2Peul387d0s5qzeQaO0VAZ0as7JAzqGfDYiItWjZFJFH2Vt54jurWnZuAEA/7n+GNZsz6N72yY0bVjxf9aTBnTg9MGd+MGzC1m1bR+/nziClBRLVNgiInGhZFIFe/KL+Cw7l++c1Pe/ZQ3TUhjYObY714ent+KV/3c8n67dyXF928crTBGRhFEyqYJPVu+kJOIc16/qiaBxg1RO6J+8C3SJiBwODQ2ugo+yttO4QQqje7YOOxQRkaSgZFIFH2Vt58hebWmUlhp2KCIiSUHJ5DBt3ZPPqm37OL4aTVwiInWNkslh+ihrO0C1+ktEROoaJZPD9FHWDto0bcCQLloJUUSklEZzHaa+HZvRoUUP3RsiIlKGkslhuvFkTRMvInKwUJq5zOxiM1tqZhEzyzho3wgzmx3sX2xmjYPyMcF2lpndb2YWlLc1s7fMbFXwVzMmiogkWFh9JkuAC4FZZQvNLA14CrjB3YcCJwNFwe6HgGuB/sFjfFB+O/C2u/cH3g62RUQkgUJJJu6+3N1XlrPrDGCRu38WHLfD3UvMrAvQ0t3nuLsDTwDnB6+ZADwePH+8TLmIiCRIso3mGgC4mc0ws/lm9uOgvBtQdgWq7KAMoJO7bw6ebwE6JSZUEREpFbcOeDObCXQuZ9ed7v5yJfEcDxwJ7AfeNrN5QG4sn+nubmZeSUzXAdcB9OjRI5a3FBGRGMQtmbj7uCq8LBuY5e7bAczsNWA00X6U9DLHpQMbg+dbzayLu28OmsO2VRLTZGAyQEZGRoVJR0REDk+yNXPNAIabWdOgM/4kYFnQjLXHzI4ORnFNAkprN9OAq4PnV5cpFxGRBAlraPAFZpYNHANMN7MZAO6+C/gT8CmwEJjv7tODl90IPAxkAV8ArwflvwVON7NVwLhgW0REEsiig6PqHzPLAdZV8eXtge01GE5tUR/Puz6eM9TP89Y5x6anu39lMaZ6m0yqw8wy3T3j0EfWLfXxvOvjOUP9PG+dc/UkW5+JiIjUQkomIiJSbUomVTM57ABCUh/Puz6eM9TP89Y5V4P6TEREpNpUMxERkWpTMhERkWpTMjlMZjbezFYG66rUyenuzay7mb1rZsuCdWW+G5TX+bVjzCzVzBaY2avBdm8z+yS43s+YWcOwY6xpZtbazJ4zsxVmttzMjqnr19rMvh/8215iZv82s8Z18Vqb2RQz22ZmS8qUlXttLer+4PwXmdnow/ksJZPDYGapwIPAWcAQ4HIzGxJuVHFRDPzQ3YcARwM3BedZH9aO+S6wvMz274A/u3s/YBfwrVCiiq/7gDfcfRBwBNHzr7PX2sy6AbcAGe4+DEgFLqNuXuvH+HLtp1IVXduz+HK9qOuIriEVMyWTwzMWyHL31e5eCEwlup5KneLum919fvB8L9Evl27U8bVjzCwdOIfotD0E88CdCjwXHFIXz7kVcCLwCIC7F7r7bur4tSY6yW2TYA7ApsBm6uC1dvdZwM6Diiu6thOAJzxqDtA6mDw3Jkomh6cbsKHMdtl1VeokM+sFjAI+oe6vHfMX4MdAJNhuB+x29+Jguy5e795ADvBo0Lz3sJk1ow5fa3ffCPwBWE80ieQC86j717pURde2Wt9vSiZSITNrDjwPfM/d95TdF6x4WWfGlZvZucA2d58XdiwJlkZ0mYeH3H0UkMdBTVp18Fq3IforvDfQFWjGV5uC6oWavLZKJodnI9C9zHbZdVXqFDNrQDSR/MvdXwiKt5ZWew+1dkwtdBzwNTNbS7T58lSifQmtg6YQqJvXOxvIdvdPgu3niCaXunytxwFr3D3H3YuAF4he/7p+rUtVdG2r9f2mZHJ4PgX6B6M+GhLttJsWckw1LugreARY7u5/KrOrzq4d4+53uHu6u/ciel3fcfcrgXeBicFhdeqcAdx9C7DBzAYGRacBy6jD15po89bRwbpJxpfnXKevdRkVXdtpwKRgVNfRQG6Z5rBD0h3wh8nMzibatp4KTHH334QcUo0zs+OBD4DFfNl/8BOi/SbPAj2ITt9/ibsf3LlX65nZycCt7n6umfUhWlNpCywArnL3gjDjq2lmNpLooIOGwGrgGqI/NOvstTazXwCXEh25uAD4NtH+gTp1rc3s38DJRKea3wrcBbxEOdc2SKwPEG3y2w9c4+6ZMX+WkomIiFSXmrlERKTalExERKTalExERKTalExERKTalExERKTalExEaoiZlZjZwjKPSidHNLMbzGxSDXzuWjNrX933EakODQ0WqSFmts/dm4fwuWuJzoC7PdGfLVJKNROROAtqDvea2WIzm2tm/YLyu83s1uD5LcH6MYvMbGpQ1tbMXgrK5pjZiKC8nZm9GazH8TBgZT7rquAzFprZP4JlE0TiTslEpOY0OaiZ69Iy+3LdfTjRO4z/Us5rbwdGufsI4Iag7BfAgqDsJ8ATQfldwIfuPhR4keidzJjZYKJ3dR/n7iOBEuDKmj1FkfKlHfoQEYnRgeBLvDz/LvP3z+XsXwT8y8xeIjrdBcDxwEUA7v5OUCNpSXT9kQuD8ulmtis4/jRgDPBpdGYMmlC3JmiUJKZkIpIYXsHzUucQTRLnAXea2fAqfIYBj7v7HVV4rUi1qJlLJDEuLfN3dtkdZpYCdHf3d4HbgFZAc6KTbV4ZHHMysD1YV2YWcEVQfhZQuj7728BEM+sY7GtrZj3jeE4i/6WaiUjNaWJmC8tsv+HupcOD25jZIqAAuPyg16UCTwVL6Bpwv7vvNrO7gSnB6/bz5bThvwD+bWZLgY+JTqmOuy8zs58CbwYJqgi4iejMsCJxpaHBInGmobtSH6iZS0REqk01ExERqTbVTEREpNqUTEREpNqUTEREpNqUTEREpNqUTEREpNr+PzuwM5S9NvHoAAAAAElFTkSuQmCC\n",
|
| 585 |
+
"text/plain": [
|
| 586 |
+
"<Figure size 432x288 with 1 Axes>"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
"metadata": {
|
| 590 |
+
"needs_background": "light"
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
],
|
| 594 |
+
"source": [
|
| 595 |
+
"# To store reward history of each episode\n",
|
| 596 |
+
"ep_reward_list = []\n",
|
| 597 |
+
"# To store average reward history of last few episodes\n",
|
| 598 |
+
"avg_reward_list = []\n",
|
| 599 |
+
"\n",
|
| 600 |
+
"# Takes about 4 min to train\n",
|
| 601 |
+
"for ep in range(total_episodes):\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" prev_state = env.reset()\n",
|
| 604 |
+
" episodic_reward = 0\n",
|
| 605 |
+
"\n",
|
| 606 |
+
" while True:\n",
|
| 607 |
+
" # Uncomment this to see the Actor in action\n",
|
| 608 |
+
" # But not in a python notebook.\n",
|
| 609 |
+
" # env.render()\n",
|
| 610 |
+
"\n",
|
| 611 |
+
" tf_prev_state = tf.expand_dims(tf.convert_to_tensor(prev_state), 0)\n",
|
| 612 |
+
"\n",
|
| 613 |
+
" action = policy(tf_prev_state, ou_noise)\n",
|
| 614 |
+
" # Recieve state and reward from environment.\n",
|
| 615 |
+
" state, reward, done, info = env.step(action)\n",
|
| 616 |
+
"\n",
|
| 617 |
+
" buffer.record((prev_state, action, reward, state))\n",
|
| 618 |
+
" episodic_reward += reward\n",
|
| 619 |
+
"\n",
|
| 620 |
+
" buffer.learn()\n",
|
| 621 |
+
" update_target(target_actor.variables, actor_model.variables, tau)\n",
|
| 622 |
+
" update_target(target_critic.variables, critic_model.variables, tau)\n",
|
| 623 |
+
"\n",
|
| 624 |
+
" # End this episode when `done` is True\n",
|
| 625 |
+
" if done:\n",
|
| 626 |
+
" break\n",
|
| 627 |
+
"\n",
|
| 628 |
+
" prev_state = state\n",
|
| 629 |
+
"\n",
|
| 630 |
+
" ep_reward_list.append(episodic_reward)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
" # Mean of last 40 episodes\n",
|
| 633 |
+
" avg_reward = np.mean(ep_reward_list[-40:])\n",
|
| 634 |
+
" print(\"Episode * {} * Avg Reward is ==> {}\".format(ep, avg_reward))\n",
|
| 635 |
+
" avg_reward_list.append(avg_reward)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"# Plotting graph\n",
|
| 638 |
+
"# Episodes versus Avg. Rewards\n",
|
| 639 |
+
"plt.plot(avg_reward_list)\n",
|
| 640 |
+
"plt.xlabel(\"Episode\")\n",
|
| 641 |
+
"plt.ylabel(\"Avg. Epsiodic Reward\")\n",
|
| 642 |
+
"plt.show()"
|
| 643 |
+
]
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"cell_type": "markdown",
|
| 647 |
+
"metadata": {
|
| 648 |
+
"id": "XY85n6_l_wFb"
|
| 649 |
+
},
|
| 650 |
+
"source": [
|
| 651 |
+
"If training proceeds correctly, the average episodic reward will increase with time.\n",
|
| 652 |
+
"\n",
|
| 653 |
+
"Feel free to try different learning rates, `tau` values, and architectures for the\n",
|
| 654 |
+
"Actor and Critic networks.\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"The Inverted Pendulum problem has low complexity, but DDPG work great on many other\n",
|
| 657 |
+
"problems.\n",
|
| 658 |
+
"\n",
|
| 659 |
+
"Another great environment to try this on is `LunarLandingContinuous-v2`, but it will take\n",
|
| 660 |
+
"more episodes to obtain good results."
|
| 661 |
+
]
|
| 662 |
+
},
|
| 663 |
+
{
|
| 664 |
+
"cell_type": "code",
|
| 665 |
+
"execution_count": 9,
|
| 666 |
+
"metadata": {
|
| 667 |
+
"id": "fDayimW0_wFb"
|
| 668 |
+
},
|
| 669 |
+
"outputs": [],
|
| 670 |
+
"source": [
|
| 671 |
+
"# Save the weights\n",
|
| 672 |
+
"actor_model.save_weights(\"pendulum_actor.h5\")\n",
|
| 673 |
+
"critic_model.save_weights(\"pendulum_critic.h5\")\n",
|
| 674 |
+
"\n",
|
| 675 |
+
"target_actor.save_weights(\"pendulum_target_actor.h5\")\n",
|
| 676 |
+
"target_critic.save_weights(\"pendulum_target_critic.h5\")"
|
| 677 |
+
]
|
| 678 |
+
},
|
| 679 |
+
{
|
| 680 |
+
"cell_type": "markdown",
|
| 681 |
+
"metadata": {
|
| 682 |
+
"id": "hYiCdLyE_wFb"
|
| 683 |
+
},
|
| 684 |
+
"source": [
|
| 685 |
+
"Before Training:\n",
|
| 686 |
+
"\n",
|
| 687 |
+
""
|
| 688 |
+
]
|
| 689 |
+
},
|
| 690 |
+
{
|
| 691 |
+
"cell_type": "markdown",
|
| 692 |
+
"metadata": {
|
| 693 |
+
"id": "D1lklgTJ_wFc"
|
| 694 |
+
},
|
| 695 |
+
"source": [
|
| 696 |
+
"After 100 episodes:\n",
|
| 697 |
+
"\n",
|
| 698 |
+
""
|
| 699 |
+
]
|
| 700 |
+
},
|
| 701 |
+
{
|
| 702 |
+
"cell_type": "code",
|
| 703 |
+
"source": [
|
| 704 |
+
"!pip install huggingface-hub\n",
|
| 705 |
+
"!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n",
|
| 706 |
+
"!sudo apt-get install git-lfs\n",
|
| 707 |
+
"!git-lfs install"
|
| 708 |
+
],
|
| 709 |
+
"metadata": {
|
| 710 |
+
"id": "c6Ao1vi4_zwE",
|
| 711 |
+
"outputId": "a2aa4ade-a162-432f-92d3-1d1358c2ead6",
|
| 712 |
+
"colab": {
|
| 713 |
+
"base_uri": "https://localhost:8080/"
|
| 714 |
+
}
|
| 715 |
+
},
|
| 716 |
+
"execution_count": 10,
|
| 717 |
+
"outputs": [
|
| 718 |
+
{
|
| 719 |
+
"output_type": "stream",
|
| 720 |
+
"name": "stdout",
|
| 721 |
+
"text": [
|
| 722 |
+
"Collecting huggingface-hub\n",
|
| 723 |
+
" Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)\n",
|
| 724 |
+
"\u001b[?25l\r\u001b[K |█████▎ | 10 kB 21.9 MB/s eta 0:00:01\r\u001b[K |██████████▋ | 20 kB 14.7 MB/s eta 0:00:01\r\u001b[K |███████████████▉ | 30 kB 11.1 MB/s eta 0:00:01\r\u001b[K |█████████████████████▏ | 40 kB 9.7 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▌ | 51 kB 5.2 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▊| 61 kB 5.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 61 kB 446 kB/s \n",
|
| 725 |
+
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (3.4.0)\n",
|
| 726 |
+
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (21.3)\n",
|
| 727 |
+
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (3.13)\n",
|
| 728 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (3.10.0.2)\n",
|
| 729 |
+
"Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (4.8.2)\n",
|
| 730 |
+
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (4.62.3)\n",
|
| 731 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (2.23.0)\n",
|
| 732 |
+
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.9->huggingface-hub) (3.0.6)\n",
|
| 733 |
+
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->huggingface-hub) (3.6.0)\n",
|
| 734 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (2021.10.8)\n",
|
| 735 |
+
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (2.10)\n",
|
| 736 |
+
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (3.0.4)\n",
|
| 737 |
+
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (1.24.3)\n",
|
| 738 |
+
"Installing collected packages: huggingface-hub\n",
|
| 739 |
+
"Successfully installed huggingface-hub-0.2.1\n",
|
| 740 |
+
"Detected operating system as Ubuntu/bionic.\n",
|
| 741 |
+
"Checking for curl...\n",
|
| 742 |
+
"Detected curl...\n",
|
| 743 |
+
"Checking for gpg...\n",
|
| 744 |
+
"Detected gpg...\n",
|
| 745 |
+
"Running apt-get update... done.\n",
|
| 746 |
+
"Installing apt-transport-https... done.\n",
|
| 747 |
+
"Installing /etc/apt/sources.list.d/github_git-lfs.list...done.\n",
|
| 748 |
+
"Importing packagecloud gpg key... done.\n",
|
| 749 |
+
"Running apt-get update... done.\n",
|
| 750 |
+
"\n",
|
| 751 |
+
"The repository is setup! You can now install packages.\n",
|
| 752 |
+
"Reading package lists... Done\n",
|
| 753 |
+
"Building dependency tree \n",
|
| 754 |
+
"Reading state information... Done\n",
|
| 755 |
+
"The following NEW packages will be installed:\n",
|
| 756 |
+
" git-lfs\n",
|
| 757 |
+
"0 upgraded, 1 newly installed, 0 to remove and 67 not upgraded.\n",
|
| 758 |
+
"Need to get 6,526 kB of archives.\n",
|
| 759 |
+
"After this operation, 14.7 MB of additional disk space will be used.\n",
|
| 760 |
+
"Get:1 https://packagecloud.io/github/git-lfs/ubuntu bionic/main amd64 git-lfs amd64 3.0.2 [6,526 kB]\n",
|
| 761 |
+
"Fetched 6,526 kB in 1s (6,123 kB/s)\n",
|
| 762 |
+
"debconf: unable to initialize frontend: Dialog\n",
|
| 763 |
+
"debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 1.)\n",
|
| 764 |
+
"debconf: falling back to frontend: Readline\n",
|
| 765 |
+
"debconf: unable to initialize frontend: Readline\n",
|
| 766 |
+
"debconf: (This frontend requires a controlling tty.)\n",
|
| 767 |
+
"debconf: falling back to frontend: Teletype\n",
|
| 768 |
+
"dpkg-preconfigure: unable to re-open stdin: \n",
|
| 769 |
+
"Selecting previously unselected package git-lfs.\n",
|
| 770 |
+
"(Reading database ... 155226 files and directories currently installed.)\n",
|
| 771 |
+
"Preparing to unpack .../git-lfs_3.0.2_amd64.deb ...\n",
|
| 772 |
+
"Unpacking git-lfs (3.0.2) ...\n",
|
| 773 |
+
"Setting up git-lfs (3.0.2) ...\n",
|
| 774 |
+
"Git LFS initialized.\n",
|
| 775 |
+
"Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n",
|
| 776 |
+
"Git LFS initialized.\n"
|
| 777 |
+
]
|
| 778 |
+
}
|
| 779 |
+
]
|
| 780 |
+
},
|
| 781 |
+
{
|
| 782 |
+
"cell_type": "code",
|
| 783 |
+
"source": [
|
| 784 |
+
"!huggingface-cli login"
|
| 785 |
+
],
|
| 786 |
+
"metadata": {
|
| 787 |
+
"id": "mBqbC9OLBIzY",
|
| 788 |
+
"outputId": "e213d779-fd78-49d6-affd-cfb0869e5624",
|
| 789 |
+
"colab": {
|
| 790 |
+
"base_uri": "https://localhost:8080/"
|
| 791 |
+
}
|
| 792 |
+
},
|
| 793 |
+
"execution_count": 11,
|
| 794 |
+
"outputs": [
|
| 795 |
+
{
|
| 796 |
+
"output_type": "stream",
|
| 797 |
+
"name": "stdout",
|
| 798 |
+
"text": [
|
| 799 |
+
"\n",
|
| 800 |
+
" _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n",
|
| 801 |
+
" _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
|
| 802 |
+
" _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n",
|
| 803 |
+
" _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
|
| 804 |
+
" _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n",
|
| 805 |
+
"\n",
|
| 806 |
+
" To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/token.\n",
|
| 807 |
+
" (Deprecated, will be removed in v0.3.0) To login with username and password instead, interrupt with Ctrl+C.\n",
|
| 808 |
+
" \n",
|
| 809 |
+
"Token: \n",
|
| 810 |
+
"Login successful\n",
|
| 811 |
+
"Your token has been saved to /root/.huggingface/token\n",
|
| 812 |
+
"\u001b[1m\u001b[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.\n",
|
| 813 |
+
"You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set this credential helper as the default\n",
|
| 814 |
+
"\n",
|
| 815 |
+
"git config --global credential.helper store\u001b[0m\n"
|
| 816 |
+
]
|
| 817 |
+
}
|
| 818 |
+
]
|
| 819 |
+
},
|
| 820 |
+
{
|
| 821 |
+
"cell_type": "code",
|
| 822 |
+
"source": [
|
| 823 |
+
"\n",
|
| 824 |
+
"from huggingface_hub.keras_mixin import push_to_hub_keras\n",
|
| 825 |
+
"push_to_hub_keras(model = actor_model, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
|
| 826 |
+
],
|
| 827 |
+
"metadata": {
|
| 828 |
+
"id": "B6pop1vc_4yZ",
|
| 829 |
+
"outputId": "f1635dd0-ac6c-4375-8054-3c873f259ef5",
|
| 830 |
+
"colab": {
|
| 831 |
+
"base_uri": "https://localhost:8080/",
|
| 832 |
+
"height": 141
|
| 833 |
+
}
|
| 834 |
+
},
|
| 835 |
+
"execution_count": 12,
|
| 836 |
+
"outputs": [
|
| 837 |
+
{
|
| 838 |
+
"output_type": "stream",
|
| 839 |
+
"name": "stderr",
|
| 840 |
+
"text": [
|
| 841 |
+
"Cloning https://huggingface.co/keras-io/deep-deterministic-policy-gradient into local empty directory.\n"
|
| 842 |
+
]
|
| 843 |
+
},
|
| 844 |
+
{
|
| 845 |
+
"output_type": "stream",
|
| 846 |
+
"name": "stdout",
|
| 847 |
+
"text": [
|
| 848 |
+
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 849 |
+
"INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
|
| 850 |
+
]
|
| 851 |
+
},
|
| 852 |
+
{
|
| 853 |
+
"output_type": "stream",
|
| 854 |
+
"name": "stderr",
|
| 855 |
+
"text": [
|
| 856 |
+
"To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
|
| 857 |
+
" 0e015ab..e37f692 main -> main\n",
|
| 858 |
+
"\n"
|
| 859 |
+
]
|
| 860 |
+
},
|
| 861 |
+
{
|
| 862 |
+
"output_type": "execute_result",
|
| 863 |
+
"data": {
|
| 864 |
+
"application/vnd.google.colaboratory.intrinsic+json": {
|
| 865 |
+
"type": "string"
|
| 866 |
+
},
|
| 867 |
+
"text/plain": [
|
| 868 |
+
"'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/e37f69227324cae395ac0075b8bee416685d2c54'"
|
| 869 |
+
]
|
| 870 |
+
},
|
| 871 |
+
"metadata": {},
|
| 872 |
+
"execution_count": 12
|
| 873 |
+
}
|
| 874 |
+
]
|
| 875 |
+
},
|
| 876 |
+
{
|
| 877 |
+
"cell_type": "code",
|
| 878 |
+
"source": [
|
| 879 |
+
"push_to_hub_keras(model = critic_model, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
|
| 880 |
+
],
|
| 881 |
+
"metadata": {
|
| 882 |
+
"id": "89Cj-m50BQKv",
|
| 883 |
+
"outputId": "56899efb-3dda-4ca6-8656-f9060741b9b3",
|
| 884 |
+
"colab": {
|
| 885 |
+
"base_uri": "https://localhost:8080/",
|
| 886 |
+
"height": 161
|
| 887 |
+
}
|
| 888 |
+
},
|
| 889 |
+
"execution_count": 13,
|
| 890 |
+
"outputs": [
|
| 891 |
+
{
|
| 892 |
+
"output_type": "stream",
|
| 893 |
+
"name": "stderr",
|
| 894 |
+
"text": [
|
| 895 |
+
"/content/deep-deterministic-policy-gradient is already a clone of https://huggingface.co/keras-io/deep-deterministic-policy-gradient. Make sure you pull the latest changes with `repo.git_pull()`.\n"
|
| 896 |
+
]
|
| 897 |
+
},
|
| 898 |
+
{
|
| 899 |
+
"output_type": "stream",
|
| 900 |
+
"name": "stdout",
|
| 901 |
+
"text": [
|
| 902 |
+
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 903 |
+
"INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
|
| 904 |
+
]
|
| 905 |
+
},
|
| 906 |
+
{
|
| 907 |
+
"output_type": "stream",
|
| 908 |
+
"name": "stderr",
|
| 909 |
+
"text": [
|
| 910 |
+
"To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
|
| 911 |
+
" e37f692..fc4c3b0 main -> main\n",
|
| 912 |
+
"\n"
|
| 913 |
+
]
|
| 914 |
+
},
|
| 915 |
+
{
|
| 916 |
+
"output_type": "execute_result",
|
| 917 |
+
"data": {
|
| 918 |
+
"application/vnd.google.colaboratory.intrinsic+json": {
|
| 919 |
+
"type": "string"
|
| 920 |
+
},
|
| 921 |
+
"text/plain": [
|
| 922 |
+
"'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/fc4c3b0eadf2d9d2e6ff7a59f4e1f99763d973fe'"
|
| 923 |
+
]
|
| 924 |
+
},
|
| 925 |
+
"metadata": {},
|
| 926 |
+
"execution_count": 13
|
| 927 |
+
}
|
| 928 |
+
]
|
| 929 |
+
},
|
| 930 |
+
{
|
| 931 |
+
"cell_type": "code",
|
| 932 |
+
"source": [
|
| 933 |
+
"push_to_hub_keras(model = target_actor, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
|
| 934 |
+
],
|
| 935 |
+
"metadata": {
|
| 936 |
+
"id": "wv-epAixBYAJ",
|
| 937 |
+
"outputId": "9c62ad0a-1523-4ba3-ced9-b6d8d3e1cbdc",
|
| 938 |
+
"colab": {
|
| 939 |
+
"base_uri": "https://localhost:8080/",
|
| 940 |
+
"height": 161
|
| 941 |
+
}
|
| 942 |
+
},
|
| 943 |
+
"execution_count": 14,
|
| 944 |
+
"outputs": [
|
| 945 |
+
{
|
| 946 |
+
"output_type": "stream",
|
| 947 |
+
"name": "stderr",
|
| 948 |
+
"text": [
|
| 949 |
+
"/content/deep-deterministic-policy-gradient is already a clone of https://huggingface.co/keras-io/deep-deterministic-policy-gradient. Make sure you pull the latest changes with `repo.git_pull()`.\n"
|
| 950 |
+
]
|
| 951 |
+
},
|
| 952 |
+
{
|
| 953 |
+
"output_type": "stream",
|
| 954 |
+
"name": "stdout",
|
| 955 |
+
"text": [
|
| 956 |
+
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 957 |
+
"INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
|
| 958 |
+
]
|
| 959 |
+
},
|
| 960 |
+
{
|
| 961 |
+
"output_type": "stream",
|
| 962 |
+
"name": "stderr",
|
| 963 |
+
"text": [
|
| 964 |
+
"To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
|
| 965 |
+
" fc4c3b0..e34067a main -> main\n",
|
| 966 |
+
"\n"
|
| 967 |
+
]
|
| 968 |
+
},
|
| 969 |
+
{
|
| 970 |
+
"output_type": "execute_result",
|
| 971 |
+
"data": {
|
| 972 |
+
"application/vnd.google.colaboratory.intrinsic+json": {
|
| 973 |
+
"type": "string"
|
| 974 |
+
},
|
| 975 |
+
"text/plain": [
|
| 976 |
+
"'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/e34067a57c76c29bf60d924f352e7d72708bec82'"
|
| 977 |
+
]
|
| 978 |
+
},
|
| 979 |
+
"metadata": {},
|
| 980 |
+
"execution_count": 14
|
| 981 |
+
}
|
| 982 |
+
]
|
| 983 |
+
},
|
| 984 |
+
{
|
| 985 |
+
"cell_type": "code",
|
| 986 |
+
"source": [
|
| 987 |
+
"push_to_hub_keras(model = target_critic, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
|
| 988 |
+
],
|
| 989 |
+
"metadata": {
|
| 990 |
+
"id": "3LVvvq2hBcfv",
|
| 991 |
+
"outputId": "c9b37d03-3f98-46d3-ee91-37e271bb5fbe",
|
| 992 |
+
"colab": {
|
| 993 |
+
"base_uri": "https://localhost:8080/",
|
| 994 |
+
"height": 161
|
| 995 |
+
}
|
| 996 |
+
},
|
| 997 |
+
"execution_count": 15,
|
| 998 |
+
"outputs": [
|
| 999 |
+
{
|
| 1000 |
+
"output_type": "stream",
|
| 1001 |
+
"name": "stderr",
|
| 1002 |
+
"text": [
|
| 1003 |
+
"/content/deep-deterministic-policy-gradient is already a clone of https://huggingface.co/keras-io/deep-deterministic-policy-gradient. Make sure you pull the latest changes with `repo.git_pull()`.\n"
|
| 1004 |
+
]
|
| 1005 |
+
},
|
| 1006 |
+
{
|
| 1007 |
+
"output_type": "stream",
|
| 1008 |
+
"name": "stdout",
|
| 1009 |
+
"text": [
|
| 1010 |
+
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 1011 |
+
"INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
|
| 1012 |
+
]
|
| 1013 |
+
},
|
| 1014 |
+
{
|
| 1015 |
+
"output_type": "stream",
|
| 1016 |
+
"name": "stderr",
|
| 1017 |
+
"text": [
|
| 1018 |
+
"To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
|
| 1019 |
+
" e34067a..10b396f main -> main\n",
|
| 1020 |
+
"\n"
|
| 1021 |
+
]
|
| 1022 |
+
},
|
| 1023 |
+
{
|
| 1024 |
+
"output_type": "execute_result",
|
| 1025 |
+
"data": {
|
| 1026 |
+
"application/vnd.google.colaboratory.intrinsic+json": {
|
| 1027 |
+
"type": "string"
|
| 1028 |
+
},
|
| 1029 |
+
"text/plain": [
|
| 1030 |
+
"'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/10b396f3c297b2359d5b5e96f2b78a03943ec833'"
|
| 1031 |
+
]
|
| 1032 |
+
},
|
| 1033 |
+
"metadata": {},
|
| 1034 |
+
"execution_count": 15
|
| 1035 |
+
}
|
| 1036 |
+
]
|
| 1037 |
+
},
|
| 1038 |
+
{
|
| 1039 |
+
"cell_type": "code",
|
| 1040 |
+
"source": [
|
| 1041 |
+
""
|
| 1042 |
+
],
|
| 1043 |
+
"metadata": {
|
| 1044 |
+
"id": "yzwDvkqZBfFJ"
|
| 1045 |
+
},
|
| 1046 |
+
"execution_count": null,
|
| 1047 |
+
"outputs": []
|
| 1048 |
+
}
|
| 1049 |
+
],
|
| 1050 |
+
"metadata": {
|
| 1051 |
+
"colab": {
|
| 1052 |
+
"collapsed_sections": [],
|
| 1053 |
+
"name": "ddpg_pendulum",
|
| 1054 |
+
"provenance": [],
|
| 1055 |
+
"toc_visible": true
|
| 1056 |
+
},
|
| 1057 |
+
"kernelspec": {
|
| 1058 |
+
"display_name": "Python 3",
|
| 1059 |
+
"language": "python",
|
| 1060 |
+
"name": "python3"
|
| 1061 |
+
},
|
| 1062 |
+
"language_info": {
|
| 1063 |
+
"codemirror_mode": {
|
| 1064 |
+
"name": "ipython",
|
| 1065 |
+
"version": 3
|
| 1066 |
+
},
|
| 1067 |
+
"file_extension": ".py",
|
| 1068 |
+
"mimetype": "text/x-python",
|
| 1069 |
+
"name": "python",
|
| 1070 |
+
"nbconvert_exporter": "python",
|
| 1071 |
+
"pygments_lexer": "ipython3",
|
| 1072 |
+
"version": "3.7.0"
|
| 1073 |
+
}
|
| 1074 |
+
},
|
| 1075 |
+
"nbformat": 4,
|
| 1076 |
+
"nbformat_minor": 0
|
| 1077 |
+
}
|