sgoodfriend
commited on
Commit
·
9d36d7e
1
Parent(s):
0589ae3
PPO playing MicrortsDefeatCoacAIShaped-v3 from https://github.com/sgoodfriend/rl-algo-impls/tree/9ba0ab50894e5cea207289f4af8b53cbafa47748
Browse files- README.md +16 -15
- pyproject.toml +1 -1
- replay.meta.json +1 -1
- replay.mp4 +0 -0
- rl_algo_impls/a2c/optimize.py +9 -5
- rl_algo_impls/huggingface_publish.py +1 -1
- rl_algo_impls/hyperparams/a2c.yml +13 -12
- rl_algo_impls/hyperparams/ppo.yml +5 -5
- rl_algo_impls/optimize.py +2 -2
- rl_algo_impls/ppo/ppo.py +1 -1
- rl_algo_impls/runner/config.py +1 -0
- rl_algo_impls/runner/evaluate.py +1 -1
- rl_algo_impls/runner/selfplay_evaluate.py +142 -0
- rl_algo_impls/runner/train.py +6 -8
- rl_algo_impls/selfplay_enjoy.py +53 -0
- rl_algo_impls/shared/actor/state_dependent_noise.py +5 -5
- rl_algo_impls/shared/callbacks/eval_callback.py +13 -1
- rl_algo_impls/shared/policy/actor_critic.py +1 -1
- rl_algo_impls/shared/vec_env/make_env.py +7 -6
- rl_algo_impls/shared/vec_env/microrts.py +6 -2
- rl_algo_impls/shared/vec_env/procgen.py +2 -0
- rl_algo_impls/shared/vec_env/vec_env.py +2 -0
- rl_algo_impls/wrappers/action_mask_wrapper.py +2 -2
- rl_algo_impls/wrappers/microrts_stats_recorder.py +26 -2
- rl_algo_impls/wrappers/self_play_wrapper.py +42 -3
- rl_algo_impls/wrappers/vec_episode_recorder.py +16 -5
- saved_models/ppo-Microrts-selfplay-unet-decay-S1-best/model.pth +1 -1
- selfplay_enjoy.py +4 -0
README.md
CHANGED
@@ -10,7 +10,7 @@ model-index:
|
|
10 |
results:
|
11 |
- metrics:
|
12 |
- type: mean_reward
|
13 |
-
value: 0.
|
14 |
name: mean_reward
|
15 |
task:
|
16 |
type: reinforcement-learning
|
@@ -27,13 +27,13 @@ All models trained at this commit can be found at https://api.wandb.ai/links/sgo
|
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
-
This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:------------------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
-
| ppo | MicrortsDefeatCoacAIShaped-v3 | 1 | 0.
|
35 |
-
| ppo | MicrortsDefeatCoacAIShaped-v3 | 2 | 0.
|
36 |
-
| ppo | MicrortsDefeatCoacAIShaped-v3 | 3 | 0.
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
@@ -53,10 +53,10 @@ login`.
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
-
[
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
-
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
@@ -68,7 +68,7 @@ notebook.
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
-
commit the agent was trained on: [
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
@@ -107,6 +107,7 @@ close and has some additional data:
|
|
107 |
```
|
108 |
additional_keys_to_log:
|
109 |
- microrts_stats
|
|
|
110 |
algo: ppo
|
111 |
algo_hyperparams:
|
112 |
batch_size: 3072
|
@@ -129,7 +130,7 @@ env_hyperparams:
|
|
129 |
make_kwargs:
|
130 |
map_paths:
|
131 |
- maps/16x16/basesWorkers16x16.xml
|
132 |
-
max_steps:
|
133 |
num_selfplay_envs: 36
|
134 |
render_theme: 2
|
135 |
reward_weight:
|
@@ -142,10 +143,10 @@ env_hyperparams:
|
|
142 |
n_envs: 24
|
143 |
self_play_kwargs:
|
144 |
num_old_policies: 12
|
145 |
-
save_steps:
|
146 |
-
swap_steps:
|
147 |
swap_window_size: 4
|
148 |
-
window:
|
149 |
env_id: MicrortsDefeatCoacAIShaped-v3
|
150 |
eval_hyperparams:
|
151 |
deterministic: false
|
@@ -199,9 +200,9 @@ wandb_entity: null
|
|
199 |
wandb_group: null
|
200 |
wandb_project_name: rl-algo-impls-benchmarks
|
201 |
wandb_tags:
|
202 |
-
-
|
203 |
-
- host_192-9-
|
204 |
-
-
|
205 |
- v0.0.9
|
206 |
|
207 |
```
|
|
|
10 |
results:
|
11 |
- metrics:
|
12 |
- type: mean_reward
|
13 |
+
value: 0.77 +/- 0.64
|
14 |
name: mean_reward
|
15 |
task:
|
16 |
type: reinforcement-learning
|
|
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
+
This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [9ba0ab5](https://github.com/sgoodfriend/rl-algo-impls/tree/9ba0ab50894e5cea207289f4af8b53cbafa47748). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:------------------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
+
| ppo | MicrortsDefeatCoacAIShaped-v3 | 1 | 0.769231 | 0.638971 | 26 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/a0smxvhw) |
|
35 |
+
| ppo | MicrortsDefeatCoacAIShaped-v3 | 2 | 0.692308 | 0.721602 | 26 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/8ees317u) |
|
36 |
+
| ppo | MicrortsDefeatCoacAIShaped-v3 | 3 | 0.423077 | 0.884615 | 26 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/ifj50v2t) |
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
+
[9ba0ab5](https://github.com/sgoodfriend/rl-algo-impls/tree/9ba0ab50894e5cea207289f4af8b53cbafa47748).
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
+
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/a0smxvhw
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
+
commit the agent was trained on: [9ba0ab5](https://github.com/sgoodfriend/rl-algo-impls/tree/9ba0ab50894e5cea207289f4af8b53cbafa47748). While
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
|
|
107 |
```
|
108 |
additional_keys_to_log:
|
109 |
- microrts_stats
|
110 |
+
- microrts_results
|
111 |
algo: ppo
|
112 |
algo_hyperparams:
|
113 |
batch_size: 3072
|
|
|
130 |
make_kwargs:
|
131 |
map_paths:
|
132 |
- maps/16x16/basesWorkers16x16.xml
|
133 |
+
max_steps: 4000
|
134 |
num_selfplay_envs: 36
|
135 |
render_theme: 2
|
136 |
reward_weight:
|
|
|
143 |
n_envs: 24
|
144 |
self_play_kwargs:
|
145 |
num_old_policies: 12
|
146 |
+
save_steps: 300000
|
147 |
+
swap_steps: 6000
|
148 |
swap_window_size: 4
|
149 |
+
window: 33
|
150 |
env_id: MicrortsDefeatCoacAIShaped-v3
|
151 |
eval_hyperparams:
|
152 |
deterministic: false
|
|
|
200 |
wandb_group: null
|
201 |
wandb_project_name: rl-algo-impls-benchmarks
|
202 |
wandb_tags:
|
203 |
+
- benchmark_9ba0ab5
|
204 |
+
- host_192-9-155-233
|
205 |
+
- branch_main
|
206 |
- v0.0.9
|
207 |
|
208 |
```
|
pyproject.toml
CHANGED
@@ -26,7 +26,7 @@ dependencies = [
|
|
26 |
"stable-baselines3[extra] >= 1.7.0, < 1.8",
|
27 |
"gym[box2d] >= 0.21.0, < 0.22",
|
28 |
"pyglet == 1.5.27",
|
29 |
-
"wandb",
|
30 |
"pyvirtualdisplay",
|
31 |
"pybullet",
|
32 |
"tabulate",
|
|
|
26 |
"stable-baselines3[extra] >= 1.7.0, < 1.8",
|
27 |
"gym[box2d] >= 0.21.0, < 0.22",
|
28 |
"pyglet == 1.5.27",
|
29 |
+
"wandb == 0.13.10",
|
30 |
"pyvirtualdisplay",
|
31 |
"pybullet",
|
32 |
"tabulate",
|
replay.meta.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "640x640", "-pix_fmt", "rgb24", "-framerate", "150", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "150", "/tmp/
|
|
|
1 |
+
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "640x640", "-pix_fmt", "rgb24", "-framerate", "150", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "150", "/tmp/tmpo54rvbdq/ppo-Microrts-selfplay-unet-decay/replay.mp4"]}, "episodes": [{"r": 1.0, "l": 740, "t": 9.992017}]}
|
replay.mp4
CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
|
|
rl_algo_impls/a2c/optimize.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
import optuna
|
2 |
-
|
3 |
from copy import deepcopy
|
4 |
|
5 |
-
|
6 |
-
|
|
|
7 |
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
|
|
|
8 |
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
|
9 |
|
10 |
|
@@ -16,7 +16,11 @@ def sample_params(
|
|
16 |
hyperparams = deepcopy(base_hyperparams)
|
17 |
|
18 |
base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
|
19 |
-
env = make_eval_env(
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# env_hyperparams
|
22 |
env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
|
|
|
|
|
|
|
1 |
from copy import deepcopy
|
2 |
|
3 |
+
import optuna
|
4 |
+
|
5 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams
|
6 |
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
|
7 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
8 |
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
|
9 |
|
10 |
|
|
|
16 |
hyperparams = deepcopy(base_hyperparams)
|
17 |
|
18 |
base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
|
19 |
+
env = make_eval_env(
|
20 |
+
base_config,
|
21 |
+
base_env_hyperparams,
|
22 |
+
override_hparams={"n_envs": 1},
|
23 |
+
)
|
24 |
|
25 |
# env_hyperparams
|
26 |
env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
|
rl_algo_impls/huggingface_publish.py
CHANGED
@@ -133,7 +133,7 @@ def publish(
|
|
133 |
make_eval_env(
|
134 |
config,
|
135 |
EnvHyperparams(**config.env_hyperparams),
|
136 |
-
|
137 |
normalize_load_path=model_path,
|
138 |
),
|
139 |
os.path.join(repo_dir_path, "replay"),
|
|
|
133 |
make_eval_env(
|
134 |
config,
|
135 |
EnvHyperparams(**config.env_hyperparams),
|
136 |
+
override_hparams={"n_envs": 1},
|
137 |
normalize_load_path=model_path,
|
138 |
),
|
139 |
os.path.join(repo_dir_path, "replay"),
|
rl_algo_impls/hyperparams/a2c.yml
CHANGED
@@ -101,31 +101,32 @@ HopperBulletEnv-v0:
|
|
101 |
CarRacing-v0:
|
102 |
n_timesteps: !!float 4e6
|
103 |
env_hyperparams:
|
104 |
-
n_envs:
|
105 |
frame_stack: 4
|
106 |
normalize: true
|
107 |
normalize_kwargs:
|
108 |
norm_obs: false
|
109 |
norm_reward: true
|
110 |
policy_hyperparams:
|
111 |
-
use_sde:
|
112 |
-
log_std_init: -
|
113 |
init_layers_orthogonal: true
|
114 |
activation_fn: tanh
|
115 |
share_features_extractor: false
|
116 |
cnn_flatten_dim: 256
|
117 |
hidden_sizes: [256]
|
118 |
algo_hyperparams:
|
119 |
-
n_steps:
|
120 |
-
learning_rate: 0.
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
vf_coef: 0.
|
126 |
-
max_grad_norm:
|
127 |
-
normalize_advantage:
|
128 |
use_rms_prop: false
|
|
|
129 |
|
130 |
_atari: &atari-defaults
|
131 |
n_timesteps: !!float 1e7
|
|
|
101 |
CarRacing-v0:
|
102 |
n_timesteps: !!float 4e6
|
103 |
env_hyperparams:
|
104 |
+
n_envs: 4
|
105 |
frame_stack: 4
|
106 |
normalize: true
|
107 |
normalize_kwargs:
|
108 |
norm_obs: false
|
109 |
norm_reward: true
|
110 |
policy_hyperparams:
|
111 |
+
use_sde: true
|
112 |
+
log_std_init: -4.839609092563
|
113 |
init_layers_orthogonal: true
|
114 |
activation_fn: tanh
|
115 |
share_features_extractor: false
|
116 |
cnn_flatten_dim: 256
|
117 |
hidden_sizes: [256]
|
118 |
algo_hyperparams:
|
119 |
+
n_steps: 64
|
120 |
+
learning_rate: 0.000018971962220405576
|
121 |
+
gamma: 0.9942776405534832
|
122 |
+
gae_lambda: 0.9549244758833236
|
123 |
+
ent_coef: 0.0000015666550584860516
|
124 |
+
ent_coef_decay: linear
|
125 |
+
vf_coef: 0.12164696385898476
|
126 |
+
max_grad_norm: 2.2574480552177127
|
127 |
+
normalize_advantage: false
|
128 |
use_rms_prop: false
|
129 |
+
sde_sample_freq: 16
|
130 |
|
131 |
_atari: &atari-defaults
|
132 |
n_timesteps: !!float 1e7
|
rl_algo_impls/hyperparams/ppo.yml
CHANGED
@@ -252,13 +252,13 @@ MicrortsRandomEnemyShapedReward3-v1-NoMask:
|
|
252 |
_microrts_ai: µrts-ai-defaults
|
253 |
<<: *microrts-defaults
|
254 |
n_timesteps: !!float 100e6
|
255 |
-
additional_keys_to_log: ["microrts_stats"]
|
256 |
env_hyperparams: µrts-ai-env-defaults
|
257 |
n_envs: 24
|
258 |
env_type: microrts
|
259 |
make_kwargs: µrts-ai-env-make-kwargs-defaults
|
260 |
num_selfplay_envs: 0
|
261 |
-
max_steps:
|
262 |
render_theme: 2
|
263 |
map_paths: [maps/16x16/basesWorkers16x16.xml]
|
264 |
reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
|
@@ -399,10 +399,10 @@ Microrts-selfplay-unet: µrts-selfplay-defaults
|
|
399 |
num_selfplay_envs: 36
|
400 |
self_play_kwargs:
|
401 |
num_old_policies: 12
|
402 |
-
save_steps:
|
403 |
-
swap_steps:
|
404 |
swap_window_size: 4
|
405 |
-
window:
|
406 |
eval_hyperparams: µrts-selfplay-eval-defaults
|
407 |
<<: *microrts-coacai-eval-defaults
|
408 |
env_overrides: µrts-selfplay-eval-env-overrides
|
|
|
252 |
_microrts_ai: µrts-ai-defaults
|
253 |
<<: *microrts-defaults
|
254 |
n_timesteps: !!float 100e6
|
255 |
+
additional_keys_to_log: ["microrts_stats", "microrts_results"]
|
256 |
env_hyperparams: µrts-ai-env-defaults
|
257 |
n_envs: 24
|
258 |
env_type: microrts
|
259 |
make_kwargs: µrts-ai-env-make-kwargs-defaults
|
260 |
num_selfplay_envs: 0
|
261 |
+
max_steps: 4000
|
262 |
render_theme: 2
|
263 |
map_paths: [maps/16x16/basesWorkers16x16.xml]
|
264 |
reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
|
|
|
399 |
num_selfplay_envs: 36
|
400 |
self_play_kwargs:
|
401 |
num_old_policies: 12
|
402 |
+
save_steps: 300000
|
403 |
+
swap_steps: 6000
|
404 |
swap_window_size: 4
|
405 |
+
window: 33
|
406 |
eval_hyperparams: µrts-selfplay-eval-defaults
|
407 |
<<: *microrts-coacai-eval-defaults
|
408 |
env_overrides: µrts-selfplay-eval-env-overrides
|
rl_algo_impls/optimize.py
CHANGED
@@ -211,7 +211,7 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
|
|
211 |
eval_env = make_eval_env(
|
212 |
config,
|
213 |
EnvHyperparams(**config.env_hyperparams),
|
214 |
-
|
215 |
)
|
216 |
optimize_callback = OptimizeCallback(
|
217 |
policy,
|
@@ -331,7 +331,7 @@ def stepwise_optimize(
|
|
331 |
config,
|
332 |
EnvHyperparams(**config.env_hyperparams),
|
333 |
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
334 |
-
|
335 |
)
|
336 |
|
337 |
start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
|
|
|
211 |
eval_env = make_eval_env(
|
212 |
config,
|
213 |
EnvHyperparams(**config.env_hyperparams),
|
214 |
+
override_hparams={"n_envs": study_args.n_eval_envs},
|
215 |
)
|
216 |
optimize_callback = OptimizeCallback(
|
217 |
policy,
|
|
|
331 |
config,
|
332 |
EnvHyperparams(**config.env_hyperparams),
|
333 |
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
334 |
+
override_hparams={"n_envs": study_args.n_eval_envs},
|
335 |
)
|
336 |
|
337 |
start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
|
rl_algo_impls/ppo/ppo.py
CHANGED
@@ -110,7 +110,7 @@ class PPO(Algorithm):
|
|
110 |
) -> None:
|
111 |
super().__init__(policy, env, device, tb_writer)
|
112 |
self.policy = policy
|
113 |
-
self.get_action_mask = getattr(env, "get_action_mask")
|
114 |
|
115 |
self.gamma_schedule = (
|
116 |
linear_schedule(gamma, gamma_end)
|
|
|
110 |
) -> None:
|
111 |
super().__init__(policy, env, device, tb_writer)
|
112 |
self.policy = policy
|
113 |
+
self.get_action_mask = getattr(env, "get_action_mask", None)
|
114 |
|
115 |
self.gamma_schedule = (
|
116 |
linear_schedule(gamma, gamma_end)
|
rl_algo_impls/runner/config.py
CHANGED
@@ -52,6 +52,7 @@ class EnvHyperparams:
|
|
52 |
mask_actions: bool = False
|
53 |
bots: Optional[Dict[str, int]] = None
|
54 |
self_play_kwargs: Optional[Dict[str, Any]] = None
|
|
|
55 |
|
56 |
|
57 |
HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
|
|
|
52 |
mask_actions: bool = False
|
53 |
bots: Optional[Dict[str, int]] = None
|
54 |
self_play_kwargs: Optional[Dict[str, Any]] = None
|
55 |
+
selfplay_bots: Optional[Dict[str, int]] = None
|
56 |
|
57 |
|
58 |
HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
|
rl_algo_impls/runner/evaluate.py
CHANGED
@@ -70,7 +70,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
|
|
70 |
env = make_eval_env(
|
71 |
config,
|
72 |
EnvHyperparams(**config.env_hyperparams),
|
73 |
-
|
74 |
render=args.render,
|
75 |
normalize_load_path=model_path,
|
76 |
)
|
|
|
70 |
env = make_eval_env(
|
71 |
config,
|
72 |
EnvHyperparams(**config.env_hyperparams),
|
73 |
+
override_hparams={"n_envs": args.n_envs} if args.n_envs else None,
|
74 |
render=args.render,
|
75 |
normalize_load_path=model_path,
|
76 |
)
|
rl_algo_impls/runner/selfplay_evaluate.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import dataclasses
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import List, NamedTuple, Optional
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import wandb
|
11 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
|
12 |
+
from rl_algo_impls.runner.evaluate import Evaluation
|
13 |
+
from rl_algo_impls.runner.running_utils import (
|
14 |
+
get_device,
|
15 |
+
load_hyperparams,
|
16 |
+
make_policy,
|
17 |
+
set_seeds,
|
18 |
+
)
|
19 |
+
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
20 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
21 |
+
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class SelfplayEvalArgs(RunArgs):
|
26 |
+
# Either wandb_run_paths or model_file_paths must have 2 elements in it.
|
27 |
+
wandb_run_paths: List[str] = dataclasses.field(default_factory=list)
|
28 |
+
model_file_paths: List[str] = dataclasses.field(default_factory=list)
|
29 |
+
render: bool = False
|
30 |
+
best: bool = True
|
31 |
+
n_envs: int = 1
|
32 |
+
n_episodes: int = 1
|
33 |
+
deterministic_eval: Optional[bool] = None
|
34 |
+
no_print_returns: bool = False
|
35 |
+
video_path: Optional[str] = None
|
36 |
+
|
37 |
+
|
38 |
+
def selfplay_evaluate(args: SelfplayEvalArgs, root_dir: str) -> Evaluation:
|
39 |
+
if args.wandb_run_paths:
|
40 |
+
api = wandb.Api()
|
41 |
+
args, config, player_1_model_path = load_player(
|
42 |
+
api, args.wandb_run_paths[0], args, root_dir
|
43 |
+
)
|
44 |
+
_, _, player_2_model_path = load_player(
|
45 |
+
api, args.wandb_run_paths[1], args, root_dir
|
46 |
+
)
|
47 |
+
elif args.model_file_paths:
|
48 |
+
hyperparams = load_hyperparams(args.algo, args.env)
|
49 |
+
|
50 |
+
config = Config(args, hyperparams, root_dir)
|
51 |
+
player_1_model_path, player_2_model_path = args.model_file_paths
|
52 |
+
else:
|
53 |
+
raise ValueError("Must specify 2 wandb_run_paths or 2 model_file_paths")
|
54 |
+
|
55 |
+
print(args)
|
56 |
+
|
57 |
+
set_seeds(args.seed, args.use_deterministic_algorithms)
|
58 |
+
|
59 |
+
env_make_kwargs = (
|
60 |
+
config.eval_hyperparams.get("env_overrides", {}).get("make_kwargs", {}).copy()
|
61 |
+
)
|
62 |
+
env_make_kwargs["num_selfplay_envs"] = args.n_envs * 2
|
63 |
+
env = make_eval_env(
|
64 |
+
config,
|
65 |
+
EnvHyperparams(**config.env_hyperparams),
|
66 |
+
override_hparams={
|
67 |
+
"n_envs": args.n_envs,
|
68 |
+
"selfplay_bots": {
|
69 |
+
player_2_model_path: args.n_envs,
|
70 |
+
},
|
71 |
+
"self_play_kwargs": {
|
72 |
+
"num_old_policies": 0,
|
73 |
+
"save_steps": np.inf,
|
74 |
+
"swap_steps": np.inf,
|
75 |
+
"bot_always_player_2": True,
|
76 |
+
},
|
77 |
+
"bots": None,
|
78 |
+
"make_kwargs": env_make_kwargs,
|
79 |
+
},
|
80 |
+
render=args.render,
|
81 |
+
normalize_load_path=player_1_model_path,
|
82 |
+
)
|
83 |
+
if args.video_path:
|
84 |
+
env = VecEpisodeRecorder(
|
85 |
+
env, args.video_path, max_video_length=18000, num_episodes=args.n_episodes
|
86 |
+
)
|
87 |
+
device = get_device(config, env)
|
88 |
+
policy = make_policy(
|
89 |
+
args.algo,
|
90 |
+
env,
|
91 |
+
device,
|
92 |
+
load_path=player_1_model_path,
|
93 |
+
**config.policy_hyperparams,
|
94 |
+
).eval()
|
95 |
+
|
96 |
+
deterministic = (
|
97 |
+
args.deterministic_eval
|
98 |
+
if args.deterministic_eval is not None
|
99 |
+
else config.eval_hyperparams.get("deterministic", True)
|
100 |
+
)
|
101 |
+
return Evaluation(
|
102 |
+
policy,
|
103 |
+
evaluate(
|
104 |
+
env,
|
105 |
+
policy,
|
106 |
+
args.n_episodes,
|
107 |
+
render=args.render,
|
108 |
+
deterministic=deterministic,
|
109 |
+
print_returns=not args.no_print_returns,
|
110 |
+
),
|
111 |
+
config,
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
class PlayerData(NamedTuple):
|
116 |
+
args: SelfplayEvalArgs
|
117 |
+
config: Config
|
118 |
+
model_path: str
|
119 |
+
|
120 |
+
|
121 |
+
def load_player(
|
122 |
+
api: wandb.Api, run_path: str, args: SelfplayEvalArgs, root_dir: str
|
123 |
+
) -> PlayerData:
|
124 |
+
args = copy.copy(args)
|
125 |
+
|
126 |
+
run = api.run(run_path)
|
127 |
+
params = run.config
|
128 |
+
args.algo = params["algo"]
|
129 |
+
args.env = params["env"]
|
130 |
+
args.seed = params.get("seed", None)
|
131 |
+
args.use_deterministic_algorithms = params.get("use_deterministic_algorithms", True)
|
132 |
+
config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
|
133 |
+
model_path = config.model_dir_path(best=args.best, downloaded=True)
|
134 |
+
|
135 |
+
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
|
136 |
+
run.file(model_archive_name).download()
|
137 |
+
if os.path.isdir(model_path):
|
138 |
+
shutil.rmtree(model_path)
|
139 |
+
shutil.unpack_archive(model_archive_name, model_path)
|
140 |
+
os.remove(model_archive_name)
|
141 |
+
|
142 |
+
return PlayerData(args, config, model_path)
|
rl_algo_impls/runner/train.py
CHANGED
@@ -49,7 +49,7 @@ def train(args: TrainArgs):
|
|
49 |
print(hyperparams)
|
50 |
config = Config(args, hyperparams, os.getcwd())
|
51 |
|
52 |
-
wandb_enabled = args.wandb_project_name
|
53 |
if wandb_enabled:
|
54 |
wandb.tensorboard.patch(
|
55 |
root_logdir=config.tensorboard_summary_path, pytorch=True
|
@@ -100,12 +100,15 @@ def train(args: TrainArgs):
|
|
100 |
best_model_path=config.model_dir_path(best=True),
|
101 |
**config.eval_callback_params(),
|
102 |
video_env=make_eval_env(
|
103 |
-
config,
|
|
|
|
|
104 |
)
|
105 |
if record_best_videos
|
106 |
else None,
|
107 |
best_video_dir=config.best_videos_dir,
|
108 |
additional_keys_to_log=config.additional_keys_to_log,
|
|
|
109 |
)
|
110 |
callbacks: List[Callback] = [eval_callback]
|
111 |
if config.hyperparams.microrts_reward_decay_callback:
|
@@ -149,13 +152,8 @@ def train(args: TrainArgs):
|
|
149 |
|
150 |
if wandb_enabled:
|
151 |
shutil.make_archive(
|
152 |
-
os.path.join(wandb.run.dir, config.model_dir_name()),
|
153 |
"zip",
|
154 |
config.model_dir_path(),
|
155 |
)
|
156 |
-
shutil.make_archive(
|
157 |
-
os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
|
158 |
-
"zip",
|
159 |
-
config.model_dir_path(best=True),
|
160 |
-
)
|
161 |
wandb.finish()
|
|
|
49 |
print(hyperparams)
|
50 |
config = Config(args, hyperparams, os.getcwd())
|
51 |
|
52 |
+
wandb_enabled = bool(args.wandb_project_name)
|
53 |
if wandb_enabled:
|
54 |
wandb.tensorboard.patch(
|
55 |
root_logdir=config.tensorboard_summary_path, pytorch=True
|
|
|
100 |
best_model_path=config.model_dir_path(best=True),
|
101 |
**config.eval_callback_params(),
|
102 |
video_env=make_eval_env(
|
103 |
+
config,
|
104 |
+
EnvHyperparams(**config.env_hyperparams),
|
105 |
+
override_hparams={"n_envs": 1},
|
106 |
)
|
107 |
if record_best_videos
|
108 |
else None,
|
109 |
best_video_dir=config.best_videos_dir,
|
110 |
additional_keys_to_log=config.additional_keys_to_log,
|
111 |
+
wandb_enabled=wandb_enabled,
|
112 |
)
|
113 |
callbacks: List[Callback] = [eval_callback]
|
114 |
if config.hyperparams.microrts_reward_decay_callback:
|
|
|
152 |
|
153 |
if wandb_enabled:
|
154 |
shutil.make_archive(
|
155 |
+
os.path.join(wandb.run.dir, config.model_dir_name()), # type: ignore
|
156 |
"zip",
|
157 |
config.model_dir_path(),
|
158 |
)
|
|
|
|
|
|
|
|
|
|
|
159 |
wandb.finish()
|
rl_algo_impls/selfplay_enjoy.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
+
|
6 |
+
from rl_algo_impls.runner.running_utils import base_parser
|
7 |
+
from rl_algo_impls.runner.selfplay_evaluate import SelfplayEvalArgs, selfplay_evaluate
|
8 |
+
|
9 |
+
|
10 |
+
def selfplay_enjoy() -> None:
|
11 |
+
parser = base_parser(multiple=False)
|
12 |
+
parser.add_argument(
|
13 |
+
"--wandb-run-paths",
|
14 |
+
type=str,
|
15 |
+
nargs="*",
|
16 |
+
help="WandB run paths to load players from. Must be 0 or 2",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--model-file-paths",
|
20 |
+
type=str,
|
21 |
+
help="File paths to load players from. Must be 0 or 2",
|
22 |
+
)
|
23 |
+
parser.add_argument("--render", action="store_true")
|
24 |
+
parser.add_argument("--n-envs", default=1, type=int)
|
25 |
+
parser.add_argument("--n-episodes", default=1, type=int)
|
26 |
+
parser.add_argument("--deterministic-eval", default=None, type=bool)
|
27 |
+
parser.add_argument(
|
28 |
+
"--no-print-returns", action="store_true", help="Limit printing"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--video-path", type=str, help="Path to save video of all plays"
|
32 |
+
)
|
33 |
+
# parser.set_defaults(
|
34 |
+
# algo=["ppo"],
|
35 |
+
# env=["Microrts-selfplay-unet-decay"],
|
36 |
+
# n_episodes=10,
|
37 |
+
# model_file_paths=[
|
38 |
+
# "downloaded_models/ppo-Microrts-selfplay-unet-decay-S3-best",
|
39 |
+
# "downloaded_models/ppo-Microrts-selfplay-unet-decay-S2-best",
|
40 |
+
# ],
|
41 |
+
# video_path="/Users/sgoodfriend/Desktop/decay3-vs-decay2",
|
42 |
+
# )
|
43 |
+
args = parser.parse_args()
|
44 |
+
args.algo = args.algo[0]
|
45 |
+
args.env = args.env[0]
|
46 |
+
args.seed = args.seed[0]
|
47 |
+
args = SelfplayEvalArgs(**vars(args))
|
48 |
+
|
49 |
+
selfplay_evaluate(args, os.getcwd())
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
selfplay_enjoy()
|
rl_algo_impls/shared/actor/state_dependent_noise.py
CHANGED
@@ -172,7 +172,7 @@ class StateDependentNoiseActorHead(Actor):
|
|
172 |
not action_masks
|
173 |
), f"{self.__class__.__name__} does not support action_masks"
|
174 |
pi = self._distribution(obs)
|
175 |
-
return pi_forward(pi, actions)
|
176 |
|
177 |
def sample_weights(self, batch_size: int = 1) -> None:
|
178 |
std = self._get_std()
|
@@ -187,13 +187,13 @@ class StateDependentNoiseActorHead(Actor):
|
|
187 |
|
188 |
|
189 |
def pi_forward(
|
190 |
-
distribution: Distribution,
|
|
|
|
|
191 |
) -> PiForward:
|
192 |
logp_a = None
|
193 |
entropy = None
|
194 |
if actions is not None:
|
195 |
logp_a = distribution.log_prob(actions)
|
196 |
-
entropy = (
|
197 |
-
-logp_a if self.bijector else sum_independent_dims(distribution.entropy())
|
198 |
-
)
|
199 |
return PiForward(distribution, logp_a, entropy)
|
|
|
172 |
not action_masks
|
173 |
), f"{self.__class__.__name__} does not support action_masks"
|
174 |
pi = self._distribution(obs)
|
175 |
+
return pi_forward(pi, actions, self.bijector)
|
176 |
|
177 |
def sample_weights(self, batch_size: int = 1) -> None:
|
178 |
std = self._get_std()
|
|
|
187 |
|
188 |
|
189 |
def pi_forward(
|
190 |
+
distribution: Distribution,
|
191 |
+
actions: Optional[torch.Tensor] = None,
|
192 |
+
bijector: Optional[TanhBijector] = None,
|
193 |
) -> PiForward:
|
194 |
logp_a = None
|
195 |
entropy = None
|
196 |
if actions is not None:
|
197 |
logp_a = distribution.log_prob(actions)
|
198 |
+
entropy = -logp_a if bijector else sum_independent_dims(distribution.entropy())
|
|
|
|
|
199 |
return PiForward(distribution, logp_a, entropy)
|
rl_algo_impls/shared/callbacks/eval_callback.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import itertools
|
2 |
import os
|
|
|
3 |
from time import perf_counter
|
4 |
from typing import Dict, List, Optional, Union
|
5 |
|
@@ -94,7 +95,7 @@ def evaluate(
|
|
94 |
)
|
95 |
|
96 |
obs = env.reset()
|
97 |
-
get_action_mask = getattr(env, "get_action_mask")
|
98 |
while not episodes.is_done():
|
99 |
act = policy.act(
|
100 |
obs,
|
@@ -132,6 +133,7 @@ class EvalCallback(Callback):
|
|
132 |
ignore_first_episode: bool = False,
|
133 |
additional_keys_to_log: Optional[List[str]] = None,
|
134 |
score_function: str = "mean-std",
|
|
|
135 |
) -> None:
|
136 |
super().__init__()
|
137 |
self.policy = policy
|
@@ -157,6 +159,7 @@ class EvalCallback(Callback):
|
|
157 |
self.ignore_first_episode = ignore_first_episode
|
158 |
self.additional_keys_to_log = additional_keys_to_log
|
159 |
self.score_function = score_function
|
|
|
160 |
|
161 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
162 |
super().on_step(timesteps_elapsed)
|
@@ -196,6 +199,15 @@ class EvalCallback(Callback):
|
|
196 |
assert self.best_model_path
|
197 |
self.policy.save(self.best_model_path)
|
198 |
print("Saved best model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
self.best.write_to_tensorboard(
|
200 |
self.tb_writer, "best_eval", self.timesteps_elapsed
|
201 |
)
|
|
|
1 |
import itertools
|
2 |
import os
|
3 |
+
import shutil
|
4 |
from time import perf_counter
|
5 |
from typing import Dict, List, Optional, Union
|
6 |
|
|
|
95 |
)
|
96 |
|
97 |
obs = env.reset()
|
98 |
+
get_action_mask = getattr(env, "get_action_mask", None)
|
99 |
while not episodes.is_done():
|
100 |
act = policy.act(
|
101 |
obs,
|
|
|
133 |
ignore_first_episode: bool = False,
|
134 |
additional_keys_to_log: Optional[List[str]] = None,
|
135 |
score_function: str = "mean-std",
|
136 |
+
wandb_enabled: bool = False,
|
137 |
) -> None:
|
138 |
super().__init__()
|
139 |
self.policy = policy
|
|
|
159 |
self.ignore_first_episode = ignore_first_episode
|
160 |
self.additional_keys_to_log = additional_keys_to_log
|
161 |
self.score_function = score_function
|
162 |
+
self.wandb_enabled = wandb_enabled
|
163 |
|
164 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
165 |
super().on_step(timesteps_elapsed)
|
|
|
199 |
assert self.best_model_path
|
200 |
self.policy.save(self.best_model_path)
|
201 |
print("Saved best model")
|
202 |
+
if self.wandb_enabled:
|
203 |
+
import wandb
|
204 |
+
|
205 |
+
best_model_name = os.path.split(self.best_model_path)[-1]
|
206 |
+
shutil.make_archive(
|
207 |
+
os.path.join(wandb.run.dir, best_model_name), # type: ignore
|
208 |
+
"zip",
|
209 |
+
self.best_model_path,
|
210 |
+
)
|
211 |
self.best.write_to_tensorboard(
|
212 |
self.tb_writer, "best_eval", self.timesteps_elapsed
|
213 |
)
|
rl_algo_impls/shared/policy/actor_critic.py
CHANGED
@@ -93,7 +93,7 @@ class ActorCritic(OnPolicy):
|
|
93 |
|
94 |
observation_space = single_observation_space(env)
|
95 |
action_space = single_action_space(env)
|
96 |
-
action_plane_space = getattr(env, "action_plane_space")
|
97 |
|
98 |
self.action_space = action_space
|
99 |
self.squash_output = squash_output
|
|
|
93 |
|
94 |
observation_space = single_observation_space(env)
|
95 |
action_space = single_action_space(env)
|
96 |
+
action_plane_space = getattr(env, "action_plane_space", None)
|
97 |
|
98 |
self.action_space = action_space
|
99 |
self.squash_output = squash_output
|
rl_algo_impls/shared/vec_env/make_env.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import asdict
|
2 |
-
from typing import Optional
|
3 |
|
4 |
from torch.utils.tensorboard.writer import SummaryWriter
|
5 |
|
@@ -52,7 +52,7 @@ def make_env(
|
|
52 |
def make_eval_env(
|
53 |
config: Config,
|
54 |
hparams: EnvHyperparams,
|
55 |
-
|
56 |
**kwargs,
|
57 |
) -> VecEnv:
|
58 |
kwargs = kwargs.copy()
|
@@ -62,10 +62,11 @@ def make_eval_env(
|
|
62 |
hparams_kwargs = asdict(hparams)
|
63 |
hparams_kwargs.update(env_overrides)
|
64 |
hparams = EnvHyperparams(**hparams_kwargs)
|
65 |
-
if
|
66 |
hparams_kwargs = asdict(hparams)
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
hparams = EnvHyperparams(**hparams_kwargs)
|
71 |
return make_env(config, hparams, **kwargs)
|
|
|
1 |
from dataclasses import asdict
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
|
4 |
from torch.utils.tensorboard.writer import SummaryWriter
|
5 |
|
|
|
52 |
def make_eval_env(
|
53 |
config: Config,
|
54 |
hparams: EnvHyperparams,
|
55 |
+
override_hparams: Optional[Dict[str, Any]] = None,
|
56 |
**kwargs,
|
57 |
) -> VecEnv:
|
58 |
kwargs = kwargs.copy()
|
|
|
62 |
hparams_kwargs = asdict(hparams)
|
63 |
hparams_kwargs.update(env_overrides)
|
64 |
hparams = EnvHyperparams(**hparams_kwargs)
|
65 |
+
if override_hparams:
|
66 |
hparams_kwargs = asdict(hparams)
|
67 |
+
for k, v in override_hparams.items():
|
68 |
+
hparams_kwargs[k] = v
|
69 |
+
if k == "n_envs" and v == 1:
|
70 |
+
hparams_kwargs["vec_env_class"] = "sync"
|
71 |
hparams = EnvHyperparams(**hparams_kwargs)
|
72 |
return make_env(config, hparams, **kwargs)
|
rl_algo_impls/shared/vec_env/microrts.py
CHANGED
@@ -50,6 +50,7 @@ def make_microrts_env(
|
|
50 |
_, # mask_actions
|
51 |
bots,
|
52 |
self_play_kwargs,
|
|
|
53 |
) = astuple(hparams)
|
54 |
|
55 |
seed = config.seed(training=training)
|
@@ -65,6 +66,7 @@ def make_microrts_env(
|
|
65 |
n_envs
|
66 |
- make_kwargs["num_selfplay_envs"]
|
67 |
+ self_play_kwargs.get("num_old_policies", 0)
|
|
|
68 |
)
|
69 |
else:
|
70 |
num_bot_envs = n_envs
|
@@ -100,14 +102,16 @@ def make_microrts_env(
|
|
100 |
envs = MicrortsMaskWrapper(envs)
|
101 |
|
102 |
if self_play_kwargs:
|
103 |
-
|
|
|
|
|
104 |
|
105 |
if seed is not None:
|
106 |
envs.action_space.seed(seed)
|
107 |
envs.observation_space.seed(seed)
|
108 |
|
109 |
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
110 |
-
envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99))
|
111 |
if training:
|
112 |
assert tb_writer
|
113 |
envs = EpisodeStatsWriter(
|
|
|
50 |
_, # mask_actions
|
51 |
bots,
|
52 |
self_play_kwargs,
|
53 |
+
selfplay_bots,
|
54 |
) = astuple(hparams)
|
55 |
|
56 |
seed = config.seed(training=training)
|
|
|
66 |
n_envs
|
67 |
- make_kwargs["num_selfplay_envs"]
|
68 |
+ self_play_kwargs.get("num_old_policies", 0)
|
69 |
+
+ (len(selfplay_bots) if selfplay_bots else 0)
|
70 |
)
|
71 |
else:
|
72 |
num_bot_envs = n_envs
|
|
|
102 |
envs = MicrortsMaskWrapper(envs)
|
103 |
|
104 |
if self_play_kwargs:
|
105 |
+
if selfplay_bots:
|
106 |
+
self_play_kwargs["selfplay_bots"] = selfplay_bots
|
107 |
+
envs = SelfPlayWrapper(envs, config, **self_play_kwargs)
|
108 |
|
109 |
if seed is not None:
|
110 |
envs.action_space.seed(seed)
|
111 |
envs.observation_space.seed(seed)
|
112 |
|
113 |
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
114 |
+
envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99), bots)
|
115 |
if training:
|
116 |
assert tb_writer
|
117 |
envs = EpisodeStatsWriter(
|
rl_algo_impls/shared/vec_env/procgen.py
CHANGED
@@ -41,6 +41,8 @@ def make_procgen_env(
|
|
41 |
_, # normalize_type
|
42 |
_, # mask_actions
|
43 |
_, # bots
|
|
|
|
|
44 |
) = astuple(hparams)
|
45 |
|
46 |
seed = config.seed(training=training)
|
|
|
41 |
_, # normalize_type
|
42 |
_, # mask_actions
|
43 |
_, # bots
|
44 |
+
_, # self_play_kwargs
|
45 |
+
_, # selfplay_bots
|
46 |
) = astuple(hparams)
|
47 |
|
48 |
seed = config.seed(training=training)
|
rl_algo_impls/shared/vec_env/vec_env.py
CHANGED
@@ -73,6 +73,8 @@ def make_vec_env(
|
|
73 |
normalize_type,
|
74 |
mask_actions,
|
75 |
_, # bots
|
|
|
|
|
76 |
) = astuple(hparams)
|
77 |
|
78 |
import_for_env_id(config.env_id)
|
|
|
73 |
normalize_type,
|
74 |
mask_actions,
|
75 |
_, # bots
|
76 |
+
_, # self_play_kwargs
|
77 |
+
_, # selfplay_bots
|
78 |
) = astuple(hparams)
|
79 |
|
80 |
import_for_env_id(config.env_id)
|
rl_algo_impls/wrappers/action_mask_wrapper.py
CHANGED
@@ -16,11 +16,11 @@ class IncompleteArrayError(Exception):
|
|
16 |
|
17 |
class SingleActionMaskWrapper(VecotarableWrapper):
|
18 |
def get_action_mask(self) -> Optional[np.ndarray]:
|
19 |
-
envs = getattr(self.env.unwrapped, "envs") # type: ignore
|
20 |
assert (
|
21 |
envs
|
22 |
), f"{self.__class__.__name__} expects to wrap synchronous vectorized env"
|
23 |
-
masks = [getattr(e.unwrapped, "action_mask") for e in envs]
|
24 |
assert all(m is not None for m in masks)
|
25 |
return np.array(masks, dtype=np.bool_)
|
26 |
|
|
|
16 |
|
17 |
class SingleActionMaskWrapper(VecotarableWrapper):
|
18 |
def get_action_mask(self) -> Optional[np.ndarray]:
|
19 |
+
envs = getattr(self.env.unwrapped, "envs", None) # type: ignore
|
20 |
assert (
|
21 |
envs
|
22 |
), f"{self.__class__.__name__} expects to wrap synchronous vectorized env"
|
23 |
+
masks = [getattr(e.unwrapped, "action_mask", None) for e in envs]
|
24 |
assert all(m is not None for m in masks)
|
25 |
return np.array(masks, dtype=np.bool_)
|
26 |
|
rl_algo_impls/wrappers/microrts_stats_recorder.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Any, Dict, List
|
2 |
|
3 |
import numpy as np
|
4 |
|
@@ -10,10 +10,19 @@ from rl_algo_impls.wrappers.vectorable_wrapper import (
|
|
10 |
|
11 |
|
12 |
class MicrortsStatsRecorder(VecotarableWrapper):
|
13 |
-
def __init__(
|
|
|
|
|
14 |
super().__init__(env)
|
15 |
self.gamma = gamma
|
16 |
self.raw_rewards = [[] for _ in range(self.num_envs)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def reset(self) -> VecEnvObs:
|
19 |
obs = super().reset()
|
@@ -33,4 +42,19 @@ class MicrortsStatsRecorder(VecotarableWrapper):
|
|
33 |
raw_rewards = np.array(self.raw_rewards[idx]).sum(0)
|
34 |
raw_names = [str(rf) for rf in self.env.unwrapped.rfs]
|
35 |
info["microrts_stats"] = dict(zip(raw_names, raw_rewards))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
self.raw_rewards[idx] = []
|
|
|
1 |
+
from typing import Any, Dict, List, Optional
|
2 |
|
3 |
import numpy as np
|
4 |
|
|
|
10 |
|
11 |
|
12 |
class MicrortsStatsRecorder(VecotarableWrapper):
|
13 |
+
def __init__(
|
14 |
+
self, env, gamma: float, bots: Optional[Dict[str, int]] = None
|
15 |
+
) -> None:
|
16 |
super().__init__(env)
|
17 |
self.gamma = gamma
|
18 |
self.raw_rewards = [[] for _ in range(self.num_envs)]
|
19 |
+
self.bots = bots
|
20 |
+
if self.bots:
|
21 |
+
self.bot_at_index = [None] * (env.num_envs - sum(self.bots.values()))
|
22 |
+
for b, n in self.bots.items():
|
23 |
+
self.bot_at_index.extend([b] * n)
|
24 |
+
else:
|
25 |
+
self.bot_at_index = [None] * env.num_envs
|
26 |
|
27 |
def reset(self) -> VecEnvObs:
|
28 |
obs = super().reset()
|
|
|
42 |
raw_rewards = np.array(self.raw_rewards[idx]).sum(0)
|
43 |
raw_names = [str(rf) for rf in self.env.unwrapped.rfs]
|
44 |
info["microrts_stats"] = dict(zip(raw_names, raw_rewards))
|
45 |
+
|
46 |
+
winloss = raw_rewards[raw_names.index("WinLossRewardFunction")]
|
47 |
+
microrts_results = {
|
48 |
+
"win": int(winloss == 1),
|
49 |
+
"draw": int(winloss == 0),
|
50 |
+
"loss": int(winloss == -1),
|
51 |
+
}
|
52 |
+
bot = self.bot_at_index[idx]
|
53 |
+
if bot:
|
54 |
+
microrts_results.update(
|
55 |
+
{f"{k}_{bot}": v for k, v in microrts_results.items()}
|
56 |
+
)
|
57 |
+
|
58 |
+
info["microrts_results"] = microrts_results
|
59 |
+
|
60 |
self.raw_rewards[idx] = []
|
rl_algo_impls/wrappers/self_play_wrapper.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import copy
|
2 |
import random
|
3 |
from collections import deque
|
4 |
-
from typing import Deque, List, Optional
|
5 |
|
6 |
import numpy as np
|
7 |
|
|
|
8 |
from rl_algo_impls.shared.policy.policy import Policy
|
9 |
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
|
10 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
@@ -21,11 +22,14 @@ class SelfPlayWrapper(VecotarableWrapper):
|
|
21 |
def __init__(
|
22 |
self,
|
23 |
env,
|
|
|
24 |
num_old_policies: int = 0,
|
25 |
save_steps: int = 20_000,
|
26 |
swap_steps: int = 10_000,
|
27 |
window: int = 10,
|
28 |
swap_window_size: int = 2,
|
|
|
|
|
29 |
) -> None:
|
30 |
super().__init__(env)
|
31 |
assert num_old_policies % 2 == 0, f"num_old_policies must be even"
|
@@ -33,17 +37,26 @@ class SelfPlayWrapper(VecotarableWrapper):
|
|
33 |
num_old_policies % swap_window_size == 0
|
34 |
), f"num_old_policies must be a multiple of swap_window_size"
|
35 |
|
|
|
36 |
self.num_old_policies = num_old_policies
|
37 |
self.save_steps = save_steps
|
38 |
self.swap_steps = swap_steps
|
39 |
self.swap_window_size = swap_window_size
|
|
|
|
|
40 |
|
41 |
self.policies: Deque[Policy] = deque(maxlen=window)
|
42 |
self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs
|
43 |
self.steps_since_swap = np.zeros(env.num_envs)
|
44 |
|
|
|
|
|
45 |
self.num_envs = env.num_envs - num_old_policies
|
46 |
|
|
|
|
|
|
|
|
|
47 |
def get_action_mask(self) -> Optional[np.ndarray]:
|
48 |
return self.env.get_action_mask()[self.learner_indexes()]
|
49 |
|
@@ -54,10 +67,12 @@ class SelfPlayWrapper(VecotarableWrapper):
|
|
54 |
copied_policy.train(False)
|
55 |
self.policies.append(copied_policy)
|
56 |
|
57 |
-
if all(p is None for p in self.policy_assignments):
|
58 |
for i in range(self.num_old_policies):
|
59 |
# Switch between player 1 and 2
|
60 |
-
self.policy_assignments[
|
|
|
|
|
61 |
|
62 |
def swap_policy(self, idx: int, swap_window_size: int = 1) -> None:
|
63 |
policy = random.choice(self.policies)
|
@@ -69,6 +84,30 @@ class SelfPlayWrapper(VecotarableWrapper):
|
|
69 |
swap_window_size * 2
|
70 |
)
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
|
73 |
env = self.env # type: ignore
|
74 |
all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype)
|
|
|
1 |
import copy
|
2 |
import random
|
3 |
from collections import deque
|
4 |
+
from typing import Any, Deque, Dict, List, Optional
|
5 |
|
6 |
import numpy as np
|
7 |
|
8 |
+
from rl_algo_impls.runner.config import Config
|
9 |
from rl_algo_impls.shared.policy.policy import Policy
|
10 |
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
|
11 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
|
|
22 |
def __init__(
|
23 |
self,
|
24 |
env,
|
25 |
+
config: Config,
|
26 |
num_old_policies: int = 0,
|
27 |
save_steps: int = 20_000,
|
28 |
swap_steps: int = 10_000,
|
29 |
window: int = 10,
|
30 |
swap_window_size: int = 2,
|
31 |
+
selfplay_bots: Optional[Dict[str, Any]] = None,
|
32 |
+
bot_always_player_2: bool = False,
|
33 |
) -> None:
|
34 |
super().__init__(env)
|
35 |
assert num_old_policies % 2 == 0, f"num_old_policies must be even"
|
|
|
37 |
num_old_policies % swap_window_size == 0
|
38 |
), f"num_old_policies must be a multiple of swap_window_size"
|
39 |
|
40 |
+
self.config = config
|
41 |
self.num_old_policies = num_old_policies
|
42 |
self.save_steps = save_steps
|
43 |
self.swap_steps = swap_steps
|
44 |
self.swap_window_size = swap_window_size
|
45 |
+
self.selfplay_bots = selfplay_bots
|
46 |
+
self.bot_always_player_2 = bot_always_player_2
|
47 |
|
48 |
self.policies: Deque[Policy] = deque(maxlen=window)
|
49 |
self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs
|
50 |
self.steps_since_swap = np.zeros(env.num_envs)
|
51 |
|
52 |
+
self.selfplay_policies: Dict[str, Policy] = {}
|
53 |
+
|
54 |
self.num_envs = env.num_envs - num_old_policies
|
55 |
|
56 |
+
if self.selfplay_bots:
|
57 |
+
self.num_envs -= sum(self.selfplay_bots.values())
|
58 |
+
self.initialize_selfplay_bots()
|
59 |
+
|
60 |
def get_action_mask(self) -> Optional[np.ndarray]:
|
61 |
return self.env.get_action_mask()[self.learner_indexes()]
|
62 |
|
|
|
67 |
copied_policy.train(False)
|
68 |
self.policies.append(copied_policy)
|
69 |
|
70 |
+
if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]):
|
71 |
for i in range(self.num_old_policies):
|
72 |
# Switch between player 1 and 2
|
73 |
+
self.policy_assignments[
|
74 |
+
2 * i + (i % 2 if not self.bot_always_player_2 else 1)
|
75 |
+
] = copied_policy
|
76 |
|
77 |
def swap_policy(self, idx: int, swap_window_size: int = 1) -> None:
|
78 |
policy = random.choice(self.policies)
|
|
|
84 |
swap_window_size * 2
|
85 |
)
|
86 |
|
87 |
+
def initialize_selfplay_bots(self) -> None:
|
88 |
+
if not self.selfplay_bots:
|
89 |
+
return
|
90 |
+
from rl_algo_impls.runner.running_utils import get_device, make_policy
|
91 |
+
|
92 |
+
env = self.env # Type: ignore
|
93 |
+
device = get_device(self.config, env)
|
94 |
+
start_idx = 2 * self.num_old_policies
|
95 |
+
for model_path, n in self.selfplay_bots.items():
|
96 |
+
policy = make_policy(
|
97 |
+
self.config.algo,
|
98 |
+
env,
|
99 |
+
device,
|
100 |
+
load_path=model_path,
|
101 |
+
**self.config.policy_hyperparams,
|
102 |
+
).eval()
|
103 |
+
self.selfplay_policies["model_path"] = policy
|
104 |
+
for idx in range(start_idx, start_idx + 2 * n, 2):
|
105 |
+
bot_idx = (
|
106 |
+
(idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2)
|
107 |
+
)
|
108 |
+
self.policy_assignments[bot_idx] = policy
|
109 |
+
start_idx += 2 * n
|
110 |
+
|
111 |
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
|
112 |
env = self.env # type: ignore
|
113 |
all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype)
|
rl_algo_impls/wrappers/vec_episode_recorder.py
CHANGED
@@ -1,21 +1,24 @@
|
|
1 |
import numpy as np
|
2 |
-
|
3 |
from gym.wrappers.monitoring.video_recorder import VideoRecorder
|
4 |
|
5 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
6 |
-
VecotarableWrapper,
|
7 |
VecEnvObs,
|
8 |
VecEnvStepReturn,
|
|
|
9 |
)
|
10 |
|
11 |
|
12 |
class VecEpisodeRecorder(VecotarableWrapper):
|
13 |
-
def __init__(
|
|
|
|
|
14 |
super().__init__(env)
|
15 |
self.base_path = base_path
|
16 |
self.max_video_length = max_video_length
|
|
|
17 |
self.video_recorder = None
|
18 |
self.recorded_frames = 0
|
|
|
19 |
|
20 |
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
|
21 |
obs, rew, dones, infos = self.env.step(actions)
|
@@ -23,13 +26,21 @@ class VecEpisodeRecorder(VecotarableWrapper):
|
|
23 |
if self.video_recorder:
|
24 |
self.video_recorder.capture_frame()
|
25 |
self.recorded_frames += 1
|
|
|
|
|
26 |
if dones[0] and infos[0].get("episode"):
|
27 |
episode_info = {
|
28 |
k: v.item() if hasattr(v, "item") else v
|
29 |
for k, v in infos[0]["episode"].items()
|
30 |
}
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
self._close_video_recorder()
|
34 |
return obs, rew, dones, infos
|
35 |
|
|
|
1 |
import numpy as np
|
|
|
2 |
from gym.wrappers.monitoring.video_recorder import VideoRecorder
|
3 |
|
4 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
|
|
5 |
VecEnvObs,
|
6 |
VecEnvStepReturn,
|
7 |
+
VecotarableWrapper,
|
8 |
)
|
9 |
|
10 |
|
11 |
class VecEpisodeRecorder(VecotarableWrapper):
|
12 |
+
def __init__(
|
13 |
+
self, env, base_path: str, max_video_length: int = 3600, num_episodes: int = 1
|
14 |
+
):
|
15 |
super().__init__(env)
|
16 |
self.base_path = base_path
|
17 |
self.max_video_length = max_video_length
|
18 |
+
self.num_episodes = num_episodes
|
19 |
self.video_recorder = None
|
20 |
self.recorded_frames = 0
|
21 |
+
self.num_completed = 0
|
22 |
|
23 |
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
|
24 |
obs, rew, dones, infos = self.env.step(actions)
|
|
|
26 |
if self.video_recorder:
|
27 |
self.video_recorder.capture_frame()
|
28 |
self.recorded_frames += 1
|
29 |
+
if dones[0]:
|
30 |
+
self.num_completed += 1
|
31 |
if dones[0] and infos[0].get("episode"):
|
32 |
episode_info = {
|
33 |
k: v.item() if hasattr(v, "item") else v
|
34 |
for k, v in infos[0]["episode"].items()
|
35 |
}
|
36 |
+
|
37 |
+
if "episodes" not in self.video_recorder.metadata:
|
38 |
+
self.video_recorder.metadata["episodes"] = []
|
39 |
+
self.video_recorder.metadata["episodes"].append(episode_info)
|
40 |
+
if (
|
41 |
+
self.num_completed == self.num_episodes
|
42 |
+
or self.recorded_frames > self.max_video_length
|
43 |
+
):
|
44 |
self._close_video_recorder()
|
45 |
return obs, rew, dones, infos
|
46 |
|
saved_models/ppo-Microrts-selfplay-unet-decay-S1-best/model.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 15323895
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7bee4122bcaffdea46193740a46d983e016e5b71d837ee1221fbc4b21f15cc39
|
3 |
size 15323895
|
selfplay_enjoy.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.selfplay_enjoy import selfplay_enjoy
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
selfplay_enjoy()
|