TheoVincent
commited on
Commit
·
4cb4fc3
1
Parent(s):
e14ba50
minimal example - code
Browse files- .gitattributes +1 -0
- README.md +38 -0
- atari.py +126 -0
- evaluate.ipynb +93 -0
- networks.py +156 -0
- performances.png +0 -0
- requirements.txt +98 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*best_online_params filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model parameters training with `i-DQN` and `i-IQN`
|
2 |
+
This repository contains the model parameters trained with `i-DQN` on [$57$ Atari games](#list-of-games-for-i-dqn) and trained with `i-IQN` on [$20$ Atari games](#list-of-games-for-i-iqn) 🎮. $5$ seeds are available for each configuration which makes a total of $385$ available models 📈.
|
3 |
+
|
4 |
+
The [evaluate.ipynb](./evaluate.ipynb) notebook contains a minimal example to evaluate to model parameters 🧑🏫. It uses JAX 🚀.
|
5 |
+
|
6 |
+
ps: The set of [$20$ Atari games](#list-of-games-for-i-iqn) is included in the set of [$57$ Atari games](#list-of-games-for-i-dqn).
|
7 |
+
|
8 |
+
### Model performances
|
9 |
+
`i-DQN` and `i-IQN` are improvements made over [`DQN`](https://www.nature.com/articles/nature14236.pdf) and [`IQN`](https://arxiv.org/abs/1806.06923) ✨. Check it out on [arXiv](https://arxiv.org/abs/2403.02107)! | <img src="performances.png" alt="drawing" width="600"/>
|
10 |
+
:-:|:-:
|
11 |
+
|
12 |
+
|
13 |
+
### List of games for `i-DQN`
|
14 |
+
Alien, Amidar, Assault, Asterix, Asteroids, Atlantis, BankHeist, BattleZone, BeamRider, Berzerk, Bowling, Boxing, Breakout, Centipede, ChopperCommand, CrazyClimber, DemonAttack, DoubleDunk, Enduro, FishingDerby, Freeway, Frostbite, Gopher, Gravitar, Hero, IceHockey, Jamesbond, Kangaroo, Krull, KungFuMaster, MontezumaRevenge, MsPacman, NameThisGame, Phoenix, Pitfall, Pong, Pooyan, PrivateEye, Qbert, Riverraid, RoadRunner, Robotank, Seaquest, Skiing, Solaris, SpaceInvaders, StarGunner, Tennis, TimePilot, Tutankham, UpNDown, Venture, VideoPinball, WizardOfWor, YarsRevenge, Zaxxon.
|
15 |
+
|
16 |
+
### List of games for `i-IQN`
|
17 |
+
Alien, Assault, BankHeist, Berzerk, Breakout, Centipede, ChopperCommand, DemonAttack, Enduro, Frostbite, Gopher, Gravitar, IceHockey, Jamesbond, Krull, KungFuMaster, Riverraid, Seaquest, Skiing, StarGunner.
|
18 |
+
|
19 |
+
## User installation
|
20 |
+
Python 3.10 is recommended. Create a Python virtual environment, activate it, update pip and install the package and its dependencies in editable mode:
|
21 |
+
```bash
|
22 |
+
python3.10 -m venv env
|
23 |
+
source env/bin/activate
|
24 |
+
pip install --upgrade pip
|
25 |
+
pip install numpy==1.23.5 # to avoid numpy==2.XX
|
26 |
+
pip install -r requirements.txt
|
27 |
+
pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
28 |
+
```
|
29 |
+
|
30 |
+
## Citing `i-QN`
|
31 |
+
```
|
32 |
+
@article{vincent2024iterated,
|
33 |
+
title={Iterated $ Q $-Network: Beyond the One-Step Bellman Operator},
|
34 |
+
author={Vincent, Th{\'e}o and Palenicek, Daniel and Belousov, Boris and Peters, Jan and D'Eramo, Carlo},
|
35 |
+
journal={arXiv preprint arXiv:2403.02107},
|
36 |
+
year={2024}
|
37 |
+
}
|
38 |
+
```
|
atari.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The environment is inspired from https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from typing import Tuple, Dict
|
7 |
+
from gym.wrappers.monitoring import video_recorder
|
8 |
+
import gym
|
9 |
+
import numpy as np
|
10 |
+
import jax
|
11 |
+
import jax.numpy as jnp
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
|
15 |
+
class AtariEnv:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
name: str,
|
19 |
+
) -> None:
|
20 |
+
self.name = name
|
21 |
+
self.state_height, self.state_width = (84, 84)
|
22 |
+
self.n_stacked_frames = 4
|
23 |
+
self.n_skipped_frames = 4
|
24 |
+
|
25 |
+
self.env = gym.make(
|
26 |
+
f"ALE/{self.name}-v5",
|
27 |
+
full_action_space=False,
|
28 |
+
frameskip=1,
|
29 |
+
repeat_action_probability=0.25,
|
30 |
+
render_mode="rgb_array",
|
31 |
+
).env
|
32 |
+
|
33 |
+
self.n_actions = self.env.action_space.n
|
34 |
+
self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape
|
35 |
+
self.screen_buffer = [
|
36 |
+
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
|
37 |
+
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
|
38 |
+
]
|
39 |
+
|
40 |
+
@property
|
41 |
+
def observation(self) -> np.ndarray:
|
42 |
+
return np.copy(self.state_[:, :, -1])
|
43 |
+
|
44 |
+
@property
|
45 |
+
def state(self) -> np.ndarray:
|
46 |
+
return jnp.array(self.state_, dtype=jnp.float32)
|
47 |
+
|
48 |
+
def reset(self) -> None:
|
49 |
+
self.env.reset()
|
50 |
+
|
51 |
+
self.n_steps = 0
|
52 |
+
|
53 |
+
self.env.ale.getScreenGrayscale(self.screen_buffer[0])
|
54 |
+
self.screen_buffer[1].fill(0)
|
55 |
+
|
56 |
+
self.state_ = np.zeros((self.state_height, self.state_width, self.n_stacked_frames), dtype=np.uint8)
|
57 |
+
self.state_[:, :, -1] = self.resize()
|
58 |
+
|
59 |
+
def step(self, action: jnp.int8) -> Tuple[float, bool, Dict]:
|
60 |
+
reward = 0
|
61 |
+
|
62 |
+
for idx_frame in range(self.n_skipped_frames):
|
63 |
+
_, reward_, terminal, _ = self.env.step(action)
|
64 |
+
|
65 |
+
reward += reward_
|
66 |
+
|
67 |
+
if idx_frame >= self.n_skipped_frames - 2:
|
68 |
+
t = idx_frame - (self.n_skipped_frames - 2)
|
69 |
+
self.env.ale.getScreenGrayscale(self.screen_buffer[t])
|
70 |
+
|
71 |
+
if terminal:
|
72 |
+
break
|
73 |
+
|
74 |
+
self.state_ = np.roll(self.state_, -1, axis=-1)
|
75 |
+
self.state_[:, :, -1] = self.pool_and_resize()
|
76 |
+
|
77 |
+
self.n_steps += 1
|
78 |
+
|
79 |
+
return reward, terminal, _
|
80 |
+
|
81 |
+
def pool_and_resize(self) -> np.ndarray:
|
82 |
+
np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0])
|
83 |
+
|
84 |
+
return self.resize()
|
85 |
+
|
86 |
+
def resize(self):
|
87 |
+
return np.asarray(
|
88 |
+
cv2.resize(self.screen_buffer[0], (self.state_width, self.state_height), interpolation=cv2.INTER_AREA),
|
89 |
+
dtype=np.uint8,
|
90 |
+
)
|
91 |
+
|
92 |
+
def evaluate_one_simulation(
|
93 |
+
self,
|
94 |
+
q,
|
95 |
+
q_params: Dict,
|
96 |
+
horizon: int,
|
97 |
+
eps_eval: float,
|
98 |
+
exploration_key: jax.random.PRNGKey,
|
99 |
+
video_path: str,
|
100 |
+
) -> float:
|
101 |
+
video = video_recorder.VideoRecorder(
|
102 |
+
self.env, path=f"{video_path}.mp4", enabled=True if video_path is not None else False
|
103 |
+
)
|
104 |
+
sun_reward = 0
|
105 |
+
terminal = False
|
106 |
+
self.reset()
|
107 |
+
|
108 |
+
while not terminal and self.n_steps < horizon:
|
109 |
+
self.env.render(mode="rgb_array")
|
110 |
+
video.capture_frame()
|
111 |
+
|
112 |
+
exploration_key, key = jax.random.split(exploration_key)
|
113 |
+
if jax.random.uniform(key) < eps_eval:
|
114 |
+
action = jax.random.choice(key, jnp.arange(self.n_actions)).astype(jnp.int8)
|
115 |
+
else:
|
116 |
+
action = q.best_action(q_params, self.state, key)
|
117 |
+
|
118 |
+
reward, terminal, _ = self.step(action)
|
119 |
+
|
120 |
+
sun_reward += reward
|
121 |
+
|
122 |
+
video.close()
|
123 |
+
if video_path is not None:
|
124 |
+
os.remove(f"{video_path}.meta.json")
|
125 |
+
|
126 |
+
return sun_reward, terminal
|
evaluate.ipynb
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"%load_ext autoreload\n",
|
10 |
+
"%autoreload 2"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import jax\n",
|
20 |
+
"import pickle\n",
|
21 |
+
"from atari import AtariEnv\n",
|
22 |
+
"from networks import AtariiDQN, AtariiIQN\n",
|
23 |
+
"from networks import AtariiIQN\n",
|
24 |
+
"\n",
|
25 |
+
"# ------- START TO MODIFY ------- #\n",
|
26 |
+
"IDQN_ALGO = True # if False then i-IQN is evaluated\n",
|
27 |
+
"GAME = \"Alien\"\n",
|
28 |
+
"NETWORK_SEED = 1 # seed in [1, 2, 3, 4, 5]\n",
|
29 |
+
"EVALUATION_SEED = 0\n",
|
30 |
+
"HORIZON = 27000\n",
|
31 |
+
"ENDING_EPS = 0.01\n",
|
32 |
+
"RECORD_VIDEO = False\n",
|
33 |
+
"\n",
|
34 |
+
"### 56 games are available for i-DQN with 5 seeds each:\n",
|
35 |
+
"# Alien, Amidar, Assault, Asterix, Asteroids, Atlantis, \n",
|
36 |
+
"# BankHeist, BattleZone, BeamRider, Berzerk, Bowling, Boxing, Breakout, Centipede, \n",
|
37 |
+
"# ChopperCommand, CrazyClimber, DemonAttack, DoubleDunk, Enduro, FishingDerby, \n",
|
38 |
+
"# Freeway, Frostbite, Gopher, Gravitar, Hero, IceHockey, Jamesbond, Kangaroo, \n",
|
39 |
+
"# Krull, KungFuMaster, MontezumaRevenge, MsPacman, NameThisGame, Phoenix, Pitfall, \n",
|
40 |
+
"# Pong, Pooyan, PrivateEye, Qbert, Riverraid, RoadRunner, Robotank, Seaquest, Skiing, \n",
|
41 |
+
"# Solaris, SpaceInvaders, StarGunner, Tennis, TimePilot, Tutankham, UpNDown, Venture, \n",
|
42 |
+
"# VideoPinball, WizardOfWor, YarsRevenge, Zaxxon\n",
|
43 |
+
"\n",
|
44 |
+
"## 20 games are available for i-IQN with 5 seeds each:\n",
|
45 |
+
"# Alien, Assault, BankHeist, Berzerk, Breakout, Centipede, \n",
|
46 |
+
"# ChopperCommand, DemonAttack, Enduro, Frostbite, Gopher, \n",
|
47 |
+
"# Gravitar, IceHockey, Jamesbond, Krull, KungFuMaster, \n",
|
48 |
+
"# Riverraid, Seaquest, Skiing, StarGunner\n",
|
49 |
+
"# ------- END TO MODIFY ------- #\n",
|
50 |
+
"\n",
|
51 |
+
"\n",
|
52 |
+
"params_path = f\"parameters/{GAME}/{'iDQN' if IDQN_ALGO else 'iIQN'}/{5 if IDQN_ALGO else 3}_Q_{NETWORK_SEED}_best_online_params\"\n",
|
53 |
+
"\n",
|
54 |
+
"env = AtariEnv(GAME)\n",
|
55 |
+
"\n",
|
56 |
+
"if IDQN_ALGO:\n",
|
57 |
+
" q = AtariiDQN(env.n_actions, idx_head=0) # idx_head in [0, 1, 2, 3, 4, 5]\n",
|
58 |
+
"else:\n",
|
59 |
+
" q = AtariiIQN(env.n_actions, idx_head=0) # idx_head in [0, 1, 2, 3]\n",
|
60 |
+
"\n",
|
61 |
+
"with open(params_path, \"rb\") as handle:\n",
|
62 |
+
" q_params = pickle.load(handle)\n",
|
63 |
+
"\n",
|
64 |
+
"reward, absorbing = env.evaluate_one_simulation(\n",
|
65 |
+
" q, q_params, HORIZON, ENDING_EPS, jax.random.PRNGKey(EVALUATION_SEED), params_path if RECORD_VIDEO else None\n",
|
66 |
+
")\n",
|
67 |
+
"print(\"Undiscounted reward:\", reward)\n",
|
68 |
+
"print(\"N steps\", env.n_steps, \"; Horizon\", HORIZON, \"; Absorbing\", absorbing)"
|
69 |
+
]
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"metadata": {
|
73 |
+
"kernelspec": {
|
74 |
+
"display_name": "env",
|
75 |
+
"language": "python",
|
76 |
+
"name": "python3"
|
77 |
+
},
|
78 |
+
"language_info": {
|
79 |
+
"codemirror_mode": {
|
80 |
+
"name": "ipython",
|
81 |
+
"version": 3
|
82 |
+
},
|
83 |
+
"file_extension": ".py",
|
84 |
+
"mimetype": "text/x-python",
|
85 |
+
"name": "python",
|
86 |
+
"nbconvert_exporter": "python",
|
87 |
+
"pygments_lexer": "ipython3",
|
88 |
+
"version": "3.10.12"
|
89 |
+
}
|
90 |
+
},
|
91 |
+
"nbformat": 4,
|
92 |
+
"nbformat_minor": 2
|
93 |
+
}
|
networks.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flax.core import FrozenDict
|
2 |
+
import flax.linen as nn
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
|
8 |
+
# --- Base functions ---
|
9 |
+
|
10 |
+
|
11 |
+
def scale(state: jnp.ndarray) -> jnp.ndarray:
|
12 |
+
return state / 255.0
|
13 |
+
|
14 |
+
|
15 |
+
class Torso(nn.Module):
|
16 |
+
initialization_type: str
|
17 |
+
|
18 |
+
@nn.compact
|
19 |
+
def __call__(self, state):
|
20 |
+
if self.initialization_type == "dqn":
|
21 |
+
initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
|
22 |
+
elif self.initialization_type == "iqn":
|
23 |
+
initializer = nn.initializers.variance_scaling(
|
24 |
+
scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
|
25 |
+
)
|
26 |
+
|
27 |
+
x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(state)
|
28 |
+
x = nn.relu(x)
|
29 |
+
x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x)
|
30 |
+
x = nn.relu(x)
|
31 |
+
x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x)
|
32 |
+
x = nn.relu(x)
|
33 |
+
|
34 |
+
return x.flatten()
|
35 |
+
|
36 |
+
|
37 |
+
class Head(nn.Module):
|
38 |
+
n_actions: int
|
39 |
+
initialization_type: str
|
40 |
+
|
41 |
+
@nn.compact
|
42 |
+
def __call__(self, x):
|
43 |
+
if self.initialization_type == "dqn":
|
44 |
+
initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
|
45 |
+
elif self.initialization_type == "iqn":
|
46 |
+
initializer = nn.initializers.variance_scaling(
|
47 |
+
scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
|
48 |
+
)
|
49 |
+
|
50 |
+
x = nn.Dense(features=512, kernel_init=initializer)(x)
|
51 |
+
x = nn.relu(x)
|
52 |
+
|
53 |
+
return nn.Dense(features=self.n_actions, kernel_init=initializer)(x)
|
54 |
+
|
55 |
+
|
56 |
+
class QuantileEmbedding(nn.Module):
|
57 |
+
n_features: int = 7744
|
58 |
+
quantile_embedding_dim: int = 64
|
59 |
+
|
60 |
+
@nn.compact
|
61 |
+
def __call__(self, key, n_quantiles):
|
62 |
+
initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform")
|
63 |
+
|
64 |
+
quantiles = jax.random.uniform(key, shape=(n_quantiles, 1))
|
65 |
+
arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape((1, self.quantile_embedding_dim))
|
66 |
+
|
67 |
+
quantile_embedding = nn.Dense(features=self.n_features, kernel_init=initializer)(
|
68 |
+
jnp.cos(jnp.pi * quantiles @ arange)
|
69 |
+
)
|
70 |
+
# output (n_quantiles, n_features) | (n_quantiles)
|
71 |
+
return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1))
|
72 |
+
|
73 |
+
|
74 |
+
# --- i-DQN networks ---
|
75 |
+
|
76 |
+
|
77 |
+
class AtariSharediDQNNet:
|
78 |
+
def __init__(self, n_actions: int) -> None:
|
79 |
+
self.n_heads = 5
|
80 |
+
self.n_actions = n_actions
|
81 |
+
self.torso = Torso("dqn")
|
82 |
+
self.head = Head(self.n_actions, "dqn")
|
83 |
+
|
84 |
+
def apply(self, params: FrozenDict, idx_head: int, state: jnp.ndarray) -> jnp.ndarray:
|
85 |
+
feature = self.torso.apply(
|
86 |
+
params[f"torso_params_{min(idx_head, 1)}"],
|
87 |
+
state,
|
88 |
+
)
|
89 |
+
|
90 |
+
return self.head.apply(params[f"head_params_{idx_head}"], feature)
|
91 |
+
|
92 |
+
|
93 |
+
class AtariiDQN:
|
94 |
+
def __init__(self, n_actions: int, idx_head: int) -> None:
|
95 |
+
self.network = AtariSharediDQNNet(n_actions)
|
96 |
+
self.idx_head = idx_head
|
97 |
+
|
98 |
+
@partial(jax.jit, static_argnames="self")
|
99 |
+
def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
|
100 |
+
return jnp.argmax(self.network.apply(params, self.idx_head, scale(state))).astype(jnp.int8)
|
101 |
+
|
102 |
+
|
103 |
+
# --- i-IQN networks ---
|
104 |
+
|
105 |
+
|
106 |
+
class AtariSharediIQNNet:
|
107 |
+
def __init__(self, n_actions: int) -> None:
|
108 |
+
self.n_heads = 4
|
109 |
+
self.n_actions = n_actions
|
110 |
+
self.torso = Torso("iqn")
|
111 |
+
self.quantile_embedding = QuantileEmbedding()
|
112 |
+
self.head = Head(self.n_actions, "iqn")
|
113 |
+
|
114 |
+
def apply(
|
115 |
+
self, params: FrozenDict, idx_head: int, state: jnp.ndarray, key: jax.random.PRNGKey, n_quantiles: int
|
116 |
+
) -> jnp.ndarray:
|
117 |
+
# output (n_features)
|
118 |
+
state_feature = self.torso.apply(
|
119 |
+
jax.tree_util.tree_map(
|
120 |
+
lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["torso_params"]
|
121 |
+
),
|
122 |
+
state,
|
123 |
+
)
|
124 |
+
|
125 |
+
# output (n_quantiles, n_features)
|
126 |
+
quantiles_feature, _ = self.quantile_embedding.apply(
|
127 |
+
jax.tree_util.tree_map(
|
128 |
+
lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["quantiles_params"]
|
129 |
+
),
|
130 |
+
key,
|
131 |
+
n_quantiles,
|
132 |
+
)
|
133 |
+
|
134 |
+
# mapping over the quantiles | output (n_quantiles, n_features)
|
135 |
+
feature = jax.vmap(
|
136 |
+
lambda quantile_feature_, state_feature_: quantile_feature_ * state_feature_, in_axes=(0, None)
|
137 |
+
)(quantiles_feature, state_feature)
|
138 |
+
|
139 |
+
return self.head.apply(
|
140 |
+
jax.tree_util.tree_map(lambda param: param[idx_head], params["head_params"]), feature
|
141 |
+
) # output (n_quantiles, n_actions)
|
142 |
+
|
143 |
+
|
144 |
+
class AtariiIQN:
|
145 |
+
def __init__(self, n_actions: int, idx_head: int) -> None:
|
146 |
+
self.network = AtariSharediIQNNet(n_actions)
|
147 |
+
self.idx_head = idx_head
|
148 |
+
self.n_quantiles_policy = 32
|
149 |
+
|
150 |
+
@partial(jax.jit, static_argnames="self")
|
151 |
+
def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
|
152 |
+
# output (n_quantiles, n_actions)
|
153 |
+
q_quantiles = self.network.apply(params, self.idx_head, scale(state), key, self.n_quantiles_policy)
|
154 |
+
q_values = jnp.mean(q_quantiles, axis=0)
|
155 |
+
|
156 |
+
return jnp.argmax(q_values).astype(jnp.int8)
|
performances.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.4.0
|
2 |
+
ale-py==0.7.5
|
3 |
+
arch==6.2.0
|
4 |
+
asttokens==2.2.1
|
5 |
+
AutoROM==0.4.2
|
6 |
+
AutoROM.accept-rom-license==0.6.1
|
7 |
+
backcall==0.2.0
|
8 |
+
cached-property==1.5.2
|
9 |
+
certifi==2023.5.7
|
10 |
+
charset-normalizer==3.1.0
|
11 |
+
chex==0.1.8
|
12 |
+
click==8.1.3
|
13 |
+
cloudpickle==2.2.1
|
14 |
+
comm==0.1.3
|
15 |
+
contourpy==1.1.0
|
16 |
+
cycler==0.11.0
|
17 |
+
debugpy==1.6.7
|
18 |
+
decorator==5.1.1
|
19 |
+
dm-tree==0.1.8
|
20 |
+
etils==1.3.0
|
21 |
+
exceptiongroup==1.1.3
|
22 |
+
executing==1.2.0
|
23 |
+
flax==0.6.11
|
24 |
+
fonttools==4.40.0
|
25 |
+
fsspec==2023.9.2
|
26 |
+
gym==0.25.2
|
27 |
+
gym-notices==0.0.8
|
28 |
+
idna==3.4
|
29 |
+
importlib-resources==5.12.0
|
30 |
+
iniconfig==2.0.0
|
31 |
+
ipykernel==6.25.0
|
32 |
+
ipython==8.14.0
|
33 |
+
jax==0.4.13
|
34 |
+
jaxlib==0.4.13
|
35 |
+
jedi==0.18.2
|
36 |
+
jupyter_client==8.3.0
|
37 |
+
jupyter_core==5.3.1
|
38 |
+
kiwisolver==1.4.5
|
39 |
+
markdown-it-py==3.0.0
|
40 |
+
matplotlib==3.7.1
|
41 |
+
matplotlib-inline==0.1.6
|
42 |
+
mdurl==0.1.2
|
43 |
+
ml-dtypes==0.2.0
|
44 |
+
msgpack==1.0.5
|
45 |
+
nest-asyncio==1.5.6
|
46 |
+
numpy==1.23.5
|
47 |
+
nvidia-cublas-cu12==12.2.5.6
|
48 |
+
nvidia-cuda-cupti-cu12==12.2.142
|
49 |
+
nvidia-cuda-nvcc-cu12==12.2.140
|
50 |
+
nvidia-cuda-nvrtc-cu12==12.2.140
|
51 |
+
nvidia-cuda-runtime-cu12==12.2.140
|
52 |
+
nvidia-cudnn-cu12==8.9.4.25
|
53 |
+
nvidia-cufft-cu12==11.0.8.103
|
54 |
+
nvidia-cusolver-cu12==11.5.2.141
|
55 |
+
nvidia-cusparse-cu12==12.1.2.141
|
56 |
+
nvidia-nvjitlink-cu12==12.2.140
|
57 |
+
opencv-python==4.7.0.72
|
58 |
+
opt-einsum==3.3.0
|
59 |
+
optax==0.1.5
|
60 |
+
orbax-checkpoint==0.2.6
|
61 |
+
packaging==23.1
|
62 |
+
pandas==2.0.2
|
63 |
+
parso==0.8.3
|
64 |
+
patsy==0.5.3
|
65 |
+
pexpect==4.8.0
|
66 |
+
pickleshare==0.7.5
|
67 |
+
Pillow==9.5.0
|
68 |
+
platformdirs==3.9.1
|
69 |
+
pluggy==1.3.0
|
70 |
+
prompt-toolkit==3.0.38
|
71 |
+
protobuf==4.23.3
|
72 |
+
psutil==5.9.5
|
73 |
+
ptyprocess==0.7.0
|
74 |
+
pure-eval==0.2.2
|
75 |
+
Pygments==2.15.1
|
76 |
+
pyparsing==3.1.0
|
77 |
+
pytest==7.4.0
|
78 |
+
python-dateutil==2.8.2
|
79 |
+
pytz==2023.3
|
80 |
+
PyYAML==6.0
|
81 |
+
pyzmq==25.1.0
|
82 |
+
requests==2.31.0
|
83 |
+
rich==13.4.2
|
84 |
+
scipy==1.11.0
|
85 |
+
six==1.16.0
|
86 |
+
stack-data==0.6.2
|
87 |
+
statsmodels==0.14.0
|
88 |
+
tensorstore==0.1.39
|
89 |
+
tomli==2.0.1
|
90 |
+
toolz==0.12.0
|
91 |
+
tornado==6.3.2
|
92 |
+
tqdm==4.65.0
|
93 |
+
traitlets==5.9.0
|
94 |
+
typing_extensions==4.6.3
|
95 |
+
tzdata==2023.3
|
96 |
+
urllib3==1.26.16
|
97 |
+
wcwidth==0.2.6
|
98 |
+
zipp==3.17.0
|