diff --git a/README.md b/README.md
index 98dae91ca6d31a4649b9b1967d84a870f272cecc..1d9ea2daafbac6dc73bc44a80f37a87f9bdf9637 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@ model-index:
results:
- metrics:
- type: mean_reward
- value: 2669.98 +/- 65.52
+ value: 2681.31 +/- 77.63
name: mean_reward
task:
type: reinforcement-learning
@@ -23,17 +23,17 @@ model-index:
This is a trained model of a **PPO** agent playing **AntBulletEnv-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
-All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/6p2sjqtn.
+All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/09frjfcs.
## Training Results
-This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
+This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|:-------|:----------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
-| ppo | AntBulletEnv-v0 | 4 | 2669.98 | 65.5195 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/vemwk5yn) |
-| ppo | AntBulletEnv-v0 | 5 | 884.068 | 1.61404 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/iu48fxl2) |
-| ppo | AntBulletEnv-v0 | 6 | 2487.6 | 47.859 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/af73b738) |
+| ppo | AntBulletEnv-v0 | 1 | 2681.31 | 77.631 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/2nh30771) |
+| ppo | AntBulletEnv-v0 | 2 | 2515.68 | 15.6691 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/0hftaqpr) |
+| ppo | AntBulletEnv-v0 | 3 | 2555.11 | 45.9397 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/kwmmqzbc) |
### Prerequisites: Weights & Biases (WandB)
@@ -53,10 +53,10 @@ login`.
Note: While the model state dictionary and hyperaparameters are saved, the latest
implementation could be sufficiently different to not be able to reproduce similar
results. You might need to checkout the commit the agent was trained on:
-[5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
+[2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0).
```
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
-python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/vemwk5yn
+python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/2nh30771
```
Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -68,11 +68,11 @@ notebook.
## Training
If you want the highest chance to reproduce these results, you'll want to checkout the
-commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). While
+commit the agent was trained on: [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0). While
training is deterministic, different hardware will give different results.
```
-python train.py --algo ppo --env AntBulletEnv-v0 --seed 4
+python train.py --algo ppo --env AntBulletEnv-v0 --seed 1
```
Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -83,14 +83,14 @@ notebook.
## Benchmarking (with Lambda Labs instance)
-This and other models from https://api.wandb.ai/links/sgoodfriend/6p2sjqtn were generated by running a script on a Lambda
+This and other models from https://api.wandb.ai/links/sgoodfriend/09frjfcs were generated by running a script on a Lambda
Labs instance. In a Lambda Labs instance terminal:
```
git clone git@github.com:sgoodfriend/rl-algo-impls.git
cd rl-algo-impls
bash ./lambda_labs/setup.sh
wandb login
-bash ./lambda_labs/benchmark.sh
+bash ./lambda_labs/benchmark.sh [-a {"ppo a2c dqn vpg"}] [-e ENVS] [-j {6}] [-p {rl-algo-impls-benchmarks}] [-s {"1 2 3"}]
```
### Alternative: Google Colab Pro+
@@ -116,12 +116,14 @@ algo_hyperparams:
max_grad_norm: 0.5
n_epochs: 20
n_steps: 512
- sde_sample_freq: 4
vf_coef: 0.5
+device: auto
env: AntBulletEnv-v0
env_hyperparams:
n_envs: 16
normalize: true
+env_id: null
+eval_params: {}
n_timesteps: 2000000
policy_hyperparams:
activation_fn: relu
@@ -131,12 +133,13 @@ policy_hyperparams:
v_hidden_sizes:
- 256
- 256
-seed: 4
+seed: 1
use_deterministic_algorithms: true
wandb_entity: null
+wandb_group: null
wandb_project_name: rl-algo-impls-benchmarks
wandb_tags:
-- benchmark_5598ebc
-- host_192-9-145-26
+- benchmark_2067e21
+- host_155-248-199-228
```
diff --git a/benchmark_publish.py b/benchmark_publish.py
index 0c09d70350f95f892d93c31b282477d4ed759bca..a8f80bc474fb52f40b408656959b2e991c8145f7 100644
--- a/benchmark_publish.py
+++ b/benchmark_publish.py
@@ -1,91 +1,4 @@
-import argparse
-import subprocess
-import wandb
-import wandb.apis.public
-
-from collections import defaultdict
-from multiprocessing.pool import ThreadPool
-from typing import List, NamedTuple
-
-
-class RunGroup(NamedTuple):
- algo: str
- env_id: str
-
+from rl_algo_impls.benchmark_publish import benchmark_publish
if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--wandb-project-name",
- type=str,
- default="rl-algo-impls-benchmarks",
- help="WandB project name to load runs from",
- )
- parser.add_argument(
- "--wandb-entity",
- type=str,
- default=None,
- help="WandB team of project. None uses default entity",
- )
- parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
- parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
- parser.add_argument(
- "--envs", type=str, nargs="*", help="Optional filter down to these envs"
- )
- parser.add_argument(
- "--huggingface-user",
- type=str,
- default=None,
- help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
- )
- parser.add_argument(
- "--pool-size",
- type=int,
- default=3,
- help="How many publish jobs can run in parallel",
- )
- parser.set_defaults(
- wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
- wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
- envs=["CartPole-v1", "Acrobot-v1"],
- )
- args = parser.parse_args()
- print(args)
-
- api = wandb.Api()
- all_runs = api.runs(
- f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
- )
-
- required_tags = set(args.wandb_tags)
- runs: List[wandb.apis.public.Run] = [
- r
- for r in all_runs
- if required_tags.issubset(set(r.config.get("wandb_tags", [])))
- ]
-
- runs_paths_by_group = defaultdict(list)
- for r in runs:
- algo = r.config["algo"]
- env = r.config["env"]
- if args.envs and env not in args.envs:
- continue
- run_group = RunGroup(algo, env)
- runs_paths_by_group[run_group].append("/".join(r.path))
-
- def run(run_paths: List[str]) -> None:
- publish_args = ["python", "huggingface_publish.py"]
- publish_args.append("--wandb-run-paths")
- publish_args.extend(run_paths)
- publish_args.append("--wandb-report-url")
- publish_args.append(args.wandb_report_url)
- if args.huggingface_user:
- publish_args.append("--huggingface-user")
- publish_args.append(args.huggingface_user)
- subprocess.run(publish_args)
-
- tp = ThreadPool(args.pool_size)
- for run_paths in runs_paths_by_group.values():
- tp.apply_async(run, (run_paths,))
- tp.close()
- tp.join()
+ benchmark_publish()
diff --git a/colab/colab_atari1.sh b/colab/colab_atari1.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d1b6d8baa7091257be27dea2382ebae6a22f6faf
--- /dev/null
+++ b/colab/colab_atari1.sh
@@ -0,0 +1,4 @@
+ALGO="ppo"
+ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
+BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
+bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
\ No newline at end of file
diff --git a/colab/colab_atari2.sh b/colab/colab_atari2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c45745ee01a77d045232e3c4926fad91a4e3ac98
--- /dev/null
+++ b/colab/colab_atari2.sh
@@ -0,0 +1,4 @@
+ALGO="ppo"
+ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
+BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
+bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
\ No newline at end of file
diff --git a/colab/colab_basic.sh b/colab/colab_basic.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c7e8ddef70a75d292cb418122c908c5fb947520c
--- /dev/null
+++ b/colab/colab_basic.sh
@@ -0,0 +1,4 @@
+ALGO="ppo"
+ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
+BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
+bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
diff --git a/colab/colab_benchmark.ipynb b/colab/colab_benchmark.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..3ddea3ce29511c785b05e78eb5fdc2dc51e727f4
--- /dev/null
+++ b/colab/colab_benchmark.ipynb
@@ -0,0 +1,195 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm",
+ "authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "gpuClass": "standard",
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
+ "## Parameters\n",
+ "\n",
+ "\n",
+ "1. Wandb\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "S-tXDWP8WTLc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from getpass import getpass\n",
+ "import os\n",
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
+ ],
+ "metadata": {
+ "id": "1ZtdYgxWNGwZ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Setup\n",
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
+ ],
+ "metadata": {
+ "id": "bsG35Io0hmKG"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
+ ],
+ "metadata": {
+ "id": "k5ynTV25hdAf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Installing the correct packages:\n",
+ "\n",
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
+ ],
+ "metadata": {
+ "id": "jKxGok-ElYQ7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "!apt install python-opengl\n",
+ "!apt install ffmpeg\n",
+ "!apt install xvfb\n",
+ "!apt install swig"
+ ],
+ "metadata": {
+ "id": "nn6EETTc2Ewf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "%cd /content/rl-algo-impls\n",
+ "python -m pip install ."
+ ],
+ "metadata": {
+ "id": "AfZh9rH3yQii"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run Once Per Runtime"
+ ],
+ "metadata": {
+ "id": "4o5HOLjc4wq7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import wandb\n",
+ "wandb.login()"
+ ],
+ "metadata": {
+ "id": "PCXa5tdS2qFX"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Restart Session beteween runs"
+ ],
+ "metadata": {
+ "id": "AZBZfSUV43JQ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "from pyvirtualdisplay import Display\n",
+ "\n",
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
+ "virtual_display.start()"
+ ],
+ "metadata": {
+ "id": "VzemeQJP2NO9"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The below 5 bash scripts train agents on environments with 3 seeds each:\n",
+ "- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
+ "- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
+ "- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
+ ],
+ "metadata": {
+ "id": "nSHfna0hLlO1"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content/rl-algo-impls\n",
+ "os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
+ "!./benchmarks/colab_basic.sh\n",
+ "!./benchmarks/colab_pybullet.sh\n",
+ "# !./benchmarks/colab_carracing.sh\n",
+ "# !./benchmarks/colab_atari1.sh\n",
+ "# !./benchmarks/colab_atari2.sh"
+ ],
+ "metadata": {
+ "id": "07aHYFH1zfXa"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/colab/colab_carracing.sh b/colab/colab_carracing.sh
new file mode 100644
index 0000000000000000000000000000000000000000..03e7cc8d811204239655b1415e1f551cc819db66
--- /dev/null
+++ b/colab/colab_carracing.sh
@@ -0,0 +1,4 @@
+ALGO="ppo"
+ENVS="CarRacing-v0"
+BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
+bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
\ No newline at end of file
diff --git a/colab/colab_enjoy.ipynb b/colab/colab_enjoy.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..21a45078865fd8be4f073827d1c0ae8775cc96af
--- /dev/null
+++ b/colab/colab_enjoy.ipynb
@@ -0,0 +1,198 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm",
+ "authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "gpuClass": "standard",
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
+ "## Parameters\n",
+ "\n",
+ "\n",
+ "1. Wandb\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "S-tXDWP8WTLc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from getpass import getpass\n",
+ "import os\n",
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
+ ],
+ "metadata": {
+ "id": "1ZtdYgxWNGwZ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "2. enjoy.py parameters"
+ ],
+ "metadata": {
+ "id": "ao0nAh3MOdN7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
+ ],
+ "metadata": {
+ "id": "jKL_NFhVOjSc"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Setup\n",
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
+ ],
+ "metadata": {
+ "id": "bsG35Io0hmKG"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
+ ],
+ "metadata": {
+ "id": "k5ynTV25hdAf"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Installing the correct packages:\n",
+ "\n",
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
+ ],
+ "metadata": {
+ "id": "jKxGok-ElYQ7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "!apt install python-opengl\n",
+ "!apt install ffmpeg\n",
+ "!apt install xvfb\n",
+ "!apt install swig"
+ ],
+ "metadata": {
+ "id": "nn6EETTc2Ewf"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "%cd /content/rl-algo-impls\n",
+ "python -m pip install ."
+ ],
+ "metadata": {
+ "id": "AfZh9rH3yQii"
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run Once Per Runtime"
+ ],
+ "metadata": {
+ "id": "4o5HOLjc4wq7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import wandb\n",
+ "wandb.login()"
+ ],
+ "metadata": {
+ "id": "PCXa5tdS2qFX"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Restart Session beteween runs"
+ ],
+ "metadata": {
+ "id": "AZBZfSUV43JQ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "from pyvirtualdisplay import Display\n",
+ "\n",
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
+ "virtual_display.start()"
+ ],
+ "metadata": {
+ "id": "VzemeQJP2NO9"
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content/rl-algo-impls\n",
+ "!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
+ ],
+ "metadata": {
+ "id": "07aHYFH1zfXa"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/colab/colab_pybullet.sh b/colab/colab_pybullet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b8a7ed364a7d7378fe4540196bbfdc89b2304be9
--- /dev/null
+++ b/colab/colab_pybullet.sh
@@ -0,0 +1,4 @@
+ALGO="ppo"
+ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
+BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
+bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
\ No newline at end of file
diff --git a/colab/colab_train.ipynb b/colab/colab_train.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d2626d3f8f86b3e57e49eda8cb30f04c211a481c
--- /dev/null
+++ b/colab/colab_train.ipynb
@@ -0,0 +1,200 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm",
+ "authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "gpuClass": "standard",
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
+ "## Parameters\n",
+ "\n",
+ "\n",
+ "1. Wandb\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "S-tXDWP8WTLc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from getpass import getpass\n",
+ "import os\n",
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
+ ],
+ "metadata": {
+ "id": "1ZtdYgxWNGwZ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "2. train run parameters"
+ ],
+ "metadata": {
+ "id": "ao0nAh3MOdN7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ALGO = \"ppo\"\n",
+ "ENV = \"CartPole-v1\"\n",
+ "SEED = 1"
+ ],
+ "metadata": {
+ "id": "jKL_NFhVOjSc"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Setup\n",
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
+ ],
+ "metadata": {
+ "id": "bsG35Io0hmKG"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
+ ],
+ "metadata": {
+ "id": "k5ynTV25hdAf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Installing the correct packages:\n",
+ "\n",
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
+ ],
+ "metadata": {
+ "id": "jKxGok-ElYQ7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "!apt install python-opengl\n",
+ "!apt install ffmpeg\n",
+ "!apt install xvfb\n",
+ "!apt install swig"
+ ],
+ "metadata": {
+ "id": "nn6EETTc2Ewf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "%cd /content/rl-algo-impls\n",
+ "python -m pip install ."
+ ],
+ "metadata": {
+ "id": "AfZh9rH3yQii"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run Once Per Runtime"
+ ],
+ "metadata": {
+ "id": "4o5HOLjc4wq7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import wandb\n",
+ "wandb.login()"
+ ],
+ "metadata": {
+ "id": "PCXa5tdS2qFX"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Restart Session beteween runs"
+ ],
+ "metadata": {
+ "id": "AZBZfSUV43JQ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "from pyvirtualdisplay import Display\n",
+ "\n",
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
+ "virtual_display.start()"
+ ],
+ "metadata": {
+ "id": "VzemeQJP2NO9"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content/rl-algo-impls\n",
+ "!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
+ ],
+ "metadata": {
+ "id": "07aHYFH1zfXa"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/compare_runs.py b/compare_runs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e892b04a7e82069b0cab74d999bfff7c61293685
--- /dev/null
+++ b/compare_runs.py
@@ -0,0 +1,4 @@
+from rl_algo_impls.compare_runs import compare_runs
+
+if __name__ == "__main__":
+ compare_runs()
diff --git a/enjoy.py b/enjoy.py
index fd004de69152d374d292d4910e9854364a4a08e5..9d57504e3bb5fc5f841e49caef149eeecfd33a25 100644
--- a/enjoy.py
+++ b/enjoy.py
@@ -1,30 +1,4 @@
-# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
-import os
-
-os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
-
-from runner.evaluate import EvalArgs, evaluate_model
-from runner.running_utils import base_parser
-
+from rl_algo_impls.enjoy import enjoy
if __name__ == "__main__":
- parser = base_parser(multiple=False)
- parser.add_argument("--render", default=True, type=bool)
- parser.add_argument("--best", default=True, type=bool)
- parser.add_argument("--n_envs", default=1, type=int)
- parser.add_argument("--n_episodes", default=3, type=int)
- parser.add_argument("--deterministic-eval", default=None, type=bool)
- parser.add_argument(
- "--no-print-returns", action="store_true", help="Limit printing"
- )
- # wandb-run-path overrides base RunArgs
- parser.add_argument("--wandb-run-path", default=None, type=str)
- parser.set_defaults(
- algo=["ppo"],
- )
- args = parser.parse_args()
- args.algo = args.algo[0]
- args.env = args.env[0]
- args = EvalArgs(**vars(args))
-
- evaluate_model(args, os.path.dirname(__file__))
+ enjoy()
diff --git a/environment.yml b/environment.yml
index 969f19408426ae75cf64779f01acd2418adbfd85..a5f2efb65d96b38ee9bdba44fb9b5dcd26857a72 100644
--- a/environment.yml
+++ b/environment.yml
@@ -4,14 +4,9 @@ channels:
- conda-forge
- nodefaults
dependencies:
- - python=3.10.*
+ - python>=3.8, <3.11
- mamba
- pip
- - poetry
- pytorch
- torchvision
- torchaudio
- - cmake
- - swig
- - ipywidgets
- - black
diff --git a/huggingface_publish.py b/huggingface_publish.py
index 2e265cd447d8606fd6f6329ad96dd826acceaff6..86d93e94755fd8eeb035478b037556498980ed00 100644
--- a/huggingface_publish.py
+++ b/huggingface_publish.py
@@ -1,177 +1,4 @@
-import os
-
-os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
-
-import argparse
-import requests
-import shutil
-import subprocess
-import tempfile
-import wandb
-import wandb.apis.public
-
-from typing import List, Optional
-
-from huggingface_hub.hf_api import HfApi, upload_folder
-from huggingface_hub.repocard import metadata_save
-from publish.markdown_format import EvalTableData, model_card_text
-from runner.evaluate import EvalArgs, evaluate_model
-from runner.env import make_eval_env
-from shared.callbacks.eval_callback import evaluate
-from wrappers.vec_episode_recorder import VecEpisodeRecorder
-
-
-def publish(
- wandb_run_paths: List[str],
- wandb_report_url: str,
- huggingface_user: Optional[str] = None,
- huggingface_token: Optional[str] = None,
-) -> None:
- api = wandb.Api()
- runs = [api.run(rp) for rp in wandb_run_paths]
- algo = runs[0].config["algo"]
- env = runs[0].config["env"]
- evaluations = [
- evaluate_model(
- EvalArgs(
- algo,
- env,
- seed=r.config.get("seed", None),
- render=False,
- best=True,
- n_envs=None,
- n_episodes=10,
- no_print_returns=True,
- wandb_run_path="/".join(r.path),
- ),
- os.path.dirname(__file__),
- )
- for r in runs
- ]
- run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
- table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
- best_eval = sorted(
- table_data, key=lambda d: d.evaluation.stats.score, reverse=True
- )[0]
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- _, (policy, stats, config) = best_eval
-
- repo_name = config.model_name(include_seed=False)
- repo_dir_path = os.path.join(tmpdirname, repo_name)
- # Locally clone this repo to a temp directory
- subprocess.run(["git", "clone", ".", repo_dir_path])
- shutil.rmtree(os.path.join(repo_dir_path, ".git"))
- model_path = config.model_dir_path(best=True, downloaded=True)
- shutil.copytree(
- model_path,
- os.path.join(
- repo_dir_path, "saved_models", config.model_dir_name(best=True)
- ),
- )
-
- github_url = "https://github.com/sgoodfriend/rl-algo-impls"
- commit_hash = run_metadata.get("git", {}).get("commit", None)
- card_text = model_card_text(
- algo,
- env,
- github_url,
- commit_hash,
- wandb_report_url,
- table_data,
- best_eval,
- )
- readme_filepath = os.path.join(repo_dir_path, "README.md")
- os.remove(readme_filepath)
- with open(readme_filepath, "w") as f:
- f.write(card_text)
-
- metadata = {
- "library_name": "rl-algo-impls",
- "tags": [
- env,
- algo,
- "deep-reinforcement-learning",
- "reinforcement-learning",
- ],
- "model-index": [
- {
- "name": algo,
- "results": [
- {
- "metrics": [
- {
- "type": "mean_reward",
- "value": str(stats.score),
- "name": "mean_reward",
- }
- ],
- "task": {
- "type": "reinforcement-learning",
- "name": "reinforcement-learning",
- },
- "dataset": {
- "name": env,
- "type": env,
- },
- }
- ],
- }
- ],
- }
- metadata_save(readme_filepath, metadata)
-
- video_env = VecEpisodeRecorder(
- make_eval_env(
- config,
- override_n_envs=1,
- normalize_load_path=model_path,
- **config.env_hyperparams,
- ),
- os.path.join(repo_dir_path, "replay"),
- max_video_length=3600,
- )
- evaluate(
- video_env,
- policy,
- 1,
- deterministic=config.eval_params.get("deterministic", True),
- )
-
- api = HfApi()
- huggingface_user = huggingface_user or api.whoami()["name"]
- huggingface_repo = f"{huggingface_user}/{repo_name}"
- api.create_repo(
- token=huggingface_token,
- repo_id=huggingface_repo,
- private=False,
- exist_ok=True,
- )
- repo_url = upload_folder(
- repo_id=huggingface_repo,
- folder_path=repo_dir_path,
- path_in_repo="",
- commit_message=f"{algo.upper()} playing {env} from {github_url}/tree/{commit_hash}",
- token=huggingface_token,
- )
- print(f"Pushed model to the hub: {repo_url}")
-
+from rl_algo_impls.huggingface_publish import huggingface_publish
if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--wandb-run-paths",
- type=str,
- nargs="+",
- help="Run paths of the form entity/project/run_id",
- )
- parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
- parser.add_argument(
- "--huggingface-user",
- type=str,
- help="Huggingface user or team to upload model cards",
- default=None,
- )
- args = parser.parse_args()
- print(args)
- publish(**vars(args))
+ huggingface_publish()
diff --git a/optimize.py b/optimize.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bed687678f26f9a19b5b43ecf1068f8c2754a28
--- /dev/null
+++ b/optimize.py
@@ -0,0 +1,4 @@
+from rl_algo_impls.optimize import optimize
+
+if __name__ == "__main__":
+ optimize()
diff --git a/pyproject.toml b/pyproject.toml
index 1e8ec4d9f23d23b9eb3363e5e77728a71bc03600..eb996603ea40cffd2969ed65f2c3d2d1ab4516ea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,29 +1,64 @@
-[tool.poetry]
-name = "rl-algo-impls"
-version = "0.1.0"
+[project]
+name = "rl_algo_impls"
+version = "0.0.4"
description = "Implementations of reinforcement learning algorithms"
-authors = ["Scott Goodfriend "]
-license = "MIT License"
+authors = [
+ {name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"},
+]
+license = {file = "LICENSE"}
readme = "README.md"
-packages = [{include = "rl_algo_impls"}]
+requires-python = ">= 3.8"
+classifiers = [
+ "License :: OSI Approved :: MIT License",
+ "Development Status :: 3 - Alpha",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+]
+dependencies = [
+ "cmake",
+ "swig",
+ "scipy",
+ "torch",
+ "torchvision",
+ "tensorboard >= 2.11.2, < 2.12",
+ "AutoROM.accept-rom-license >= 0.4.2, < 0.5",
+ "stable-baselines3[extra] >= 1.7.0, < 1.8",
+ "gym[box2d] >= 0.21.0, < 0.22",
+ "pyglet == 1.5.27",
+ "wandb",
+ "pyvirtualdisplay",
+ "pybullet",
+ "tabulate",
+ "huggingface-hub",
+ "optuna",
+ "dash",
+ "kaleido",
+ "PyYAML",
+]
-[tool.poetry.dependencies]
-python = "~3.10"
-"AutoROM.accept-rom-license" = "^0.4.2"
-stable-baselines3 = {extras = ["extra"], version = "^1.7.0"}
-scipy = "^1.10.0"
-gym = {extras = ["box2d"], version = "^0.21.0"}
-pyglet = "1.5.27"
-PyYAML = "^6.0"
-tensorboard = "^2.11.0"
-pybullet = "^3.2.5"
-wandb = "^0.13.9"
-conda-lock = "^1.3.0"
-torch-tb-profiler = "^0.4.1"
-jupyter = "^1.0.0"
-tabulate = "^0.9.0"
-huggingface-hub = "^0.12.0"
+[tool.setuptools]
+packages = ["rl_algo_impls"]
+
+[project.optional-dependencies]
+test = [
+ "pytest",
+ "black",
+ "mypy",
+ "flake8",
+ "flake8-bugbear",
+ "isort",
+]
+procgen = [
+ "numexpr >= 2.8.4",
+ "gym3",
+ "glfw >= 1.12.0, < 1.13",
+ "procgen; platform_machine=='x86_64'",
+]
+
+[project.urls]
+"Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
[build-system]
-requires = ["poetry-core"]
-build-backend = "poetry.core.masonry.api"
+requires = ["setuptools==65.5.0", "setuptools-scm"]
+build-backend = "setuptools.build_meta"
\ No newline at end of file
diff --git a/replay.meta.json b/replay.meta.json
index 58495b3f32ddcb15f6e9abc875a4de4b884e0f2e..7c8df06ec00ef15f94499e2943dd9ff7f003ea83 100644
--- a/replay.meta.json
+++ b/replay.meta.json
@@ -1 +1 @@
-{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmp9sxuaf92/ppo-AntBulletEnv-v0/replay.mp4"]}, "episode": {"r": 2682.6376953125, "l": 1000, "t": 31.604279}}
\ No newline at end of file
+{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/tmptp1rp10h/ppo-AntBulletEnv-v0/replay.mp4"]}, "episode": {"r": 2772.951416015625, "l": 1000, "t": 29.059943}}
\ No newline at end of file
diff --git a/replay.mp4 b/replay.mp4
index 40f7537fe0a29c016ba195b19afe348d22e0c418..ca26f6839a4ccdfddb89636c39aafe4581230137 100644
--- a/replay.mp4
+++ b/replay.mp4
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:d5ac924cf833bcf539f1268564c431053fb0cb07bdb2555b3ca33d109eef560d
-size 1880057
+oid sha256:12d036e47daf108be27959b112b5e597bd879940f3735f18dc5a773f45d2e184
+size 1381735
diff --git a/rl_algo_impls/a2c/a2c.py b/rl_algo_impls/a2c/a2c.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5075c31b3a2366ebf299cc74f22691c4c2e66f3
--- /dev/null
+++ b/rl_algo_impls/a2c/a2c.py
@@ -0,0 +1,209 @@
+import logging
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from time import perf_counter
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Optional, TypeVar
+
+from rl_algo_impls.shared.algorithm import Algorithm
+from rl_algo_impls.shared.callbacks.callback import Callback
+from rl_algo_impls.shared.policy.on_policy import ActorCritic
+from rl_algo_impls.shared.schedule import schedule, update_learning_rate
+from rl_algo_impls.shared.stats import log_scalars
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecEnv,
+ single_observation_space,
+ single_action_space,
+)
+
+A2CSelf = TypeVar("A2CSelf", bound="A2C")
+
+
+class A2C(Algorithm):
+ def __init__(
+ self,
+ policy: ActorCritic,
+ env: VecEnv,
+ device: torch.device,
+ tb_writer: SummaryWriter,
+ learning_rate: float = 7e-4,
+ learning_rate_decay: str = "none",
+ n_steps: int = 5,
+ gamma: float = 0.99,
+ gae_lambda: float = 1.0,
+ ent_coef: float = 0.0,
+ ent_coef_decay: str = "none",
+ vf_coef: float = 0.5,
+ max_grad_norm: float = 0.5,
+ rms_prop_eps: float = 1e-5,
+ use_rms_prop: bool = True,
+ sde_sample_freq: int = -1,
+ normalize_advantage: bool = False,
+ ) -> None:
+ super().__init__(policy, env, device, tb_writer)
+ self.policy = policy
+
+ self.lr_schedule = schedule(learning_rate_decay, learning_rate)
+ if use_rms_prop:
+ self.optimizer = torch.optim.RMSprop(
+ policy.parameters(), lr=learning_rate, eps=rms_prop_eps
+ )
+ else:
+ self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
+
+ self.n_steps = n_steps
+
+ self.gamma = gamma
+ self.gae_lambda = gae_lambda
+
+ self.vf_coef = vf_coef
+ self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
+ self.max_grad_norm = max_grad_norm
+
+ self.sde_sample_freq = sde_sample_freq
+ self.normalize_advantage = normalize_advantage
+
+ def learn(
+ self: A2CSelf,
+ train_timesteps: int,
+ callback: Optional[Callback] = None,
+ total_timesteps: Optional[int] = None,
+ start_timesteps: int = 0,
+ ) -> A2CSelf:
+ if total_timesteps is None:
+ total_timesteps = train_timesteps
+ assert start_timesteps + train_timesteps <= total_timesteps
+ epoch_dim = (self.n_steps, self.env.num_envs)
+ step_dim = (self.env.num_envs,)
+ obs_space = single_observation_space(self.env)
+ act_space = single_action_space(self.env)
+
+ obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
+ actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
+ rewards = np.zeros(epoch_dim, dtype=np.float32)
+ episode_starts = np.zeros(epoch_dim, dtype=np.byte)
+ values = np.zeros(epoch_dim, dtype=np.float32)
+ logprobs = np.zeros(epoch_dim, dtype=np.float32)
+
+ next_obs = self.env.reset()
+ next_episode_starts = np.ones(step_dim, dtype=np.byte)
+
+ timesteps_elapsed = start_timesteps
+ while timesteps_elapsed < start_timesteps + train_timesteps:
+ start_time = perf_counter()
+
+ progress = timesteps_elapsed / total_timesteps
+ ent_coef = self.ent_coef_schedule(progress)
+ learning_rate = self.lr_schedule(progress)
+ update_learning_rate(self.optimizer, learning_rate)
+ log_scalars(
+ self.tb_writer,
+ "charts",
+ {
+ "ent_coef": ent_coef,
+ "learning_rate": learning_rate,
+ },
+ timesteps_elapsed,
+ )
+
+ self.policy.eval()
+ self.policy.reset_noise()
+ for s in range(self.n_steps):
+ timesteps_elapsed += self.env.num_envs
+ if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
+ self.policy.reset_noise()
+
+ obs[s] = next_obs
+ episode_starts[s] = next_episode_starts
+
+ actions[s], values[s], logprobs[s], clamped_action = self.policy.step(
+ next_obs
+ )
+ next_obs, rewards[s], next_episode_starts, _ = self.env.step(
+ clamped_action
+ )
+
+ advantages = np.zeros(epoch_dim, dtype=np.float32)
+ last_gae_lam = 0
+ for t in reversed(range(self.n_steps)):
+ if t == self.n_steps - 1:
+ next_nonterminal = 1.0 - next_episode_starts
+ next_value = self.policy.value(next_obs)
+ else:
+ next_nonterminal = 1.0 - episode_starts[t + 1]
+ next_value = values[t + 1]
+ delta = (
+ rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
+ )
+ last_gae_lam = (
+ delta
+ + self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
+ )
+ advantages[t] = last_gae_lam
+ returns = advantages + values
+
+ b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
+ b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to(
+ self.device
+ )
+ b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
+ b_returns = torch.tensor(returns.reshape(-1)).to(self.device)
+
+ if self.normalize_advantage:
+ b_advantages = (b_advantages - b_advantages.mean()) / (
+ b_advantages.std() + 1e-8
+ )
+
+ self.policy.train()
+ logp_a, entropy, v = self.policy(b_obs, b_actions)
+
+ pi_loss = -(b_advantages * logp_a).mean()
+ value_loss = F.mse_loss(b_returns, v)
+ entropy_loss = -entropy.mean()
+
+ loss = pi_loss + self.vf_coef * value_loss + ent_coef * entropy_loss
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
+ self.optimizer.step()
+
+ y_pred = values.reshape(-1)
+ y_true = returns.reshape(-1)
+ var_y = np.var(y_true).item()
+ explained_var = (
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
+ )
+
+ end_time = perf_counter()
+ rollout_steps = self.n_steps * self.env.num_envs
+ self.tb_writer.add_scalar(
+ "train/steps_per_second",
+ (rollout_steps) / (end_time - start_time),
+ timesteps_elapsed,
+ )
+
+ log_scalars(
+ self.tb_writer,
+ "losses",
+ {
+ "loss": loss.item(),
+ "pi_loss": pi_loss.item(),
+ "v_loss": value_loss.item(),
+ "entropy_loss": entropy_loss.item(),
+ "explained_var": explained_var,
+ },
+ timesteps_elapsed,
+ )
+
+ if callback:
+ if not callback.on_step(timesteps_elapsed=rollout_steps):
+ logging.info(
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
+ )
+ break
+
+ return self
diff --git a/rl_algo_impls/a2c/optimize.py b/rl_algo_impls/a2c/optimize.py
new file mode 100644
index 0000000000000000000000000000000000000000..baf9299ef141ea7c9ae49b0a0c707fed4ac14bc6
--- /dev/null
+++ b/rl_algo_impls/a2c/optimize.py
@@ -0,0 +1,77 @@
+import optuna
+
+from copy import deepcopy
+
+from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
+from rl_algo_impls.runner.env import make_eval_env
+from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
+from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
+
+
+def sample_params(
+ trial: optuna.Trial,
+ base_hyperparams: Hyperparams,
+ base_config: Config,
+) -> Hyperparams:
+ hyperparams = deepcopy(base_hyperparams)
+
+ base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
+ env = make_eval_env(base_config, base_env_hyperparams, override_n_envs=1)
+
+ # env_hyperparams
+ env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
+
+ # policy_hyperparams
+ policy_hyperparams = sample_on_policy_hyperparams(
+ trial, hyperparams.policy_hyperparams, env
+ )
+
+ # algo_hyperparams
+ algo_hyperparams = hyperparams.algo_hyperparams
+
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 2e-3, log=True)
+ learning_rate_decay = trial.suggest_categorical(
+ "learning_rate_decay", ["none", "linear"]
+ )
+ n_steps_exp = trial.suggest_int("n_steps_exp", 1, 10)
+ n_steps = 2**n_steps_exp
+ trial.set_user_attr("n_steps", n_steps)
+ gamma = 1.0 - trial.suggest_float("gamma_om", 1e-4, 1e-1, log=True)
+ trial.set_user_attr("gamma", gamma)
+ gae_lambda = 1 - trial.suggest_float("gae_lambda_om", 1e-4, 1e-1)
+ trial.set_user_attr("gae_lambda", gae_lambda)
+ ent_coef = trial.suggest_float("ent_coef", 1e-8, 2.5e-2, log=True)
+ ent_coef_decay = trial.suggest_categorical("ent_coef_decay", ["none", "linear"])
+ vf_coef = trial.suggest_float("vf_coef", 0.1, 0.7)
+ max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e1, log=True)
+ use_rms_prop = trial.suggest_categorical("use_rms_prop", [True, False])
+ normalize_advantage = trial.suggest_categorical(
+ "normalize_advantage", [True, False]
+ )
+
+ algo_hyperparams.update(
+ {
+ "learning_rate": learning_rate,
+ "learning_rate_decay": learning_rate_decay,
+ "n_steps": n_steps,
+ "gamma": gamma,
+ "gae_lambda": gae_lambda,
+ "ent_coef": ent_coef,
+ "ent_coef_decay": ent_coef_decay,
+ "vf_coef": vf_coef,
+ "max_grad_norm": max_grad_norm,
+ "use_rms_prop": use_rms_prop,
+ "normalize_advantage": normalize_advantage,
+ }
+ )
+
+ if policy_hyperparams.get("use_sde", False):
+ sde_sample_freq = 2 ** trial.suggest_int("sde_sample_freq_exp", 0, n_steps_exp)
+ trial.set_user_attr("sde_sample_freq", sde_sample_freq)
+ algo_hyperparams["sde_sample_freq"] = sde_sample_freq
+ elif "sde_sample_freq" in algo_hyperparams:
+ del algo_hyperparams["sde_sample_freq"]
+
+ env.close()
+
+ return hyperparams
diff --git a/rl_algo_impls/benchmark_publish.py b/rl_algo_impls/benchmark_publish.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d2010efab9a48121f4b79cc7192ac6a7e5524b0
--- /dev/null
+++ b/rl_algo_impls/benchmark_publish.py
@@ -0,0 +1,111 @@
+import argparse
+import subprocess
+import wandb
+import wandb.apis.public
+
+from collections import defaultdict
+from multiprocessing.pool import ThreadPool
+from typing import List, NamedTuple
+
+
+class RunGroup(NamedTuple):
+ algo: str
+ env_id: str
+
+
+def benchmark_publish() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--wandb-project-name",
+ type=str,
+ default="rl-algo-impls-benchmarks",
+ help="WandB project name to load runs from",
+ )
+ parser.add_argument(
+ "--wandb-entity",
+ type=str,
+ default=None,
+ help="WandB team of project. None uses default entity",
+ )
+ parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
+ parser.add_argument(
+ "--envs", type=str, nargs="*", help="Optional filter down to these envs"
+ )
+ parser.add_argument(
+ "--exclude-envs",
+ type=str,
+ nargs="*",
+ help="Environments to exclude from publishing",
+ )
+ parser.add_argument(
+ "--huggingface-user",
+ type=str,
+ default=None,
+ help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
+ )
+ parser.add_argument(
+ "--pool-size",
+ type=int,
+ default=3,
+ help="How many publish jobs can run in parallel",
+ )
+ parser.add_argument(
+ "--virtual-display", action="store_true", help="Use headless virtual display"
+ )
+ # parser.set_defaults(
+ # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
+ # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
+ # envs=[],
+ # exclude_envs=[],
+ # )
+ args = parser.parse_args()
+ print(args)
+
+ api = wandb.Api()
+ all_runs = api.runs(
+ f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
+ )
+
+ required_tags = set(args.wandb_tags)
+ runs: List[wandb.apis.public.Run] = [
+ r
+ for r in all_runs
+ if required_tags.issubset(set(r.config.get("wandb_tags", [])))
+ ]
+
+ runs_paths_by_group = defaultdict(list)
+ for r in runs:
+ if r.state != "finished":
+ continue
+ algo = r.config["algo"]
+ env = r.config["env"]
+ if args.envs and env not in args.envs:
+ continue
+ if args.exclude_envs and env in args.exclude_envs:
+ continue
+ run_group = RunGroup(algo, env)
+ runs_paths_by_group[run_group].append("/".join(r.path))
+
+ def run(run_paths: List[str]) -> None:
+ publish_args = ["python", "huggingface_publish.py"]
+ publish_args.append("--wandb-run-paths")
+ publish_args.extend(run_paths)
+ publish_args.append("--wandb-report-url")
+ publish_args.append(args.wandb_report_url)
+ if args.huggingface_user:
+ publish_args.append("--huggingface-user")
+ publish_args.append(args.huggingface_user)
+ if args.virtual_display:
+ publish_args.append("--virtual-display")
+ subprocess.run(publish_args)
+
+ tp = ThreadPool(args.pool_size)
+ for run_paths in runs_paths_by_group.values():
+ tp.apply_async(run, (run_paths,))
+ tp.close()
+ tp.join()
+
+
+if __name__ == "__main__":
+ benchmark_publish()
diff --git a/rl_algo_impls/compare_runs.py b/rl_algo_impls/compare_runs.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d1341f62eeb6a54ab79a31d76c21148cdf1458
--- /dev/null
+++ b/rl_algo_impls/compare_runs.py
@@ -0,0 +1,198 @@
+import argparse
+import itertools
+import numpy as np
+import pandas as pd
+import wandb
+import wandb.apis.public
+
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Dict, Iterable, List, TypeVar
+
+from rl_algo_impls.benchmark_publish import RunGroup
+
+
+@dataclass
+class Comparison:
+ control_values: List[float]
+ experiment_values: List[float]
+
+ def mean_diff_percentage(self) -> float:
+ return self._diff_percentage(
+ np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
+ )
+
+ def median_diff_percentage(self) -> float:
+ return self._diff_percentage(
+ np.median(self.control_values).item(),
+ np.median(self.experiment_values).item(),
+ )
+
+ def _diff_percentage(self, c: float, e: float) -> float:
+ if c == e:
+ return 0
+ elif c == 0:
+ return float("inf") if e > 0 else float("-inf")
+ return 100 * (e - c) / c
+
+ def score(self) -> float:
+ return (
+ np.sum(
+ np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
+ ).item()
+ / 2
+ )
+
+
+RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
+
+
+class RunGroupRuns:
+ def __init__(
+ self,
+ run_group: RunGroup,
+ control: List[str],
+ experiment: List[str],
+ summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
+ summary_metrics: List[str] = ["mean", "result"],
+ ) -> None:
+ self.algo = run_group.algo
+ self.env = run_group.env_id
+ self.control = set(control)
+ self.experiment = set(experiment)
+
+ self.summary_stats = summary_stats
+ self.summary_metrics = summary_metrics
+
+ self.control_runs = []
+ self.experiment_runs = []
+
+ def add_run(self, run: wandb.apis.public.Run) -> None:
+ wandb_tags = set(run.config.get("wandb_tags", []))
+ if self.control & wandb_tags:
+ self.control_runs.append(run)
+ elif self.experiment & wandb_tags:
+ self.experiment_runs.append(run)
+
+ def comparisons_by_metric(self) -> Dict[str, Comparison]:
+ c_by_m = {}
+ for metric in (
+ f"{s}/{m}"
+ for s, m in itertools.product(self.summary_stats, self.summary_metrics)
+ ):
+ c_by_m[metric] = Comparison(
+ [c.summary[metric] for c in self.control_runs],
+ [e.summary[metric] for e in self.experiment_runs],
+ )
+ return c_by_m
+
+ @staticmethod
+ def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
+ results = defaultdict(list)
+ for r in rows:
+ if not r.control_runs or not r.experiment_runs:
+ continue
+ results["algo"].append(r.algo)
+ results["env"].append(r.env)
+ results["control"].append(r.control)
+ results["expierment"].append(r.experiment)
+ c_by_m = r.comparisons_by_metric()
+ results["score"].append(
+ sum(m.score() for m in c_by_m.values()) / len(c_by_m)
+ )
+ for m, c in c_by_m.items():
+ results[f"{m}_mean"].append(c.mean_diff_percentage())
+ results[f"{m}_median"].append(c.median_diff_percentage())
+ return pd.DataFrame(results)
+
+
+def compare_runs() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-p",
+ "--wandb-project-name",
+ type=str,
+ default="rl-algo-impls-benchmarks",
+ help="WandB project name to load runs from",
+ )
+ parser.add_argument(
+ "--wandb-entity",
+ type=str,
+ default=None,
+ help="WandB team. None uses default entity",
+ )
+ parser.add_argument(
+ "-n",
+ "--wandb-hostname-tag",
+ type=str,
+ nargs="*",
+ help="WandB tags for hostname (i.e. host_192-9-145-26)",
+ )
+ parser.add_argument(
+ "-c",
+ "--wandb-control-tag",
+ type=str,
+ nargs="+",
+ help="WandB tag for control commit (i.e. benchmark_5598ebc)",
+ )
+ parser.add_argument(
+ "-e",
+ "--wandb-experiment-tag",
+ type=str,
+ nargs="+",
+ help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
+ )
+ parser.add_argument(
+ "--envs",
+ type=str,
+ nargs="*",
+ help="If specified, only compare these envs",
+ )
+ parser.add_argument(
+ "--exclude-envs",
+ type=str,
+ nargs="*",
+ help="Environments to exclude from comparison",
+ )
+ # parser.set_defaults(
+ # wandb_hostname_tag=["host_150-230-44-105", "host_155-248-214-128"],
+ # wandb_control_tag=["benchmark_fbc943f"],
+ # wandb_experiment_tag=["benchmark_f59bf74"],
+ # exclude_envs=[],
+ # )
+ args = parser.parse_args()
+ print(args)
+
+ api = wandb.Api()
+ all_runs = api.runs(
+ path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
+ order="+created_at",
+ )
+
+ runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
+ wandb_hostname_tags = set(args.wandb_hostname_tag)
+ for r in all_runs:
+ if r.state != "finished":
+ continue
+ wandb_tags = set(r.config.get("wandb_tags", []))
+ if not wandb_tags or not wandb_hostname_tags & wandb_tags:
+ continue
+ rg = RunGroup(r.config["algo"], r.config.get("env_id") or r.config["env"])
+ if args.exclude_envs and rg.env_id in args.exclude_envs:
+ continue
+ if args.envs and rg.env_id not in args.envs:
+ continue
+ if rg not in runs_by_run_group:
+ runs_by_run_group[rg] = RunGroupRuns(
+ rg,
+ args.wandb_control_tag,
+ args.wandb_experiment_tag,
+ )
+ runs_by_run_group[rg].add_run(r)
+ df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
+ print(f"**Total Score: {sum(df.score)}**")
+ df.loc["mean"] = df.mean(numeric_only=True)
+ print(df.to_markdown())
+
+if __name__ == "__main__":
+ compare_runs()
\ No newline at end of file
diff --git a/rl_algo_impls/dqn/dqn.py b/rl_algo_impls/dqn/dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..57cd3e074444352d003d6f60ca95a5add79467b4
--- /dev/null
+++ b/rl_algo_impls/dqn/dqn.py
@@ -0,0 +1,182 @@
+import copy
+import numpy as np
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from collections import deque
+from torch.optim import Adam
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import NamedTuple, Optional, TypeVar
+
+from rl_algo_impls.dqn.policy import DQNPolicy
+from rl_algo_impls.shared.algorithm import Algorithm
+from rl_algo_impls.shared.callbacks.callback import Callback
+from rl_algo_impls.shared.schedule import linear_schedule
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
+
+
+class Transition(NamedTuple):
+ obs: np.ndarray
+ action: np.ndarray
+ reward: float
+ done: bool
+ next_obs: np.ndarray
+
+
+class Batch(NamedTuple):
+ obs: np.ndarray
+ actions: np.ndarray
+ rewards: np.ndarray
+ dones: np.ndarray
+ next_obs: np.ndarray
+
+
+class ReplayBuffer:
+ def __init__(self, num_envs: int, maxlen: int) -> None:
+ self.num_envs = num_envs
+ self.buffer = deque(maxlen=maxlen)
+
+ def add(
+ self,
+ obs: VecEnvObs,
+ action: np.ndarray,
+ reward: np.ndarray,
+ done: np.ndarray,
+ next_obs: VecEnvObs,
+ ) -> None:
+ assert isinstance(obs, np.ndarray)
+ assert isinstance(next_obs, np.ndarray)
+ for i in range(self.num_envs):
+ self.buffer.append(
+ Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
+ )
+
+ def sample(self, batch_size: int) -> Batch:
+ ts = random.sample(self.buffer, batch_size)
+ return Batch(
+ obs=np.array([t.obs for t in ts]),
+ actions=np.array([t.action for t in ts]),
+ rewards=np.array([t.reward for t in ts]),
+ dones=np.array([t.done for t in ts]),
+ next_obs=np.array([t.next_obs for t in ts]),
+ )
+
+ def __len__(self) -> int:
+ return len(self.buffer)
+
+
+DQNSelf = TypeVar("DQNSelf", bound="DQN")
+
+
+class DQN(Algorithm):
+ def __init__(
+ self,
+ policy: DQNPolicy,
+ env: VecEnv,
+ device: torch.device,
+ tb_writer: SummaryWriter,
+ learning_rate: float = 1e-4,
+ buffer_size: int = 1_000_000,
+ learning_starts: int = 50_000,
+ batch_size: int = 32,
+ tau: float = 1.0,
+ gamma: float = 0.99,
+ train_freq: int = 4,
+ gradient_steps: int = 1,
+ target_update_interval: int = 10_000,
+ exploration_fraction: float = 0.1,
+ exploration_initial_eps: float = 1.0,
+ exploration_final_eps: float = 0.05,
+ max_grad_norm: float = 10.0,
+ ) -> None:
+ super().__init__(policy, env, device, tb_writer)
+ self.policy = policy
+
+ self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
+
+ self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
+ self.target_q_net.train(False)
+ self.tau = tau
+ self.target_update_interval = target_update_interval
+
+ self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
+ self.batch_size = batch_size
+
+ self.learning_starts = learning_starts
+ self.train_freq = train_freq
+ self.gradient_steps = gradient_steps
+
+ self.gamma = gamma
+ self.exploration_eps_schedule = linear_schedule(
+ exploration_initial_eps,
+ exploration_final_eps,
+ end_fraction=exploration_fraction,
+ )
+
+ self.max_grad_norm = max_grad_norm
+
+ def learn(
+ self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
+ ) -> DQNSelf:
+ self.policy.train(True)
+ obs = self.env.reset()
+ obs = self._collect_rollout(self.learning_starts, obs, 1)
+ learning_steps = total_timesteps - self.learning_starts
+ timesteps_elapsed = 0
+ steps_since_target_update = 0
+ while timesteps_elapsed < learning_steps:
+ progress = timesteps_elapsed / learning_steps
+ eps = self.exploration_eps_schedule(progress)
+ obs = self._collect_rollout(self.train_freq, obs, eps)
+ rollout_steps = self.train_freq
+ timesteps_elapsed += rollout_steps
+ for _ in range(
+ self.gradient_steps if self.gradient_steps > 0 else self.train_freq
+ ):
+ self.train()
+ steps_since_target_update += rollout_steps
+ if steps_since_target_update >= self.target_update_interval:
+ self._update_target()
+ steps_since_target_update = 0
+ if callback:
+ callback.on_step(timesteps_elapsed=rollout_steps)
+ return self
+
+ def train(self) -> None:
+ if len(self.replay_buffer) < self.batch_size:
+ return
+ o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
+ o = torch.as_tensor(o, device=self.device)
+ a = torch.as_tensor(a, device=self.device).unsqueeze(1)
+ r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
+ d = torch.as_tensor(d, dtype=torch.long, device=self.device)
+ next_o = torch.as_tensor(next_o, device=self.device)
+
+ with torch.no_grad():
+ target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
+ current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
+ loss = F.smooth_l1_loss(current, target)
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ if self.max_grad_norm:
+ nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
+ self.optimizer.step()
+
+ def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
+ for _ in range(0, timesteps, self.env.num_envs):
+ action = self.policy.act(obs, eps, deterministic=False)
+ next_obs, reward, done, _ = self.env.step(action)
+ self.replay_buffer.add(obs, action, reward, done, next_obs)
+ obs = next_obs
+ return obs
+
+ def _update_target(self) -> None:
+ for target_param, param in zip(
+ self.target_q_net.parameters(), self.policy.q_net.parameters()
+ ):
+ target_param.data.copy_(
+ self.tau * param.data + (1 - self.tau) * target_param.data
+ )
diff --git a/rl_algo_impls/dqn/policy.py b/rl_algo_impls/dqn/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7189c107b3882d785b94b30501ef36b8123ae38
--- /dev/null
+++ b/rl_algo_impls/dqn/policy.py
@@ -0,0 +1,55 @@
+import numpy as np
+import os
+import torch
+
+from typing import Optional, Sequence, TypeVar
+
+from rl_algo_impls.dqn.q_net import QNetwork
+from rl_algo_impls.shared.policy.policy import Policy
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecEnv,
+ VecEnvObs,
+ single_observation_space,
+ single_action_space,
+)
+
+DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
+
+
+class DQNPolicy(Policy):
+ def __init__(
+ self,
+ env: VecEnv,
+ hidden_sizes: Sequence[int] = [],
+ cnn_feature_dim: int = 512,
+ cnn_style: str = "nature",
+ cnn_layers_init_orthogonal: Optional[bool] = None,
+ impala_channels: Sequence[int] = (16, 32, 32),
+ **kwargs,
+ ) -> None:
+ super().__init__(env, **kwargs)
+ self.q_net = QNetwork(
+ single_observation_space(env),
+ single_action_space(env),
+ hidden_sizes,
+ cnn_feature_dim=cnn_feature_dim,
+ cnn_style=cnn_style,
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
+ impala_channels=impala_channels,
+ )
+
+ def act(
+ self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
+ ) -> np.ndarray:
+ assert eps == 0 if deterministic else eps >= 0
+ if not deterministic and np.random.random() < eps:
+ return np.array(
+ [
+ single_action_space(self.env).sample()
+ for _ in range(self.env.num_envs)
+ ]
+ )
+ else:
+ o = self._as_tensor(obs)
+ with torch.no_grad():
+ return self.q_net(o).argmax(axis=1).cpu().numpy()
diff --git a/rl_algo_impls/dqn/q_net.py b/rl_algo_impls/dqn/q_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e9233547f60b7eabdf58633afcd067cdb2ca345
--- /dev/null
+++ b/rl_algo_impls/dqn/q_net.py
@@ -0,0 +1,41 @@
+import gym
+import torch as th
+import torch.nn as nn
+
+from gym.spaces import Discrete
+from typing import Optional, Sequence, Type
+
+from rl_algo_impls.shared.module.feature_extractor import FeatureExtractor
+from rl_algo_impls.shared.module.module import mlp
+
+
+class QNetwork(nn.Module):
+ def __init__(
+ self,
+ observation_space: gym.Space,
+ action_space: gym.Space,
+ hidden_sizes: Sequence[int] = [],
+ activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
+ cnn_feature_dim: int = 512,
+ cnn_style: str = "nature",
+ cnn_layers_init_orthogonal: Optional[bool] = None,
+ impala_channels: Sequence[int] = (16, 32, 32),
+ ) -> None:
+ super().__init__()
+ assert isinstance(action_space, Discrete)
+ self._feature_extractor = FeatureExtractor(
+ observation_space,
+ activation,
+ cnn_feature_dim=cnn_feature_dim,
+ cnn_style=cnn_style,
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
+ impala_channels=impala_channels,
+ )
+ layer_sizes = (
+ (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
+ )
+ self._fc = mlp(layer_sizes, activation)
+
+ def forward(self, obs: th.Tensor) -> th.Tensor:
+ x = self._feature_extractor(obs)
+ return self._fc(x)
diff --git a/rl_algo_impls/enjoy.py b/rl_algo_impls/enjoy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ba6169d34f531d5a8d77dada1d3a69c2f9fa471
--- /dev/null
+++ b/rl_algo_impls/enjoy.py
@@ -0,0 +1,35 @@
+# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
+import os
+
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+
+from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
+from rl_algo_impls.runner.running_utils import base_parser
+
+
+def enjoy() -> None:
+ parser = base_parser(multiple=False)
+ parser.add_argument("--render", default=True, type=bool)
+ parser.add_argument("--best", default=True, type=bool)
+ parser.add_argument("--n_envs", default=1, type=int)
+ parser.add_argument("--n_episodes", default=3, type=int)
+ parser.add_argument("--deterministic-eval", default=None, type=bool)
+ parser.add_argument(
+ "--no-print-returns", action="store_true", help="Limit printing"
+ )
+ # wandb-run-path overrides base RunArgs
+ parser.add_argument("--wandb-run-path", default=None, type=str)
+ parser.set_defaults(
+ algo=["ppo"],
+ wandb_run_path="sgoodfriend/rl-algo-impls/m5c1t7g5",
+ )
+ args = parser.parse_args()
+ args.algo = args.algo[0]
+ args.env = args.env[0]
+ args = EvalArgs(**vars(args))
+
+ evaluate_model(args, os.getcwd())
+
+
+if __name__ == "__main__":
+ enjoy()
diff --git a/rl_algo_impls/huggingface_publish.py b/rl_algo_impls/huggingface_publish.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf482357b0c44703a30f88b2e70a92542153781
--- /dev/null
+++ b/rl_algo_impls/huggingface_publish.py
@@ -0,0 +1,193 @@
+import os
+
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+
+import argparse
+import requests
+import shutil
+import subprocess
+import tempfile
+import wandb
+import wandb.apis.public
+
+from typing import List, Optional
+
+from huggingface_hub.hf_api import HfApi, upload_folder
+from huggingface_hub.repocard import metadata_save
+from pyvirtualdisplay.display import Display
+
+from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
+from rl_algo_impls.runner.config import EnvHyperparams
+from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
+from rl_algo_impls.runner.env import make_eval_env
+from rl_algo_impls.shared.callbacks.eval_callback import evaluate
+from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
+
+
+def publish(
+ wandb_run_paths: List[str],
+ wandb_report_url: str,
+ huggingface_user: Optional[str] = None,
+ huggingface_token: Optional[str] = None,
+ virtual_display: bool = False,
+) -> None:
+ if virtual_display:
+ display = Display(visible=False, size=(1400, 900))
+ display.start()
+
+ api = wandb.Api()
+ runs = [api.run(rp) for rp in wandb_run_paths]
+ algo = runs[0].config["algo"]
+ hyperparam_id = runs[0].config["env"]
+ evaluations = [
+ evaluate_model(
+ EvalArgs(
+ algo,
+ hyperparam_id,
+ seed=r.config.get("seed", None),
+ render=False,
+ best=True,
+ n_envs=None,
+ n_episodes=10,
+ no_print_returns=True,
+ wandb_run_path="/".join(r.path),
+ ),
+ os.getcwd(),
+ )
+ for r in runs
+ ]
+ run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
+ table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
+ best_eval = sorted(
+ table_data, key=lambda d: d.evaluation.stats.score, reverse=True
+ )[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ _, (policy, stats, config) = best_eval
+
+ repo_name = config.model_name(include_seed=False)
+ repo_dir_path = os.path.join(tmpdirname, repo_name)
+ # Locally clone this repo to a temp directory
+ subprocess.run(["git", "clone", ".", repo_dir_path])
+ shutil.rmtree(os.path.join(repo_dir_path, ".git"))
+ model_path = config.model_dir_path(best=True, downloaded=True)
+ shutil.copytree(
+ model_path,
+ os.path.join(
+ repo_dir_path, "saved_models", config.model_dir_name(best=True)
+ ),
+ )
+
+ github_url = "https://github.com/sgoodfriend/rl-algo-impls"
+ commit_hash = run_metadata.get("git", {}).get("commit", None)
+ env_id = runs[0].config.get("env_id") or runs[0].config["env"]
+ card_text = model_card_text(
+ algo,
+ env_id,
+ github_url,
+ commit_hash,
+ wandb_report_url,
+ table_data,
+ best_eval,
+ )
+ readme_filepath = os.path.join(repo_dir_path, "README.md")
+ os.remove(readme_filepath)
+ with open(readme_filepath, "w") as f:
+ f.write(card_text)
+
+ metadata = {
+ "library_name": "rl-algo-impls",
+ "tags": [
+ env_id,
+ algo,
+ "deep-reinforcement-learning",
+ "reinforcement-learning",
+ ],
+ "model-index": [
+ {
+ "name": algo,
+ "results": [
+ {
+ "metrics": [
+ {
+ "type": "mean_reward",
+ "value": str(stats.score),
+ "name": "mean_reward",
+ }
+ ],
+ "task": {
+ "type": "reinforcement-learning",
+ "name": "reinforcement-learning",
+ },
+ "dataset": {
+ "name": env_id,
+ "type": env_id,
+ },
+ }
+ ],
+ }
+ ],
+ }
+ metadata_save(readme_filepath, metadata)
+
+ video_env = VecEpisodeRecorder(
+ make_eval_env(
+ config,
+ EnvHyperparams(**config.env_hyperparams),
+ override_n_envs=1,
+ normalize_load_path=model_path,
+ ),
+ os.path.join(repo_dir_path, "replay"),
+ max_video_length=3600,
+ )
+ evaluate(
+ video_env,
+ policy,
+ 1,
+ deterministic=config.eval_params.get("deterministic", True),
+ )
+
+ api = HfApi()
+ huggingface_user = huggingface_user or api.whoami()["name"]
+ huggingface_repo = f"{huggingface_user}/{repo_name}"
+ api.create_repo(
+ token=huggingface_token,
+ repo_id=huggingface_repo,
+ private=False,
+ exist_ok=True,
+ )
+ repo_url = upload_folder(
+ repo_id=huggingface_repo,
+ folder_path=repo_dir_path,
+ path_in_repo="",
+ commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
+ token=huggingface_token,
+ )
+ print(f"Pushed model to the hub: {repo_url}")
+
+
+def huggingface_publish():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--wandb-run-paths",
+ type=str,
+ nargs="+",
+ help="Run paths of the form entity/project/run_id",
+ )
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
+ parser.add_argument(
+ "--huggingface-user",
+ type=str,
+ help="Huggingface user or team to upload model cards",
+ default=None,
+ )
+ parser.add_argument(
+ "--virtual-display", action="store_true", help="Use headless virtual display"
+ )
+ args = parser.parse_args()
+ print(args)
+ publish(**vars(args))
+
+
+if __name__ == "__main__":
+ huggingface_publish()
diff --git a/rl_algo_impls/hyperparams/a2c.yml b/rl_algo_impls/hyperparams/a2c.yml
new file mode 100644
index 0000000000000000000000000000000000000000..217892eb3e72cb326e327ea7061c262899bdce2c
--- /dev/null
+++ b/rl_algo_impls/hyperparams/a2c.yml
@@ -0,0 +1,138 @@
+CartPole-v1: &cartpole-defaults
+ n_timesteps: !!float 5e5
+ env_hyperparams:
+ n_envs: 8
+
+CartPole-v0:
+ <<: *cartpole-defaults
+
+MountainCar-v0:
+ n_timesteps: !!float 1e6
+ env_hyperparams:
+ n_envs: 16
+ normalize: true
+
+MountainCarContinuous-v0:
+ n_timesteps: !!float 1e5
+ env_hyperparams:
+ n_envs: 4
+ normalize: true
+ # policy_hyperparams:
+ # use_sde: true
+ # log_std_init: 0.0
+ # init_layers_orthogonal: false
+ algo_hyperparams:
+ n_steps: 100
+ sde_sample_freq: 16
+
+Acrobot-v1:
+ n_timesteps: !!float 5e5
+ env_hyperparams:
+ normalize: true
+ n_envs: 16
+
+# Tuned
+LunarLander-v2:
+ device: cpu
+ n_timesteps: !!float 1e6
+ env_hyperparams:
+ n_envs: 4
+ normalize: true
+ algo_hyperparams:
+ n_steps: 2
+ gamma: 0.9955517404308908
+ gae_lambda: 0.9875340918797773
+ learning_rate: 0.0013814130817068916
+ learning_rate_decay: linear
+ ent_coef: !!float 3.388369146384422e-7
+ ent_coef_decay: none
+ max_grad_norm: 3.33982095073364
+ normalize_advantage: true
+ vf_coef: 0.1667838310548184
+
+BipedalWalker-v3:
+ n_timesteps: !!float 5e6
+ env_hyperparams:
+ n_envs: 16
+ normalize: true
+ policy_hyperparams:
+ use_sde: true
+ log_std_init: -2
+ init_layers_orthogonal: false
+ algo_hyperparams:
+ ent_coef: 0
+ max_grad_norm: 0.5
+ n_steps: 8
+ gae_lambda: 0.9
+ vf_coef: 0.4
+ gamma: 0.99
+ learning_rate: !!float 9.6e-4
+ learning_rate_decay: linear
+
+HalfCheetahBulletEnv-v0: &pybullet-defaults
+ n_timesteps: !!float 2e6
+ env_hyperparams:
+ n_envs: 4
+ normalize: true
+ policy_hyperparams:
+ use_sde: true
+ log_std_init: -2
+ init_layers_orthogonal: false
+ algo_hyperparams: &pybullet-algo-defaults
+ n_steps: 8
+ ent_coef: 0
+ max_grad_norm: 0.5
+ gae_lambda: 0.9
+ gamma: 0.99
+ vf_coef: 0.4
+ learning_rate: !!float 9.6e-4
+ learning_rate_decay: linear
+
+AntBulletEnv-v0:
+ <<: *pybullet-defaults
+
+Walker2DBulletEnv-v0:
+ <<: *pybullet-defaults
+
+HopperBulletEnv-v0:
+ <<: *pybullet-defaults
+
+CarRacing-v0:
+ n_timesteps: !!float 4e6
+ env_hyperparams:
+ n_envs: 8
+ frame_stack: 4
+ normalize: true
+ normalize_kwargs:
+ norm_obs: false
+ norm_reward: true
+ policy_hyperparams:
+ use_sde: true
+ log_std_init: -2
+ init_layers_orthogonal: false
+ activation_fn: relu
+ share_features_extractor: false
+ cnn_feature_dim: 256
+ hidden_sizes: [256]
+ algo_hyperparams:
+ n_steps: 512
+ learning_rate: !!float 1.62e-5
+ gamma: 0.997
+ gae_lambda: 0.975
+ ent_coef: 0
+ sde_sample_freq: 128
+ vf_coef: 0.64
+
+_atari: &atari-defaults
+ n_timesteps: !!float 1e7
+ env_hyperparams: &atari-env-defaults
+ n_envs: 16
+ frame_stack: 4
+ no_reward_timeout_steps: 1000
+ no_reward_fire_steps: 500
+ vec_env_class: async
+ policy_hyperparams: &atari-policy-defaults
+ activation_fn: relu
+ algo_hyperparams:
+ ent_coef: 0.01
+ vf_coef: 0.25
diff --git a/rl_algo_impls/hyperparams/dqn.yml b/rl_algo_impls/hyperparams/dqn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..66003a67dd8c7865fd9ea269f6fc84b5d95fb428
--- /dev/null
+++ b/rl_algo_impls/hyperparams/dqn.yml
@@ -0,0 +1,130 @@
+CartPole-v1: &cartpole-defaults
+ n_timesteps: !!float 5e4
+ env_hyperparams:
+ rolling_length: 50
+ policy_hyperparams:
+ hidden_sizes: [256, 256]
+ algo_hyperparams:
+ learning_rate: !!float 2.3e-3
+ batch_size: 64
+ buffer_size: 100000
+ learning_starts: 1000
+ gamma: 0.99
+ target_update_interval: 10
+ train_freq: 256
+ gradient_steps: 128
+ exploration_fraction: 0.16
+ exploration_final_eps: 0.04
+ eval_params:
+ step_freq: !!float 1e4
+
+CartPole-v0:
+ <<: *cartpole-defaults
+ n_timesteps: !!float 4e4
+
+MountainCar-v0:
+ n_timesteps: !!float 1.2e5
+ env_hyperparams:
+ rolling_length: 50
+ policy_hyperparams:
+ hidden_sizes: [256, 256]
+ algo_hyperparams:
+ learning_rate: !!float 4e-3
+ batch_size: 128
+ buffer_size: 10000
+ learning_starts: 1000
+ gamma: 0.98
+ target_update_interval: 600
+ train_freq: 16
+ gradient_steps: 8
+ exploration_fraction: 0.2
+ exploration_final_eps: 0.07
+
+Acrobot-v1:
+ n_timesteps: !!float 1e5
+ env_hyperparams:
+ rolling_length: 50
+ policy_hyperparams:
+ hidden_sizes: [256, 256]
+ algo_hyperparams:
+ learning_rate: !!float 6.3e-4
+ batch_size: 128
+ buffer_size: 50000
+ learning_starts: 0
+ gamma: 0.99
+ target_update_interval: 250
+ train_freq: 4
+ gradient_steps: -1
+ exploration_fraction: 0.12
+ exploration_final_eps: 0.1
+
+LunarLander-v2:
+ n_timesteps: !!float 5e5
+ env_hyperparams:
+ rolling_length: 50
+ policy_hyperparams:
+ hidden_sizes: [256, 256]
+ algo_hyperparams:
+ learning_rate: !!float 1e-4
+ batch_size: 256
+ buffer_size: 100000
+ learning_starts: 10000
+ gamma: 0.99
+ target_update_interval: 250
+ train_freq: 8
+ gradient_steps: -1
+ exploration_fraction: 0.12
+ exploration_final_eps: 0.1
+ max_grad_norm: 0.5
+ eval_params:
+ step_freq: 25_000
+
+_atari: &atari-defaults
+ n_timesteps: !!float 1e7
+ env_hyperparams:
+ frame_stack: 4
+ no_reward_timeout_steps: 1_000
+ no_reward_fire_steps: 500
+ n_envs: 8
+ vec_env_class: async
+ algo_hyperparams:
+ buffer_size: 100000
+ learning_rate: !!float 1e-4
+ batch_size: 32
+ learning_starts: 100000
+ target_update_interval: 1000
+ train_freq: 8
+ gradient_steps: 2
+ exploration_fraction: 0.1
+ exploration_final_eps: 0.01
+ eval_params:
+ deterministic: false
+
+PongNoFrameskip-v4:
+ <<: *atari-defaults
+ n_timesteps: !!float 2.5e6
+
+_impala-atari: &impala-atari-defaults
+ <<: *atari-defaults
+ policy_hyperparams:
+ cnn_style: impala
+ cnn_feature_dim: 256
+ init_layers_orthogonal: true
+ cnn_layers_init_orthogonal: false
+
+impala-PongNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: PongNoFrameskip-v4
+ n_timesteps: !!float 2.5e6
+
+impala-BreakoutNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: BreakoutNoFrameskip-v4
+
+impala-SpaceInvadersNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: SpaceInvadersNoFrameskip-v4
+
+impala-QbertNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: QbertNoFrameskip-v4
diff --git a/rl_algo_impls/hyperparams/ppo.yml b/rl_algo_impls/hyperparams/ppo.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0136fc46b019f9722aaad54bfe93f36ed88f4bc8
--- /dev/null
+++ b/rl_algo_impls/hyperparams/ppo.yml
@@ -0,0 +1,383 @@
+CartPole-v1: &cartpole-defaults
+ n_timesteps: !!float 1e5
+ env_hyperparams:
+ n_envs: 8
+ algo_hyperparams:
+ n_steps: 32
+ batch_size: 256
+ n_epochs: 20
+ gae_lambda: 0.8
+ gamma: 0.98
+ ent_coef: 0.0
+ learning_rate: 0.001
+ learning_rate_decay: linear
+ clip_range: 0.2
+ clip_range_decay: linear
+ eval_params:
+ step_freq: !!float 2.5e4
+
+CartPole-v0:
+ <<: *cartpole-defaults
+ n_timesteps: !!float 5e4
+
+MountainCar-v0:
+ n_timesteps: !!float 1e6
+ env_hyperparams:
+ normalize: true
+ n_envs: 16
+ algo_hyperparams:
+ n_steps: 16
+ n_epochs: 4
+ gae_lambda: 0.98
+ gamma: 0.99
+ ent_coef: 0.0
+
+MountainCarContinuous-v0:
+ n_timesteps: !!float 1e5
+ env_hyperparams:
+ normalize: true
+ n_envs: 4
+ # policy_hyperparams:
+ # init_layers_orthogonal: false
+ # log_std_init: -3.29
+ # use_sde: true
+ algo_hyperparams:
+ n_steps: 512
+ batch_size: 256
+ n_epochs: 10
+ learning_rate: !!float 7.77e-5
+ ent_coef: 0.01 # 0.00429
+ ent_coef_decay: linear
+ clip_range: 0.1
+ gae_lambda: 0.9
+ max_grad_norm: 5
+ vf_coef: 0.19
+ eval_params:
+ step_freq: 5000
+
+Acrobot-v1:
+ n_timesteps: !!float 1e6
+ env_hyperparams:
+ n_envs: 16
+ normalize: true
+ algo_hyperparams:
+ n_steps: 256
+ n_epochs: 4
+ gae_lambda: 0.94
+ gamma: 0.99
+ ent_coef: 0.0
+
+LunarLander-v2:
+ n_timesteps: !!float 4e6
+ env_hyperparams:
+ n_envs: 16
+ algo_hyperparams:
+ n_steps: 1024
+ batch_size: 64
+ n_epochs: 4
+ gae_lambda: 0.98
+ gamma: 0.999
+ learning_rate: !!float 5e-4
+ learning_rate_decay: linear
+ clip_range: 0.2
+ clip_range_decay: linear
+ ent_coef: 0.01
+ normalize_advantage: false
+
+BipedalWalker-v3:
+ n_timesteps: !!float 10e6
+ env_hyperparams:
+ n_envs: 16
+ normalize: true
+ algo_hyperparams:
+ n_steps: 2048
+ batch_size: 64
+ gae_lambda: 0.95
+ gamma: 0.99
+ n_epochs: 10
+ ent_coef: 0.001
+ learning_rate: !!float 2.5e-4
+ learning_rate_decay: linear
+ clip_range: 0.2
+ clip_range_decay: linear
+
+CarRacing-v0: &carracing-defaults
+ n_timesteps: !!float 4e6
+ env_hyperparams:
+ n_envs: 8
+ frame_stack: 4
+ policy_hyperparams: &carracing-policy-defaults
+ use_sde: true
+ log_std_init: -2
+ init_layers_orthogonal: false
+ activation_fn: relu
+ share_features_extractor: false
+ cnn_feature_dim: 256
+ hidden_sizes: [256]
+ algo_hyperparams:
+ n_steps: 512
+ batch_size: 128
+ n_epochs: 10
+ learning_rate: !!float 1e-4
+ learning_rate_decay: linear
+ gamma: 0.99
+ gae_lambda: 0.95
+ ent_coef: 0.0
+ sde_sample_freq: 4
+ max_grad_norm: 0.5
+ vf_coef: 0.5
+ clip_range: 0.2
+
+impala-CarRacing-v0:
+ <<: *carracing-defaults
+ env_id: CarRacing-v0
+ policy_hyperparams:
+ <<: *carracing-policy-defaults
+ cnn_style: impala
+ init_layers_orthogonal: true
+ cnn_layers_init_orthogonal: false
+ hidden_sizes: []
+
+# BreakoutNoFrameskip-v4
+# PongNoFrameskip-v4
+# SpaceInvadersNoFrameskip-v4
+# QbertNoFrameskip-v4
+_atari: &atari-defaults
+ n_timesteps: !!float 1e7
+ env_hyperparams: &atari-env-defaults
+ n_envs: 8
+ frame_stack: 4
+ no_reward_timeout_steps: 1000
+ no_reward_fire_steps: 500
+ vec_env_class: async
+ policy_hyperparams: &atari-policy-defaults
+ activation_fn: relu
+ algo_hyperparams:
+ n_steps: 128
+ batch_size: 256
+ n_epochs: 4
+ learning_rate: !!float 2.5e-4
+ learning_rate_decay: linear
+ clip_range: 0.1
+ clip_range_decay: linear
+ vf_coef: 0.5
+ ent_coef: 0.01
+ eval_params:
+ deterministic: false
+
+_norm-rewards-atari: &norm-rewards-atari-default
+ <<: *atari-defaults
+ env_hyperparams:
+ <<: *atari-env-defaults
+ clip_atari_rewards: false
+ normalize: true
+ normalize_kwargs:
+ norm_obs: false
+ norm_reward: true
+
+norm-rewards-BreakoutNoFrameskip-v4:
+ <<: *norm-rewards-atari-default
+ env_id: BreakoutNoFrameskip-v4
+
+debug-PongNoFrameskip-v4:
+ <<: *atari-defaults
+ device: cpu
+ env_id: PongNoFrameskip-v4
+ env_hyperparams:
+ <<: *atari-env-defaults
+ vec_env_class: sync
+
+_impala-atari: &impala-atari-defaults
+ <<: *atari-defaults
+ policy_hyperparams:
+ <<: *atari-policy-defaults
+ cnn_style: impala
+ cnn_feature_dim: 256
+ init_layers_orthogonal: true
+ cnn_layers_init_orthogonal: false
+
+impala-PongNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: PongNoFrameskip-v4
+
+impala-BreakoutNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: BreakoutNoFrameskip-v4
+
+impala-SpaceInvadersNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: SpaceInvadersNoFrameskip-v4
+
+impala-QbertNoFrameskip-v4:
+ <<: *impala-atari-defaults
+ env_id: QbertNoFrameskip-v4
+
+HalfCheetahBulletEnv-v0: &pybullet-defaults
+ n_timesteps: !!float 2e6
+ env_hyperparams: &pybullet-env-defaults
+ n_envs: 16
+ normalize: true
+ policy_hyperparams: &pybullet-policy-defaults
+ pi_hidden_sizes: [256, 256]
+ v_hidden_sizes: [256, 256]
+ activation_fn: relu
+ algo_hyperparams: &pybullet-algo-defaults
+ n_steps: 512
+ batch_size: 128
+ n_epochs: 20
+ gamma: 0.99
+ gae_lambda: 0.9
+ ent_coef: 0.0
+ max_grad_norm: 0.5
+ vf_coef: 0.5
+ learning_rate: !!float 3e-5
+ clip_range: 0.4
+
+AntBulletEnv-v0:
+ <<: *pybullet-defaults
+ policy_hyperparams:
+ <<: *pybullet-policy-defaults
+ algo_hyperparams:
+ <<: *pybullet-algo-defaults
+
+Walker2DBulletEnv-v0:
+ <<: *pybullet-defaults
+ algo_hyperparams:
+ <<: *pybullet-algo-defaults
+ clip_range_decay: linear
+
+HopperBulletEnv-v0:
+ <<: *pybullet-defaults
+ algo_hyperparams:
+ <<: *pybullet-algo-defaults
+ clip_range_decay: linear
+
+HumanoidBulletEnv-v0:
+ <<: *pybullet-defaults
+ n_timesteps: !!float 1e7
+ env_hyperparams:
+ <<: *pybullet-env-defaults
+ n_envs: 8
+ policy_hyperparams:
+ <<: *pybullet-policy-defaults
+ # log_std_init: -1
+ algo_hyperparams:
+ <<: *pybullet-algo-defaults
+ n_steps: 2048
+ batch_size: 64
+ n_epochs: 10
+ gae_lambda: 0.95
+ learning_rate: !!float 2.5e-4
+ clip_range: 0.2
+
+_procgen: &procgen-defaults
+ env_hyperparams: &procgen-env-defaults
+ env_type: procgen
+ n_envs: 64
+ # grayscale: false
+ # frame_stack: 4
+ normalize: true # procgen only normalizes reward
+ make_kwargs: &procgen-make-kwargs-defaults
+ num_threads: 8
+ policy_hyperparams: &procgen-policy-defaults
+ activation_fn: relu
+ cnn_style: impala
+ cnn_feature_dim: 256
+ init_layers_orthogonal: true
+ cnn_layers_init_orthogonal: false
+ algo_hyperparams: &procgen-algo-defaults
+ gamma: 0.999
+ gae_lambda: 0.95
+ n_steps: 256
+ batch_size: 2048
+ n_epochs: 3
+ ent_coef: 0.01
+ clip_range: 0.2
+ # clip_range_decay: linear
+ clip_range_vf: 0.2
+ learning_rate: !!float 5e-4
+ # learning_rate_decay: linear
+ vf_coef: 0.5
+ eval_params: &procgen-eval-defaults
+ ignore_first_episode: true
+ # deterministic: false
+ step_freq: !!float 1e5
+
+_procgen-easy: &procgen-easy-defaults
+ <<: *procgen-defaults
+ n_timesteps: !!float 25e6
+ env_hyperparams: &procgen-easy-env-defaults
+ <<: *procgen-env-defaults
+ make_kwargs:
+ <<: *procgen-make-kwargs-defaults
+ distribution_mode: easy
+
+procgen-coinrun-easy: &coinrun-easy-defaults
+ <<: *procgen-easy-defaults
+ env_id: coinrun
+
+debug-procgen-coinrun:
+ <<: *coinrun-easy-defaults
+ device: cpu
+
+procgen-starpilot-easy:
+ <<: *procgen-easy-defaults
+ env_id: starpilot
+
+procgen-bossfight-easy:
+ <<: *procgen-easy-defaults
+ env_id: bossfight
+
+procgen-bigfish-easy:
+ <<: *procgen-easy-defaults
+ env_id: bigfish
+
+_procgen-hard: &procgen-hard-defaults
+ <<: *procgen-defaults
+ n_timesteps: !!float 200e6
+ env_hyperparams: &procgen-hard-env-defaults
+ <<: *procgen-env-defaults
+ n_envs: 256
+ make_kwargs:
+ <<: *procgen-make-kwargs-defaults
+ distribution_mode: hard
+ algo_hyperparams: &procgen-hard-algo-defaults
+ <<: *procgen-algo-defaults
+ batch_size: 8192
+ clip_range_decay: linear
+ learning_rate_decay: linear
+ eval_params:
+ <<: *procgen-eval-defaults
+ step_freq: !!float 5e5
+
+procgen-starpilot-hard: &procgen-starpilot-hard-defaults
+ <<: *procgen-hard-defaults
+ env_id: starpilot
+
+procgen-starpilot-hard-2xIMPALA:
+ <<: *procgen-starpilot-hard-defaults
+ policy_hyperparams:
+ <<: *procgen-policy-defaults
+ impala_channels: [32, 64, 64]
+ algo_hyperparams:
+ <<: *procgen-hard-algo-defaults
+ learning_rate: !!float 3.3e-4
+
+procgen-starpilot-hard-2xIMPALA-fat:
+ <<: *procgen-starpilot-hard-defaults
+ policy_hyperparams:
+ <<: *procgen-policy-defaults
+ impala_channels: [32, 64, 64]
+ cnn_feature_dim: 512
+ algo_hyperparams:
+ <<: *procgen-hard-algo-defaults
+ learning_rate: !!float 2.5e-4
+
+procgen-starpilot-hard-4xIMPALA:
+ <<: *procgen-starpilot-hard-defaults
+ policy_hyperparams:
+ <<: *procgen-policy-defaults
+ impala_channels: [64, 128, 128]
+ algo_hyperparams:
+ <<: *procgen-hard-algo-defaults
+ learning_rate: !!float 2.1e-4
diff --git a/rl_algo_impls/hyperparams/vpg.yml b/rl_algo_impls/hyperparams/vpg.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e472a9226b830c127f044718672d6d0c9e8c83dc
--- /dev/null
+++ b/rl_algo_impls/hyperparams/vpg.yml
@@ -0,0 +1,197 @@
+CartPole-v1: &cartpole-defaults
+ n_timesteps: !!float 4e5
+ algo_hyperparams:
+ n_steps: 4096
+ pi_lr: 0.01
+ gamma: 0.99
+ gae_lambda: 1
+ val_lr: 0.01
+ train_v_iters: 80
+ eval_params:
+ step_freq: !!float 2.5e4
+
+CartPole-v0:
+ <<: *cartpole-defaults
+ n_timesteps: !!float 1e5
+ algo_hyperparams:
+ n_steps: 1024
+ pi_lr: 0.01
+ gamma: 0.99
+ gae_lambda: 1
+ val_lr: 0.01
+ train_v_iters: 80
+
+MountainCar-v0:
+ n_timesteps: !!float 1e6
+ env_hyperparams:
+ normalize: true
+ n_envs: 16
+ algo_hyperparams:
+ n_steps: 200
+ pi_lr: 0.005
+ gamma: 0.99
+ gae_lambda: 0.97
+ val_lr: 0.01
+ train_v_iters: 80
+ max_grad_norm: 0.5
+
+MountainCarContinuous-v0:
+ n_timesteps: !!float 3e5
+ env_hyperparams:
+ normalize: true
+ n_envs: 4
+ # policy_hyperparams:
+ # init_layers_orthogonal: false
+ # log_std_init: -3.29
+ # use_sde: true
+ algo_hyperparams:
+ n_steps: 1000
+ pi_lr: !!float 5e-4
+ gamma: 0.99
+ gae_lambda: 0.9
+ val_lr: !!float 1e-3
+ train_v_iters: 80
+ max_grad_norm: 5
+ eval_params:
+ step_freq: 5000
+
+Acrobot-v1:
+ n_timesteps: !!float 2e5
+ algo_hyperparams:
+ n_steps: 2048
+ pi_lr: 0.005
+ gamma: 0.99
+ gae_lambda: 0.97
+ val_lr: 0.01
+ train_v_iters: 80
+ max_grad_norm: 0.5
+
+LunarLander-v2:
+ n_timesteps: !!float 4e6
+ policy_hyperparams:
+ hidden_sizes: [256, 256]
+ algo_hyperparams:
+ n_steps: 2048
+ pi_lr: 0.0001
+ gamma: 0.999
+ gae_lambda: 0.97
+ val_lr: 0.0001
+ train_v_iters: 80
+ max_grad_norm: 0.5
+ eval_params:
+ deterministic: false
+
+BipedalWalker-v3:
+ n_timesteps: !!float 10e6
+ env_hyperparams:
+ n_envs: 16
+ normalize: true
+ policy_hyperparams:
+ hidden_sizes: [256, 256]
+ algo_hyperparams:
+ n_steps: 1600
+ gae_lambda: 0.95
+ gamma: 0.99
+ pi_lr: !!float 1e-4
+ val_lr: !!float 1e-4
+ train_v_iters: 80
+ max_grad_norm: 0.5
+ eval_params:
+ deterministic: false
+
+CarRacing-v0:
+ n_timesteps: !!float 4e6
+ env_hyperparams:
+ frame_stack: 4
+ n_envs: 4
+ vec_env_class: sync
+ policy_hyperparams:
+ use_sde: true
+ log_std_init: -2
+ init_layers_orthogonal: false
+ activation_fn: relu
+ cnn_feature_dim: 256
+ hidden_sizes: [256]
+ algo_hyperparams:
+ n_steps: 1000
+ pi_lr: !!float 5e-5
+ gamma: 0.99
+ gae_lambda: 0.95
+ val_lr: !!float 1e-4
+ train_v_iters: 40
+ max_grad_norm: 0.5
+ sde_sample_freq: 4
+
+HalfCheetahBulletEnv-v0: &pybullet-defaults
+ n_timesteps: !!float 2e6
+ env_hyperparams: &pybullet-env-defaults
+ normalize: true
+ policy_hyperparams: &pybullet-policy-defaults
+ hidden_sizes: [256, 256]
+ algo_hyperparams: &pybullet-algo-defaults
+ n_steps: 4000
+ pi_lr: !!float 3e-4
+ gamma: 0.99
+ gae_lambda: 0.97
+ val_lr: !!float 1e-3
+ train_v_iters: 80
+ max_grad_norm: 0.5
+
+AntBulletEnv-v0:
+ <<: *pybullet-defaults
+ policy_hyperparams:
+ <<: *pybullet-policy-defaults
+ hidden_sizes: [400, 300]
+ algo_hyperparams:
+ <<: *pybullet-algo-defaults
+ pi_lr: !!float 7e-4
+ val_lr: !!float 7e-3
+
+HopperBulletEnv-v0:
+ <<: *pybullet-defaults
+
+Walker2DBulletEnv-v0:
+ <<: *pybullet-defaults
+
+FrozenLake-v1:
+ n_timesteps: !!float 8e5
+ env_params:
+ make_kwargs:
+ map_name: 8x8
+ is_slippery: true
+ policy_hyperparams:
+ hidden_sizes: [64]
+ algo_hyperparams:
+ n_steps: 2048
+ pi_lr: 0.01
+ gamma: 0.99
+ gae_lambda: 0.98
+ val_lr: 0.01
+ train_v_iters: 80
+ max_grad_norm: 0.5
+ eval_params:
+ step_freq: !!float 5e4
+ n_episodes: 10
+ save_best: true
+
+_atari: &atari-defaults
+ n_timesteps: !!float 25e6
+ env_hyperparams:
+ n_envs: 4
+ frame_stack: 4
+ no_reward_timeout_steps: 1000
+ no_reward_fire_steps: 500
+ vec_env_class: async
+ policy_hyperparams:
+ activation_fn: relu
+ algo_hyperparams:
+ n_steps: 2048
+ pi_lr: !!float 5e-5
+ gamma: 0.99
+ gae_lambda: 0.95
+ val_lr: !!float 1e-4
+ train_v_iters: 80
+ max_grad_norm: 0.5
+ ent_coef: 0.01
+ eval_params:
+ deterministic: false
diff --git a/rl_algo_impls/optimize.py b/rl_algo_impls/optimize.py
new file mode 100644
index 0000000000000000000000000000000000000000..1078a9bea61fd49f728d25b7ece323202ad31104
--- /dev/null
+++ b/rl_algo_impls/optimize.py
@@ -0,0 +1,441 @@
+import dataclasses
+import gc
+import inspect
+import logging
+import numpy as np
+import optuna
+import os
+import torch
+import wandb
+
+from dataclasses import asdict, dataclass
+from optuna.pruners import HyperbandPruner
+from optuna.samplers import TPESampler
+from optuna.visualization import plot_optimization_history, plot_param_importances
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Callable, List, NamedTuple, Optional, Sequence, Union
+
+from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
+from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
+from rl_algo_impls.runner.env import make_env, make_eval_env
+from rl_algo_impls.runner.running_utils import (
+ base_parser,
+ load_hyperparams,
+ set_seeds,
+ get_device,
+ make_policy,
+ ALGOS,
+ hparam_dict,
+)
+from rl_algo_impls.shared.callbacks.optimize_callback import (
+ Evaluation,
+ OptimizeCallback,
+ evaluation,
+)
+from rl_algo_impls.shared.stats import EpisodesStats
+
+
+@dataclass
+class StudyArgs:
+ load_study: bool
+ study_name: Optional[str] = None
+ storage_path: Optional[str] = None
+ n_trials: int = 100
+ n_jobs: int = 1
+ n_evaluations: int = 4
+ n_eval_envs: int = 8
+ n_eval_episodes: int = 16
+ timeout: Union[int, float, None] = None
+ wandb_project_name: Optional[str] = None
+ wandb_entity: Optional[str] = None
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
+ wandb_group: Optional[str] = None
+ virtual_display: bool = False
+
+
+class Args(NamedTuple):
+ train_args: Sequence[RunArgs]
+ study_args: StudyArgs
+
+
+def parse_args() -> Args:
+ parser = base_parser()
+ parser.add_argument(
+ "--load-study",
+ action="store_true",
+ help="Load a preexisting study, useful for parallelization",
+ )
+ parser.add_argument("--study-name", type=str, help="Optuna study name")
+ parser.add_argument(
+ "--storage-path",
+ type=str,
+ help="Path of database for Optuna to persist to",
+ )
+ parser.add_argument(
+ "--wandb-project-name",
+ type=str,
+ default="rl-algo-impls-tuning",
+ help="WandB project name to upload tuning data to. If none, won't upload",
+ )
+ parser.add_argument(
+ "--wandb-entity",
+ type=str,
+ help="WandB team. None uses the default entity",
+ )
+ parser.add_argument(
+ "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
+ )
+ parser.add_argument(
+ "--wandb-group", type=str, help="WandB group to group trials under"
+ )
+ parser.add_argument(
+ "--n-trials", type=int, default=100, help="Maximum number of trials"
+ )
+ parser.add_argument(
+ "--n-jobs", type=int, default=1, help="Number of jobs to run in parallel"
+ )
+ parser.add_argument(
+ "--n-evaluations",
+ type=int,
+ default=4,
+ help="Number of evaluations during the training",
+ )
+ parser.add_argument(
+ "--n-eval-envs",
+ type=int,
+ default=8,
+ help="Number of envs in vectorized eval environment",
+ )
+ parser.add_argument(
+ "--n-eval-episodes",
+ type=int,
+ default=16,
+ help="Number of episodes to complete for evaluation",
+ )
+ parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization")
+ parser.add_argument(
+ "--virtual-display", action="store_true", help="Use headless virtual display"
+ )
+ # parser.set_defaults(
+ # algo=["a2c"],
+ # env=["CartPole-v1"],
+ # seed=[100, 200, 300],
+ # n_trials=5,
+ # virtual_display=True,
+ # )
+ train_dict, study_dict = {}, {}
+ for k, v in vars(parser.parse_args()).items():
+ if k in inspect.signature(StudyArgs).parameters:
+ study_dict[k] = v
+ else:
+ train_dict[k] = v
+
+ study_args = StudyArgs(**study_dict)
+ # Hyperparameter tuning across algos and envs not supported
+ assert len(train_dict["algo"]) == 1
+ assert len(train_dict["env"]) == 1
+ train_args = RunArgs.expand_from_dict(train_dict)
+
+ if not all((study_args.study_name, study_args.storage_path)):
+ hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env)
+ config = Config(train_args[0], hyperparams, os.getcwd())
+ if study_args.study_name is None:
+ study_args.study_name = config.run_name(include_seed=False)
+ if study_args.storage_path is None:
+ study_args.storage_path = (
+ f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}"
+ )
+ # Default set group name to study name
+ study_args.wandb_group = study_args.wandb_group or study_args.study_name
+
+ return Args(train_args, study_args)
+
+
+def objective_fn(
+ args: Sequence[RunArgs], study_args: StudyArgs
+) -> Callable[[optuna.Trial], float]:
+ def objective(trial: optuna.Trial) -> float:
+ if len(args) == 1:
+ return simple_optimize(trial, args[0], study_args)
+ else:
+ return stepwise_optimize(trial, args, study_args)
+
+ return objective
+
+
+def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float:
+ base_hyperparams = load_hyperparams(args.algo, args.env)
+ base_config = Config(args, base_hyperparams, os.getcwd())
+ if args.algo == "a2c":
+ hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
+ else:
+ raise ValueError(f"Optimizing {args.algo} isn't supported")
+ config = Config(args, hyperparams, os.getcwd())
+
+ wandb_enabled = bool(study_args.wandb_project_name)
+ if wandb_enabled:
+ wandb.init(
+ project=study_args.wandb_project_name,
+ entity=study_args.wandb_entity,
+ config=asdict(hyperparams),
+ name=f"{config.model_name()}-{str(trial.number)}",
+ tags=study_args.wandb_tags,
+ group=study_args.wandb_group,
+ sync_tensorboard=True,
+ monitor_gym=True,
+ save_code=True,
+ reinit=True,
+ )
+ wandb.config.update(args)
+
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
+ set_seeds(args.seed, args.use_deterministic_algorithms)
+
+ env = make_env(
+ config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
+ )
+ device = get_device(config.device, env)
+ policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
+
+ eval_env = make_eval_env(
+ config,
+ EnvHyperparams(**config.env_hyperparams),
+ override_n_envs=study_args.n_eval_envs,
+ )
+ callback = OptimizeCallback(
+ policy,
+ eval_env,
+ trial,
+ tb_writer,
+ step_freq=config.n_timesteps // study_args.n_evaluations,
+ n_episodes=study_args.n_eval_episodes,
+ deterministic=config.eval_params.get("deterministic", True),
+ )
+ try:
+ algo.learn(config.n_timesteps, callback=callback)
+
+ if not callback.is_pruned:
+ callback.evaluate()
+ if not callback.is_pruned:
+ policy.save(config.model_dir_path(best=False))
+
+ eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
+ train_stat: EpisodesStats = callback.last_train_stat # type: ignore
+
+ tb_writer.add_hparams(
+ hparam_dict(hyperparams, vars(args)),
+ {
+ "hparam/last_mean": eval_stat.score.mean,
+ "hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
+ "hparam/train_mean": train_stat.score.mean,
+ "hparam/train_result": train_stat.score.mean - train_stat.score.std,
+ "hparam/score": callback.last_score,
+ "hparam/is_pruned": callback.is_pruned,
+ },
+ None,
+ config.run_name(),
+ )
+ tb_writer.close()
+
+ if wandb_enabled:
+ wandb.run.summary["state"] = "Pruned" if callback.is_pruned else "Complete"
+ wandb.finish(quiet=True)
+
+ if callback.is_pruned:
+ raise optuna.exceptions.TrialPruned()
+
+ return callback.last_score
+ except AssertionError as e:
+ logging.warning(e)
+ return np.nan
+ finally:
+ env.close()
+ eval_env.close()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+def stepwise_optimize(
+ trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs
+) -> float:
+ algo = args[0].algo
+ env_id = args[0].env
+ base_hyperparams = load_hyperparams(algo, env_id)
+ base_config = Config(args[0], base_hyperparams, os.getcwd())
+ if algo == "a2c":
+ hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
+ else:
+ raise ValueError(f"Optimizing {algo} isn't supported")
+
+ wandb_enabled = bool(study_args.wandb_project_name)
+ if wandb_enabled:
+ wandb.init(
+ project=study_args.wandb_project_name,
+ entity=study_args.wandb_entity,
+ config=asdict(hyperparams),
+ name=f"{study_args.study_name}-{str(trial.number)}",
+ tags=study_args.wandb_tags,
+ group=study_args.wandb_group,
+ save_code=True,
+ reinit=True,
+ )
+
+ score = -np.inf
+
+ for i in range(study_args.n_evaluations):
+ evaluations: List[Evaluation] = []
+
+ for arg in args:
+ config = Config(arg, hyperparams, os.getcwd())
+
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
+ set_seeds(arg.seed, arg.use_deterministic_algorithms)
+
+ env = make_env(
+ config,
+ EnvHyperparams(**config.env_hyperparams),
+ normalize_load_path=config.model_dir_path() if i > 0 else None,
+ tb_writer=tb_writer,
+ )
+ device = get_device(config.device, env)
+ policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
+ if i > 0:
+ policy.load(config.model_dir_path())
+ algo = ALGOS[arg.algo](
+ policy, env, device, tb_writer, **config.algo_hyperparams
+ )
+
+ eval_env = make_eval_env(
+ config,
+ EnvHyperparams(**config.env_hyperparams),
+ normalize_load_path=config.model_dir_path() if i > 0 else None,
+ override_n_envs=study_args.n_eval_envs,
+ )
+
+ start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
+ train_timesteps = (
+ int((i + 1) * config.n_timesteps / study_args.n_evaluations)
+ - start_timesteps
+ )
+
+ try:
+ algo.learn(
+ train_timesteps,
+ callback=None,
+ total_timesteps=config.n_timesteps,
+ start_timesteps=start_timesteps,
+ )
+
+ evaluations.append(
+ evaluation(
+ policy,
+ eval_env,
+ tb_writer,
+ study_args.n_eval_episodes,
+ config.eval_params.get("deterministic", True),
+ start_timesteps + train_timesteps,
+ )
+ )
+
+ policy.save(config.model_dir_path())
+
+ tb_writer.close()
+
+ except AssertionError as e:
+ logging.warning(e)
+ if wandb_enabled:
+ wandb_finish("Error")
+ return np.nan
+ finally:
+ env.close()
+ eval_env.close()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ d = {}
+ for idx, e in enumerate(evaluations):
+ d[f"{idx}/eval_mean"] = e.eval_stat.score.mean
+ d[f"{idx}/train_mean"] = e.train_stat.score.mean
+ d[f"{idx}/score"] = e.score
+ d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item()
+ d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item()
+ score = np.mean([e.score for e in evaluations]).item()
+ d["score"] = score
+
+ step = i + 1
+ wandb.log(d, step=step)
+
+ print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}")
+ trial.report(score, step)
+ if trial.should_prune():
+ if wandb_enabled:
+ wandb_finish("Pruned")
+ raise optuna.exceptions.TrialPruned()
+
+ if wandb_enabled:
+ wandb_finish("Complete")
+ return score
+
+
+def wandb_finish(state: str) -> None:
+ wandb.run.summary["state"] = state
+ wandb.finish(quiet=True)
+
+
+def optimize() -> None:
+ from pyvirtualdisplay.display import Display
+
+ train_args, study_args = parse_args()
+ if study_args.virtual_display:
+ virtual_display = Display(visible=False, size=(1400, 900))
+ virtual_display.start()
+
+ sampler = TPESampler(**TPESampler.hyperopt_parameters())
+ pruner = HyperbandPruner()
+ if study_args.load_study:
+ assert study_args.study_name
+ assert study_args.storage_path
+ study = optuna.load_study(
+ study_name=study_args.study_name,
+ storage=study_args.storage_path,
+ sampler=sampler,
+ pruner=pruner,
+ )
+ else:
+ study = optuna.create_study(
+ study_name=study_args.study_name,
+ storage=study_args.storage_path,
+ sampler=sampler,
+ pruner=pruner,
+ direction="maximize",
+ )
+
+ try:
+ study.optimize(
+ objective_fn(train_args, study_args),
+ n_trials=study_args.n_trials,
+ n_jobs=study_args.n_jobs,
+ timeout=study_args.timeout,
+ )
+ except KeyboardInterrupt:
+ pass
+
+ best = study.best_trial
+ print(f"Best Trial Value: {best.value}")
+ print("Attributes:")
+ for key, value in list(best.params.items()) + list(best.user_attrs.items()):
+ print(f" {key}: {value}")
+
+ df = study.trials_dataframe()
+ df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False)
+ print(df.to_markdown(index=False))
+
+ fig1 = plot_optimization_history(study)
+ fig1.write_image("opt_history.png")
+ fig2 = plot_param_importances(study)
+ fig2.write_image("param_importances.png")
+
+
+if __name__ == "__main__":
+ optimize()
diff --git a/rl_algo_impls/ppo/ppo.py b/rl_algo_impls/ppo/ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1a5850ebe02ed1d946cd399b49e64375c0a6089
--- /dev/null
+++ b/rl_algo_impls/ppo/ppo.py
@@ -0,0 +1,353 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from dataclasses import asdict, dataclass, field
+from time import perf_counter
+from torch.optim import Adam
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import List, Optional, NamedTuple, TypeVar
+
+from rl_algo_impls.shared.algorithm import Algorithm
+from rl_algo_impls.shared.callbacks.callback import Callback
+from rl_algo_impls.shared.gae import compute_advantage, compute_rtg_and_advantage
+from rl_algo_impls.shared.policy.on_policy import ActorCritic
+from rl_algo_impls.shared.schedule import (
+ constant_schedule,
+ linear_schedule,
+ update_learning_rate,
+)
+from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
+
+
+@dataclass
+class PPOTrajectory(Trajectory):
+ logp_a: List[float] = field(default_factory=list)
+
+ def add(
+ self,
+ obs: np.ndarray,
+ act: np.ndarray,
+ next_obs: np.ndarray,
+ rew: float,
+ terminated: bool,
+ v: float,
+ logp_a: float,
+ ):
+ super().add(obs, act, next_obs, rew, terminated, v)
+ self.logp_a.append(logp_a)
+
+
+class PPOTrajectoryAccumulator(TrajectoryAccumulator):
+ def __init__(self, num_envs: int) -> None:
+ super().__init__(num_envs, PPOTrajectory)
+
+ def step(
+ self,
+ obs: VecEnvObs,
+ action: np.ndarray,
+ next_obs: VecEnvObs,
+ reward: np.ndarray,
+ done: np.ndarray,
+ val: np.ndarray,
+ logp_a: np.ndarray,
+ ) -> None:
+ super().step(obs, action, next_obs, reward, done, val, logp_a)
+
+
+class TrainStepStats(NamedTuple):
+ loss: float
+ pi_loss: float
+ v_loss: float
+ entropy_loss: float
+ approx_kl: float
+ clipped_frac: float
+ val_clipped_frac: float
+
+
+@dataclass
+class TrainStats:
+ loss: float
+ pi_loss: float
+ v_loss: float
+ entropy_loss: float
+ approx_kl: float
+ clipped_frac: float
+ val_clipped_frac: float
+ explained_var: float
+
+ def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None:
+ self.loss = np.mean([s.loss for s in step_stats]).item()
+ self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
+ self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
+ self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
+ self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
+ self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
+ self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item()
+ self.explained_var = explained_var
+
+ def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
+ for name, value in asdict(self).items():
+ tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
+
+ def __repr__(self) -> str:
+ return " | ".join(
+ [
+ f"Loss: {round(self.loss, 2)}",
+ f"Pi L: {round(self.pi_loss, 2)}",
+ f"V L: {round(self.v_loss, 2)}",
+ f"E L: {round(self.entropy_loss, 2)}",
+ f"Apx KL Div: {round(self.approx_kl, 2)}",
+ f"Clip Frac: {round(self.clipped_frac, 2)}",
+ f"Val Clip Frac: {round(self.val_clipped_frac, 2)}",
+ ]
+ )
+
+
+PPOSelf = TypeVar("PPOSelf", bound="PPO")
+
+
+class PPO(Algorithm):
+ def __init__(
+ self,
+ policy: ActorCritic,
+ env: VecEnv,
+ device: torch.device,
+ tb_writer: SummaryWriter,
+ learning_rate: float = 3e-4,
+ learning_rate_decay: str = "none",
+ n_steps: int = 2048,
+ batch_size: int = 64,
+ n_epochs: int = 10,
+ gamma: float = 0.99,
+ gae_lambda: float = 0.95,
+ clip_range: float = 0.2,
+ clip_range_decay: str = "none",
+ clip_range_vf: Optional[float] = None,
+ clip_range_vf_decay: str = "none",
+ normalize_advantage: bool = True,
+ ent_coef: float = 0.0,
+ ent_coef_decay: str = "none",
+ vf_coef: float = 0.5,
+ ppo2_vf_coef_halving: bool = False,
+ max_grad_norm: float = 0.5,
+ update_rtg_between_epochs: bool = False,
+ sde_sample_freq: int = -1,
+ ) -> None:
+ super().__init__(policy, env, device, tb_writer)
+ self.policy = policy
+
+ self.gamma = gamma
+ self.gae_lambda = gae_lambda
+ self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
+ self.lr_schedule = (
+ linear_schedule(learning_rate, 0)
+ if learning_rate_decay == "linear"
+ else constant_schedule(learning_rate)
+ )
+ self.max_grad_norm = max_grad_norm
+ self.clip_range_schedule = (
+ linear_schedule(clip_range, 0)
+ if clip_range_decay == "linear"
+ else constant_schedule(clip_range)
+ )
+ self.clip_range_vf_schedule = None
+ if clip_range_vf:
+ self.clip_range_vf_schedule = (
+ linear_schedule(clip_range_vf, 0)
+ if clip_range_vf_decay == "linear"
+ else constant_schedule(clip_range_vf)
+ )
+ self.normalize_advantage = normalize_advantage
+ self.ent_coef_schedule = (
+ linear_schedule(ent_coef, 0)
+ if ent_coef_decay == "linear"
+ else constant_schedule(ent_coef)
+ )
+ self.vf_coef = vf_coef
+ self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
+
+ self.n_steps = n_steps
+ self.batch_size = batch_size
+ self.n_epochs = n_epochs
+ self.sde_sample_freq = sde_sample_freq
+
+ self.update_rtg_between_epochs = update_rtg_between_epochs
+
+ def learn(
+ self: PPOSelf,
+ total_timesteps: int,
+ callback: Optional[Callback] = None,
+ ) -> PPOSelf:
+ obs = self.env.reset()
+ ts_elapsed = 0
+ while ts_elapsed < total_timesteps:
+ start_time = perf_counter()
+ accumulator = self._collect_trajectories(obs)
+ rollout_steps = self.n_steps * self.env.num_envs
+ ts_elapsed += rollout_steps
+ progress = ts_elapsed / total_timesteps
+ train_stats = self.train(accumulator.all_trajectories, progress, ts_elapsed)
+ train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
+ end_time = perf_counter()
+ self.tb_writer.add_scalar(
+ "train/steps_per_second",
+ rollout_steps / (end_time - start_time),
+ ts_elapsed,
+ )
+ if callback:
+ callback.on_step(timesteps_elapsed=rollout_steps)
+
+ return self
+
+ def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
+ self.policy.eval()
+ accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
+ self.policy.reset_noise()
+ for i in range(self.n_steps):
+ if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
+ self.policy.reset_noise()
+ action, value, logp_a, clamped_action = self.policy.step(obs)
+ next_obs, reward, done, _ = self.env.step(clamped_action)
+ accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
+ obs = next_obs
+ return accumulator
+
+ def train(
+ self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
+ ) -> TrainStats:
+ self.policy.train()
+ learning_rate = self.lr_schedule(progress)
+ update_learning_rate(self.optimizer, learning_rate)
+ self.tb_writer.add_scalar(
+ "charts/learning_rate",
+ self.optimizer.param_groups[0]["lr"],
+ timesteps_elapsed,
+ )
+
+ pi_clip = self.clip_range_schedule(progress)
+ self.tb_writer.add_scalar("charts/pi_clip", pi_clip, timesteps_elapsed)
+ if self.clip_range_vf_schedule:
+ v_clip = self.clip_range_vf_schedule(progress)
+ self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
+ else:
+ v_clip = None
+ ent_coef = self.ent_coef_schedule(progress)
+ self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
+
+ obs = torch.as_tensor(
+ np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
+ )
+ act = torch.as_tensor(
+ np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
+ )
+ rtg, adv = compute_rtg_and_advantage(
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
+ )
+ orig_v = torch.as_tensor(
+ np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
+ )
+ orig_logp_a = torch.as_tensor(
+ np.concatenate([np.array(t.logp_a) for t in trajectories]),
+ device=self.device,
+ )
+
+ step_stats = []
+ for _ in range(self.n_epochs):
+ step_stats.clear()
+ if self.update_rtg_between_epochs:
+ rtg, adv = compute_rtg_and_advantage(
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
+ )
+ else:
+ adv = compute_advantage(
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
+ )
+ idxs = torch.randperm(len(obs))
+ for i in range(0, len(obs), self.batch_size):
+ mb_idxs = idxs[i : i + self.batch_size]
+ mb_adv = adv[mb_idxs]
+ if self.normalize_advantage:
+ mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
+ self.policy.reset_noise(self.batch_size)
+ step_stats.append(
+ self._train_step(
+ pi_clip,
+ v_clip,
+ ent_coef,
+ obs[mb_idxs],
+ act[mb_idxs],
+ rtg[mb_idxs],
+ mb_adv,
+ orig_v[mb_idxs],
+ orig_logp_a[mb_idxs],
+ )
+ )
+
+ y_pred, y_true = orig_v.cpu().numpy(), rtg.cpu().numpy()
+ var_y = np.var(y_true).item()
+ explained_var = (
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
+ )
+
+ return TrainStats(step_stats, explained_var)
+
+ def _train_step(
+ self,
+ pi_clip: float,
+ v_clip: Optional[float],
+ ent_coef: float,
+ obs: torch.Tensor,
+ act: torch.Tensor,
+ rtg: torch.Tensor,
+ adv: torch.Tensor,
+ orig_v: torch.Tensor,
+ orig_logp_a: torch.Tensor,
+ ) -> TrainStepStats:
+ logp_a, entropy, v = self.policy(obs, act)
+ logratio = logp_a - orig_logp_a
+ ratio = torch.exp(logratio)
+ clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
+ pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
+
+ v_loss_unclipped = (v - rtg) ** 2
+ if v_clip:
+ v_loss_clipped = (
+ orig_v + torch.clamp(v - orig_v, -v_clip, v_clip) - rtg
+ ) ** 2
+ v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
+ else:
+ v_loss = v_loss_unclipped.mean()
+ if self.ppo2_vf_coef_halving:
+ v_loss *= 0.5
+
+ entropy_loss = -entropy.mean()
+
+ loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
+ self.optimizer.step()
+
+ with torch.no_grad():
+ approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
+ clipped_frac = (
+ ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
+ )
+ val_clipped_frac = (
+ (((v - orig_v).abs() > v_clip).float().mean().cpu().numpy().item())
+ if v_clip
+ else 0
+ )
+
+ return TrainStepStats(
+ loss.item(),
+ pi_loss.item(),
+ v_loss.item(),
+ entropy_loss.item(),
+ approx_kl,
+ clipped_frac,
+ val_clipped_frac,
+ )
diff --git a/rl_algo_impls/publish/markdown_format.py b/rl_algo_impls/publish/markdown_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..07202dbbc8fa1820bef99a6864e5e90da1545538
--- /dev/null
+++ b/rl_algo_impls/publish/markdown_format.py
@@ -0,0 +1,210 @@
+import os
+import pandas as pd
+import wandb.apis.public
+import yaml
+
+from collections import defaultdict
+from dataclasses import dataclass, asdict
+from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
+from urllib.parse import urlparse
+
+from rl_algo_impls.runner.evaluate import Evaluation
+
+EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
+
+
+@dataclass
+class EvaluationRow:
+ algo: str
+ env: str
+ seed: Optional[int]
+ reward_mean: float
+ reward_std: float
+ eval_episodes: int
+ best: str
+ wandb_url: str
+
+ @staticmethod
+ def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
+ results = defaultdict(list)
+ for r in rows:
+ for k, v in asdict(r).items():
+ results[k].append(v)
+ return pd.DataFrame(results)
+
+
+class EvalTableData(NamedTuple):
+ run: wandb.apis.public.Run
+ evaluation: Evaluation
+
+
+def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
+ best_stats = sorted(
+ [d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
+ )[0]
+ table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
+ rows = [
+ EvaluationRow(
+ config.algo,
+ config.env_id,
+ config.seed(),
+ stats.score.mean,
+ stats.score.std,
+ len(stats),
+ "*" if stats == best_stats else "",
+ f"[wandb]({r.url})",
+ )
+ for (r, (_, stats, config)) in table_data
+ ]
+ df = EvaluationRow.data_frame(rows)
+ return df.to_markdown(index=False)
+
+
+def github_project_link(github_url: str) -> str:
+ return f"[{urlparse(github_url).path}]({github_url})"
+
+
+def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
+ algo_caps = algo.upper()
+ lines = [
+ f"# **{algo_caps}** Agent playing **{env}**",
+ f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
+ f"the {github_project_link(github_url)} repo.",
+ f"All models trained at this commit can be found at {wandb_report_url}.",
+ ]
+ return "\n\n".join(lines)
+
+
+def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
+ if not commit_hash:
+ return github_project_link(github_url)
+ return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
+
+
+def results_section(
+ table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
+) -> str:
+ # type: ignore
+ lines = [
+ "## Training Results",
+ f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
+ + "agents using different initial seeds. "
+ + f"These agents were trained by checking out "
+ + f"{github_tree_link(github_url, commit_hash)}. "
+ + "The best and last models were kept from each training. "
+ + "This submission has loaded the best models from each training, reevaluates "
+ + "them, and selects the best model from these latest evaluations (mean - std).",
+ ]
+ lines.append(evaluation_table(table_data))
+ return "\n\n".join(lines)
+
+
+def prerequisites_section() -> str:
+ return """
+### Prerequisites: Weights & Biases (WandB)
+Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
+By default training goes to a rl-algo-impls project while benchmarks go to
+rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
+models and the model weights are uploaded to WandB.
+
+Before doing anything below, you'll need to create a wandb account and run `wandb
+login`.
+"""
+
+
+def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
+ return f"""
+## Usage
+{urlparse(github_url).path}: {github_url}
+
+Note: While the model state dictionary and hyperaparameters are saved, the latest
+implementation could be sufficiently different to not be able to reproduce similar
+results. You might need to checkout the commit the agent was trained on:
+{github_tree_link(github_url, commit_hash)}.
+```
+# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
+python enjoy.py --wandb-run-path={run_path}
+```
+
+Setup hasn't been completely worked out yet, so you might be best served by using Google
+Colab starting from the
+[colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
+notebook.
+"""
+
+
+def training_setion(
+ github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
+) -> str:
+ return f"""
+## Training
+If you want the highest chance to reproduce these results, you'll want to checkout the
+commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
+training is deterministic, different hardware will give different results.
+
+```
+python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
+```
+
+Setup hasn't been completely worked out yet, so you might be best served by using Google
+Colab starting from the
+[colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
+notebook.
+"""
+
+
+def benchmarking_section(report_url: str) -> str:
+ return f"""
+## Benchmarking (with Lambda Labs instance)
+This and other models from {report_url} were generated by running a script on a Lambda
+Labs instance. In a Lambda Labs instance terminal:
+```
+git clone git@github.com:sgoodfriend/rl-algo-impls.git
+cd rl-algo-impls
+bash ./lambda_labs/setup.sh
+wandb login
+bash ./lambda_labs/benchmark.sh [-a {{"ppo a2c dqn vpg"}}] [-e ENVS] [-j {{6}}] [-p {{rl-algo-impls-benchmarks}}] [-s {{"1 2 3"}}]
+```
+
+### Alternative: Google Colab Pro+
+As an alternative,
+[colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
+can be used. However, this requires a Google Colab Pro+ subscription and running across
+4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
+"""
+
+
+def hyperparams_section(run_config: Dict[str, Any]) -> str:
+ return f"""
+## Hyperparameters
+This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
+run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
+close and has some additional data:
+```
+{yaml.dump(run_config)}
+```
+"""
+
+
+def model_card_text(
+ algo: str,
+ env: str,
+ github_url: str,
+ commit_hash: str,
+ wandb_report_url: str,
+ table_data: List[EvalTableData],
+ best_eval: EvalTableData,
+) -> str:
+ run, (_, _, config) = best_eval
+ run_path = "/".join(run.path)
+ return "\n\n".join(
+ [
+ header_section(algo, env, github_url, wandb_report_url),
+ results_section(table_data, algo, github_url, commit_hash),
+ prerequisites_section(),
+ usage_section(github_url, run_path, commit_hash),
+ training_setion(github_url, commit_hash, algo, env, config.seed()),
+ benchmarking_section(wandb_report_url),
+ hyperparams_section(run.config),
+ ]
+ )
diff --git a/rl_algo_impls/runner/config.py b/rl_algo_impls/runner/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..565913e716c537208e7353a13b70cff23ed0eb09
--- /dev/null
+++ b/rl_algo_impls/runner/config.py
@@ -0,0 +1,189 @@
+import dataclasses
+import inspect
+import itertools
+import os
+
+from datetime import datetime
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Type, TypeVar, Union
+
+
+RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
+
+
+@dataclass
+class RunArgs:
+ algo: str
+ env: str
+ seed: Optional[int] = None
+ use_deterministic_algorithms: bool = True
+
+ @classmethod
+ def expand_from_dict(
+ cls: Type[RunArgsSelf], d: Dict[str, Any]
+ ) -> List[RunArgsSelf]:
+ maybe_listify = lambda v: [v] if isinstance(v, str) or isinstance(v, int) else v
+ algos = maybe_listify(d["algo"])
+ envs = maybe_listify(d["env"])
+ seeds = maybe_listify(d["seed"])
+ args = []
+ for algo, env, seed in itertools.product(algos, envs, seeds):
+ _d = d.copy()
+ _d.update({"algo": algo, "env": env, "seed": seed})
+ args.append(cls(**_d))
+ return args
+
+
+@dataclass
+class EnvHyperparams:
+ env_type: str = "gymvec"
+ n_envs: int = 1
+ frame_stack: int = 1
+ make_kwargs: Optional[Dict[str, Any]] = None
+ no_reward_timeout_steps: Optional[int] = None
+ no_reward_fire_steps: Optional[int] = None
+ vec_env_class: str = "sync"
+ normalize: bool = False
+ normalize_kwargs: Optional[Dict[str, Any]] = None
+ rolling_length: int = 100
+ train_record_video: bool = False
+ video_step_interval: Union[int, float] = 1_000_000
+ initial_steps_to_truncate: Optional[int] = None
+ clip_atari_rewards: bool = True
+
+
+HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
+
+
+@dataclass
+class Hyperparams:
+ device: str = "auto"
+ n_timesteps: Union[int, float] = 100_000
+ env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ env_id: Optional[str] = None
+
+ @classmethod
+ def from_dict_with_extra_fields(
+ cls: Type[HyperparamsSelf], d: Dict[str, Any]
+ ) -> HyperparamsSelf:
+ return cls(
+ **{k: v for k, v in d.items() if k in inspect.signature(cls).parameters}
+ )
+
+
+@dataclass
+class Config:
+ args: RunArgs
+ hyperparams: Hyperparams
+ root_dir: str
+ run_id: str = datetime.now().isoformat()
+
+ def seed(self, training: bool = True) -> Optional[int]:
+ seed = self.args.seed
+ if training or seed is None:
+ return seed
+ return seed + self.env_hyperparams.get("n_envs", 1)
+
+ @property
+ def device(self) -> str:
+ return self.hyperparams.device
+
+ @property
+ def n_timesteps(self) -> int:
+ return int(self.hyperparams.n_timesteps)
+
+ @property
+ def env_hyperparams(self) -> Dict[str, Any]:
+ return self.hyperparams.env_hyperparams
+
+ @property
+ def policy_hyperparams(self) -> Dict[str, Any]:
+ return self.hyperparams.policy_hyperparams
+
+ @property
+ def algo_hyperparams(self) -> Dict[str, Any]:
+ return self.hyperparams.algo_hyperparams
+
+ @property
+ def eval_params(self) -> Dict[str, Any]:
+ return self.hyperparams.eval_params
+
+ @property
+ def algo(self) -> str:
+ return self.args.algo
+
+ @property
+ def env_id(self) -> str:
+ return self.hyperparams.env_id or self.args.env
+
+ def model_name(self, include_seed: bool = True) -> str:
+ # Use arg env name instead of environment name
+ parts = [self.algo, self.args.env]
+ if include_seed and self.args.seed is not None:
+ parts.append(f"S{self.args.seed}")
+
+ # Assume that the custom arg name already has the necessary information
+ if not self.hyperparams.env_id:
+ make_kwargs = self.env_hyperparams.get("make_kwargs", {})
+ if make_kwargs:
+ for k, v in make_kwargs.items():
+ if type(v) == bool and v:
+ parts.append(k)
+ elif type(v) == int and v:
+ parts.append(f"{k}{v}")
+ else:
+ parts.append(str(v))
+
+ return "-".join(parts)
+
+ def run_name(self, include_seed: bool = True) -> str:
+ parts = [self.model_name(include_seed=include_seed), self.run_id]
+ return "-".join(parts)
+
+ @property
+ def saved_models_dir(self) -> str:
+ return os.path.join(self.root_dir, "saved_models")
+
+ @property
+ def downloaded_models_dir(self) -> str:
+ return os.path.join(self.root_dir, "downloaded_models")
+
+ def model_dir_name(
+ self,
+ best: bool = False,
+ extension: str = "",
+ ) -> str:
+ return self.model_name() + ("-best" if best else "") + extension
+
+ def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
+ return os.path.join(
+ self.saved_models_dir if not downloaded else self.downloaded_models_dir,
+ self.model_dir_name(best=best),
+ )
+
+ @property
+ def runs_dir(self) -> str:
+ return os.path.join(self.root_dir, "runs")
+
+ @property
+ def tensorboard_summary_path(self) -> str:
+ return os.path.join(self.runs_dir, self.run_name())
+
+ @property
+ def logs_path(self) -> str:
+ return os.path.join(self.runs_dir, f"log.yml")
+
+ @property
+ def videos_dir(self) -> str:
+ return os.path.join(self.root_dir, "videos")
+
+ @property
+ def video_prefix(self) -> str:
+ return os.path.join(self.videos_dir, self.model_name())
+
+ @property
+ def best_videos_dir(self) -> str:
+ return os.path.join(self.videos_dir, f"{self.model_name()}-best")
diff --git a/rl_algo_impls/runner/env.py b/rl_algo_impls/runner/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccdd8343c9f8c853e839a030b858fda50e53a7cb
--- /dev/null
+++ b/rl_algo_impls/runner/env.py
@@ -0,0 +1,292 @@
+import gym
+import numpy as np
+import os
+
+from dataclasses import asdict, astuple
+from gym.vector.async_vector_env import AsyncVectorEnv
+from gym.vector.sync_vector_env import SyncVectorEnv
+from gym.wrappers.resize_observation import ResizeObservation
+from gym.wrappers.gray_scale_observation import GrayScaleObservation
+from gym.wrappers.frame_stack import FrameStack
+from stable_baselines3.common.atari_wrappers import (
+ MaxAndSkipEnv,
+ NoopResetEnv,
+)
+from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
+from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
+from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Callable, Optional
+
+from rl_algo_impls.runner.config import Config, EnvHyperparams
+from rl_algo_impls.shared.policy.policy import VEC_NORMALIZE_FILENAME
+from rl_algo_impls.wrappers.atari_wrappers import (
+ EpisodicLifeEnv,
+ FireOnLifeStarttEnv,
+ ClipRewardEnv,
+)
+from rl_algo_impls.wrappers.episode_record_video import EpisodeRecordVideo
+from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
+from rl_algo_impls.wrappers.initial_step_truncate_wrapper import (
+ InitialStepTruncateWrapper,
+)
+from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
+from rl_algo_impls.wrappers.no_reward_timeout import NoRewardTimeout
+from rl_algo_impls.wrappers.noop_env_seed import NoopEnvSeed
+from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
+from rl_algo_impls.wrappers.sync_vector_env_render_compat import (
+ SyncVectorEnvRenderCompat,
+)
+from rl_algo_impls.wrappers.transpose_image_observation import TransposeImageObservation
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
+from rl_algo_impls.wrappers.video_compat_wrapper import VideoCompatWrapper
+
+
+def make_env(
+ config: Config,
+ hparams: EnvHyperparams,
+ training: bool = True,
+ render: bool = False,
+ normalize_load_path: Optional[str] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+) -> VecEnv:
+ if hparams.env_type == "procgen":
+ return _make_procgen_env(
+ config,
+ hparams,
+ training=training,
+ render=render,
+ normalize_load_path=normalize_load_path,
+ tb_writer=tb_writer,
+ )
+ elif hparams.env_type in {"sb3vec", "gymvec"}:
+ return _make_vec_env(
+ config,
+ hparams,
+ training=training,
+ render=render,
+ normalize_load_path=normalize_load_path,
+ tb_writer=tb_writer,
+ )
+ else:
+ raise ValueError(f"env_type {hparams.env_type} not supported")
+
+
+def make_eval_env(
+ config: Config,
+ hparams: EnvHyperparams,
+ override_n_envs: Optional[int] = None,
+ **kwargs,
+) -> VecEnv:
+ kwargs = kwargs.copy()
+ kwargs["training"] = False
+ if override_n_envs is not None:
+ hparams_kwargs = asdict(hparams)
+ hparams_kwargs["n_envs"] = override_n_envs
+ if override_n_envs == 1:
+ hparams_kwargs["vec_env_class"] = "sync"
+ hparams = EnvHyperparams(**hparams_kwargs)
+ return make_env(config, hparams, **kwargs)
+
+
+def _make_vec_env(
+ config: Config,
+ hparams: EnvHyperparams,
+ training: bool = True,
+ render: bool = False,
+ normalize_load_path: Optional[str] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+) -> VecEnv:
+ (
+ env_type,
+ n_envs,
+ frame_stack,
+ make_kwargs,
+ no_reward_timeout_steps,
+ no_reward_fire_steps,
+ vec_env_class,
+ normalize,
+ normalize_kwargs,
+ rolling_length,
+ train_record_video,
+ video_step_interval,
+ initial_steps_to_truncate,
+ clip_atari_rewards,
+ ) = astuple(hparams)
+
+ if "BulletEnv" in config.env_id:
+ import pybullet_envs
+
+ spec = gym.spec(config.env_id)
+ seed = config.seed(training=training)
+
+ make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
+ if "BulletEnv" in config.env_id and render:
+ make_kwargs["render"] = True
+ if "CarRacing" in config.env_id:
+ make_kwargs["verbose"] = 0
+ if "procgen" in config.env_id:
+ if not render:
+ make_kwargs["render_mode"] = "rgb_array"
+
+ def make(idx: int) -> Callable[[], gym.Env]:
+ def _make() -> gym.Env:
+ env = gym.make(config.env_id, **make_kwargs)
+ env = gym.wrappers.RecordEpisodeStatistics(env)
+ env = VideoCompatWrapper(env)
+ if training and train_record_video and idx == 0:
+ env = EpisodeRecordVideo(
+ env,
+ config.video_prefix,
+ step_increment=n_envs,
+ video_step_interval=int(video_step_interval),
+ )
+ if training and initial_steps_to_truncate:
+ env = InitialStepTruncateWrapper(
+ env, idx * initial_steps_to_truncate // n_envs
+ )
+ if "AtariEnv" in spec.entry_point: # type: ignore
+ env = NoopResetEnv(env, noop_max=30)
+ env = MaxAndSkipEnv(env, skip=4)
+ env = EpisodicLifeEnv(env, training=training)
+ action_meanings = env.unwrapped.get_action_meanings()
+ if "FIRE" in action_meanings: # type: ignore
+ env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
+ if clip_atari_rewards:
+ env = ClipRewardEnv(env, training=training)
+ env = ResizeObservation(env, (84, 84))
+ env = GrayScaleObservation(env, keep_dim=False)
+ env = FrameStack(env, frame_stack)
+ elif "CarRacing" in config.env_id:
+ env = ResizeObservation(env, (64, 64))
+ env = GrayScaleObservation(env, keep_dim=False)
+ env = FrameStack(env, frame_stack)
+ elif "procgen" in config.env_id:
+ # env = GrayScaleObservation(env, keep_dim=False)
+ env = NoopEnvSeed(env)
+ env = TransposeImageObservation(env)
+ if frame_stack > 1:
+ env = FrameStack(env, frame_stack)
+
+ if no_reward_timeout_steps:
+ env = NoRewardTimeout(
+ env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
+ )
+
+ if seed is not None:
+ env.seed(seed + idx)
+ env.action_space.seed(seed + idx)
+ env.observation_space.seed(seed + idx)
+
+ return env
+
+ return _make
+
+ if env_type == "sb3vec":
+ VecEnvClass = {"sync": DummyVecEnv, "async": SubprocVecEnv}[vec_env_class]
+ elif env_type == "gymvec":
+ VecEnvClass = {"sync": SyncVectorEnv, "async": AsyncVectorEnv}[vec_env_class]
+ else:
+ raise ValueError(f"env_type {env_type} unsupported")
+ envs = VecEnvClass([make(i) for i in range(n_envs)])
+ if env_type == "gymvec" and vec_env_class == "sync":
+ envs = SyncVectorEnvRenderCompat(envs)
+ if training:
+ assert tb_writer
+ envs = EpisodeStatsWriter(
+ envs, tb_writer, training=training, rolling_length=rolling_length
+ )
+ if normalize:
+ normalize_kwargs = normalize_kwargs or {}
+ if env_type == "sb3vec":
+ if normalize_load_path:
+ envs = VecNormalize.load(
+ os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
+ envs, # type: ignore
+ )
+ else:
+ envs = VecNormalize(
+ envs, # type: ignore
+ training=training,
+ **normalize_kwargs,
+ )
+ if not training:
+ envs.norm_reward = False
+ else:
+ if normalize_kwargs.get("norm_obs", True):
+ envs = NormalizeObservation(
+ envs, training=training, clip=normalize_kwargs.get("clip_obs", 10.0)
+ )
+ if training and normalize_kwargs.get("norm_reward", True):
+ envs = NormalizeReward(
+ envs,
+ training=training,
+ clip=normalize_kwargs.get("clip_reward", 10.0),
+ )
+ return envs
+
+
+def _make_procgen_env(
+ config: Config,
+ hparams: EnvHyperparams,
+ training: bool = True,
+ render: bool = False,
+ normalize_load_path: Optional[str] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+) -> VecEnv:
+ from gym3 import ViewerWrapper, ExtractDictObWrapper
+ from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv
+
+ (
+ _, # env_type
+ n_envs,
+ _, # frame_stack
+ make_kwargs,
+ _, # no_reward_timeout_steps
+ _, # no_reward_fire_steps
+ _, # vec_env_class
+ normalize,
+ normalize_kwargs,
+ rolling_length,
+ _, # train_record_video
+ _, # video_step_interval
+ _, # initial_steps_to_truncate
+ _, # clip_atari_rewards
+ ) = astuple(hparams)
+
+ seed = config.seed(training=training)
+
+ make_kwargs = make_kwargs or {}
+ make_kwargs["render_mode"] = "rgb_array"
+ if seed is not None:
+ make_kwargs["rand_seed"] = seed
+
+ envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs)
+ envs = ExtractDictObWrapper(envs, key="rgb")
+ if render:
+ envs = ViewerWrapper(envs, info_key="rgb")
+ envs = ToBaselinesVecEnv(envs)
+ envs = IsVectorEnv(envs)
+ # TODO: Handle Grayscale and/or FrameStack
+ envs = TransposeImageObservation(envs)
+
+ envs = gym.wrappers.RecordEpisodeStatistics(envs)
+
+ if seed is not None:
+ envs.action_space.seed(seed)
+ envs.observation_space.seed(seed)
+
+ if training:
+ assert tb_writer
+ envs = EpisodeStatsWriter(
+ envs, tb_writer, training=training, rolling_length=rolling_length
+ )
+ if normalize and training:
+ normalize_kwargs = normalize_kwargs or {}
+ envs = gym.wrappers.NormalizeReward(envs)
+ clip_obs = normalize_kwargs.get("clip_reward", 10.0)
+ envs = gym.wrappers.TransformReward(
+ envs, lambda r: np.clip(r, -clip_obs, clip_obs)
+ )
+
+ return envs # type: ignore
diff --git a/rl_algo_impls/runner/evaluate.py b/rl_algo_impls/runner/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad85bcac91ba9cfef99ff43b67d71b223c4e7051
--- /dev/null
+++ b/rl_algo_impls/runner/evaluate.py
@@ -0,0 +1,103 @@
+import os
+import shutil
+
+from dataclasses import dataclass
+from typing import NamedTuple, Optional
+
+from rl_algo_impls.runner.env import make_eval_env
+from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
+from rl_algo_impls.runner.running_utils import (
+ load_hyperparams,
+ set_seeds,
+ get_device,
+ make_policy,
+)
+from rl_algo_impls.shared.callbacks.eval_callback import evaluate
+from rl_algo_impls.shared.policy.policy import Policy
+from rl_algo_impls.shared.stats import EpisodesStats
+
+
+@dataclass
+class EvalArgs(RunArgs):
+ render: bool = True
+ best: bool = True
+ n_envs: Optional[int] = 1
+ n_episodes: int = 3
+ deterministic_eval: Optional[bool] = None
+ no_print_returns: bool = False
+ wandb_run_path: Optional[str] = None
+
+
+class Evaluation(NamedTuple):
+ policy: Policy
+ stats: EpisodesStats
+ config: Config
+
+
+def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
+ if args.wandb_run_path:
+ import wandb
+
+ api = wandb.Api()
+ run = api.run(args.wandb_run_path)
+ params = run.config
+
+ args.algo = params["algo"]
+ args.env = params["env"]
+ args.seed = params.get("seed", None)
+ args.use_deterministic_algorithms = params.get(
+ "use_deterministic_algorithms", True
+ )
+
+ config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
+ model_path = config.model_dir_path(best=args.best, downloaded=True)
+
+ model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
+ run.file(model_archive_name).download()
+ if os.path.isdir(model_path):
+ shutil.rmtree(model_path)
+ shutil.unpack_archive(model_archive_name, model_path)
+ os.remove(model_archive_name)
+ else:
+ hyperparams = load_hyperparams(args.algo, args.env)
+
+ config = Config(args, hyperparams, root_dir)
+ model_path = config.model_dir_path(best=args.best)
+
+ print(args)
+
+ set_seeds(args.seed, args.use_deterministic_algorithms)
+
+ env = make_eval_env(
+ config,
+ EnvHyperparams(**config.env_hyperparams),
+ override_n_envs=args.n_envs,
+ render=args.render,
+ normalize_load_path=model_path,
+ )
+ device = get_device(config.device, env)
+ policy = make_policy(
+ args.algo,
+ env,
+ device,
+ load_path=model_path,
+ **config.policy_hyperparams,
+ ).eval()
+
+ deterministic = (
+ args.deterministic_eval
+ if args.deterministic_eval is not None
+ else config.eval_params.get("deterministic", True)
+ )
+ return Evaluation(
+ policy,
+ evaluate(
+ env,
+ policy,
+ args.n_episodes,
+ render=args.render,
+ deterministic=deterministic,
+ print_returns=not args.no_print_returns,
+ ),
+ config,
+ )
diff --git a/rl_algo_impls/runner/running_utils.py b/rl_algo_impls/runner/running_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee1ef1304c175e34f69b205f2bf2b139eafc16f4
--- /dev/null
+++ b/rl_algo_impls/runner/running_utils.py
@@ -0,0 +1,192 @@
+import argparse
+import gym
+import json
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import random
+import torch
+import torch.backends.cudnn
+import yaml
+
+from dataclasses import asdict
+from gym.spaces import Box, Discrete
+from pathlib import Path
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Dict, Optional, Type, Union
+
+from rl_algo_impls.runner.config import Hyperparams
+from rl_algo_impls.shared.algorithm import Algorithm
+from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
+from rl_algo_impls.shared.policy.on_policy import ActorCritic
+from rl_algo_impls.shared.policy.policy import Policy
+
+from rl_algo_impls.a2c.a2c import A2C
+from rl_algo_impls.dqn.dqn import DQN
+from rl_algo_impls.dqn.policy import DQNPolicy
+from rl_algo_impls.ppo.ppo import PPO
+from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
+from rl_algo_impls.vpg.policy import VPGActorCritic
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
+
+ALGOS: Dict[str, Type[Algorithm]] = {
+ "dqn": DQN,
+ "vpg": VanillaPolicyGradient,
+ "ppo": PPO,
+ "a2c": A2C,
+}
+POLICIES: Dict[str, Type[Policy]] = {
+ "dqn": DQNPolicy,
+ "vpg": VPGActorCritic,
+ "ppo": ActorCritic,
+ "a2c": ActorCritic,
+}
+
+HYPERPARAMS_PATH = "hyperparams"
+
+
+def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--algo",
+ default=["dqn"],
+ type=str,
+ choices=list(ALGOS.keys()),
+ nargs="+" if multiple else 1,
+ help="Abbreviation(s) of algorithm(s)",
+ )
+ parser.add_argument(
+ "--env",
+ default=["CartPole-v1"],
+ type=str,
+ nargs="+" if multiple else 1,
+ help="Name of environment(s) in gym",
+ )
+ parser.add_argument(
+ "--seed",
+ default=[1],
+ type=int,
+ nargs="*" if multiple else "?",
+ help="Seeds to run experiment. Unset will do one run with no set seed",
+ )
+ return parser
+
+
+def load_hyperparams(algo: str, env_id: str) -> Hyperparams:
+ root_path = Path(__file__).parent.parent
+ hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
+ with open(hyperparams_path, "r") as f:
+ hyperparams_dict = yaml.safe_load(f)
+
+ if env_id in hyperparams_dict:
+ return Hyperparams(**hyperparams_dict[env_id])
+
+ if "BulletEnv" in env_id:
+ import pybullet_envs
+ spec = gym.spec(env_id)
+ if "AtariEnv" in str(spec.entry_point) and "_atari" in hyperparams_dict:
+ return Hyperparams(**hyperparams_dict["_atari"])
+ else:
+ raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
+
+
+def get_device(device: str, env: VecEnv) -> torch.device:
+ # cuda by default
+ if device == "auto":
+ device = "cuda"
+ # Apple MPS is a second choice (sometimes)
+ if device == "cuda" and not torch.cuda.is_available():
+ device = "mps"
+ # If no MPS, fallback to cpu
+ if device == "mps" and not torch.backends.mps.is_available():
+ device = "cpu"
+ # Simple environments like Discreet and 1-D Boxes might also be better
+ # served with the CPU.
+ if device == "mps":
+ obs_space = single_observation_space(env)
+ if isinstance(obs_space, Discrete):
+ device = "cpu"
+ elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
+ device = "cpu"
+ print(f"Device: {device}")
+ return torch.device(device)
+
+
+def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
+ if seed is None:
+ return
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.backends.cudnn.benchmark = False
+ torch.use_deterministic_algorithms(use_deterministic_algorithms)
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+ # Stop warning and it would introduce stochasticity if I was using TF
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
+
+
+def make_policy(
+ algo: str,
+ env: VecEnv,
+ device: torch.device,
+ load_path: Optional[str] = None,
+ **kwargs,
+) -> Policy:
+ policy = POLICIES[algo](env, **kwargs).to(device)
+ if load_path:
+ policy.load(load_path)
+ return policy
+
+
+def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
+ figure = plt.figure()
+ cumulative_steps = [
+ (idx + 1) * callback.step_freq for idx in range(len(callback.stats))
+ ]
+ plt.plot(
+ cumulative_steps,
+ [s.score.mean for s in callback.stats],
+ "b-",
+ label="mean",
+ )
+ plt.plot(
+ cumulative_steps,
+ [s.score.mean - s.score.std for s in callback.stats],
+ "g--",
+ label="mean-std",
+ )
+ plt.fill_between(
+ cumulative_steps,
+ [s.score.min for s in callback.stats], # type: ignore
+ [s.score.max for s in callback.stats], # type: ignore
+ facecolor="cyan",
+ label="range",
+ )
+ plt.xlabel("Steps")
+ plt.ylabel("Score")
+ plt.legend()
+ plt.title(f"Eval {run_name}")
+ tb_writer.add_figure("eval", figure)
+
+
+Scalar = Union[bool, str, float, int, None]
+
+
+def hparam_dict(
+ hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
+) -> Dict[str, Scalar]:
+ flattened = args.copy()
+ for k, v in flattened.items():
+ if isinstance(v, list):
+ flattened[k] = json.dumps(v)
+ for k, v in asdict(hyperparams).items():
+ if isinstance(v, dict):
+ for sk, sv in v.items():
+ key = f"{k}/{sk}"
+ if isinstance(sv, dict) or isinstance(sv, list):
+ flattened[key] = str(sv)
+ else:
+ flattened[key] = sv
+ else:
+ flattened[k] = v # type: ignore
+ return flattened # type: ignore
diff --git a/rl_algo_impls/runner/train.py b/rl_algo_impls/runner/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..359117263e09ba9d4cf4859fe96901eaf85f49a5
--- /dev/null
+++ b/rl_algo_impls/runner/train.py
@@ -0,0 +1,143 @@
+# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
+import os
+
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+
+import dataclasses
+import shutil
+import wandb
+import yaml
+
+from dataclasses import asdict, dataclass
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Any, Dict, Optional, Sequence
+
+from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
+from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
+from rl_algo_impls.runner.env import make_env, make_eval_env
+from rl_algo_impls.runner.running_utils import (
+ ALGOS,
+ load_hyperparams,
+ set_seeds,
+ get_device,
+ make_policy,
+ plot_eval_callback,
+ hparam_dict,
+)
+from rl_algo_impls.shared.stats import EpisodesStats
+
+
+@dataclass
+class TrainArgs(RunArgs):
+ wandb_project_name: Optional[str] = None
+ wandb_entity: Optional[str] = None
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
+ wandb_group: Optional[str] = None
+
+
+def train(args: TrainArgs):
+ print(args)
+ hyperparams = load_hyperparams(args.algo, args.env)
+ print(hyperparams)
+ config = Config(args, hyperparams, os.getcwd())
+
+ wandb_enabled = args.wandb_project_name
+ if wandb_enabled:
+ wandb.tensorboard.patch(
+ root_logdir=config.tensorboard_summary_path, pytorch=True
+ )
+ wandb.init(
+ project=args.wandb_project_name,
+ entity=args.wandb_entity,
+ config=asdict(hyperparams),
+ name=config.run_name(),
+ monitor_gym=True,
+ save_code=True,
+ tags=args.wandb_tags,
+ group=args.wandb_group,
+ )
+ wandb.config.update(args)
+
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
+
+ set_seeds(args.seed, args.use_deterministic_algorithms)
+
+ env = make_env(
+ config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
+ )
+ device = get_device(config.device, env)
+ policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
+
+ num_parameters = policy.num_parameters()
+ num_trainable_parameters = policy.num_trainable_parameters()
+ if wandb_enabled:
+ wandb.run.summary["num_parameters"] = num_parameters
+ wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters
+ else:
+ print(
+ f"num_parameters = {num_parameters} ; "
+ f"num_trainable_parameters = {num_trainable_parameters}"
+ )
+
+ eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
+ record_best_videos = config.eval_params.get("record_best_videos", True)
+ callback = EvalCallback(
+ policy,
+ eval_env,
+ tb_writer,
+ best_model_path=config.model_dir_path(best=True),
+ **config.eval_params,
+ video_env=make_eval_env(
+ config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
+ )
+ if record_best_videos
+ else None,
+ best_video_dir=config.best_videos_dir,
+ )
+ algo.learn(config.n_timesteps, callback=callback)
+
+ policy.save(config.model_dir_path(best=False))
+
+ eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
+
+ plot_eval_callback(callback, tb_writer, config.run_name())
+
+ log_dict: Dict[str, Any] = {
+ "eval": eval_stats._asdict(),
+ }
+ if callback.best:
+ log_dict["best_eval"] = callback.best._asdict()
+ log_dict.update(asdict(hyperparams))
+ log_dict.update(vars(args))
+ with open(config.logs_path, "a") as f:
+ yaml.dump({config.run_name(): log_dict}, f)
+
+ best_eval_stats: EpisodesStats = callback.best # type: ignore
+ tb_writer.add_hparams(
+ hparam_dict(hyperparams, vars(args)),
+ {
+ "hparam/best_mean": best_eval_stats.score.mean,
+ "hparam/best_result": best_eval_stats.score.mean
+ - best_eval_stats.score.std,
+ "hparam/last_mean": eval_stats.score.mean,
+ "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
+ },
+ None,
+ config.run_name(),
+ )
+
+ tb_writer.close()
+
+ if wandb_enabled:
+ shutil.make_archive(
+ os.path.join(wandb.run.dir, config.model_dir_name()),
+ "zip",
+ config.model_dir_path(),
+ )
+ shutil.make_archive(
+ os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
+ "zip",
+ config.model_dir_path(best=True),
+ )
+ wandb.finish()
diff --git a/rl_algo_impls/shared/algorithm.py b/rl_algo_impls/shared/algorithm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70160aaaeb6a0fb92aaaef473fd0b665999d2f9
--- /dev/null
+++ b/rl_algo_impls/shared/algorithm.py
@@ -0,0 +1,39 @@
+import gym
+import torch
+
+from abc import ABC, abstractmethod
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Optional, TypeVar
+
+from rl_algo_impls.shared.callbacks.callback import Callback
+from rl_algo_impls.shared.policy.policy import Policy
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
+
+AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
+
+
+class Algorithm(ABC):
+ @abstractmethod
+ def __init__(
+ self,
+ policy: Policy,
+ env: VecEnv,
+ device: torch.device,
+ tb_writer: SummaryWriter,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ self.policy = policy
+ self.env = env
+ self.device = device
+ self.tb_writer = tb_writer
+
+ @abstractmethod
+ def learn(
+ self: AlgorithmSelf,
+ train_timesteps: int,
+ callback: Optional[Callback] = None,
+ total_timesteps: Optional[int] = None,
+ start_timesteps: int = 0,
+ ) -> AlgorithmSelf:
+ ...
diff --git a/rl_algo_impls/shared/callbacks/callback.py b/rl_algo_impls/shared/callbacks/callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a8fbf77f5527d8099aa46ff3892efa1f1241063
--- /dev/null
+++ b/rl_algo_impls/shared/callbacks/callback.py
@@ -0,0 +1,11 @@
+from abc import ABC
+
+
+class Callback(ABC):
+ def __init__(self) -> None:
+ super().__init__()
+ self.timesteps_elapsed = 0
+
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
+ self.timesteps_elapsed += timesteps_elapsed
+ return True
diff --git a/rl_algo_impls/shared/callbacks/eval_callback.py b/rl_algo_impls/shared/callbacks/eval_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..f32b2d1c6dede8bf3c35e6086b26ca636d84958f
--- /dev/null
+++ b/rl_algo_impls/shared/callbacks/eval_callback.py
@@ -0,0 +1,199 @@
+import itertools
+import numpy as np
+import os
+
+from time import perf_counter
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import List, Optional, Union
+
+from rl_algo_impls.shared.callbacks.callback import Callback
+from rl_algo_impls.shared.policy.policy import Policy
+from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
+from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
+
+
+class EvaluateAccumulator(EpisodeAccumulator):
+ def __init__(
+ self,
+ num_envs: int,
+ goal_episodes: int,
+ print_returns: bool = True,
+ ignore_first_episode: bool = False,
+ ):
+ super().__init__(num_envs)
+ self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
+ self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
+ self.print_returns = print_returns
+ if ignore_first_episode:
+ first_done = set()
+
+ def should_record_done(idx: int) -> bool:
+ has_done_first_episode = idx in first_done
+ first_done.add(idx)
+ return has_done_first_episode
+
+ self.should_record_done = should_record_done
+ else:
+ self.should_record_done = lambda idx: True
+
+ def on_done(self, ep_idx: int, episode: Episode) -> None:
+ if (
+ self.should_record_done(ep_idx)
+ and len(self.completed_episodes_by_env_idx[ep_idx])
+ >= self.goal_episodes_per_env
+ ):
+ return
+ self.completed_episodes_by_env_idx[ep_idx].append(episode)
+ if self.print_returns:
+ print(
+ f"Episode {len(self)} | "
+ f"Score {episode.score} | "
+ f"Length {episode.length}"
+ )
+
+ def __len__(self) -> int:
+ return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
+
+ @property
+ def episodes(self) -> List[Episode]:
+ return list(itertools.chain(*self.completed_episodes_by_env_idx))
+
+ def is_done(self) -> bool:
+ return all(
+ len(ce) == self.goal_episodes_per_env
+ for ce in self.completed_episodes_by_env_idx
+ )
+
+
+def evaluate(
+ env: VecEnv,
+ policy: Policy,
+ n_episodes: int,
+ render: bool = False,
+ deterministic: bool = True,
+ print_returns: bool = True,
+ ignore_first_episode: bool = False,
+) -> EpisodesStats:
+ policy.sync_normalization(env)
+ policy.eval()
+
+ episodes = EvaluateAccumulator(
+ env.num_envs, n_episodes, print_returns, ignore_first_episode
+ )
+
+ obs = env.reset()
+ while not episodes.is_done():
+ act = policy.act(obs, deterministic=deterministic)
+ obs, rew, done, _ = env.step(act)
+ episodes.step(rew, done)
+ if render:
+ env.render()
+ stats = EpisodesStats(episodes.episodes)
+ if print_returns:
+ print(stats)
+ return stats
+
+
+class EvalCallback(Callback):
+ def __init__(
+ self,
+ policy: Policy,
+ env: VecEnv,
+ tb_writer: SummaryWriter,
+ best_model_path: Optional[str] = None,
+ step_freq: Union[int, float] = 50_000,
+ n_episodes: int = 10,
+ save_best: bool = True,
+ deterministic: bool = True,
+ record_best_videos: bool = True,
+ video_env: Optional[VecEnv] = None,
+ best_video_dir: Optional[str] = None,
+ max_video_length: int = 3600,
+ ignore_first_episode: bool = False,
+ ) -> None:
+ super().__init__()
+ self.policy = policy
+ self.env = env
+ self.tb_writer = tb_writer
+ self.best_model_path = best_model_path
+ self.step_freq = int(step_freq)
+ self.n_episodes = n_episodes
+ self.save_best = save_best
+ self.deterministic = deterministic
+ self.stats: List[EpisodesStats] = []
+ self.best = None
+
+ self.record_best_videos = record_best_videos
+ assert video_env or not record_best_videos
+ self.video_env = video_env
+ assert best_video_dir or not record_best_videos
+ self.best_video_dir = best_video_dir
+ if best_video_dir:
+ os.makedirs(best_video_dir, exist_ok=True)
+ self.max_video_length = max_video_length
+ self.best_video_base_path = None
+
+ self.ignore_first_episode = ignore_first_episode
+
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
+ super().on_step(timesteps_elapsed)
+ if self.timesteps_elapsed // self.step_freq >= len(self.stats):
+ self.evaluate()
+ return True
+
+ def evaluate(
+ self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
+ ) -> EpisodesStats:
+ start_time = perf_counter()
+ eval_stat = evaluate(
+ self.env,
+ self.policy,
+ n_episodes or self.n_episodes,
+ deterministic=self.deterministic,
+ print_returns=print_returns or False,
+ ignore_first_episode=self.ignore_first_episode,
+ )
+ end_time = perf_counter()
+ self.tb_writer.add_scalar(
+ "eval/steps_per_second",
+ eval_stat.length.sum() / (end_time - start_time),
+ self.timesteps_elapsed,
+ )
+ self.policy.train(True)
+ print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
+
+ self.stats.append(eval_stat)
+
+ if not self.best or eval_stat >= self.best:
+ strictly_better = not self.best or eval_stat > self.best
+ self.best = eval_stat
+ if self.save_best:
+ assert self.best_model_path
+ self.policy.save(self.best_model_path)
+ print("Saved best model")
+ self.best.write_to_tensorboard(
+ self.tb_writer, "best_eval", self.timesteps_elapsed
+ )
+ if strictly_better and self.record_best_videos:
+ assert self.video_env and self.best_video_dir
+ self.best_video_base_path = os.path.join(
+ self.best_video_dir, str(self.timesteps_elapsed)
+ )
+ video_wrapped = VecEpisodeRecorder(
+ self.video_env,
+ self.best_video_base_path,
+ max_video_length=self.max_video_length,
+ )
+ video_stats = evaluate(
+ video_wrapped,
+ self.policy,
+ 1,
+ deterministic=self.deterministic,
+ print_returns=False,
+ )
+ print(f"Saved best video: {video_stats}")
+
+ eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
+
+ return eval_stat
diff --git a/rl_algo_impls/shared/callbacks/optimize_callback.py b/rl_algo_impls/shared/callbacks/optimize_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f1cbcb79b30e905f78ecf9fc38c00f79bf9207
--- /dev/null
+++ b/rl_algo_impls/shared/callbacks/optimize_callback.py
@@ -0,0 +1,117 @@
+import numpy as np
+import optuna
+
+from time import perf_counter
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import NamedTuple, Union
+
+from rl_algo_impls.shared.callbacks.callback import Callback
+from rl_algo_impls.shared.callbacks.eval_callback import evaluate
+from rl_algo_impls.shared.policy.policy import Policy
+from rl_algo_impls.shared.stats import EpisodesStats
+from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, find_wrapper
+
+
+class Evaluation(NamedTuple):
+ eval_stat: EpisodesStats
+ train_stat: EpisodesStats
+ score: float
+
+
+class OptimizeCallback(Callback):
+ def __init__(
+ self,
+ policy: Policy,
+ env: VecEnv,
+ trial: optuna.Trial,
+ tb_writer: SummaryWriter,
+ step_freq: Union[int, float] = 50_000,
+ n_episodes: int = 10,
+ deterministic: bool = True,
+ ) -> None:
+ super().__init__()
+ self.policy = policy
+ self.env = env
+ self.trial = trial
+ self.tb_writer = tb_writer
+ self.step_freq = step_freq
+ self.n_episodes = n_episodes
+ self.deterministic = deterministic
+
+ stats_writer = find_wrapper(policy.env, EpisodeStatsWriter)
+ assert stats_writer
+ self.stats_writer = stats_writer
+
+ self.eval_step = 1
+ self.is_pruned = False
+ self.last_eval_stat = None
+ self.last_train_stat = None
+ self.last_score = -np.inf
+
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
+ super().on_step(timesteps_elapsed)
+ if self.timesteps_elapsed >= self.eval_step * self.step_freq:
+ self.evaluate()
+ return not self.is_pruned
+ return True
+
+ def evaluate(self) -> None:
+ self.last_eval_stat, self.last_train_stat, score = evaluation(
+ self.policy,
+ self.env,
+ self.tb_writer,
+ self.n_episodes,
+ self.deterministic,
+ self.timesteps_elapsed,
+ )
+ self.last_score = score
+
+ self.trial.report(score, self.eval_step)
+ if self.trial.should_prune():
+ self.is_pruned = True
+
+ self.eval_step += 1
+
+
+def evaluation(
+ policy: Policy,
+ env: VecEnv,
+ tb_writer: SummaryWriter,
+ n_episodes: int,
+ deterministic: bool,
+ timesteps_elapsed: int,
+) -> Evaluation:
+ start_time = perf_counter()
+ eval_stat = evaluate(
+ env,
+ policy,
+ n_episodes,
+ deterministic=deterministic,
+ print_returns=False,
+ )
+ end_time = perf_counter()
+ tb_writer.add_scalar(
+ "eval/steps_per_second",
+ eval_stat.length.sum() / (end_time - start_time),
+ timesteps_elapsed,
+ )
+ policy.train()
+ print(f"Eval Timesteps: {timesteps_elapsed} | {eval_stat}")
+ eval_stat.write_to_tensorboard(tb_writer, "eval", timesteps_elapsed)
+
+ stats_writer = find_wrapper(policy.env, EpisodeStatsWriter)
+ assert stats_writer
+
+ train_stat = EpisodesStats(stats_writer.episodes)
+ print(f" Train Stat: {train_stat}")
+
+ score = (eval_stat.score.mean + train_stat.score.mean) / 2
+ print(f" Score: {round(score, 2)}")
+ tb_writer.add_scalar(
+ "eval/score",
+ score,
+ timesteps_elapsed,
+ )
+
+ return Evaluation(eval_stat, train_stat, score)
diff --git a/rl_algo_impls/shared/gae.py b/rl_algo_impls/shared/gae.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4d02ded6524d99be6eacc79984e5349fa97af55
--- /dev/null
+++ b/rl_algo_impls/shared/gae.py
@@ -0,0 +1,67 @@
+import numpy as np
+import torch
+
+from typing import NamedTuple, Sequence
+
+from rl_algo_impls.shared.policy.on_policy import OnPolicy
+from rl_algo_impls.shared.trajectory import Trajectory
+
+
+class RtgAdvantage(NamedTuple):
+ rewards_to_go: torch.Tensor
+ advantage: torch.Tensor
+
+
+def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
+ dc = x.copy()
+ for i in reversed(range(len(x) - 1)):
+ dc[i] += gamma * dc[i + 1]
+ return dc
+
+
+def compute_advantage(
+ trajectories: Sequence[Trajectory],
+ policy: OnPolicy,
+ gamma: float,
+ gae_lambda: float,
+ device: torch.device,
+) -> torch.Tensor:
+ advantage = []
+ for traj in trajectories:
+ last_val = 0
+ if not traj.terminated and traj.next_obs is not None:
+ last_val = policy.value(traj.next_obs)
+ rew = np.append(np.array(traj.rew), last_val)
+ v = np.append(np.array(traj.v), last_val)
+ deltas = rew[:-1] + gamma * v[1:] - v[:-1]
+ advantage.append(discounted_cumsum(deltas, gamma * gae_lambda))
+ return torch.as_tensor(
+ np.concatenate(advantage), dtype=torch.float32, device=device
+ )
+
+
+def compute_rtg_and_advantage(
+ trajectories: Sequence[Trajectory],
+ policy: OnPolicy,
+ gamma: float,
+ gae_lambda: float,
+ device: torch.device,
+) -> RtgAdvantage:
+ rewards_to_go = []
+ advantages = []
+ for traj in trajectories:
+ last_val = 0
+ if not traj.terminated and traj.next_obs is not None:
+ last_val = policy.value(traj.next_obs)
+ rew = np.append(np.array(traj.rew), last_val)
+ v = np.append(np.array(traj.v), last_val)
+ deltas = rew[:-1] + gamma * v[1:] - v[:-1]
+ adv = discounted_cumsum(deltas, gamma * gae_lambda)
+ advantages.append(adv)
+ rewards_to_go.append(v[:-1] + adv)
+ return RtgAdvantage(
+ torch.as_tensor(
+ np.concatenate(rewards_to_go), dtype=torch.float32, device=device
+ ),
+ torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
+ )
diff --git a/rl_algo_impls/shared/module/feature_extractor.py b/rl_algo_impls/shared/module/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ccaeefae90377ca93aa67285e28d5bb136977e
--- /dev/null
+++ b/rl_algo_impls/shared/module/feature_extractor.py
@@ -0,0 +1,215 @@
+import gym
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from abc import ABC, abstractmethod
+from gym.spaces import Box, Discrete
+from stable_baselines3.common.preprocessing import get_flattened_obs_dim
+from typing import Dict, Optional, Sequence, Type
+
+from rl_algo_impls.shared.module.module import layer_init
+
+
+class CnnFeatureExtractor(nn.Module, ABC):
+ @abstractmethod
+ def __init__(
+ self,
+ in_channels: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ init_layers_orthogonal: Optional[bool] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+
+class NatureCnn(CnnFeatureExtractor):
+ """
+ CNN from DQN Nature paper: Mnih, Volodymyr, et al.
+ "Human-level control through deep reinforcement learning."
+ Nature 518.7540 (2015): 529-533.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ init_layers_orthogonal: Optional[bool] = None,
+ **kwargs,
+ ) -> None:
+ if init_layers_orthogonal is None:
+ init_layers_orthogonal = True
+ super().__init__(in_channels, activation, init_layers_orthogonal)
+ self.cnn = nn.Sequential(
+ layer_init(
+ nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
+ init_layers_orthogonal,
+ ),
+ activation(),
+ layer_init(
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
+ init_layers_orthogonal,
+ ),
+ activation(),
+ layer_init(
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
+ init_layers_orthogonal,
+ ),
+ activation(),
+ nn.Flatten(),
+ )
+
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
+ return self.cnn(obs)
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ init_layers_orthogonal: bool = False,
+ ) -> None:
+ super().__init__()
+ self.residual = nn.Sequential(
+ activation(),
+ layer_init(
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
+ ),
+ activation(),
+ layer_init(
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
+ ),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x + self.residual(x)
+
+
+class ConvSequence(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ init_layers_orthogonal: bool = False,
+ ) -> None:
+ super().__init__()
+ self.seq = nn.Sequential(
+ layer_init(
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
+ init_layers_orthogonal,
+ ),
+ nn.MaxPool2d(3, stride=2, padding=1),
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.seq(x)
+
+
+class ImpalaCnn(CnnFeatureExtractor):
+ """
+ IMPALA-style CNN architecture
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ init_layers_orthogonal: Optional[bool] = None,
+ impala_channels: Sequence[int] = (16, 32, 32),
+ **kwargs,
+ ) -> None:
+ if init_layers_orthogonal is None:
+ init_layers_orthogonal = False
+ super().__init__(in_channels, activation, init_layers_orthogonal)
+ sequences = []
+ for out_channels in impala_channels:
+ sequences.append(
+ ConvSequence(
+ in_channels, out_channels, activation, init_layers_orthogonal
+ )
+ )
+ in_channels = out_channels
+ sequences.extend(
+ [
+ activation(),
+ nn.Flatten(),
+ ]
+ )
+ self.seq = nn.Sequential(*sequences)
+
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
+ return self.seq(obs)
+
+
+CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
+ "nature": NatureCnn,
+ "impala": ImpalaCnn,
+}
+
+
+class FeatureExtractor(nn.Module):
+ def __init__(
+ self,
+ obs_space: gym.Space,
+ activation: Type[nn.Module],
+ init_layers_orthogonal: bool = False,
+ cnn_feature_dim: int = 512,
+ cnn_style: str = "nature",
+ cnn_layers_init_orthogonal: Optional[bool] = None,
+ impala_channels: Sequence[int] = (16, 32, 32),
+ ) -> None:
+ super().__init__()
+ if isinstance(obs_space, Box):
+ # Conv2D: (channels, height, width)
+ if len(obs_space.shape) == 3:
+ cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
+ obs_space.shape[0],
+ activation,
+ init_layers_orthogonal=cnn_layers_init_orthogonal,
+ impala_channels=impala_channels,
+ )
+
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
+ if len(obs.shape) == 3:
+ obs = obs.unsqueeze(0)
+ return obs.float() / 255.0
+
+ with torch.no_grad():
+ cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
+ self.preprocess = preprocess
+ self.feature_extractor = nn.Sequential(
+ cnn,
+ layer_init(
+ nn.Linear(cnn_out.shape[1], cnn_feature_dim),
+ init_layers_orthogonal,
+ ),
+ activation(),
+ )
+ self.out_dim = cnn_feature_dim
+ elif len(obs_space.shape) == 1:
+
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
+ if len(obs.shape) == 1:
+ obs = obs.unsqueeze(0)
+ return obs.float()
+
+ self.preprocess = preprocess
+ self.feature_extractor = nn.Flatten()
+ self.out_dim = get_flattened_obs_dim(obs_space)
+ else:
+ raise ValueError(f"Unsupported observation space: {obs_space}")
+ elif isinstance(obs_space, Discrete):
+ self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
+ self.feature_extractor = nn.Flatten()
+ self.out_dim = obs_space.n
+ else:
+ raise NotImplementedError
+
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
+ if self.preprocess:
+ obs = self.preprocess(obs)
+ return self.feature_extractor(obs)
diff --git a/rl_algo_impls/shared/module/module.py b/rl_algo_impls/shared/module/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..c579fb2a3808de47ec8d4c5233eea947b5cf0d28
--- /dev/null
+++ b/rl_algo_impls/shared/module/module.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch.nn as nn
+
+from typing import Sequence, Type
+
+
+def mlp(
+ layer_sizes: Sequence[int],
+ activation: Type[nn.Module],
+ output_activation: Type[nn.Module] = nn.Identity,
+ init_layers_orthogonal: bool = False,
+ final_layer_gain: float = np.sqrt(2),
+) -> nn.Module:
+ layers = []
+ for i in range(len(layer_sizes) - 2):
+ layers.append(
+ layer_init(
+ nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
+ )
+ )
+ layers.append(activation())
+ layers.append(
+ layer_init(
+ nn.Linear(layer_sizes[-2], layer_sizes[-1]),
+ init_layers_orthogonal,
+ std=final_layer_gain,
+ )
+ )
+ layers.append(output_activation())
+ return nn.Sequential(*layers)
+
+
+def layer_init(
+ layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
+) -> nn.Module:
+ if not init_layers_orthogonal:
+ return layer
+ nn.init.orthogonal_(layer.weight, std) # type: ignore
+ nn.init.constant_(layer.bias, 0.0) # type: ignore
+ return layer
diff --git a/rl_algo_impls/shared/policy/actor.py b/rl_algo_impls/shared/policy/actor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ec5c8615bdb2bb24e5a6d37efaec204977cd2a
--- /dev/null
+++ b/rl_algo_impls/shared/policy/actor.py
@@ -0,0 +1,310 @@
+import gym
+import torch
+import torch.nn as nn
+
+from abc import ABC, abstractmethod
+from gym.spaces import Box, Discrete
+from torch.distributions import Categorical, Distribution, Normal
+from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
+
+from rl_algo_impls.shared.module.module import mlp
+
+
+class PiForward(NamedTuple):
+ pi: Distribution
+ logp_a: Optional[torch.Tensor]
+ entropy: Optional[torch.Tensor]
+
+
+class Actor(nn.Module, ABC):
+ @abstractmethod
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
+ ...
+
+
+class CategoricalActorHead(Actor):
+ def __init__(
+ self,
+ act_dim: int,
+ hidden_sizes: Sequence[int] = (32,),
+ activation: Type[nn.Module] = nn.Tanh,
+ init_layers_orthogonal: bool = True,
+ ) -> None:
+ super().__init__()
+ layer_sizes = tuple(hidden_sizes) + (act_dim,)
+ self._fc = mlp(
+ layer_sizes,
+ activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ final_layer_gain=0.01,
+ )
+
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
+ logits = self._fc(obs)
+ pi = Categorical(logits=logits)
+ logp_a = None
+ entropy = None
+ if a is not None:
+ logp_a = pi.log_prob(a)
+ entropy = pi.entropy()
+ return PiForward(pi, logp_a, entropy)
+
+
+class GaussianDistribution(Normal):
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
+ return super().log_prob(a).sum(axis=-1)
+
+ def sample(self) -> torch.Tensor:
+ return self.rsample()
+
+
+class GaussianActorHead(Actor):
+ def __init__(
+ self,
+ act_dim: int,
+ hidden_sizes: Sequence[int] = (32,),
+ activation: Type[nn.Module] = nn.Tanh,
+ init_layers_orthogonal: bool = True,
+ log_std_init: float = -0.5,
+ ) -> None:
+ super().__init__()
+ layer_sizes = tuple(hidden_sizes) + (act_dim,)
+ self.mu_net = mlp(
+ layer_sizes,
+ activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ final_layer_gain=0.01,
+ )
+ self.log_std = nn.Parameter(
+ torch.ones(act_dim, dtype=torch.float32) * log_std_init
+ )
+
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
+ mu = self.mu_net(obs)
+ std = torch.exp(self.log_std)
+ return GaussianDistribution(mu, std)
+
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
+ pi = self._distribution(obs)
+ logp_a = None
+ entropy = None
+ if a is not None:
+ logp_a = pi.log_prob(a)
+ entropy = pi.entropy()
+ return PiForward(pi, logp_a, entropy)
+
+
+class TanhBijector:
+ def __init__(self, epsilon: float = 1e-6) -> None:
+ self.epsilon = epsilon
+
+ @staticmethod
+ def forward(x: torch.Tensor) -> torch.Tensor:
+ return torch.tanh(x)
+
+ @staticmethod
+ def inverse(y: torch.Tensor) -> torch.Tensor:
+ eps = torch.finfo(y.dtype).eps
+ clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps)
+ return torch.atanh(clamped_y)
+
+ def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
+ return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
+
+
+def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor:
+ if len(tensor.shape) > 1:
+ return tensor.sum(dim=1)
+ return tensor.sum()
+
+
+class StateDependentNoiseDistribution(Normal):
+ def __init__(
+ self,
+ loc,
+ scale,
+ latent_sde: torch.Tensor,
+ exploration_mat: torch.Tensor,
+ exploration_matrices: torch.Tensor,
+ bijector: Optional[TanhBijector] = None,
+ validate_args=None,
+ ):
+ super().__init__(loc, scale, validate_args)
+ self.latent_sde = latent_sde
+ self.exploration_mat = exploration_mat
+ self.exploration_matrices = exploration_matrices
+ self.bijector = bijector
+
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
+ gaussian_a = self.bijector.inverse(a) if self.bijector else a
+ log_prob = sum_independent_dims(super().log_prob(gaussian_a))
+ if self.bijector:
+ log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1)
+ return log_prob
+
+ def sample(self) -> torch.Tensor:
+ noise = self._get_noise()
+ actions = self.mean + noise
+ return self.bijector.forward(actions) if self.bijector else actions
+
+ def _get_noise(self) -> torch.Tensor:
+ if len(self.latent_sde) == 1 or len(self.latent_sde) != len(
+ self.exploration_matrices
+ ):
+ return torch.mm(self.latent_sde, self.exploration_mat)
+ # (batch_size, n_features) -> (batch_size, 1, n_features)
+ latent_sde = self.latent_sde.unsqueeze(dim=1)
+ # (batch_size, 1, n_actions)
+ noise = torch.bmm(latent_sde, self.exploration_matrices)
+ return noise.squeeze(dim=1)
+
+ @property
+ def mode(self) -> torch.Tensor:
+ mean = super().mode
+ return self.bijector.forward(mean) if self.bijector else mean
+
+
+StateDependentNoiseActorHeadSelf = TypeVar(
+ "StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead"
+)
+
+
+class StateDependentNoiseActorHead(Actor):
+ def __init__(
+ self,
+ act_dim: int,
+ hidden_sizes: Sequence[int] = (32,),
+ activation: Type[nn.Module] = nn.Tanh,
+ init_layers_orthogonal: bool = True,
+ log_std_init: float = -0.5,
+ full_std: bool = True,
+ squash_output: bool = False,
+ learn_std: bool = False,
+ ) -> None:
+ super().__init__()
+ self.act_dim = act_dim
+ layer_sizes = tuple(hidden_sizes) + (self.act_dim,)
+ if len(layer_sizes) == 2:
+ self.latent_net = nn.Identity()
+ elif len(layer_sizes) > 2:
+ self.latent_net = mlp(
+ layer_sizes[:-1],
+ activation,
+ output_activation=activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ )
+ else:
+ raise ValueError("hidden_sizes must be of at least length 1")
+ self.mu_net = mlp(
+ layer_sizes[-2:],
+ activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ final_layer_gain=0.01,
+ )
+ self.full_std = full_std
+ std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1)
+ self.log_std = nn.Parameter(
+ torch.ones(std_dim, dtype=torch.float32) * log_std_init
+ )
+ self.bijector = TanhBijector() if squash_output else None
+ self.learn_std = learn_std
+ self.device = None
+
+ self.exploration_mat = None
+ self.exploration_matrices = None
+ self.sample_weights()
+
+ def to(
+ self: StateDependentNoiseActorHeadSelf,
+ device: Optional[torch.device] = None,
+ dtype: Optional[Union[torch.dtype, str]] = None,
+ non_blocking: bool = False,
+ ) -> StateDependentNoiseActorHeadSelf:
+ super().to(device, dtype, non_blocking)
+ self.device = device
+ return self
+
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
+ latent = self.latent_net(obs)
+ mu = self.mu_net(latent)
+ latent_sde = latent if self.learn_std else latent.detach()
+ variance = torch.mm(latent_sde**2, self._get_std() ** 2)
+ assert self.exploration_mat is not None
+ assert self.exploration_matrices is not None
+ return StateDependentNoiseDistribution(
+ mu,
+ torch.sqrt(variance + 1e-6),
+ latent_sde,
+ self.exploration_mat,
+ self.exploration_matrices,
+ self.bijector,
+ )
+
+ def _get_std(self) -> torch.Tensor:
+ std = torch.exp(self.log_std)
+ if self.full_std:
+ return std
+ ones = torch.ones(self.log_std.shape[0], self.act_dim)
+ if self.device:
+ ones = ones.to(self.device)
+ return ones * std
+
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
+ pi = self._distribution(obs)
+ logp_a = None
+ entropy = None
+ if a is not None:
+ logp_a = pi.log_prob(a)
+ entropy = -logp_a if self.bijector else sum_independent_dims(pi.entropy())
+ return PiForward(pi, logp_a, entropy)
+
+ def sample_weights(self, batch_size: int = 1) -> None:
+ std = self._get_std()
+ weights_dist = Normal(torch.zeros_like(std), std)
+ # Reparametrization trick to pass gradients
+ self.exploration_mat = weights_dist.rsample()
+ self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
+
+
+def actor_head(
+ action_space: gym.Space,
+ hidden_sizes: Sequence[int],
+ init_layers_orthogonal: bool,
+ activation: Type[nn.Module],
+ log_std_init: float = -0.5,
+ use_sde: bool = False,
+ full_std: bool = True,
+ squash_output: bool = False,
+) -> Actor:
+ assert not use_sde or isinstance(
+ action_space, Box
+ ), "use_sde only valid if Box action_space"
+ assert not squash_output or use_sde, "squash_output only valid if use_sde"
+ if isinstance(action_space, Discrete):
+ return CategoricalActorHead(
+ action_space.n,
+ hidden_sizes=hidden_sizes,
+ activation=activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ )
+ elif isinstance(action_space, Box):
+ if use_sde:
+ return StateDependentNoiseActorHead(
+ action_space.shape[0],
+ hidden_sizes=hidden_sizes,
+ activation=activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ log_std_init=log_std_init,
+ full_std=full_std,
+ squash_output=squash_output,
+ )
+ else:
+ return GaussianActorHead(
+ action_space.shape[0],
+ hidden_sizes=hidden_sizes,
+ activation=activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ log_std_init=log_std_init,
+ )
+ else:
+ raise ValueError(f"Unsupported action space: {action_space}")
diff --git a/rl_algo_impls/shared/policy/critic.py b/rl_algo_impls/shared/policy/critic.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fceb1c3d31d7133fae14878f03662a504eccda7
--- /dev/null
+++ b/rl_algo_impls/shared/policy/critic.py
@@ -0,0 +1,28 @@
+import gym
+import torch
+import torch.nn as nn
+
+from typing import Sequence, Type
+
+from rl_algo_impls.shared.module.module import mlp
+
+
+class CriticHead(nn.Module):
+ def __init__(
+ self,
+ hidden_sizes: Sequence[int] = (32,),
+ activation: Type[nn.Module] = nn.Tanh,
+ init_layers_orthogonal: bool = True,
+ ) -> None:
+ super().__init__()
+ layer_sizes = tuple(hidden_sizes) + (1,)
+ self._fc = mlp(
+ layer_sizes,
+ activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ final_layer_gain=1.0,
+ )
+
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
+ v = self._fc(obs)
+ return v.squeeze(-1)
diff --git a/rl_algo_impls/shared/policy/on_policy.py b/rl_algo_impls/shared/policy/on_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1fc3e11a31dc840903b92ec94c5e361f51ad99
--- /dev/null
+++ b/rl_algo_impls/shared/policy/on_policy.py
@@ -0,0 +1,226 @@
+import gym
+import numpy as np
+import torch
+
+from abc import abstractmethod
+from gym.spaces import Box, Discrete, Space
+from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
+
+from rl_algo_impls.shared.module.feature_extractor import FeatureExtractor
+from rl_algo_impls.shared.policy.actor import (
+ PiForward,
+ StateDependentNoiseActorHead,
+ actor_head,
+)
+from rl_algo_impls.shared.policy.critic import CriticHead
+from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecEnv,
+ VecEnvObs,
+ single_observation_space,
+ single_action_space,
+)
+
+
+class Step(NamedTuple):
+ a: np.ndarray
+ v: np.ndarray
+ logp_a: np.ndarray
+ clamped_a: np.ndarray
+
+
+class ACForward(NamedTuple):
+ logp_a: torch.Tensor
+ entropy: torch.Tensor
+ v: torch.Tensor
+
+
+FEAT_EXT_FILE_NAME = "feat_ext.pt"
+V_FEAT_EXT_FILE_NAME = "v_feat_ext.pt"
+PI_FILE_NAME = "pi.pt"
+V_FILE_NAME = "v.pt"
+ActorCriticSelf = TypeVar("ActorCriticSelf", bound="ActorCritic")
+
+
+def clamp_actions(
+ actions: np.ndarray, action_space: gym.Space, squash_output: bool
+) -> np.ndarray:
+ if isinstance(action_space, Box):
+ low, high = action_space.low, action_space.high # type: ignore
+ if squash_output:
+ # Squashed output is already between -1 and 1. Rescale if the actual
+ # output needs to something other than -1 and 1
+ return low + 0.5 * (actions + 1) * (high - low)
+ else:
+ return np.clip(actions, low, high)
+ return actions
+
+
+def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
+ if isinstance(obs_space, Box):
+ if len(obs_space.shape) == 3:
+ # By default feature extractor to output has no hidden layers
+ return []
+ elif len(obs_space.shape) == 1:
+ return [64, 64]
+ else:
+ raise ValueError(f"Unsupported observation space: {obs_space}")
+ elif isinstance(obs_space, Discrete):
+ return [64]
+ else:
+ raise ValueError(f"Unsupported observation space: {obs_space}")
+
+
+class OnPolicy(Policy):
+ @abstractmethod
+ def value(self, obs: VecEnvObs) -> np.ndarray:
+ ...
+
+ @abstractmethod
+ def step(self, obs: VecEnvObs) -> Step:
+ ...
+
+
+class ActorCritic(OnPolicy):
+ def __init__(
+ self,
+ env: VecEnv,
+ pi_hidden_sizes: Optional[Sequence[int]] = None,
+ v_hidden_sizes: Optional[Sequence[int]] = None,
+ init_layers_orthogonal: bool = True,
+ activation_fn: str = "tanh",
+ log_std_init: float = -0.5,
+ use_sde: bool = False,
+ full_std: bool = True,
+ squash_output: bool = False,
+ share_features_extractor: bool = True,
+ cnn_feature_dim: int = 512,
+ cnn_style: str = "nature",
+ cnn_layers_init_orthogonal: Optional[bool] = None,
+ impala_channels: Sequence[int] = (16, 32, 32),
+ **kwargs,
+ ) -> None:
+ super().__init__(env, **kwargs)
+
+ observation_space = single_observation_space(env)
+ action_space = single_action_space(env)
+
+ pi_hidden_sizes = (
+ pi_hidden_sizes
+ if pi_hidden_sizes is not None
+ else default_hidden_sizes(observation_space)
+ )
+ v_hidden_sizes = (
+ v_hidden_sizes
+ if v_hidden_sizes is not None
+ else default_hidden_sizes(observation_space)
+ )
+
+ activation = ACTIVATION[activation_fn]
+ self.action_space = action_space
+ self.squash_output = squash_output
+ self.share_features_extractor = share_features_extractor
+ self._feature_extractor = FeatureExtractor(
+ observation_space,
+ activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ cnn_feature_dim=cnn_feature_dim,
+ cnn_style=cnn_style,
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
+ impala_channels=impala_channels,
+ )
+ self._pi = actor_head(
+ self.action_space,
+ (self._feature_extractor.out_dim,) + tuple(pi_hidden_sizes),
+ init_layers_orthogonal,
+ activation,
+ log_std_init=log_std_init,
+ use_sde=use_sde,
+ full_std=full_std,
+ squash_output=squash_output,
+ )
+
+ if not share_features_extractor:
+ self._v_feature_extractor = FeatureExtractor(
+ observation_space,
+ activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ cnn_feature_dim=cnn_feature_dim,
+ cnn_style=cnn_style,
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
+ )
+ v_hidden_sizes = (self._v_feature_extractor.out_dim,) + tuple(
+ v_hidden_sizes
+ )
+ else:
+ self._v_feature_extractor = None
+ v_hidden_sizes = (self._feature_extractor.out_dim,) + tuple(v_hidden_sizes)
+ self._v = CriticHead(
+ hidden_sizes=v_hidden_sizes,
+ activation=activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ )
+
+ def _pi_forward(
+ self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
+ ) -> Tuple[PiForward, torch.Tensor]:
+ p_fe = self._feature_extractor(obs)
+ pi_forward = self._pi(p_fe, action)
+
+ return pi_forward, p_fe
+
+ def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
+ v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
+ return self._v(v_fe)
+
+ def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
+ (_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
+ v = self._v_forward(obs, p_fc)
+
+ assert logp_a is not None
+ assert entropy is not None
+ return ACForward(logp_a, entropy, v)
+
+ def value(self, obs: VecEnvObs) -> np.ndarray:
+ o = self._as_tensor(obs)
+ with torch.no_grad():
+ fe = (
+ self._v_feature_extractor(o)
+ if self._v_feature_extractor
+ else self._feature_extractor(o)
+ )
+ v = self._v(fe)
+ return v.cpu().numpy()
+
+ def step(self, obs: VecEnvObs) -> Step:
+ o = self._as_tensor(obs)
+ with torch.no_grad():
+ (pi, _, _), p_fc = self._pi_forward(o)
+ a = pi.sample()
+ logp_a = pi.log_prob(a)
+
+ v = self._v_forward(o, p_fc)
+
+ a_np = a.cpu().numpy()
+ clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
+ return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
+
+ def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
+ if not deterministic:
+ return self.step(obs).clamped_a
+ else:
+ o = self._as_tensor(obs)
+ with torch.no_grad():
+ (pi, _, _), _ = self._pi_forward(o)
+ a = pi.mode
+ return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
+
+ def load(self, path: str) -> None:
+ super().load(path)
+ self.reset_noise()
+
+ def reset_noise(self, batch_size: Optional[int] = None) -> None:
+ if isinstance(self._pi, StateDependentNoiseActorHead):
+ self._pi.sample_weights(
+ batch_size=batch_size if batch_size else self.env.num_envs
+ )
diff --git a/rl_algo_impls/shared/policy/optimize_on_policy.py b/rl_algo_impls/shared/policy/optimize_on_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab343131a2db347623bb3208a2478bf168f730b
--- /dev/null
+++ b/rl_algo_impls/shared/policy/optimize_on_policy.py
@@ -0,0 +1,35 @@
+import optuna
+
+from gym.spaces import Box
+from typing import Any, Dict
+
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecEnv,
+ single_action_space,
+)
+
+
+def sample_on_policy_hyperparams(
+ trial: optuna.Trial, policy_hparams: Dict[str, Any], env: VecEnv
+) -> Dict[str, Any]:
+ act_space = single_action_space(env)
+
+ policy_hparams["init_layers_orthogonal"] = trial.suggest_categorical(
+ "init_layers_orthogonal", [True, False]
+ )
+ policy_hparams["activation_fn"] = trial.suggest_categorical(
+ "activation_fn", ["tanh", "relu"]
+ )
+
+ if isinstance(act_space, Box):
+ policy_hparams["log_std_init"] = trial.suggest_float("log_std_init", -5, 0.5)
+ policy_hparams["use_sde"] = trial.suggest_categorical("use_sde", [False, True])
+
+ if policy_hparams.get("use_sde", False):
+ policy_hparams["squash_output"] = trial.suggest_categorical(
+ "squash_output", [False, True]
+ )
+ elif "squash_output" in policy_hparams:
+ del policy_hparams["squash_output"]
+
+ return policy_hparams
diff --git a/rl_algo_impls/shared/policy/policy.py b/rl_algo_impls/shared/policy/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..41d49004aec251235744341890412db9ef7ce389
--- /dev/null
+++ b/rl_algo_impls/shared/policy/policy.py
@@ -0,0 +1,114 @@
+import numpy as np
+import os
+import torch
+import torch.nn as nn
+
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from stable_baselines3.common.vec_env import unwrap_vec_normalize
+from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
+from typing import Dict, Optional, Type, TypeVar, Union
+
+from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs, find_wrapper
+
+ACTIVATION: Dict[str, Type[nn.Module]] = {
+ "tanh": nn.Tanh,
+ "relu": nn.ReLU,
+}
+
+VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
+MODEL_FILENAME = "model.pth"
+NORMALIZE_OBSERVATION_FILENAME = "norm_obs.npz"
+NORMALIZE_REWARD_FILENAME = "norm_reward.npz"
+
+PolicySelf = TypeVar("PolicySelf", bound="Policy")
+
+
+class Policy(nn.Module, ABC):
+ @abstractmethod
+ def __init__(self, env: VecEnv, **kwargs) -> None:
+ super().__init__()
+ self.env = env
+ self.vec_normalize = unwrap_vec_normalize(env)
+ self.norm_observation = find_wrapper(env, NormalizeObservation)
+ self.norm_reward = find_wrapper(env, NormalizeReward)
+ self.device = None
+
+ def to(
+ self: PolicySelf,
+ device: Optional[torch.device] = None,
+ dtype: Optional[Union[torch.dtype, str]] = None,
+ non_blocking: bool = False,
+ ) -> PolicySelf:
+ super().to(device, dtype, non_blocking)
+ self.device = device
+ return self
+
+ @abstractmethod
+ def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
+ ...
+
+ def save(self, path: str) -> None:
+ os.makedirs(path, exist_ok=True)
+
+ if self.vec_normalize:
+ self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
+ if self.norm_observation:
+ self.norm_observation.save(
+ os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
+ )
+ if self.norm_reward:
+ self.norm_reward.save(os.path.join(path, NORMALIZE_REWARD_FILENAME))
+ torch.save(
+ self.state_dict(),
+ os.path.join(path, MODEL_FILENAME),
+ )
+
+ def load(self, path: str) -> None:
+ # VecNormalize load occurs in env.py
+ self.load_state_dict(
+ torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
+ )
+ if self.norm_observation:
+ self.norm_observation.load(
+ os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
+ )
+ if self.norm_reward:
+ self.norm_reward.load(os.path.join(path, NORMALIZE_REWARD_FILENAME))
+
+ def reset_noise(self) -> None:
+ pass
+
+ def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
+ assert isinstance(obs, np.ndarray)
+ o = torch.as_tensor(obs)
+ if self.device is not None:
+ o = o.to(self.device)
+ return o
+
+ def num_trainable_parameters(self) -> int:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
+
+ def num_parameters(self) -> int:
+ return sum(p.numel() for p in self.parameters())
+
+ def sync_normalization(self, destination_env) -> None:
+ current = destination_env
+ while current != current.unwrapped:
+ if isinstance(current, VecNormalize):
+ assert self.vec_normalize
+ current.ret_rms = deepcopy(self.vec_normalize.ret_rms)
+ if hasattr(self.vec_normalize, "obs_rms"):
+ current.obs_rms = deepcopy(self.vec_normalize.obs_rms)
+ elif isinstance(current, NormalizeObservation):
+ assert self.norm_observation
+ current.rms = deepcopy(self.norm_observation.rms)
+ elif isinstance(current, NormalizeReward):
+ assert self.norm_reward
+ current.rms = deepcopy(self.norm_reward.rms)
+ current = getattr(current, "venv", getattr(current, "env", current))
+ if not current:
+ raise AttributeError(
+ f"{type(current)} doesn't include env or venv attribute"
+ )
diff --git a/rl_algo_impls/shared/schedule.py b/rl_algo_impls/shared/schedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..1461a782341eff5d89a53f16aebdc268bf9f7f52
--- /dev/null
+++ b/rl_algo_impls/shared/schedule.py
@@ -0,0 +1,31 @@
+from torch.optim import Optimizer
+from typing import Callable
+
+Schedule = Callable[[float], float]
+
+
+def linear_schedule(
+ start_val: float, end_val: float, end_fraction: float = 1.0
+) -> Schedule:
+ def func(progress_fraction: float) -> float:
+ if progress_fraction >= end_fraction:
+ return end_val
+ else:
+ return start_val + (end_val - start_val) * progress_fraction / end_fraction
+
+ return func
+
+
+def constant_schedule(val: float) -> Schedule:
+ return lambda f: val
+
+
+def schedule(name: str, start_val: float) -> Schedule:
+ if name == "linear":
+ return linear_schedule(start_val, 0)
+ return constant_schedule(start_val)
+
+
+def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = learning_rate
diff --git a/rl_algo_impls/shared/stats.py b/rl_algo_impls/shared/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..2315e6bb0de04ee56ca577cb7444f17e93e88fc0
--- /dev/null
+++ b/rl_algo_impls/shared/stats.py
@@ -0,0 +1,160 @@
+import numpy as np
+
+from dataclasses import dataclass
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Dict, List, Optional, Sequence, Union, TypeVar
+
+
+@dataclass
+class Episode:
+ score: float = 0
+ length: int = 0
+
+
+StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
+
+
+@dataclass
+class Statistic:
+ values: np.ndarray
+ round_digits: int = 2
+
+ @property
+ def mean(self) -> float:
+ return np.mean(self.values).item()
+
+ @property
+ def std(self) -> float:
+ return np.std(self.values).item()
+
+ @property
+ def min(self) -> float:
+ return np.min(self.values).item()
+
+ @property
+ def max(self) -> float:
+ return np.max(self.values).item()
+
+ def sum(self) -> float:
+ return np.sum(self.values).item()
+
+ def __len__(self) -> int:
+ return len(self.values)
+
+ def _diff(self: StatisticSelf, o: StatisticSelf) -> float:
+ return (self.mean - self.std) - (o.mean - o.std)
+
+ def __gt__(self: StatisticSelf, o: StatisticSelf) -> bool:
+ return self._diff(o) > 0
+
+ def __ge__(self: StatisticSelf, o: StatisticSelf) -> bool:
+ return self._diff(o) >= 0
+
+ def __repr__(self) -> str:
+ mean = round(self.mean, self.round_digits)
+ std = round(self.std, self.round_digits)
+ if self.round_digits == 0:
+ mean = int(mean)
+ std = int(std)
+ return f"{mean} +/- {std}"
+
+ def to_dict(self) -> Dict[str, float]:
+ return {
+ "mean": self.mean,
+ "std": self.std,
+ "min": self.min,
+ "max": self.max,
+ }
+
+
+EpisodesStatsSelf = TypeVar("EpisodesStatsSelf", bound="EpisodesStats")
+
+
+class EpisodesStats:
+ episodes: Sequence[Episode]
+ simple: bool
+ score: Statistic
+ length: Statistic
+
+ def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
+ self.episodes = episodes
+ self.simple = simple
+ self.score = Statistic(np.array([e.score for e in episodes]))
+ self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
+
+ def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
+ return self.score > o.score
+
+ def __ge__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
+ return self.score >= o.score
+
+ def __repr__(self) -> str:
+ return (
+ f"Score: {self.score} ({round(self.score.mean - self.score.std, 2)}) | "
+ f"Length: {self.length}"
+ )
+
+ def __len__(self) -> int:
+ return len(self.episodes)
+
+ def _asdict(self) -> dict:
+ return {
+ "n_episodes": len(self.episodes),
+ "score": self.score.to_dict(),
+ "length": self.length.to_dict(),
+ }
+
+ def write_to_tensorboard(
+ self, tb_writer: SummaryWriter, main_tag: str, global_step: Optional[int] = None
+ ) -> None:
+ stats = {"mean": self.score.mean}
+ if not self.simple:
+ stats.update(
+ {
+ "min": self.score.min,
+ "max": self.score.max,
+ "result": self.score.mean - self.score.std,
+ "n_episodes": len(self.episodes),
+ "length": self.length.mean,
+ }
+ )
+ for name, value in stats.items():
+ tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
+
+
+class EpisodeAccumulator:
+ def __init__(self, num_envs: int):
+ self._episodes = []
+ self.current_episodes = [Episode() for _ in range(num_envs)]
+
+ @property
+ def episodes(self) -> List[Episode]:
+ return self._episodes
+
+ def step(self, reward: np.ndarray, done: np.ndarray) -> None:
+ for idx, current in enumerate(self.current_episodes):
+ current.score += reward[idx]
+ current.length += 1
+ if done[idx]:
+ self._episodes.append(current)
+ self.current_episodes[idx] = Episode()
+ self.on_done(idx, current)
+
+ def __len__(self) -> int:
+ return len(self.episodes)
+
+ def on_done(self, ep_idx: int, episode: Episode) -> None:
+ pass
+
+ def stats(self) -> EpisodesStats:
+ return EpisodesStats(self.episodes)
+
+
+def log_scalars(
+ tb_writer: SummaryWriter,
+ main_tag: str,
+ tag_scalar_dict: Dict[str, Union[int, float]],
+ global_step: int,
+) -> None:
+ for tag, value in tag_scalar_dict.items():
+ tb_writer.add_scalar(f"{main_tag}/{tag}", value, global_step)
diff --git a/rl_algo_impls/shared/trajectory.py b/rl_algo_impls/shared/trajectory.py
new file mode 100644
index 0000000000000000000000000000000000000000..73427dca6429f4e1850b716b7b0967fd3b68ee0d
--- /dev/null
+++ b/rl_algo_impls/shared/trajectory.py
@@ -0,0 +1,81 @@
+import numpy as np
+
+from dataclasses import dataclass, field
+from typing import Generic, List, Optional, Type, TypeVar
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
+
+
+@dataclass
+class Trajectory:
+ obs: List[np.ndarray] = field(default_factory=list)
+ act: List[np.ndarray] = field(default_factory=list)
+ next_obs: Optional[np.ndarray] = None
+ rew: List[float] = field(default_factory=list)
+ terminated: bool = False
+ v: List[float] = field(default_factory=list)
+
+ def add(
+ self,
+ obs: np.ndarray,
+ act: np.ndarray,
+ next_obs: np.ndarray,
+ rew: float,
+ terminated: bool,
+ v: float,
+ ):
+ self.obs.append(obs)
+ self.act.append(act)
+ self.next_obs = next_obs if not terminated else None
+ self.rew.append(rew)
+ self.terminated = terminated
+ self.v.append(v)
+
+ def __len__(self) -> int:
+ return len(self.obs)
+
+
+T = TypeVar("T", bound=Trajectory)
+
+
+class TrajectoryAccumulator(Generic[T]):
+ def __init__(self, num_envs: int, trajectory_class: Type[T] = Trajectory) -> None:
+ self.num_envs = num_envs
+ self.trajectory_class = trajectory_class
+
+ self._trajectories = []
+ self._current_trajectories = [trajectory_class() for _ in range(num_envs)]
+
+ def step(
+ self,
+ obs: VecEnvObs,
+ action: np.ndarray,
+ next_obs: VecEnvObs,
+ reward: np.ndarray,
+ done: np.ndarray,
+ val: np.ndarray,
+ *args,
+ ) -> None:
+ assert isinstance(obs, np.ndarray)
+ assert isinstance(next_obs, np.ndarray)
+ for i, args in enumerate(zip(obs, action, next_obs, reward, done, val, *args)):
+ trajectory = self._current_trajectories[i]
+ # TODO: Eventually take advantage of terminated/truncated differentiation in
+ # later versions of gym.
+ trajectory.add(*args)
+ if done[i]:
+ self._trajectories.append(trajectory)
+ self._current_trajectories[i] = self.trajectory_class()
+ self.on_done(i, trajectory)
+
+ @property
+ def all_trajectories(self) -> List[T]:
+ return self._trajectories + list(
+ filter(lambda t: len(t), self._current_trajectories)
+ )
+
+ def n_timesteps(self) -> int:
+ return sum(len(t) for t in self.all_trajectories)
+
+ def on_done(self, env_idx: int, trajectory: T) -> None:
+ pass
diff --git a/rl_algo_impls/train.py b/rl_algo_impls/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1518ca570e1aade189c976fcfc654dbfb4ffc18
--- /dev/null
+++ b/rl_algo_impls/train.py
@@ -0,0 +1,67 @@
+# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
+import os
+
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+
+from multiprocessing import Pool
+
+from rl_algo_impls.runner.running_utils import base_parser
+from rl_algo_impls.runner.train import train as runner_train, TrainArgs
+
+
+def train() -> None:
+ parser = base_parser()
+ parser.add_argument(
+ "--wandb-project-name",
+ type=str,
+ default="rl-algo-impls",
+ help="WandB project name to upload training data to. If none, won't upload.",
+ )
+ parser.add_argument(
+ "--wandb-entity",
+ type=str,
+ default=None,
+ help="WandB team of project. None uses default entity",
+ )
+ parser.add_argument(
+ "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
+ )
+ parser.add_argument(
+ "--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
+ )
+ parser.add_argument(
+ "--virtual-display", action="store_true", help="Use headless virtual display"
+ )
+ # parser.set_defaults(
+ # algo=["ppo"],
+ # env=["CartPole-v1"],
+ # seed=[10],
+ # pool_size=3,
+ # )
+ args = parser.parse_args()
+ print(args)
+
+ if args.virtual_display:
+ from pyvirtualdisplay.display import Display
+
+ virtual_display = Display(visible=False, size=(1400, 900))
+ virtual_display.start()
+ # virtual_display isn't a TrainArg so must be removed
+ delattr(args, "virtual_display")
+
+ pool_size = min(args.pool_size, len(args.seed))
+ # pool_size isn't a TrainArg so must be removed from args
+ delattr(args, "pool_size")
+
+ train_args = TrainArgs.expand_from_dict(vars(args))
+ if len(train_args) == 1:
+ runner_train(train_args[0])
+ else:
+ # Force a new process for each job to get around wandb not allowing more than one
+ # wandb.tensorboard.patch call per process.
+ with Pool(pool_size, maxtasksperchild=1) as p:
+ p.map(runner_train, train_args)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/rl_algo_impls/tuning/optimize_env.py b/rl_algo_impls/tuning/optimize_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7a366266857dde52863b454fd4d64664285592c
--- /dev/null
+++ b/rl_algo_impls/tuning/optimize_env.py
@@ -0,0 +1,41 @@
+import optuna
+
+from typing import Any, Dict
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
+
+
+def sample_env_hyperparams(
+ trial: optuna.Trial, env_hparams: Dict[str, Any], env: VecEnv
+) -> Dict[str, Any]:
+ obs_space = single_observation_space(env)
+
+ n_envs = 2 ** trial.suggest_int("n_envs_exp", 1, 5)
+ trial.set_user_attr("n_envs", n_envs)
+ env_hparams["n_envs"] = n_envs
+
+ normalize = trial.suggest_categorical("normalize", [False, True])
+ env_hparams["normalize"] = normalize
+ if normalize:
+ normalize_kwargs = env_hparams.get("normalize_kwargs", {})
+ if len(obs_space.shape) == 3:
+ normalize_kwargs.update(
+ {
+ "norm_obs": False,
+ "norm_reward": True,
+ }
+ )
+ else:
+ norm_obs = trial.suggest_categorical("norm_obs", [True, False])
+ norm_reward = trial.suggest_categorical("norm_reward", [True, False])
+ normalize_kwargs.update(
+ {
+ "norm_obs": norm_obs,
+ "norm_reward": norm_reward,
+ }
+ )
+ env_hparams["normalize_kwargs"] = normalize_kwargs
+ elif "normalize_kwargs" in env_hparams:
+ del env_hparams["normalize_kwargs"]
+
+ return env_hparams
diff --git a/rl_algo_impls/vpg/policy.py b/rl_algo_impls/vpg/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..e016592fd079ba614ae40f51c70e0119a29113bc
--- /dev/null
+++ b/rl_algo_impls/vpg/policy.py
@@ -0,0 +1,133 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from typing import Optional, Sequence
+
+from rl_algo_impls.shared.module.feature_extractor import FeatureExtractor
+from rl_algo_impls.shared.policy.actor import (
+ PiForward,
+ Actor,
+ StateDependentNoiseActorHead,
+ actor_head,
+)
+from rl_algo_impls.shared.policy.critic import CriticHead
+from rl_algo_impls.shared.policy.on_policy import (
+ Step,
+ ACForward,
+ OnPolicy,
+ clamp_actions,
+ default_hidden_sizes,
+)
+from rl_algo_impls.shared.policy.policy import ACTIVATION
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecEnv,
+ VecEnvObs,
+ single_observation_space,
+ single_action_space,
+)
+
+PI_FILE_NAME = "pi.pt"
+V_FILE_NAME = "v.pt"
+
+
+class VPGActor(Actor):
+ def __init__(self, feature_extractor: FeatureExtractor, head: Actor) -> None:
+ super().__init__()
+ self.feature_extractor = feature_extractor
+ self.head = head
+
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
+ fe = self.feature_extractor(obs)
+ return self.head(fe, a)
+
+
+class VPGActorCritic(OnPolicy):
+ def __init__(
+ self,
+ env: VecEnv,
+ hidden_sizes: Optional[Sequence[int]] = None,
+ init_layers_orthogonal: bool = True,
+ activation_fn: str = "tanh",
+ log_std_init: float = -0.5,
+ use_sde: bool = False,
+ full_std: bool = True,
+ squash_output: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(env, **kwargs)
+ activation = ACTIVATION[activation_fn]
+ obs_space = single_observation_space(env)
+ self.action_space = single_action_space(env)
+ self.use_sde = use_sde
+ self.squash_output = squash_output
+
+ hidden_sizes = (
+ hidden_sizes
+ if hidden_sizes is not None
+ else default_hidden_sizes(obs_space)
+ )
+
+ pi_feature_extractor = FeatureExtractor(
+ obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
+ )
+ pi_head = actor_head(
+ self.action_space,
+ (pi_feature_extractor.out_dim,) + tuple(hidden_sizes),
+ init_layers_orthogonal,
+ activation,
+ log_std_init=log_std_init,
+ use_sde=use_sde,
+ full_std=full_std,
+ squash_output=squash_output,
+ )
+ self.pi = VPGActor(pi_feature_extractor, pi_head)
+
+ v_feature_extractor = FeatureExtractor(
+ obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
+ )
+ v_head = CriticHead(
+ (v_feature_extractor.out_dim,) + tuple(hidden_sizes),
+ activation=activation,
+ init_layers_orthogonal=init_layers_orthogonal,
+ )
+ self.v = nn.Sequential(v_feature_extractor, v_head)
+
+ def value(self, obs: VecEnvObs) -> np.ndarray:
+ o = self._as_tensor(obs)
+ with torch.no_grad():
+ v = self.v(o)
+ return v.cpu().numpy()
+
+ def step(self, obs: VecEnvObs) -> Step:
+ o = self._as_tensor(obs)
+ with torch.no_grad():
+ pi, _, _ = self.pi(o)
+ a = pi.sample()
+ logp_a = pi.log_prob(a)
+
+ v = self.v(o)
+
+ a_np = a.cpu().numpy()
+ clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
+ return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
+
+ def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
+ if not deterministic:
+ return self.step(obs).clamped_a
+ else:
+ o = self._as_tensor(obs)
+ with torch.no_grad():
+ pi, _, _ = self.pi(o)
+ a = pi.mode
+ return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
+
+ def load(self, path: str) -> None:
+ super().load(path)
+ self.reset_noise()
+
+ def reset_noise(self, batch_size: Optional[int] = None) -> None:
+ if isinstance(self.pi.head, StateDependentNoiseActorHead):
+ self.pi.head.sample_weights(
+ batch_size=batch_size if batch_size else self.env.num_envs
+ )
diff --git a/rl_algo_impls/vpg/vpg.py b/rl_algo_impls/vpg/vpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..9605efb23097a11d4cfd50b6e86b67162b10873e
--- /dev/null
+++ b/rl_algo_impls/vpg/vpg.py
@@ -0,0 +1,168 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from collections import defaultdict
+from dataclasses import dataclass, asdict
+from torch.optim import Adam
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Optional, Sequence, TypeVar
+
+from rl_algo_impls.shared.algorithm import Algorithm
+from rl_algo_impls.shared.callbacks.callback import Callback
+from rl_algo_impls.shared.gae import compute_rtg_and_advantage, compute_advantage
+from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
+from rl_algo_impls.vpg.policy import VPGActorCritic
+from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
+
+
+@dataclass
+class TrainEpochStats:
+ pi_loss: float
+ entropy_loss: float
+ v_loss: float
+ envs_with_done: int = 0
+ episodes_done: int = 0
+
+ def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
+ for name, value in asdict(self).items():
+ tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
+
+
+class VPGTrajectoryAccumulator(TrajectoryAccumulator):
+ def __init__(self, num_envs: int) -> None:
+ super().__init__(num_envs, trajectory_class=Trajectory)
+ self.completed_per_env: defaultdict[int, int] = defaultdict(int)
+
+ def on_done(self, env_idx: int, trajectory: Trajectory) -> None:
+ self.completed_per_env[env_idx] += 1
+
+
+VanillaPolicyGradientSelf = TypeVar(
+ "VanillaPolicyGradientSelf", bound="VanillaPolicyGradient"
+)
+
+
+class VanillaPolicyGradient(Algorithm):
+ def __init__(
+ self,
+ policy: VPGActorCritic,
+ env: VecEnv,
+ device: torch.device,
+ tb_writer: SummaryWriter,
+ gamma: float = 0.99,
+ pi_lr: float = 3e-4,
+ val_lr: float = 1e-3,
+ train_v_iters: int = 80,
+ gae_lambda: float = 0.97,
+ max_grad_norm: float = 10.0,
+ n_steps: int = 4_000,
+ sde_sample_freq: int = -1,
+ update_rtg_between_v_iters: bool = False,
+ ent_coef: float = 0.0,
+ ) -> None:
+ super().__init__(policy, env, device, tb_writer)
+ self.policy = policy
+
+ self.gamma = gamma
+ self.gae_lambda = gae_lambda
+ self.pi_optim = Adam(self.policy.pi.parameters(), lr=pi_lr)
+ self.val_optim = Adam(self.policy.v.parameters(), lr=val_lr)
+ self.max_grad_norm = max_grad_norm
+
+ self.n_steps = n_steps
+ self.train_v_iters = train_v_iters
+ self.sde_sample_freq = sde_sample_freq
+ self.update_rtg_between_v_iters = update_rtg_between_v_iters
+
+ self.ent_coef = ent_coef
+
+ def learn(
+ self: VanillaPolicyGradientSelf,
+ total_timesteps: int,
+ callback: Optional[Callback] = None,
+ ) -> VanillaPolicyGradientSelf:
+ timesteps_elapsed = 0
+ epoch_cnt = 0
+ while timesteps_elapsed < total_timesteps:
+ epoch_cnt += 1
+ accumulator = self._collect_trajectories()
+ epoch_stats = self.train(accumulator.all_trajectories)
+ epoch_stats.envs_with_done = len(accumulator.completed_per_env)
+ epoch_stats.episodes_done = sum(accumulator.completed_per_env.values())
+ epoch_steps = accumulator.n_timesteps()
+ timesteps_elapsed += epoch_steps
+ epoch_stats.write_to_tensorboard(
+ self.tb_writer, global_step=timesteps_elapsed
+ )
+ print(
+ " | ".join(
+ [
+ f"Epoch: {epoch_cnt}",
+ f"Pi Loss: {round(epoch_stats.pi_loss, 2)}",
+ f"Epoch Loss: {round(epoch_stats.entropy_loss, 2)}",
+ f"V Loss: {round(epoch_stats.v_loss, 2)}",
+ f"Total Steps: {timesteps_elapsed}",
+ ]
+ )
+ )
+ if callback:
+ callback.on_step(timesteps_elapsed=epoch_steps)
+ return self
+
+ def train(self, trajectories: Sequence[Trajectory]) -> TrainEpochStats:
+ self.policy.train()
+ obs = torch.as_tensor(
+ np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
+ )
+ act = torch.as_tensor(
+ np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
+ )
+ rtg, adv = compute_rtg_and_advantage(
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
+ )
+
+ _, logp, entropy = self.policy.pi(obs, act)
+ pi_loss = -(logp * adv).mean()
+ entropy_loss = entropy.mean()
+
+ actor_loss = pi_loss - self.ent_coef * entropy_loss
+
+ self.pi_optim.zero_grad()
+ actor_loss.backward()
+ nn.utils.clip_grad_norm_(self.policy.pi.parameters(), self.max_grad_norm)
+ self.pi_optim.step()
+
+ v_loss = 0
+ for _ in range(self.train_v_iters):
+ if self.update_rtg_between_v_iters:
+ rtg = compute_advantage(
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
+ )
+ v = self.policy.v(obs)
+ v_loss = ((v - rtg) ** 2).mean()
+
+ self.val_optim.zero_grad()
+ v_loss.backward()
+ nn.utils.clip_grad_norm_(self.policy.v.parameters(), self.max_grad_norm)
+ self.val_optim.step()
+
+ return TrainEpochStats(
+ pi_loss.item(),
+ entropy_loss.item(),
+ v_loss.item(), # type: ignore
+ )
+
+ def _collect_trajectories(self) -> VPGTrajectoryAccumulator:
+ self.policy.eval()
+ obs = self.env.reset()
+ accumulator = VPGTrajectoryAccumulator(self.env.num_envs)
+ self.policy.reset_noise()
+ for i in range(self.n_steps):
+ if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
+ self.policy.reset_noise()
+ action, value, _, clamped_action = self.policy.step(obs)
+ next_obs, reward, done, _ = self.env.step(clamped_action)
+ accumulator.step(obs, action, next_obs, reward, done, value)
+ obs = next_obs
+ return accumulator
diff --git a/rl_algo_impls/wrappers/atari_wrappers.py b/rl_algo_impls/wrappers/atari_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fb0b345f9b1bc8a972cfdc635027aef9a45e37a
--- /dev/null
+++ b/rl_algo_impls/wrappers/atari_wrappers.py
@@ -0,0 +1,84 @@
+import gym
+import numpy as np
+
+from typing import Any, Dict, Tuple, Union
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
+
+ObsType = Union[np.ndarray, dict]
+ActType = Union[int, float, np.ndarray, dict]
+
+
+class EpisodicLifeEnv(VecotarableWrapper):
+ def __init__(self, env: gym.Env, training: bool = True, noop_act: int = 0) -> None:
+ super().__init__(env)
+ self.training = training
+ self.noop_act = noop_act
+ self.life_done_continue = False
+ self.lives = 0
+
+ def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
+ obs, rew, done, info = self.env.step(action)
+ new_lives = self.env.unwrapped.ale.lives()
+ self.life_done_continue = new_lives != self.lives and not done
+ # Only if training should life-end be marked as done
+ if self.training and 0 < new_lives < self.lives:
+ done = True
+ self.lives = new_lives
+ return obs, rew, done, info
+
+ def reset(self, **kwargs) -> ObsType:
+ # If life_done_continue (but not game over), then a reset should just allow the
+ # game to progress to the next life.
+ if self.training and self.life_done_continue:
+ obs, _, _, _ = self.env.step(self.noop_act)
+ else:
+ obs = self.env.reset(**kwargs)
+ self.lives = self.env.unwrapped.ale.lives()
+ return obs
+
+
+class FireOnLifeStarttEnv(VecotarableWrapper):
+ def __init__(self, env: gym.Env, fire_act: int = 1) -> None:
+ super().__init__(env)
+ self.fire_act = fire_act
+ action_meanings = env.unwrapped.get_action_meanings()
+ assert action_meanings[fire_act] == "FIRE"
+ assert len(action_meanings) >= 3
+ self.lives = 0
+ self.fire_on_next_step = True
+
+ def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
+ if self.fire_on_next_step:
+ action = self.fire_act
+ self.fire_on_next_step = False
+ obs, rew, done, info = self.env.step(action)
+ new_lives = self.env.unwrapped.ale.lives()
+ if 0 < new_lives < self.lives and not done:
+ self.fire_on_next_step = True
+ self.lives = new_lives
+ return obs, rew, done, info
+
+ def reset(self, **kwargs) -> ObsType:
+ self.env.reset(**kwargs)
+ obs, _, done, _ = self.env.step(self.fire_act)
+ if done:
+ self.env.reset(**kwargs)
+ obs, _, done, _ = self.env.step(2)
+ if done:
+ self.env.reset(**kwargs)
+ self.fire_on_next_step = False
+ return obs
+
+
+class ClipRewardEnv(VecotarableWrapper):
+ def __init__(self, env: gym.Env, training: bool = True) -> None:
+ super().__init__(env)
+ self.training = training
+
+ def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
+ obs, rew, done, info = self.env.step(action)
+ if self.training:
+ info["unclipped_reward"] = rew
+ rew = np.sign(rew)
+ return obs, rew, done, info
diff --git a/rl_algo_impls/wrappers/episode_record_video.py b/rl_algo_impls/wrappers/episode_record_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42aba4a537b2698c5f358a5fdf52df526f5d1d1
--- /dev/null
+++ b/rl_algo_impls/wrappers/episode_record_video.py
@@ -0,0 +1,75 @@
+import gym
+import numpy as np
+
+from gym.wrappers.monitoring.video_recorder import VideoRecorder
+from typing import Tuple, Union
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
+
+ObsType = Union[np.ndarray, dict]
+ActType = Union[int, float, np.ndarray, dict]
+
+
+class EpisodeRecordVideo(VecotarableWrapper):
+ def __init__(
+ self,
+ env: gym.Env,
+ video_path_prefix: str,
+ step_increment: int = 1,
+ video_step_interval: int = 1_000_000,
+ max_video_length: int = 3600,
+ ) -> None:
+ super().__init__(env)
+ self.video_path_prefix = video_path_prefix
+ self.step_increment = step_increment
+ self.video_step_interval = video_step_interval
+ self.max_video_length = max_video_length
+ self.total_steps = 0
+ self.next_record_video_step = 0
+ self.video_recorder = None
+ self.recorded_frames = 0
+
+ def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
+ obs, rew, done, info = self.env.step(action)
+ self.total_steps += self.step_increment
+ # Using first env to record episodes
+ if self.video_recorder:
+ self.video_recorder.capture_frame()
+ self.recorded_frames += 1
+ if info.get("episode"):
+ episode_info = {
+ k: v.item() if hasattr(v, "item") else v
+ for k, v in info["episode"].items()
+ }
+ self.video_recorder.metadata["episode"] = episode_info
+ if self.recorded_frames > self.max_video_length:
+ self._close_video_recorder()
+ return obs, rew, done, info
+
+ def reset(self, **kwargs) -> ObsType:
+ obs = self.env.reset(**kwargs)
+ if self.video_recorder:
+ self._close_video_recorder()
+ elif self.total_steps >= self.next_record_video_step:
+ self._start_video_recorder()
+ return obs
+
+ def _start_video_recorder(self) -> None:
+ self._close_video_recorder()
+
+ video_path = f"{self.video_path_prefix}-{self.next_record_video_step}"
+ self.video_recorder = VideoRecorder(
+ self.env,
+ base_path=video_path,
+ metadata={"step": self.total_steps},
+ )
+
+ self.video_recorder.capture_frame()
+ self.recorded_frames = 1
+ self.next_record_video_step += self.video_step_interval
+
+ def _close_video_recorder(self) -> None:
+ if self.video_recorder:
+ self.video_recorder.close()
+ self.video_recorder = None
+ self.recorded_frames = 0
diff --git a/rl_algo_impls/wrappers/episode_stats_writer.py b/rl_algo_impls/wrappers/episode_stats_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..21018e3e03e9f98a3948df5a93c889bb0e17d42f
--- /dev/null
+++ b/rl_algo_impls/wrappers/episode_stats_writer.py
@@ -0,0 +1,70 @@
+import numpy as np
+
+from collections import deque
+from torch.utils.tensorboard.writer import SummaryWriter
+from typing import Any, Dict, List
+
+from rl_algo_impls.shared.stats import Episode, EpisodesStats
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecotarableWrapper,
+ VecEnvStepReturn,
+ VecEnvObs,
+)
+
+
+class EpisodeStatsWriter(VecotarableWrapper):
+ def __init__(
+ self,
+ env,
+ tb_writer: SummaryWriter,
+ training: bool = True,
+ rolling_length=100,
+ ):
+ super().__init__(env)
+ self.training = training
+ self.tb_writer = tb_writer
+ self.rolling_length = rolling_length
+ self.episodes = deque(maxlen=rolling_length)
+ self.total_steps = 0
+ self.episode_cnt = 0
+ self.last_episode_cnt_print = 0
+
+ def step(self, actions: np.ndarray) -> VecEnvStepReturn:
+ obs, rews, dones, infos = self.env.step(actions)
+ self._record_stats(infos)
+ return obs, rews, dones, infos
+
+ # Support for stable_baselines3.common.vec_env.VecEnvWrapper
+ def step_wait(self) -> VecEnvStepReturn:
+ obs, rews, dones, infos = self.env.step_wait()
+ self._record_stats(infos)
+ return obs, rews, dones, infos
+
+ def _record_stats(self, infos: List[Dict[str, Any]]) -> None:
+ self.total_steps += getattr(self.env, "num_envs", 1)
+ step_episodes = []
+ for info in infos:
+ ep_info = info.get("episode")
+ if ep_info:
+ episode = Episode(ep_info["r"], ep_info["l"])
+ step_episodes.append(episode)
+ self.episodes.append(episode)
+ if step_episodes:
+ tag = "train" if self.training else "eval"
+ step_stats = EpisodesStats(step_episodes, simple=True)
+ step_stats.write_to_tensorboard(self.tb_writer, tag, self.total_steps)
+ rolling_stats = EpisodesStats(self.episodes)
+ rolling_stats.write_to_tensorboard(
+ self.tb_writer, f"{tag}_rolling", self.total_steps
+ )
+ self.episode_cnt += len(step_episodes)
+ if self.episode_cnt >= self.last_episode_cnt_print + self.rolling_length:
+ print(
+ f"Episode: {self.episode_cnt} | "
+ f"Steps: {self.total_steps} | "
+ f"{rolling_stats}"
+ )
+ self.last_episode_cnt_print += self.rolling_length
+
+ def reset(self) -> VecEnvObs:
+ return self.env.reset()
diff --git a/rl_algo_impls/wrappers/initial_step_truncate_wrapper.py b/rl_algo_impls/wrappers/initial_step_truncate_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9a30dd781304251590562168a97875c20e42dc8
--- /dev/null
+++ b/rl_algo_impls/wrappers/initial_step_truncate_wrapper.py
@@ -0,0 +1,27 @@
+import gym
+import numpy as np
+
+from typing import Any, Dict, Tuple, Union
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
+
+ObsType = Union[np.ndarray, dict]
+ActType = Union[int, float, np.ndarray, dict]
+
+
+class InitialStepTruncateWrapper(VecotarableWrapper):
+ def __init__(self, env: gym.Env, initial_steps_to_truncate: int) -> None:
+ super().__init__(env)
+ self.initial_steps_to_truncate = initial_steps_to_truncate
+ self.initialized = initial_steps_to_truncate == 0
+ self.steps = 0
+
+ def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
+ obs, rew, done, info = self.env.step(action)
+ if not self.initialized:
+ self.steps += 1
+ if self.steps >= self.initial_steps_to_truncate:
+ print(f"Truncation at {self.steps} steps")
+ done = True
+ self.initialized = True
+ return obs, rew, done, info
diff --git a/rl_algo_impls/wrappers/is_vector_env.py b/rl_algo_impls/wrappers/is_vector_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd994c577808c0201dfadb6f2e6c522e250a5970
--- /dev/null
+++ b/rl_algo_impls/wrappers/is_vector_env.py
@@ -0,0 +1,13 @@
+from typing import Any
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
+
+
+class IsVectorEnv(VecotarableWrapper):
+ """
+ Override to set properties to match gym.vector.VectorEnv
+ """
+
+ def __init__(self, env: Any) -> None:
+ super().__init__(env)
+ self.is_vector_env = True
diff --git a/rl_algo_impls/wrappers/no_reward_timeout.py b/rl_algo_impls/wrappers/no_reward_timeout.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ff13e761d5f27b9f1fb262b9b31118932281cb8
--- /dev/null
+++ b/rl_algo_impls/wrappers/no_reward_timeout.py
@@ -0,0 +1,65 @@
+import gym
+import numpy as np
+
+from typing import Optional, Tuple, Union
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
+
+ObsType = Union[np.ndarray, dict]
+ActType = Union[int, float, np.ndarray, dict]
+
+
+class NoRewardTimeout(VecotarableWrapper):
+ def __init__(
+ self, env: gym.Env, n_timeout_steps: int, n_fire_steps: Optional[int] = None
+ ) -> None:
+ super().__init__(env)
+ self.n_timeout_steps = n_timeout_steps
+ self.n_fire_steps = n_fire_steps
+
+ self.fire_act = None
+ if n_fire_steps is not None:
+ action_meanings = env.unwrapped.get_action_meanings()
+ assert "FIRE" in action_meanings
+ self.fire_act = action_meanings.index("FIRE")
+
+ self.steps_since_reward = 0
+
+ self.episode_score = 0
+ self.episode_step_idx = 0
+
+ def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
+ if self.steps_since_reward == self.n_fire_steps:
+ assert self.fire_act is not None
+ self.print_intervention("Force fire action")
+ action = self.fire_act
+ obs, rew, done, info = self.env.step(action)
+
+ self.episode_score += rew
+ self.episode_step_idx += 1
+
+ if rew != 0 or done:
+ self.steps_since_reward = 0
+ else:
+ self.steps_since_reward += 1
+ if self.steps_since_reward >= self.n_timeout_steps:
+ self.print_intervention("Early terminate")
+ done = True
+
+ return obs, rew, done, info
+
+ def reset(self, **kwargs) -> ObsType:
+ self._reset_state()
+ return self.env.reset(**kwargs)
+
+ def _reset_state(self) -> None:
+ self.steps_since_reward = 0
+ self.episode_score = 0
+ self.episode_step_idx = 0
+
+ def print_intervention(self, tag: str) -> None:
+ print(
+ f"{self.__class__.__name__}: {tag} | "
+ f"Score: {self.episode_score} | "
+ f"Length: {self.episode_step_idx}"
+ )
diff --git a/rl_algo_impls/wrappers/noop_env_seed.py b/rl_algo_impls/wrappers/noop_env_seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..b013b960ad5fc0d33b2f3e88b0e91ffc861ba0c8
--- /dev/null
+++ b/rl_algo_impls/wrappers/noop_env_seed.py
@@ -0,0 +1,12 @@
+from typing import List, Optional
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
+
+
+class NoopEnvSeed(VecotarableWrapper):
+ """
+ Wrapper to stop a seed call going to the underlying environment.
+ """
+
+ def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
+ return None
diff --git a/rl_algo_impls/wrappers/normalize.py b/rl_algo_impls/wrappers/normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e48288f450b0ec284b405261ddbd22d8ff3bbe10
--- /dev/null
+++ b/rl_algo_impls/wrappers/normalize.py
@@ -0,0 +1,140 @@
+import gym
+import numpy as np
+
+from numpy.typing import NDArray
+from typing import Tuple
+
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecotarableWrapper,
+ single_observation_space,
+)
+
+
+class RunningMeanStd:
+ def __init__(self, episilon: float = 1e-4, shape: Tuple[int, ...] = ()) -> None:
+ self.mean = np.zeros(shape, np.float64)
+ self.var = np.ones(shape, np.float64)
+ self.count = episilon
+
+ def update(self, x: NDArray) -> None:
+ batch_mean = np.mean(x, axis=0)
+ batch_var = np.var(x, axis=0)
+ batch_count = x.shape[0]
+
+ delta = batch_mean - self.mean
+ total_count = self.count + batch_count
+
+ self.mean += delta * batch_count / total_count
+
+ m_a = self.var * self.count
+ m_b = batch_var * batch_count
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / total_count
+ self.var = M2 / total_count
+ self.count = total_count
+
+
+class NormalizeObservation(VecotarableWrapper):
+ def __init__(
+ self,
+ env: gym.Env,
+ training: bool = True,
+ epsilon: float = 1e-8,
+ clip: float = 10.0,
+ ) -> None:
+ super().__init__(env)
+ self.rms = RunningMeanStd(shape=single_observation_space(env).shape)
+ self.training = training
+ self.epsilon = epsilon
+ self.clip = clip
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ return self.normalize(obs), reward, done, info
+
+ def reset(self, **kwargs):
+ obs = self.env.reset(**kwargs)
+ return self.normalize(obs)
+
+ def normalize(self, obs: NDArray) -> NDArray:
+ obs_array = np.array([obs]) if not self.is_vector_env else obs
+ if self.training:
+ self.rms.update(obs_array)
+ normalized = np.clip(
+ (obs_array - self.rms.mean) / np.sqrt(self.rms.var + self.epsilon),
+ -self.clip,
+ self.clip,
+ )
+ return normalized[0] if not self.is_vector_env else normalized
+
+ def save(self, path: str) -> None:
+ np.savez_compressed(
+ path,
+ mean=self.rms.mean,
+ var=self.rms.var,
+ count=self.rms.count,
+ )
+
+ def load(self, path: str) -> None:
+ data = np.load(path)
+ self.rms.mean = data["mean"]
+ self.rms.var = data["var"]
+ self.rms.count = data["count"]
+
+
+class NormalizeReward(VecotarableWrapper):
+ def __init__(
+ self,
+ env: gym.Env,
+ training: bool = True,
+ gamma: float = 0.99,
+ epsilon: float = 1e-8,
+ clip: float = 10.0,
+ ) -> None:
+ super().__init__(env)
+ self.rms = RunningMeanStd(shape=())
+ self.training = training
+ self.gamma = gamma
+ self.epsilon = epsilon
+ self.clip = clip
+
+ self.returns = np.zeros(self.num_envs)
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+
+ if not self.is_vector_env:
+ reward = np.array([reward])
+ reward = self.normalize(reward)
+ if not self.is_vector_env:
+ reward = reward[0]
+
+ dones = done if self.is_vector_env else np.array([done])
+ self.returns[dones] = 0
+
+ return obs, reward, done, info
+
+ def reset(self, **kwargs):
+ self.returns = np.zeros(self.num_envs)
+ return self.env.reset(**kwargs)
+
+ def normalize(self, rewards):
+ if self.training:
+ self.returns = self.returns * self.gamma + rewards
+ self.rms.update(self.returns)
+ return np.clip(
+ rewards / np.sqrt(self.rms.var + self.epsilon), -self.clip, self.clip
+ )
+
+ def save(self, path: str) -> None:
+ np.savez_compressed(
+ path,
+ mean=self.rms.mean,
+ var=self.rms.var,
+ count=self.rms.count,
+ )
+
+ def load(self, path: str) -> None:
+ data = np.load(path)
+ self.rms.mean = data["mean"]
+ self.rms.var = data["var"]
+ self.rms.count = data["count"]
diff --git a/rl_algo_impls/wrappers/sync_vector_env_render_compat.py b/rl_algo_impls/wrappers/sync_vector_env_render_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8e9f7456f105fc9606e5479831b0c67263a7bc3
--- /dev/null
+++ b/rl_algo_impls/wrappers/sync_vector_env_render_compat.py
@@ -0,0 +1,31 @@
+import numpy as np
+
+from gym.vector.sync_vector_env import SyncVectorEnv
+from stable_baselines3.common.vec_env.base_vec_env import tile_images
+from typing import Optional
+
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecotarableWrapper,
+)
+
+
+class SyncVectorEnvRenderCompat(VecotarableWrapper):
+ def __init__(self, env) -> None:
+ super().__init__(env)
+
+ def render(self, mode: str = "human") -> Optional[np.ndarray]:
+ base_env = self.env.unwrapped
+ if isinstance(base_env, SyncVectorEnv):
+ imgs = [env.render(mode="rgb_array") for env in base_env.envs]
+ bigimg = tile_images(imgs)
+ if mode == "human":
+ import cv2
+
+ cv2.imshow("vecenv", bigimg[:, :, ::-1])
+ cv2.waitKey(1)
+ elif mode == "rgb_array":
+ return bigimg
+ else:
+ raise NotImplemented(f"Render mode {mode} is not supported")
+ else:
+ return self.env.render(mode=mode)
diff --git a/rl_algo_impls/wrappers/transpose_image_observation.py b/rl_algo_impls/wrappers/transpose_image_observation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7076c9146fa28cd266a2466b58ac6a6b6555d59b
--- /dev/null
+++ b/rl_algo_impls/wrappers/transpose_image_observation.py
@@ -0,0 +1,34 @@
+import gym
+import numpy as np
+
+from gym import ObservationWrapper
+from gym.spaces import Box
+
+
+class TransposeImageObservation(ObservationWrapper):
+ def __init__(self, env: gym.Env) -> None:
+ super().__init__(env)
+
+ assert isinstance(env.observation_space, Box)
+
+ obs_space = env.observation_space
+ axes = tuple(i for i in range(len(obs_space.shape)))
+ self._transpose_axes = axes[:-3] + (axes[-1],) + axes[-3:-1]
+
+ self.observation_space = Box(
+ low=np.transpose(obs_space.low, axes=self._transpose_axes),
+ high=np.transpose(obs_space.high, axes=self._transpose_axes),
+ shape=[obs_space.shape[idx] for idx in self._transpose_axes],
+ dtype=obs_space.dtype,
+ )
+
+ def observation(self, obs: np.ndarray) -> np.ndarray:
+ full_shape = obs.shape
+ obs_shape = self.observation_space.shape
+ addl_dims = len(full_shape) - len(obs_shape)
+ if addl_dims > 0:
+ transpose_axes = list(range(addl_dims))
+ transpose_axes.extend(a + addl_dims for a in self._transpose_axes)
+ else:
+ transpose_axes = self._transpose_axes
+ return np.transpose(obs, axes=transpose_axes)
diff --git a/rl_algo_impls/wrappers/vec_episode_recorder.py b/rl_algo_impls/wrappers/vec_episode_recorder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86907ab71c3c9930a04fc9c1c6d51fbc083cf54
--- /dev/null
+++ b/rl_algo_impls/wrappers/vec_episode_recorder.py
@@ -0,0 +1,55 @@
+import numpy as np
+
+from gym.wrappers.monitoring.video_recorder import VideoRecorder
+
+from rl_algo_impls.wrappers.vectorable_wrapper import (
+ VecotarableWrapper,
+ VecEnvObs,
+ VecEnvStepReturn,
+)
+
+
+class VecEpisodeRecorder(VecotarableWrapper):
+ def __init__(self, env, base_path: str, max_video_length: int = 3600):
+ super().__init__(env)
+ self.base_path = base_path
+ self.max_video_length = max_video_length
+ self.video_recorder = None
+ self.recorded_frames = 0
+
+ def step(self, actions: np.ndarray) -> VecEnvStepReturn:
+ obs, rew, dones, infos = self.env.step(actions)
+ # Using first env to record episodes
+ if self.video_recorder:
+ self.video_recorder.capture_frame()
+ self.recorded_frames += 1
+ if dones[0] and infos[0].get("episode"):
+ episode_info = {
+ k: v.item() if hasattr(v, "item") else v
+ for k, v in infos[0]["episode"].items()
+ }
+ self.video_recorder.metadata["episode"] = episode_info
+ if dones[0] or self.recorded_frames > self.max_video_length:
+ self._close_video_recorder()
+ return obs, rew, dones, infos
+
+ def reset(self) -> VecEnvObs:
+ obs = self.env.reset()
+ self._start_video_recorder()
+ return obs
+
+ def _start_video_recorder(self) -> None:
+ self._close_video_recorder()
+
+ self.video_recorder = VideoRecorder(
+ self.env,
+ base_path=self.base_path,
+ )
+
+ self.video_recorder.capture_frame()
+ self.recorded_frames = 1
+
+ def _close_video_recorder(self) -> None:
+ if self.video_recorder:
+ self.video_recorder.close()
+ self.video_recorder = None
diff --git a/rl_algo_impls/wrappers/vectorable_wrapper.py b/rl_algo_impls/wrappers/vectorable_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..03df8d1400ab84242353f5dc1288a4394158b941
--- /dev/null
+++ b/rl_algo_impls/wrappers/vectorable_wrapper.py
@@ -0,0 +1,46 @@
+import numpy as np
+from gym import Env, Space, Wrapper
+
+from stable_baselines3.common.vec_env import VecEnv as SB3VecEnv
+from typing import Dict, List, Optional, Type, TypeVar, Tuple, Union
+
+VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
+VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
+
+
+class VecotarableWrapper(Wrapper):
+ def __init__(self, env: Env) -> None:
+ super().__init__(env)
+ self.num_envs = getattr(env, "num_envs", 1)
+ self.is_vector_env = getattr(env, "is_vector_env", False)
+ self.single_observation_space = single_observation_space(env)
+ self.single_action_space = single_action_space(env)
+
+ def step(self, action) -> VecEnvStepReturn:
+ return self.env.step(action)
+
+ def reset(self) -> VecEnvObs:
+ return self.env.reset()
+
+
+VecEnv = Union[VecotarableWrapper, SB3VecEnv]
+
+
+def single_observation_space(env: Union[VecEnv, Env]) -> Space:
+ return getattr(env, "single_observation_space", env.observation_space)
+
+
+def single_action_space(env: Union[VecEnv, Env]) -> Space:
+ return getattr(env, "single_action_space", env.action_space)
+
+
+W = TypeVar("W", bound=Wrapper)
+
+
+def find_wrapper(env: VecEnv, wrapper_class: Type[W]) -> Optional[W]:
+ current = env
+ while current and current != current.unwrapped:
+ if isinstance(current, wrapper_class):
+ return current
+ current = getattr(current, "env")
+ return None
diff --git a/rl_algo_impls/wrappers/video_compat_wrapper.py b/rl_algo_impls/wrappers/video_compat_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fec197d2da3fdc36c6434c3bc6f3fbb9a2d6610
--- /dev/null
+++ b/rl_algo_impls/wrappers/video_compat_wrapper.py
@@ -0,0 +1,15 @@
+import gym
+import numpy as np
+
+from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
+
+
+class VideoCompatWrapper(VecotarableWrapper):
+ def __init__(self, env: gym.Env) -> None:
+ super().__init__(env)
+
+ def render(self, mode="human", **kwargs):
+ r = super().render(mode=mode, **kwargs)
+ if mode == "rgb_array" and isinstance(r, np.ndarray) and r.dtype != np.uint8:
+ r = r.astype(np.uint8)
+ return r
diff --git a/saved_models/ppo-AntBulletEnv-v0-S1-best/model.pth b/saved_models/ppo-AntBulletEnv-v0-S1-best/model.pth
new file mode 100644
index 0000000000000000000000000000000000000000..4816f94ae9c858337eee202fc036c4cfd8893527
--- /dev/null
+++ b/saved_models/ppo-AntBulletEnv-v0-S1-best/model.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f4469d1da7d122c13b25bdc93c5884abb452872f0ba3cce2bd29b4ae48b74bd1
+size 598880
diff --git a/saved_models/ppo-AntBulletEnv-v0-S1-best/norm_obs.npz b/saved_models/ppo-AntBulletEnv-v0-S1-best/norm_obs.npz
new file mode 100644
index 0000000000000000000000000000000000000000..09e0df79049118b7edaff5352fd73492e75a73ce
--- /dev/null
+++ b/saved_models/ppo-AntBulletEnv-v0-S1-best/norm_obs.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:597cb0cd06cc5d477cfb99e37915579dd4927d91718740d5e7258e04e1d0856d
+size 1044
diff --git a/saved_models/ppo-AntBulletEnv-v0-S1-best/norm_reward.npz b/saved_models/ppo-AntBulletEnv-v0-S1-best/norm_reward.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1be6b9764034b6205d649a6ff474f92c04a33762
--- /dev/null
+++ b/saved_models/ppo-AntBulletEnv-v0-S1-best/norm_reward.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:888f37626d97c45ef2d3632362a0ac4db433734153a9909ada67e36d4afda429
+size 581
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
new file mode 100644
index 0000000000000000000000000000000000000000..edb726cfd0e2e8fcc7cce9aa5cfd03bad82492fd
--- /dev/null
+++ b/scripts/benchmark.sh
@@ -0,0 +1,67 @@
+while test $# != 0
+do
+ case "$1" in
+ -a) algos=$2 ;;
+ -j) n_jobs=$2 ;;
+ -p) project_name=$2 ;;
+ -s) seeds=$2 ;;
+ -e) envs=$2 ;;
+ --procgen) procgen=t
+ esac
+ shift
+done
+
+algos="${algos:-ppo}"
+n_jobs="${n_jobs:-6}"
+project_name="${project_name:-rl-algo-impls-benchmarks}"
+seeds="${seeds:-1 2 3}"
+
+DISCRETE_ENVS=(
+ # Basic
+ "CartPole-v1"
+ "MountainCar-v0"
+ "Acrobot-v1"
+ "LunarLander-v2"
+ # Atari
+ "PongNoFrameskip-v4"
+ "BreakoutNoFrameskip-v4"
+ "SpaceInvadersNoFrameskip-v4"
+ "QbertNoFrameskip-v4"
+)
+BOX_ENVS=(
+ # Basic
+ "MountainCarContinuous-v0"
+ "BipedalWalker-v3"
+ # PyBullet
+ "HalfCheetahBulletEnv-v0"
+ "AntBulletEnv-v0"
+ "HopperBulletEnv-v0"
+ "Walker2DBulletEnv-v0"
+ # CarRacing
+ "CarRacing-v0"
+)
+
+for algo in $(echo $algos); do
+ if [ "$algo" = "dqn" ]; then
+ BENCHMARK_ENVS="${DISCRETE_ENVS[*]}"
+ else
+ BENCHMARK_ENVS="${DISCRETE_ENVS[*]} ${BOX_ENVS[*]}"
+ fi
+ algo_envs=$envs
+ if [ -z $algo_envs ]; then
+ echo "-e unspecified; therefore, benchmark training on ${BENCHMARK_ENVS[*]}"
+ algo_envs=${BENCHMARK_ENVS[*]}
+ fi
+
+ PROCGEN_ENVS=(
+ "procgen-coinrun-easy"
+ "procgen-starpilot-easy"
+ "procgen-bossfight-easy"
+ "procgen-bigfish-easy"
+ )
+ if [ "$procgen" = "t" ]; then
+ algo_envs=${PROCGEN_ENVS[*]}
+ fi
+
+ bash scripts/train_loop.sh -a $algo -e "$algo_envs" -p $project_name -s "$seeds" | xargs -I CMD -P $n_jobs bash -c CMD
+done
\ No newline at end of file
diff --git a/scripts/setup.sh b/scripts/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..46862d85eb13670353bca988b85207d3f68bb023
--- /dev/null
+++ b/scripts/setup.sh
@@ -0,0 +1,10 @@
+sudo apt update
+sudo apt install -y python-opengl
+sudo apt install -y ffmpeg
+sudo apt install -y xvfb
+sudo apt install -y swig
+
+python3 -m pip install --upgrade pip
+pip install --upgrade torch torchvision torchaudio
+
+python -m pip install --upgrade '.[test,procgen]'
\ No newline at end of file
diff --git a/scripts/tags_benchmark.sh b/scripts/tags_benchmark.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cfde37478f2663ccc41681906ae25cc550bcb814
--- /dev/null
+++ b/scripts/tags_benchmark.sh
@@ -0,0 +1 @@
+echo "benchmark_$(git rev-parse --short HEAD) host_$(hostname)"
\ No newline at end of file
diff --git a/scripts/train_loop.sh b/scripts/train_loop.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e2e13c176f7227b259688b6cb0265fa2b1cb7167
--- /dev/null
+++ b/scripts/train_loop.sh
@@ -0,0 +1,18 @@
+while getopts a:e:s:p: flag
+do
+ case "${flag}" in
+ a) algo=${OPTARG};;
+ e) envs=${OPTARG};;
+ s) seeds=${OPTARG};;
+ p) project_name=${OPTARG};;
+ esac
+done
+
+WANDB_TAGS=$(bash scripts/tags_benchmark.sh)
+project_name="${project_name:-rl-algo-impls-benchmarks}"
+seeds="${seeds:-1 2 3}"
+for env in $(echo $envs); do
+ for seed in $seeds; do
+ echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $project_name --virtual-display
+ done
+done
diff --git a/scripts/tuning.sh b/scripts/tuning.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9c886b832ce125b32497129ecf81bbf8345ee113
--- /dev/null
+++ b/scripts/tuning.sh
@@ -0,0 +1,32 @@
+while getopts a:e:j:n:s:i: flag
+do
+ case "${flag}" in
+ a) algo=${OPTARG};;
+ e) env=${OPTARG};;
+ j) n_jobs=${OPTARG};;
+ n) study_name=${OPTARG};;
+ s) seeds=${OPTARG};;
+ i) increment=${OPTARG};;
+ esac
+done
+
+TZ="America/Los_Angeles"
+NOW=$(date +"%Y-%m-%dT%H:%M:%S")
+study_name="${study_name:-$algo-$env-$NOW}"
+STORAGE_PATH="sqlite:///runs/tuning.db"
+increment="${increment:-100}"
+
+mkdir -p runs
+optuna create-study --study-name $study_name --storage $STORAGE_PATH --direction maximize --skip-if-exists
+
+optimize () {
+ for ((j=$increment;j<=n_jobs*100+$increment;j+=100)); do
+ seed=()
+ for ((s=0;s Dict[str, Any]:
- d = vars(args).copy()
- d.update(
- {
- "algo": algo,
- "env": env,
- "seed": seed,
- }
- )
- return d
-
+from rl_algo_impls.train import train
if __name__ == "__main__":
- parser = base_parser()
- parser.add_argument(
- "--wandb-project-name",
- type=str,
- default="rl-algo-impls",
- help="WandB project namme to upload training data to. If none, won't upload.",
- )
- parser.add_argument(
- "--wandb-entity",
- type=str,
- default=None,
- help="WandB team of project. None uses default entity",
- )
- parser.add_argument(
- "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
- )
- parser.add_argument(
- "--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
- )
- parser.add_argument(
- "--virtual-display",
- action="store_true",
- help="Whether to create a virtual display for video rendering",
- )
- parser.set_defaults(algo="ppo", env="CartPole-v1", seed=1)
- args = parser.parse_args()
- print(args)
-
- if args.virtual_display:
- from pyvirtualdisplay import Display
-
- virtual_display = Display(visible=0, size=(1400, 900))
- virtual_display.start()
- delattr(args, "virtual_display")
-
- # pool_size isn't a TrainArg so must be removed from args
- pool_size = args.pool_size
- delattr(args, "pool_size")
-
- algos = args.algo if isinstance(args.algo, list) else [args.algo]
- envs = args.env if isinstance(args.env, list) else [args.env]
- seeds = args.seed if isinstance(args.seed, list) else [args.seed]
- if all(len(arg) == 1 for arg in [algos, envs, seeds]):
- train(TrainArgs(**args_dict(algos[0], envs[0], seeds[0], args)))
- else:
- # Force a new process for each job to get around wandb not allowing more than one
- # wandb.tensorboard.patch call per process.
- with Pool(pool_size, maxtasksperchild=1) as p:
- train_args = [
- TrainArgs(**args_dict(algo, env, seed, args))
- for algo, env, seed in itertools.product(algos, envs, seeds)
- ]
- p.map(train, train_args)
+ train()