TheoVincent commited on
Commit
4cb4fc3
·
1 Parent(s): e14ba50

minimal example - code

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +38 -0
  3. atari.py +126 -0
  4. evaluate.ipynb +93 -0
  5. networks.py +156 -0
  6. performances.png +0 -0
  7. requirements.txt +98 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *best_online_params filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model parameters training with `i-DQN` and `i-IQN`
2
+ This repository contains the model parameters trained with `i-DQN` on [$57$ Atari games](#list-of-games-for-i-dqn) and trained with `i-IQN` on [$20$ Atari games](#list-of-games-for-i-iqn) 🎮. $5$ seeds are available for each configuration which makes a total of $385$ available models 📈.
3
+
4
+ The [evaluate.ipynb](./evaluate.ipynb) notebook contains a minimal example to evaluate to model parameters 🧑‍🏫. It uses JAX 🚀.
5
+
6
+ ps: The set of [$20$ Atari games](#list-of-games-for-i-iqn) is included in the set of [$57$ Atari games](#list-of-games-for-i-dqn).
7
+
8
+ ### Model performances
9
+ `i-DQN` and `i-IQN` are improvements made over [`DQN`](https://www.nature.com/articles/nature14236.pdf) and [`IQN`](https://arxiv.org/abs/1806.06923) ✨. Check it out on [arXiv](https://arxiv.org/abs/2403.02107)! | <img src="performances.png" alt="drawing" width="600"/>
10
+ :-:|:-:
11
+
12
+
13
+ ### List of games for `i-DQN`
14
+ Alien, Amidar, Assault, Asterix, Asteroids, Atlantis, BankHeist, BattleZone, BeamRider, Berzerk, Bowling, Boxing, Breakout, Centipede, ChopperCommand, CrazyClimber, DemonAttack, DoubleDunk, Enduro, FishingDerby, Freeway, Frostbite, Gopher, Gravitar, Hero, IceHockey, Jamesbond, Kangaroo, Krull, KungFuMaster, MontezumaRevenge, MsPacman, NameThisGame, Phoenix, Pitfall, Pong, Pooyan, PrivateEye, Qbert, Riverraid, RoadRunner, Robotank, Seaquest, Skiing, Solaris, SpaceInvaders, StarGunner, Tennis, TimePilot, Tutankham, UpNDown, Venture, VideoPinball, WizardOfWor, YarsRevenge, Zaxxon.
15
+
16
+ ### List of games for `i-IQN`
17
+ Alien, Assault, BankHeist, Berzerk, Breakout, Centipede, ChopperCommand, DemonAttack, Enduro, Frostbite, Gopher, Gravitar, IceHockey, Jamesbond, Krull, KungFuMaster, Riverraid, Seaquest, Skiing, StarGunner.
18
+
19
+ ## User installation
20
+ Python 3.10 is recommended. Create a Python virtual environment, activate it, update pip and install the package and its dependencies in editable mode:
21
+ ```bash
22
+ python3.10 -m venv env
23
+ source env/bin/activate
24
+ pip install --upgrade pip
25
+ pip install numpy==1.23.5 # to avoid numpy==2.XX
26
+ pip install -r requirements.txt
27
+ pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
28
+ ```
29
+
30
+ ## Citing `i-QN`
31
+ ```
32
+ @article{vincent2024iterated,
33
+ title={Iterated $ Q $-Network: Beyond the One-Step Bellman Operator},
34
+ author={Vincent, Th{\'e}o and Palenicek, Daniel and Belousov, Boris and Peters, Jan and D'Eramo, Carlo},
35
+ journal={arXiv preprint arXiv:2403.02107},
36
+ year={2024}
37
+ }
38
+ ```
atari.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The environment is inspired from https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py
3
+ """
4
+
5
+ import os
6
+ from typing import Tuple, Dict
7
+ from gym.wrappers.monitoring import video_recorder
8
+ import gym
9
+ import numpy as np
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import cv2
13
+
14
+
15
+ class AtariEnv:
16
+ def __init__(
17
+ self,
18
+ name: str,
19
+ ) -> None:
20
+ self.name = name
21
+ self.state_height, self.state_width = (84, 84)
22
+ self.n_stacked_frames = 4
23
+ self.n_skipped_frames = 4
24
+
25
+ self.env = gym.make(
26
+ f"ALE/{self.name}-v5",
27
+ full_action_space=False,
28
+ frameskip=1,
29
+ repeat_action_probability=0.25,
30
+ render_mode="rgb_array",
31
+ ).env
32
+
33
+ self.n_actions = self.env.action_space.n
34
+ self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape
35
+ self.screen_buffer = [
36
+ np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
37
+ np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
38
+ ]
39
+
40
+ @property
41
+ def observation(self) -> np.ndarray:
42
+ return np.copy(self.state_[:, :, -1])
43
+
44
+ @property
45
+ def state(self) -> np.ndarray:
46
+ return jnp.array(self.state_, dtype=jnp.float32)
47
+
48
+ def reset(self) -> None:
49
+ self.env.reset()
50
+
51
+ self.n_steps = 0
52
+
53
+ self.env.ale.getScreenGrayscale(self.screen_buffer[0])
54
+ self.screen_buffer[1].fill(0)
55
+
56
+ self.state_ = np.zeros((self.state_height, self.state_width, self.n_stacked_frames), dtype=np.uint8)
57
+ self.state_[:, :, -1] = self.resize()
58
+
59
+ def step(self, action: jnp.int8) -> Tuple[float, bool, Dict]:
60
+ reward = 0
61
+
62
+ for idx_frame in range(self.n_skipped_frames):
63
+ _, reward_, terminal, _ = self.env.step(action)
64
+
65
+ reward += reward_
66
+
67
+ if idx_frame >= self.n_skipped_frames - 2:
68
+ t = idx_frame - (self.n_skipped_frames - 2)
69
+ self.env.ale.getScreenGrayscale(self.screen_buffer[t])
70
+
71
+ if terminal:
72
+ break
73
+
74
+ self.state_ = np.roll(self.state_, -1, axis=-1)
75
+ self.state_[:, :, -1] = self.pool_and_resize()
76
+
77
+ self.n_steps += 1
78
+
79
+ return reward, terminal, _
80
+
81
+ def pool_and_resize(self) -> np.ndarray:
82
+ np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0])
83
+
84
+ return self.resize()
85
+
86
+ def resize(self):
87
+ return np.asarray(
88
+ cv2.resize(self.screen_buffer[0], (self.state_width, self.state_height), interpolation=cv2.INTER_AREA),
89
+ dtype=np.uint8,
90
+ )
91
+
92
+ def evaluate_one_simulation(
93
+ self,
94
+ q,
95
+ q_params: Dict,
96
+ horizon: int,
97
+ eps_eval: float,
98
+ exploration_key: jax.random.PRNGKey,
99
+ video_path: str,
100
+ ) -> float:
101
+ video = video_recorder.VideoRecorder(
102
+ self.env, path=f"{video_path}.mp4", enabled=True if video_path is not None else False
103
+ )
104
+ sun_reward = 0
105
+ terminal = False
106
+ self.reset()
107
+
108
+ while not terminal and self.n_steps < horizon:
109
+ self.env.render(mode="rgb_array")
110
+ video.capture_frame()
111
+
112
+ exploration_key, key = jax.random.split(exploration_key)
113
+ if jax.random.uniform(key) < eps_eval:
114
+ action = jax.random.choice(key, jnp.arange(self.n_actions)).astype(jnp.int8)
115
+ else:
116
+ action = q.best_action(q_params, self.state, key)
117
+
118
+ reward, terminal, _ = self.step(action)
119
+
120
+ sun_reward += reward
121
+
122
+ video.close()
123
+ if video_path is not None:
124
+ os.remove(f"{video_path}.meta.json")
125
+
126
+ return sun_reward, terminal
evaluate.ipynb ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import jax\n",
20
+ "import pickle\n",
21
+ "from atari import AtariEnv\n",
22
+ "from networks import AtariiDQN, AtariiIQN\n",
23
+ "from networks import AtariiIQN\n",
24
+ "\n",
25
+ "# ------- START TO MODIFY ------- #\n",
26
+ "IDQN_ALGO = True # if False then i-IQN is evaluated\n",
27
+ "GAME = \"Alien\"\n",
28
+ "NETWORK_SEED = 1 # seed in [1, 2, 3, 4, 5]\n",
29
+ "EVALUATION_SEED = 0\n",
30
+ "HORIZON = 27000\n",
31
+ "ENDING_EPS = 0.01\n",
32
+ "RECORD_VIDEO = False\n",
33
+ "\n",
34
+ "### 56 games are available for i-DQN with 5 seeds each:\n",
35
+ "# Alien, Amidar, Assault, Asterix, Asteroids, Atlantis, \n",
36
+ "# BankHeist, BattleZone, BeamRider, Berzerk, Bowling, Boxing, Breakout, Centipede, \n",
37
+ "# ChopperCommand, CrazyClimber, DemonAttack, DoubleDunk, Enduro, FishingDerby, \n",
38
+ "# Freeway, Frostbite, Gopher, Gravitar, Hero, IceHockey, Jamesbond, Kangaroo, \n",
39
+ "# Krull, KungFuMaster, MontezumaRevenge, MsPacman, NameThisGame, Phoenix, Pitfall, \n",
40
+ "# Pong, Pooyan, PrivateEye, Qbert, Riverraid, RoadRunner, Robotank, Seaquest, Skiing, \n",
41
+ "# Solaris, SpaceInvaders, StarGunner, Tennis, TimePilot, Tutankham, UpNDown, Venture, \n",
42
+ "# VideoPinball, WizardOfWor, YarsRevenge, Zaxxon\n",
43
+ "\n",
44
+ "## 20 games are available for i-IQN with 5 seeds each:\n",
45
+ "# Alien, Assault, BankHeist, Berzerk, Breakout, Centipede, \n",
46
+ "# ChopperCommand, DemonAttack, Enduro, Frostbite, Gopher, \n",
47
+ "# Gravitar, IceHockey, Jamesbond, Krull, KungFuMaster, \n",
48
+ "# Riverraid, Seaquest, Skiing, StarGunner\n",
49
+ "# ------- END TO MODIFY ------- #\n",
50
+ "\n",
51
+ "\n",
52
+ "params_path = f\"parameters/{GAME}/{'iDQN' if IDQN_ALGO else 'iIQN'}/{5 if IDQN_ALGO else 3}_Q_{NETWORK_SEED}_best_online_params\"\n",
53
+ "\n",
54
+ "env = AtariEnv(GAME)\n",
55
+ "\n",
56
+ "if IDQN_ALGO:\n",
57
+ " q = AtariiDQN(env.n_actions, idx_head=0) # idx_head in [0, 1, 2, 3, 4, 5]\n",
58
+ "else:\n",
59
+ " q = AtariiIQN(env.n_actions, idx_head=0) # idx_head in [0, 1, 2, 3]\n",
60
+ "\n",
61
+ "with open(params_path, \"rb\") as handle:\n",
62
+ " q_params = pickle.load(handle)\n",
63
+ "\n",
64
+ "reward, absorbing = env.evaluate_one_simulation(\n",
65
+ " q, q_params, HORIZON, ENDING_EPS, jax.random.PRNGKey(EVALUATION_SEED), params_path if RECORD_VIDEO else None\n",
66
+ ")\n",
67
+ "print(\"Undiscounted reward:\", reward)\n",
68
+ "print(\"N steps\", env.n_steps, \"; Horizon\", HORIZON, \"; Absorbing\", absorbing)"
69
+ ]
70
+ }
71
+ ],
72
+ "metadata": {
73
+ "kernelspec": {
74
+ "display_name": "env",
75
+ "language": "python",
76
+ "name": "python3"
77
+ },
78
+ "language_info": {
79
+ "codemirror_mode": {
80
+ "name": "ipython",
81
+ "version": 3
82
+ },
83
+ "file_extension": ".py",
84
+ "mimetype": "text/x-python",
85
+ "name": "python",
86
+ "nbconvert_exporter": "python",
87
+ "pygments_lexer": "ipython3",
88
+ "version": "3.10.12"
89
+ }
90
+ },
91
+ "nbformat": 4,
92
+ "nbformat_minor": 2
93
+ }
networks.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax.core import FrozenDict
2
+ import flax.linen as nn
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from functools import partial
6
+
7
+
8
+ # --- Base functions ---
9
+
10
+
11
+ def scale(state: jnp.ndarray) -> jnp.ndarray:
12
+ return state / 255.0
13
+
14
+
15
+ class Torso(nn.Module):
16
+ initialization_type: str
17
+
18
+ @nn.compact
19
+ def __call__(self, state):
20
+ if self.initialization_type == "dqn":
21
+ initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
22
+ elif self.initialization_type == "iqn":
23
+ initializer = nn.initializers.variance_scaling(
24
+ scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
25
+ )
26
+
27
+ x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(state)
28
+ x = nn.relu(x)
29
+ x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x)
30
+ x = nn.relu(x)
31
+ x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x)
32
+ x = nn.relu(x)
33
+
34
+ return x.flatten()
35
+
36
+
37
+ class Head(nn.Module):
38
+ n_actions: int
39
+ initialization_type: str
40
+
41
+ @nn.compact
42
+ def __call__(self, x):
43
+ if self.initialization_type == "dqn":
44
+ initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
45
+ elif self.initialization_type == "iqn":
46
+ initializer = nn.initializers.variance_scaling(
47
+ scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
48
+ )
49
+
50
+ x = nn.Dense(features=512, kernel_init=initializer)(x)
51
+ x = nn.relu(x)
52
+
53
+ return nn.Dense(features=self.n_actions, kernel_init=initializer)(x)
54
+
55
+
56
+ class QuantileEmbedding(nn.Module):
57
+ n_features: int = 7744
58
+ quantile_embedding_dim: int = 64
59
+
60
+ @nn.compact
61
+ def __call__(self, key, n_quantiles):
62
+ initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform")
63
+
64
+ quantiles = jax.random.uniform(key, shape=(n_quantiles, 1))
65
+ arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape((1, self.quantile_embedding_dim))
66
+
67
+ quantile_embedding = nn.Dense(features=self.n_features, kernel_init=initializer)(
68
+ jnp.cos(jnp.pi * quantiles @ arange)
69
+ )
70
+ # output (n_quantiles, n_features) | (n_quantiles)
71
+ return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1))
72
+
73
+
74
+ # --- i-DQN networks ---
75
+
76
+
77
+ class AtariSharediDQNNet:
78
+ def __init__(self, n_actions: int) -> None:
79
+ self.n_heads = 5
80
+ self.n_actions = n_actions
81
+ self.torso = Torso("dqn")
82
+ self.head = Head(self.n_actions, "dqn")
83
+
84
+ def apply(self, params: FrozenDict, idx_head: int, state: jnp.ndarray) -> jnp.ndarray:
85
+ feature = self.torso.apply(
86
+ params[f"torso_params_{min(idx_head, 1)}"],
87
+ state,
88
+ )
89
+
90
+ return self.head.apply(params[f"head_params_{idx_head}"], feature)
91
+
92
+
93
+ class AtariiDQN:
94
+ def __init__(self, n_actions: int, idx_head: int) -> None:
95
+ self.network = AtariSharediDQNNet(n_actions)
96
+ self.idx_head = idx_head
97
+
98
+ @partial(jax.jit, static_argnames="self")
99
+ def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
100
+ return jnp.argmax(self.network.apply(params, self.idx_head, scale(state))).astype(jnp.int8)
101
+
102
+
103
+ # --- i-IQN networks ---
104
+
105
+
106
+ class AtariSharediIQNNet:
107
+ def __init__(self, n_actions: int) -> None:
108
+ self.n_heads = 4
109
+ self.n_actions = n_actions
110
+ self.torso = Torso("iqn")
111
+ self.quantile_embedding = QuantileEmbedding()
112
+ self.head = Head(self.n_actions, "iqn")
113
+
114
+ def apply(
115
+ self, params: FrozenDict, idx_head: int, state: jnp.ndarray, key: jax.random.PRNGKey, n_quantiles: int
116
+ ) -> jnp.ndarray:
117
+ # output (n_features)
118
+ state_feature = self.torso.apply(
119
+ jax.tree_util.tree_map(
120
+ lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["torso_params"]
121
+ ),
122
+ state,
123
+ )
124
+
125
+ # output (n_quantiles, n_features)
126
+ quantiles_feature, _ = self.quantile_embedding.apply(
127
+ jax.tree_util.tree_map(
128
+ lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["quantiles_params"]
129
+ ),
130
+ key,
131
+ n_quantiles,
132
+ )
133
+
134
+ # mapping over the quantiles | output (n_quantiles, n_features)
135
+ feature = jax.vmap(
136
+ lambda quantile_feature_, state_feature_: quantile_feature_ * state_feature_, in_axes=(0, None)
137
+ )(quantiles_feature, state_feature)
138
+
139
+ return self.head.apply(
140
+ jax.tree_util.tree_map(lambda param: param[idx_head], params["head_params"]), feature
141
+ ) # output (n_quantiles, n_actions)
142
+
143
+
144
+ class AtariiIQN:
145
+ def __init__(self, n_actions: int, idx_head: int) -> None:
146
+ self.network = AtariSharediIQNNet(n_actions)
147
+ self.idx_head = idx_head
148
+ self.n_quantiles_policy = 32
149
+
150
+ @partial(jax.jit, static_argnames="self")
151
+ def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
152
+ # output (n_quantiles, n_actions)
153
+ q_quantiles = self.network.apply(params, self.idx_head, scale(state), key, self.n_quantiles_policy)
154
+ q_values = jnp.mean(q_quantiles, axis=0)
155
+
156
+ return jnp.argmax(q_values).astype(jnp.int8)
performances.png ADDED
requirements.txt ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ ale-py==0.7.5
3
+ arch==6.2.0
4
+ asttokens==2.2.1
5
+ AutoROM==0.4.2
6
+ AutoROM.accept-rom-license==0.6.1
7
+ backcall==0.2.0
8
+ cached-property==1.5.2
9
+ certifi==2023.5.7
10
+ charset-normalizer==3.1.0
11
+ chex==0.1.8
12
+ click==8.1.3
13
+ cloudpickle==2.2.1
14
+ comm==0.1.3
15
+ contourpy==1.1.0
16
+ cycler==0.11.0
17
+ debugpy==1.6.7
18
+ decorator==5.1.1
19
+ dm-tree==0.1.8
20
+ etils==1.3.0
21
+ exceptiongroup==1.1.3
22
+ executing==1.2.0
23
+ flax==0.6.11
24
+ fonttools==4.40.0
25
+ fsspec==2023.9.2
26
+ gym==0.25.2
27
+ gym-notices==0.0.8
28
+ idna==3.4
29
+ importlib-resources==5.12.0
30
+ iniconfig==2.0.0
31
+ ipykernel==6.25.0
32
+ ipython==8.14.0
33
+ jax==0.4.13
34
+ jaxlib==0.4.13
35
+ jedi==0.18.2
36
+ jupyter_client==8.3.0
37
+ jupyter_core==5.3.1
38
+ kiwisolver==1.4.5
39
+ markdown-it-py==3.0.0
40
+ matplotlib==3.7.1
41
+ matplotlib-inline==0.1.6
42
+ mdurl==0.1.2
43
+ ml-dtypes==0.2.0
44
+ msgpack==1.0.5
45
+ nest-asyncio==1.5.6
46
+ numpy==1.23.5
47
+ nvidia-cublas-cu12==12.2.5.6
48
+ nvidia-cuda-cupti-cu12==12.2.142
49
+ nvidia-cuda-nvcc-cu12==12.2.140
50
+ nvidia-cuda-nvrtc-cu12==12.2.140
51
+ nvidia-cuda-runtime-cu12==12.2.140
52
+ nvidia-cudnn-cu12==8.9.4.25
53
+ nvidia-cufft-cu12==11.0.8.103
54
+ nvidia-cusolver-cu12==11.5.2.141
55
+ nvidia-cusparse-cu12==12.1.2.141
56
+ nvidia-nvjitlink-cu12==12.2.140
57
+ opencv-python==4.7.0.72
58
+ opt-einsum==3.3.0
59
+ optax==0.1.5
60
+ orbax-checkpoint==0.2.6
61
+ packaging==23.1
62
+ pandas==2.0.2
63
+ parso==0.8.3
64
+ patsy==0.5.3
65
+ pexpect==4.8.0
66
+ pickleshare==0.7.5
67
+ Pillow==9.5.0
68
+ platformdirs==3.9.1
69
+ pluggy==1.3.0
70
+ prompt-toolkit==3.0.38
71
+ protobuf==4.23.3
72
+ psutil==5.9.5
73
+ ptyprocess==0.7.0
74
+ pure-eval==0.2.2
75
+ Pygments==2.15.1
76
+ pyparsing==3.1.0
77
+ pytest==7.4.0
78
+ python-dateutil==2.8.2
79
+ pytz==2023.3
80
+ PyYAML==6.0
81
+ pyzmq==25.1.0
82
+ requests==2.31.0
83
+ rich==13.4.2
84
+ scipy==1.11.0
85
+ six==1.16.0
86
+ stack-data==0.6.2
87
+ statsmodels==0.14.0
88
+ tensorstore==0.1.39
89
+ tomli==2.0.1
90
+ toolz==0.12.0
91
+ tornado==6.3.2
92
+ tqdm==4.65.0
93
+ traitlets==5.9.0
94
+ typing_extensions==4.6.3
95
+ tzdata==2023.3
96
+ urllib3==1.26.16
97
+ wcwidth==0.2.6
98
+ zipp==3.17.0