Stoub commited on
Commit
94b78bf
·
verified ·
1 Parent(s): 398bbdb

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/requirements-unit4-checkpoint.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # git+https://github.com/ntasfi/PyGame-Learning-Environment.git
2
+ # git+https://github.com/simoninithomas/gym-games
3
+ huggingface_hub
4
+ imageio-ffmpeg
5
+ # pyyaml==6.0
.ipynb_checkpoints/unit4-checkpoint.ipynb ADDED
@@ -0,0 +1,1786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/huggingface/deep-rl-class/blob/GymnasiumUpdate%2FUnit4/notebooks/unit4/unit4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "CjRWziAVU2lZ"
17
+ },
18
+ "source": [
19
+ "# Unit 4: Code your first Deep Reinforcement Learning Algorithm with PyTorch: Reinforce. And test its robustness 💪\n",
20
+ "\n",
21
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/thumbnail.png\" alt=\"thumbnail\"/>\n",
22
+ "\n",
23
+ "\n",
24
+ "In this notebook, you'll code your first Deep Reinforcement Learning algorithm from scratch: Reinforce (also called Monte Carlo Policy Gradient).\n",
25
+ "\n",
26
+ "Reinforce is a *Policy-based method*: a Deep Reinforcement Learning algorithm that tries **to optimize the policy directly without using an action-value function**.\n",
27
+ "\n",
28
+ "More precisely, Reinforce is a *Policy-gradient method*, a subclass of *Policy-based methods* that aims **to optimize the policy directly by estimating the weights of the optimal policy using gradient ascent**.\n",
29
+ "\n",
30
+ "To test its robustness, we're going to train it in 2 different simple environments:\n",
31
+ "- Cartpole-v1\n",
32
+ "- PixelcopterEnv\n",
33
+ "\n",
34
+ "⬇️ Here is an example of what **you will achieve at the end of this notebook.** ⬇️"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {
40
+ "id": "s4rBom2sbo7S"
41
+ },
42
+ "source": [
43
+ " <img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/envs.gif\" alt=\"Environments\"/>\n"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {
49
+ "id": "BPLwsPajb1f8"
50
+ },
51
+ "source": [
52
+ "### 🎮 Environments: \n",
53
+ "\n",
54
+ "- [CartPole-v1](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n",
55
+ "- [PixelCopter](https://pygame-learning-environment.readthedocs.io/en/latest/user/games/pixelcopter.html)\n",
56
+ "\n",
57
+ "### 📚 RL-Library: \n",
58
+ "\n",
59
+ "- Python\n",
60
+ "- PyTorch\n",
61
+ "\n",
62
+ "\n",
63
+ "We're constantly trying to improve our tutorials, so **if you find some issues in this notebook**, please [open an issue on the GitHub Repo](https://github.com/huggingface/deep-rl-class/issues)."
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "metadata": {
69
+ "id": "L_WSo0VUV99t"
70
+ },
71
+ "source": [
72
+ "## Objectives of this notebook 🏆\n",
73
+ "At the end of the notebook, you will:\n",
74
+ "- Be able to **code from scratch a Reinforce algorithm using PyTorch.**\n",
75
+ "- Be able to **test the robustness of your agent using simple environments.**\n",
76
+ "- Be able to **push your trained agent to the Hub** with a nice video replay and an evaluation score 🔥."
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {
82
+ "id": "lEPrZg2eWa4R"
83
+ },
84
+ "source": [
85
+ "## This notebook is from the Deep Reinforcement Learning Course\n",
86
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/deep-rl-course-illustration.jpg\" alt=\"Deep RL Course illustration\"/>"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "metadata": {
92
+ "id": "6p5HnEefISCB"
93
+ },
94
+ "source": [
95
+ "In this free course, you will:\n",
96
+ "\n",
97
+ "- 📖 Study Deep Reinforcement Learning in **theory and practice**.\n",
98
+ "- 🧑‍💻 Learn to **use famous Deep RL libraries** such as Stable Baselines3, RL Baselines3 Zoo, CleanRL and Sample Factory 2.0.\n",
99
+ "- 🤖 Train **agents in unique environments** \n",
100
+ "\n",
101
+ "And more check 📚 the syllabus 👉 https://simoninithomas.github.io/deep-rl-course\n",
102
+ "\n",
103
+ "Don’t forget to **<a href=\"http://eepurl.com/ic5ZUD\">sign up to the course</a>** (we are collecting your email to be able to **send you the links when each Unit is published and give you information about the challenges and updates).**\n",
104
+ "\n",
105
+ "\n",
106
+ "The best way to keep in touch is to join our discord server to exchange with the community and with us 👉🏻 https://discord.gg/ydHrjt3WP5"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {
112
+ "id": "mjY-eq3eWh9O"
113
+ },
114
+ "source": [
115
+ "## Prerequisites 🏗️\n",
116
+ "Before diving into the notebook, you need to:\n",
117
+ "\n",
118
+ "🔲 📚 [Study Policy Gradients by reading Unit 4](https://huggingface.co/deep-rl-course/unit4/introduction)"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {
124
+ "id": "Bsh4ZAamchSl"
125
+ },
126
+ "source": [
127
+ "# Let's code Reinforce algorithm from scratch 🔥\n",
128
+ "\n",
129
+ "\n",
130
+ "To validate this hands-on for the certification process, you need to push your trained models to the Hub.\n",
131
+ "\n",
132
+ "- Get a result of >= 350 for `Cartpole-v1`.\n",
133
+ "- Get a result of >= 5 for `PixelCopter`.\n",
134
+ "\n",
135
+ "To find your result, go to the leaderboard and find your model, **the result = mean_reward - std of reward**. **If you don't see your model on the leaderboard, go at the bottom of the leaderboard page and click on the refresh button**.\n",
136
+ "\n",
137
+ "For more information about the certification process, check this section 👉 https://huggingface.co/deep-rl-course/en/unit0/introduction#certification-process\n"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "metadata": {
143
+ "id": "JoTC9o2SczNn",
144
+ "jp-MarkdownHeadingCollapsed": true
145
+ },
146
+ "source": [
147
+ "## An colab advice 💡"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {},
153
+ "source": [
154
+ "It's better to run this colab in a copy on your Google Drive, so that **if it timeouts** you still have the saved notebook on your Google Drive and do not need to fill everything from scratch.\n",
155
+ "\n",
156
+ "To do that you can either do `Ctrl + S` or `File > Save a copy in Google Drive.`"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "metadata": {
162
+ "id": "PU4FVzaoM6fC",
163
+ "jp-MarkdownHeadingCollapsed": true
164
+ },
165
+ "source": [
166
+ "## Set the GPU 💪"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {},
172
+ "source": [
173
+ "- To **accelerate the agent's training, we'll use a GPU**. To do that, go to `Runtime > Change Runtime type`\n",
174
+ "\n",
175
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/gpu-step1.jpg\" alt=\"GPU Step 1\">"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {
181
+ "id": "KV0NyFdQM9ZG"
182
+ },
183
+ "source": [
184
+ "- `Hardware Accelerator > GPU`\n",
185
+ "\n",
186
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/gpu-step2.jpg\" alt=\"GPU Step 2\">"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {
192
+ "id": "bTpYcVZVMzUI"
193
+ },
194
+ "source": [
195
+ "## Create a virtual display 🖥\n",
196
+ "\n",
197
+ "During the notebook, we'll need to generate a replay video. To do so, with colab, **we need to have a virtual screen to be able to render the environment** (and thus record the frames). \n",
198
+ "\n",
199
+ "Hence the following cell will install the librairies and create and run a virtual screen 🖥"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 35,
205
+ "metadata": {
206
+ "id": "jV6wjQ7Be7p5",
207
+ "scrolled": true
208
+ },
209
+ "outputs": [
210
+ {
211
+ "name": "stdout",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "Requirement already satisfied: pyvirtualdisplay in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (3.0)\n",
215
+ "Requirement already satisfied: pyglet==1.5.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (1.5.1)\n",
216
+ "Requirement already satisfied: huggingface_hub in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (0.25.1)\n",
217
+ "Requirement already satisfied: filelock in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (3.13.1)\n",
218
+ "Requirement already satisfied: fsspec>=2023.5.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (2024.2.0)\n",
219
+ "Requirement already satisfied: packaging>=20.9 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (24.1)\n",
220
+ "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (6.0.1)\n",
221
+ "Requirement already satisfied: requests in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (2.32.3)\n",
222
+ "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (4.66.5)\n",
223
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (4.11.0)\n",
224
+ "Requirement already satisfied: colorama in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from tqdm>=4.42.1->huggingface_hub) (0.4.6)\n",
225
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (3.3.2)\n",
226
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (3.7)\n",
227
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (2.2.2)\n",
228
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (2024.8.30)\n",
229
+ "Requirement already satisfied: gym in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (0.26.2)\n",
230
+ "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from gym) (1.26.4)\n",
231
+ "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from gym) (3.0.0)\n",
232
+ "Requirement already satisfied: gym-notices>=0.0.4 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from gym) (0.0.8)\n",
233
+ "Requirement already satisfied: imageio[ffmpeg] in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (2.35.1)\n",
234
+ "Requirement already satisfied: numpy in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (1.26.4)\n",
235
+ "Requirement already satisfied: pillow>=8.3.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (10.2.0)\n",
236
+ "Requirement already satisfied: imageio-ffmpeg in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (0.5.1)\n",
237
+ "Requirement already satisfied: psutil in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (5.9.0)\n",
238
+ "Requirement already satisfied: setuptools in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio-ffmpeg->imageio[ffmpeg]) (75.1.0)\n",
239
+ "Collecting moviepy\n",
240
+ " Downloading moviepy-1.0.3.tar.gz (388 kB)\n",
241
+ " Preparing metadata (setup.py): started\n",
242
+ " Preparing metadata (setup.py): finished with status 'done'\n",
243
+ "Collecting decorator<5.0,>=4.0.2 (from moviepy)\n",
244
+ " Downloading decorator-4.4.2-py2.py3-none-any.whl.metadata (4.2 kB)\n",
245
+ "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (2.35.1)\n",
246
+ "Requirement already satisfied: imageio_ffmpeg>=0.2.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (0.5.1)\n",
247
+ "Requirement already satisfied: tqdm<5.0,>=4.11.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (4.66.5)\n",
248
+ "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (1.26.4)\n",
249
+ "Requirement already satisfied: requests<3.0,>=2.8.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (2.32.3)\n",
250
+ "Collecting proglog<=1.0.0 (from moviepy)\n",
251
+ " Downloading proglog-0.1.10-py3-none-any.whl.metadata (639 bytes)\n",
252
+ "Requirement already satisfied: pillow>=8.3.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio<3.0,>=2.5->moviepy) (10.2.0)\n",
253
+ "Requirement already satisfied: setuptools in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio_ffmpeg>=0.2.0->moviepy) (75.1.0)\n",
254
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.3.2)\n",
255
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.7)\n",
256
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2.2.2)\n",
257
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2024.8.30)\n",
258
+ "Requirement already satisfied: colorama in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from tqdm<5.0,>=4.11.2->moviepy) (0.4.6)\n",
259
+ "Downloading decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)\n",
260
+ "Downloading proglog-0.1.10-py3-none-any.whl (6.1 kB)\n",
261
+ "Building wheels for collected packages: moviepy\n",
262
+ " Building wheel for moviepy (setup.py): started\n",
263
+ " Building wheel for moviepy (setup.py): finished with status 'done'\n",
264
+ " Created wheel for moviepy: filename=moviepy-1.0.3-py3-none-any.whl size=110755 sha256=ac8e10668e9a35dafcbf1df54a5f1f5baf0ad0efea2c754e1cfef089a99b6eb0\n",
265
+ " Stored in directory: c:\\users\\utilisateur\\appdata\\local\\pip\\cache\\wheels\\df\\ba\\4b\\0917fc0c8833c8ba7016565fc975b74c67bc8610806e930272\n",
266
+ "Successfully built moviepy\n",
267
+ "Installing collected packages: decorator, proglog, moviepy\n",
268
+ " Attempting uninstall: decorator\n",
269
+ " Found existing installation: decorator 5.1.1\n",
270
+ " Uninstalling decorator-5.1.1:\n",
271
+ " Successfully uninstalled decorator-5.1.1\n",
272
+ "Successfully installed decorator-4.4.2 moviepy-1.0.3 proglog-0.1.10\n"
273
+ ]
274
+ }
275
+ ],
276
+ "source": [
277
+ "!pip install pyvirtualdisplay\n",
278
+ "!pip install pyglet==1.5.1\n",
279
+ "!pip install huggingface_hub\n",
280
+ "!pip install gym --upgrade\n",
281
+ "!pip install imageio[ffmpeg]\n",
282
+ "!pip install moviepy"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "metadata": {
288
+ "id": "tjrLfPFIW8XK"
289
+ },
290
+ "source": [
291
+ "## Install the dependencies 🔽\n",
292
+ "The first step is to install the dependencies. We’ll install multiple ones:\n",
293
+ "\n",
294
+ "- `gym`\n",
295
+ "- `gym-games`: Extra gym environments made with PyGame.\n",
296
+ "- `huggingface_hub`: 🤗 works as a central place where anyone can share and explore models and datasets. It has versioning, metrics, visualizations, and other features that will allow you to easily collaborate with others.\n",
297
+ "\n",
298
+ "You may be wondering why we install gym and not gymnasium, a more recent version of gym? **Because the gym-games we are using are not updated yet with gymnasium**. \n",
299
+ "\n",
300
+ "The differences you'll encounter here:\n",
301
+ "- In `gym` we don't have `terminated` and `truncated` but only `done`.\n",
302
+ "- In `gym` using `env.step()` returns `state, reward, done, info`\n",
303
+ "\n",
304
+ "You can learn more about the differences between Gym and Gymnasium here 👉 https://gymnasium.farama.org/content/migration-guide/\n",
305
+ "\n",
306
+ "\n",
307
+ "You can see here all the Reinforce models available 👉 https://huggingface.co/models?other=reinforce\n",
308
+ "\n",
309
+ "And you can find all the Deep Reinforcement Learning models here 👉 https://huggingface.co/models?pipeline_tag=reinforcement-learning\n"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "markdown",
314
+ "metadata": {
315
+ "id": "AAHAq6RZW3rn"
316
+ },
317
+ "source": [
318
+ "## Import the packages 📦\n",
319
+ "In addition to import the installed libraries, we also import:\n",
320
+ "\n",
321
+ "- `imageio`: A library that will help us to generate a replay video\n",
322
+ "\n"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": 2,
328
+ "metadata": {
329
+ "id": "V8oadoJSWp7C"
330
+ },
331
+ "outputs": [],
332
+ "source": [
333
+ "import numpy as np\n",
334
+ "\n",
335
+ "from collections import deque\n",
336
+ "\n",
337
+ "import matplotlib.pyplot as plt\n",
338
+ "%matplotlib inline\n",
339
+ "\n",
340
+ "# PyTorch\n",
341
+ "import torch\n",
342
+ "import torch.nn as nn\n",
343
+ "import torch.nn.functional as F\n",
344
+ "import torch.optim as optim\n",
345
+ "from torch.distributions import Categorical\n",
346
+ "\n",
347
+ "# Gym\n",
348
+ "import gym\n",
349
+ "import gym_pygame\n",
350
+ "\n",
351
+ "# Hugging Face Hub\n",
352
+ "from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.\n",
353
+ "import imageio"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "markdown",
358
+ "metadata": {
359
+ "id": "RfxJYdMeeVgv"
360
+ },
361
+ "source": [
362
+ "## Check if we have a GPU\n",
363
+ "\n",
364
+ "- Let's check if we have a GPU\n",
365
+ "- If it's the case you should see `device:cuda0`"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 3,
371
+ "metadata": {
372
+ "id": "kaJu5FeZxXGY"
373
+ },
374
+ "outputs": [],
375
+ "source": [
376
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 4,
382
+ "metadata": {
383
+ "id": "U5TNYa14aRav"
384
+ },
385
+ "outputs": [
386
+ {
387
+ "name": "stdout",
388
+ "output_type": "stream",
389
+ "text": [
390
+ "cuda:0\n"
391
+ ]
392
+ }
393
+ ],
394
+ "source": [
395
+ "print(device)"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "markdown",
400
+ "metadata": {
401
+ "id": "PBPecCtBL_pZ"
402
+ },
403
+ "source": [
404
+ "We're now ready to implement our Reinforce algorithm 🔥"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "markdown",
409
+ "metadata": {
410
+ "id": "8KEyKYo2ZSC-"
411
+ },
412
+ "source": [
413
+ "# First agent: Playing CartPole-v1 🤖"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "markdown",
418
+ "metadata": {
419
+ "id": "haLArKURMyuF"
420
+ },
421
+ "source": [
422
+ "## Create the CartPole environment and understand how it works\n",
423
+ "### [The environment 🎮](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "markdown",
428
+ "metadata": {
429
+ "id": "AH_TaLKFXo_8"
430
+ },
431
+ "source": [
432
+ "### Why do we use a simple environment like CartPole-v1?\n",
433
+ "As explained in [Reinforcement Learning Tips and Tricks](https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html), when you implement your agent from scratch you need **to be sure that it works correctly and find bugs with easy environments before going deeper**. Since finding bugs will be much easier in simple environments.\n",
434
+ "\n",
435
+ "\n",
436
+ "> Try to have some “sign of life” on toy problems\n",
437
+ "\n",
438
+ "\n",
439
+ "> Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo). You usually need to run hyperparameter optimization for that step.\n",
440
+ "___\n",
441
+ "### The CartPole-v1 environment\n",
442
+ "\n",
443
+ "> A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces in the left and right direction on the cart.\n",
444
+ "\n",
445
+ "\n",
446
+ "\n",
447
+ "So, we start with CartPole-v1. The goal is to push the cart left or right **so that the pole stays in the equilibrium.**\n",
448
+ "\n",
449
+ "The episode ends if:\n",
450
+ "- The pole Angle is greater than ±12°\n",
451
+ "- Cart Position is greater than ±2.4\n",
452
+ "- Episode length is greater than 500\n",
453
+ "\n",
454
+ "We get a reward 💰 of +1 every timestep the Pole stays in the equilibrium."
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": 37,
460
+ "metadata": {
461
+ "id": "POOOk15_K6KA"
462
+ },
463
+ "outputs": [],
464
+ "source": [
465
+ "env_id = \"CartPole-v1\"\n",
466
+ "# Create the env\n",
467
+ "env = gym.make(env_id)\n",
468
+ "\n",
469
+ "# Create the evaluation env\n",
470
+ "eval_env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")\n",
471
+ "\n",
472
+ "# Get the state space and action space\n",
473
+ "s_size = env.observation_space.shape[0]\n",
474
+ "a_size = env.action_space.n"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "code",
479
+ "execution_count": 6,
480
+ "metadata": {
481
+ "id": "FMLFrjiBNLYJ"
482
+ },
483
+ "outputs": [
484
+ {
485
+ "name": "stdout",
486
+ "output_type": "stream",
487
+ "text": [
488
+ "_____OBSERVATION SPACE_____ \n",
489
+ "\n",
490
+ "The State Space is: 4\n",
491
+ "Sample observation [ 1.6458762e+00 3.1792375e+38 3.1316426e-01 -1.6702383e+38]\n"
492
+ ]
493
+ }
494
+ ],
495
+ "source": [
496
+ "print(\"_____OBSERVATION SPACE_____ \\n\")\n",
497
+ "print(\"The State Space is: \", s_size)\n",
498
+ "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "execution_count": 7,
504
+ "metadata": {
505
+ "id": "Lu6t4sRNNWkN"
506
+ },
507
+ "outputs": [
508
+ {
509
+ "name": "stdout",
510
+ "output_type": "stream",
511
+ "text": [
512
+ "\n",
513
+ " _____ACTION SPACE_____ \n",
514
+ "\n",
515
+ "The Action Space is: 2\n",
516
+ "Action Space Sample 1\n"
517
+ ]
518
+ }
519
+ ],
520
+ "source": [
521
+ "print(\"\\n _____ACTION SPACE_____ \\n\")\n",
522
+ "print(\"The Action Space is: \", a_size)\n",
523
+ "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
524
+ ]
525
+ },
526
+ {
527
+ "cell_type": "markdown",
528
+ "metadata": {
529
+ "id": "7SJMJj3WaFOz"
530
+ },
531
+ "source": [
532
+ "## Let's build the Reinforce Architecture\n",
533
+ "This implementation is based on two implementations:\n",
534
+ "- [PyTorch official Reinforcement Learning example](https://github.com/pytorch/examples/blob/main/reinforcement_learning/reinforce.py)\n",
535
+ "- [Udacity Reinforce](https://github.com/udacity/deep-reinforcement-learning/blob/master/reinforce/REINFORCE.ipynb)\n",
536
+ "- [Improvement of the integration by Chris1nexus](https://github.com/huggingface/deep-rl-class/pull/95)\n",
537
+ "\n",
538
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/reinforce.png\" alt=\"Reinforce\"/>"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "markdown",
543
+ "metadata": {
544
+ "id": "49kogtxBODX8"
545
+ },
546
+ "source": [
547
+ "So we want:\n",
548
+ "- Two fully connected layers (fc1 and fc2).\n",
549
+ "- Using ReLU as activation function of fc1\n",
550
+ "- Using Softmax to output a probability distribution over actions"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": 8,
556
+ "metadata": {
557
+ "id": "w2LHcHhVZvPZ"
558
+ },
559
+ "outputs": [],
560
+ "source": [
561
+ "class Policy(nn.Module):\n",
562
+ " def __init__(self, s_size, a_size, h_size):\n",
563
+ " super(Policy, self).__init__()\n",
564
+ " # Create two fully connected layers\n",
565
+ " self.fc1 = nn.Linear(s_size,h_size)\n",
566
+ " self.fc2 = nn.Linear(h_size,a_size)\n",
567
+ "\n",
568
+ " def forward(self, x):\n",
569
+ " # Define the forward pass\n",
570
+ " # state goes to fc1 then we apply ReLU activation function\n",
571
+ " x = self.fc1(x)\n",
572
+ " x = F.relu(x)\n",
573
+ " # fc1 outputs goes to fc2\n",
574
+ " x = self.fc2(x)\n",
575
+ " # We output the softmax\n",
576
+ " x = F.softmax(x, dim=1)\n",
577
+ " return(x)\n",
578
+ " \n",
579
+ " def act(self, state):\n",
580
+ " \"\"\"\n",
581
+ " Given a state, take action\n",
582
+ " \"\"\"\n",
583
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
584
+ " probs = self.forward(state).cpu()\n",
585
+ " m = Categorical(probs)\n",
586
+ " action = m.sample()\n",
587
+ " return action.item(), m.log_prob(action)"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "markdown",
592
+ "metadata": {
593
+ "id": "rOMrdwSYOWSC",
594
+ "jp-MarkdownHeadingCollapsed": true
595
+ },
596
+ "source": [
597
+ "### Solution"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": 81,
603
+ "metadata": {
604
+ "id": "jGdhRSVrOV4K"
605
+ },
606
+ "outputs": [],
607
+ "source": [
608
+ "class Policy(nn.Module):\n",
609
+ " def __init__(self, s_size, a_size, h_size):\n",
610
+ " super(Policy, self).__init__()\n",
611
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
612
+ " self.fc2 = nn.Linear(h_size, a_size)\n",
613
+ "\n",
614
+ " def forward(self, x):\n",
615
+ " x = F.relu(self.fc1(x))\n",
616
+ " x = self.fc2(x)\n",
617
+ " return F.softmax(x, dim=1)\n",
618
+ " \n",
619
+ " def act(self, state):\n",
620
+ " # if isinstance(state, tuple):\n",
621
+ " # state = state[0]\n",
622
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
623
+ " probs = self.forward(state).cpu()\n",
624
+ " m = Categorical(probs)\n",
625
+ " action = np.argmax(m)\n",
626
+ " return action.item(), m.log_prob(action)"
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "markdown",
631
+ "metadata": {
632
+ "id": "ZTGWL4g2eM5B"
633
+ },
634
+ "source": [
635
+ "I make a mistake, can you guess where?\n",
636
+ "\n",
637
+ "- To find out let's make a forward pass:"
638
+ ]
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": 84,
643
+ "metadata": {
644
+ "id": "lwnqGBCNePor"
645
+ },
646
+ "outputs": [
647
+ {
648
+ "data": {
649
+ "text/plain": [
650
+ "(1, tensor([-0.6791], grad_fn=<SqueezeBackward1>))"
651
+ ]
652
+ },
653
+ "execution_count": 84,
654
+ "metadata": {},
655
+ "output_type": "execute_result"
656
+ }
657
+ ],
658
+ "source": [
659
+ "debug_policy = Policy(s_size, a_size, 64).to(device)\n",
660
+ "state, _ = env.reset()\n",
661
+ "debug_policy.act(state)"
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "markdown",
666
+ "metadata": {
667
+ "id": "14UYkoxCPaor"
668
+ },
669
+ "source": [
670
+ "- Here we see that the error says `ValueError: The value argument to log_prob must be a Tensor`\n",
671
+ "\n",
672
+ "- It means that `action` in `m.log_prob(action)` must be a Tensor **but it's not.**\n",
673
+ "\n",
674
+ "- Do you know why? Check the act function and try to see why it does not work. \n",
675
+ "\n",
676
+ "Advice 💡: Something is wrong in this implementation. Remember that we act function **we want to sample an action from the probability distribution over actions**.\n"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "markdown",
681
+ "metadata": {
682
+ "id": "gfGJNZBUP7Vn",
683
+ "jp-MarkdownHeadingCollapsed": true
684
+ },
685
+ "source": [
686
+ "### (Real) Solution"
687
+ ]
688
+ },
689
+ {
690
+ "cell_type": "code",
691
+ "execution_count": 83,
692
+ "metadata": {
693
+ "id": "Ho_UHf49N9i4"
694
+ },
695
+ "outputs": [],
696
+ "source": [
697
+ "class Policy(nn.Module):\n",
698
+ " def __init__(self, s_size, a_size, h_size):\n",
699
+ " super(Policy, self).__init__()\n",
700
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
701
+ " self.fc2 = nn.Linear(h_size, a_size)\n",
702
+ "\n",
703
+ " def forward(self, x):\n",
704
+ " x = F.relu(self.fc1(x))\n",
705
+ " x = self.fc2(x)\n",
706
+ " return F.softmax(x, dim=1)\n",
707
+ " \n",
708
+ " def act(self, state):\n",
709
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
710
+ " probs = self.forward(state).cpu()\n",
711
+ " m = Categorical(probs)\n",
712
+ " action = m.sample()\n",
713
+ " return action.item(), m.log_prob(action)"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "markdown",
718
+ "metadata": {
719
+ "id": "rgJWQFU_eUYw"
720
+ },
721
+ "source": [
722
+ "By using CartPole, it was easier to debug since **we know that the bug comes from our integration and not from our simple environment**."
723
+ ]
724
+ },
725
+ {
726
+ "cell_type": "markdown",
727
+ "metadata": {
728
+ "id": "c-20i7Pk0l1T"
729
+ },
730
+ "source": [
731
+ "- Since **we want to sample an action from the probability distribution over actions**, we can't use `action = np.argmax(m)` since it will always output the action that have the highest probability.\n",
732
+ "\n",
733
+ "- We need to replace with `action = m.sample()` that will sample an action from the probability distribution P(.|s)"
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "markdown",
738
+ "metadata": {
739
+ "id": "4MXoqetzfIoW"
740
+ },
741
+ "source": [
742
+ "### Let's build the Reinforce Training Algorithm\n",
743
+ "This is the Reinforce algorithm pseudocode:\n",
744
+ "\n",
745
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/pg_pseudocode.png\" alt=\"Policy gradient pseudocode\"/>\n",
746
+ " "
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "markdown",
751
+ "metadata": {
752
+ "id": "QmcXG-9i2Qu2"
753
+ },
754
+ "source": [
755
+ "- When we calculate the return Gt (line 6) we see that we calculate the sum of discounted rewards **starting at timestep t**.\n",
756
+ "\n",
757
+ "- Why? Because our policy should only **reinforce actions on the basis of the consequences**: so rewards obtained before taking an action are useless (since they were not because of the action), **only the ones that come after the action matters**.\n",
758
+ "\n",
759
+ "- Before coding this you should read this section [don't let the past distract you](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#don-t-let-the-past-distract-you) that explains why we use reward-to-go policy gradient.\n",
760
+ "\n",
761
+ "We use an interesting technique coded by [Chris1nexus](https://github.com/Chris1nexus) to **compute the return at each timestep efficiently**. The comments explained the procedure. Don't hesitate also [to check the PR explanation](https://github.com/huggingface/deep-rl-class/pull/95)\n",
762
+ "But overall the idea is to **compute the return at each timestep efficiently**."
763
+ ]
764
+ },
765
+ {
766
+ "cell_type": "markdown",
767
+ "metadata": {
768
+ "id": "O554nUGPpcoq"
769
+ },
770
+ "source": [
771
+ "The second question you may ask is **why do we minimize the loss**? You talked about Gradient Ascent not Gradient Descent?\n",
772
+ "\n",
773
+ "- We want to maximize our utility function $J(\\theta)$ but in PyTorch like in Tensorflow it's better to **minimize an objective function.**\n",
774
+ " - So let's say we want to reinforce action 3 at a certain timestep. Before training this action P is 0.25.\n",
775
+ " - So we want to modify $\\theta$ such that $\\pi_\\theta(a_3|s; \\theta) > 0.25$\n",
776
+ " - Because all P must sum to 1, max $\\pi_\\theta(a_3|s; \\theta)$ will **minimize other action probability.**\n",
777
+ " - So we should tell PyTorch **to min $1 - \\pi_\\theta(a_3|s; \\theta)$.**\n",
778
+ " - This loss function approaches 0 as $\\pi_\\theta(a_3|s; \\theta)$ nears 1.\n",
779
+ " - So we are encouraging the gradient to max $\\pi_\\theta(a_3|s; \\theta)$\n"
780
+ ]
781
+ },
782
+ {
783
+ "cell_type": "code",
784
+ "execution_count": 13,
785
+ "metadata": {
786
+ "id": "iOdv8Q9NfLK7"
787
+ },
788
+ "outputs": [],
789
+ "source": [
790
+ "def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):\n",
791
+ " # Help us to calculate the score during the training\n",
792
+ " scores_deque = deque(maxlen=100)\n",
793
+ " scores = []\n",
794
+ " # Line 3 of pseudocode\n",
795
+ " for i_episode in range(1, n_training_episodes+1):\n",
796
+ " saved_log_probs = []\n",
797
+ " rewards = []\n",
798
+ " state, _ = env.reset()\n",
799
+ " # Line 4 of pseudocode\n",
800
+ " for t in range(max_t):\n",
801
+ " action, log_prob = policy.act(state)# TODO get the action\n",
802
+ " saved_log_probs.append(log_prob)\n",
803
+ " state, reward, terminated, truncated, info = env.step(action)\n",
804
+ " rewards.append(reward)\n",
805
+ " if bool(terminated) or bool(truncated):\n",
806
+ " break \n",
807
+ " scores_deque.append(sum(rewards))\n",
808
+ " scores.append(sum(rewards))\n",
809
+ " \n",
810
+ " # Line 6 of pseudocode: calculate the return\n",
811
+ " returns = deque(maxlen=max_t)\n",
812
+ " n_steps = len(rewards) \n",
813
+ " # Compute the discounted returns at each timestep,\n",
814
+ " # as the sum of the gamma-discounted return at time t (G_t) + the reward at time t\n",
815
+ " \n",
816
+ " # In O(N) time, where N is the number of time steps\n",
817
+ " # (this definition of the discounted return G_t follows the definition of this quantity \n",
818
+ " # shown at page 44 of Sutton&Barto 2017 2nd draft)\n",
819
+ " # G_t = r_(t+1) + r_(t+2) + ...\n",
820
+ " \n",
821
+ " # Given this formulation, the returns at each timestep t can be computed \n",
822
+ " # by re-using the computed future returns G_(t+1) to compute the current return G_t\n",
823
+ " # G_t = r_(t+1) + gamma*G_(t+1)\n",
824
+ " # G_(t-1) = r_t + gamma* G_t\n",
825
+ " # (this follows a dynamic programming approach, with which we memorize solutions in order \n",
826
+ " # to avoid computing them multiple times)\n",
827
+ " \n",
828
+ " # This is correct since the above is equivalent to (see also page 46 of Sutton&Barto 2017 2nd draft)\n",
829
+ " # G_(t-1) = r_t + gamma*r_(t+1) + gamma*gamma*r_(t+2) + ...\n",
830
+ " \n",
831
+ " \n",
832
+ " ## Given the above, we calculate the returns at timestep t as: \n",
833
+ " # gamma[t] * return[t] + reward[t]\n",
834
+ " #\n",
835
+ " ## We compute this starting from the last timestep to the first, in order\n",
836
+ " ## to employ the formula presented above and avoid redundant computations that would be needed \n",
837
+ " ## if we were to do it from first to last.\n",
838
+ " \n",
839
+ " ## Hence, the queue \"returns\" will hold the returns in chronological order, from t=0 to t=n_steps\n",
840
+ " ## thanks to the appendleft() function which allows to append to the position 0 in constant time O(1)\n",
841
+ " ## a normal python list would instead require O(N) to do this.\n",
842
+ " for t in range(n_steps)[::-1]:\n",
843
+ " disc_return_t = (returns[0] if len(returns)>0 else 0)\n",
844
+ " returns.appendleft(gamma * disc_return_t + rewards[t]) \n",
845
+ " \n",
846
+ " ## standardization of the returns is employed to make training more stable\n",
847
+ " eps = np.finfo(np.float32).eps.item()\n",
848
+ " \n",
849
+ " ## eps is the smallest representable float, which is \n",
850
+ " # added to the standard deviation of the returns to avoid numerical instabilities\n",
851
+ " returns = torch.tensor(returns)\n",
852
+ " returns = (returns - returns.mean()) / (returns.std() + eps)\n",
853
+ " \n",
854
+ " # Line 7:\n",
855
+ " policy_loss = []\n",
856
+ " for log_prob, disc_return in zip(saved_log_probs, returns):\n",
857
+ " policy_loss.append(-log_prob * disc_return)\n",
858
+ " policy_loss = torch.cat(policy_loss).sum()\n",
859
+ " \n",
860
+ " # Line 8: PyTorch prefers gradient descent \n",
861
+ " optimizer.zero_grad()\n",
862
+ " policy_loss.backward()\n",
863
+ " optimizer.step()\n",
864
+ " \n",
865
+ " if i_episode % print_every == 0:\n",
866
+ " print('Episode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))\n",
867
+ " \n",
868
+ " return scores"
869
+ ]
870
+ },
871
+ {
872
+ "cell_type": "markdown",
873
+ "metadata": {
874
+ "id": "YB0Cxrw1StrP",
875
+ "jp-MarkdownHeadingCollapsed": true
876
+ },
877
+ "source": [
878
+ "#### Solution"
879
+ ]
880
+ },
881
+ {
882
+ "cell_type": "code",
883
+ "execution_count": 89,
884
+ "metadata": {
885
+ "id": "NCNvyElRStWG"
886
+ },
887
+ "outputs": [],
888
+ "source": [
889
+ "def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):\n",
890
+ " # Help us to calculate the score during the training\n",
891
+ " scores_deque = deque(maxlen=100)\n",
892
+ " scores = []\n",
893
+ " # Line 3 of pseudocode\n",
894
+ " for i_episode in range(1, n_training_episodes+1):\n",
895
+ " saved_log_probs = []\n",
896
+ " rewards = []\n",
897
+ " state, _ = env.reset()\n",
898
+ " # Line 4 of pseudocode\n",
899
+ " for t in range(max_t):\n",
900
+ " action, log_prob = policy.act(state)\n",
901
+ " saved_log_probs.append(log_prob)\n",
902
+ " state, reward, terminated, truncated, _ = env.step(action)\n",
903
+ " rewards.append(reward)\n",
904
+ " done = terminated or truncated\n",
905
+ " if done:\n",
906
+ " break \n",
907
+ " scores_deque.append(sum(rewards))\n",
908
+ " scores.append(sum(rewards))\n",
909
+ " \n",
910
+ " # Line 6 of pseudocode: calculate the return\n",
911
+ " returns = deque(maxlen=max_t) \n",
912
+ " n_steps = len(rewards) \n",
913
+ " # Compute the discounted returns at each timestep,\n",
914
+ " # as \n",
915
+ " # the sum of the gamma-discounted return at time t (G_t) + the reward at time t\n",
916
+ " #\n",
917
+ " # In O(N) time, where N is the number of time steps\n",
918
+ " # (this definition of the discounted return G_t follows the definition of this quantity \n",
919
+ " # shown at page 44 of Sutton&Barto 2017 2nd draft)\n",
920
+ " # G_t = r_(t+1) + r_(t+2) + ...\n",
921
+ " \n",
922
+ " # Given this formulation, the returns at each timestep t can be computed \n",
923
+ " # by re-using the computed future returns G_(t+1) to compute the current return G_t\n",
924
+ " # G_t = r_(t+1) + gamma*G_(t+1)\n",
925
+ " # G_(t-1) = r_t + gamma* G_t\n",
926
+ " # (this follows a dynamic programming approach, with which we memorize solutions in order \n",
927
+ " # to avoid computing them multiple times)\n",
928
+ " \n",
929
+ " # This is correct since the above is equivalent to (see also page 46 of Sutton&Barto 2017 2nd draft)\n",
930
+ " # G_(t-1) = r_t + gamma*r_(t+1) + gamma*gamma*r_(t+2) + ...\n",
931
+ " \n",
932
+ " \n",
933
+ " ## Given the above, we calculate the returns at timestep t as: \n",
934
+ " # gamma[t] * return[t] + reward[t]\n",
935
+ " #\n",
936
+ " ## We compute this starting from the last timestep to the first, in order\n",
937
+ " ## to employ the formula presented above and avoid redundant computations that would be needed \n",
938
+ " ## if we were to do it from first to last.\n",
939
+ " \n",
940
+ " ## Hence, the queue \"returns\" will hold the returns in chronological order, from t=0 to t=n_steps\n",
941
+ " ## thanks to the appendleft() function which allows to append to the position 0 in constant time O(1)\n",
942
+ " ## a normal python list would instead require O(N) to do this.\n",
943
+ " for t in range(n_steps)[::-1]:\n",
944
+ " disc_return_t = (returns[0] if len(returns)>0 else 0)\n",
945
+ " returns.appendleft( gamma*disc_return_t + rewards[t]) \n",
946
+ " \n",
947
+ " ## standardization of the returns is employed to make training more stable\n",
948
+ " eps = np.finfo(np.float32).eps.item()\n",
949
+ " ## eps is the smallest representable float, which is \n",
950
+ " # added to the standard deviation of the returns to avoid numerical instabilities \n",
951
+ " returns = torch.tensor(returns)\n",
952
+ " returns = (returns - returns.mean()) / (returns.std() + eps)\n",
953
+ " \n",
954
+ " # Line 7:\n",
955
+ " policy_loss = []\n",
956
+ " for log_prob, disc_return in zip(saved_log_probs, returns):\n",
957
+ " policy_loss.append(-log_prob * disc_return)\n",
958
+ " policy_loss = torch.cat(policy_loss).sum()\n",
959
+ " \n",
960
+ " # Line 8: PyTorch prefers gradient descent \n",
961
+ " optimizer.zero_grad()\n",
962
+ " policy_loss.backward()\n",
963
+ " optimizer.step()\n",
964
+ " \n",
965
+ " if i_episode % print_every == 0:\n",
966
+ " print('Episode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))\n",
967
+ " \n",
968
+ " return scores"
969
+ ]
970
+ },
971
+ {
972
+ "cell_type": "markdown",
973
+ "metadata": {
974
+ "id": "RIWhQyJjfpEt"
975
+ },
976
+ "source": [
977
+ "## Train it\n",
978
+ "- We're now ready to train our agent.\n",
979
+ "- But first, we define a variable containing all the training hyperparameters.\n",
980
+ "- You can change the training parameters (and should 😉)"
981
+ ]
982
+ },
983
+ {
984
+ "cell_type": "code",
985
+ "execution_count": 14,
986
+ "metadata": {
987
+ "id": "utRe1NgtVBYF"
988
+ },
989
+ "outputs": [],
990
+ "source": [
991
+ "cartpole_hyperparameters = {\n",
992
+ " \"h_size\": 16,\n",
993
+ " \"n_training_episodes\": 1000,\n",
994
+ " \"n_evaluation_episodes\": 10,\n",
995
+ " \"max_t\": 1000,\n",
996
+ " \"gamma\": 1.0,\n",
997
+ " \"lr\": 1e-2,\n",
998
+ " \"env_id\": env_id,\n",
999
+ " \"state_space\": s_size,\n",
1000
+ " \"action_space\": a_size,\n",
1001
+ "}"
1002
+ ]
1003
+ },
1004
+ {
1005
+ "cell_type": "code",
1006
+ "execution_count": 15,
1007
+ "metadata": {
1008
+ "id": "D3lWyVXBVfl6"
1009
+ },
1010
+ "outputs": [],
1011
+ "source": [
1012
+ "# Create policy and place it to the device\n",
1013
+ "cartpole_policy = Policy(cartpole_hyperparameters[\"state_space\"], cartpole_hyperparameters[\"action_space\"], cartpole_hyperparameters[\"h_size\"]).to(device)\n",
1014
+ "cartpole_optimizer = optim.Adam(cartpole_policy.parameters(), lr=cartpole_hyperparameters[\"lr\"])"
1015
+ ]
1016
+ },
1017
+ {
1018
+ "cell_type": "code",
1019
+ "execution_count": 16,
1020
+ "metadata": {
1021
+ "id": "uGf-hQCnfouB"
1022
+ },
1023
+ "outputs": [
1024
+ {
1025
+ "name": "stdout",
1026
+ "output_type": "stream",
1027
+ "text": [
1028
+ "Episode 100\tAverage Score: 50.60\n",
1029
+ "Episode 200\tAverage Score: 312.39\n",
1030
+ "Episode 300\tAverage Score: 322.14\n",
1031
+ "Episode 400\tAverage Score: 498.43\n",
1032
+ "Episode 500\tAverage Score: 414.31\n",
1033
+ "Episode 600\tAverage Score: 474.40\n",
1034
+ "Episode 700\tAverage Score: 485.19\n",
1035
+ "Episode 800\tAverage Score: 498.00\n",
1036
+ "Episode 900\tAverage Score: 481.82\n",
1037
+ "Episode 1000\tAverage Score: 500.00\n"
1038
+ ]
1039
+ }
1040
+ ],
1041
+ "source": [
1042
+ "scores = reinforce(cartpole_policy,\n",
1043
+ " cartpole_optimizer,\n",
1044
+ " cartpole_hyperparameters[\"n_training_episodes\"], \n",
1045
+ " cartpole_hyperparameters[\"max_t\"],\n",
1046
+ " cartpole_hyperparameters[\"gamma\"], \n",
1047
+ " 100)"
1048
+ ]
1049
+ },
1050
+ {
1051
+ "cell_type": "markdown",
1052
+ "metadata": {
1053
+ "id": "Qajj2kXqhB3g"
1054
+ },
1055
+ "source": [
1056
+ "## Define evaluation method 📝\n",
1057
+ "- Here we define the evaluation method that we're going to use to test our Reinforce agent."
1058
+ ]
1059
+ },
1060
+ {
1061
+ "cell_type": "code",
1062
+ "execution_count": 19,
1063
+ "metadata": {
1064
+ "id": "3FamHmxyhBEU"
1065
+ },
1066
+ "outputs": [],
1067
+ "source": [
1068
+ "def evaluate_agent(env, max_steps, n_eval_episodes, policy):\n",
1069
+ " \"\"\"\n",
1070
+ " Evaluate the agent for ``n_eval_episodes`` episodes and returns average reward and std of reward.\n",
1071
+ " :param env: The evaluation environment\n",
1072
+ " :param n_eval_episodes: Number of episode to evaluate the agent\n",
1073
+ " :param policy: The Reinforce agent\n",
1074
+ " \"\"\"\n",
1075
+ " episode_rewards = []\n",
1076
+ " for episode in range(n_eval_episodes):\n",
1077
+ " state, _ = env.reset()\n",
1078
+ " step = 0\n",
1079
+ " total_rewards_ep = 0\n",
1080
+ " \n",
1081
+ " for step in range(max_steps):\n",
1082
+ " action, _ = policy.act(state)\n",
1083
+ " new_state, reward, terminated, truncated, _ = env.step(action)\n",
1084
+ " total_rewards_ep += reward\n",
1085
+ " \n",
1086
+ " if bool(terminated) or bool(truncated):\n",
1087
+ " break\n",
1088
+ " state = new_state\n",
1089
+ " episode_rewards.append(total_rewards_ep)\n",
1090
+ " mean_reward = np.mean(episode_rewards)\n",
1091
+ " std_reward = np.std(episode_rewards)\n",
1092
+ "\n",
1093
+ " return mean_reward, std_reward"
1094
+ ]
1095
+ },
1096
+ {
1097
+ "cell_type": "markdown",
1098
+ "metadata": {
1099
+ "id": "xdH2QCrLTrlT"
1100
+ },
1101
+ "source": [
1102
+ "## Evaluate our agent 📈"
1103
+ ]
1104
+ },
1105
+ {
1106
+ "cell_type": "code",
1107
+ "execution_count": 20,
1108
+ "metadata": {
1109
+ "id": "ohGSXDyHh0xx"
1110
+ },
1111
+ "outputs": [
1112
+ {
1113
+ "data": {
1114
+ "text/plain": [
1115
+ "(np.float64(500.0), np.float64(0.0))"
1116
+ ]
1117
+ },
1118
+ "execution_count": 20,
1119
+ "metadata": {},
1120
+ "output_type": "execute_result"
1121
+ }
1122
+ ],
1123
+ "source": [
1124
+ "evaluate_agent(eval_env, \n",
1125
+ " cartpole_hyperparameters[\"max_t\"], \n",
1126
+ " cartpole_hyperparameters[\"n_evaluation_episodes\"],\n",
1127
+ " cartpole_policy)"
1128
+ ]
1129
+ },
1130
+ {
1131
+ "cell_type": "markdown",
1132
+ "metadata": {
1133
+ "id": "7CoeLkQ7TpO8"
1134
+ },
1135
+ "source": [
1136
+ "### Publish our trained model on the Hub 🔥\n",
1137
+ "Now that we saw we got good results after the training, we can publish our trained model on the hub 🤗 with one line of code.\n",
1138
+ "\n",
1139
+ "Here's an example of a Model Card:\n",
1140
+ "\n",
1141
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/modelcard.png\"/>"
1142
+ ]
1143
+ },
1144
+ {
1145
+ "cell_type": "markdown",
1146
+ "metadata": {
1147
+ "id": "Jmhs1k-cftIq"
1148
+ },
1149
+ "source": [
1150
+ "### Push to the Hub\n",
1151
+ "#### Do not modify this code"
1152
+ ]
1153
+ },
1154
+ {
1155
+ "cell_type": "code",
1156
+ "execution_count": 21,
1157
+ "metadata": {
1158
+ "id": "LIVsvlW_8tcw"
1159
+ },
1160
+ "outputs": [],
1161
+ "source": [
1162
+ "from huggingface_hub import HfApi, snapshot_download\n",
1163
+ "from huggingface_hub.repocard import metadata_eval_result, metadata_save\n",
1164
+ "\n",
1165
+ "from pathlib import Path\n",
1166
+ "import datetime\n",
1167
+ "import json\n",
1168
+ "import imageio\n",
1169
+ "\n",
1170
+ "import tempfile\n",
1171
+ "\n",
1172
+ "import os"
1173
+ ]
1174
+ },
1175
+ {
1176
+ "cell_type": "code",
1177
+ "execution_count": 41,
1178
+ "metadata": {
1179
+ "id": "Lo4JH45if81z"
1180
+ },
1181
+ "outputs": [],
1182
+ "source": [
1183
+ "def record_video(env, policy, out_directory, fps=30):\n",
1184
+ " \"\"\"\n",
1185
+ " Generate a replay video of the agent's performance.\n",
1186
+ " :param env: the gym environment\n",
1187
+ " :param policy: the policy (or model) of the agent\n",
1188
+ " :param out_directory: path where to save the video\n",
1189
+ " :param fps: frames per second for the video\n",
1190
+ " \"\"\"\n",
1191
+ " # Wrap the environment with RecordVideo to capture video\n",
1192
+ " env = gym.wrappers.RecordVideo(env, out_directory, episode_trigger=lambda episode_id: True)\n",
1193
+ "\n",
1194
+ " state, _ = env.reset()\n",
1195
+ " done = False\n",
1196
+ " while not done:\n",
1197
+ " action, _ = policy.act(state) # Get the action from the policy\n",
1198
+ " state, reward, terminated, truncated, _ = env.step(action)\n",
1199
+ " \n",
1200
+ " # Determine whether the episode has finished\n",
1201
+ " done = terminated or truncated\n",
1202
+ "\n",
1203
+ " env.close()\n",
1204
+ "\n",
1205
+ " print(f\"Video saved at {out_directory}\")"
1206
+ ]
1207
+ },
1208
+ {
1209
+ "cell_type": "code",
1210
+ "execution_count": 44,
1211
+ "metadata": {
1212
+ "id": "_TPdq47D7_f_"
1213
+ },
1214
+ "outputs": [],
1215
+ "source": [
1216
+ "def push_to_hub(repo_id, \n",
1217
+ " model,\n",
1218
+ " hyperparameters,\n",
1219
+ " eval_env,\n",
1220
+ " video_fps=30):\n",
1221
+ " \"\"\"\n",
1222
+ " Evaluate, Generate a video and Upload a model to Hugging Face Hub.\n",
1223
+ " This method does the complete pipeline:\n",
1224
+ " - It evaluates the model\n",
1225
+ " - It generates the model card\n",
1226
+ " - It generates a replay video of the agent\n",
1227
+ " - It pushes everything to the Hub\n",
1228
+ "\n",
1229
+ " :param repo_id: repo_id: id of the model repository from the Hugging Face Hub\n",
1230
+ " :param model: the pytorch model we want to save\n",
1231
+ " :param hyperparameters: training hyperparameters\n",
1232
+ " :param eval_env: evaluation environment\n",
1233
+ " :param video_fps: how many frames per second to record our video replay \n",
1234
+ " \"\"\"\n",
1235
+ "\n",
1236
+ " _, repo_name = repo_id.split(\"/\")\n",
1237
+ " api = HfApi()\n",
1238
+ "\n",
1239
+ " # Step 1: Create the repo\n",
1240
+ " repo_url = api.create_repo(\n",
1241
+ " repo_id=repo_id,\n",
1242
+ " exist_ok=True,\n",
1243
+ " )\n",
1244
+ "\n",
1245
+ " # Define the path to save the video in the current working directory\n",
1246
+ " current_directory = Path(os.getcwd())\n",
1247
+ " print(current_directory)\n",
1248
+ " \n",
1249
+ " # Step 2: Save the model\n",
1250
+ " torch.save(model, current_directory / \"model.pt\")\n",
1251
+ "\n",
1252
+ " # Step 3: Save the hyperparameters to JSON\n",
1253
+ " with open(current_directory / \"hyperparameters.json\", \"w\") as outfile:\n",
1254
+ " json.dump(hyperparameters, outfile)\n",
1255
+ "\n",
1256
+ " # Step 4: Evaluate the model and build JSON\n",
1257
+ " mean_reward, std_reward = evaluate_agent(eval_env, \n",
1258
+ " hyperparameters[\"max_t\"],\n",
1259
+ " hyperparameters[\"n_evaluation_episodes\"], \n",
1260
+ " model)\n",
1261
+ "\n",
1262
+ " # Get datetime\n",
1263
+ " eval_datetime = datetime.datetime.now()\n",
1264
+ " eval_form_datetime = eval_datetime.isoformat()\n",
1265
+ "\n",
1266
+ " evaluate_data = {\n",
1267
+ " \"env_id\": hyperparameters[\"env_id\"], \n",
1268
+ " \"mean_reward\": mean_reward,\n",
1269
+ " \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n",
1270
+ " \"eval_datetime\": eval_form_datetime,\n",
1271
+ " }\n",
1272
+ "\n",
1273
+ " # Write a JSON file\n",
1274
+ " with open(current_directory / \"results.json\", \"w\") as outfile:\n",
1275
+ " json.dump(evaluate_data, outfile)\n",
1276
+ "\n",
1277
+ " # Step 5: Create the model card\n",
1278
+ " env_name = hyperparameters[\"env_id\"]\n",
1279
+ " \n",
1280
+ " metadata = {}\n",
1281
+ " metadata[\"tags\"] = [\n",
1282
+ " env_name,\n",
1283
+ " \"reinforce\",\n",
1284
+ " \"reinforcement-learning\",\n",
1285
+ " \"custom-implementation\",\n",
1286
+ " \"deep-rl-class\"\n",
1287
+ " ]\n",
1288
+ "\n",
1289
+ " # Add metrics\n",
1290
+ " eval = metadata_eval_result(\n",
1291
+ " model_pretty_name=repo_name,\n",
1292
+ " task_pretty_name=\"reinforcement-learning\",\n",
1293
+ " task_id=\"reinforcement-learning\",\n",
1294
+ " metrics_pretty_name=\"mean_reward\",\n",
1295
+ " metrics_id=\"mean_reward\",\n",
1296
+ " metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n",
1297
+ " dataset_pretty_name=env_name,\n",
1298
+ " dataset_id=env_name,\n",
1299
+ " )\n",
1300
+ "\n",
1301
+ " # Merges both dictionaries\n",
1302
+ " metadata = {**metadata, **eval}\n",
1303
+ "\n",
1304
+ " model_card = f\"\"\"\n",
1305
+ "# **Reinforce** Agent playing **{env_name}**\n",
1306
+ "This is a trained model of a **Reinforce** agent playing **{env_name}**.\n",
1307
+ "To learn to use this model and train yours, check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction\n",
1308
+ "\"\"\"\n",
1309
+ "\n",
1310
+ " readme_path = current_directory / \"README.md\"\n",
1311
+ " readme = \"\"\n",
1312
+ " if readme_path.exists():\n",
1313
+ " with readme_path.open(\"r\", encoding=\"utf8\") as f:\n",
1314
+ " readme = f.read()\n",
1315
+ " else:\n",
1316
+ " readme = model_card\n",
1317
+ "\n",
1318
+ " with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n",
1319
+ " f.write(readme)\n",
1320
+ "\n",
1321
+ " # Save our metrics to Readme metadata\n",
1322
+ " metadata_save(readme_path, metadata)\n",
1323
+ "\n",
1324
+ " # Step 6: Record a video and save it in the current working directory\n",
1325
+ " video_path = current_directory / \"replay.mp4\"\n",
1326
+ " record_video(eval_env, model, video_path, video_fps)\n",
1327
+ "\n",
1328
+ " # Step 7: Push everything to the Hub\n",
1329
+ " api.upload_folder(\n",
1330
+ " repo_id=repo_id,\n",
1331
+ " folder_path=current_directory,\n",
1332
+ " path_in_repo=\".\",\n",
1333
+ " )\n",
1334
+ "\n",
1335
+ " print(f\"Your model is pushed to the Hub. You can view your model here: {repo_url}\")\n"
1336
+ ]
1337
+ },
1338
+ {
1339
+ "cell_type": "markdown",
1340
+ "metadata": {
1341
+ "id": "w17w8CxzoURM"
1342
+ },
1343
+ "source": [
1344
+ "### .\n",
1345
+ "\n",
1346
+ "By using `push_to_hub` **you evaluate, record a replay, generate a model card of your agent and push it to the Hub**.\n",
1347
+ "\n",
1348
+ "This way:\n",
1349
+ "- You can **showcase our work** 🔥\n",
1350
+ "- You can **visualize your agent playing** 👀\n",
1351
+ "- You can **share with the community an agent that others can use** 💾\n",
1352
+ "- You can **access a leaderboard 🏆 to see how well your agent is performing compared to your classmates** 👉 https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard\n"
1353
+ ]
1354
+ },
1355
+ {
1356
+ "cell_type": "markdown",
1357
+ "metadata": {
1358
+ "id": "cWnFC0iZooTw"
1359
+ },
1360
+ "source": [
1361
+ "To be able to share your model with the community there are three more steps to follow:\n",
1362
+ "\n",
1363
+ "1️⃣ (If it's not already done) create an account to HF ➡ https://huggingface.co/join\n",
1364
+ "\n",
1365
+ "2️⃣ Sign in and then, you need to store your authentication token from the Hugging Face website.\n",
1366
+ "- Create a new token (https://huggingface.co/settings/tokens) **with write role**\n",
1367
+ "\n",
1368
+ "\n",
1369
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/create-token.jpg\" alt=\"Create HF Token\">\n"
1370
+ ]
1371
+ },
1372
+ {
1373
+ "cell_type": "code",
1374
+ "execution_count": 24,
1375
+ "metadata": {
1376
+ "id": "QB5nIcxR8paT"
1377
+ },
1378
+ "outputs": [
1379
+ {
1380
+ "data": {
1381
+ "application/vnd.jupyter.widget-view+json": {
1382
+ "model_id": "ccc4f901779d4a22b2c996cb6e4c9ea8",
1383
+ "version_major": 2,
1384
+ "version_minor": 0
1385
+ },
1386
+ "text/plain": [
1387
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
1388
+ ]
1389
+ },
1390
+ "metadata": {},
1391
+ "output_type": "display_data"
1392
+ }
1393
+ ],
1394
+ "source": [
1395
+ "notebook_login()"
1396
+ ]
1397
+ },
1398
+ {
1399
+ "cell_type": "markdown",
1400
+ "metadata": {
1401
+ "id": "GyWc1x3-o3xG"
1402
+ },
1403
+ "source": [
1404
+ "If you don't want to use a Google Colab or a Jupyter Notebook, you need to use this command instead: `huggingface-cli login` (or `login`)"
1405
+ ]
1406
+ },
1407
+ {
1408
+ "cell_type": "markdown",
1409
+ "metadata": {
1410
+ "id": "F-D-zhbRoeOm"
1411
+ },
1412
+ "source": [
1413
+ "3️⃣ We're now ready to push our trained agent to the 🤗 Hub 🔥 using `package_to_hub()` function"
1414
+ ]
1415
+ },
1416
+ {
1417
+ "cell_type": "code",
1418
+ "execution_count": 45,
1419
+ "metadata": {
1420
+ "id": "UNwkTS65Uq3Q"
1421
+ },
1422
+ "outputs": [
1423
+ {
1424
+ "name": "stdout",
1425
+ "output_type": "stream",
1426
+ "text": [
1427
+ "C:\\Users\\UTILIS~1\\AppData\\Local\\Temp\\tmp18ry8l3t\n",
1428
+ "Video saved at C:\\Users\\UTILIS~1\\AppData\\Local\\Temp\\tmp18ry8l3t\\replay.mp4\n",
1429
+ "Your model is pushed to the Hub. You can view your model here: https://huggingface.co/Stoub/Reinforce-Cartpole-v1\n"
1430
+ ]
1431
+ }
1432
+ ],
1433
+ "source": [
1434
+ "repo_id = \"Stoub/Reinforce-Cartpole-v1\" #TODO Define your repo id {username/Reinforce-{model-id}}\n",
1435
+ "push_to_hub(repo_id,\n",
1436
+ " cartpole_policy, # The model we want to save\n",
1437
+ " cartpole_hyperparameters, # Hyperparameters\n",
1438
+ " eval_env, # Evaluation environment\n",
1439
+ " video_fps=30\n",
1440
+ " )"
1441
+ ]
1442
+ },
1443
+ {
1444
+ "cell_type": "markdown",
1445
+ "metadata": {
1446
+ "id": "jrnuKH1gYZSz"
1447
+ },
1448
+ "source": [
1449
+ "Now that we try the robustness of our implementation, let's try a more complex environment: PixelCopter 🚁\n",
1450
+ "\n",
1451
+ "\n"
1452
+ ]
1453
+ },
1454
+ {
1455
+ "cell_type": "markdown",
1456
+ "metadata": {
1457
+ "id": "JNLVmKKVKA6j"
1458
+ },
1459
+ "source": [
1460
+ "## Second agent: PixelCopter 🚁\n",
1461
+ "\n",
1462
+ "### Study the PixelCopter environment 👀\n",
1463
+ "- [The Environment documentation](https://pygame-learning-environment.readthedocs.io/en/latest/user/games/pixelcopter.html)\n"
1464
+ ]
1465
+ },
1466
+ {
1467
+ "cell_type": "code",
1468
+ "execution_count": null,
1469
+ "metadata": {
1470
+ "id": "JBSc8mlfyin3"
1471
+ },
1472
+ "outputs": [],
1473
+ "source": [
1474
+ "env_id = \"Pixelcopter-PLE-v0\"\n",
1475
+ "env = gym.make(env_id)\n",
1476
+ "eval_env = gym.make(env_id)\n",
1477
+ "s_size = env.observation_space.shape[0]\n",
1478
+ "a_size = env.action_space.n"
1479
+ ]
1480
+ },
1481
+ {
1482
+ "cell_type": "code",
1483
+ "execution_count": null,
1484
+ "metadata": {
1485
+ "id": "L5u_zAHsKBy7"
1486
+ },
1487
+ "outputs": [],
1488
+ "source": [
1489
+ "print(\"_____OBSERVATION SPACE_____ \\n\")\n",
1490
+ "print(\"The State Space is: \", s_size)\n",
1491
+ "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
1492
+ ]
1493
+ },
1494
+ {
1495
+ "cell_type": "code",
1496
+ "execution_count": null,
1497
+ "metadata": {
1498
+ "id": "D7yJM9YXKNbq"
1499
+ },
1500
+ "outputs": [],
1501
+ "source": [
1502
+ "print(\"\\n _____ACTION SPACE_____ \\n\")\n",
1503
+ "print(\"The Action Space is: \", a_size)\n",
1504
+ "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
1505
+ ]
1506
+ },
1507
+ {
1508
+ "cell_type": "markdown",
1509
+ "metadata": {
1510
+ "id": "NNWvlyvzalXr"
1511
+ },
1512
+ "source": [
1513
+ "The observation space (7) 👀:\n",
1514
+ "- player y position\n",
1515
+ "- player velocity\n",
1516
+ "- player distance to floor\n",
1517
+ "- player distance to ceiling\n",
1518
+ "- next block x distance to player\n",
1519
+ "- next blocks top y location\n",
1520
+ "- next blocks bottom y location\n",
1521
+ "\n",
1522
+ "The action space(2) 🎮:\n",
1523
+ "- Up (press accelerator) \n",
1524
+ "- Do nothing (don't press accelerator) \n",
1525
+ "\n",
1526
+ "The reward function 💰: \n",
1527
+ "- For each vertical block it passes through it gains a positive reward of +1. Each time a terminal state reached it receives a negative reward of -1."
1528
+ ]
1529
+ },
1530
+ {
1531
+ "cell_type": "markdown",
1532
+ "metadata": {
1533
+ "id": "aV1466QP8crz"
1534
+ },
1535
+ "source": [
1536
+ "### Define the new Policy 🧠\n",
1537
+ "- We need to have a deeper neural network since the environment is more complex"
1538
+ ]
1539
+ },
1540
+ {
1541
+ "cell_type": "code",
1542
+ "execution_count": null,
1543
+ "metadata": {
1544
+ "id": "I1eBkCiX2X_S"
1545
+ },
1546
+ "outputs": [],
1547
+ "source": [
1548
+ "class Policy(nn.Module):\n",
1549
+ " def __init__(self, s_size, a_size, h_size):\n",
1550
+ " super(Policy, self).__init__()\n",
1551
+ " # Define the three layers here\n",
1552
+ "\n",
1553
+ " def forward(self, x):\n",
1554
+ " # Define the forward process here\n",
1555
+ " return F.softmax(x, dim=1)\n",
1556
+ " \n",
1557
+ " def act(self, state):\n",
1558
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
1559
+ " probs = self.forward(state).cpu()\n",
1560
+ " m = Categorical(probs)\n",
1561
+ " action = m.sample()\n",
1562
+ " return action.item(), m.log_prob(action)"
1563
+ ]
1564
+ },
1565
+ {
1566
+ "cell_type": "markdown",
1567
+ "metadata": {
1568
+ "id": "47iuAFqV8Ws-"
1569
+ },
1570
+ "source": [
1571
+ "#### Solution"
1572
+ ]
1573
+ },
1574
+ {
1575
+ "cell_type": "code",
1576
+ "execution_count": null,
1577
+ "metadata": {
1578
+ "id": "wrNuVcHC8Xu7"
1579
+ },
1580
+ "outputs": [],
1581
+ "source": [
1582
+ "class Policy(nn.Module):\n",
1583
+ " def __init__(self, s_size, a_size, h_size):\n",
1584
+ " super(Policy, self).__init__()\n",
1585
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
1586
+ " self.fc2 = nn.Linear(h_size, h_size*2)\n",
1587
+ " self.fc3 = nn.Linear(h_size*2, a_size)\n",
1588
+ "\n",
1589
+ " def forward(self, x):\n",
1590
+ " x = F.relu(self.fc1(x))\n",
1591
+ " x = F.relu(self.fc2(x))\n",
1592
+ " x = self.fc3(x)\n",
1593
+ " return F.softmax(x, dim=1)\n",
1594
+ " \n",
1595
+ " def act(self, state):\n",
1596
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
1597
+ " probs = self.forward(state).cpu()\n",
1598
+ " m = Categorical(probs)\n",
1599
+ " action = m.sample()\n",
1600
+ " return action.item(), m.log_prob(action)"
1601
+ ]
1602
+ },
1603
+ {
1604
+ "cell_type": "markdown",
1605
+ "metadata": {
1606
+ "id": "SM1QiGCSbBkM"
1607
+ },
1608
+ "source": [
1609
+ "### Define the hyperparameters ⚙️\n",
1610
+ "- Because this environment is more complex.\n",
1611
+ "- Especially for the hidden size, we need more neurons."
1612
+ ]
1613
+ },
1614
+ {
1615
+ "cell_type": "code",
1616
+ "execution_count": null,
1617
+ "metadata": {
1618
+ "id": "y0uujOR_ypB6"
1619
+ },
1620
+ "outputs": [],
1621
+ "source": [
1622
+ "pixelcopter_hyperparameters = {\n",
1623
+ " \"h_size\": 64,\n",
1624
+ " \"n_training_episodes\": 50000,\n",
1625
+ " \"n_evaluation_episodes\": 10,\n",
1626
+ " \"max_t\": 10000,\n",
1627
+ " \"gamma\": 0.99,\n",
1628
+ " \"lr\": 1e-4,\n",
1629
+ " \"env_id\": env_id,\n",
1630
+ " \"state_space\": s_size,\n",
1631
+ " \"action_space\": a_size,\n",
1632
+ "}"
1633
+ ]
1634
+ },
1635
+ {
1636
+ "cell_type": "markdown",
1637
+ "metadata": {
1638
+ "id": "wyvXTJWm9GJG"
1639
+ },
1640
+ "source": [
1641
+ "### Train it\n",
1642
+ "- We're now ready to train our agent 🔥."
1643
+ ]
1644
+ },
1645
+ {
1646
+ "cell_type": "code",
1647
+ "execution_count": null,
1648
+ "metadata": {
1649
+ "id": "7mM2P_ckysFE"
1650
+ },
1651
+ "outputs": [],
1652
+ "source": [
1653
+ "# Create policy and place it to the device\n",
1654
+ "# torch.manual_seed(50)\n",
1655
+ "pixelcopter_policy = Policy(pixelcopter_hyperparameters[\"state_space\"], pixelcopter_hyperparameters[\"action_space\"], pixelcopter_hyperparameters[\"h_size\"]).to(device)\n",
1656
+ "pixelcopter_optimizer = optim.Adam(pixelcopter_policy.parameters(), lr=pixelcopter_hyperparameters[\"lr\"])"
1657
+ ]
1658
+ },
1659
+ {
1660
+ "cell_type": "code",
1661
+ "execution_count": null,
1662
+ "metadata": {
1663
+ "id": "v1HEqP-fy-Rf"
1664
+ },
1665
+ "outputs": [],
1666
+ "source": [
1667
+ "scores = reinforce(pixelcopter_policy,\n",
1668
+ " pixelcopter_optimizer,\n",
1669
+ " pixelcopter_hyperparameters[\"n_training_episodes\"], \n",
1670
+ " pixelcopter_hyperparameters[\"max_t\"],\n",
1671
+ " pixelcopter_hyperparameters[\"gamma\"], \n",
1672
+ " 1000)"
1673
+ ]
1674
+ },
1675
+ {
1676
+ "cell_type": "markdown",
1677
+ "metadata": {
1678
+ "id": "8kwFQ-Ip85BE"
1679
+ },
1680
+ "source": [
1681
+ "### Publish our trained model on the Hub 🔥"
1682
+ ]
1683
+ },
1684
+ {
1685
+ "cell_type": "code",
1686
+ "execution_count": null,
1687
+ "metadata": {
1688
+ "id": "6PtB7LRbTKWK"
1689
+ },
1690
+ "outputs": [],
1691
+ "source": [
1692
+ "repo_id = \"\" #TODO Define your repo id {username/Reinforce-{model-id}}\n",
1693
+ "push_to_hub(repo_id,\n",
1694
+ " pixelcopter_policy, # The model we want to save\n",
1695
+ " pixelcopter_hyperparameters, # Hyperparameters\n",
1696
+ " eval_env, # Evaluation environment\n",
1697
+ " video_fps=30\n",
1698
+ " )"
1699
+ ]
1700
+ },
1701
+ {
1702
+ "cell_type": "markdown",
1703
+ "metadata": {
1704
+ "id": "7VDcJ29FcOyb"
1705
+ },
1706
+ "source": [
1707
+ "## Some additional challenges 🏆\n",
1708
+ "The best way to learn **is to try things on your own**! As you saw, the current agent is not doing great. As a first suggestion, you can train for more steps. But also trying to find better parameters.\n",
1709
+ "\n",
1710
+ "In the [Leaderboard](https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard) you will find your agents. Can you get to the top?\n",
1711
+ "\n",
1712
+ "Here are some ideas to achieve so:\n",
1713
+ "* Train more steps\n",
1714
+ "* Try different hyperparameters by looking at what your classmates have done 👉 https://huggingface.co/models?other=reinforce\n",
1715
+ "* **Push your new trained model** on the Hub 🔥\n",
1716
+ "* **Improving the implementation for more complex environments** (for instance, what about changing the network to a Convolutional Neural Network to handle\n",
1717
+ "frames as observation)?"
1718
+ ]
1719
+ },
1720
+ {
1721
+ "cell_type": "markdown",
1722
+ "metadata": {
1723
+ "id": "x62pP0PHdA-y"
1724
+ },
1725
+ "source": [
1726
+ "________________________________________________________________________\n",
1727
+ "\n",
1728
+ "**Congrats on finishing this unit**! There was a lot of information.\n",
1729
+ "And congrats on finishing the tutorial. You've just coded your first Deep Reinforcement Learning agent from scratch using PyTorch and shared it on the Hub 🥳.\n",
1730
+ "\n",
1731
+ "Don't hesitate to iterate on this unit **by improving the implementation for more complex environments** (for instance, what about changing the network to a Convolutional Neural Network to handle\n",
1732
+ "frames as observation)?\n",
1733
+ "\n",
1734
+ "In the next unit, **we're going to learn more about Unity MLAgents**, by training agents in Unity environments. This way, you will be ready to participate in the **AI vs AI challenges where you'll train your agents\n",
1735
+ "to compete against other agents in a snowball fight and a soccer game.**\n",
1736
+ "\n",
1737
+ "Sounds fun? See you next time!\n",
1738
+ "\n",
1739
+ "Finally, we would love **to hear what you think of the course and how we can improve it**. If you have some feedback then, please 👉 [fill this form](https://forms.gle/BzKXWzLAGZESGNaE9)\n",
1740
+ "\n",
1741
+ "See you in Unit 5! 🔥\n",
1742
+ "\n",
1743
+ "### Keep Learning, stay awesome 🤗\n",
1744
+ "\n"
1745
+ ]
1746
+ }
1747
+ ],
1748
+ "metadata": {
1749
+ "accelerator": "GPU",
1750
+ "colab": {
1751
+ "collapsed_sections": [
1752
+ "BPLwsPajb1f8",
1753
+ "L_WSo0VUV99t",
1754
+ "mjY-eq3eWh9O",
1755
+ "JoTC9o2SczNn",
1756
+ "gfGJNZBUP7Vn",
1757
+ "YB0Cxrw1StrP",
1758
+ "47iuAFqV8Ws-",
1759
+ "x62pP0PHdA-y"
1760
+ ],
1761
+ "include_colab_link": true,
1762
+ "private_outputs": true,
1763
+ "provenance": []
1764
+ },
1765
+ "gpuClass": "standard",
1766
+ "kernelspec": {
1767
+ "display_name": "Python 3 (ipykernel)",
1768
+ "language": "python",
1769
+ "name": "python3"
1770
+ },
1771
+ "language_info": {
1772
+ "codemirror_mode": {
1773
+ "name": "ipython",
1774
+ "version": 3
1775
+ },
1776
+ "file_extension": ".py",
1777
+ "mimetype": "text/x-python",
1778
+ "name": "python",
1779
+ "nbconvert_exporter": "python",
1780
+ "pygments_lexer": "ipython3",
1781
+ "version": "3.12.5"
1782
+ }
1783
+ },
1784
+ "nbformat": 4,
1785
+ "nbformat_minor": 4
1786
+ }
README.md CHANGED
@@ -21,7 +21,6 @@ model-index:
21
  verified: false
22
  ---
23
 
24
- # **Reinforce** Agent playing **CartPole-v1**
25
- This is a trained model of a **Reinforce** agent playing **CartPole-v1** .
26
- To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction
27
-
 
21
  verified: false
22
  ---
23
 
24
+ # **Reinforce** Agent playing **CartPole-v1**
25
+ This is a trained model of a **Reinforce** agent playing **CartPole-v1**.
26
+ To learn to use this model and train yours, check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction
 
replay.mp4 ADDED
Binary file (358 kB). View file
 
requirements-unit4.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # git+https://github.com/ntasfi/PyGame-Learning-Environment.git
2
+ # git+https://github.com/simoninithomas/gym-games
3
+ huggingface_hub
4
+ imageio-ffmpeg
5
+ # pyyaml==6.0
results.json CHANGED
@@ -1 +1 @@
1
- {"env_id": "CartPole-v1", "mean_reward": 500.0, "n_evaluation_episodes": 10, "eval_datetime": "2024-09-25T17:32:02.161498"}
 
1
+ {"env_id": "CartPole-v1", "mean_reward": 500.0, "n_evaluation_episodes": 10, "eval_datetime": "2024-09-25T19:18:52.143176"}
unit4.ipynb ADDED
@@ -0,0 +1,2345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/huggingface/deep-rl-class/blob/GymnasiumUpdate%2FUnit4/notebooks/unit4/unit4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "CjRWziAVU2lZ"
17
+ },
18
+ "source": [
19
+ "# Unit 4: Code your first Deep Reinforcement Learning Algorithm with PyTorch: Reinforce. And test its robustness 💪\n",
20
+ "\n",
21
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/thumbnail.png\" alt=\"thumbnail\"/>\n",
22
+ "\n",
23
+ "\n",
24
+ "In this notebook, you'll code your first Deep Reinforcement Learning algorithm from scratch: Reinforce (also called Monte Carlo Policy Gradient).\n",
25
+ "\n",
26
+ "Reinforce is a *Policy-based method*: a Deep Reinforcement Learning algorithm that tries **to optimize the policy directly without using an action-value function**.\n",
27
+ "\n",
28
+ "More precisely, Reinforce is a *Policy-gradient method*, a subclass of *Policy-based methods* that aims **to optimize the policy directly by estimating the weights of the optimal policy using gradient ascent**.\n",
29
+ "\n",
30
+ "To test its robustness, we're going to train it in 2 different simple environments:\n",
31
+ "- Cartpole-v1\n",
32
+ "- PixelcopterEnv\n",
33
+ "\n",
34
+ "⬇️ Here is an example of what **you will achieve at the end of this notebook.** ⬇️"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {
40
+ "id": "s4rBom2sbo7S"
41
+ },
42
+ "source": [
43
+ " <img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/envs.gif\" alt=\"Environments\"/>\n"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {
49
+ "id": "BPLwsPajb1f8"
50
+ },
51
+ "source": [
52
+ "### 🎮 Environments: \n",
53
+ "\n",
54
+ "- [CartPole-v1](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n",
55
+ "- [PixelCopter](https://pygame-learning-environment.readthedocs.io/en/latest/user/games/pixelcopter.html)\n",
56
+ "\n",
57
+ "### 📚 RL-Library: \n",
58
+ "\n",
59
+ "- Python\n",
60
+ "- PyTorch\n",
61
+ "\n",
62
+ "\n",
63
+ "We're constantly trying to improve our tutorials, so **if you find some issues in this notebook**, please [open an issue on the GitHub Repo](https://github.com/huggingface/deep-rl-class/issues)."
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "metadata": {
69
+ "id": "L_WSo0VUV99t"
70
+ },
71
+ "source": [
72
+ "## Objectives of this notebook 🏆\n",
73
+ "At the end of the notebook, you will:\n",
74
+ "- Be able to **code from scratch a Reinforce algorithm using PyTorch.**\n",
75
+ "- Be able to **test the robustness of your agent using simple environments.**\n",
76
+ "- Be able to **push your trained agent to the Hub** with a nice video replay and an evaluation score 🔥."
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {
82
+ "id": "lEPrZg2eWa4R"
83
+ },
84
+ "source": [
85
+ "## This notebook is from the Deep Reinforcement Learning Course\n",
86
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/deep-rl-course-illustration.jpg\" alt=\"Deep RL Course illustration\"/>"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "metadata": {
92
+ "id": "6p5HnEefISCB"
93
+ },
94
+ "source": [
95
+ "In this free course, you will:\n",
96
+ "\n",
97
+ "- 📖 Study Deep Reinforcement Learning in **theory and practice**.\n",
98
+ "- 🧑‍💻 Learn to **use famous Deep RL libraries** such as Stable Baselines3, RL Baselines3 Zoo, CleanRL and Sample Factory 2.0.\n",
99
+ "- 🤖 Train **agents in unique environments** \n",
100
+ "\n",
101
+ "And more check 📚 the syllabus 👉 https://simoninithomas.github.io/deep-rl-course\n",
102
+ "\n",
103
+ "Don’t forget to **<a href=\"http://eepurl.com/ic5ZUD\">sign up to the course</a>** (we are collecting your email to be able to **send you the links when each Unit is published and give you information about the challenges and updates).**\n",
104
+ "\n",
105
+ "\n",
106
+ "The best way to keep in touch is to join our discord server to exchange with the community and with us 👉🏻 https://discord.gg/ydHrjt3WP5"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {
112
+ "id": "mjY-eq3eWh9O"
113
+ },
114
+ "source": [
115
+ "## Prerequisites 🏗️\n",
116
+ "Before diving into the notebook, you need to:\n",
117
+ "\n",
118
+ "🔲 📚 [Study Policy Gradients by reading Unit 4](https://huggingface.co/deep-rl-course/unit4/introduction)"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {
124
+ "id": "Bsh4ZAamchSl"
125
+ },
126
+ "source": [
127
+ "# Let's code Reinforce algorithm from scratch 🔥\n",
128
+ "\n",
129
+ "\n",
130
+ "To validate this hands-on for the certification process, you need to push your trained models to the Hub.\n",
131
+ "\n",
132
+ "- Get a result of >= 350 for `Cartpole-v1`.\n",
133
+ "- Get a result of >= 5 for `PixelCopter`.\n",
134
+ "\n",
135
+ "To find your result, go to the leaderboard and find your model, **the result = mean_reward - std of reward**. **If you don't see your model on the leaderboard, go at the bottom of the leaderboard page and click on the refresh button**.\n",
136
+ "\n",
137
+ "For more information about the certification process, check this section 👉 https://huggingface.co/deep-rl-course/en/unit0/introduction#certification-process\n"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "metadata": {
143
+ "id": "JoTC9o2SczNn",
144
+ "jp-MarkdownHeadingCollapsed": true
145
+ },
146
+ "source": [
147
+ "## An colab advice 💡"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {},
153
+ "source": [
154
+ "It's better to run this colab in a copy on your Google Drive, so that **if it timeouts** you still have the saved notebook on your Google Drive and do not need to fill everything from scratch.\n",
155
+ "\n",
156
+ "To do that you can either do `Ctrl + S` or `File > Save a copy in Google Drive.`"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "metadata": {
162
+ "id": "PU4FVzaoM6fC",
163
+ "jp-MarkdownHeadingCollapsed": true
164
+ },
165
+ "source": [
166
+ "## Set the GPU 💪"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {},
172
+ "source": [
173
+ "- To **accelerate the agent's training, we'll use a GPU**. To do that, go to `Runtime > Change Runtime type`\n",
174
+ "\n",
175
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/gpu-step1.jpg\" alt=\"GPU Step 1\">"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {
181
+ "id": "KV0NyFdQM9ZG"
182
+ },
183
+ "source": [
184
+ "- `Hardware Accelerator > GPU`\n",
185
+ "\n",
186
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/gpu-step2.jpg\" alt=\"GPU Step 2\">"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {
192
+ "id": "bTpYcVZVMzUI"
193
+ },
194
+ "source": [
195
+ "## Create a virtual display 🖥\n",
196
+ "\n",
197
+ "During the notebook, we'll need to generate a replay video. To do so, with colab, **we need to have a virtual screen to be able to render the environment** (and thus record the frames). \n",
198
+ "\n",
199
+ "Hence the following cell will install the librairies and create and run a virtual screen 🖥"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 1,
205
+ "metadata": {
206
+ "id": "jV6wjQ7Be7p5",
207
+ "scrolled": true
208
+ },
209
+ "outputs": [
210
+ {
211
+ "name": "stdout",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "Requirement already satisfied: pyvirtualdisplay in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (3.0)\n",
215
+ "Requirement already satisfied: pyglet==1.5.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (1.5.1)\n",
216
+ "Requirement already satisfied: huggingface_hub in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (0.25.1)\n",
217
+ "Requirement already satisfied: filelock in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (3.13.1)\n",
218
+ "Requirement already satisfied: fsspec>=2023.5.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (2024.2.0)\n",
219
+ "Requirement already satisfied: packaging>=20.9 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (24.1)\n",
220
+ "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (6.0.1)\n",
221
+ "Requirement already satisfied: requests in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (2.32.3)\n",
222
+ "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (4.66.5)\n",
223
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from huggingface_hub) (4.11.0)\n",
224
+ "Requirement already satisfied: colorama in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from tqdm>=4.42.1->huggingface_hub) (0.4.6)\n",
225
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (3.3.2)\n",
226
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (3.7)\n",
227
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (2.2.2)\n",
228
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests->huggingface_hub) (2024.8.30)\n",
229
+ "Requirement already satisfied: gym in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (0.26.2)\n",
230
+ "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from gym) (1.26.4)\n",
231
+ "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from gym) (3.0.0)\n",
232
+ "Requirement already satisfied: gym-notices>=0.0.4 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from gym) (0.0.8)\n",
233
+ "Requirement already satisfied: imageio[ffmpeg] in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (2.35.1)\n",
234
+ "Requirement already satisfied: numpy in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (1.26.4)\n",
235
+ "Requirement already satisfied: pillow>=8.3.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (10.2.0)\n",
236
+ "Requirement already satisfied: imageio-ffmpeg in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (0.5.1)\n",
237
+ "Requirement already satisfied: psutil in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio[ffmpeg]) (5.9.0)\n",
238
+ "Requirement already satisfied: setuptools in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio-ffmpeg->imageio[ffmpeg]) (75.1.0)\n",
239
+ "Requirement already satisfied: moviepy in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (1.0.3)\n",
240
+ "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (4.4.2)\n",
241
+ "Requirement already satisfied: tqdm<5.0,>=4.11.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (4.66.5)\n",
242
+ "Requirement already satisfied: requests<3.0,>=2.8.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (2.32.3)\n",
243
+ "Requirement already satisfied: proglog<=1.0.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (0.1.10)\n",
244
+ "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (1.26.4)\n",
245
+ "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (2.35.1)\n",
246
+ "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from moviepy) (0.5.1)\n",
247
+ "Requirement already satisfied: pillow>=8.3.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio<3.0,>=2.5->moviepy) (10.2.0)\n",
248
+ "Requirement already satisfied: setuptools in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from imageio-ffmpeg>=0.2.0->moviepy) (75.1.0)\n",
249
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.3.2)\n",
250
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.7)\n",
251
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2.2.2)\n",
252
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2024.8.30)\n",
253
+ "Requirement already satisfied: colorama in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from tqdm<5.0,>=4.11.2->moviepy) (0.4.6)\n",
254
+ "Requirement already satisfied: opencv-python in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (4.10.0.84)\n",
255
+ "Requirement already satisfied: numpy>=1.21.2 in c:\\users\\utilisateur\\anaconda3\\envs\\rl\\lib\\site-packages (from opencv-python) (1.26.4)\n"
256
+ ]
257
+ }
258
+ ],
259
+ "source": [
260
+ "!pip install pyvirtualdisplay\n",
261
+ "!pip install pyglet==1.5.1\n",
262
+ "!pip install huggingface_hub\n",
263
+ "!pip install gym --upgrade\n",
264
+ "!pip install imageio[ffmpeg]\n",
265
+ "!pip install moviepy\n",
266
+ "!pip install opencv-python"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "markdown",
271
+ "metadata": {
272
+ "id": "tjrLfPFIW8XK"
273
+ },
274
+ "source": [
275
+ "## Install the dependencies 🔽\n",
276
+ "The first step is to install the dependencies. We’ll install multiple ones:\n",
277
+ "\n",
278
+ "- `gym`\n",
279
+ "- `gym-games`: Extra gym environments made with PyGame.\n",
280
+ "- `huggingface_hub`: 🤗 works as a central place where anyone can share and explore models and datasets. It has versioning, metrics, visualizations, and other features that will allow you to easily collaborate with others.\n",
281
+ "\n",
282
+ "You may be wondering why we install gym and not gymnasium, a more recent version of gym? **Because the gym-games we are using are not updated yet with gymnasium**. \n",
283
+ "\n",
284
+ "The differences you'll encounter here:\n",
285
+ "- In `gym` we don't have `terminated` and `truncated` but only `done`.\n",
286
+ "- In `gym` using `env.step()` returns `state, reward, done, info`\n",
287
+ "\n",
288
+ "You can learn more about the differences between Gym and Gymnasium here 👉 https://gymnasium.farama.org/content/migration-guide/\n",
289
+ "\n",
290
+ "\n",
291
+ "You can see here all the Reinforce models available 👉 https://huggingface.co/models?other=reinforce\n",
292
+ "\n",
293
+ "And you can find all the Deep Reinforcement Learning models here 👉 https://huggingface.co/models?pipeline_tag=reinforcement-learning\n"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "metadata": {
299
+ "id": "AAHAq6RZW3rn"
300
+ },
301
+ "source": [
302
+ "## Import the packages 📦\n",
303
+ "In addition to import the installed libraries, we also import:\n",
304
+ "\n",
305
+ "- `imageio`: A library that will help us to generate a replay video\n",
306
+ "\n"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": 2,
312
+ "metadata": {
313
+ "id": "V8oadoJSWp7C"
314
+ },
315
+ "outputs": [],
316
+ "source": [
317
+ "import numpy as np\n",
318
+ "\n",
319
+ "from collections import deque\n",
320
+ "\n",
321
+ "import matplotlib.pyplot as plt\n",
322
+ "%matplotlib inline\n",
323
+ "\n",
324
+ "# PyTorch\n",
325
+ "import torch\n",
326
+ "import torch.nn as nn\n",
327
+ "import torch.nn.functional as F\n",
328
+ "import torch.optim as optim\n",
329
+ "from torch.distributions import Categorical\n",
330
+ "\n",
331
+ "# Gym\n",
332
+ "import gym\n",
333
+ "import gym_pygame\n",
334
+ "\n",
335
+ "# Hugging Face Hub\n",
336
+ "from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.\n",
337
+ "import imageio"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "markdown",
342
+ "metadata": {
343
+ "id": "RfxJYdMeeVgv"
344
+ },
345
+ "source": [
346
+ "## Check if we have a GPU\n",
347
+ "\n",
348
+ "- Let's check if we have a GPU\n",
349
+ "- If it's the case you should see `device:cuda0`"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": 3,
355
+ "metadata": {
356
+ "id": "kaJu5FeZxXGY"
357
+ },
358
+ "outputs": [],
359
+ "source": [
360
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": 4,
366
+ "metadata": {
367
+ "id": "U5TNYa14aRav"
368
+ },
369
+ "outputs": [
370
+ {
371
+ "name": "stdout",
372
+ "output_type": "stream",
373
+ "text": [
374
+ "cuda:0\n"
375
+ ]
376
+ }
377
+ ],
378
+ "source": [
379
+ "print(device)"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "markdown",
384
+ "metadata": {
385
+ "id": "PBPecCtBL_pZ"
386
+ },
387
+ "source": [
388
+ "We're now ready to implement our Reinforce algorithm 🔥"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "markdown",
393
+ "metadata": {
394
+ "id": "8KEyKYo2ZSC-"
395
+ },
396
+ "source": [
397
+ "# First agent: Playing CartPole-v1 🤖"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "metadata": {
403
+ "id": "haLArKURMyuF"
404
+ },
405
+ "source": [
406
+ "## Create the CartPole environment and understand how it works\n",
407
+ "### [The environment 🎮](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "markdown",
412
+ "metadata": {
413
+ "id": "AH_TaLKFXo_8"
414
+ },
415
+ "source": [
416
+ "### Why do we use a simple environment like CartPole-v1?\n",
417
+ "As explained in [Reinforcement Learning Tips and Tricks](https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html), when you implement your agent from scratch you need **to be sure that it works correctly and find bugs with easy environments before going deeper**. Since finding bugs will be much easier in simple environments.\n",
418
+ "\n",
419
+ "\n",
420
+ "> Try to have some “sign of life” on toy problems\n",
421
+ "\n",
422
+ "\n",
423
+ "> Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo). You usually need to run hyperparameter optimization for that step.\n",
424
+ "___\n",
425
+ "### The CartPole-v1 environment\n",
426
+ "\n",
427
+ "> A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces in the left and right direction on the cart.\n",
428
+ "\n",
429
+ "\n",
430
+ "\n",
431
+ "So, we start with CartPole-v1. The goal is to push the cart left or right **so that the pole stays in the equilibrium.**\n",
432
+ "\n",
433
+ "The episode ends if:\n",
434
+ "- The pole Angle is greater than ±12°\n",
435
+ "- Cart Position is greater than ±2.4\n",
436
+ "- Episode length is greater than 500\n",
437
+ "\n",
438
+ "We get a reward 💰 of +1 every timestep the Pole stays in the equilibrium."
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": 5,
444
+ "metadata": {
445
+ "id": "POOOk15_K6KA"
446
+ },
447
+ "outputs": [],
448
+ "source": [
449
+ "env_id = \"CartPole-v1\"\n",
450
+ "# Create the env\n",
451
+ "env = gym.make(env_id)\n",
452
+ "\n",
453
+ "# Create the evaluation env\n",
454
+ "eval_env = gym.make(env_id)\n",
455
+ "\n",
456
+ "# Create the evaluation env\n",
457
+ "render_env = gym.make(env_id, render_mode=\"rgb_array\")\n",
458
+ "\n",
459
+ "# Get the state space and action space\n",
460
+ "s_size = env.observation_space.shape[0]\n",
461
+ "a_size = env.action_space.n"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": 6,
467
+ "metadata": {
468
+ "id": "FMLFrjiBNLYJ"
469
+ },
470
+ "outputs": [
471
+ {
472
+ "name": "stdout",
473
+ "output_type": "stream",
474
+ "text": [
475
+ "_____OBSERVATION SPACE_____ \n",
476
+ "\n",
477
+ "The State Space is: 4\n",
478
+ "Sample observation [ 1.5935603e+00 -1.3312091e+38 3.4518537e-01 3.0665638e+38]\n"
479
+ ]
480
+ }
481
+ ],
482
+ "source": [
483
+ "print(\"_____OBSERVATION SPACE_____ \\n\")\n",
484
+ "print(\"The State Space is: \", s_size)\n",
485
+ "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 7,
491
+ "metadata": {
492
+ "id": "Lu6t4sRNNWkN"
493
+ },
494
+ "outputs": [
495
+ {
496
+ "name": "stdout",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "\n",
500
+ " _____ACTION SPACE_____ \n",
501
+ "\n",
502
+ "The Action Space is: 2\n",
503
+ "Action Space Sample 1\n"
504
+ ]
505
+ }
506
+ ],
507
+ "source": [
508
+ "print(\"\\n _____ACTION SPACE_____ \\n\")\n",
509
+ "print(\"The Action Space is: \", a_size)\n",
510
+ "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "markdown",
515
+ "metadata": {
516
+ "id": "7SJMJj3WaFOz"
517
+ },
518
+ "source": [
519
+ "## Let's build the Reinforce Architecture\n",
520
+ "This implementation is based on two implementations:\n",
521
+ "- [PyTorch official Reinforcement Learning example](https://github.com/pytorch/examples/blob/main/reinforcement_learning/reinforce.py)\n",
522
+ "- [Udacity Reinforce](https://github.com/udacity/deep-reinforcement-learning/blob/master/reinforce/REINFORCE.ipynb)\n",
523
+ "- [Improvement of the integration by Chris1nexus](https://github.com/huggingface/deep-rl-class/pull/95)\n",
524
+ "\n",
525
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/reinforce.png\" alt=\"Reinforce\"/>"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "markdown",
530
+ "metadata": {
531
+ "id": "49kogtxBODX8"
532
+ },
533
+ "source": [
534
+ "So we want:\n",
535
+ "- Two fully connected layers (fc1 and fc2).\n",
536
+ "- Using ReLU as activation function of fc1\n",
537
+ "- Using Softmax to output a probability distribution over actions"
538
+ ]
539
+ },
540
+ {
541
+ "cell_type": "code",
542
+ "execution_count": 8,
543
+ "metadata": {
544
+ "id": "w2LHcHhVZvPZ"
545
+ },
546
+ "outputs": [],
547
+ "source": [
548
+ "class Policy(nn.Module):\n",
549
+ " def __init__(self, s_size, a_size, h_size):\n",
550
+ " super(Policy, self).__init__()\n",
551
+ " # Create two fully connected layers\n",
552
+ " self.fc1 = nn.Linear(s_size,h_size)\n",
553
+ " self.fc2 = nn.Linear(h_size,a_size)\n",
554
+ "\n",
555
+ " def forward(self, x):\n",
556
+ " # Define the forward pass\n",
557
+ " # state goes to fc1 then we apply ReLU activation function\n",
558
+ " x = self.fc1(x)\n",
559
+ " x = F.relu(x)\n",
560
+ " # fc1 outputs goes to fc2\n",
561
+ " x = self.fc2(x)\n",
562
+ " # We output the softmax\n",
563
+ " x = F.softmax(x, dim=1)\n",
564
+ " return(x)\n",
565
+ " \n",
566
+ " def act(self, state):\n",
567
+ " \"\"\"\n",
568
+ " Given a state, take action\n",
569
+ " \"\"\"\n",
570
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
571
+ " probs = self.forward(state).cpu()\n",
572
+ " m = Categorical(probs)\n",
573
+ " action = m.sample()\n",
574
+ " return action.item(), m.log_prob(action)"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "markdown",
579
+ "metadata": {
580
+ "id": "rOMrdwSYOWSC",
581
+ "jp-MarkdownHeadingCollapsed": true
582
+ },
583
+ "source": [
584
+ "### Solution"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 81,
590
+ "metadata": {
591
+ "id": "jGdhRSVrOV4K"
592
+ },
593
+ "outputs": [],
594
+ "source": [
595
+ "class Policy(nn.Module):\n",
596
+ " def __init__(self, s_size, a_size, h_size):\n",
597
+ " super(Policy, self).__init__()\n",
598
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
599
+ " self.fc2 = nn.Linear(h_size, a_size)\n",
600
+ "\n",
601
+ " def forward(self, x):\n",
602
+ " x = F.relu(self.fc1(x))\n",
603
+ " x = self.fc2(x)\n",
604
+ " return F.softmax(x, dim=1)\n",
605
+ " \n",
606
+ " def act(self, state):\n",
607
+ " # if isinstance(state, tuple):\n",
608
+ " # state = state[0]\n",
609
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
610
+ " probs = self.forward(state).cpu()\n",
611
+ " m = Categorical(probs)\n",
612
+ " action = np.argmax(m)\n",
613
+ " return action.item(), m.log_prob(action)"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "markdown",
618
+ "metadata": {
619
+ "id": "ZTGWL4g2eM5B"
620
+ },
621
+ "source": [
622
+ "I make a mistake, can you guess where?\n",
623
+ "\n",
624
+ "- To find out let's make a forward pass:"
625
+ ]
626
+ },
627
+ {
628
+ "cell_type": "code",
629
+ "execution_count": 84,
630
+ "metadata": {
631
+ "id": "lwnqGBCNePor"
632
+ },
633
+ "outputs": [
634
+ {
635
+ "data": {
636
+ "text/plain": [
637
+ "(1, tensor([-0.6791], grad_fn=<SqueezeBackward1>))"
638
+ ]
639
+ },
640
+ "execution_count": 84,
641
+ "metadata": {},
642
+ "output_type": "execute_result"
643
+ }
644
+ ],
645
+ "source": [
646
+ "debug_policy = Policy(s_size, a_size, 64).to(device)\n",
647
+ "state, _ = env.reset()\n",
648
+ "debug_policy.act(state)"
649
+ ]
650
+ },
651
+ {
652
+ "cell_type": "markdown",
653
+ "metadata": {
654
+ "id": "14UYkoxCPaor"
655
+ },
656
+ "source": [
657
+ "- Here we see that the error says `ValueError: The value argument to log_prob must be a Tensor`\n",
658
+ "\n",
659
+ "- It means that `action` in `m.log_prob(action)` must be a Tensor **but it's not.**\n",
660
+ "\n",
661
+ "- Do you know why? Check the act function and try to see why it does not work. \n",
662
+ "\n",
663
+ "Advice 💡: Something is wrong in this implementation. Remember that we act function **we want to sample an action from the probability distribution over actions**.\n"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "markdown",
668
+ "metadata": {
669
+ "id": "gfGJNZBUP7Vn",
670
+ "jp-MarkdownHeadingCollapsed": true
671
+ },
672
+ "source": [
673
+ "### (Real) Solution"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": 83,
679
+ "metadata": {
680
+ "id": "Ho_UHf49N9i4"
681
+ },
682
+ "outputs": [],
683
+ "source": [
684
+ "class Policy(nn.Module):\n",
685
+ " def __init__(self, s_size, a_size, h_size):\n",
686
+ " super(Policy, self).__init__()\n",
687
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
688
+ " self.fc2 = nn.Linear(h_size, a_size)\n",
689
+ "\n",
690
+ " def forward(self, x):\n",
691
+ " x = F.relu(self.fc1(x))\n",
692
+ " x = self.fc2(x)\n",
693
+ " return F.softmax(x, dim=1)\n",
694
+ " \n",
695
+ " def act(self, state):\n",
696
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
697
+ " probs = self.forward(state).cpu()\n",
698
+ " m = Categorical(probs)\n",
699
+ " action = m.sample()\n",
700
+ " return action.item(), m.log_prob(action)"
701
+ ]
702
+ },
703
+ {
704
+ "cell_type": "markdown",
705
+ "metadata": {
706
+ "id": "rgJWQFU_eUYw"
707
+ },
708
+ "source": [
709
+ "By using CartPole, it was easier to debug since **we know that the bug comes from our integration and not from our simple environment**."
710
+ ]
711
+ },
712
+ {
713
+ "cell_type": "markdown",
714
+ "metadata": {
715
+ "id": "c-20i7Pk0l1T"
716
+ },
717
+ "source": [
718
+ "- Since **we want to sample an action from the probability distribution over actions**, we can't use `action = np.argmax(m)` since it will always output the action that have the highest probability.\n",
719
+ "\n",
720
+ "- We need to replace with `action = m.sample()` that will sample an action from the probability distribution P(.|s)"
721
+ ]
722
+ },
723
+ {
724
+ "cell_type": "markdown",
725
+ "metadata": {
726
+ "id": "4MXoqetzfIoW"
727
+ },
728
+ "source": [
729
+ "### Let's build the Reinforce Training Algorithm\n",
730
+ "This is the Reinforce algorithm pseudocode:\n",
731
+ "\n",
732
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/pg_pseudocode.png\" alt=\"Policy gradient pseudocode\"/>\n",
733
+ " "
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "markdown",
738
+ "metadata": {
739
+ "id": "QmcXG-9i2Qu2"
740
+ },
741
+ "source": [
742
+ "- When we calculate the return Gt (line 6) we see that we calculate the sum of discounted rewards **starting at timestep t**.\n",
743
+ "\n",
744
+ "- Why? Because our policy should only **reinforce actions on the basis of the consequences**: so rewards obtained before taking an action are useless (since they were not because of the action), **only the ones that come after the action matters**.\n",
745
+ "\n",
746
+ "- Before coding this you should read this section [don't let the past distract you](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#don-t-let-the-past-distract-you) that explains why we use reward-to-go policy gradient.\n",
747
+ "\n",
748
+ "We use an interesting technique coded by [Chris1nexus](https://github.com/Chris1nexus) to **compute the return at each timestep efficiently**. The comments explained the procedure. Don't hesitate also [to check the PR explanation](https://github.com/huggingface/deep-rl-class/pull/95)\n",
749
+ "But overall the idea is to **compute the return at each timestep efficiently**."
750
+ ]
751
+ },
752
+ {
753
+ "cell_type": "markdown",
754
+ "metadata": {
755
+ "id": "O554nUGPpcoq"
756
+ },
757
+ "source": [
758
+ "The second question you may ask is **why do we minimize the loss**? You talked about Gradient Ascent not Gradient Descent?\n",
759
+ "\n",
760
+ "- We want to maximize our utility function $J(\\theta)$ but in PyTorch like in Tensorflow it's better to **minimize an objective function.**\n",
761
+ " - So let's say we want to reinforce action 3 at a certain timestep. Before training this action P is 0.25.\n",
762
+ " - So we want to modify $\\theta$ such that $\\pi_\\theta(a_3|s; \\theta) > 0.25$\n",
763
+ " - Because all P must sum to 1, max $\\pi_\\theta(a_3|s; \\theta)$ will **minimize other action probability.**\n",
764
+ " - So we should tell PyTorch **to min $1 - \\pi_\\theta(a_3|s; \\theta)$.**\n",
765
+ " - This loss function approaches 0 as $\\pi_\\theta(a_3|s; \\theta)$ nears 1.\n",
766
+ " - So we are encouraging the gradient to max $\\pi_\\theta(a_3|s; \\theta)$\n"
767
+ ]
768
+ },
769
+ {
770
+ "cell_type": "code",
771
+ "execution_count": 9,
772
+ "metadata": {
773
+ "id": "iOdv8Q9NfLK7"
774
+ },
775
+ "outputs": [],
776
+ "source": [
777
+ "def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):\n",
778
+ " # Help us to calculate the score during the training\n",
779
+ " scores_deque = deque(maxlen=100)\n",
780
+ " scores = []\n",
781
+ " # Line 3 of pseudocode\n",
782
+ " for i_episode in range(1, n_training_episodes+1):\n",
783
+ " saved_log_probs = []\n",
784
+ " rewards = []\n",
785
+ " state, _ = env.reset()\n",
786
+ " # Line 4 of pseudocode\n",
787
+ " for t in range(max_t):\n",
788
+ " action, log_prob = policy.act(state)# TODO get the action\n",
789
+ " saved_log_probs.append(log_prob)\n",
790
+ " state, reward, terminated, truncated, info = env.step(action)\n",
791
+ " rewards.append(reward)\n",
792
+ " if bool(terminated) or bool(truncated):\n",
793
+ " break \n",
794
+ " scores_deque.append(sum(rewards))\n",
795
+ " scores.append(sum(rewards))\n",
796
+ " \n",
797
+ " # Line 6 of pseudocode: calculate the return\n",
798
+ " returns = deque(maxlen=max_t)\n",
799
+ " n_steps = len(rewards) \n",
800
+ " # Compute the discounted returns at each timestep,\n",
801
+ " # as the sum of the gamma-discounted return at time t (G_t) + the reward at time t\n",
802
+ " \n",
803
+ " # In O(N) time, where N is the number of time steps\n",
804
+ " # (this definition of the discounted return G_t follows the definition of this quantity \n",
805
+ " # shown at page 44 of Sutton&Barto 2017 2nd draft)\n",
806
+ " # G_t = r_(t+1) + r_(t+2) + ...\n",
807
+ " \n",
808
+ " # Given this formulation, the returns at each timestep t can be computed \n",
809
+ " # by re-using the computed future returns G_(t+1) to compute the current return G_t\n",
810
+ " # G_t = r_(t+1) + gamma*G_(t+1)\n",
811
+ " # G_(t-1) = r_t + gamma* G_t\n",
812
+ " # (this follows a dynamic programming approach, with which we memorize solutions in order \n",
813
+ " # to avoid computing them multiple times)\n",
814
+ " \n",
815
+ " # This is correct since the above is equivalent to (see also page 46 of Sutton&Barto 2017 2nd draft)\n",
816
+ " # G_(t-1) = r_t + gamma*r_(t+1) + gamma*gamma*r_(t+2) + ...\n",
817
+ " \n",
818
+ " \n",
819
+ " ## Given the above, we calculate the returns at timestep t as: \n",
820
+ " # gamma[t] * return[t] + reward[t]\n",
821
+ " #\n",
822
+ " ## We compute this starting from the last timestep to the first, in order\n",
823
+ " ## to employ the formula presented above and avoid redundant computations that would be needed \n",
824
+ " ## if we were to do it from first to last.\n",
825
+ " \n",
826
+ " ## Hence, the queue \"returns\" will hold the returns in chronological order, from t=0 to t=n_steps\n",
827
+ " ## thanks to the appendleft() function which allows to append to the position 0 in constant time O(1)\n",
828
+ " ## a normal python list would instead require O(N) to do this.\n",
829
+ " for t in range(n_steps)[::-1]:\n",
830
+ " disc_return_t = (returns[0] if len(returns)>0 else 0)\n",
831
+ " returns.appendleft(gamma * disc_return_t + rewards[t]) \n",
832
+ " \n",
833
+ " ## standardization of the returns is employed to make training more stable\n",
834
+ " eps = np.finfo(np.float32).eps.item()\n",
835
+ " \n",
836
+ " ## eps is the smallest representable float, which is \n",
837
+ " # added to the standard deviation of the returns to avoid numerical instabilities\n",
838
+ " returns = torch.tensor(returns)\n",
839
+ " returns = (returns - returns.mean()) / (returns.std() + eps)\n",
840
+ " \n",
841
+ " # Line 7:\n",
842
+ " policy_loss = []\n",
843
+ " for log_prob, disc_return in zip(saved_log_probs, returns):\n",
844
+ " policy_loss.append(-log_prob * disc_return)\n",
845
+ " policy_loss = torch.cat(policy_loss).sum()\n",
846
+ " \n",
847
+ " # Line 8: PyTorch prefers gradient descent \n",
848
+ " optimizer.zero_grad()\n",
849
+ " policy_loss.backward()\n",
850
+ " optimizer.step()\n",
851
+ " \n",
852
+ " if i_episode % print_every == 0:\n",
853
+ " print('Episode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))\n",
854
+ " \n",
855
+ " return scores"
856
+ ]
857
+ },
858
+ {
859
+ "cell_type": "markdown",
860
+ "metadata": {
861
+ "id": "YB0Cxrw1StrP",
862
+ "jp-MarkdownHeadingCollapsed": true
863
+ },
864
+ "source": [
865
+ "#### Solution"
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "code",
870
+ "execution_count": 89,
871
+ "metadata": {
872
+ "id": "NCNvyElRStWG"
873
+ },
874
+ "outputs": [],
875
+ "source": [
876
+ "def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):\n",
877
+ " # Help us to calculate the score during the training\n",
878
+ " scores_deque = deque(maxlen=100)\n",
879
+ " scores = []\n",
880
+ " # Line 3 of pseudocode\n",
881
+ " for i_episode in range(1, n_training_episodes+1):\n",
882
+ " saved_log_probs = []\n",
883
+ " rewards = []\n",
884
+ " state, _ = env.reset()\n",
885
+ " # Line 4 of pseudocode\n",
886
+ " for t in range(max_t):\n",
887
+ " action, log_prob = policy.act(state)\n",
888
+ " saved_log_probs.append(log_prob)\n",
889
+ " state, reward, terminated, truncated, _ = env.step(action)\n",
890
+ " rewards.append(reward)\n",
891
+ " done = terminated or truncated\n",
892
+ " if done:\n",
893
+ " break \n",
894
+ " scores_deque.append(sum(rewards))\n",
895
+ " scores.append(sum(rewards))\n",
896
+ " \n",
897
+ " # Line 6 of pseudocode: calculate the return\n",
898
+ " returns = deque(maxlen=max_t) \n",
899
+ " n_steps = len(rewards) \n",
900
+ " # Compute the discounted returns at each timestep,\n",
901
+ " # as \n",
902
+ " # the sum of the gamma-discounted return at time t (G_t) + the reward at time t\n",
903
+ " #\n",
904
+ " # In O(N) time, where N is the number of time steps\n",
905
+ " # (this definition of the discounted return G_t follows the definition of this quantity \n",
906
+ " # shown at page 44 of Sutton&Barto 2017 2nd draft)\n",
907
+ " # G_t = r_(t+1) + r_(t+2) + ...\n",
908
+ " \n",
909
+ " # Given this formulation, the returns at each timestep t can be computed \n",
910
+ " # by re-using the computed future returns G_(t+1) to compute the current return G_t\n",
911
+ " # G_t = r_(t+1) + gamma*G_(t+1)\n",
912
+ " # G_(t-1) = r_t + gamma* G_t\n",
913
+ " # (this follows a dynamic programming approach, with which we memorize solutions in order \n",
914
+ " # to avoid computing them multiple times)\n",
915
+ " \n",
916
+ " # This is correct since the above is equivalent to (see also page 46 of Sutton&Barto 2017 2nd draft)\n",
917
+ " # G_(t-1) = r_t + gamma*r_(t+1) + gamma*gamma*r_(t+2) + ...\n",
918
+ " \n",
919
+ " \n",
920
+ " ## Given the above, we calculate the returns at timestep t as: \n",
921
+ " # gamma[t] * return[t] + reward[t]\n",
922
+ " #\n",
923
+ " ## We compute this starting from the last timestep to the first, in order\n",
924
+ " ## to employ the formula presented above and avoid redundant computations that would be needed \n",
925
+ " ## if we were to do it from first to last.\n",
926
+ " \n",
927
+ " ## Hence, the queue \"returns\" will hold the returns in chronological order, from t=0 to t=n_steps\n",
928
+ " ## thanks to the appendleft() function which allows to append to the position 0 in constant time O(1)\n",
929
+ " ## a normal python list would instead require O(N) to do this.\n",
930
+ " for t in range(n_steps)[::-1]:\n",
931
+ " disc_return_t = (returns[0] if len(returns)>0 else 0)\n",
932
+ " returns.appendleft( gamma*disc_return_t + rewards[t]) \n",
933
+ " \n",
934
+ " ## standardization of the returns is employed to make training more stable\n",
935
+ " eps = np.finfo(np.float32).eps.item()\n",
936
+ " ## eps is the smallest representable float, which is \n",
937
+ " # added to the standard deviation of the returns to avoid numerical instabilities \n",
938
+ " returns = torch.tensor(returns)\n",
939
+ " returns = (returns - returns.mean()) / (returns.std() + eps)\n",
940
+ " \n",
941
+ " # Line 7:\n",
942
+ " policy_loss = []\n",
943
+ " for log_prob, disc_return in zip(saved_log_probs, returns):\n",
944
+ " policy_loss.append(-log_prob * disc_return)\n",
945
+ " policy_loss = torch.cat(policy_loss).sum()\n",
946
+ " \n",
947
+ " # Line 8: PyTorch prefers gradient descent \n",
948
+ " optimizer.zero_grad()\n",
949
+ " policy_loss.backward()\n",
950
+ " optimizer.step()\n",
951
+ " \n",
952
+ " if i_episode % print_every == 0:\n",
953
+ " print('Episode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))\n",
954
+ " \n",
955
+ " return scores"
956
+ ]
957
+ },
958
+ {
959
+ "cell_type": "markdown",
960
+ "metadata": {
961
+ "id": "RIWhQyJjfpEt"
962
+ },
963
+ "source": [
964
+ "## Train it\n",
965
+ "- We're now ready to train our agent.\n",
966
+ "- But first, we define a variable containing all the training hyperparameters.\n",
967
+ "- You can change the training parameters (and should 😉)"
968
+ ]
969
+ },
970
+ {
971
+ "cell_type": "code",
972
+ "execution_count": 10,
973
+ "metadata": {
974
+ "id": "utRe1NgtVBYF"
975
+ },
976
+ "outputs": [],
977
+ "source": [
978
+ "cartpole_hyperparameters = {\n",
979
+ " \"h_size\": 16,\n",
980
+ " \"n_training_episodes\": 1000,\n",
981
+ " \"n_evaluation_episodes\": 10,\n",
982
+ " \"max_t\": 1000,\n",
983
+ " \"gamma\": 1.0,\n",
984
+ " \"lr\": 1e-2,\n",
985
+ " \"env_id\": env_id,\n",
986
+ " \"state_space\": s_size,\n",
987
+ " \"action_space\": a_size,\n",
988
+ "}"
989
+ ]
990
+ },
991
+ {
992
+ "cell_type": "code",
993
+ "execution_count": 15,
994
+ "metadata": {
995
+ "id": "D3lWyVXBVfl6"
996
+ },
997
+ "outputs": [],
998
+ "source": [
999
+ "# Create policy and place it to the device\n",
1000
+ "cartpole_policy = Policy(cartpole_hyperparameters[\"state_space\"], cartpole_hyperparameters[\"action_space\"], cartpole_hyperparameters[\"h_size\"]).to(device)\n",
1001
+ "cartpole_optimizer = optim.Adam(cartpole_policy.parameters(), lr=cartpole_hyperparameters[\"lr\"])"
1002
+ ]
1003
+ },
1004
+ {
1005
+ "cell_type": "code",
1006
+ "execution_count": 16,
1007
+ "metadata": {
1008
+ "id": "uGf-hQCnfouB"
1009
+ },
1010
+ "outputs": [
1011
+ {
1012
+ "name": "stdout",
1013
+ "output_type": "stream",
1014
+ "text": [
1015
+ "Episode 100\tAverage Score: 50.60\n",
1016
+ "Episode 200\tAverage Score: 312.39\n",
1017
+ "Episode 300\tAverage Score: 322.14\n",
1018
+ "Episode 400\tAverage Score: 498.43\n",
1019
+ "Episode 500\tAverage Score: 414.31\n",
1020
+ "Episode 600\tAverage Score: 474.40\n",
1021
+ "Episode 700\tAverage Score: 485.19\n",
1022
+ "Episode 800\tAverage Score: 498.00\n",
1023
+ "Episode 900\tAverage Score: 481.82\n",
1024
+ "Episode 1000\tAverage Score: 500.00\n"
1025
+ ]
1026
+ }
1027
+ ],
1028
+ "source": [
1029
+ "scores = reinforce(cartpole_policy,\n",
1030
+ " cartpole_optimizer,\n",
1031
+ " cartpole_hyperparameters[\"n_training_episodes\"], \n",
1032
+ " cartpole_hyperparameters[\"max_t\"],\n",
1033
+ " cartpole_hyperparameters[\"gamma\"], \n",
1034
+ " 100)"
1035
+ ]
1036
+ },
1037
+ {
1038
+ "cell_type": "markdown",
1039
+ "metadata": {
1040
+ "id": "Qajj2kXqhB3g"
1041
+ },
1042
+ "source": [
1043
+ "## Define evaluation method 📝\n",
1044
+ "- Here we define the evaluation method that we're going to use to test our Reinforce agent."
1045
+ ]
1046
+ },
1047
+ {
1048
+ "cell_type": "code",
1049
+ "execution_count": 11,
1050
+ "metadata": {
1051
+ "id": "3FamHmxyhBEU"
1052
+ },
1053
+ "outputs": [],
1054
+ "source": [
1055
+ "def evaluate_agent(env, max_steps, n_eval_episodes, policy):\n",
1056
+ " \"\"\"\n",
1057
+ " Evaluate the agent for ``n_eval_episodes`` episodes and returns average reward and std of reward.\n",
1058
+ " :param env: The evaluation environment\n",
1059
+ " :param n_eval_episodes: Number of episode to evaluate the agent\n",
1060
+ " :param policy: The Reinforce agent\n",
1061
+ " \"\"\"\n",
1062
+ " episode_rewards = []\n",
1063
+ " for episode in range(n_eval_episodes):\n",
1064
+ " state, _ = env.reset()\n",
1065
+ " step = 0\n",
1066
+ " total_rewards_ep = 0\n",
1067
+ " \n",
1068
+ " for step in range(max_steps):\n",
1069
+ " action, _ = policy.act(state)\n",
1070
+ " new_state, reward, terminated, truncated, _ = env.step(action)\n",
1071
+ " total_rewards_ep += reward\n",
1072
+ " \n",
1073
+ " if bool(terminated) or bool(truncated):\n",
1074
+ " break\n",
1075
+ " state = new_state\n",
1076
+ " episode_rewards.append(total_rewards_ep)\n",
1077
+ " mean_reward = np.mean(episode_rewards)\n",
1078
+ " std_reward = np.std(episode_rewards)\n",
1079
+ "\n",
1080
+ " return mean_reward, std_reward"
1081
+ ]
1082
+ },
1083
+ {
1084
+ "cell_type": "markdown",
1085
+ "metadata": {
1086
+ "id": "xdH2QCrLTrlT"
1087
+ },
1088
+ "source": [
1089
+ "## Evaluate our agent 📈"
1090
+ ]
1091
+ },
1092
+ {
1093
+ "cell_type": "code",
1094
+ "execution_count": 43,
1095
+ "metadata": {
1096
+ "id": "ohGSXDyHh0xx"
1097
+ },
1098
+ "outputs": [
1099
+ {
1100
+ "data": {
1101
+ "text/plain": [
1102
+ "(np.float64(500.0), np.float64(0.0))"
1103
+ ]
1104
+ },
1105
+ "execution_count": 43,
1106
+ "metadata": {},
1107
+ "output_type": "execute_result"
1108
+ }
1109
+ ],
1110
+ "source": [
1111
+ "evaluate_agent(eval_env, \n",
1112
+ " cartpole_hyperparameters[\"max_t\"], \n",
1113
+ " cartpole_hyperparameters[\"n_evaluation_episodes\"],\n",
1114
+ " cartpole_policy)"
1115
+ ]
1116
+ },
1117
+ {
1118
+ "cell_type": "markdown",
1119
+ "metadata": {
1120
+ "id": "7CoeLkQ7TpO8"
1121
+ },
1122
+ "source": [
1123
+ "### Publish our trained model on the Hub 🔥\n",
1124
+ "Now that we saw we got good results after the training, we can publish our trained model on the hub 🤗 with one line of code.\n",
1125
+ "\n",
1126
+ "Here's an example of a Model Card:\n",
1127
+ "\n",
1128
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/modelcard.png\"/>"
1129
+ ]
1130
+ },
1131
+ {
1132
+ "cell_type": "markdown",
1133
+ "metadata": {
1134
+ "id": "Jmhs1k-cftIq"
1135
+ },
1136
+ "source": [
1137
+ "### Push to the Hub\n",
1138
+ "#### Do not modify this code"
1139
+ ]
1140
+ },
1141
+ {
1142
+ "cell_type": "code",
1143
+ "execution_count": 12,
1144
+ "metadata": {
1145
+ "id": "LIVsvlW_8tcw"
1146
+ },
1147
+ "outputs": [],
1148
+ "source": [
1149
+ "from huggingface_hub import HfApi, snapshot_download\n",
1150
+ "from huggingface_hub.repocard import metadata_eval_result, metadata_save\n",
1151
+ "\n",
1152
+ "from gymnasium.wrappers import RecordVideo\n",
1153
+ "\n",
1154
+ "from pathlib import Path\n",
1155
+ "import datetime\n",
1156
+ "import json\n",
1157
+ "import imageio\n",
1158
+ "import cv2\n",
1159
+ "import numpy as np\n",
1160
+ "\n",
1161
+ "import tempfile\n",
1162
+ "\n",
1163
+ "import os"
1164
+ ]
1165
+ },
1166
+ {
1167
+ "cell_type": "code",
1168
+ "execution_count": 13,
1169
+ "metadata": {
1170
+ "id": "Lo4JH45if81z"
1171
+ },
1172
+ "outputs": [],
1173
+ "source": [
1174
+ "def record_video(env, policy, out_directory, fps=30):\n",
1175
+ " \"\"\"\n",
1176
+ " Generate a replay video of the agent\n",
1177
+ " :param env\n",
1178
+ " :param Qtable: Qtable of our agent\n",
1179
+ " :param out_directory\n",
1180
+ " :param fps: how many frame per seconds (with taxi-v3 and frozenlake-v1 we use 1)\n",
1181
+ " \"\"\"\n",
1182
+ " done = False\n",
1183
+ " \n",
1184
+ " state, _ = env.reset()\n",
1185
+ " frame = env.render()\n",
1186
+ " height, width, _ = frame.shape\n",
1187
+ "\n",
1188
+ " # Set up the video writer\n",
1189
+ " fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for the video\n",
1190
+ " video_writer = cv2.VideoWriter(out_directory, fourcc, fps, (width, height))\n",
1191
+ " print(out_directory)\n",
1192
+ " step=0\n",
1193
+ " \n",
1194
+ " while not done:\n",
1195
+ " print(step)\n",
1196
+ " step+=1\n",
1197
+ " # Take the action determined by the policy\n",
1198
+ " action, _ = policy.act(state)\n",
1199
+ " state, reward, terminated, truncated, _ = env.step(action)\n",
1200
+ " \n",
1201
+ " # Render the frame in 'rgb_array' mode and convert to BGR for OpenCV\n",
1202
+ " frame = env.render()\n",
1203
+ " bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n",
1204
+ "\n",
1205
+ " # Write the frame to the video\n",
1206
+ " video_writer.write(bgr_frame)\n",
1207
+ "\n",
1208
+ " done = terminated or truncated\n",
1209
+ "\n",
1210
+ " # Release the video writer\n",
1211
+ " video_writer.release()"
1212
+ ]
1213
+ },
1214
+ {
1215
+ "cell_type": "code",
1216
+ "execution_count": 15,
1217
+ "metadata": {
1218
+ "id": "_TPdq47D7_f_"
1219
+ },
1220
+ "outputs": [],
1221
+ "source": [
1222
+ "def push_to_hub(repo_id, \n",
1223
+ " model,\n",
1224
+ " hyperparameters,\n",
1225
+ " eval_env,\n",
1226
+ " video_fps=30):\n",
1227
+ " \"\"\"\n",
1228
+ " Evaluate, Generate a video and Upload a model to Hugging Face Hub.\n",
1229
+ " This method does the complete pipeline:\n",
1230
+ " - It evaluates the model\n",
1231
+ " - It generates the model card\n",
1232
+ " - It generates a replay video of the agent\n",
1233
+ " - It pushes everything to the Hub\n",
1234
+ "\n",
1235
+ " :param repo_id: repo_id: id of the model repository from the Hugging Face Hub\n",
1236
+ " :param model: the pytorch model we want to save\n",
1237
+ " :param hyperparameters: training hyperparameters\n",
1238
+ " :param eval_env: evaluation environment\n",
1239
+ " :param video_fps: how many frames per second to record our video replay \n",
1240
+ " \"\"\"\n",
1241
+ "\n",
1242
+ " _, repo_name = repo_id.split(\"/\")\n",
1243
+ " api = HfApi()\n",
1244
+ "\n",
1245
+ " # Step 1: Create the repo\n",
1246
+ " repo_url = api.create_repo(\n",
1247
+ " repo_id=repo_id,\n",
1248
+ " exist_ok=True,\n",
1249
+ " )\n",
1250
+ "\n",
1251
+ " # Define the path to save the video in the current working directory\n",
1252
+ " current_directory = Path(os.getcwd())\n",
1253
+ "\n",
1254
+ " # Step 2: Save the model\n",
1255
+ " torch.save(model, current_directory / \"model.pt\")\n",
1256
+ "\n",
1257
+ " # Step 3: Save the hyperparameters to JSON\n",
1258
+ " with open(current_directory / \"hyperparameters.json\", \"w\") as outfile:\n",
1259
+ " json.dump(hyperparameters, outfile)\n",
1260
+ "\n",
1261
+ " # Step 4: Evaluate the model and build JSON\n",
1262
+ " mean_reward, std_reward = evaluate_agent(eval_env, \n",
1263
+ " hyperparameters[\"max_t\"],\n",
1264
+ " hyperparameters[\"n_evaluation_episodes\"], \n",
1265
+ " model)\n",
1266
+ "\n",
1267
+ " # Get datetime\n",
1268
+ " eval_datetime = datetime.datetime.now()\n",
1269
+ " eval_form_datetime = eval_datetime.isoformat()\n",
1270
+ "\n",
1271
+ " evaluate_data = {\n",
1272
+ " \"env_id\": hyperparameters[\"env_id\"], \n",
1273
+ " \"mean_reward\": mean_reward,\n",
1274
+ " \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n",
1275
+ " \"eval_datetime\": eval_form_datetime,\n",
1276
+ " }\n",
1277
+ "\n",
1278
+ " # Write a JSON file\n",
1279
+ " with open(current_directory / \"results.json\", \"w\") as outfile:\n",
1280
+ " json.dump(evaluate_data, outfile)\n",
1281
+ "\n",
1282
+ " # Step 5: Create the model card\n",
1283
+ " env_name = hyperparameters[\"env_id\"]\n",
1284
+ " \n",
1285
+ " metadata = {}\n",
1286
+ " metadata[\"tags\"] = [\n",
1287
+ " env_name,\n",
1288
+ " \"reinforce\",\n",
1289
+ " \"reinforcement-learning\",\n",
1290
+ " \"custom-implementation\",\n",
1291
+ " \"deep-rl-class\"\n",
1292
+ " ]\n",
1293
+ "\n",
1294
+ " # Add metrics\n",
1295
+ " eval = metadata_eval_result(\n",
1296
+ " model_pretty_name=repo_name,\n",
1297
+ " task_pretty_name=\"reinforcement-learning\",\n",
1298
+ " task_id=\"reinforcement-learning\",\n",
1299
+ " metrics_pretty_name=\"mean_reward\",\n",
1300
+ " metrics_id=\"mean_reward\",\n",
1301
+ " metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n",
1302
+ " dataset_pretty_name=env_name,\n",
1303
+ " dataset_id=env_name,\n",
1304
+ " )\n",
1305
+ "\n",
1306
+ " # Merges both dictionaries\n",
1307
+ " metadata = {**metadata, **eval}\n",
1308
+ "\n",
1309
+ " model_card = f\"\"\"\n",
1310
+ "# **Reinforce** Agent playing **{env_name}**\n",
1311
+ "This is a trained model of a **Reinforce** agent playing **{env_name}**.\n",
1312
+ "To learn to use this model and train yours, check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction\n",
1313
+ "\"\"\"\n",
1314
+ "\n",
1315
+ " readme_path = current_directory / \"README.md\"\n",
1316
+ " readme = \"\"\n",
1317
+ " if readme_path.exists():\n",
1318
+ " with readme_path.open(\"r\", encoding=\"utf8\") as f:\n",
1319
+ " readme = f.read()\n",
1320
+ " else:\n",
1321
+ " readme = model_card\n",
1322
+ "\n",
1323
+ " with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n",
1324
+ " f.write(readme)\n",
1325
+ "\n",
1326
+ " # Save our metrics to Readme metadata\n",
1327
+ " metadata_save(readme_path, metadata)\n",
1328
+ "\n",
1329
+ " # Step 6: Record a video\n",
1330
+ " video_path = current_directory / \"replay.mp4\"\n",
1331
+ " record_video(render_env, model, video_path, video_fps)\n",
1332
+ " print(\"video recorded\")\n",
1333
+ "\n",
1334
+ " # Step 7. Push everything to the Hub\n",
1335
+ " api.upload_folder(\n",
1336
+ " repo_id=repo_id,\n",
1337
+ " folder_path=current_directory,\n",
1338
+ " path_in_repo=\".\",\n",
1339
+ " )\n",
1340
+ "\n",
1341
+ " print(f\"Your model is pushed to the Hub. You can view your model here: {repo_url}\")\n"
1342
+ ]
1343
+ },
1344
+ {
1345
+ "cell_type": "markdown",
1346
+ "metadata": {
1347
+ "id": "w17w8CxzoURM"
1348
+ },
1349
+ "source": [
1350
+ "### .\n",
1351
+ "\n",
1352
+ "By using `push_to_hub` **you evaluate, record a replay, generate a model card of your agent and push it to the Hub**.\n",
1353
+ "\n",
1354
+ "This way:\n",
1355
+ "- You can **showcase our work** 🔥\n",
1356
+ "- You can **visualize your agent playing** 👀\n",
1357
+ "- You can **share with the community an agent that others can use** 💾\n",
1358
+ "- You can **access a leaderboard 🏆 to see how well your agent is performing compared to your classmates** 👉 https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard\n"
1359
+ ]
1360
+ },
1361
+ {
1362
+ "cell_type": "markdown",
1363
+ "metadata": {
1364
+ "id": "cWnFC0iZooTw"
1365
+ },
1366
+ "source": [
1367
+ "To be able to share your model with the community there are three more steps to follow:\n",
1368
+ "\n",
1369
+ "1️⃣ (If it's not already done) create an account to HF ➡ https://huggingface.co/join\n",
1370
+ "\n",
1371
+ "2️⃣ Sign in and then, you need to store your authentication token from the Hugging Face website.\n",
1372
+ "- Create a new token (https://huggingface.co/settings/tokens) **with write role**\n",
1373
+ "\n",
1374
+ "\n",
1375
+ "<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/notebooks/create-token.jpg\" alt=\"Create HF Token\">\n"
1376
+ ]
1377
+ },
1378
+ {
1379
+ "cell_type": "code",
1380
+ "execution_count": 16,
1381
+ "metadata": {
1382
+ "id": "QB5nIcxR8paT"
1383
+ },
1384
+ "outputs": [
1385
+ {
1386
+ "data": {
1387
+ "application/vnd.jupyter.widget-view+json": {
1388
+ "model_id": "5cda3e494c2e4cc3bf833eb935891576",
1389
+ "version_major": 2,
1390
+ "version_minor": 0
1391
+ },
1392
+ "text/plain": [
1393
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
1394
+ ]
1395
+ },
1396
+ "metadata": {},
1397
+ "output_type": "display_data"
1398
+ }
1399
+ ],
1400
+ "source": [
1401
+ "notebook_login()"
1402
+ ]
1403
+ },
1404
+ {
1405
+ "cell_type": "markdown",
1406
+ "metadata": {
1407
+ "id": "GyWc1x3-o3xG"
1408
+ },
1409
+ "source": [
1410
+ "If you don't want to use a Google Colab or a Jupyter Notebook, you need to use this command instead: `huggingface-cli login` (or `login`)"
1411
+ ]
1412
+ },
1413
+ {
1414
+ "cell_type": "markdown",
1415
+ "metadata": {
1416
+ "id": "F-D-zhbRoeOm"
1417
+ },
1418
+ "source": [
1419
+ "3️⃣ We're now ready to push our trained agent to the 🤗 Hub 🔥 using `package_to_hub()` function"
1420
+ ]
1421
+ },
1422
+ {
1423
+ "cell_type": "code",
1424
+ "execution_count": 17,
1425
+ "metadata": {},
1426
+ "outputs": [],
1427
+ "source": [
1428
+ "import torch\n",
1429
+ "import os\n",
1430
+ "\n",
1431
+ "def load_model(model_path=\"model.pt\"):\n",
1432
+ " \"\"\"\n",
1433
+ " Load a saved PyTorch model from a file.\n",
1434
+ " \n",
1435
+ " :param model_path: Path to the saved model file (default: \"model.pt\")\n",
1436
+ " :return: The loaded model\n",
1437
+ " \"\"\"\n",
1438
+ " # Check if the file exists\n",
1439
+ " if not os.path.exists(model_path):\n",
1440
+ " raise FileNotFoundError(f\"Model file not found: {model_path}\")\n",
1441
+ " \n",
1442
+ " # Load the model using torch.load\n",
1443
+ " model = torch.load(model_path)\n",
1444
+ " \n",
1445
+ " return model\n"
1446
+ ]
1447
+ },
1448
+ {
1449
+ "cell_type": "code",
1450
+ "execution_count": 18,
1451
+ "metadata": {},
1452
+ "outputs": [
1453
+ {
1454
+ "name": "stdout",
1455
+ "output_type": "stream",
1456
+ "text": [
1457
+ "Model loaded successfully!\n"
1458
+ ]
1459
+ },
1460
+ {
1461
+ "name": "stderr",
1462
+ "output_type": "stream",
1463
+ "text": [
1464
+ "C:\\Users\\Utilisateur\\AppData\\Local\\Temp\\ipykernel_472\\2361131237.py:16: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
1465
+ " model = torch.load(model_path)\n"
1466
+ ]
1467
+ }
1468
+ ],
1469
+ "source": [
1470
+ "# Load the saved model\n",
1471
+ "cartpole_policy = load_model(\"model.pt\")\n",
1472
+ "\n",
1473
+ "# Now you can use the `loaded_model` to evaluate or continue training\n",
1474
+ "print(\"Model loaded successfully!\")\n"
1475
+ ]
1476
+ },
1477
+ {
1478
+ "cell_type": "code",
1479
+ "execution_count": null,
1480
+ "metadata": {
1481
+ "id": "UNwkTS65Uq3Q"
1482
+ },
1483
+ "outputs": [
1484
+ {
1485
+ "name": "stdout",
1486
+ "output_type": "stream",
1487
+ "text": [
1488
+ "C:\\Users\\Utilisateur\\OneDrive\\Documents\\distant_env\\Workspace\\Esteban\\HG_RL_courses\\unit4\\replay.mp4\n",
1489
+ "0\n",
1490
+ "1\n",
1491
+ "2\n",
1492
+ "3\n",
1493
+ "4\n",
1494
+ "5\n",
1495
+ "6\n",
1496
+ "7\n",
1497
+ "8\n",
1498
+ "9\n",
1499
+ "10\n",
1500
+ "11\n",
1501
+ "12\n",
1502
+ "13\n",
1503
+ "14\n",
1504
+ "15\n",
1505
+ "16\n",
1506
+ "17\n",
1507
+ "18\n",
1508
+ "19\n",
1509
+ "20\n",
1510
+ "21\n",
1511
+ "22\n",
1512
+ "23\n",
1513
+ "24\n",
1514
+ "25\n",
1515
+ "26\n",
1516
+ "27\n",
1517
+ "28\n",
1518
+ "29\n",
1519
+ "30\n",
1520
+ "31\n",
1521
+ "32\n",
1522
+ "33\n",
1523
+ "34\n",
1524
+ "35\n",
1525
+ "36\n",
1526
+ "37\n",
1527
+ "38\n",
1528
+ "39\n",
1529
+ "40\n",
1530
+ "41\n",
1531
+ "42\n",
1532
+ "43\n",
1533
+ "44\n",
1534
+ "45\n",
1535
+ "46\n",
1536
+ "47\n",
1537
+ "48\n",
1538
+ "49\n",
1539
+ "50\n",
1540
+ "51\n",
1541
+ "52\n",
1542
+ "53\n",
1543
+ "54\n",
1544
+ "55\n",
1545
+ "56\n",
1546
+ "57\n",
1547
+ "58\n",
1548
+ "59\n",
1549
+ "60\n",
1550
+ "61\n",
1551
+ "62\n",
1552
+ "63\n",
1553
+ "64\n",
1554
+ "65\n",
1555
+ "66\n",
1556
+ "67\n",
1557
+ "68\n",
1558
+ "69\n",
1559
+ "70\n",
1560
+ "71\n",
1561
+ "72\n",
1562
+ "73\n",
1563
+ "74\n",
1564
+ "75\n",
1565
+ "76\n",
1566
+ "77\n",
1567
+ "78\n",
1568
+ "79\n",
1569
+ "80\n",
1570
+ "81\n",
1571
+ "82\n",
1572
+ "83\n",
1573
+ "84\n",
1574
+ "85\n",
1575
+ "86\n",
1576
+ "87\n",
1577
+ "88\n",
1578
+ "89\n",
1579
+ "90\n",
1580
+ "91\n",
1581
+ "92\n",
1582
+ "93\n",
1583
+ "94\n",
1584
+ "95\n",
1585
+ "96\n",
1586
+ "97\n",
1587
+ "98\n",
1588
+ "99\n",
1589
+ "100\n",
1590
+ "101\n",
1591
+ "102\n",
1592
+ "103\n",
1593
+ "104\n",
1594
+ "105\n",
1595
+ "106\n",
1596
+ "107\n",
1597
+ "108\n",
1598
+ "109\n",
1599
+ "110\n",
1600
+ "111\n",
1601
+ "112\n",
1602
+ "113\n",
1603
+ "114\n",
1604
+ "115\n",
1605
+ "116\n",
1606
+ "117\n",
1607
+ "118\n",
1608
+ "119\n",
1609
+ "120\n",
1610
+ "121\n",
1611
+ "122\n",
1612
+ "123\n",
1613
+ "124\n",
1614
+ "125\n",
1615
+ "126\n",
1616
+ "127\n",
1617
+ "128\n",
1618
+ "129\n",
1619
+ "130\n",
1620
+ "131\n",
1621
+ "132\n",
1622
+ "133\n",
1623
+ "134\n",
1624
+ "135\n",
1625
+ "136\n",
1626
+ "137\n",
1627
+ "138\n",
1628
+ "139\n",
1629
+ "140\n",
1630
+ "141\n",
1631
+ "142\n",
1632
+ "143\n",
1633
+ "144\n",
1634
+ "145\n",
1635
+ "146\n",
1636
+ "147\n",
1637
+ "148\n",
1638
+ "149\n",
1639
+ "150\n",
1640
+ "151\n",
1641
+ "152\n",
1642
+ "153\n",
1643
+ "154\n",
1644
+ "155\n",
1645
+ "156\n",
1646
+ "157\n",
1647
+ "158\n",
1648
+ "159\n",
1649
+ "160\n",
1650
+ "161\n",
1651
+ "162\n",
1652
+ "163\n",
1653
+ "164\n",
1654
+ "165\n",
1655
+ "166\n",
1656
+ "167\n",
1657
+ "168\n",
1658
+ "169\n",
1659
+ "170\n",
1660
+ "171\n",
1661
+ "172\n",
1662
+ "173\n",
1663
+ "174\n",
1664
+ "175\n",
1665
+ "176\n",
1666
+ "177\n",
1667
+ "178\n",
1668
+ "179\n",
1669
+ "180\n",
1670
+ "181\n",
1671
+ "182\n",
1672
+ "183\n",
1673
+ "184\n",
1674
+ "185\n",
1675
+ "186\n",
1676
+ "187\n",
1677
+ "188\n",
1678
+ "189\n",
1679
+ "190\n",
1680
+ "191\n",
1681
+ "192\n",
1682
+ "193\n",
1683
+ "194\n",
1684
+ "195\n",
1685
+ "196\n",
1686
+ "197\n",
1687
+ "198\n",
1688
+ "199\n",
1689
+ "200\n",
1690
+ "201\n",
1691
+ "202\n",
1692
+ "203\n",
1693
+ "204\n",
1694
+ "205\n",
1695
+ "206\n",
1696
+ "207\n",
1697
+ "208\n",
1698
+ "209\n",
1699
+ "210\n",
1700
+ "211\n",
1701
+ "212\n",
1702
+ "213\n",
1703
+ "214\n",
1704
+ "215\n",
1705
+ "216\n",
1706
+ "217\n",
1707
+ "218\n",
1708
+ "219\n",
1709
+ "220\n",
1710
+ "221\n",
1711
+ "222\n",
1712
+ "223\n",
1713
+ "224\n",
1714
+ "225\n",
1715
+ "226\n",
1716
+ "227\n",
1717
+ "228\n",
1718
+ "229\n",
1719
+ "230\n",
1720
+ "231\n",
1721
+ "232\n",
1722
+ "233\n",
1723
+ "234\n",
1724
+ "235\n",
1725
+ "236\n",
1726
+ "237\n",
1727
+ "238\n",
1728
+ "239\n",
1729
+ "240\n",
1730
+ "241\n",
1731
+ "242\n",
1732
+ "243\n",
1733
+ "244\n",
1734
+ "245\n",
1735
+ "246\n",
1736
+ "247\n",
1737
+ "248\n",
1738
+ "249\n",
1739
+ "250\n",
1740
+ "251\n",
1741
+ "252\n",
1742
+ "253\n",
1743
+ "254\n",
1744
+ "255\n",
1745
+ "256\n",
1746
+ "257\n",
1747
+ "258\n",
1748
+ "259\n",
1749
+ "260\n",
1750
+ "261\n",
1751
+ "262\n",
1752
+ "263\n",
1753
+ "264\n",
1754
+ "265\n",
1755
+ "266\n",
1756
+ "267\n",
1757
+ "268\n",
1758
+ "269\n",
1759
+ "270\n",
1760
+ "271\n",
1761
+ "272\n",
1762
+ "273\n",
1763
+ "274\n",
1764
+ "275\n",
1765
+ "276\n",
1766
+ "277\n",
1767
+ "278\n",
1768
+ "279\n",
1769
+ "280\n",
1770
+ "281\n",
1771
+ "282\n",
1772
+ "283\n",
1773
+ "284\n",
1774
+ "285\n",
1775
+ "286\n",
1776
+ "287\n",
1777
+ "288\n",
1778
+ "289\n",
1779
+ "290\n",
1780
+ "291\n",
1781
+ "292\n",
1782
+ "293\n",
1783
+ "294\n",
1784
+ "295\n",
1785
+ "296\n",
1786
+ "297\n",
1787
+ "298\n",
1788
+ "299\n",
1789
+ "300\n",
1790
+ "301\n",
1791
+ "302\n",
1792
+ "303\n",
1793
+ "304\n",
1794
+ "305\n",
1795
+ "306\n",
1796
+ "307\n",
1797
+ "308\n",
1798
+ "309\n",
1799
+ "310\n",
1800
+ "311\n",
1801
+ "312\n",
1802
+ "313\n",
1803
+ "314\n",
1804
+ "315\n",
1805
+ "316\n",
1806
+ "317\n",
1807
+ "318\n",
1808
+ "319\n",
1809
+ "320\n",
1810
+ "321\n",
1811
+ "322\n",
1812
+ "323\n",
1813
+ "324\n",
1814
+ "325\n",
1815
+ "326\n",
1816
+ "327\n",
1817
+ "328\n",
1818
+ "329\n",
1819
+ "330\n",
1820
+ "331\n",
1821
+ "332\n",
1822
+ "333\n",
1823
+ "334\n",
1824
+ "335\n",
1825
+ "336\n",
1826
+ "337\n",
1827
+ "338\n",
1828
+ "339\n",
1829
+ "340\n",
1830
+ "341\n",
1831
+ "342\n",
1832
+ "343\n",
1833
+ "344\n",
1834
+ "345\n",
1835
+ "346\n",
1836
+ "347\n",
1837
+ "348\n",
1838
+ "349\n",
1839
+ "350\n",
1840
+ "351\n",
1841
+ "352\n",
1842
+ "353\n",
1843
+ "354\n",
1844
+ "355\n",
1845
+ "356\n",
1846
+ "357\n",
1847
+ "358\n",
1848
+ "359\n",
1849
+ "360\n",
1850
+ "361\n",
1851
+ "362\n",
1852
+ "363\n",
1853
+ "364\n",
1854
+ "365\n",
1855
+ "366\n",
1856
+ "367\n",
1857
+ "368\n",
1858
+ "369\n",
1859
+ "370\n",
1860
+ "371\n",
1861
+ "372\n",
1862
+ "373\n",
1863
+ "374\n",
1864
+ "375\n",
1865
+ "376\n",
1866
+ "377\n",
1867
+ "378\n",
1868
+ "379\n",
1869
+ "380\n",
1870
+ "381\n",
1871
+ "382\n",
1872
+ "383\n",
1873
+ "384\n",
1874
+ "385\n",
1875
+ "386\n",
1876
+ "387\n",
1877
+ "388\n",
1878
+ "389\n",
1879
+ "390\n",
1880
+ "391\n",
1881
+ "392\n",
1882
+ "393\n",
1883
+ "394\n",
1884
+ "395\n",
1885
+ "396\n",
1886
+ "397\n",
1887
+ "398\n",
1888
+ "399\n",
1889
+ "400\n",
1890
+ "401\n",
1891
+ "402\n",
1892
+ "403\n",
1893
+ "404\n",
1894
+ "405\n",
1895
+ "406\n",
1896
+ "407\n",
1897
+ "408\n",
1898
+ "409\n",
1899
+ "410\n",
1900
+ "411\n",
1901
+ "412\n",
1902
+ "413\n",
1903
+ "414\n",
1904
+ "415\n",
1905
+ "416\n",
1906
+ "417\n",
1907
+ "418\n",
1908
+ "419\n",
1909
+ "420\n",
1910
+ "421\n",
1911
+ "422\n",
1912
+ "423\n",
1913
+ "424\n",
1914
+ "425\n",
1915
+ "426\n",
1916
+ "427\n",
1917
+ "428\n",
1918
+ "429\n",
1919
+ "430\n",
1920
+ "431\n",
1921
+ "432\n",
1922
+ "433\n",
1923
+ "434\n",
1924
+ "435\n",
1925
+ "436\n",
1926
+ "437\n",
1927
+ "438\n",
1928
+ "439\n",
1929
+ "440\n",
1930
+ "441\n",
1931
+ "442\n",
1932
+ "443\n",
1933
+ "444\n",
1934
+ "445\n",
1935
+ "446\n",
1936
+ "447\n",
1937
+ "448\n",
1938
+ "449\n",
1939
+ "450\n",
1940
+ "451\n",
1941
+ "452\n",
1942
+ "453\n",
1943
+ "454\n",
1944
+ "455\n",
1945
+ "456\n",
1946
+ "457\n",
1947
+ "458\n",
1948
+ "459\n",
1949
+ "460\n",
1950
+ "461\n",
1951
+ "462\n",
1952
+ "463\n",
1953
+ "464\n",
1954
+ "465\n",
1955
+ "466\n",
1956
+ "467\n",
1957
+ "468\n",
1958
+ "469\n",
1959
+ "470\n",
1960
+ "471\n",
1961
+ "472\n",
1962
+ "473\n",
1963
+ "474\n",
1964
+ "475\n",
1965
+ "476\n",
1966
+ "477\n",
1967
+ "478\n",
1968
+ "479\n",
1969
+ "480\n",
1970
+ "481\n",
1971
+ "482\n",
1972
+ "483\n",
1973
+ "484\n",
1974
+ "485\n",
1975
+ "486\n",
1976
+ "487\n",
1977
+ "488\n",
1978
+ "489\n",
1979
+ "490\n",
1980
+ "491\n",
1981
+ "492\n",
1982
+ "493\n",
1983
+ "494\n",
1984
+ "495\n",
1985
+ "496\n",
1986
+ "497\n",
1987
+ "498\n",
1988
+ "499\n"
1989
+ ]
1990
+ }
1991
+ ],
1992
+ "source": [
1993
+ "repo_id = \"Stoub/Reinforce-Cartpole-v1\" # TODO Define your repo id {username/Reinforce-{model-id}}\n",
1994
+ "push_to_hub(repo_id,\n",
1995
+ " cartpole_policy, # The model we want to save\n",
1996
+ " cartpole_hyperparameters, # Hyperparameters\n",
1997
+ " eval_env, # Evaluation environment\n",
1998
+ " video_fps=30\n",
1999
+ " )"
2000
+ ]
2001
+ },
2002
+ {
2003
+ "cell_type": "markdown",
2004
+ "metadata": {
2005
+ "id": "jrnuKH1gYZSz"
2006
+ },
2007
+ "source": [
2008
+ "Now that we try the robustness of our implementation, let's try a more complex environment: PixelCopter 🚁\n",
2009
+ "\n",
2010
+ "\n"
2011
+ ]
2012
+ },
2013
+ {
2014
+ "cell_type": "markdown",
2015
+ "metadata": {
2016
+ "id": "JNLVmKKVKA6j"
2017
+ },
2018
+ "source": [
2019
+ "## Second agent: PixelCopter 🚁\n",
2020
+ "\n",
2021
+ "### Study the PixelCopter environment 👀\n",
2022
+ "- [The Environment documentation](https://pygame-learning-environment.readthedocs.io/en/latest/user/games/pixelcopter.html)\n"
2023
+ ]
2024
+ },
2025
+ {
2026
+ "cell_type": "code",
2027
+ "execution_count": null,
2028
+ "metadata": {
2029
+ "id": "JBSc8mlfyin3"
2030
+ },
2031
+ "outputs": [],
2032
+ "source": [
2033
+ "env_id = \"Pixelcopter-PLE-v0\"\n",
2034
+ "env = gym.make(env_id)\n",
2035
+ "eval_env = gym.make(env_id)\n",
2036
+ "s_size = env.observation_space.shape[0]\n",
2037
+ "a_size = env.action_space.n"
2038
+ ]
2039
+ },
2040
+ {
2041
+ "cell_type": "code",
2042
+ "execution_count": null,
2043
+ "metadata": {
2044
+ "id": "L5u_zAHsKBy7"
2045
+ },
2046
+ "outputs": [],
2047
+ "source": [
2048
+ "print(\"_____OBSERVATION SPACE_____ \\n\")\n",
2049
+ "print(\"The State Space is: \", s_size)\n",
2050
+ "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
2051
+ ]
2052
+ },
2053
+ {
2054
+ "cell_type": "code",
2055
+ "execution_count": null,
2056
+ "metadata": {
2057
+ "id": "D7yJM9YXKNbq"
2058
+ },
2059
+ "outputs": [],
2060
+ "source": [
2061
+ "print(\"\\n _____ACTION SPACE_____ \\n\")\n",
2062
+ "print(\"The Action Space is: \", a_size)\n",
2063
+ "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
2064
+ ]
2065
+ },
2066
+ {
2067
+ "cell_type": "markdown",
2068
+ "metadata": {
2069
+ "id": "NNWvlyvzalXr"
2070
+ },
2071
+ "source": [
2072
+ "The observation space (7) 👀:\n",
2073
+ "- player y position\n",
2074
+ "- player velocity\n",
2075
+ "- player distance to floor\n",
2076
+ "- player distance to ceiling\n",
2077
+ "- next block x distance to player\n",
2078
+ "- next blocks top y location\n",
2079
+ "- next blocks bottom y location\n",
2080
+ "\n",
2081
+ "The action space(2) 🎮:\n",
2082
+ "- Up (press accelerator) \n",
2083
+ "- Do nothing (don't press accelerator) \n",
2084
+ "\n",
2085
+ "The reward function 💰: \n",
2086
+ "- For each vertical block it passes through it gains a positive reward of +1. Each time a terminal state reached it receives a negative reward of -1."
2087
+ ]
2088
+ },
2089
+ {
2090
+ "cell_type": "markdown",
2091
+ "metadata": {
2092
+ "id": "aV1466QP8crz"
2093
+ },
2094
+ "source": [
2095
+ "### Define the new Policy 🧠\n",
2096
+ "- We need to have a deeper neural network since the environment is more complex"
2097
+ ]
2098
+ },
2099
+ {
2100
+ "cell_type": "code",
2101
+ "execution_count": null,
2102
+ "metadata": {
2103
+ "id": "I1eBkCiX2X_S"
2104
+ },
2105
+ "outputs": [],
2106
+ "source": [
2107
+ "class Policy(nn.Module):\n",
2108
+ " def __init__(self, s_size, a_size, h_size):\n",
2109
+ " super(Policy, self).__init__()\n",
2110
+ " # Define the three layers here\n",
2111
+ "\n",
2112
+ " def forward(self, x):\n",
2113
+ " # Define the forward process here\n",
2114
+ " return F.softmax(x, dim=1)\n",
2115
+ " \n",
2116
+ " def act(self, state):\n",
2117
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
2118
+ " probs = self.forward(state).cpu()\n",
2119
+ " m = Categorical(probs)\n",
2120
+ " action = m.sample()\n",
2121
+ " return action.item(), m.log_prob(action)"
2122
+ ]
2123
+ },
2124
+ {
2125
+ "cell_type": "markdown",
2126
+ "metadata": {
2127
+ "id": "47iuAFqV8Ws-"
2128
+ },
2129
+ "source": [
2130
+ "#### Solution"
2131
+ ]
2132
+ },
2133
+ {
2134
+ "cell_type": "code",
2135
+ "execution_count": null,
2136
+ "metadata": {
2137
+ "id": "wrNuVcHC8Xu7"
2138
+ },
2139
+ "outputs": [],
2140
+ "source": [
2141
+ "class Policy(nn.Module):\n",
2142
+ " def __init__(self, s_size, a_size, h_size):\n",
2143
+ " super(Policy, self).__init__()\n",
2144
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
2145
+ " self.fc2 = nn.Linear(h_size, h_size*2)\n",
2146
+ " self.fc3 = nn.Linear(h_size*2, a_size)\n",
2147
+ "\n",
2148
+ " def forward(self, x):\n",
2149
+ " x = F.relu(self.fc1(x))\n",
2150
+ " x = F.relu(self.fc2(x))\n",
2151
+ " x = self.fc3(x)\n",
2152
+ " return F.softmax(x, dim=1)\n",
2153
+ " \n",
2154
+ " def act(self, state):\n",
2155
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
2156
+ " probs = self.forward(state).cpu()\n",
2157
+ " m = Categorical(probs)\n",
2158
+ " action = m.sample()\n",
2159
+ " return action.item(), m.log_prob(action)"
2160
+ ]
2161
+ },
2162
+ {
2163
+ "cell_type": "markdown",
2164
+ "metadata": {
2165
+ "id": "SM1QiGCSbBkM"
2166
+ },
2167
+ "source": [
2168
+ "### Define the hyperparameters ⚙️\n",
2169
+ "- Because this environment is more complex.\n",
2170
+ "- Especially for the hidden size, we need more neurons."
2171
+ ]
2172
+ },
2173
+ {
2174
+ "cell_type": "code",
2175
+ "execution_count": null,
2176
+ "metadata": {
2177
+ "id": "y0uujOR_ypB6"
2178
+ },
2179
+ "outputs": [],
2180
+ "source": [
2181
+ "pixelcopter_hyperparameters = {\n",
2182
+ " \"h_size\": 64,\n",
2183
+ " \"n_training_episodes\": 50000,\n",
2184
+ " \"n_evaluation_episodes\": 10,\n",
2185
+ " \"max_t\": 10000,\n",
2186
+ " \"gamma\": 0.99,\n",
2187
+ " \"lr\": 1e-4,\n",
2188
+ " \"env_id\": env_id,\n",
2189
+ " \"state_space\": s_size,\n",
2190
+ " \"action_space\": a_size,\n",
2191
+ "}"
2192
+ ]
2193
+ },
2194
+ {
2195
+ "cell_type": "markdown",
2196
+ "metadata": {
2197
+ "id": "wyvXTJWm9GJG"
2198
+ },
2199
+ "source": [
2200
+ "### Train it\n",
2201
+ "- We're now ready to train our agent 🔥."
2202
+ ]
2203
+ },
2204
+ {
2205
+ "cell_type": "code",
2206
+ "execution_count": null,
2207
+ "metadata": {
2208
+ "id": "7mM2P_ckysFE"
2209
+ },
2210
+ "outputs": [],
2211
+ "source": [
2212
+ "# Create policy and place it to the device\n",
2213
+ "# torch.manual_seed(50)\n",
2214
+ "pixelcopter_policy = Policy(pixelcopter_hyperparameters[\"state_space\"], pixelcopter_hyperparameters[\"action_space\"], pixelcopter_hyperparameters[\"h_size\"]).to(device)\n",
2215
+ "pixelcopter_optimizer = optim.Adam(pixelcopter_policy.parameters(), lr=pixelcopter_hyperparameters[\"lr\"])"
2216
+ ]
2217
+ },
2218
+ {
2219
+ "cell_type": "code",
2220
+ "execution_count": null,
2221
+ "metadata": {
2222
+ "id": "v1HEqP-fy-Rf"
2223
+ },
2224
+ "outputs": [],
2225
+ "source": [
2226
+ "scores = reinforce(pixelcopter_policy,\n",
2227
+ " pixelcopter_optimizer,\n",
2228
+ " pixelcopter_hyperparameters[\"n_training_episodes\"], \n",
2229
+ " pixelcopter_hyperparameters[\"max_t\"],\n",
2230
+ " pixelcopter_hyperparameters[\"gamma\"], \n",
2231
+ " 1000)"
2232
+ ]
2233
+ },
2234
+ {
2235
+ "cell_type": "markdown",
2236
+ "metadata": {
2237
+ "id": "8kwFQ-Ip85BE"
2238
+ },
2239
+ "source": [
2240
+ "### Publish our trained model on the Hub 🔥"
2241
+ ]
2242
+ },
2243
+ {
2244
+ "cell_type": "code",
2245
+ "execution_count": null,
2246
+ "metadata": {
2247
+ "id": "6PtB7LRbTKWK"
2248
+ },
2249
+ "outputs": [],
2250
+ "source": [
2251
+ "repo_id = \"\" #TODO Define your repo id {username/Reinforce-{model-id}}\n",
2252
+ "push_to_hub(repo_id,\n",
2253
+ " pixelcopter_policy, # The model we want to save\n",
2254
+ " pixelcopter_hyperparameters, # Hyperparameters\n",
2255
+ " eval_env, # Evaluation environment\n",
2256
+ " video_fps=30\n",
2257
+ " )"
2258
+ ]
2259
+ },
2260
+ {
2261
+ "cell_type": "markdown",
2262
+ "metadata": {
2263
+ "id": "7VDcJ29FcOyb"
2264
+ },
2265
+ "source": [
2266
+ "## Some additional challenges 🏆\n",
2267
+ "The best way to learn **is to try things on your own**! As you saw, the current agent is not doing great. As a first suggestion, you can train for more steps. But also trying to find better parameters.\n",
2268
+ "\n",
2269
+ "In the [Leaderboard](https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard) you will find your agents. Can you get to the top?\n",
2270
+ "\n",
2271
+ "Here are some ideas to achieve so:\n",
2272
+ "* Train more steps\n",
2273
+ "* Try different hyperparameters by looking at what your classmates have done 👉 https://huggingface.co/models?other=reinforce\n",
2274
+ "* **Push your new trained model** on the Hub 🔥\n",
2275
+ "* **Improving the implementation for more complex environments** (for instance, what about changing the network to a Convolutional Neural Network to handle\n",
2276
+ "frames as observation)?"
2277
+ ]
2278
+ },
2279
+ {
2280
+ "cell_type": "markdown",
2281
+ "metadata": {
2282
+ "id": "x62pP0PHdA-y"
2283
+ },
2284
+ "source": [
2285
+ "________________________________________________________________________\n",
2286
+ "\n",
2287
+ "**Congrats on finishing this unit**! There was a lot of information.\n",
2288
+ "And congrats on finishing the tutorial. You've just coded your first Deep Reinforcement Learning agent from scratch using PyTorch and shared it on the Hub 🥳.\n",
2289
+ "\n",
2290
+ "Don't hesitate to iterate on this unit **by improving the implementation for more complex environments** (for instance, what about changing the network to a Convolutional Neural Network to handle\n",
2291
+ "frames as observation)?\n",
2292
+ "\n",
2293
+ "In the next unit, **we're going to learn more about Unity MLAgents**, by training agents in Unity environments. This way, you will be ready to participate in the **AI vs AI challenges where you'll train your agents\n",
2294
+ "to compete against other agents in a snowball fight and a soccer game.**\n",
2295
+ "\n",
2296
+ "Sounds fun? See you next time!\n",
2297
+ "\n",
2298
+ "Finally, we would love **to hear what you think of the course and how we can improve it**. If you have some feedback then, please 👉 [fill this form](https://forms.gle/BzKXWzLAGZESGNaE9)\n",
2299
+ "\n",
2300
+ "See you in Unit 5! 🔥\n",
2301
+ "\n",
2302
+ "### Keep Learning, stay awesome 🤗\n",
2303
+ "\n"
2304
+ ]
2305
+ }
2306
+ ],
2307
+ "metadata": {
2308
+ "accelerator": "GPU",
2309
+ "colab": {
2310
+ "collapsed_sections": [
2311
+ "BPLwsPajb1f8",
2312
+ "L_WSo0VUV99t",
2313
+ "mjY-eq3eWh9O",
2314
+ "JoTC9o2SczNn",
2315
+ "gfGJNZBUP7Vn",
2316
+ "YB0Cxrw1StrP",
2317
+ "47iuAFqV8Ws-",
2318
+ "x62pP0PHdA-y"
2319
+ ],
2320
+ "include_colab_link": true,
2321
+ "private_outputs": true,
2322
+ "provenance": []
2323
+ },
2324
+ "gpuClass": "standard",
2325
+ "kernelspec": {
2326
+ "display_name": "Python 3 (ipykernel)",
2327
+ "language": "python",
2328
+ "name": "python3"
2329
+ },
2330
+ "language_info": {
2331
+ "codemirror_mode": {
2332
+ "name": "ipython",
2333
+ "version": 3
2334
+ },
2335
+ "file_extension": ".py",
2336
+ "mimetype": "text/x-python",
2337
+ "name": "python",
2338
+ "nbconvert_exporter": "python",
2339
+ "pygments_lexer": "ipython3",
2340
+ "version": "3.12.5"
2341
+ }
2342
+ },
2343
+ "nbformat": 4,
2344
+ "nbformat_minor": 4
2345
+ }