qgallouedec HF staff commited on
Commit
02b850c
·
1 Parent(s): 294a607

Upload folder using huggingface_hub

Browse files
.summary/0/events.out.tfevents.1688752637.qgallouedec-MS-7C84 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6b04de5aa4f26e0cfa099ca361578a8c4daf5333f63956d1f1d3200f814d000
3
+ size 714504
README.md CHANGED
@@ -15,7 +15,7 @@ model-index:
15
  type: basketball-v2
16
  metrics:
17
  - type: mean_reward
18
- value: 4555.14 +/- 18.34
19
  name: mean_reward
20
  verified: false
21
  ---
 
15
  type: basketball-v2
16
  metrics:
17
  - type: mean_reward
18
+ value: 272.23 +/- 43.84
19
  name: mean_reward
20
  verified: false
21
  ---
checkpoint_p0/best_000004120_2109440_reward_284.269.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3b973299e0d9f378459863001c367eabe1dd45e50803abd73eae7cbf7a09747
3
+ size 98239
checkpoint_p0/checkpoint_000019296_9879552.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f90dd56e35eb6ec733e84b355f4d0791e61e57fd5aa4b7fef8d3fe06edb86ac9
3
+ size 98567
checkpoint_p0/checkpoint_000019544_10006528.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96e94fefc651dd6d6e0df6fd5a3a88bce68b7b80fb068036902d3b65f69cdd91
3
+ size 98567
config.json CHANGED
@@ -65,7 +65,7 @@
65
  "summaries_use_frameskip": true,
66
  "heartbeat_interval": 20,
67
  "heartbeat_reporting_interval": 180,
68
- "train_for_env_steps": 100000000,
69
  "train_for_seconds": 10000000000,
70
  "save_every_sec": 15,
71
  "keep_checkpoints": 2,
@@ -128,7 +128,7 @@
128
  "wandb_user": "qgallouedec",
129
  "wandb_project": "sample_facotry_metaworld"
130
  },
131
- "git_hash": "aed90d9e164e44f91bab1d70c09fac4dee064031",
132
  "git_repo_name": "https://github.com/huggingface/gia",
133
- "wandb_unique_id": "basketball-v2_20230707_155648_448084"
134
  }
 
65
  "summaries_use_frameskip": true,
66
  "heartbeat_interval": 20,
67
  "heartbeat_reporting_interval": 180,
68
+ "train_for_env_steps": 10000000,
69
  "train_for_seconds": 10000000000,
70
  "save_every_sec": 15,
71
  "keep_checkpoints": 2,
 
128
  "wandb_user": "qgallouedec",
129
  "wandb_project": "sample_facotry_metaworld"
130
  },
131
+ "git_hash": "dda7c2cbaa4c60ae8940e37f69d814d32339d2fa",
132
  "git_repo_name": "https://github.com/huggingface/gia",
133
+ "wandb_unique_id": "basketball-v2_20230707_195715_026260"
134
  }
git.diff CHANGED
@@ -1,3 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py
2
  index e21b237..c2b1907 100644
3
  --- a/data/envs/metaworld/generate_dataset.py
@@ -13,20 +423,22 @@ index e21b237..c2b1907 100644
13
  dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0])
14
  dataset["continuous_actions"][-1].append(actions[0])
15
  diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
16
- index cfdae2f..5db8c4b 100755
17
  --- a/data/envs/metaworld/generate_dataset_all.sh
18
  +++ b/data/envs/metaworld/generate_dataset_all.sh
19
- @@ -2,58 +2,58 @@
 
20
 
21
  ENVS=(
22
- assembly
23
- - basketball
24
- - bin-picking
25
- - box-close
26
- - button-press-topdown
27
- - button-press-topdown-wall
28
- - button-press
29
- - button-press-wall
 
30
  - coffee-button
31
  - coffee-pull
32
  - coffee-push
@@ -69,13 +481,6 @@ index cfdae2f..5db8c4b 100755
69
  - sweep
70
  - window-close
71
  - window-open
72
- + # basketball
73
- + # bin-picking
74
- + # box-close
75
- + # button-press-topdown
76
- + # button-press-topdown-wall
77
- + # button-press
78
- + # button-press-wall
79
  + # coffee-button
80
  + # coffee-pull
81
  + # coffee-push
@@ -127,7 +532,7 @@ index cfdae2f..5db8c4b 100755
127
  + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
128
  done
129
  diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh
130
- index 9d71467..5b05c6d 100755
131
  --- a/data/envs/metaworld/push_all.sh
132
  +++ b/data/envs/metaworld/push_all.sh
133
  @@ -2,57 +2,57 @@
@@ -236,30 +641,114 @@ index 9d71467..5b05c6d 100755
236
 
237
  for ENV in "${ENVS[@]}"; do
238
  - python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
239
- + python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
240
  done
241
- diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh
242
- index dbf328a..1b3c4c8 100755
243
- --- a/data/envs/metaworld/train_all.sh
244
- +++ b/data/envs/metaworld/train_all.sh
245
- @@ -1,7 +1,7 @@
246
- #!/bin/bash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- ENVS=(
249
- - assembly
250
- + # assembly
251
- basketball
252
- bin-picking
253
- box-close
254
  diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
255
- index 91b645c..196a601 100644
256
  --- a/gia/eval/evaluator.py
257
  +++ b/gia/eval/evaluator.py
258
- @@ -2,14 +2,16 @@ import torch
 
 
 
259
 
260
  from gia.config.arguments import Arguments
261
- from gia.model import GiaModel
262
- +from typing import Optional
263
 
264
 
265
  class Evaluator:
@@ -274,38 +763,31 @@ index 91b645c..196a601 100644
274
  def evaluate(self, model: GiaModel) -> float:
275
  return self._evaluate(model)
276
 
277
- diff --git a/gia/eval/mappings.py b/gia/eval/mappings.py
278
- deleted file mode 100644
279
- index e7ba9d3..0000000
280
- --- a/gia/eval/mappings.py
281
- +++ /dev/null
282
- @@ -1,11 +0,0 @@
283
- -TASK_TO_ENV_MAPPING = {
284
- - "mujoco-ant": "Ant-v4",
285
- - "mujoco-halfcheetah": "HalfCheetah-v4",
286
- - "mujoco-hopper": "Hopper-v4",
287
- - "mujoco-doublependulum": "InvertedDoublePendulum-v4",
288
- - "mujoco-pendulum": "InvertedPendulum-v4",
289
- - "mujoco-reacher": "Reacher-v4",
290
- - "mujoco-swimmer": "Swimmer-v4",
291
- - "mujoco-walker": "Walker2d-v4",
292
- - # Atari etc...
293
- -}
294
- diff --git a/gia/eval/rl/__init__.py b/gia/eval/rl/__init__.py
295
- index 36d890b..85a788d 100644
296
- --- a/gia/eval/rl/__init__.py
297
- +++ b/gia/eval/rl/__init__.py
298
- @@ -1,4 +1,4 @@
299
- from .gym_evaluator import GymEvaluator
300
- +from .envs.core import make
301
 
302
- -
303
- -__all__ = ["GymEvaluator"]
304
- +__all__ = ["GymEvaluator", "make"]
 
 
 
305
  diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
306
- index f0d0b9b..04b9637 100644
307
  --- a/gia/eval/rl/gia_agent.py
308
  +++ b/gia/eval/rl/gia_agent.py
 
 
 
 
 
 
 
 
 
309
  @@ -75,6 +75,11 @@ class GiaAgent:
310
  ) -> Tuple[Tuple[Tensor, Tensor], ...]:
311
  return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
@@ -332,16 +814,32 @@ index f8531ee..754c05d 100644
332
 
333
 
334
  diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py
335
- index c5cc423..ca0c7da 100644
336
  --- a/gia/eval/rl/rl_evaluator.py
337
  +++ b/gia/eval/rl/rl_evaluator.py
338
- @@ -8,6 +8,9 @@ from gia.eval.rl.gia_agent import GiaAgent
339
 
340
 
341
  class RLEvaluator(Evaluator):
342
  + def __init__(self, args, task):
343
  + super().__init__(args, task)
344
  + self.agent = GiaAgent()
 
345
  def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ?
346
  raise NotImplementedError
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/data/envs/download_expert_scores.py b/data/envs/download_expert_scores.py
2
+ index 4c3f06b..88b6c45 100644
3
+ --- a/data/envs/download_expert_scores.py
4
+ +++ b/data/envs/download_expert_scores.py
5
+ @@ -12,162 +12,162 @@ from tqdm import tqdm
6
+
7
+
8
+ ENV_NAMES = [
9
+ - "atari-alien",
10
+ - "atari-amidar",
11
+ - "atari-assault",
12
+ - "atari-asterix",
13
+ - "atari-asteroids",
14
+ - "atari-atlantis",
15
+ - "atari-bankheist",
16
+ - "atari-battlezone",
17
+ - "atari-beamrider",
18
+ - "atari-berzerk",
19
+ - "atari-bowling",
20
+ - "atari-boxing",
21
+ - "atari-breakout",
22
+ - "atari-centipede",
23
+ - "atari-choppercommand",
24
+ - "atari-crazyclimber",
25
+ - "atari-defender",
26
+ - "atari-demonattack",
27
+ - "atari-doubledunk",
28
+ - "atari-enduro",
29
+ - "atari-fishingderby",
30
+ - "atari-freeway",
31
+ - "atari-frostbite",
32
+ - "atari-gopher",
33
+ - "atari-gravitar",
34
+ - "atari-hero",
35
+ - "atari-icehockey",
36
+ - "atari-jamesbond",
37
+ - "atari-kangaroo",
38
+ - "atari-krull",
39
+ - "atari-kungfumaster",
40
+ - "atari-montezumarevenge",
41
+ - "atari-mspacman",
42
+ - "atari-namethisgame",
43
+ - "atari-phoenix",
44
+ - "atari-pitfall",
45
+ - "atari-pong",
46
+ - "atari-privateeye",
47
+ - "atari-qbert",
48
+ - "atari-riverraid",
49
+ - "atari-roadrunner",
50
+ - "atari-robotank",
51
+ - "atari-seaquest",
52
+ - "atari-skiing",
53
+ - "atari-solaris",
54
+ - "atari-spaceinvaders",
55
+ - "atari-stargunner",
56
+ - # "atari-surround", # Not in the dataset
57
+ - "atari-tennis",
58
+ - "atari-timepilot",
59
+ - "atari-tutankham",
60
+ - "atari-upndown",
61
+ - "atari-venture",
62
+ - "atari-videopinball",
63
+ - "atari-wizardofwor",
64
+ - "atari-yarsrevenge",
65
+ - "atari-zaxxon",
66
+ - "babyai-action-obj-door",
67
+ - "babyai-blocked-unlock-pickup",
68
+ - "babyai-boss-level-no-unlock",
69
+ - "babyai-boss-level",
70
+ - "babyai-find-obj-s5",
71
+ - "babyai-go-to-door",
72
+ - # "babyai-go-to-imp-unlock", # Not in the dataset
73
+ - "babyai-go-to-local",
74
+ - "babyai-go-to-obj-door",
75
+ - "babyai-go-to-obj",
76
+ - "babyai-go-to-red-ball-grey",
77
+ - "babyai-go-to-red-ball-no-dists",
78
+ - "babyai-go-to-red-ball",
79
+ - "babyai-go-to-red-blue-ball",
80
+ - "babyai-go-to-seq",
81
+ - "babyai-go-to",
82
+ - "babyai-key-corridor",
83
+ - "babyai-key-in-box",
84
+ - "babyai-mini-boss-level",
85
+ - "babyai-move-two-across",
86
+ - "babyai-one-room-s8",
87
+ - "babyai-open-door",
88
+ - "babyai-open-doors-order",
89
+ - "babyai-open-red-door",
90
+ - "babyai-open-two-doors",
91
+ - "babyai-open",
92
+ - "babyai-pickup-above",
93
+ - "babyai-pickup-dist",
94
+ - "babyai-pickup-loc",
95
+ - "babyai-pickup",
96
+ - "babyai-synth-loc",
97
+ - "babyai-synth-seq",
98
+ - "babyai-synth",
99
+ - "babyai-unblock-pickup",
100
+ - "babyai-unlock-local",
101
+ - "babyai-unlock-pickup",
102
+ - # "babyai-unlock-to-unlock", # Not in the dataset
103
+ - # "babyai-unlock", # Not in the dataset
104
+ + # "atari-alien",
105
+ + # "atari-amidar",
106
+ + # "atari-assault",
107
+ + # "atari-asterix",
108
+ + # "atari-asteroids",
109
+ + # "atari-atlantis",
110
+ + # "atari-bankheist",
111
+ + # "atari-battlezone",
112
+ + # "atari-beamrider",
113
+ + # "atari-berzerk",
114
+ + # "atari-bowling",
115
+ + # "atari-boxing",
116
+ + # "atari-breakout",
117
+ + # "atari-centipede",
118
+ + # "atari-choppercommand",
119
+ + # "atari-crazyclimber",
120
+ + # "atari-defender",
121
+ + # "atari-demonattack",
122
+ + # "atari-doubledunk",
123
+ + # "atari-enduro",
124
+ + # "atari-fishingderby",
125
+ + # "atari-freeway",
126
+ + # "atari-frostbite",
127
+ + # "atari-gopher",
128
+ + # "atari-gravitar",
129
+ + # "atari-hero",
130
+ + # "atari-icehockey",
131
+ + # "atari-jamesbond",
132
+ + # "atari-kangaroo",
133
+ + # "atari-krull",
134
+ + # "atari-kungfumaster",
135
+ + # "atari-montezumarevenge",
136
+ + # "atari-mspacman",
137
+ + # "atari-namethisgame",
138
+ + # "atari-phoenix",
139
+ + # "atari-pitfall",
140
+ + # "atari-pong",
141
+ + # "atari-privateeye",
142
+ + # "atari-qbert",
143
+ + # "atari-riverraid",
144
+ + # "atari-roadrunner",
145
+ + # "atari-robotank",
146
+ + # "atari-seaquest",
147
+ + # "atari-skiing",
148
+ + # "atari-solaris",
149
+ + # "atari-spaceinvaders",
150
+ + # "atari-stargunner",
151
+ + # # "atari-surround", # Not in the dataset
152
+ + # "atari-tennis",
153
+ + # "atari-timepilot",
154
+ + # "atari-tutankham",
155
+ + # "atari-upndown",
156
+ + # "atari-venture",
157
+ + # "atari-videopinball",
158
+ + # "atari-wizardofwor",
159
+ + # "atari-yarsrevenge",
160
+ + # "atari-zaxxon",
161
+ + # "babyai-action-obj-door",
162
+ + # "babyai-blocked-unlock-pickup",
163
+ + # "babyai-boss-level-no-unlock",
164
+ + # "babyai-boss-level",
165
+ + # "babyai-find-obj-s5",
166
+ + # "babyai-go-to-door",
167
+ + # # "babyai-go-to-imp-unlock", # Not in the dataset
168
+ + # "babyai-go-to-local",
169
+ + # "babyai-go-to-obj-door",
170
+ + # "babyai-go-to-obj",
171
+ + # "babyai-go-to-red-ball-grey",
172
+ + # "babyai-go-to-red-ball-no-dists",
173
+ + # "babyai-go-to-red-ball",
174
+ + # "babyai-go-to-red-blue-ball",
175
+ + # "babyai-go-to-seq",
176
+ + # "babyai-go-to",
177
+ + # "babyai-key-corridor",
178
+ + # "babyai-key-in-box",
179
+ + # "babyai-mini-boss-level",
180
+ + # "babyai-move-two-across",
181
+ + # "babyai-one-room-s8",
182
+ + # "babyai-open-door",
183
+ + # "babyai-open-doors-order",
184
+ + # "babyai-open-red-door",
185
+ + # "babyai-open-two-doors",
186
+ + # "babyai-open",
187
+ + # "babyai-pickup-above",
188
+ + # "babyai-pickup-dist",
189
+ + # "babyai-pickup-loc",
190
+ + # "babyai-pickup",
191
+ + # "babyai-synth-loc",
192
+ + # "babyai-synth-seq",
193
+ + # "babyai-synth",
194
+ + # "babyai-unblock-pickup",
195
+ + # "babyai-unlock-local",
196
+ + # "babyai-unlock-pickup",
197
+ + # # "babyai-unlock-to-unlock", # Not in the dataset
198
+ + # # "babyai-unlock", # Not in the dataset
199
+ "metaworld-assembly",
200
+ - "metaworld-basketball",
201
+ - "metaworld-bin-picking",
202
+ - "metaworld-box-close",
203
+ - "metaworld-button-press-topdown-wall",
204
+ - "metaworld-button-press-topdown",
205
+ - "metaworld-button-press-wall",
206
+ - "metaworld-button-press",
207
+ - "metaworld-coffee-button",
208
+ - "metaworld-coffee-pull",
209
+ - "metaworld-coffee-push",
210
+ - "metaworld-dial-turn",
211
+ - "metaworld-disassemble",
212
+ - "metaworld-door-close",
213
+ - "metaworld-door-lock",
214
+ - "metaworld-door-open",
215
+ - "metaworld-door-unlock",
216
+ - "metaworld-drawer-close",
217
+ - "metaworld-drawer-open",
218
+ - "metaworld-faucet-close",
219
+ - "metaworld-faucet-open",
220
+ - "metaworld-hammer",
221
+ - "metaworld-hand-insert",
222
+ - "metaworld-handle-press-side",
223
+ - "metaworld-handle-press",
224
+ - "metaworld-handle-pull-side",
225
+ - "metaworld-handle-pull",
226
+ - "metaworld-lever-pull",
227
+ - "metaworld-peg-insert-side",
228
+ - "metaworld-peg-unplug-side",
229
+ - "metaworld-pick-out-of-hole",
230
+ - "metaworld-pick-place-wall",
231
+ - "metaworld-pick-place",
232
+ - "metaworld-plate-slide-back-side",
233
+ - "metaworld-plate-slide-back",
234
+ - "metaworld-plate-slide-side",
235
+ - "metaworld-plate-slide",
236
+ - "metaworld-push-back",
237
+ - "metaworld-push-wall",
238
+ - "metaworld-push",
239
+ - "metaworld-reach-wall",
240
+ - "metaworld-reach",
241
+ - "metaworld-shelf-place",
242
+ - "metaworld-soccer",
243
+ - "metaworld-stick-pull",
244
+ - "metaworld-stick-push",
245
+ - "metaworld-sweep-into",
246
+ - "metaworld-sweep",
247
+ - "metaworld-window-close",
248
+ - "metaworld-window-open",
249
+ - "mujoco-ant",
250
+ - "mujoco-doublependulum",
251
+ - "mujoco-halfcheetah",
252
+ - "mujoco-hopper",
253
+ + # "metaworld-basketball",
254
+ + # "metaworld-bin-picking",
255
+ + # "metaworld-box-close",
256
+ + # "metaworld-button-press-topdown-wall",
257
+ + # "metaworld-button-press-topdown",
258
+ + # "metaworld-button-press-wall",
259
+ + # "metaworld-button-press",
260
+ + # "metaworld-coffee-button",
261
+ + # "metaworld-coffee-pull",
262
+ + # "metaworld-coffee-push",
263
+ + # "metaworld-dial-turn",
264
+ + # "metaworld-disassemble",
265
+ + # "metaworld-door-close",
266
+ + # "metaworld-door-lock",
267
+ + # "metaworld-door-open",
268
+ + # "metaworld-door-unlock",
269
+ + # "metaworld-drawer-close",
270
+ + # "metaworld-drawer-open",
271
+ + # "metaworld-faucet-close",
272
+ + # "metaworld-faucet-open",
273
+ + # "metaworld-hammer",
274
+ + # "metaworld-hand-insert",
275
+ + # "metaworld-handle-press-side",
276
+ + # "metaworld-handle-press",
277
+ + # "metaworld-handle-pull-side",
278
+ + # "metaworld-handle-pull",
279
+ + # "metaworld-lever-pull",
280
+ + # "metaworld-peg-insert-side",
281
+ + # "metaworld-peg-unplug-side",
282
+ + # "metaworld-pick-out-of-hole",
283
+ + # "metaworld-pick-place-wall",
284
+ + # "metaworld-pick-place",
285
+ + # "metaworld-plate-slide-back-side",
286
+ + # "metaworld-plate-slide-back",
287
+ + # "metaworld-plate-slide-side",
288
+ + # "metaworld-plate-slide",
289
+ + # "metaworld-push-back",
290
+ + # "metaworld-push-wall",
291
+ + # "metaworld-push",
292
+ + # "metaworld-reach-wall",
293
+ + # "metaworld-reach",
294
+ + # "metaworld-shelf-place",
295
+ + # "metaworld-soccer",
296
+ + # "metaworld-stick-pull",
297
+ + # "metaworld-stick-push",
298
+ + # "metaworld-sweep-into",
299
+ + # "metaworld-sweep",
300
+ + # "metaworld-window-close",
301
+ + # "metaworld-window-open",
302
+ + # "mujoco-ant",
303
+ + # "mujoco-doublependulum",
304
+ + # "mujoco-halfcheetah",
305
+ + # "mujoco-hopper",
306
+ # "mujoco-humanoid", # Not in the dataset
307
+ - "mujoco-pendulum",
308
+ - # "mujoco-pusher", # Not in the dataset
309
+ - "mujoco-reacher",
310
+ + # "mujoco-pendulum",
311
+ + # # "mujoco-pusher", # Not in the dataset
312
+ + # "mujoco-reacher",
313
+ # "mujoco-standup", # Not in the dataset
314
+ - "mujoco-swimmer",
315
+ - "mujoco-walker",
316
+ + # "mujoco-swimmer",
317
+ + # "mujoco-walker",
318
+ ]
319
+
320
+
321
+ diff --git a/data/envs/metaworld/enjoy.py b/data/envs/metaworld/enjoy.py
322
+ deleted file mode 100644
323
+ index 6ec026b..0000000
324
+ --- a/data/envs/metaworld/enjoy.py
325
+ +++ /dev/null
326
+ @@ -1,84 +0,0 @@
327
+ -import sys
328
+ -from typing import Dict, Optional
329
+ -
330
+ -import gym
331
+ -import metaworld # noqa: F401
332
+ -from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
333
+ -from sample_factory.enjoy import enjoy
334
+ -from sample_factory.envs.env_utils import register_env
335
+ -
336
+ -
337
+ -ENV_NAMES = [
338
+ - "assembly-v2",
339
+ - "basketball-v2",
340
+ - "bin-picking-v2",
341
+ - "box-close-v2",
342
+ - "button-press-topdown-v2",
343
+ - "button-press-topdown-wall-v2",
344
+ - "button-press-v2",
345
+ - "button-press-wall-v2",
346
+ - "coffee-button-v2",
347
+ - "coffee-pull-v2",
348
+ - "coffee-push-v2",
349
+ - "dial-turn-v2",
350
+ - "disassemble-v2",
351
+ - "door-close-v2",
352
+ - "door-lock-v2",
353
+ - "door-open-v2",
354
+ - "door-unlock-v2",
355
+ - "drawer-close-v2",
356
+ - "drawer-open-v2",
357
+ - "faucet-close-v2",
358
+ - "faucet-open-v2",
359
+ - "hammer-v2",
360
+ - "hand-insert-v2",
361
+ - "handle-press-side-v2",
362
+ - "handle-press-v2",
363
+ - "handle-pull-side-v2",
364
+ - "handle-pull-v2",
365
+ - "lever-pull-v2",
366
+ - "peg-insert-side-v2",
367
+ - "peg-unplug-side-v2",
368
+ - "pick-out-of-hole-v2",
369
+ - "pick-place-v2",
370
+ - "pick-place-wall-v2",
371
+ - "plate-slide-back-side-v2",
372
+ - "plate-slide-back-v2",
373
+ - "plate-slide-side-v2",
374
+ - "plate-slide-v2",
375
+ - "push-back-v2",
376
+ - "push-v2",
377
+ - "push-wall-v2",
378
+ - "reach-v2",
379
+ - "reach-wall-v2",
380
+ - "shelf-place-v2",
381
+ - "soccer-v2",
382
+ - "stick-pull-v2",
383
+ - "stick-push-v2",
384
+ - "sweep-into-v2",
385
+ - "sweep-v2",
386
+ - "window-close-v2",
387
+ - "window-open-v2",
388
+ -]
389
+ -
390
+ -
391
+ -def make_custom_env(
392
+ - full_env_name: str,
393
+ - cfg: Optional[Dict] = None,
394
+ - env_config: Optional[Dict] = None,
395
+ - render_mode: Optional[str] = None,
396
+ -) -> gym.Env:
397
+ - return gym.make(full_env_name, render_mode=render_mode)
398
+ -
399
+ -
400
+ -def main() -> int:
401
+ - for env_name in ENV_NAMES:
402
+ - register_env(env_name, make_custom_env)
403
+ - parser, _ = parse_sf_args(argv=None, evaluation=True)
404
+ - cfg = parse_full_cfg(parser)
405
+ - status = enjoy(cfg)
406
+ - return status
407
+ -
408
+ -
409
+ -if __name__ == "__main__":
410
+ - sys.exit(main())
411
  diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py
412
  index e21b237..c2b1907 100644
413
  --- a/data/envs/metaworld/generate_dataset.py
 
423
  dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0])
424
  dataset["continuous_actions"][-1].append(actions[0])
425
  diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
426
+ index cfdae2f..8720089 100755
427
  --- a/data/envs/metaworld/generate_dataset_all.sh
428
  +++ b/data/envs/metaworld/generate_dataset_all.sh
429
+ @@ -1,7 +1,7 @@
430
+ #!/bin/bash
431
 
432
  ENVS=(
433
+ - assembly
434
+ + # assembly
435
+ basketball
436
+ bin-picking
437
+ box-close
438
+ @@ -9,51 +9,51 @@ ENVS=(
439
+ button-press-topdown-wall
440
+ button-press
441
+ button-press-wall
442
  - coffee-button
443
  - coffee-pull
444
  - coffee-push
 
481
  - sweep
482
  - window-close
483
  - window-open
 
 
 
 
 
 
 
484
  + # coffee-button
485
  + # coffee-pull
486
  + # coffee-push
 
532
  + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
533
  done
534
  diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh
535
+ index 9d71467..4fc1fc2 100755
536
  --- a/data/envs/metaworld/push_all.sh
537
  +++ b/data/envs/metaworld/push_all.sh
538
  @@ -2,57 +2,57 @@
 
641
 
642
  for ENV in "${ENVS[@]}"; do
643
  - python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
644
+ + python push.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
645
  done
646
+ diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py
647
+ index 46dc581..095414e 100644
648
+ --- a/data/envs/metaworld/train.py
649
+ +++ b/data/envs/metaworld/train.py
650
+ @@ -2,67 +2,13 @@ import argparse
651
+ import sys
652
+ from typing import Dict, Optional
653
+
654
+ -import gym
655
+ +import gymnasium as gym
656
+ import metaworld # noqa: F401
657
+ from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
658
+ from sample_factory.envs.env_utils import register_env
659
+ from sample_factory.train import run_rl
660
+
661
+
662
+ -ENV_NAMES = [
663
+ - "assembly-v2",
664
+ - "basketball-v2",
665
+ - "bin-picking-v2",
666
+ - "box-close-v2",
667
+ - "button-press-topdown-v2",
668
+ - "button-press-topdown-wall-v2",
669
+ - "button-press-v2",
670
+ - "button-press-wall-v2",
671
+ - "coffee-button-v2",
672
+ - "coffee-pull-v2",
673
+ - "coffee-push-v2",
674
+ - "dial-turn-v2",
675
+ - "disassemble-v2",
676
+ - "door-close-v2",
677
+ - "door-lock-v2",
678
+ - "door-open-v2",
679
+ - "door-unlock-v2",
680
+ - "drawer-close-v2",
681
+ - "drawer-open-v2",
682
+ - "faucet-close-v2",
683
+ - "faucet-open-v2",
684
+ - "hammer-v2",
685
+ - "hand-insert-v2",
686
+ - "handle-press-side-v2",
687
+ - "handle-press-v2",
688
+ - "handle-pull-side-v2",
689
+ - "handle-pull-v2",
690
+ - "lever-pull-v2",
691
+ - "peg-insert-side-v2",
692
+ - "peg-unplug-side-v2",
693
+ - "pick-out-of-hole-v2",
694
+ - "pick-place-v2",
695
+ - "pick-place-wall-v2",
696
+ - "plate-slide-back-side-v2",
697
+ - "plate-slide-back-v2",
698
+ - "plate-slide-side-v2",
699
+ - "plate-slide-v2",
700
+ - "push-back-v2",
701
+ - "push-v2",
702
+ - "push-wall-v2",
703
+ - "reach-v2",
704
+ - "reach-wall-v2",
705
+ - "shelf-place-v2",
706
+ - "soccer-v2",
707
+ - "stick-pull-v2",
708
+ - "stick-push-v2",
709
+ - "sweep-into-v2",
710
+ - "sweep-v2",
711
+ - "window-close-v2",
712
+ - "window-open-v2",
713
+ -]
714
+ -
715
+ -
716
+ def make_custom_env(
717
+ full_env_name: str,
718
+ cfg: Optional[Dict] = None,
719
+ @@ -79,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
720
+ num_workers=8,
721
+ num_envs_per_worker=8,
722
+ worker_num_splits=2,
723
+ - train_for_env_steps=100_000_000,
724
+ + train_for_env_steps=10_000_000,
725
+ encoder_mlp_layers=[64, 64],
726
+ env_frameskip=1,
727
+ nonlinearity="tanh",
728
+ @@ -116,11 +62,10 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
729
+
730
+
731
+ def main() -> int:
732
+ - for env_name in ENV_NAMES:
733
+ - register_env(env_name, make_custom_env)
734
+ parser, _ = parse_sf_args(argv=None, evaluation=False)
735
+ parser = override_defaults(parser)
736
+ cfg = parse_full_cfg(parser)
737
+ + register_env(cfg.env, make_custom_env)
738
+ status = run_rl(cfg)
739
+ return status
740
 
 
 
 
 
 
 
741
  diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
742
+ index 91b645c..3e2cae7 100644
743
  --- a/gia/eval/evaluator.py
744
  +++ b/gia/eval/evaluator.py
745
+ @@ -1,3 +1,5 @@
746
+ +from typing import Optional
747
+ +
748
+ import torch
749
 
750
  from gia.config.arguments import Arguments
751
+ @@ -5,11 +7,12 @@ from gia.model import GiaModel
 
752
 
753
 
754
  class Evaluator:
 
763
  def evaluate(self, model: GiaModel) -> float:
764
  return self._evaluate(model)
765
 
766
+ diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
767
+ index f1f83f5..3e8e182 100644
768
+ --- a/gia/eval/rl/envs/core.py
769
+ +++ b/gia/eval/rl/envs/core.py
770
+ @@ -177,6 +177,7 @@ def make(task_name: str, num_envs: int = 1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
 
772
+ elif task_name.startswith("metaworld"):
773
+ import gym
774
+ + import metaworld
775
+
776
+ env_id = TASK_TO_ENV_MAPPING[task_name]
777
+ env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs)
778
  diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
779
+ index f0d0b9b..ca37721 100644
780
  --- a/gia/eval/rl/gia_agent.py
781
  +++ b/gia/eval/rl/gia_agent.py
782
+ @@ -9,7 +9,7 @@ from gia.datasets import GiaDataCollator, Prompter
783
+ from gia.model.gia_model import GiaModel
784
+ from gia.processing import GiaProcessor
785
+
786
+ -
787
+ +import sample_factory.envs.env_utils
788
+ class GiaAgent:
789
+ r"""
790
+ An RL agent that uses Gia to generate actions.
791
  @@ -75,6 +75,11 @@ class GiaAgent:
792
  ) -> Tuple[Tuple[Tensor, Tensor], ...]:
793
  return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
 
814
 
815
 
816
  diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py
817
+ index c5cc423..91189f3 100644
818
  --- a/gia/eval/rl/rl_evaluator.py
819
  +++ b/gia/eval/rl/rl_evaluator.py
820
+ @@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent
821
 
822
 
823
  class RLEvaluator(Evaluator):
824
  + def __init__(self, args, task):
825
  + super().__init__(args, task)
826
  + self.agent = GiaAgent()
827
+ +
828
  def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ?
829
  raise NotImplementedError
830
 
831
+ diff --git a/gia/eval/rl/scores_dict.json b/gia/eval/rl/scores_dict.json
832
+ index 1b8ebee..ff7d030 100644
833
+ --- a/gia/eval/rl/scores_dict.json
834
+ +++ b/gia/eval/rl/scores_dict.json
835
+ @@ -929,8 +929,8 @@
836
+ },
837
+ "metaworld-assembly": {
838
+ "expert": {
839
+ - "mean": 311.29314618777823,
840
+ - "std": 75.04282151450695
841
+ + "mean": 3523.81468486244,
842
+ + "std": 63.22745220327798
843
+ },
844
+ "random": {
845
+ "mean": 220.65601680730813,
replay.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:23c9f230036c64e975bc021b96a313ba7f4bc9ad230969145be81a3b4a7e2471
3
- size 2553079
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53a34a5592b2bcc2c6dd205134780353b33455a576ec6ff7a1b3b6c94844a9ed
3
+ size 889542
sf_log.txt CHANGED
The diff for this file is too large to render. See raw diff