sgoodfriend commited on
Commit
68e589c
·
1 Parent(s): 250e2c3

PPO playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Logging into tensorboard and wandb
132
+ runs/*
133
+ wandb
134
+
135
+ # macOS
136
+ .DS_STORE
137
+
138
+ # Local scratch work
139
+ scratch/*
140
+
141
+ # vscode
142
+ .vscode/
143
+
144
+ # Don't bother tracking saved_models or videos
145
+ saved_models/*
146
+ downloaded_models/*
147
+ videos/*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Scott Goodfriend
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: rl-algo-impls
3
+ tags:
4
+ - MountainCar-v0
5
+ - ppo
6
+ - deep-reinforcement-learning
7
+ - reinforcement-learning
8
+ model-index:
9
+ - name: ppo
10
+ results:
11
+ - metrics:
12
+ - type: mean_reward
13
+ value: -112.94 +/- 1.52
14
+ name: mean_reward
15
+ task:
16
+ type: reinforcement-learning
17
+ name: reinforcement-learning
18
+ dataset:
19
+ name: MountainCar-v0
20
+ type: MountainCar-v0
21
+ ---
22
+ # **PPO** Agent playing **MountainCar-v0**
23
+
24
+ This is a trained model of a **PPO** agent playing **MountainCar-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
+
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/6p2sjqtn.
27
+
28
+ ## Training Results
29
+
30
+ This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
+
32
+ | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
+ |:-------|:---------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | ppo | MountainCar-v0 | 4 | -113.75 | 1.47902 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/8yurshm7) |
35
+ | ppo | MountainCar-v0 | 5 | -111.688 | 10.4146 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/0s0kikzh) |
36
+ | ppo | MountainCar-v0 | 6 | -112.938 | 1.51941 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/jvwz1vhg) |
37
+
38
+
39
+ ### Prerequisites: Weights & Biases (WandB)
40
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
41
+ By default training goes to a rl-algo-impls project while benchmarks go to
42
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
43
+ models and the model weights are uploaded to WandB.
44
+
45
+ Before doing anything below, you'll need to create a wandb account and run `wandb
46
+ login`.
47
+
48
+
49
+
50
+ ## Usage
51
+ /sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
52
+
53
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
54
+ implementation could be sufficiently different to not be able to reproduce similar
55
+ results. You might need to checkout the commit the agent was trained on:
56
+ [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
57
+ ```
58
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/jvwz1vhg
60
+ ```
61
+
62
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
63
+ Colab starting from the
64
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
65
+ notebook.
66
+
67
+
68
+
69
+ ## Training
70
+ If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). While
72
+ training is deterministic, different hardware will give different results.
73
+
74
+ ```
75
+ python train.py --algo ppo --env MountainCar-v0 --seed 6
76
+ ```
77
+
78
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
79
+ Colab starting from the
80
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
81
+ notebook.
82
+
83
+
84
+
85
+ ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/6p2sjqtn were generated by running a script on a Lambda
87
+ Labs instance. In a Lambda Labs instance terminal:
88
+ ```
89
+ git clone [email protected]:sgoodfriend/rl-algo-impls.git
90
+ cd rl-algo-impls
91
+ bash ./lambda_labs/setup.sh
92
+ wandb login
93
+ bash ./lambda_labs/benchmark.sh
94
+ ```
95
+
96
+ ### Alternative: Google Colab Pro+
97
+ As an alternative,
98
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
99
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
100
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
101
+
102
+
103
+
104
+ ## Hyperparameters
105
+ This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
106
+ close and has some additional data:
107
+ ```
108
+ algo: ppo
109
+ algo_hyperparams:
110
+ ent_coef: 0
111
+ gae_lambda: 0.98
112
+ gamma: 0.99
113
+ n_epochs: 4
114
+ n_steps: 16
115
+ env: MountainCar-v0
116
+ env_hyperparams:
117
+ n_envs: 16
118
+ normalize: true
119
+ n_timesteps: 1000000
120
+ seed: 6
121
+ use_deterministic_algorithms: true
122
+ wandb_entity: null
123
+ wandb_project_name: rl-algo-impls-benchmarks
124
+ wandb_tags:
125
+ - benchmark_5598ebc
126
+ - host_192-9-145-26
127
+
128
+ ```
benchmark_publish.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import wandb
4
+ import wandb.apis.public
5
+
6
+ from collections import defaultdict
7
+ from multiprocessing.pool import ThreadPool
8
+ from typing import List, NamedTuple
9
+
10
+
11
+ class RunGroup(NamedTuple):
12
+ algo: str
13
+ env_id: str
14
+
15
+
16
+ if __name__ == "__main__":
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--wandb-project-name",
20
+ type=str,
21
+ default="rl-algo-impls-benchmarks",
22
+ help="WandB project name to load runs from",
23
+ )
24
+ parser.add_argument(
25
+ "--wandb-entity",
26
+ type=str,
27
+ default=None,
28
+ help="WandB team of project. None uses default entity",
29
+ )
30
+ parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
31
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
32
+ parser.add_argument(
33
+ "--envs", type=str, nargs="*", help="Optional filter down to these envs"
34
+ )
35
+ parser.add_argument(
36
+ "--huggingface-user",
37
+ type=str,
38
+ default=None,
39
+ help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
40
+ )
41
+ parser.add_argument(
42
+ "--pool-size",
43
+ type=int,
44
+ default=3,
45
+ help="How many publish jobs can run in parallel",
46
+ )
47
+ parser.set_defaults(
48
+ wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
49
+ wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
50
+ envs=["CartPole-v1", "Acrobot-v1"],
51
+ )
52
+ args = parser.parse_args()
53
+ print(args)
54
+
55
+ api = wandb.Api()
56
+ all_runs = api.runs(
57
+ f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
58
+ )
59
+
60
+ required_tags = set(args.wandb_tags)
61
+ runs: List[wandb.apis.public.Run] = [
62
+ r
63
+ for r in all_runs
64
+ if required_tags.issubset(set(r.config.get("wandb_tags", [])))
65
+ ]
66
+
67
+ runs_paths_by_group = defaultdict(list)
68
+ for r in runs:
69
+ algo = r.config["algo"]
70
+ env = r.config["env"]
71
+ if args.envs and env not in args.envs:
72
+ continue
73
+ run_group = RunGroup(algo, env)
74
+ runs_paths_by_group[run_group].append("/".join(r.path))
75
+
76
+ def run(run_paths: List[str]) -> None:
77
+ publish_args = ["python", "huggingface_publish.py"]
78
+ publish_args.append("--wandb-run-paths")
79
+ publish_args.extend(run_paths)
80
+ publish_args.append("--wandb-report-url")
81
+ publish_args.append(args.wandb_report_url)
82
+ if args.huggingface_user:
83
+ publish_args.append("--huggingface-user")
84
+ publish_args.append(args.huggingface_user)
85
+ subprocess.run(publish_args)
86
+
87
+ tp = ThreadPool(args.pool_size)
88
+ for run_paths in runs_paths_by_group.values():
89
+ tp.apply_async(run, (run_paths,))
90
+ tp.close()
91
+ tp.join()
benchmarks/colab_atari1.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_atari2.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_basic.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_benchmark.ipynb ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "## Setup\n",
63
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
64
+ ],
65
+ "metadata": {
66
+ "id": "bsG35Io0hmKG"
67
+ }
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "%%capture\n",
73
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
74
+ ],
75
+ "metadata": {
76
+ "id": "k5ynTV25hdAf"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "Installing the correct packages:\n",
85
+ "\n",
86
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
87
+ ],
88
+ "metadata": {
89
+ "id": "jKxGok-ElYQ7"
90
+ }
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "source": [
95
+ "%%capture\n",
96
+ "!apt install python-opengl\n",
97
+ "!apt install ffmpeg\n",
98
+ "!apt install xvfb\n",
99
+ "!apt install swig"
100
+ ],
101
+ "metadata": {
102
+ "id": "nn6EETTc2Ewf"
103
+ },
104
+ "execution_count": null,
105
+ "outputs": []
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "source": [
110
+ "%%capture\n",
111
+ "%cd /content/rl-algo-impls\n",
112
+ "!pip install -r colab_requirements.txt"
113
+ ],
114
+ "metadata": {
115
+ "id": "AfZh9rH3yQii"
116
+ },
117
+ "execution_count": null,
118
+ "outputs": []
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "source": [
123
+ "## Run Once Per Runtime"
124
+ ],
125
+ "metadata": {
126
+ "id": "4o5HOLjc4wq7"
127
+ }
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "import wandb\n",
133
+ "wandb.login()"
134
+ ],
135
+ "metadata": {
136
+ "id": "PCXa5tdS2qFX"
137
+ },
138
+ "execution_count": null,
139
+ "outputs": []
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "source": [
144
+ "## Restart Session beteween runs"
145
+ ],
146
+ "metadata": {
147
+ "id": "AZBZfSUV43JQ"
148
+ }
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "source": [
153
+ "%%capture\n",
154
+ "from pyvirtualdisplay import Display\n",
155
+ "\n",
156
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
157
+ "virtual_display.start()"
158
+ ],
159
+ "metadata": {
160
+ "id": "VzemeQJP2NO9"
161
+ },
162
+ "execution_count": null,
163
+ "outputs": []
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "source": [
168
+ "The below 5 bash scripts train agents on environments with 3 seeds each:\n",
169
+ "- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
170
+ "- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
171
+ "- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
172
+ ],
173
+ "metadata": {
174
+ "id": "nSHfna0hLlO1"
175
+ }
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "source": [
180
+ "%cd /content/rl-algo-impls\n",
181
+ "os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
182
+ "!./benchmarks/colab_basic.sh\n",
183
+ "!./benchmarks/colab_pybullet.sh\n",
184
+ "# !./benchmarks/colab_carracing.sh\n",
185
+ "# !./benchmarks/colab_atari1.sh\n",
186
+ "# !./benchmarks/colab_atari2.sh"
187
+ ],
188
+ "metadata": {
189
+ "id": "07aHYFH1zfXa"
190
+ },
191
+ "execution_count": null,
192
+ "outputs": []
193
+ }
194
+ ]
195
+ }
benchmarks/colab_carracing.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="CarRacing-v0"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_pybullet.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 Walker2DBulletEnv-v0 HopperBulletEnv-v0"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/train_loop.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_loop () {
2
+ local WANDB_TAGS="benchmark_$(git rev-parse --short HEAD) host_$(hostname)"
3
+ local algo
4
+ local env
5
+ local seed
6
+ local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
7
+ local args=()
8
+ (( VIRTUAL_DISPLAY == 1)) && args+=("--virtual-display")
9
+ local SEEDS="${SEEDS:-1 2 3}"
10
+ for algo in $(echo $1); do
11
+ for env in $(echo $2); do
12
+ for seed in $SEEDS; do
13
+ echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME ${args[@]}
14
+ done
15
+ done
16
+ done
17
+ }
colab_enjoy.ipynb ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. enjoy.py parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
72
+ ],
73
+ "metadata": {
74
+ "id": "jKL_NFhVOjSc"
75
+ },
76
+ "execution_count": 2,
77
+ "outputs": []
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "source": [
82
+ "## Setup\n",
83
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
84
+ ],
85
+ "metadata": {
86
+ "id": "bsG35Io0hmKG"
87
+ }
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "source": [
92
+ "%%capture\n",
93
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
94
+ ],
95
+ "metadata": {
96
+ "id": "k5ynTV25hdAf"
97
+ },
98
+ "execution_count": 3,
99
+ "outputs": []
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "source": [
104
+ "Installing the correct packages:\n",
105
+ "\n",
106
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
107
+ ],
108
+ "metadata": {
109
+ "id": "jKxGok-ElYQ7"
110
+ }
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "source": [
115
+ "%%capture\n",
116
+ "!apt install python-opengl\n",
117
+ "!apt install ffmpeg\n",
118
+ "!apt install xvfb\n",
119
+ "!apt install swig"
120
+ ],
121
+ "metadata": {
122
+ "id": "nn6EETTc2Ewf"
123
+ },
124
+ "execution_count": 4,
125
+ "outputs": []
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "source": [
130
+ "%%capture\n",
131
+ "%cd /content/rl-algo-impls\n",
132
+ "!pip install -r colab_requirements.txt"
133
+ ],
134
+ "metadata": {
135
+ "id": "AfZh9rH3yQii"
136
+ },
137
+ "execution_count": 5,
138
+ "outputs": []
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "source": [
143
+ "## Run Once Per Runtime"
144
+ ],
145
+ "metadata": {
146
+ "id": "4o5HOLjc4wq7"
147
+ }
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "source": [
152
+ "import wandb\n",
153
+ "wandb.login()"
154
+ ],
155
+ "metadata": {
156
+ "id": "PCXa5tdS2qFX"
157
+ },
158
+ "execution_count": null,
159
+ "outputs": []
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "source": [
164
+ "## Restart Session beteween runs"
165
+ ],
166
+ "metadata": {
167
+ "id": "AZBZfSUV43JQ"
168
+ }
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "source": [
173
+ "%%capture\n",
174
+ "from pyvirtualdisplay import Display\n",
175
+ "\n",
176
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
177
+ "virtual_display.start()"
178
+ ],
179
+ "metadata": {
180
+ "id": "VzemeQJP2NO9"
181
+ },
182
+ "execution_count": 7,
183
+ "outputs": []
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "source": [
188
+ "%cd /content/rl-algo-impls\n",
189
+ "!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
190
+ ],
191
+ "metadata": {
192
+ "id": "07aHYFH1zfXa"
193
+ },
194
+ "execution_count": null,
195
+ "outputs": []
196
+ }
197
+ ]
198
+ }
colab_requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ AutoROM.accept-rom-license >= 0.4.2, < 0.5
2
+ stable-baselines3[extra] >= 1.7.0, < 1.8
3
+ gym[box2d] >= 0.21.0, < 0.22
4
+ pyglet == 1.5.27
5
+ wandb >= 0.13.9, < 0.14
6
+ pyvirtualdisplay == 3.0
7
+ pybullet >= 3.2.5, < 3.3
8
+ tabulate >= 0.9.0, < 0.10
9
+ huggingface-hub >= 0.12.0, < 0.13
colab_train.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. train run parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "ALGO = \"ppo\"\n",
72
+ "ENV = \"CartPole-v1\"\n",
73
+ "SEED = 1"
74
+ ],
75
+ "metadata": {
76
+ "id": "jKL_NFhVOjSc"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "## Setup\n",
85
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
86
+ ],
87
+ "metadata": {
88
+ "id": "bsG35Io0hmKG"
89
+ }
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "source": [
94
+ "%%capture\n",
95
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
96
+ ],
97
+ "metadata": {
98
+ "id": "k5ynTV25hdAf"
99
+ },
100
+ "execution_count": null,
101
+ "outputs": []
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "source": [
106
+ "Installing the correct packages:\n",
107
+ "\n",
108
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
109
+ ],
110
+ "metadata": {
111
+ "id": "jKxGok-ElYQ7"
112
+ }
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "source": [
117
+ "%%capture\n",
118
+ "!apt install python-opengl\n",
119
+ "!apt install ffmpeg\n",
120
+ "!apt install xvfb\n",
121
+ "!apt install swig"
122
+ ],
123
+ "metadata": {
124
+ "id": "nn6EETTc2Ewf"
125
+ },
126
+ "execution_count": null,
127
+ "outputs": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "%%capture\n",
133
+ "%cd /content/rl-algo-impls\n",
134
+ "!pip install -r colab_requirements.txt"
135
+ ],
136
+ "metadata": {
137
+ "id": "AfZh9rH3yQii"
138
+ },
139
+ "execution_count": null,
140
+ "outputs": []
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "source": [
145
+ "## Run Once Per Runtime"
146
+ ],
147
+ "metadata": {
148
+ "id": "4o5HOLjc4wq7"
149
+ }
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "source": [
154
+ "import wandb\n",
155
+ "wandb.login()"
156
+ ],
157
+ "metadata": {
158
+ "id": "PCXa5tdS2qFX"
159
+ },
160
+ "execution_count": null,
161
+ "outputs": []
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "source": [
166
+ "## Restart Session beteween runs"
167
+ ],
168
+ "metadata": {
169
+ "id": "AZBZfSUV43JQ"
170
+ }
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "source": [
175
+ "%%capture\n",
176
+ "from pyvirtualdisplay import Display\n",
177
+ "\n",
178
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
179
+ "virtual_display.start()"
180
+ ],
181
+ "metadata": {
182
+ "id": "VzemeQJP2NO9"
183
+ },
184
+ "execution_count": null,
185
+ "outputs": []
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "source": [
190
+ "%cd /content/rl-algo-impls\n",
191
+ "!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
192
+ ],
193
+ "metadata": {
194
+ "id": "07aHYFH1zfXa"
195
+ },
196
+ "execution_count": null,
197
+ "outputs": []
198
+ }
199
+ ]
200
+ }
dqn/dqn.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from collections import deque
9
+ from torch.optim import Adam
10
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
11
+ from torch.utils.tensorboard.writer import SummaryWriter
12
+ from typing import List, NamedTuple, Optional, TypeVar
13
+
14
+ from dqn.policy import DQNPolicy
15
+ from shared.algorithm import Algorithm
16
+ from shared.callbacks.callback import Callback
17
+ from shared.schedule import linear_schedule
18
+
19
+
20
+ class Transition(NamedTuple):
21
+ obs: np.ndarray
22
+ action: np.ndarray
23
+ reward: float
24
+ done: bool
25
+ next_obs: np.ndarray
26
+
27
+
28
+ class Batch(NamedTuple):
29
+ obs: np.ndarray
30
+ actions: np.ndarray
31
+ rewards: np.ndarray
32
+ dones: np.ndarray
33
+ next_obs: np.ndarray
34
+
35
+
36
+ class ReplayBuffer:
37
+ def __init__(self, num_envs: int, maxlen: int) -> None:
38
+ self.num_envs = num_envs
39
+ self.buffer = deque(maxlen=maxlen)
40
+
41
+ def add(
42
+ self,
43
+ obs: VecEnvObs,
44
+ action: np.ndarray,
45
+ reward: np.ndarray,
46
+ done: np.ndarray,
47
+ next_obs: VecEnvObs,
48
+ ) -> None:
49
+ assert isinstance(obs, np.ndarray)
50
+ assert isinstance(next_obs, np.ndarray)
51
+ for i in range(self.num_envs):
52
+ self.buffer.append(
53
+ Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
54
+ )
55
+
56
+ def sample(self, batch_size: int) -> Batch:
57
+ ts = random.sample(self.buffer, batch_size)
58
+ return Batch(
59
+ obs=np.array([t.obs for t in ts]),
60
+ actions=np.array([t.action for t in ts]),
61
+ rewards=np.array([t.reward for t in ts]),
62
+ dones=np.array([t.done for t in ts]),
63
+ next_obs=np.array([t.next_obs for t in ts]),
64
+ )
65
+
66
+ def __len__(self) -> int:
67
+ return len(self.buffer)
68
+
69
+
70
+ DQNSelf = TypeVar("DQNSelf", bound="DQN")
71
+
72
+
73
+ class DQN(Algorithm):
74
+ def __init__(
75
+ self,
76
+ policy: DQNPolicy,
77
+ env: VecEnv,
78
+ device: torch.device,
79
+ tb_writer: SummaryWriter,
80
+ learning_rate: float = 1e-4,
81
+ buffer_size: int = 1_000_000,
82
+ learning_starts: int = 50_000,
83
+ batch_size: int = 32,
84
+ tau: float = 1.0,
85
+ gamma: float = 0.99,
86
+ train_freq: int = 4,
87
+ gradient_steps: int = 1,
88
+ target_update_interval: int = 10_000,
89
+ exploration_fraction: float = 0.1,
90
+ exploration_initial_eps: float = 1.0,
91
+ exploration_final_eps: float = 0.05,
92
+ max_grad_norm: float = 10.0,
93
+ ) -> None:
94
+ super().__init__(policy, env, device, tb_writer)
95
+ self.policy = policy
96
+
97
+ self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
98
+
99
+ self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
100
+ self.target_q_net.train(False)
101
+ self.tau = tau
102
+ self.target_update_interval = target_update_interval
103
+
104
+ self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
105
+ self.batch_size = batch_size
106
+
107
+ self.learning_starts = learning_starts
108
+ self.train_freq = train_freq
109
+ self.gradient_steps = gradient_steps
110
+
111
+ self.gamma = gamma
112
+ self.exploration_eps_schedule = linear_schedule(
113
+ exploration_initial_eps,
114
+ exploration_final_eps,
115
+ end_fraction=exploration_fraction,
116
+ )
117
+
118
+ self.max_grad_norm = max_grad_norm
119
+
120
+ def learn(
121
+ self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
122
+ ) -> DQNSelf:
123
+ self.policy.train(True)
124
+ obs = self.env.reset()
125
+ obs = self._collect_rollout(self.learning_starts, obs, 1)
126
+ learning_steps = total_timesteps - self.learning_starts
127
+ timesteps_elapsed = 0
128
+ steps_since_target_update = 0
129
+ while timesteps_elapsed < learning_steps:
130
+ progress = timesteps_elapsed / learning_steps
131
+ eps = self.exploration_eps_schedule(progress)
132
+ obs = self._collect_rollout(self.train_freq, obs, eps)
133
+ rollout_steps = self.train_freq
134
+ timesteps_elapsed += rollout_steps
135
+ for _ in range(
136
+ self.gradient_steps if self.gradient_steps > 0 else self.train_freq
137
+ ):
138
+ self.train()
139
+ steps_since_target_update += rollout_steps
140
+ if steps_since_target_update >= self.target_update_interval:
141
+ self._update_target()
142
+ steps_since_target_update = 0
143
+ if callback:
144
+ callback.on_step(timesteps_elapsed=rollout_steps)
145
+ return self
146
+
147
+ def train(self) -> None:
148
+ if len(self.replay_buffer) < self.batch_size:
149
+ return
150
+ o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
151
+ o = torch.as_tensor(o, device=self.device)
152
+ a = torch.as_tensor(a, device=self.device).unsqueeze(1)
153
+ r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
154
+ d = torch.as_tensor(d, dtype=torch.long, device=self.device)
155
+ next_o = torch.as_tensor(next_o, device=self.device)
156
+
157
+ with torch.no_grad():
158
+ target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
159
+ current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
160
+ loss = F.smooth_l1_loss(current, target)
161
+
162
+ self.optimizer.zero_grad()
163
+ loss.backward()
164
+ if self.max_grad_norm:
165
+ nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
166
+ self.optimizer.step()
167
+
168
+ def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
169
+ for _ in range(0, timesteps, self.env.num_envs):
170
+ action = self.policy.act(obs, eps, deterministic=False)
171
+ next_obs, reward, done, _ = self.env.step(action)
172
+ self.replay_buffer.add(obs, action, reward, done, next_obs)
173
+ obs = next_obs
174
+ return obs
175
+
176
+ def _update_target(self) -> None:
177
+ for target_param, param in zip(
178
+ self.target_q_net.parameters(), self.policy.q_net.parameters()
179
+ ):
180
+ target_param.data.copy_(
181
+ self.tau * param.data + (1 - self.tau) * target_param.data
182
+ )
dqn/policy.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+
5
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
6
+ from typing import Sequence, TypeVar
7
+
8
+ from dqn.q_net import QNetwork
9
+ from shared.policy.policy import Policy
10
+
11
+ DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
12
+
13
+
14
+ class DQNPolicy(Policy):
15
+ def __init__(
16
+ self,
17
+ env: VecEnv,
18
+ hidden_sizes: Sequence[int],
19
+ **kwargs,
20
+ ) -> None:
21
+ super().__init__(env, **kwargs)
22
+ self.q_net = QNetwork(env.observation_space, env.action_space, hidden_sizes)
23
+
24
+ def act(
25
+ self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
26
+ ) -> np.ndarray:
27
+ assert eps == 0 if deterministic else eps >= 0
28
+ if not deterministic and np.random.random() < eps:
29
+ return np.array(
30
+ [self.env.action_space.sample() for _ in range(self.env.num_envs)]
31
+ )
32
+ else:
33
+ with torch.no_grad():
34
+ obs_th = torch.as_tensor(np.array(obs))
35
+ if self.device:
36
+ obs_th = obs_th.to(self.device)
37
+ return self.q_net(obs_th).argmax(axis=1).cpu().numpy()
dqn/q_net.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch as th
3
+ import torch.nn as nn
4
+
5
+ from gym.spaces import Discrete
6
+ from typing import Sequence, Type
7
+
8
+ from shared.module import FeatureExtractor, mlp
9
+
10
+
11
+ class QNetwork(nn.Module):
12
+ def __init__(
13
+ self,
14
+ observation_space: gym.Space,
15
+ action_space: gym.Space,
16
+ hidden_sizes: Sequence[int],
17
+ activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
18
+ ) -> None:
19
+ super().__init__()
20
+ assert isinstance(action_space, Discrete)
21
+ self._feature_extractor = FeatureExtractor(observation_space, activation)
22
+ layer_sizes = (
23
+ (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
24
+ )
25
+ self._fc = mlp(layer_sizes, activation)
26
+
27
+ def forward(self, obs: th.Tensor) -> th.Tensor:
28
+ x = self._feature_extractor(obs)
29
+ return self._fc(x)
enjoy.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ from runner.evaluate import EvalArgs, evaluate_model
7
+ from runner.running_utils import base_parser
8
+
9
+
10
+ if __name__ == "__main__":
11
+ parser = base_parser(multiple=False)
12
+ parser.add_argument("--render", default=True, type=bool)
13
+ parser.add_argument("--best", default=True, type=bool)
14
+ parser.add_argument("--n_envs", default=1, type=int)
15
+ parser.add_argument("--n_episodes", default=3, type=int)
16
+ parser.add_argument("--deterministic-eval", default=None, type=bool)
17
+ parser.add_argument(
18
+ "--no-print-returns", action="store_true", help="Limit printing"
19
+ )
20
+ # wandb-run-path overrides base RunArgs
21
+ parser.add_argument("--wandb-run-path", default=None, type=str)
22
+ parser.set_defaults(
23
+ algo=["ppo"],
24
+ )
25
+ args = parser.parse_args()
26
+ args.algo = args.algo[0]
27
+ args.env = args.env[0]
28
+ args = EvalArgs(**vars(args))
29
+
30
+ evaluate_model(args, os.path.dirname(__file__))
environment.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: rl_algo_impls
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - nodefaults
6
+ dependencies:
7
+ - python=3.10.*
8
+ - mamba
9
+ - pip
10
+ - poetry
11
+ - pytorch
12
+ - torchvision
13
+ - torchaudio
14
+ - cmake
15
+ - swig
16
+ - ipywidgets
17
+ - black
huggingface_publish.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
+
5
+ import argparse
6
+ import requests
7
+ import shutil
8
+ import subprocess
9
+ import tempfile
10
+ import wandb
11
+ import wandb.apis.public
12
+
13
+ from typing import List, Optional
14
+
15
+ from huggingface_hub.hf_api import HfApi, upload_folder
16
+ from huggingface_hub.repocard import metadata_save
17
+ from publish.markdown_format import EvalTableData, model_card_text
18
+ from runner.evaluate import EvalArgs, evaluate_model
19
+ from runner.env import make_eval_env
20
+ from shared.callbacks.eval_callback import evaluate
21
+ from wrappers.vec_episode_recorder import VecEpisodeRecorder
22
+
23
+
24
+ def publish(
25
+ wandb_run_paths: List[str],
26
+ wandb_report_url: str,
27
+ huggingface_user: Optional[str] = None,
28
+ huggingface_token: Optional[str] = None,
29
+ ) -> None:
30
+ api = wandb.Api()
31
+ runs = [api.run(rp) for rp in wandb_run_paths]
32
+ algo = runs[0].config["algo"]
33
+ env = runs[0].config["env"]
34
+ evaluations = [
35
+ evaluate_model(
36
+ EvalArgs(
37
+ algo,
38
+ env,
39
+ seed=r.config.get("seed", None),
40
+ render=False,
41
+ best=True,
42
+ n_envs=None,
43
+ n_episodes=10,
44
+ no_print_returns=True,
45
+ wandb_run_path="/".join(r.path),
46
+ ),
47
+ os.path.dirname(__file__),
48
+ )
49
+ for r in runs
50
+ ]
51
+ run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
52
+ table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
53
+ best_eval = sorted(
54
+ table_data, key=lambda d: d.evaluation.stats.score, reverse=True
55
+ )[0]
56
+
57
+ with tempfile.TemporaryDirectory() as tmpdirname:
58
+ _, (policy, stats, config) = best_eval
59
+
60
+ repo_name = config.model_name(include_seed=False)
61
+ repo_dir_path = os.path.join(tmpdirname, repo_name)
62
+ # Locally clone this repo to a temp directory
63
+ subprocess.run(["git", "clone", ".", repo_dir_path])
64
+ shutil.rmtree(os.path.join(repo_dir_path, ".git"))
65
+ model_path = config.model_dir_path(best=True, downloaded=True)
66
+ shutil.copytree(
67
+ model_path,
68
+ os.path.join(
69
+ repo_dir_path, "saved_models", config.model_dir_name(best=True)
70
+ ),
71
+ )
72
+
73
+ github_url = "https://github.com/sgoodfriend/rl-algo-impls"
74
+ commit_hash = run_metadata.get("git", {}).get("commit", None)
75
+ card_text = model_card_text(
76
+ algo,
77
+ env,
78
+ github_url,
79
+ commit_hash,
80
+ wandb_report_url,
81
+ table_data,
82
+ best_eval,
83
+ )
84
+ readme_filepath = os.path.join(repo_dir_path, "README.md")
85
+ os.remove(readme_filepath)
86
+ with open(readme_filepath, "w") as f:
87
+ f.write(card_text)
88
+
89
+ metadata = {
90
+ "library_name": "rl-algo-impls",
91
+ "tags": [
92
+ env,
93
+ algo,
94
+ "deep-reinforcement-learning",
95
+ "reinforcement-learning",
96
+ ],
97
+ "model-index": [
98
+ {
99
+ "name": algo,
100
+ "results": [
101
+ {
102
+ "metrics": [
103
+ {
104
+ "type": "mean_reward",
105
+ "value": str(stats.score),
106
+ "name": "mean_reward",
107
+ }
108
+ ],
109
+ "task": {
110
+ "type": "reinforcement-learning",
111
+ "name": "reinforcement-learning",
112
+ },
113
+ "dataset": {
114
+ "name": env,
115
+ "type": env,
116
+ },
117
+ }
118
+ ],
119
+ }
120
+ ],
121
+ }
122
+ metadata_save(readme_filepath, metadata)
123
+
124
+ video_env = VecEpisodeRecorder(
125
+ make_eval_env(
126
+ config,
127
+ override_n_envs=1,
128
+ normalize_load_path=model_path,
129
+ **config.env_hyperparams,
130
+ ),
131
+ os.path.join(repo_dir_path, "replay"),
132
+ max_video_length=3600,
133
+ )
134
+ evaluate(
135
+ video_env,
136
+ policy,
137
+ 1,
138
+ deterministic=config.eval_params.get("deterministic", True),
139
+ )
140
+
141
+ api = HfApi()
142
+ huggingface_user = huggingface_user or api.whoami()["name"]
143
+ huggingface_repo = f"{huggingface_user}/{repo_name}"
144
+ api.create_repo(
145
+ token=huggingface_token,
146
+ repo_id=huggingface_repo,
147
+ private=False,
148
+ exist_ok=True,
149
+ )
150
+ repo_url = upload_folder(
151
+ repo_id=huggingface_repo,
152
+ folder_path=repo_dir_path,
153
+ path_in_repo="",
154
+ commit_message=f"{algo.upper()} playing {env} from {github_url}/tree/{commit_hash}",
155
+ token=huggingface_token,
156
+ )
157
+ print(f"Pushed model to the hub: {repo_url}")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ parser = argparse.ArgumentParser()
162
+ parser.add_argument(
163
+ "--wandb-run-paths",
164
+ type=str,
165
+ nargs="+",
166
+ help="Run paths of the form entity/project/run_id",
167
+ )
168
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
169
+ parser.add_argument(
170
+ "--huggingface-user",
171
+ type=str,
172
+ help="Huggingface user or team to upload model cards",
173
+ default=None,
174
+ )
175
+ args = parser.parse_args()
176
+ print(args)
177
+ publish(**vars(args))
hyperparams/dqn.yml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 5e4
3
+ env_hyperparams:
4
+ n_envs: 1
5
+ rolling_length: 50
6
+ policy_hyperparams:
7
+ hidden_sizes: [256, 256]
8
+ algo_hyperparams:
9
+ learning_rate: !!float 2.3e-3
10
+ batch_size: 64
11
+ buffer_size: 100000
12
+ learning_starts: 1000
13
+ gamma: 0.99
14
+ target_update_interval: 10
15
+ train_freq: 256
16
+ gradient_steps: 128
17
+ exploration_fraction: 0.16
18
+ exploration_final_eps: 0.04
19
+ eval_params:
20
+ step_freq: !!float 1e4
21
+ n_episodes: 10
22
+ save_best: true
23
+
24
+ CartPole-v0:
25
+ <<: *cartpole-defaults
26
+ n_timesteps: !!float 4e4
27
+
28
+ MountainCar-v0:
29
+ n_timesteps: !!float 1.2e5
30
+ env_hyperparams:
31
+ rolling_length: 50
32
+ policy_hyperparams:
33
+ hidden_sizes: [256, 256]
34
+ algo_hyperparams:
35
+ learning_rate: !!float 4e-3
36
+ batch_size: 128
37
+ buffer_size: 10000
38
+ learning_starts: 1000
39
+ gamma: 0.98
40
+ target_update_interval: 600
41
+ train_freq: 16
42
+ gradient_steps: 8
43
+ exploration_fraction: 0.2
44
+ exploration_final_eps: 0.07
45
+
46
+ Acrobot-v1:
47
+ n_timesteps: !!float 1e5
48
+ env_hyperparams:
49
+ rolling_length: 10
50
+ policy_hyperparams:
51
+ hidden_sizes: [256, 256]
52
+ algo_hyperparams:
53
+ learning_rate: !!float 6.3e-4
54
+ batch_size: 128
55
+ buffer_size: 50000
56
+ learning_starts: 0
57
+ gamma: 0.99
58
+ target_update_interval: 250
59
+ train_freq: 4
60
+ gradient_steps: -1
61
+ exploration_fraction: 0.12
62
+ exploration_final_eps: 0.1
63
+
64
+ LunarLander-v2:
65
+ n_timesteps: !!float 5e5
66
+ env_hyperparams:
67
+ rolling_length: 10
68
+ policy_hyperparams:
69
+ hidden_sizes: [256, 256]
70
+ algo_hyperparams:
71
+ learning_rate: !!float 1e-4
72
+ batch_size: 256
73
+ buffer_size: 100000
74
+ learning_starts: 10000
75
+ gamma: 0.99
76
+ target_update_interval: 250
77
+ train_freq: 8
78
+ gradient_steps: -1
79
+ exploration_fraction: 0.12
80
+ exploration_final_eps: 0.1
81
+ max_grad_norm: 0.5
82
+ eval_params:
83
+ step_freq: 25_000
84
+ n_episodes: 10
85
+ save_best: true
86
+
87
+ SpaceInvadersNoFrameskip-v4: &atari-defaults
88
+ n_timesteps: !!float 1e7
89
+ env_hyperparams:
90
+ frame_stack: 4
91
+ no_reward_timeout_steps: 1_000
92
+ n_envs: 8
93
+ vec_env_class: "subproc"
94
+ rolling_length: 20
95
+ policy_hyperparams:
96
+ hidden_sizes: [512]
97
+ algo_hyperparams:
98
+ buffer_size: 100000
99
+ learning_rate: !!float 1e-4
100
+ batch_size: 32
101
+ learning_starts: 100000
102
+ target_update_interval: 1000
103
+ train_freq: 8
104
+ gradient_steps: 2
105
+ exploration_fraction: 0.1
106
+ exploration_final_eps: 0.01
107
+ eval_params:
108
+ step_freq: 100_000
109
+ n_episodes: 10
110
+ save_best: true
111
+
112
+ BreakoutNoFrameskip-v4:
113
+ <<: *atari-defaults
114
+
115
+ PongNoFrameskip-v4:
116
+ <<: *atari-defaults
117
+ n_timesteps: !!float 2.5e6
hyperparams/ppo.yml ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 1e5
3
+ env_hyperparams:
4
+ n_envs: 8
5
+ algo_hyperparams:
6
+ n_steps: 32
7
+ batch_size: 256
8
+ n_epochs: 20
9
+ gae_lambda: 0.8
10
+ gamma: 0.98
11
+ ent_coef: 0.0
12
+ learning_rate: 0.001
13
+ learning_rate_decay: linear
14
+ clip_range: 0.2
15
+ clip_range_decay: linear
16
+ eval_params:
17
+ step_freq: !!float 2.5e4
18
+ n_episodes: 10
19
+ save_best: true
20
+
21
+ CartPole-v0:
22
+ <<: *cartpole-defaults
23
+ n_timesteps: !!float 5e4
24
+
25
+ MountainCar-v0:
26
+ n_timesteps: !!float 1e6
27
+ env_hyperparams:
28
+ normalize: true
29
+ n_envs: 16
30
+ algo_hyperparams:
31
+ n_steps: 16
32
+ n_epochs: 4
33
+ gae_lambda: 0.98
34
+ gamma: 0.99
35
+ ent_coef: 0.0
36
+
37
+ MountainCarContinuous-v0:
38
+ n_timesteps: !!float 1e5
39
+ env_hyperparams:
40
+ normalize: true
41
+ n_envs: 4
42
+ policy_hyperparams:
43
+ init_layers_orthogonal: false
44
+ # log_std_init: -3.29
45
+ algo_hyperparams:
46
+ n_steps: 512
47
+ batch_size: 256
48
+ n_epochs: 10
49
+ learning_rate: !!float 7.77e-5
50
+ ent_coef: 0.01 # 0.00429
51
+ ent_coef_decay: linear
52
+ clip_range: 0.1
53
+ gae_lambda: 0.9
54
+ max_grad_norm: 5
55
+ vf_coef: 0.19
56
+ # use_sde: true
57
+ eval_params:
58
+ step_freq: 5000
59
+ n_episodes: 10
60
+ save_best: true
61
+
62
+ Acrobot-v1:
63
+ n_timesteps: !!float 1e6
64
+ env_hyperparams:
65
+ n_envs: 16
66
+ normalize: true
67
+ algo_hyperparams:
68
+ n_steps: 256
69
+ n_epochs: 4
70
+ gae_lambda: 0.94
71
+ gamma: 0.99
72
+ ent_coef: 0.0
73
+
74
+ LunarLander-v2:
75
+ n_timesteps: !!float 1e6
76
+ env_hyperparams:
77
+ n_envs: 16
78
+ algo_hyperparams:
79
+ n_steps: 1024
80
+ batch_size: 64
81
+ n_epochs: 4
82
+ gae_lambda: 0.98
83
+ gamma: 0.999
84
+ ent_coef: 0.01
85
+ ent_coef_decay: linear
86
+ normalize_advantage: false
87
+ eval_params:
88
+ step_freq: !!float 5e4
89
+ n_episodes: 10
90
+ save_best: true
91
+
92
+ CarRacing-v0:
93
+ n_timesteps: !!float 4e6
94
+ env_hyperparams:
95
+ n_envs: 8
96
+ frame_stack: 4
97
+ policy_hyperparams:
98
+ use_sde: true
99
+ log_std_init: -2
100
+ init_layers_orthogonal: false
101
+ activation_fn: relu
102
+ share_features_extractor: false
103
+ cnn_feature_dim: 256
104
+ algo_hyperparams:
105
+ n_steps: 512
106
+ batch_size: 128
107
+ n_epochs: 10
108
+ learning_rate: !!float 1e-4
109
+ learning_rate_decay: linear
110
+ gamma: 0.99
111
+ gae_lambda: 0.95
112
+ ent_coef: 0.0
113
+ sde_sample_freq: 4
114
+ max_grad_norm: 0.5
115
+ vf_coef: 0.5
116
+ clip_range: 0.2
117
+
118
+ # BreakoutNoFrameskip-v4
119
+ # PongNoFrameskip-v4
120
+ # SpaceInvadersNoFrameskip-v4
121
+ # QbertNoFrameskip-v4
122
+ atari: &atari-defaults
123
+ n_timesteps: !!float 1e7
124
+ policy_hyperparams:
125
+ activation_fn: relu
126
+ env_hyperparams: &atari-env-defaults
127
+ n_envs: 8
128
+ frame_stack: 4
129
+ no_reward_timeout_steps: 1000
130
+ no_reward_fire_steps: 500
131
+ vec_env_class: subproc
132
+ algo_hyperparams:
133
+ n_steps: 128
134
+ batch_size: 256
135
+ n_epochs: 4
136
+ learning_rate: !!float 2.5e-4
137
+ learning_rate_decay: linear
138
+ clip_range: 0.1
139
+ clip_range_decay: linear
140
+ vf_coef: 0.5
141
+ ent_coef: 0.01
142
+ eval_params:
143
+ deterministic: false
144
+
145
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
146
+ n_timesteps: !!float 2e6
147
+ env_hyperparams: &pybullet-env-defaults
148
+ n_envs: 16
149
+ normalize: true
150
+ policy_hyperparams: &pybullet-policy-defaults
151
+ pi_hidden_sizes: [256, 256]
152
+ v_hidden_sizes: [256, 256]
153
+ activation_fn: relu
154
+ algo_hyperparams: &pybullet-algo-defaults
155
+ n_steps: 512
156
+ batch_size: 128
157
+ n_epochs: 20
158
+ gamma: 0.99
159
+ gae_lambda: 0.9
160
+ ent_coef: 0.0
161
+ sde_sample_freq: 4
162
+ max_grad_norm: 0.5
163
+ vf_coef: 0.5
164
+ learning_rate: !!float 3e-5
165
+ clip_range: 0.4
166
+
167
+ AntBulletEnv-v0:
168
+ <<: *pybullet-defaults
169
+ policy_hyperparams:
170
+ <<: *pybullet-policy-defaults
171
+ algo_hyperparams:
172
+ <<: *pybullet-algo-defaults
173
+
174
+ Walker2DBulletEnv-v0:
175
+ <<: *pybullet-defaults
176
+ algo_hyperparams:
177
+ <<: *pybullet-algo-defaults
178
+ clip_range_decay: linear
179
+
180
+ HopperBulletEnv-v0:
181
+ <<: *pybullet-defaults
182
+ algo_hyperparams:
183
+ <<: *pybullet-algo-defaults
184
+ clip_range_decay: linear
185
+
186
+ HumanoidBulletEnv-v0:
187
+ <<: *pybullet-defaults
188
+ n_timesteps: !!float 1e7
189
+ env_hyperparams:
190
+ <<: *pybullet-env-defaults
191
+ n_envs: 8
192
+ policy_hyperparams:
193
+ <<: *pybullet-policy-defaults
194
+ # log_std_init: -1
195
+ algo_hyperparams:
196
+ <<: *pybullet-algo-defaults
197
+ n_steps: 2048
198
+ batch_size: 64
199
+ n_epochs: 10
200
+ gae_lambda: 0.95
201
+ learning_rate: !!float 2.5e-4
202
+ clip_range: 0.2
hyperparams/vpg.yml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 4e5
3
+ policy_hyperparams:
4
+ hidden_sizes: [32]
5
+ algo_hyperparams:
6
+ steps_per_epoch: 4096
7
+ pi_lr: 0.01
8
+ gamma: 0.99
9
+ lam: 1
10
+ val_lr: 0.01
11
+ train_v_iters: 80
12
+ eval_params:
13
+ step_freq: !!float 2.5e4
14
+ n_episodes: 10
15
+ save_best: true
16
+
17
+ CartPole-v0:
18
+ <<: *cartpole-defaults
19
+ n_timesteps: !!float 1e5
20
+ algo_hyperparams:
21
+ steps_per_epoch: 1024
22
+ pi_lr: 0.01
23
+ gamma: 0.99
24
+ lam: 1
25
+ val_lr: 0.01
26
+ train_v_iters: 80
27
+
28
+ Acrobot-v1:
29
+ n_timesteps: !!float 2e5
30
+ policy_hyperparams:
31
+ hidden_sizes: [32, 32]
32
+ algo_hyperparams:
33
+ steps_per_epoch: 2048
34
+ pi_lr: 0.005
35
+ gamma: 0.99
36
+ lam: 0.97
37
+ val_lr: 0.01
38
+ train_v_iters: 80
39
+ max_grad_norm: 0.5
40
+ eval_params:
41
+ step_freq: !!float 4e4
42
+ n_episodes: 10
43
+ save_best: true
44
+
45
+ LunarLander-v2:
46
+ n_timesteps: !!float 4e6
47
+ policy_hyperparams:
48
+ hidden_sizes: [256, 256]
49
+ algo_hyperparams:
50
+ steps_per_epoch: 2048
51
+ pi_lr: 0.0001
52
+ gamma: 0.999
53
+ lam: 0.97
54
+ val_lr: 0.0001
55
+ train_v_iters: 80
56
+ max_grad_norm: 0.5
57
+ eval_params:
58
+ step_freq: !!float 5e4
59
+ n_episodes: 10
60
+ save_best: true
61
+
62
+ CarRacing-v0:
63
+ n_timesteps: !!float 4e6
64
+ env_hyperparams:
65
+ frame_stack: 4
66
+ n_envs: 4
67
+ vec_env_class: "dummy"
68
+ policy_hyperparams:
69
+ hidden_sizes: [256, 256]
70
+ algo_hyperparams:
71
+ steps_per_epoch: 4000
72
+ pi_lr: !!float 7e-5
73
+ gamma: 0.99
74
+ lam: 0.95
75
+ val_lr: !!float 1e-4
76
+ train_v_iters: 40
77
+ max_grad_norm: 0.5
78
+ eval_params:
79
+ step_freq: !!float 5e4
80
+ n_episodes: 10
81
+ save_best: true
82
+
83
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
84
+ n_timesteps: !!float 2e6
85
+ policy_hyperparams:
86
+ hidden_sizes: [64, 64]
87
+ init_layers_orthogonal: true
88
+ algo_hyperparams:
89
+ steps_per_epoch: 4000
90
+ pi_lr: !!float 3e-4
91
+ gamma: 0.99
92
+ lam: 0.97
93
+ val_lr: !!float 1e-3
94
+ train_v_iters: 80
95
+ max_grad_norm: 0.5
96
+ eval_params:
97
+ step_freq: !!float 1e5
98
+ n_episodes: 10
99
+ save_best: true
100
+
101
+ HopperBulletEnv-v0:
102
+ <<: *pybullet-defaults
103
+
104
+ AntBulletEnv-v0:
105
+ <<: *pybullet-defaults
106
+ policy_hyperparams:
107
+ hidden_sizes: [400, 300]
108
+ algo_hyperparams:
109
+ pi_lr: !!float 7e-4
110
+ gamma: 0.99
111
+ lam: 0.97
112
+ val_lr: !!float 7e-3
113
+ train_v_iters: 80
114
+ max_grad_norm: 0.5
115
+
116
+ FrozenLake-v1:
117
+ n_timesteps: !!float 8e5
118
+ env_params:
119
+ make_kwargs:
120
+ map_name: 8x8
121
+ is_slippery: true
122
+ policy_hyperparams:
123
+ hidden_sizes: [64]
124
+ algo_hyperparams:
125
+ steps_per_epoch: 2048
126
+ pi_lr: 0.01
127
+ gamma: 0.99
128
+ lam: 0.98
129
+ val_lr: 0.01
130
+ train_v_iters: 80
131
+ max_grad_norm: 0.5
132
+ eval_params:
133
+ step_freq: !!float 5e4
134
+ n_episodes: 10
135
+ save_best: true
136
+
137
+ SpaceInvadersNoFrameskip-v4: &atari-defaults
138
+ n_timesteps: !!float 1e7
139
+ env_hyperparams:
140
+ frame_stack: 4
141
+ no_reward_timeout_steps: 1_000
142
+ n_envs: 8
143
+ vec_env_class: "subproc"
144
+ policy_hyperparams:
145
+ hidden_sizes: [256, 256]
146
+ algo_hyperparams:
147
+ steps_per_epoch: 4096
148
+ pi_lr: !!float 1e-4
149
+ gamma: 0.99
150
+ lam: 0.95
151
+ val_lr: !!float 2e-4
152
+ train_v_iters: 80
153
+ max_grad_norm: 0.5
154
+ eval_params:
155
+ step_freq: !!float 1e5
156
+ n_episodes: 10
157
+ save_best: true
lambda_labs/benchmark.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+
3
+ # export WANDB_PROJECT_NAME="rl-algo-impls"
4
+ export VIRTUAL_DISPLAY=1
5
+
6
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
7
+
8
+ ALGOS=(
9
+ # "vpg"
10
+ # "dqn"
11
+ "ppo"
12
+ )
13
+ ENVS=(
14
+ # Basic
15
+ "CartPole-v1"
16
+ "MountainCar-v0"
17
+ "MountainCarContinuous-v0"
18
+ "Acrobot-v1"
19
+ "LunarLander-v2"
20
+ # PyBullet
21
+ "HalfCheetahBulletEnv-v0"
22
+ "AntBulletEnv-v0"
23
+ "Walker2DBulletEnv-v0"
24
+ "HopperBulletEnv-v0"
25
+ # CarRacing
26
+ "CarRacing-v0"
27
+ # Atari
28
+ "PongNoFrameskip-v4"
29
+ "BreakoutNoFrameskip-v4"
30
+ "SpaceInvadersNoFrameskip-v4"
31
+ "QbertNoFrameskip-v4"
32
+ )
33
+ train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
lambda_labs/lambda_requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scipy >= 1.10.0, < 1.11
2
+ tensorboard >= ^2.11.0, < 2.12
3
+ AutoROM.accept-rom-license >= 0.4.2, < 0.5
4
+ stable-baselines3[extra] >= 1.7.0, < 1.8
5
+ gym[box2d] >= 0.21.0, < 0.22
6
+ pyglet == 1.5.27
7
+ wandb >= 0.13.9, < 0.14
8
+ pyvirtualdisplay == 3.0
9
+ pybullet >= 3.2.5, < 3.3
10
+ tabulate >= 0.9.0, < 0.10
11
+ huggingface-hub >= 0.12.0, < 0.13
lambda_labs/setup.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ sudo apt update
2
+ sudo apt install -y python-opengl
3
+ sudo apt install -y ffmpeg
4
+ sudo apt install -y xvfb
5
+ sudo apt install -y swig
6
+
7
+ python3 -m pip install --upgrade pip
8
+ pip install --upgrade torch torchvision torchaudio
9
+
10
+ pip install --upgrade -r ~/rl-algo-impls/lambda_labs/lambda_requirements.txt
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
ppo/policy.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
2
+ from typing import Optional, Sequence
3
+
4
+ from gym.spaces import Box, Discrete
5
+ from shared.policy.on_policy import ActorCritic
6
+
7
+
8
+ class PPOActorCritic(ActorCritic):
9
+ def __init__(
10
+ self,
11
+ env: VecEnv,
12
+ pi_hidden_sizes: Optional[Sequence[int]] = None,
13
+ v_hidden_sizes: Optional[Sequence[int]] = None,
14
+ **kwargs,
15
+ ) -> None:
16
+ obs_space = env.observation_space
17
+ if isinstance(obs_space, Box):
18
+ if len(obs_space.shape) == 3:
19
+ pi_hidden_sizes = pi_hidden_sizes or []
20
+ v_hidden_sizes = v_hidden_sizes or []
21
+ elif len(obs_space.shape) == 1:
22
+ pi_hidden_sizes = pi_hidden_sizes or [64, 64]
23
+ v_hidden_sizes = v_hidden_sizes or [64, 64]
24
+ else:
25
+ raise ValueError(f"Unsupported observation space: {obs_space}")
26
+ elif isinstance(obs_space, Discrete):
27
+ pi_hidden_sizes = pi_hidden_sizes or [64]
28
+ v_hidden_sizes = v_hidden_sizes or [64]
29
+ else:
30
+ raise ValueError(f"Unsupported observation space: {obs_space}")
31
+ super().__init__(
32
+ env,
33
+ pi_hidden_sizes,
34
+ v_hidden_sizes,
35
+ **kwargs,
36
+ )
ppo/ppo.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from dataclasses import asdict, dataclass
6
+ from torch.optim import Adam
7
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
8
+ from torch.utils.tensorboard.writer import SummaryWriter
9
+ from typing import List, Optional, Sequence, NamedTuple, TypeVar
10
+
11
+ from shared.algorithm import Algorithm
12
+ from shared.callbacks.callback import Callback
13
+ from shared.policy.on_policy import ActorCritic
14
+ from shared.schedule import constant_schedule, linear_schedule
15
+ from shared.trajectory import Trajectory as BaseTrajectory
16
+ from shared.utils import discounted_cumsum
17
+
18
+
19
+ @dataclass
20
+ class PPOTrajectory(BaseTrajectory):
21
+ logp_a: List[float]
22
+ next_obs: Optional[np.ndarray]
23
+
24
+ def __init__(self) -> None:
25
+ super().__init__()
26
+ self.logp_a = []
27
+ self.next_obs = None
28
+
29
+ def add(
30
+ self,
31
+ obs: np.ndarray,
32
+ act: np.ndarray,
33
+ next_obs: np.ndarray,
34
+ rew: float,
35
+ terminated: bool,
36
+ v: float,
37
+ logp_a: float,
38
+ ):
39
+ super().add(obs, act, rew, v)
40
+ self.next_obs = next_obs if not terminated else None
41
+ self.terminated = terminated
42
+ self.logp_a.append(logp_a)
43
+
44
+
45
+ class TrajectoryAccumulator:
46
+ def __init__(self, num_envs: int) -> None:
47
+ self.num_envs = num_envs
48
+
49
+ self.trajectories_ = []
50
+ self.current_trajectories_ = [PPOTrajectory() for _ in range(num_envs)]
51
+
52
+ def step(
53
+ self,
54
+ obs: VecEnvObs,
55
+ action: np.ndarray,
56
+ next_obs: VecEnvObs,
57
+ reward: np.ndarray,
58
+ done: np.ndarray,
59
+ val: np.ndarray,
60
+ logp_a: np.ndarray,
61
+ ) -> None:
62
+ assert isinstance(obs, np.ndarray)
63
+ assert isinstance(next_obs, np.ndarray)
64
+ for i, trajectory in enumerate(self.current_trajectories_):
65
+ # TODO: Eventually take advantage of terminated/truncated differentiation in
66
+ # later versions of gym.
67
+ trajectory.add(
68
+ obs[i], action[i], next_obs[i], reward[i], done[i], val[i], logp_a[i]
69
+ )
70
+ if done[i]:
71
+ self.trajectories_.append(trajectory)
72
+ self.current_trajectories_[i] = PPOTrajectory()
73
+
74
+ @property
75
+ def all_trajectories(self) -> List[PPOTrajectory]:
76
+ return self.trajectories_ + list(
77
+ filter(lambda t: len(t), self.current_trajectories_)
78
+ )
79
+
80
+
81
+ class RtgAdvantage(NamedTuple):
82
+ rewards_to_go: torch.Tensor
83
+ advantage: torch.Tensor
84
+
85
+
86
+ class TrainStepStats(NamedTuple):
87
+ loss: float
88
+ pi_loss: float
89
+ v_loss: float
90
+ entropy_loss: float
91
+ approx_kl: float
92
+ clipped_frac: float
93
+
94
+
95
+ @dataclass
96
+ class TrainStats:
97
+ loss: float
98
+ pi_loss: float
99
+ v_loss: float
100
+ entropy_loss: float
101
+ approx_kl: float
102
+ clipped_frac: float
103
+
104
+ def __init__(self, step_stats: List[TrainStepStats]) -> None:
105
+ self.loss = np.mean([s.loss for s in step_stats]).item()
106
+ self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
107
+ self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
108
+ self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
109
+ self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
110
+ self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
111
+
112
+ def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
113
+ tb_writer.add_scalars("losses", asdict(self), global_step=global_step)
114
+
115
+ def __repr__(self) -> str:
116
+ return " | ".join(
117
+ [
118
+ f"Loss: {round(self.loss, 2)}",
119
+ f"Pi L: {round(self.pi_loss, 2)}",
120
+ f"V L: {round(self.v_loss, 2)}",
121
+ f"E L: {round(self.entropy_loss, 2)}",
122
+ f"Apx KL Div: {round(self.approx_kl, 2)}",
123
+ f"Clip Frac: {round(self.clipped_frac, 2)}",
124
+ ]
125
+ )
126
+
127
+
128
+ PPOSelf = TypeVar("PPOSelf", bound="PPO")
129
+
130
+
131
+ class PPO(Algorithm):
132
+ def __init__(
133
+ self,
134
+ policy: ActorCritic,
135
+ env: VecEnv,
136
+ device: torch.device,
137
+ tb_writer: SummaryWriter,
138
+ learning_rate: float = 3e-4,
139
+ learning_rate_decay: str = "none",
140
+ n_steps: int = 2048,
141
+ batch_size: int = 64,
142
+ n_epochs: int = 10,
143
+ gamma: float = 0.99,
144
+ gae_lambda: float = 0.95,
145
+ clip_range: float = 0.2,
146
+ clip_range_decay: str = "none",
147
+ clip_range_vf: Optional[float] = None,
148
+ clip_range_vf_decay: str = "none",
149
+ normalize_advantage: bool = True,
150
+ ent_coef: float = 0.0,
151
+ ent_coef_decay: str = "none",
152
+ vf_coef: float = 0.5,
153
+ max_grad_norm: float = 0.5,
154
+ update_rtg_between_epochs: bool = False,
155
+ sde_sample_freq: int = -1,
156
+ ) -> None:
157
+ super().__init__(policy, env, device, tb_writer)
158
+ self.policy = policy
159
+
160
+ self.gamma = gamma
161
+ self.gae_lambda = gae_lambda
162
+ self.optimizer = Adam(self.policy.parameters(), lr=learning_rate)
163
+ self.lr_schedule = (
164
+ linear_schedule(learning_rate, 0)
165
+ if learning_rate_decay == "linear"
166
+ else constant_schedule(learning_rate)
167
+ )
168
+ self.max_grad_norm = max_grad_norm
169
+ self.clip_range_schedule = (
170
+ linear_schedule(clip_range, 0)
171
+ if clip_range_decay == "linear"
172
+ else constant_schedule(clip_range)
173
+ )
174
+ self.clip_range_vf_schedule = None
175
+ if clip_range_vf:
176
+ self.clip_range_vf_schedule = (
177
+ linear_schedule(clip_range_vf, 0)
178
+ if clip_range_vf_decay == "linear"
179
+ else constant_schedule(clip_range_vf)
180
+ )
181
+ self.normalize_advantage = normalize_advantage
182
+ self.ent_coef_schedule = (
183
+ linear_schedule(ent_coef, 0)
184
+ if ent_coef_decay == "linear"
185
+ else constant_schedule(ent_coef)
186
+ )
187
+ self.vf_coef = vf_coef
188
+
189
+ self.n_steps = n_steps
190
+ self.batch_size = batch_size
191
+ self.n_epochs = n_epochs
192
+ self.sde_sample_freq = sde_sample_freq
193
+
194
+ self.update_rtg_between_epochs = update_rtg_between_epochs
195
+
196
+ def learn(
197
+ self: PPOSelf,
198
+ total_timesteps: int,
199
+ callback: Optional[Callback] = None,
200
+ ) -> PPOSelf:
201
+ obs = self.env.reset()
202
+ ts_elapsed = 0
203
+ while ts_elapsed < total_timesteps:
204
+ accumulator = self._collect_trajectories(obs)
205
+ progress = ts_elapsed / total_timesteps
206
+ train_stats = self.train(accumulator.all_trajectories, progress)
207
+ rollout_steps = self.n_steps * self.env.num_envs
208
+ ts_elapsed += rollout_steps
209
+ train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
210
+ if callback:
211
+ callback.on_step(timesteps_elapsed=rollout_steps)
212
+
213
+ return self
214
+
215
+ def _collect_trajectories(self, obs: VecEnvObs) -> TrajectoryAccumulator:
216
+ self.policy.eval()
217
+ accumulator = TrajectoryAccumulator(self.env.num_envs)
218
+ self.policy.reset_noise()
219
+ for i in range(self.n_steps):
220
+ if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
221
+ self.policy.reset_noise()
222
+ action, value, logp_a, clamped_action = self.policy.step(obs)
223
+ next_obs, reward, done, _ = self.env.step(clamped_action)
224
+ accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
225
+ obs = next_obs
226
+ return accumulator
227
+
228
+ def train(self, trajectories: List[PPOTrajectory], progress: float) -> TrainStats:
229
+ self.policy.train()
230
+ learning_rate = self.lr_schedule(progress)
231
+ self.optimizer.param_groups[0]["lr"] = learning_rate
232
+
233
+ pi_clip = self.clip_range_schedule(progress)
234
+ v_clip = (
235
+ self.clip_range_vf_schedule(progress)
236
+ if self.clip_range_vf_schedule
237
+ else None
238
+ )
239
+ ent_coef = self.ent_coef_schedule(progress)
240
+
241
+ obs = torch.as_tensor(
242
+ np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
243
+ )
244
+ act = torch.as_tensor(
245
+ np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
246
+ )
247
+ rtg, adv = self._compute_rtg_and_advantage(trajectories)
248
+ orig_v = torch.as_tensor(
249
+ np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
250
+ )
251
+ orig_logp_a = torch.as_tensor(
252
+ np.concatenate([np.array(t.logp_a) for t in trajectories]),
253
+ device=self.device,
254
+ )
255
+
256
+ step_stats = []
257
+ for _ in range(self.n_epochs):
258
+ if self.update_rtg_between_epochs:
259
+ rtg, adv = self._compute_rtg_and_advantage(trajectories)
260
+ else:
261
+ adv = self._compute_advantage(trajectories)
262
+ idxs = torch.randperm(len(obs))
263
+ for i in range(0, len(obs), self.batch_size):
264
+ mb_idxs = idxs[i : i + self.batch_size]
265
+ mb_adv = adv[mb_idxs]
266
+ if self.normalize_advantage:
267
+ mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
268
+ step_stats.append(
269
+ self._train_step(
270
+ pi_clip,
271
+ v_clip,
272
+ ent_coef,
273
+ obs[mb_idxs],
274
+ act[mb_idxs],
275
+ rtg[mb_idxs],
276
+ mb_adv,
277
+ orig_v[mb_idxs],
278
+ orig_logp_a[mb_idxs],
279
+ )
280
+ )
281
+
282
+ return TrainStats(step_stats)
283
+
284
+ def _train_step(
285
+ self,
286
+ pi_clip: float,
287
+ v_clip: Optional[float],
288
+ ent_coef: float,
289
+ obs: torch.Tensor,
290
+ act: torch.Tensor,
291
+ rtg: torch.Tensor,
292
+ adv: torch.Tensor,
293
+ orig_v: torch.Tensor,
294
+ orig_logp_a: torch.Tensor,
295
+ ) -> TrainStepStats:
296
+ logp_a, entropy, v = self.policy(obs, act)
297
+ logratio = logp_a - orig_logp_a
298
+ ratio = torch.exp(logratio)
299
+ clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
300
+ pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
301
+
302
+ v_loss = (v - rtg).pow(2)
303
+ if v_clip:
304
+ v_clipped = (torch.clamp(v, orig_v - v_clip, orig_v + v_clip) - rtg).pow(2)
305
+ v_loss = torch.maximum(v_loss, v_clipped)
306
+ v_loss = v_loss.mean()
307
+
308
+ entropy_loss = entropy.mean()
309
+
310
+ loss = pi_loss - ent_coef * entropy_loss + self.vf_coef * v_loss
311
+
312
+ self.optimizer.zero_grad()
313
+ loss.backward()
314
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
315
+ self.optimizer.step()
316
+
317
+ with torch.no_grad():
318
+ approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
319
+ clipped_frac = (
320
+ ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
321
+ )
322
+ return TrainStepStats(
323
+ loss.item(),
324
+ pi_loss.item(),
325
+ v_loss.item(),
326
+ entropy_loss.item(),
327
+ approx_kl,
328
+ clipped_frac,
329
+ )
330
+
331
+ def _compute_advantage(self, trajectories: Sequence[PPOTrajectory]) -> torch.Tensor:
332
+ advantage = []
333
+ for traj in trajectories:
334
+ last_val = 0
335
+ if not traj.terminated and traj.next_obs is not None:
336
+ last_val = self.policy.value(np.array(traj.next_obs))
337
+ rew = np.append(np.array(traj.rew), last_val)
338
+ v = np.append(np.array(traj.v), last_val)
339
+ deltas = rew[:-1] + self.gamma * v[1:] - v[:-1]
340
+ advantage.append(discounted_cumsum(deltas, self.gamma * self.gae_lambda))
341
+ return torch.as_tensor(
342
+ np.concatenate(advantage), dtype=torch.float32, device=self.device
343
+ )
344
+
345
+ def _compute_rtg_and_advantage(
346
+ self, trajectories: Sequence[PPOTrajectory]
347
+ ) -> RtgAdvantage:
348
+ rewards_to_go = []
349
+ advantages = []
350
+ for traj in trajectories:
351
+ last_val = 0
352
+ if not traj.terminated and traj.next_obs is not None:
353
+ last_val = self.policy.value(np.array(traj.next_obs))
354
+ rew = np.append(np.array(traj.rew), last_val)
355
+ v = np.append(np.array(traj.v), last_val)
356
+ deltas = rew[:-1] + self.gamma * v[1:] - v[:-1]
357
+ adv = discounted_cumsum(deltas, self.gamma * self.gae_lambda)
358
+ advantages.append(adv)
359
+ rewards_to_go.append(v[:-1] + adv)
360
+ return RtgAdvantage(
361
+ torch.as_tensor(
362
+ np.concatenate(rewards_to_go), dtype=torch.float32, device=self.device
363
+ ),
364
+ torch.as_tensor(
365
+ np.concatenate(advantages), dtype=torch.float32, device=self.device
366
+ ),
367
+ )
publish/markdown_format.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import wandb.apis.public
4
+ import yaml
5
+
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass, asdict
8
+ from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
9
+ from urllib.parse import urlparse
10
+
11
+ from runner.evaluate import Evaluation
12
+
13
+ EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
14
+
15
+
16
+ @dataclass
17
+ class EvaluationRow:
18
+ algo: str
19
+ env: str
20
+ seed: Optional[int]
21
+ reward_mean: float
22
+ reward_std: float
23
+ eval_episodes: int
24
+ best: str
25
+ wandb_url: str
26
+
27
+ @staticmethod
28
+ def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
29
+ results = defaultdict(list)
30
+ for r in rows:
31
+ for k, v in asdict(r).items():
32
+ results[k].append(v)
33
+ return pd.DataFrame(results)
34
+
35
+
36
+ class EvalTableData(NamedTuple):
37
+ run: wandb.apis.public.Run
38
+ evaluation: Evaluation
39
+
40
+
41
+ def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
42
+ best_stats = sorted(
43
+ [d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
44
+ )[0]
45
+ table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
46
+ rows = [
47
+ EvaluationRow(
48
+ config.algo,
49
+ config.env_id,
50
+ config.seed(),
51
+ stats.score.mean,
52
+ stats.score.std,
53
+ len(stats),
54
+ "*" if stats == best_stats else "",
55
+ f"[wandb]({r.url})",
56
+ )
57
+ for (r, (_, stats, config)) in table_data
58
+ ]
59
+ df = EvaluationRow.data_frame(rows)
60
+ return df.to_markdown(index=False)
61
+
62
+
63
+ def github_project_link(github_url: str) -> str:
64
+ return f"[{urlparse(github_url).path}]({github_url})"
65
+
66
+
67
+ def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
68
+ algo_caps = algo.upper()
69
+ lines = [
70
+ f"# **{algo_caps}** Agent playing **{env}**",
71
+ f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
72
+ f"the {github_project_link(github_url)} repo.",
73
+ f"All models trained at this commit can be found at {wandb_report_url}.",
74
+ ]
75
+ return "\n\n".join(lines)
76
+
77
+
78
+ def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
79
+ if not commit_hash:
80
+ return github_project_link(github_url)
81
+ return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
82
+
83
+
84
+ def results_section(
85
+ table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
86
+ ) -> str:
87
+ # type: ignore
88
+ lines = [
89
+ "## Training Results",
90
+ f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
91
+ + "agents using different initial seeds. "
92
+ + f"These agents were trained by checking out "
93
+ + f"{github_tree_link(github_url, commit_hash)}. "
94
+ + "The best and last models were kept from each training. "
95
+ + "This submission has loaded the best models from each training, reevaluates "
96
+ + "them, and selects the best model from these latest evaluations (mean - std).",
97
+ ]
98
+ lines.append(evaluation_table(table_data))
99
+ return "\n\n".join(lines)
100
+
101
+
102
+ def prerequisites_section() -> str:
103
+ return """
104
+ ### Prerequisites: Weights & Biases (WandB)
105
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
106
+ By default training goes to a rl-algo-impls project while benchmarks go to
107
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
108
+ models and the model weights are uploaded to WandB.
109
+
110
+ Before doing anything below, you'll need to create a wandb account and run `wandb
111
+ login`.
112
+ """
113
+
114
+
115
+ def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
116
+ return f"""
117
+ ## Usage
118
+ {urlparse(github_url).path}: {github_url}
119
+
120
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
121
+ implementation could be sufficiently different to not be able to reproduce similar
122
+ results. You might need to checkout the commit the agent was trained on:
123
+ {github_tree_link(github_url, commit_hash)}.
124
+ ```
125
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
126
+ python enjoy.py --wandb-run-path={run_path}
127
+ ```
128
+
129
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
130
+ Colab starting from the
131
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
132
+ notebook.
133
+ """
134
+
135
+
136
+ def training_setion(
137
+ github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
138
+ ) -> str:
139
+ return f"""
140
+ ## Training
141
+ If you want the highest chance to reproduce these results, you'll want to checkout the
142
+ commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
143
+ training is deterministic, different hardware will give different results.
144
+
145
+ ```
146
+ python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
147
+ ```
148
+
149
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
150
+ Colab starting from the
151
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
152
+ notebook.
153
+ """
154
+
155
+
156
+ def benchmarking_section(report_url: str) -> str:
157
+ return f"""
158
+ ## Benchmarking (with Lambda Labs instance)
159
+ This and other models from {report_url} were generated by running a script on a Lambda
160
+ Labs instance. In a Lambda Labs instance terminal:
161
+ ```
162
+ git clone [email protected]:sgoodfriend/rl-algo-impls.git
163
+ cd rl-algo-impls
164
+ bash ./lambda_labs/setup.sh
165
+ wandb login
166
+ bash ./lambda_labs/benchmark.sh
167
+ ```
168
+
169
+ ### Alternative: Google Colab Pro+
170
+ As an alternative,
171
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
172
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
173
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
174
+ """
175
+
176
+
177
+ def hyperparams_section(run_config: Dict[str, Any]) -> str:
178
+ return f"""
179
+ ## Hyperparameters
180
+ This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
181
+ run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
182
+ close and has some additional data:
183
+ ```
184
+ {yaml.dump(run_config)}
185
+ ```
186
+ """
187
+
188
+
189
+ def model_card_text(
190
+ algo: str,
191
+ env: str,
192
+ github_url: str,
193
+ commit_hash: str,
194
+ wandb_report_url: str,
195
+ table_data: List[EvalTableData],
196
+ best_eval: EvalTableData,
197
+ ) -> str:
198
+ run, (_, _, config) = best_eval
199
+ run_path = "/".join(run.path)
200
+ return "\n\n".join(
201
+ [
202
+ header_section(algo, env, github_url, wandb_report_url),
203
+ results_section(table_data, algo, github_url, commit_hash),
204
+ prerequisites_section(),
205
+ usage_section(github_url, run_path, commit_hash),
206
+ training_setion(github_url, commit_hash, algo, env, config.seed()),
207
+ benchmarking_section(wandb_report_url),
208
+ hyperparams_section(run.config),
209
+ ]
210
+ )
pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "rl-algo-impls"
3
+ version = "0.1.0"
4
+ description = "Implementations of reinforcement learning algorithms"
5
+ authors = ["Scott Goodfriend <[email protected]>"]
6
+ license = "MIT License"
7
+ readme = "README.md"
8
+ packages = [{include = "rl_algo_impls"}]
9
+
10
+ [tool.poetry.dependencies]
11
+ python = "~3.10"
12
+ "AutoROM.accept-rom-license" = "^0.4.2"
13
+ stable-baselines3 = {extras = ["extra"], version = "^1.7.0"}
14
+ scipy = "^1.10.0"
15
+ gym = {extras = ["box2d"], version = "^0.21.0"}
16
+ pyglet = "1.5.27"
17
+ PyYAML = "^6.0"
18
+ tensorboard = "^2.11.0"
19
+ pybullet = "^3.2.5"
20
+ wandb = "^0.13.9"
21
+ conda-lock = "^1.3.0"
22
+ torch-tb-profiler = "^0.4.1"
23
+ jupyter = "^1.0.0"
24
+ tabulate = "^0.9.0"
25
+ huggingface-hub = "^0.12.0"
26
+
27
+ [build-system]
28
+ requires = ["poetry-core"]
29
+ build-backend = "poetry.core.masonry.api"
replay.meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmpkxbc37m_/ppo-MountainCar-v0/replay.mp4"]}, "episode": {"r": -111.0, "l": 111, "t": 1.394619}}
replay.mp4 ADDED
Binary file (33.8 kB). View file
 
runner/config.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from datetime import datetime
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, Optional, TypedDict, Union
6
+
7
+
8
+ @dataclass
9
+ class RunArgs:
10
+ algo: str
11
+ env: str
12
+ seed: Optional[int] = None
13
+ use_deterministic_algorithms: bool = True
14
+
15
+
16
+ class Hyperparams(TypedDict, total=False):
17
+ device: str
18
+ n_timesteps: Union[int, float]
19
+ env_hyperparams: Dict[str, Any]
20
+ policy_hyperparams: Dict[str, Any]
21
+ algo_hyperparams: Dict[str, Any]
22
+ eval_params: Dict[str, Any]
23
+
24
+
25
+ @dataclass
26
+ class Config:
27
+ args: RunArgs
28
+ hyperparams: Hyperparams
29
+ root_dir: str
30
+ run_id: str = datetime.now().isoformat()
31
+
32
+ def seed(self, training: bool = True) -> Optional[int]:
33
+ seed = self.args.seed
34
+ if training or seed is None:
35
+ return seed
36
+ return seed + self.env_hyperparams.get("n_envs", 1)
37
+
38
+ @property
39
+ def device(self) -> str:
40
+ return self.hyperparams.get("device", "auto")
41
+
42
+ @property
43
+ def n_timesteps(self) -> int:
44
+ return int(self.hyperparams.get("n_timesteps", 100_000))
45
+
46
+ @property
47
+ def env_hyperparams(self) -> Dict[str, Any]:
48
+ return self.hyperparams.get("env_hyperparams", {})
49
+
50
+ @property
51
+ def policy_hyperparams(self) -> Dict[str, Any]:
52
+ return self.hyperparams.get("policy_hyperparams", {})
53
+
54
+ @property
55
+ def algo_hyperparams(self) -> Dict[str, Any]:
56
+ return self.hyperparams.get("algo_hyperparams", {})
57
+
58
+ @property
59
+ def eval_params(self) -> Dict[str, Any]:
60
+ return self.hyperparams.get("eval_params", {})
61
+
62
+ @property
63
+ def algo(self) -> str:
64
+ return self.args.algo
65
+
66
+ @property
67
+ def env_id(self) -> str:
68
+ return self.args.env
69
+
70
+ def model_name(self, include_seed: bool = True) -> str:
71
+ parts = [self.algo, self.env_id]
72
+ if include_seed and self.args.seed is not None:
73
+ parts.append(f"S{self.args.seed}")
74
+ make_kwargs = self.env_hyperparams.get("make_kwargs", {})
75
+ if make_kwargs:
76
+ for k, v in make_kwargs.items():
77
+ if type(v) == bool and v:
78
+ parts.append(k)
79
+ elif type(v) == int and v:
80
+ parts.append(f"{k}{v}")
81
+ else:
82
+ parts.append(str(v))
83
+ return "-".join(parts)
84
+
85
+ @property
86
+ def run_name(self) -> str:
87
+ parts = [self.model_name(), self.run_id]
88
+ return "-".join(parts)
89
+
90
+ @property
91
+ def saved_models_dir(self) -> str:
92
+ return os.path.join(self.root_dir, "saved_models")
93
+
94
+ @property
95
+ def downloaded_models_dir(self) -> str:
96
+ return os.path.join(self.root_dir, "downloaded_models")
97
+
98
+ def model_dir_name(
99
+ self,
100
+ best: bool = False,
101
+ extension: str = "",
102
+ ) -> str:
103
+ return self.model_name() + ("-best" if best else "") + extension
104
+
105
+ def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
106
+ return os.path.join(
107
+ self.saved_models_dir if not downloaded else self.downloaded_models_dir,
108
+ self.model_dir_name(best=best),
109
+ )
110
+
111
+ @property
112
+ def runs_dir(self) -> str:
113
+ return os.path.join(self.root_dir, "runs")
114
+
115
+ @property
116
+ def tensorboard_summary_path(self) -> str:
117
+ return os.path.join(self.runs_dir, self.run_name)
118
+
119
+ @property
120
+ def logs_path(self) -> str:
121
+ return os.path.join(self.runs_dir, f"log.yml")
122
+
123
+ @property
124
+ def videos_dir(self) -> str:
125
+ return os.path.join(self.root_dir, "videos")
126
+
127
+ @property
128
+ def video_prefix(self) -> str:
129
+ return os.path.join(self.videos_dir, self.model_name())
130
+
131
+ @property
132
+ def best_videos_dir(self) -> str:
133
+ return os.path.join(self.videos_dir, f"{self.model_name()}-best")
runner/env.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import os
3
+
4
+ from gym.wrappers.resize_observation import ResizeObservation
5
+ from gym.wrappers.gray_scale_observation import GrayScaleObservation
6
+ from gym.wrappers.frame_stack import FrameStack
7
+ from stable_baselines3.common.atari_wrappers import (
8
+ MaxAndSkipEnv,
9
+ NoopResetEnv,
10
+ )
11
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
12
+ from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
13
+ from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
14
+ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
15
+ from torch.utils.tensorboard.writer import SummaryWriter
16
+ from typing import Any, Callable, Dict, Optional, Union
17
+
18
+ from runner.config import Config
19
+ from shared.policy.policy import VEC_NORMALIZE_FILENAME
20
+ from wrappers.atari_wrappers import EpisodicLifeEnv, FireOnLifeStarttEnv, ClipRewardEnv
21
+ from wrappers.episode_record_video import EpisodeRecordVideo
22
+ from wrappers.episode_stats_writer import EpisodeStatsWriter
23
+ from wrappers.initial_step_truncate_wrapper import InitialStepTruncateWrapper
24
+ from wrappers.video_compat_wrapper import VideoCompatWrapper
25
+
26
+
27
+ def make_env(
28
+ config: Config,
29
+ training: bool = True,
30
+ render: bool = False,
31
+ normalize_load_path: Optional[str] = None,
32
+ n_envs: int = 1,
33
+ frame_stack: int = 1,
34
+ make_kwargs: Optional[Dict[str, Any]] = None,
35
+ no_reward_timeout_steps: Optional[int] = None,
36
+ no_reward_fire_steps: Optional[int] = None,
37
+ vec_env_class: str = "dummy",
38
+ normalize: bool = False,
39
+ normalize_kwargs: Optional[Dict[str, Any]] = None,
40
+ tb_writer: Optional[SummaryWriter] = None,
41
+ rolling_length: int = 100,
42
+ train_record_video: bool = False,
43
+ video_step_interval: Union[int, float] = 1_000_000,
44
+ initial_steps_to_truncate: Optional[int] = None,
45
+ ) -> VecEnv:
46
+ if "BulletEnv" in config.env_id:
47
+ import pybullet_envs
48
+
49
+ make_kwargs = make_kwargs if make_kwargs is not None else {}
50
+ if "BulletEnv" in config.env_id and render:
51
+ make_kwargs["render"] = True
52
+ if "CarRacing" in config.env_id:
53
+ make_kwargs["verbose"] = 0
54
+
55
+ spec = gym.spec(config.env_id)
56
+
57
+ def make(idx: int) -> Callable[[], gym.Env]:
58
+ def _make() -> gym.Env:
59
+ env = gym.make(config.env_id, **make_kwargs)
60
+ env = gym.wrappers.RecordEpisodeStatistics(env)
61
+ env = VideoCompatWrapper(env)
62
+ if training and train_record_video and idx == 0:
63
+ env = EpisodeRecordVideo(
64
+ env,
65
+ config.video_prefix,
66
+ step_increment=n_envs,
67
+ video_step_interval=int(video_step_interval),
68
+ )
69
+ if training and initial_steps_to_truncate:
70
+ env = InitialStepTruncateWrapper(
71
+ env, idx * initial_steps_to_truncate // n_envs
72
+ )
73
+ if "AtariEnv" in spec.entry_point: # type: ignore
74
+ env = NoopResetEnv(env, noop_max=30)
75
+ env = MaxAndSkipEnv(env, skip=4)
76
+ env = EpisodicLifeEnv(env, training=training)
77
+ action_meanings = env.unwrapped.get_action_meanings()
78
+ if "FIRE" in action_meanings: # type: ignore
79
+ env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
80
+ env = ClipRewardEnv(env, training=training)
81
+ env = ResizeObservation(env, (84, 84))
82
+ env = GrayScaleObservation(env, keep_dim=False)
83
+ env = FrameStack(env, frame_stack)
84
+ elif "CarRacing" in config.env_id:
85
+ env = ResizeObservation(env, (64, 64))
86
+ env = GrayScaleObservation(env, keep_dim=False)
87
+ env = FrameStack(env, frame_stack)
88
+
89
+ if no_reward_timeout_steps:
90
+ from wrappers.no_reward_timeout import NoRewardTimeout
91
+
92
+ env = NoRewardTimeout(
93
+ env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
94
+ )
95
+
96
+ seed = config.seed(training=training)
97
+ if seed is not None:
98
+ env.seed(seed + idx)
99
+ env.action_space.seed(seed + idx)
100
+ env.observation_space.seed(seed + idx)
101
+
102
+ return env
103
+
104
+ return _make
105
+
106
+ VecEnvClass = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_class]
107
+ venv = VecEnvClass([make(i) for i in range(n_envs)])
108
+ if training:
109
+ assert tb_writer
110
+ venv = EpisodeStatsWriter(
111
+ venv, tb_writer, training=training, rolling_length=rolling_length
112
+ )
113
+ if normalize:
114
+ if normalize_load_path:
115
+ venv = VecNormalize.load(
116
+ os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME), venv
117
+ )
118
+ else:
119
+ venv = VecNormalize(venv, training=training, **(normalize_kwargs or {}))
120
+ if not training:
121
+ venv.norm_reward = False
122
+ return venv
123
+
124
+
125
+ def make_eval_env(
126
+ config: Config, override_n_envs: Optional[int] = None, **kwargs
127
+ ) -> VecEnv:
128
+ kwargs = kwargs.copy()
129
+ kwargs["training"] = False
130
+ if override_n_envs is not None:
131
+ kwargs["n_envs"] = override_n_envs
132
+ if override_n_envs == 1:
133
+ kwargs["vec_env_class"] = "dummy"
134
+ return make_env(config, **kwargs)
runner/evaluate.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Optional
6
+
7
+ from runner.env import make_eval_env
8
+ from runner.config import Config, RunArgs
9
+ from runner.running_utils import (
10
+ load_hyperparams,
11
+ set_seeds,
12
+ get_device,
13
+ make_policy,
14
+ )
15
+ from shared.callbacks.eval_callback import evaluate
16
+ from shared.policy.policy import Policy
17
+ from shared.stats import EpisodesStats
18
+
19
+
20
+ @dataclass
21
+ class EvalArgs(RunArgs):
22
+ render: bool = True
23
+ best: bool = True
24
+ n_envs: Optional[int] = 1
25
+ n_episodes: int = 3
26
+ deterministic_eval: Optional[bool] = None
27
+ no_print_returns: bool = False
28
+ wandb_run_path: Optional[str] = None
29
+
30
+
31
+ class Evaluation(NamedTuple):
32
+ policy: Policy
33
+ stats: EpisodesStats
34
+ config: Config
35
+
36
+
37
+ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
38
+ if args.wandb_run_path:
39
+ import wandb
40
+
41
+ api = wandb.Api()
42
+ run = api.run(args.wandb_run_path)
43
+ hyperparams = run.config
44
+
45
+ args.algo = hyperparams["algo"]
46
+ args.env = hyperparams["env"]
47
+ args.seed = hyperparams.get("seed", None)
48
+ args.use_deterministic_algorithms = hyperparams.get(
49
+ "use_deterministic_algorithms", True
50
+ )
51
+
52
+ config = Config(args, hyperparams, root_dir)
53
+ model_path = config.model_dir_path(best=args.best, downloaded=True)
54
+
55
+ model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
56
+ run.file(model_archive_name).download()
57
+ if os.path.isdir(model_path):
58
+ shutil.rmtree(model_path)
59
+ shutil.unpack_archive(model_archive_name, model_path)
60
+ os.remove(model_archive_name)
61
+ else:
62
+ hyperparams = load_hyperparams(args.algo, args.env, root_dir)
63
+
64
+ config = Config(args, hyperparams, root_dir)
65
+ model_path = config.model_dir_path(best=args.best)
66
+
67
+ print(args)
68
+
69
+ set_seeds(args.seed, args.use_deterministic_algorithms)
70
+
71
+ env = make_eval_env(
72
+ config,
73
+ override_n_envs=args.n_envs,
74
+ render=args.render,
75
+ normalize_load_path=model_path,
76
+ **config.env_hyperparams,
77
+ )
78
+ device = get_device(config.device, env)
79
+ policy = make_policy(
80
+ args.algo,
81
+ env,
82
+ device,
83
+ load_path=model_path,
84
+ **config.policy_hyperparams,
85
+ ).eval()
86
+
87
+ deterministic = (
88
+ args.deterministic_eval
89
+ if args.deterministic_eval is not None
90
+ else config.eval_params.get("deterministic", True)
91
+ )
92
+ return Evaluation(
93
+ policy,
94
+ evaluate(
95
+ env,
96
+ policy,
97
+ args.n_episodes,
98
+ render=args.render,
99
+ deterministic=deterministic,
100
+ print_returns=not args.no_print_returns,
101
+ ),
102
+ config,
103
+ )
runner/running_utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gym
3
+ import json
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os
7
+ import random
8
+ import torch
9
+ import torch.backends.cudnn
10
+ import yaml
11
+
12
+ from gym.spaces import Box, Discrete
13
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
14
+ from torch.utils.tensorboard.writer import SummaryWriter
15
+ from typing import Dict, Optional, Type, Union
16
+
17
+ from runner.config import Hyperparams
18
+ from shared.algorithm import Algorithm
19
+ from shared.callbacks.eval_callback import EvalCallback
20
+ from shared.policy.policy import Policy
21
+
22
+ from dqn.dqn import DQN
23
+ from dqn.policy import DQNPolicy
24
+ from vpg.vpg import VanillaPolicyGradient
25
+ from vpg.policy import VPGActorCritic
26
+ from ppo.ppo import PPO
27
+ from ppo.policy import PPOActorCritic
28
+
29
+ ALGOS: Dict[str, Type[Algorithm]] = {
30
+ "dqn": DQN,
31
+ "vpg": VanillaPolicyGradient,
32
+ "ppo": PPO,
33
+ }
34
+ POLICIES: Dict[str, Type[Policy]] = {
35
+ "dqn": DQNPolicy,
36
+ "vpg": VPGActorCritic,
37
+ "ppo": PPOActorCritic,
38
+ }
39
+
40
+ HYPERPARAMS_PATH = "hyperparams"
41
+
42
+
43
+ def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "--algo",
47
+ default=["dqn"],
48
+ type=str,
49
+ choices=list(ALGOS.keys()),
50
+ nargs="+" if multiple else 1,
51
+ help="Abbreviation(s) of algorithm(s)",
52
+ )
53
+ parser.add_argument(
54
+ "--env",
55
+ default=["CartPole-v1"],
56
+ type=str,
57
+ nargs="+" if multiple else 1,
58
+ help="Name of environment(s) in gym",
59
+ )
60
+ parser.add_argument(
61
+ "--seed",
62
+ default=[1],
63
+ type=int,
64
+ nargs="*" if multiple else "?",
65
+ help="Seeds to run experiment. Unset will do one run with no set seed",
66
+ )
67
+ parser.add_argument(
68
+ "--use-deterministic-algorithms",
69
+ default=True,
70
+ type=bool,
71
+ help="If seed set, set torch.use_deterministic_algorithms",
72
+ )
73
+ return parser
74
+
75
+
76
+ def load_hyperparams(algo: str, env_id: str, root_path: str) -> Hyperparams:
77
+ hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
78
+ with open(hyperparams_path, "r") as f:
79
+ hyperparams_dict = yaml.safe_load(f)
80
+ if "BulletEnv" in env_id:
81
+ import pybullet_envs
82
+ spec = gym.spec(env_id)
83
+ if env_id in hyperparams_dict:
84
+ return hyperparams_dict[env_id]
85
+ elif "AtariEnv" in str(spec.entry_point) and "atari" in hyperparams_dict:
86
+ return hyperparams_dict["atari"]
87
+ else:
88
+ raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
89
+
90
+
91
+ def get_device(device: str, env: VecEnv) -> torch.device:
92
+ # cuda by default
93
+ if device == "auto":
94
+ device = "cuda"
95
+ # Apple MPS is a second choice (sometimes)
96
+ if device == "cuda" and not torch.cuda.is_available():
97
+ device = "mps"
98
+ # If no MPS, fallback to cpu
99
+ if device == "mps" and not torch.backends.mps.is_available():
100
+ device = "cpu"
101
+ # Simple environments like Discreet and 1-D Boxes might also be better
102
+ # served with the CPU.
103
+ if device == "mps":
104
+ obs_space = env.observation_space
105
+ if isinstance(obs_space, Discrete):
106
+ device = "cpu"
107
+ elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
108
+ device = "cpu"
109
+ print(f"Device: {device}")
110
+ return torch.device(device)
111
+
112
+
113
+ def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
114
+ if seed is None:
115
+ return
116
+ random.seed(seed)
117
+ np.random.seed(seed)
118
+ torch.manual_seed(seed)
119
+ torch.backends.cudnn.benchmark = False
120
+ torch.use_deterministic_algorithms(use_deterministic_algorithms)
121
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
122
+
123
+
124
+ def make_policy(
125
+ algo: str,
126
+ env: VecEnv,
127
+ device: torch.device,
128
+ load_path: Optional[str] = None,
129
+ **kwargs,
130
+ ) -> Policy:
131
+ policy = POLICIES[algo](env, **kwargs).to(device)
132
+ if load_path:
133
+ policy.load(load_path)
134
+ return policy
135
+
136
+
137
+ def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
138
+ figure = plt.figure()
139
+ cumulative_steps = [
140
+ (idx + 1) * callback.step_freq for idx in range(len(callback.stats))
141
+ ]
142
+ plt.plot(
143
+ cumulative_steps,
144
+ [s.score.mean for s in callback.stats],
145
+ "b-",
146
+ label="mean",
147
+ )
148
+ plt.plot(
149
+ cumulative_steps,
150
+ [s.score.mean - s.score.std for s in callback.stats],
151
+ "g--",
152
+ label="mean-std",
153
+ )
154
+ plt.fill_between(
155
+ cumulative_steps,
156
+ [s.score.min for s in callback.stats], # type: ignore
157
+ [s.score.max for s in callback.stats], # type: ignore
158
+ facecolor="cyan",
159
+ label="range",
160
+ )
161
+ plt.xlabel("Steps")
162
+ plt.ylabel("Score")
163
+ plt.legend()
164
+ plt.title(f"Eval {run_name}")
165
+ tb_writer.add_figure("eval", figure)
166
+
167
+
168
+ Scalar = Union[bool, str, float, int, None]
169
+
170
+
171
+ def flatten_hyperparameters(
172
+ hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
173
+ ) -> Dict[str, Scalar]:
174
+ flattened = args.copy()
175
+ for k, v in flattened.items():
176
+ if isinstance(v, list):
177
+ flattened[k] = json.dumps(v)
178
+ for k, v in hyperparams.items():
179
+ if isinstance(v, dict):
180
+ for sk, sv in v.items():
181
+ key = f"{k}/{sk}"
182
+ if isinstance(sv, dict) or isinstance(sv, list):
183
+ flattened[key] = str(sv)
184
+ else:
185
+ flattened[key] = sv
186
+ else:
187
+ flattened[k] = v # type: ignore
188
+ return flattened # type: ignore
runner/train.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ import dataclasses
7
+ import shutil
8
+ import wandb
9
+ import yaml
10
+
11
+ from dataclasses import dataclass
12
+ from torch.utils.tensorboard.writer import SummaryWriter
13
+ from typing import Any, Dict, Optional, Sequence
14
+
15
+ from shared.callbacks.eval_callback import EvalCallback
16
+ from runner.env import make_env, make_eval_env
17
+ from runner.config import Config, RunArgs
18
+ from runner.running_utils import (
19
+ ALGOS,
20
+ load_hyperparams,
21
+ set_seeds,
22
+ get_device,
23
+ make_policy,
24
+ plot_eval_callback,
25
+ flatten_hyperparameters,
26
+ )
27
+ from shared.stats import EpisodesStats
28
+
29
+
30
+ @dataclass
31
+ class TrainArgs(RunArgs):
32
+ wandb_project_name: Optional[str] = None
33
+ wandb_entity: Optional[str] = None
34
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
35
+
36
+
37
+ def train(args: TrainArgs):
38
+ print(args)
39
+ hyperparams = load_hyperparams(args.algo, args.env, os.getcwd())
40
+ print(hyperparams)
41
+ config = Config(args, hyperparams, os.getcwd())
42
+
43
+ wandb_enabled = args.wandb_project_name
44
+ if wandb_enabled:
45
+ wandb.tensorboard.patch(
46
+ root_logdir=config.tensorboard_summary_path, pytorch=True
47
+ )
48
+ wandb.init(
49
+ project=args.wandb_project_name,
50
+ entity=args.wandb_entity,
51
+ config=hyperparams, # type: ignore
52
+ name=config.run_name,
53
+ monitor_gym=True,
54
+ save_code=True,
55
+ tags=args.wandb_tags,
56
+ )
57
+ wandb.config.update(args)
58
+
59
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
60
+
61
+ set_seeds(args.seed, args.use_deterministic_algorithms)
62
+
63
+ env = make_env(config, tb_writer=tb_writer, **config.env_hyperparams)
64
+ device = get_device(config.device, env)
65
+ policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
66
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
67
+
68
+ eval_env = make_eval_env(config, **config.env_hyperparams)
69
+ record_best_videos = config.eval_params.get("record_best_videos", True)
70
+ callback = EvalCallback(
71
+ policy,
72
+ eval_env,
73
+ tb_writer,
74
+ best_model_path=config.model_dir_path(best=True),
75
+ **config.eval_params,
76
+ video_env=make_eval_env(config, override_n_envs=1, **config.env_hyperparams)
77
+ if record_best_videos
78
+ else None,
79
+ best_video_dir=config.best_videos_dir,
80
+ )
81
+ algo.learn(config.n_timesteps, callback=callback)
82
+
83
+ policy.save(config.model_dir_path(best=False))
84
+
85
+ eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
86
+
87
+ plot_eval_callback(callback, tb_writer, config.run_name)
88
+
89
+ log_dict: Dict[str, Any] = {
90
+ "eval": eval_stats._asdict(),
91
+ }
92
+ if callback.best:
93
+ log_dict["best_eval"] = callback.best._asdict()
94
+ log_dict.update(hyperparams)
95
+ log_dict.update(vars(args))
96
+ with open(config.logs_path, "a") as f:
97
+ yaml.dump({config.run_name: log_dict}, f)
98
+
99
+ best_eval_stats: EpisodesStats = callback.best # type: ignore
100
+ tb_writer.add_hparams(
101
+ flatten_hyperparameters(hyperparams, vars(args)),
102
+ {
103
+ "hparam/best_mean": best_eval_stats.score.mean,
104
+ "hparam/best_result": best_eval_stats.score.mean
105
+ - best_eval_stats.score.std,
106
+ "hparam/last_mean": eval_stats.score.mean,
107
+ "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
108
+ },
109
+ None,
110
+ config.run_name,
111
+ )
112
+
113
+ tb_writer.close()
114
+
115
+ if wandb_enabled:
116
+ shutil.make_archive(
117
+ os.path.join(wandb.run.dir, config.model_dir_name()),
118
+ "zip",
119
+ config.model_dir_path(),
120
+ )
121
+ shutil.make_archive(
122
+ os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
123
+ "zip",
124
+ config.model_dir_path(best=True),
125
+ )
126
+ wandb.finish()
saved_models/ppo-MountainCar-v0-S6-best/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3c581e0366f696864c5708c9640cdac1449335c75e277dbaef88188c4d1affb
3
+ size 39461
saved_models/ppo-MountainCar-v0-S6-best/vecnormalize.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a8b4d4c4f7920087254434562e1d8044da94c432aea18b4ec95b7f658fa695a
3
+ size 6650
shared/algorithm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+
4
+ from abc import ABC, abstractmethod
5
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
6
+ from torch.utils.tensorboard.writer import SummaryWriter
7
+ from typing import List, Optional, TypeVar
8
+
9
+ from shared.callbacks.callback import Callback
10
+ from shared.policy.policy import Policy
11
+ from shared.stats import EpisodesStats
12
+
13
+ AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
14
+
15
+ class Algorithm(ABC):
16
+ @abstractmethod
17
+ def __init__(
18
+ self,
19
+ policy: Policy,
20
+ env: VecEnv,
21
+ device: torch.device,
22
+ tb_writer: SummaryWriter,
23
+ **kwargs,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.policy = policy
27
+ self.env = env
28
+ self.device = device
29
+ self.tb_writer = tb_writer
30
+
31
+ @abstractmethod
32
+ def learn(
33
+ self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None
34
+ ) -> AlgorithmSelf:
35
+ ...
shared/callbacks/callback.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class Callback(ABC):
5
+
6
+ def __init__(self) -> None:
7
+ super().__init__()
8
+ self.timesteps_elapsed = 0
9
+
10
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
11
+ self.timesteps_elapsed += timesteps_elapsed
12
+ return True
shared/callbacks/eval_callback.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ import os
4
+
5
+ from copy import deepcopy
6
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
7
+ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
8
+ from torch.utils.tensorboard.writer import SummaryWriter
9
+ from typing import List, Optional, Union
10
+
11
+ from shared.callbacks.callback import Callback
12
+ from shared.policy.policy import Policy
13
+ from shared.stats import Episode, EpisodeAccumulator, EpisodesStats
14
+ from wrappers.vec_episode_recorder import VecEpisodeRecorder
15
+
16
+
17
+ class EvaluateAccumulator(EpisodeAccumulator):
18
+ def __init__(self, num_envs: int, goal_episodes: int, print_returns: bool = True):
19
+ super().__init__(num_envs)
20
+ self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
21
+ self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
22
+ self.print_returns = print_returns
23
+
24
+ def on_done(self, ep_idx: int, episode: Episode) -> None:
25
+ if (
26
+ len(self.completed_episodes_by_env_idx[ep_idx])
27
+ >= self.goal_episodes_per_env
28
+ ):
29
+ return
30
+ self.completed_episodes_by_env_idx[ep_idx].append(episode)
31
+ if self.print_returns:
32
+ print(
33
+ f"Episode {len(self)} | "
34
+ f"Score {episode.score} | "
35
+ f"Length {episode.length}"
36
+ )
37
+
38
+ def __len__(self) -> int:
39
+ return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
40
+
41
+ @property
42
+ def episodes(self) -> List[Episode]:
43
+ return list(itertools.chain(*self.completed_episodes_by_env_idx))
44
+
45
+ def is_done(self) -> bool:
46
+ return all(
47
+ len(ce) == self.goal_episodes_per_env
48
+ for ce in self.completed_episodes_by_env_idx
49
+ )
50
+
51
+
52
+ def evaluate(
53
+ env: VecEnv,
54
+ policy: Policy,
55
+ n_episodes: int,
56
+ render: bool = False,
57
+ deterministic: bool = True,
58
+ print_returns: bool = True,
59
+ ) -> EpisodesStats:
60
+ policy.eval()
61
+ episodes = EvaluateAccumulator(env.num_envs, n_episodes, print_returns)
62
+
63
+ obs = env.reset()
64
+ while not episodes.is_done():
65
+ act = policy.act(obs, deterministic=deterministic)
66
+ obs, rew, done, _ = env.step(act)
67
+ episodes.step(rew, done)
68
+ if render:
69
+ env.render()
70
+ stats = EpisodesStats(episodes.episodes)
71
+ if print_returns:
72
+ print(stats)
73
+ return stats
74
+
75
+
76
+ class EvalCallback(Callback):
77
+ def __init__(
78
+ self,
79
+ policy: Policy,
80
+ env: VecEnv,
81
+ tb_writer: SummaryWriter,
82
+ best_model_path: Optional[str] = None,
83
+ step_freq: Union[int, float] = 50_000,
84
+ n_episodes: int = 10,
85
+ save_best: bool = True,
86
+ deterministic: bool = True,
87
+ record_best_videos: bool = True,
88
+ video_env: Optional[VecEnv] = None,
89
+ best_video_dir: Optional[str] = None,
90
+ max_video_length: int = 3600,
91
+ ) -> None:
92
+ super().__init__()
93
+ self.policy = policy
94
+ self.env = env
95
+ self.tb_writer = tb_writer
96
+ self.best_model_path = best_model_path
97
+ self.step_freq = int(step_freq)
98
+ self.n_episodes = n_episodes
99
+ self.save_best = save_best
100
+ self.deterministic = deterministic
101
+ self.stats: List[EpisodesStats] = []
102
+ self.best = None
103
+
104
+ self.record_best_videos = record_best_videos
105
+ assert video_env or not record_best_videos
106
+ self.video_env = video_env
107
+ assert best_video_dir or not record_best_videos
108
+ self.best_video_dir = best_video_dir
109
+ if best_video_dir:
110
+ os.makedirs(best_video_dir, exist_ok=True)
111
+ self.max_video_length = max_video_length
112
+ self.best_video_base_path = None
113
+
114
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
115
+ super().on_step(timesteps_elapsed)
116
+ if self.timesteps_elapsed // self.step_freq >= len(self.stats):
117
+ sync_vec_normalize(self.policy.vec_normalize, self.env)
118
+ self.evaluate()
119
+ return True
120
+
121
+ def evaluate(
122
+ self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
123
+ ) -> EpisodesStats:
124
+ eval_stat = evaluate(
125
+ self.env,
126
+ self.policy,
127
+ n_episodes or self.n_episodes,
128
+ deterministic=self.deterministic,
129
+ print_returns=print_returns or False,
130
+ )
131
+ self.policy.train(True)
132
+ print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
133
+
134
+ self.stats.append(eval_stat)
135
+
136
+ if not self.best or eval_stat >= self.best:
137
+ strictly_better = not self.best or eval_stat > self.best
138
+ self.best = eval_stat
139
+ if self.save_best:
140
+ assert self.best_model_path
141
+ self.policy.save(self.best_model_path)
142
+ print("Saved best model")
143
+ self.best.write_to_tensorboard(
144
+ self.tb_writer, "best_eval", self.timesteps_elapsed
145
+ )
146
+ if strictly_better and self.record_best_videos:
147
+ assert self.video_env and self.best_video_dir
148
+ sync_vec_normalize(self.policy.vec_normalize, self.video_env)
149
+ self.best_video_base_path = os.path.join(
150
+ self.best_video_dir, str(self.timesteps_elapsed)
151
+ )
152
+ video_wrapped = VecEpisodeRecorder(
153
+ self.video_env,
154
+ self.best_video_base_path,
155
+ max_video_length=self.max_video_length,
156
+ )
157
+ video_stats = evaluate(
158
+ video_wrapped,
159
+ self.policy,
160
+ 1,
161
+ deterministic=self.deterministic,
162
+ print_returns=False,
163
+ )
164
+ print(f"Saved best video: {video_stats}")
165
+
166
+ eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
167
+
168
+ return eval_stat
169
+
170
+
171
+ def sync_vec_normalize(
172
+ origin_vec_normalize: Optional[VecNormalize], destination_env: VecEnv
173
+ ) -> None:
174
+ if origin_vec_normalize is not None:
175
+ eval_env_wrapper = destination_env
176
+ while isinstance(eval_env_wrapper, VecEnvWrapper):
177
+ if isinstance(eval_env_wrapper, VecNormalize):
178
+ if hasattr(origin_vec_normalize, "obs_rms"):
179
+ eval_env_wrapper.obs_rms = deepcopy(origin_vec_normalize.obs_rms)
180
+ eval_env_wrapper.ret_rms = deepcopy(origin_vec_normalize.ret_rms)
181
+ eval_env_wrapper = eval_env_wrapper.venv
shared/module.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from gym.spaces import Box, Discrete
8
+ from stable_baselines3.common.preprocessing import get_flattened_obs_dim
9
+ from typing import Sequence, Type
10
+
11
+
12
+ class FeatureExtractor(nn.Module):
13
+ def __init__(
14
+ self,
15
+ obs_space: gym.Space,
16
+ activation: Type[nn.Module],
17
+ init_layers_orthogonal: bool = False,
18
+ cnn_feature_dim: int = 512,
19
+ ) -> None:
20
+ super().__init__()
21
+ if isinstance(obs_space, Box):
22
+ # Conv2D: (channels, height, width)
23
+ if len(obs_space.shape) == 3:
24
+ # CNN from DQN Nature paper: Mnih, Volodymyr, et al.
25
+ # "Human-level control through deep reinforcement learning."
26
+ # Nature 518.7540 (2015): 529-533.
27
+ cnn = nn.Sequential(
28
+ layer_init(
29
+ nn.Conv2d(obs_space.shape[0], 32, kernel_size=8, stride=4),
30
+ init_layers_orthogonal,
31
+ ),
32
+ activation(),
33
+ layer_init(
34
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
35
+ init_layers_orthogonal,
36
+ ),
37
+ activation(),
38
+ layer_init(
39
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
40
+ init_layers_orthogonal,
41
+ ),
42
+ activation(),
43
+ nn.Flatten(),
44
+ )
45
+
46
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
47
+ if len(obs.shape) == 3:
48
+ obs = obs.unsqueeze(0)
49
+ return obs.float() / 255.0
50
+
51
+ with torch.no_grad():
52
+ cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
53
+ self.preprocess = preprocess
54
+ self.feature_extractor = nn.Sequential(
55
+ cnn,
56
+ layer_init(
57
+ nn.Linear(cnn_out.shape[1], cnn_feature_dim),
58
+ init_layers_orthogonal,
59
+ ),
60
+ activation(),
61
+ )
62
+ self.out_dim = cnn_feature_dim
63
+ elif len(obs_space.shape) == 1:
64
+
65
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
66
+ if len(obs.shape) == 1:
67
+ obs = obs.unsqueeze(0)
68
+ return obs.float()
69
+
70
+ self.preprocess = preprocess
71
+ self.feature_extractor = nn.Flatten()
72
+ self.out_dim = get_flattened_obs_dim(obs_space)
73
+ else:
74
+ raise ValueError(f"Unsupported observation space: {obs_space}")
75
+ elif isinstance(obs_space, Discrete):
76
+ self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
77
+ self.feature_extractor = nn.Flatten()
78
+ self.out_dim = obs_space.n
79
+ else:
80
+ raise NotImplementedError
81
+
82
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
83
+ if self.preprocess:
84
+ obs = self.preprocess(obs)
85
+ return self.feature_extractor(obs)
86
+
87
+
88
+ def mlp(
89
+ layer_sizes: Sequence[int],
90
+ activation: Type[nn.Module],
91
+ output_activation: Type[nn.Module] = nn.Identity,
92
+ init_layers_orthogonal: bool = False,
93
+ final_layer_gain: float = np.sqrt(2),
94
+ ) -> nn.Module:
95
+ layers = []
96
+ for i in range(len(layer_sizes) - 2):
97
+ layers.append(
98
+ layer_init(
99
+ nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
100
+ )
101
+ )
102
+ layers.append(activation())
103
+ layers.append(
104
+ layer_init(
105
+ nn.Linear(layer_sizes[-2], layer_sizes[-1]),
106
+ init_layers_orthogonal,
107
+ std=final_layer_gain,
108
+ )
109
+ )
110
+ layers.append(output_activation())
111
+ return nn.Sequential(*layers)
112
+
113
+
114
+ def layer_init(
115
+ layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
116
+ ) -> nn.Module:
117
+ if not init_layers_orthogonal:
118
+ return layer
119
+ nn.init.orthogonal_(layer.weight, std) # type: ignore
120
+ nn.init.constant_(layer.bias, 0.0) # type: ignore
121
+ return layer
shared/policy/actor.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from abc import ABC, abstractmethod
6
+ from gym.spaces import Box, Discrete
7
+ from torch.distributions import Categorical, Distribution, Normal
8
+ from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
9
+
10
+ from shared.module import FeatureExtractor, mlp
11
+
12
+
13
+ class PiForward(NamedTuple):
14
+ pi: Distribution
15
+ logp_a: Optional[torch.Tensor]
16
+ entropy: Optional[torch.Tensor]
17
+
18
+
19
+ class Actor(nn.Module, ABC):
20
+ @abstractmethod
21
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
22
+ ...
23
+
24
+
25
+ class CategoricalActorHead(Actor):
26
+ def __init__(
27
+ self,
28
+ act_dim: int,
29
+ hidden_sizes: Sequence[int] = (32,),
30
+ activation: Type[nn.Module] = nn.Tanh,
31
+ init_layers_orthogonal: bool = True,
32
+ ) -> None:
33
+ super().__init__()
34
+ layer_sizes = tuple(hidden_sizes) + (act_dim,)
35
+ self._fc = mlp(
36
+ layer_sizes,
37
+ activation,
38
+ init_layers_orthogonal=init_layers_orthogonal,
39
+ final_layer_gain=0.01,
40
+ )
41
+
42
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
43
+ logits = self._fc(obs)
44
+ pi = Categorical(logits=logits)
45
+ logp_a = None
46
+ entropy = None
47
+ if a is not None:
48
+ logp_a = pi.log_prob(a)
49
+ entropy = pi.entropy()
50
+ return PiForward(pi, logp_a, entropy)
51
+
52
+
53
+ class GaussianDistribution(Normal):
54
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
55
+ return super().log_prob(a).sum(axis=-1)
56
+
57
+ def sample(self) -> torch.Tensor:
58
+ return self.rsample()
59
+
60
+
61
+ class GaussianActorHead(Actor):
62
+ def __init__(
63
+ self,
64
+ act_dim: int,
65
+ hidden_sizes: Sequence[int] = (32,),
66
+ activation: Type[nn.Module] = nn.Tanh,
67
+ init_layers_orthogonal: bool = True,
68
+ log_std_init: float = -0.5,
69
+ ) -> None:
70
+ super().__init__()
71
+ layer_sizes = tuple(hidden_sizes) + (act_dim,)
72
+ self.mu_net = mlp(
73
+ layer_sizes,
74
+ activation,
75
+ init_layers_orthogonal=init_layers_orthogonal,
76
+ final_layer_gain=0.01,
77
+ )
78
+ self.log_std = nn.Parameter(
79
+ torch.ones(act_dim, dtype=torch.float32) * log_std_init
80
+ )
81
+
82
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
83
+ mu = self.mu_net(obs)
84
+ std = torch.exp(self.log_std)
85
+ return GaussianDistribution(mu, std)
86
+
87
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
88
+ pi = self._distribution(obs)
89
+ logp_a = None
90
+ entropy = None
91
+ if a is not None:
92
+ logp_a = pi.log_prob(a)
93
+ entropy = pi.entropy()
94
+ return PiForward(pi, logp_a, entropy)
95
+
96
+
97
+ class TanhBijector:
98
+ def __init__(self, epsilon: float = 1e-6) -> None:
99
+ self.epsilon = epsilon
100
+
101
+ @staticmethod
102
+ def forward(x: torch.Tensor) -> torch.Tensor:
103
+ return torch.tanh(x)
104
+
105
+ @staticmethod
106
+ def inverse(y: torch.Tensor) -> torch.Tensor:
107
+ eps = torch.finfo(y.dtype).eps
108
+ clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps)
109
+ return torch.atanh(clamped_y)
110
+
111
+ def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
112
+ return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
113
+
114
+
115
+ class StateDependentNoiseDistribution(Normal):
116
+ def __init__(
117
+ self,
118
+ loc,
119
+ scale,
120
+ latent_sde: torch.Tensor,
121
+ exploration_mat: torch.Tensor,
122
+ exploration_matrices: torch.Tensor,
123
+ bijector: Optional[TanhBijector] = None,
124
+ validate_args=None,
125
+ ):
126
+ super().__init__(loc, scale, validate_args)
127
+ self.latent_sde = latent_sde
128
+ self.exploration_mat = exploration_mat
129
+ self.exploration_matrices = exploration_matrices
130
+ self.bijector = bijector
131
+
132
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
133
+ gaussian_a = self.bijector.inverse(a) if self.bijector else a
134
+ log_prob = super().log_prob(gaussian_a).sum(axis=-1)
135
+ if self.bijector:
136
+ log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1)
137
+ return log_prob
138
+
139
+ def sample(self) -> torch.Tensor:
140
+ noise = self._get_noise()
141
+ actions = self.mean + noise
142
+ return self.bijector.forward(actions) if self.bijector else actions
143
+
144
+ def _get_noise(self) -> torch.Tensor:
145
+ if len(self.latent_sde) == 1 or len(self.latent_sde) != len(
146
+ self.exploration_matrices
147
+ ):
148
+ return torch.mm(self.latent_sde, self.exploration_mat)
149
+ # (batch_size, n_features) -> (batch_size, 1, n_features)
150
+ latent_sde = self.latent_sde.unsqueeze(dim=1)
151
+ # (batch_size, 1, n_actions)
152
+ noise = torch.bmm(latent_sde, self.exploration_matrices)
153
+ return noise.squeeze(dim=1)
154
+
155
+ @property
156
+ def mode(self) -> torch.Tensor:
157
+ mean = super().mode
158
+ return self.bijector.forward(mean) if self.bijector else mean
159
+
160
+
161
+ StateDependentNoiseActorHeadSelf = TypeVar(
162
+ "StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead"
163
+ )
164
+
165
+
166
+ class StateDependentNoiseActorHead(Actor):
167
+ def __init__(
168
+ self,
169
+ act_dim: int,
170
+ hidden_sizes: Sequence[int] = (32,),
171
+ activation: Type[nn.Module] = nn.Tanh,
172
+ init_layers_orthogonal: bool = True,
173
+ log_std_init: float = -0.5,
174
+ full_std: bool = True,
175
+ squash_output: bool = False,
176
+ learn_std: bool = False,
177
+ ) -> None:
178
+ super().__init__()
179
+ self.act_dim = act_dim
180
+ layer_sizes = tuple(hidden_sizes) + (self.act_dim,)
181
+ if len(layer_sizes) == 2:
182
+ self.latent_net = nn.Identity()
183
+ elif len(layer_sizes) > 2:
184
+ self.latent_net = mlp(
185
+ layer_sizes[:-1],
186
+ activation,
187
+ output_activation=activation,
188
+ init_layers_orthogonal=init_layers_orthogonal,
189
+ )
190
+ else:
191
+ raise ValueError("hidden_sizes must be of at least length 1")
192
+ self.mu_net = mlp(
193
+ layer_sizes[-2:],
194
+ activation,
195
+ init_layers_orthogonal=init_layers_orthogonal,
196
+ final_layer_gain=0.01,
197
+ )
198
+ self.full_std = full_std
199
+ std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1)
200
+ self.log_std = nn.Parameter(
201
+ torch.ones(std_dim, dtype=torch.float32) * log_std_init
202
+ )
203
+ self.bijector = TanhBijector() if squash_output else None
204
+ self.learn_std = learn_std
205
+ self.device = None
206
+
207
+ self.exploration_mat = None
208
+ self.exploration_matrices = None
209
+ self.sample_weights()
210
+
211
+ def to(
212
+ self: StateDependentNoiseActorHeadSelf,
213
+ device: Optional[torch.device] = None,
214
+ dtype: Optional[Union[torch.dtype, str]] = None,
215
+ non_blocking: bool = False,
216
+ ) -> StateDependentNoiseActorHeadSelf:
217
+ super().to(device, dtype, non_blocking)
218
+ self.device = device
219
+ return self
220
+
221
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
222
+ latent = self.latent_net(obs)
223
+ mu = self.mu_net(latent)
224
+ latent_sde = latent if self.learn_std else latent.detach()
225
+ variance = torch.mm(latent_sde**2, self._get_std() ** 2)
226
+ assert self.exploration_mat is not None
227
+ assert self.exploration_matrices is not None
228
+ return StateDependentNoiseDistribution(
229
+ mu,
230
+ torch.sqrt(variance + 1e-6),
231
+ latent_sde,
232
+ self.exploration_mat,
233
+ self.exploration_matrices,
234
+ self.bijector,
235
+ )
236
+
237
+ def _get_std(self) -> torch.Tensor:
238
+ std = torch.exp(self.log_std)
239
+ if self.full_std:
240
+ return std
241
+ ones = torch.ones(self.log_std.shape[0], self.act_dim)
242
+ if self.device:
243
+ ones = ones.to(self.device)
244
+ return ones * std
245
+
246
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
247
+ pi = self._distribution(obs)
248
+ logp_a = None
249
+ entropy = None
250
+ if a is not None:
251
+ logp_a = pi.log_prob(a)
252
+ entropy = -logp_a
253
+ return PiForward(pi, logp_a, entropy)
254
+
255
+ def sample_weights(self, batch_size: int = 1) -> None:
256
+ std = self._get_std()
257
+ weights_dist = Normal(torch.zeros_like(std), std)
258
+ # Reparametrization trick to pass gradients
259
+ self.exploration_mat = weights_dist.rsample()
260
+ self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
261
+
262
+
263
+ def actor_head(
264
+ action_space: gym.Space,
265
+ hidden_sizes: Sequence[int],
266
+ init_layers_orthogonal: bool,
267
+ activation: Type[nn.Module],
268
+ log_std_init: float = -0.5,
269
+ use_sde: bool = False,
270
+ full_std: bool = True,
271
+ squash_output: bool = False,
272
+ ) -> Actor:
273
+ assert not use_sde or isinstance(
274
+ action_space, Box
275
+ ), "use_sde only valid if Box action_space"
276
+ assert not squash_output or use_sde, "squash_output only valid if use_sde"
277
+ if isinstance(action_space, Discrete):
278
+ return CategoricalActorHead(
279
+ action_space.n,
280
+ hidden_sizes=hidden_sizes,
281
+ activation=activation,
282
+ init_layers_orthogonal=init_layers_orthogonal,
283
+ )
284
+ elif isinstance(action_space, Box):
285
+ if use_sde:
286
+ return StateDependentNoiseActorHead(
287
+ action_space.shape[0],
288
+ hidden_sizes=hidden_sizes,
289
+ activation=activation,
290
+ init_layers_orthogonal=init_layers_orthogonal,
291
+ log_std_init=log_std_init,
292
+ full_std=full_std,
293
+ squash_output=squash_output,
294
+ )
295
+ else:
296
+ return GaussianActorHead(
297
+ action_space.shape[0],
298
+ hidden_sizes=hidden_sizes,
299
+ activation=activation,
300
+ init_layers_orthogonal=init_layers_orthogonal,
301
+ log_std_init=log_std_init,
302
+ )
303
+ else:
304
+ raise ValueError(f"Unsupported action space: {action_space}")
shared/policy/critic.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from typing import Sequence, Type
6
+ from shared.module import FeatureExtractor, mlp
7
+
8
+
9
+ class CriticHead(nn.Module):
10
+ def __init__(
11
+ self,
12
+ hidden_sizes: Sequence[int] = (32,),
13
+ activation: Type[nn.Module] = nn.Tanh,
14
+ init_layers_orthogonal: bool = True,
15
+ ) -> None:
16
+ super().__init__()
17
+ layer_sizes = tuple(hidden_sizes) + (1,)
18
+ self._fc = mlp(
19
+ layer_sizes,
20
+ activation,
21
+ init_layers_orthogonal=init_layers_orthogonal,
22
+ final_layer_gain=1.0,
23
+ )
24
+
25
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
26
+ v = self._fc(obs)
27
+ return v.squeeze(-1)
shared/policy/on_policy.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import numpy as np
3
+ import torch
4
+
5
+ from gym.spaces import Box
6
+ from pathlib import Path
7
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
8
+ from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
9
+
10
+ from shared.module import FeatureExtractor
11
+ from shared.policy.actor import PiForward, StateDependentNoiseActorHead, actor_head
12
+ from shared.policy.critic import CriticHead
13
+ from shared.policy.policy import ACTIVATION, Policy
14
+
15
+
16
+ class Step(NamedTuple):
17
+ a: np.ndarray
18
+ v: np.ndarray
19
+ logp_a: np.ndarray
20
+ clamped_a: np.ndarray
21
+
22
+
23
+ class ACForward(NamedTuple):
24
+ logp_a: torch.Tensor
25
+ entropy: torch.Tensor
26
+ v: torch.Tensor
27
+
28
+
29
+ FEAT_EXT_FILE_NAME = "feat_ext.pt"
30
+ V_FEAT_EXT_FILE_NAME = "v_feat_ext.pt"
31
+ PI_FILE_NAME = "pi.pt"
32
+ V_FILE_NAME = "v.pt"
33
+ ActorCriticSelf = TypeVar("ActorCriticSelf", bound="ActorCritic")
34
+
35
+
36
+ def clamp_actions(
37
+ actions: np.ndarray, action_space: gym.Space, squash_output: bool
38
+ ) -> np.ndarray:
39
+ if isinstance(action_space, Box):
40
+ low, high = action_space.low, action_space.high # type: ignore
41
+ if squash_output:
42
+ # Squashed output is already between -1 and 1. Rescale if the actual
43
+ # output needs to something other than -1 and 1
44
+ return low + 0.5 * (actions + 1) * (high - low)
45
+ else:
46
+ return np.clip(actions, low, high)
47
+ return actions
48
+
49
+
50
+ class ActorCritic(Policy):
51
+ def __init__(
52
+ self,
53
+ env: VecEnv,
54
+ pi_hidden_sizes: Sequence[int],
55
+ v_hidden_sizes: Sequence[int],
56
+ init_layers_orthogonal: bool = True,
57
+ activation_fn: str = "tanh",
58
+ log_std_init: float = -0.5,
59
+ use_sde: bool = False,
60
+ full_std: bool = True,
61
+ squash_output: bool = False,
62
+ share_features_extractor: bool = True,
63
+ cnn_feature_dim: int = 512,
64
+ **kwargs,
65
+ ) -> None:
66
+ super().__init__(env, **kwargs)
67
+ activation = ACTIVATION[activation_fn]
68
+ observation_space = env.observation_space
69
+ self.action_space = env.action_space
70
+ self.squash_output = squash_output
71
+ self.share_features_extractor = share_features_extractor
72
+ self._feature_extractor = FeatureExtractor(
73
+ observation_space,
74
+ activation,
75
+ init_layers_orthogonal=init_layers_orthogonal,
76
+ cnn_feature_dim=cnn_feature_dim,
77
+ )
78
+ self._pi = actor_head(
79
+ self.action_space,
80
+ (self._feature_extractor.out_dim,) + tuple(pi_hidden_sizes),
81
+ init_layers_orthogonal,
82
+ activation,
83
+ log_std_init=log_std_init,
84
+ use_sde=use_sde,
85
+ full_std=full_std,
86
+ squash_output=squash_output,
87
+ )
88
+
89
+ if not share_features_extractor:
90
+ self._v_feature_extractor = FeatureExtractor(
91
+ observation_space,
92
+ activation,
93
+ init_layers_orthogonal=init_layers_orthogonal,
94
+ cnn_feature_dim=cnn_feature_dim,
95
+ )
96
+ v_hidden_sizes = (self._v_feature_extractor.out_dim,) + tuple(
97
+ v_hidden_sizes
98
+ )
99
+ else:
100
+ self._v_feature_extractor = None
101
+ v_hidden_sizes = (self._feature_extractor.out_dim,) + tuple(v_hidden_sizes)
102
+ self._v = CriticHead(
103
+ hidden_sizes=v_hidden_sizes,
104
+ activation=activation,
105
+ init_layers_orthogonal=init_layers_orthogonal,
106
+ )
107
+
108
+ def _pi_forward(
109
+ self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
110
+ ) -> Tuple[PiForward, torch.Tensor]:
111
+ p_fe = self._feature_extractor(obs)
112
+ pi_forward = self._pi(p_fe, action)
113
+
114
+ return pi_forward, p_fe
115
+
116
+ def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
117
+ v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
118
+ return self._v(v_fe)
119
+
120
+ def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
121
+ (_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
122
+ v = self._v_forward(obs, p_fc)
123
+
124
+ assert logp_a is not None
125
+ assert entropy is not None
126
+ return ACForward(logp_a, entropy, v)
127
+
128
+ def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
129
+ assert isinstance(obs, np.ndarray)
130
+ o = torch.as_tensor(obs)
131
+ if self.device is not None:
132
+ o = o.to(self.device)
133
+ return o
134
+
135
+ def value(self, obs: VecEnvObs) -> np.ndarray:
136
+ o = self._as_tensor(obs)
137
+ with torch.no_grad():
138
+ fe = (
139
+ self._v_feature_extractor(o)
140
+ if self._v_feature_extractor
141
+ else self._feature_extractor(o)
142
+ )
143
+ v = self._v(fe)
144
+ return v.cpu().numpy()
145
+
146
+ def step(self, obs: VecEnvObs) -> Step:
147
+ o = self._as_tensor(obs)
148
+ with torch.no_grad():
149
+ (pi, _, _), p_fc = self._pi_forward(o)
150
+ a = pi.sample()
151
+ logp_a = pi.log_prob(a)
152
+
153
+ v = self._v_forward(o, p_fc)
154
+
155
+ a_np = a.cpu().numpy()
156
+ clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
157
+ return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
158
+
159
+ def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
160
+ if not deterministic:
161
+ return self.step(obs).clamped_a
162
+ else:
163
+ o = self._as_tensor(obs)
164
+ with torch.no_grad():
165
+ (pi, _, _), _ = self._pi_forward(o)
166
+ a = pi.mode
167
+ return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
168
+
169
+ def load(self, path: str) -> None:
170
+ super().load(path)
171
+ self.reset_noise()
172
+
173
+ def reset_noise(self, batch_size: Optional[int] = None) -> None:
174
+ if isinstance(self._pi, StateDependentNoiseActorHead):
175
+ self._pi.sample_weights(
176
+ batch_size=batch_size if batch_size else self.env.num_envs
177
+ )
shared/policy/policy.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from abc import ABC, abstractmethod
7
+ from stable_baselines3.common.vec_env import unwrap_vec_normalize
8
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
9
+ from typing import Dict, Optional, Type, TypeVar, Union
10
+
11
+ ACTIVATION: Dict[str, Type[nn.Module]] = {
12
+ "tanh": nn.Tanh,
13
+ "relu": nn.ReLU,
14
+ }
15
+
16
+ VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
17
+ MODEL_FILENAME = "model.pth"
18
+
19
+ PolicySelf = TypeVar("PolicySelf", bound="Policy")
20
+
21
+
22
+ class Policy(nn.Module, ABC):
23
+ @abstractmethod
24
+ def __init__(self, env: VecEnv, **kwargs) -> None:
25
+ super().__init__()
26
+ self.env = env
27
+ self.vec_normalize = unwrap_vec_normalize(env)
28
+ self.device = None
29
+
30
+ def to(
31
+ self: PolicySelf,
32
+ device: Optional[torch.device] = None,
33
+ dtype: Optional[Union[torch.dtype, str]] = None,
34
+ non_blocking: bool = False,
35
+ ) -> PolicySelf:
36
+ super().to(device, dtype, non_blocking)
37
+ self.device = device
38
+ return self
39
+
40
+ @abstractmethod
41
+ def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
42
+ ...
43
+
44
+ def save(self, path: str) -> None:
45
+ os.makedirs(path, exist_ok=True)
46
+
47
+ if self.vec_normalize:
48
+ self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
49
+ torch.save(
50
+ self.state_dict(),
51
+ os.path.join(path, MODEL_FILENAME),
52
+ )
53
+
54
+ @abstractmethod
55
+ def load(self, path: str) -> None:
56
+ # VecNormalize load occurs in env.py
57
+ self.load_state_dict(
58
+ torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
59
+ )
60
+
61
+ def reset_noise(self) -> None:
62
+ pass
shared/schedule.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ Schedule = Callable[[float], float]
4
+
5
+
6
+ def linear_schedule(
7
+ start_val: float, end_val: float, end_fraction: float = 1.0
8
+ ) -> Schedule:
9
+ def func(progress_fraction: float) -> float:
10
+ if progress_fraction >= end_fraction:
11
+ return end_val
12
+ else:
13
+ return start_val + (end_val - start_val) * progress_fraction / end_fraction
14
+
15
+ return func
16
+
17
+
18
+ def constant_schedule(val: float) -> Schedule:
19
+ return lambda f: val
shared/stats.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from dataclasses import dataclass
4
+ from torch.utils.tensorboard.writer import SummaryWriter
5
+ from typing import Dict, List, Optional, Sequence, TypeVar
6
+
7
+
8
+ @dataclass
9
+ class Episode:
10
+ score: float = 0
11
+ length: int = 0
12
+
13
+
14
+ StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
15
+
16
+
17
+ @dataclass
18
+ class Statistic:
19
+ values: np.ndarray
20
+ round_digits: int = 2
21
+
22
+ @property
23
+ def mean(self) -> float:
24
+ return np.mean(self.values).item()
25
+
26
+ @property
27
+ def std(self) -> float:
28
+ return np.std(self.values).item()
29
+
30
+ @property
31
+ def min(self) -> float:
32
+ return np.min(self.values).item()
33
+
34
+ @property
35
+ def max(self) -> float:
36
+ return np.max(self.values).item()
37
+
38
+ def sum(self) -> float:
39
+ return np.sum(self.values).item()
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.values)
43
+
44
+ def _diff(self: StatisticSelf, o: StatisticSelf) -> float:
45
+ return (self.mean - self.std) - (o.mean - o.std)
46
+
47
+ def __gt__(self: StatisticSelf, o: StatisticSelf) -> bool:
48
+ return self._diff(o) > 0
49
+
50
+ def __ge__(self: StatisticSelf, o: StatisticSelf) -> bool:
51
+ return self._diff(o) >= 0
52
+
53
+ def __repr__(self) -> str:
54
+ mean = round(self.mean, self.round_digits)
55
+ std = round(self.std, self.round_digits)
56
+ if self.round_digits == 0:
57
+ mean = int(mean)
58
+ std = int(std)
59
+ return f"{mean} +/- {std}"
60
+
61
+ def to_dict(self) -> Dict[str, float]:
62
+ return {
63
+ "mean": self.mean,
64
+ "std": self.std,
65
+ "min": self.min,
66
+ "max": self.max,
67
+ }
68
+
69
+
70
+ EpisodesStatsSelf = TypeVar("EpisodesStatsSelf", bound="EpisodesStats")
71
+
72
+
73
+ class EpisodesStats:
74
+ episodes: Sequence[Episode]
75
+ simple: bool
76
+ score: Statistic
77
+ length: Statistic
78
+
79
+ def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
80
+ self.episodes = episodes
81
+ self.simple = simple
82
+ self.score = Statistic(np.array([e.score for e in episodes]))
83
+ self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
84
+
85
+ def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
86
+ return self.score > o.score
87
+
88
+ def __ge__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
89
+ return self.score >= o.score
90
+
91
+ def __repr__(self) -> str:
92
+ return (
93
+ f"Score: {self.score} ({round(self.score.mean - self.score.std, 2)}) | "
94
+ f"Length: {self.length}"
95
+ )
96
+
97
+ def __len__(self) -> int:
98
+ return len(self.episodes)
99
+
100
+ def _asdict(self) -> dict:
101
+ return {
102
+ "n_episodes": len(self.episodes),
103
+ "score": self.score.to_dict(),
104
+ "length": self.length.to_dict(),
105
+ }
106
+
107
+ def write_to_tensorboard(
108
+ self, tb_writer: SummaryWriter, main_tag: str, global_step: Optional[int] = None
109
+ ) -> None:
110
+ stats = {"mean": self.score.mean}
111
+ if not self.simple:
112
+ stats.update(
113
+ {
114
+ "min": self.score.min,
115
+ "max": self.score.max,
116
+ "result": self.score.mean - self.score.std,
117
+ "n_episodes": len(self.episodes),
118
+ }
119
+ )
120
+ tb_writer.add_scalars(
121
+ main_tag,
122
+ stats,
123
+ global_step=global_step,
124
+ )
125
+
126
+
127
+ class EpisodeAccumulator:
128
+ def __init__(self, num_envs: int):
129
+ self._episodes = []
130
+ self.current_episodes = [Episode() for _ in range(num_envs)]
131
+
132
+ @property
133
+ def episodes(self) -> List[Episode]:
134
+ return self._episodes
135
+
136
+ def step(self, reward: np.ndarray, done: np.ndarray) -> None:
137
+ for idx, current in enumerate(self.current_episodes):
138
+ current.score += reward[idx]
139
+ current.length += 1
140
+ if done[idx]:
141
+ self._episodes.append(current)
142
+ self.on_done(idx, current)
143
+ self.current_episodes[idx] = Episode()
144
+
145
+ def __len__(self) -> int:
146
+ return len(self.episodes)
147
+
148
+ def on_done(self, ep_idx: int, episode: Episode) -> None:
149
+ pass
150
+
151
+ def stats(self) -> EpisodesStats:
152
+ return EpisodesStats(self.episodes)