qgallouedec HF staff commited on
Commit
51dac85
·
1 Parent(s): 02b850c

Upload folder using huggingface_hub

Browse files
.summary/0/events.out.tfevents.1688812107.qgallouedec-MS-7C84 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06126be73fe0997e1b04084dc35ecfdc4a118296b435860a4a7cf8de71d7ab38
3
+ size 789833
README.md CHANGED
@@ -15,7 +15,7 @@ model-index:
15
  type: basketball-v2
16
  metrics:
17
  - type: mean_reward
18
- value: 272.23 +/- 43.84
19
  name: mean_reward
20
  verified: false
21
  ---
 
15
  type: basketball-v2
16
  metrics:
17
  - type: mean_reward
18
+ value: 615.55 +/- 1.99
19
  name: mean_reward
20
  verified: false
21
  ---
checkpoint_p0/best_000018776_9613312_reward_616.234.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c54cc819a55cb62efead40dd0754b4cba247173f1d989eae3b1adfc0a0e423af
3
+ size 98239
checkpoint_p0/checkpoint_000019472_9969664.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c1568eba462f89979296f8791d7135da493d6b7213512e79dc365f6b928cc2b
3
+ size 98567
checkpoint_p0/checkpoint_000019544_10006528.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:96e94fefc651dd6d6e0df6fd5a3a88bce68b7b80fb068036902d3b65f69cdd91
3
  size 98567
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8adddde812daca766e730c2ee83d4bb2742415959dc8ba82bab4fa96d6bfb561
3
  size 98567
config.json CHANGED
@@ -128,7 +128,7 @@
128
  "wandb_user": "qgallouedec",
129
  "wandb_project": "sample_facotry_metaworld"
130
  },
131
- "git_hash": "dda7c2cbaa4c60ae8940e37f69d814d32339d2fa",
132
  "git_repo_name": "https://github.com/huggingface/gia",
133
- "wandb_unique_id": "basketball-v2_20230707_195715_026260"
134
  }
 
128
  "wandb_user": "qgallouedec",
129
  "wandb_project": "sample_facotry_metaworld"
130
  },
131
+ "git_hash": "66db1b7a27030aa65fcfa2d6e3503089a7cff207",
132
  "git_repo_name": "https://github.com/huggingface/gia",
133
+ "wandb_unique_id": "basketball-v2_20230708_122825_084480"
134
  }
git.diff CHANGED
@@ -1,742 +1,18 @@
1
- diff --git a/data/envs/download_expert_scores.py b/data/envs/download_expert_scores.py
2
- index 4c3f06b..88b6c45 100644
3
- --- a/data/envs/download_expert_scores.py
4
- +++ b/data/envs/download_expert_scores.py
5
- @@ -12,162 +12,162 @@ from tqdm import tqdm
6
-
7
-
8
- ENV_NAMES = [
9
- - "atari-alien",
10
- - "atari-amidar",
11
- - "atari-assault",
12
- - "atari-asterix",
13
- - "atari-asteroids",
14
- - "atari-atlantis",
15
- - "atari-bankheist",
16
- - "atari-battlezone",
17
- - "atari-beamrider",
18
- - "atari-berzerk",
19
- - "atari-bowling",
20
- - "atari-boxing",
21
- - "atari-breakout",
22
- - "atari-centipede",
23
- - "atari-choppercommand",
24
- - "atari-crazyclimber",
25
- - "atari-defender",
26
- - "atari-demonattack",
27
- - "atari-doubledunk",
28
- - "atari-enduro",
29
- - "atari-fishingderby",
30
- - "atari-freeway",
31
- - "atari-frostbite",
32
- - "atari-gopher",
33
- - "atari-gravitar",
34
- - "atari-hero",
35
- - "atari-icehockey",
36
- - "atari-jamesbond",
37
- - "atari-kangaroo",
38
- - "atari-krull",
39
- - "atari-kungfumaster",
40
- - "atari-montezumarevenge",
41
- - "atari-mspacman",
42
- - "atari-namethisgame",
43
- - "atari-phoenix",
44
- - "atari-pitfall",
45
- - "atari-pong",
46
- - "atari-privateeye",
47
- - "atari-qbert",
48
- - "atari-riverraid",
49
- - "atari-roadrunner",
50
- - "atari-robotank",
51
- - "atari-seaquest",
52
- - "atari-skiing",
53
- - "atari-solaris",
54
- - "atari-spaceinvaders",
55
- - "atari-stargunner",
56
- - # "atari-surround", # Not in the dataset
57
- - "atari-tennis",
58
- - "atari-timepilot",
59
- - "atari-tutankham",
60
- - "atari-upndown",
61
- - "atari-venture",
62
- - "atari-videopinball",
63
- - "atari-wizardofwor",
64
- - "atari-yarsrevenge",
65
- - "atari-zaxxon",
66
- - "babyai-action-obj-door",
67
- - "babyai-blocked-unlock-pickup",
68
- - "babyai-boss-level-no-unlock",
69
- - "babyai-boss-level",
70
- - "babyai-find-obj-s5",
71
- - "babyai-go-to-door",
72
- - # "babyai-go-to-imp-unlock", # Not in the dataset
73
- - "babyai-go-to-local",
74
- - "babyai-go-to-obj-door",
75
- - "babyai-go-to-obj",
76
- - "babyai-go-to-red-ball-grey",
77
- - "babyai-go-to-red-ball-no-dists",
78
- - "babyai-go-to-red-ball",
79
- - "babyai-go-to-red-blue-ball",
80
- - "babyai-go-to-seq",
81
- - "babyai-go-to",
82
- - "babyai-key-corridor",
83
- - "babyai-key-in-box",
84
- - "babyai-mini-boss-level",
85
- - "babyai-move-two-across",
86
- - "babyai-one-room-s8",
87
- - "babyai-open-door",
88
- - "babyai-open-doors-order",
89
- - "babyai-open-red-door",
90
- - "babyai-open-two-doors",
91
- - "babyai-open",
92
- - "babyai-pickup-above",
93
- - "babyai-pickup-dist",
94
- - "babyai-pickup-loc",
95
- - "babyai-pickup",
96
- - "babyai-synth-loc",
97
- - "babyai-synth-seq",
98
- - "babyai-synth",
99
- - "babyai-unblock-pickup",
100
- - "babyai-unlock-local",
101
- - "babyai-unlock-pickup",
102
- - # "babyai-unlock-to-unlock", # Not in the dataset
103
- - # "babyai-unlock", # Not in the dataset
104
- + # "atari-alien",
105
- + # "atari-amidar",
106
- + # "atari-assault",
107
- + # "atari-asterix",
108
- + # "atari-asteroids",
109
- + # "atari-atlantis",
110
- + # "atari-bankheist",
111
- + # "atari-battlezone",
112
- + # "atari-beamrider",
113
- + # "atari-berzerk",
114
- + # "atari-bowling",
115
- + # "atari-boxing",
116
- + # "atari-breakout",
117
- + # "atari-centipede",
118
- + # "atari-choppercommand",
119
- + # "atari-crazyclimber",
120
- + # "atari-defender",
121
- + # "atari-demonattack",
122
- + # "atari-doubledunk",
123
- + # "atari-enduro",
124
- + # "atari-fishingderby",
125
- + # "atari-freeway",
126
- + # "atari-frostbite",
127
- + # "atari-gopher",
128
- + # "atari-gravitar",
129
- + # "atari-hero",
130
- + # "atari-icehockey",
131
- + # "atari-jamesbond",
132
- + # "atari-kangaroo",
133
- + # "atari-krull",
134
- + # "atari-kungfumaster",
135
- + # "atari-montezumarevenge",
136
- + # "atari-mspacman",
137
- + # "atari-namethisgame",
138
- + # "atari-phoenix",
139
- + # "atari-pitfall",
140
- + # "atari-pong",
141
- + # "atari-privateeye",
142
- + # "atari-qbert",
143
- + # "atari-riverraid",
144
- + # "atari-roadrunner",
145
- + # "atari-robotank",
146
- + # "atari-seaquest",
147
- + # "atari-skiing",
148
- + # "atari-solaris",
149
- + # "atari-spaceinvaders",
150
- + # "atari-stargunner",
151
- + # # "atari-surround", # Not in the dataset
152
- + # "atari-tennis",
153
- + # "atari-timepilot",
154
- + # "atari-tutankham",
155
- + # "atari-upndown",
156
- + # "atari-venture",
157
- + # "atari-videopinball",
158
- + # "atari-wizardofwor",
159
- + # "atari-yarsrevenge",
160
- + # "atari-zaxxon",
161
- + # "babyai-action-obj-door",
162
- + # "babyai-blocked-unlock-pickup",
163
- + # "babyai-boss-level-no-unlock",
164
- + # "babyai-boss-level",
165
- + # "babyai-find-obj-s5",
166
- + # "babyai-go-to-door",
167
- + # # "babyai-go-to-imp-unlock", # Not in the dataset
168
- + # "babyai-go-to-local",
169
- + # "babyai-go-to-obj-door",
170
- + # "babyai-go-to-obj",
171
- + # "babyai-go-to-red-ball-grey",
172
- + # "babyai-go-to-red-ball-no-dists",
173
- + # "babyai-go-to-red-ball",
174
- + # "babyai-go-to-red-blue-ball",
175
- + # "babyai-go-to-seq",
176
- + # "babyai-go-to",
177
- + # "babyai-key-corridor",
178
- + # "babyai-key-in-box",
179
- + # "babyai-mini-boss-level",
180
- + # "babyai-move-two-across",
181
- + # "babyai-one-room-s8",
182
- + # "babyai-open-door",
183
- + # "babyai-open-doors-order",
184
- + # "babyai-open-red-door",
185
- + # "babyai-open-two-doors",
186
- + # "babyai-open",
187
- + # "babyai-pickup-above",
188
- + # "babyai-pickup-dist",
189
- + # "babyai-pickup-loc",
190
- + # "babyai-pickup",
191
- + # "babyai-synth-loc",
192
- + # "babyai-synth-seq",
193
- + # "babyai-synth",
194
- + # "babyai-unblock-pickup",
195
- + # "babyai-unlock-local",
196
- + # "babyai-unlock-pickup",
197
- + # # "babyai-unlock-to-unlock", # Not in the dataset
198
- + # # "babyai-unlock", # Not in the dataset
199
- "metaworld-assembly",
200
- - "metaworld-basketball",
201
- - "metaworld-bin-picking",
202
- - "metaworld-box-close",
203
- - "metaworld-button-press-topdown-wall",
204
- - "metaworld-button-press-topdown",
205
- - "metaworld-button-press-wall",
206
- - "metaworld-button-press",
207
- - "metaworld-coffee-button",
208
- - "metaworld-coffee-pull",
209
- - "metaworld-coffee-push",
210
- - "metaworld-dial-turn",
211
- - "metaworld-disassemble",
212
- - "metaworld-door-close",
213
- - "metaworld-door-lock",
214
- - "metaworld-door-open",
215
- - "metaworld-door-unlock",
216
- - "metaworld-drawer-close",
217
- - "metaworld-drawer-open",
218
- - "metaworld-faucet-close",
219
- - "metaworld-faucet-open",
220
- - "metaworld-hammer",
221
- - "metaworld-hand-insert",
222
- - "metaworld-handle-press-side",
223
- - "metaworld-handle-press",
224
- - "metaworld-handle-pull-side",
225
- - "metaworld-handle-pull",
226
- - "metaworld-lever-pull",
227
- - "metaworld-peg-insert-side",
228
- - "metaworld-peg-unplug-side",
229
- - "metaworld-pick-out-of-hole",
230
- - "metaworld-pick-place-wall",
231
- - "metaworld-pick-place",
232
- - "metaworld-plate-slide-back-side",
233
- - "metaworld-plate-slide-back",
234
- - "metaworld-plate-slide-side",
235
- - "metaworld-plate-slide",
236
- - "metaworld-push-back",
237
- - "metaworld-push-wall",
238
- - "metaworld-push",
239
- - "metaworld-reach-wall",
240
- - "metaworld-reach",
241
- - "metaworld-shelf-place",
242
- - "metaworld-soccer",
243
- - "metaworld-stick-pull",
244
- - "metaworld-stick-push",
245
- - "metaworld-sweep-into",
246
- - "metaworld-sweep",
247
- - "metaworld-window-close",
248
- - "metaworld-window-open",
249
- - "mujoco-ant",
250
- - "mujoco-doublependulum",
251
- - "mujoco-halfcheetah",
252
- - "mujoco-hopper",
253
- + # "metaworld-basketball",
254
- + # "metaworld-bin-picking",
255
- + # "metaworld-box-close",
256
- + # "metaworld-button-press-topdown-wall",
257
- + # "metaworld-button-press-topdown",
258
- + # "metaworld-button-press-wall",
259
- + # "metaworld-button-press",
260
- + # "metaworld-coffee-button",
261
- + # "metaworld-coffee-pull",
262
- + # "metaworld-coffee-push",
263
- + # "metaworld-dial-turn",
264
- + # "metaworld-disassemble",
265
- + # "metaworld-door-close",
266
- + # "metaworld-door-lock",
267
- + # "metaworld-door-open",
268
- + # "metaworld-door-unlock",
269
- + # "metaworld-drawer-close",
270
- + # "metaworld-drawer-open",
271
- + # "metaworld-faucet-close",
272
- + # "metaworld-faucet-open",
273
- + # "metaworld-hammer",
274
- + # "metaworld-hand-insert",
275
- + # "metaworld-handle-press-side",
276
- + # "metaworld-handle-press",
277
- + # "metaworld-handle-pull-side",
278
- + # "metaworld-handle-pull",
279
- + # "metaworld-lever-pull",
280
- + # "metaworld-peg-insert-side",
281
- + # "metaworld-peg-unplug-side",
282
- + # "metaworld-pick-out-of-hole",
283
- + # "metaworld-pick-place-wall",
284
- + # "metaworld-pick-place",
285
- + # "metaworld-plate-slide-back-side",
286
- + # "metaworld-plate-slide-back",
287
- + # "metaworld-plate-slide-side",
288
- + # "metaworld-plate-slide",
289
- + # "metaworld-push-back",
290
- + # "metaworld-push-wall",
291
- + # "metaworld-push",
292
- + # "metaworld-reach-wall",
293
- + # "metaworld-reach",
294
- + # "metaworld-shelf-place",
295
- + # "metaworld-soccer",
296
- + # "metaworld-stick-pull",
297
- + # "metaworld-stick-push",
298
- + # "metaworld-sweep-into",
299
- + # "metaworld-sweep",
300
- + # "metaworld-window-close",
301
- + # "metaworld-window-open",
302
- + # "mujoco-ant",
303
- + # "mujoco-doublependulum",
304
- + # "mujoco-halfcheetah",
305
- + # "mujoco-hopper",
306
- # "mujoco-humanoid", # Not in the dataset
307
- - "mujoco-pendulum",
308
- - # "mujoco-pusher", # Not in the dataset
309
- - "mujoco-reacher",
310
- + # "mujoco-pendulum",
311
- + # # "mujoco-pusher", # Not in the dataset
312
- + # "mujoco-reacher",
313
- # "mujoco-standup", # Not in the dataset
314
- - "mujoco-swimmer",
315
- - "mujoco-walker",
316
- + # "mujoco-swimmer",
317
- + # "mujoco-walker",
318
- ]
319
-
320
-
321
- diff --git a/data/envs/metaworld/enjoy.py b/data/envs/metaworld/enjoy.py
322
- deleted file mode 100644
323
- index 6ec026b..0000000
324
- --- a/data/envs/metaworld/enjoy.py
325
- +++ /dev/null
326
- @@ -1,84 +0,0 @@
327
- -import sys
328
- -from typing import Dict, Optional
329
- -
330
- -import gym
331
- -import metaworld # noqa: F401
332
- -from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
333
- -from sample_factory.enjoy import enjoy
334
- -from sample_factory.envs.env_utils import register_env
335
- -
336
- -
337
- -ENV_NAMES = [
338
- - "assembly-v2",
339
- - "basketball-v2",
340
- - "bin-picking-v2",
341
- - "box-close-v2",
342
- - "button-press-topdown-v2",
343
- - "button-press-topdown-wall-v2",
344
- - "button-press-v2",
345
- - "button-press-wall-v2",
346
- - "coffee-button-v2",
347
- - "coffee-pull-v2",
348
- - "coffee-push-v2",
349
- - "dial-turn-v2",
350
- - "disassemble-v2",
351
- - "door-close-v2",
352
- - "door-lock-v2",
353
- - "door-open-v2",
354
- - "door-unlock-v2",
355
- - "drawer-close-v2",
356
- - "drawer-open-v2",
357
- - "faucet-close-v2",
358
- - "faucet-open-v2",
359
- - "hammer-v2",
360
- - "hand-insert-v2",
361
- - "handle-press-side-v2",
362
- - "handle-press-v2",
363
- - "handle-pull-side-v2",
364
- - "handle-pull-v2",
365
- - "lever-pull-v2",
366
- - "peg-insert-side-v2",
367
- - "peg-unplug-side-v2",
368
- - "pick-out-of-hole-v2",
369
- - "pick-place-v2",
370
- - "pick-place-wall-v2",
371
- - "plate-slide-back-side-v2",
372
- - "plate-slide-back-v2",
373
- - "plate-slide-side-v2",
374
- - "plate-slide-v2",
375
- - "push-back-v2",
376
- - "push-v2",
377
- - "push-wall-v2",
378
- - "reach-v2",
379
- - "reach-wall-v2",
380
- - "shelf-place-v2",
381
- - "soccer-v2",
382
- - "stick-pull-v2",
383
- - "stick-push-v2",
384
- - "sweep-into-v2",
385
- - "sweep-v2",
386
- - "window-close-v2",
387
- - "window-open-v2",
388
- -]
389
- -
390
- -
391
- -def make_custom_env(
392
- - full_env_name: str,
393
- - cfg: Optional[Dict] = None,
394
- - env_config: Optional[Dict] = None,
395
- - render_mode: Optional[str] = None,
396
- -) -> gym.Env:
397
- - return gym.make(full_env_name, render_mode=render_mode)
398
- -
399
- -
400
- -def main() -> int:
401
- - for env_name in ENV_NAMES:
402
- - register_env(env_name, make_custom_env)
403
- - parser, _ = parse_sf_args(argv=None, evaluation=True)
404
- - cfg = parse_full_cfg(parser)
405
- - status = enjoy(cfg)
406
- - return status
407
- -
408
- -
409
- -if __name__ == "__main__":
410
- - sys.exit(main())
411
- diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py
412
- index e21b237..c2b1907 100644
413
- --- a/data/envs/metaworld/generate_dataset.py
414
- +++ b/data/envs/metaworld/generate_dataset.py
415
- @@ -142,7 +142,8 @@ def create_dataset(cfg: Config, dataset_size: int = 100_000, split: str = "train
416
-
417
- # Actions shape should be [num_agents, num_actions] even if it's [1, 1]
418
- actions = preprocess_actions(env_info, actions)
419
- -
420
- + # Clamp actions to be in the range of the action space
421
- + actions = np.clip(actions, env.action_space.low, env.action_space.high)
422
- rnn_states = policy_outputs["new_rnn_states"]
423
- dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0])
424
- dataset["continuous_actions"][-1].append(actions[0])
425
- diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
426
- index cfdae2f..8720089 100755
427
- --- a/data/envs/metaworld/generate_dataset_all.sh
428
- +++ b/data/envs/metaworld/generate_dataset_all.sh
429
- @@ -1,7 +1,7 @@
430
- #!/bin/bash
431
-
432
- ENVS=(
433
- - assembly
434
- + # assembly
435
- basketball
436
- bin-picking
437
- box-close
438
- @@ -9,51 +9,51 @@ ENVS=(
439
- button-press-topdown-wall
440
- button-press
441
- button-press-wall
442
- - coffee-button
443
- - coffee-pull
444
- - coffee-push
445
- - dial-turn
446
- - disassemble
447
- - door-close
448
- - door-lock
449
- - door-open
450
- - door-unlock
451
- - drawer-close
452
- - drawer-open
453
- - faucet-close
454
- - faucet-open
455
- - hammer
456
- - hand-insert
457
- - handle-press-side
458
- - handle-press
459
- - handle-pull-side
460
- - handle-pull
461
- - lever-pull
462
- - peg-insert-side
463
- - peg-unplug-side
464
- - pick-out-of-hole
465
- - pick-place
466
- - pick-place-wall
467
- - plate-slide-back-side
468
- - plate-slide-back
469
- - plate-slide-side
470
- - plate-slide
471
- - push-back
472
- - push
473
- - push-wall
474
- - reach
475
- - reach-wall
476
- - shelf-place
477
- - soccer
478
- - stick-pull
479
- - stick-push
480
- - sweep-into
481
- - sweep
482
- - window-close
483
- - window-open
484
- + # coffee-button
485
- + # coffee-pull
486
- + # coffee-push
487
- + # dial-turn
488
- + # disassemble
489
- + # door-close
490
- + # door-lock
491
- + # door-open
492
- + # door-unlock
493
- + # drawer-close
494
- + # drawer-open
495
- + # faucet-close
496
- + # faucet-open
497
- + # hammer
498
- + # hand-insert
499
- + # handle-press-side
500
- + # handle-press
501
- + # handle-pull-side
502
- + # handle-pull
503
- + # lever-pull
504
- + # peg-insert-side
505
- + # peg-unplug-side
506
- + # pick-out-of-hole
507
- + # pick-place
508
- + # pick-place-wall
509
- + # plate-slide-back-side
510
- + # plate-slide-back
511
- + # plate-slide-side
512
- + # plate-slide
513
- + # push-back
514
- + # push
515
- + # push-wall
516
- + # reach
517
- + # reach-wall
518
- + # shelf-place
519
- + # soccer
520
- + # stick-pull
521
- + # stick-push
522
- + # sweep-into
523
- + # sweep
524
- + # window-close
525
- + # window-open
526
- )
527
-
528
- for ENV in "${ENVS[@]}"; do
529
- - python -m sample_factory.huggingface.load_from_hub -r qgallouedec/sample-factory-$ENV-v2
530
- - python generate_dataset.py --env $ENV-v2 --experiment sample-factory-$ENV-v2 --train_dir=./train_dir
531
- + python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2
532
- + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
533
- done
534
- diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh
535
- index 9d71467..4fc1fc2 100755
536
- --- a/data/envs/metaworld/push_all.sh
537
- +++ b/data/envs/metaworld/push_all.sh
538
- @@ -2,57 +2,57 @@
539
-
540
- ENVS=(
541
- assembly
542
- - basketball
543
- - bin-picking
544
- - box-close
545
- - button-press-topdown
546
- - button-press-topdown-wall
547
- - button-press
548
- - button-press-wall
549
- - coffee-button
550
- - coffee-pull
551
- - coffee-push
552
- - dial-turn
553
- - disassemble
554
- - door-close
555
- - door-lock
556
- - door-open
557
- - door-unlock
558
- - drawer-close
559
- - drawer-open
560
- - faucet-close
561
- - faucet-open
562
- - hammer
563
- - hand-insert
564
- - handle-press-side
565
- - handle-press
566
- - handle-pull-side
567
- - handle-pull
568
- - lever-pull
569
- - peg-insert-side
570
- - peg-unplug-side
571
- - pick-out-of-hole
572
- - pick-place
573
- - pick-place-wall
574
- - plate-slide-back-side
575
- - plate-slide-back
576
- - plate-slide-side
577
- - plate-slide
578
- - push-back
579
- - push
580
- - push-wall
581
- - reach
582
- - reach-wall
583
- - shelf-place
584
- - soccer
585
- - stick-pull
586
- - stick-push
587
- - sweep-into
588
- - sweep
589
- - window-close
590
- - window-open
591
- + # basketball
592
- + # bin-picking
593
- + # box-close
594
- + # button-press-topdown
595
- + # button-press-topdown-wall
596
- + # button-press
597
- + # button-press-wall
598
- + # coffee-button
599
- + # coffee-pull
600
- + # coffee-push
601
- + # dial-turn
602
- + # disassemble
603
- + # door-close
604
- + # door-lock
605
- + # door-open
606
- + # door-unlock
607
- + # drawer-close
608
- + # drawer-open
609
- + # faucet-close
610
- + # faucet-open
611
- + # hammer
612
- + # hand-insert
613
- + # handle-press-side
614
- + # handle-press
615
- + # handle-pull-side
616
- + # handle-pull
617
- + # lever-pull
618
- + # peg-insert-side
619
- + # peg-unplug-side
620
- + # pick-out-of-hole
621
- + # pick-place
622
- + # pick-place-wall
623
- + # plate-slide-back-side
624
- + # plate-slide-back
625
- + # plate-slide-side
626
- + # plate-slide
627
- + # push-back
628
- + # push
629
- + # push-wall
630
- + # reach
631
- + # reach-wall
632
- + # shelf-place
633
- + # soccer
634
- + # stick-pull
635
- + # stick-push
636
- + # sweep-into
637
- + # sweep
638
- + # window-close
639
- + # window-open
640
- )
641
-
642
- for ENV in "${ENVS[@]}"; do
643
- - python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
644
- + python push.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
645
- done
646
- diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py
647
- index 46dc581..095414e 100644
648
- --- a/data/envs/metaworld/train.py
649
- +++ b/data/envs/metaworld/train.py
650
- @@ -2,67 +2,13 @@ import argparse
651
- import sys
652
- from typing import Dict, Optional
653
-
654
- -import gym
655
- +import gymnasium as gym
656
- import metaworld # noqa: F401
657
- from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
658
- from sample_factory.envs.env_utils import register_env
659
- from sample_factory.train import run_rl
660
-
661
-
662
- -ENV_NAMES = [
663
- - "assembly-v2",
664
- - "basketball-v2",
665
- - "bin-picking-v2",
666
- - "box-close-v2",
667
- - "button-press-topdown-v2",
668
- - "button-press-topdown-wall-v2",
669
- - "button-press-v2",
670
- - "button-press-wall-v2",
671
- - "coffee-button-v2",
672
- - "coffee-pull-v2",
673
- - "coffee-push-v2",
674
- - "dial-turn-v2",
675
- - "disassemble-v2",
676
- - "door-close-v2",
677
- - "door-lock-v2",
678
- - "door-open-v2",
679
- - "door-unlock-v2",
680
- - "drawer-close-v2",
681
- - "drawer-open-v2",
682
- - "faucet-close-v2",
683
- - "faucet-open-v2",
684
- - "hammer-v2",
685
- - "hand-insert-v2",
686
- - "handle-press-side-v2",
687
- - "handle-press-v2",
688
- - "handle-pull-side-v2",
689
- - "handle-pull-v2",
690
- - "lever-pull-v2",
691
- - "peg-insert-side-v2",
692
- - "peg-unplug-side-v2",
693
- - "pick-out-of-hole-v2",
694
- - "pick-place-v2",
695
- - "pick-place-wall-v2",
696
- - "plate-slide-back-side-v2",
697
- - "plate-slide-back-v2",
698
- - "plate-slide-side-v2",
699
- - "plate-slide-v2",
700
- - "push-back-v2",
701
- - "push-v2",
702
- - "push-wall-v2",
703
- - "reach-v2",
704
- - "reach-wall-v2",
705
- - "shelf-place-v2",
706
- - "soccer-v2",
707
- - "stick-pull-v2",
708
- - "stick-push-v2",
709
- - "sweep-into-v2",
710
- - "sweep-v2",
711
- - "window-close-v2",
712
- - "window-open-v2",
713
- -]
714
- -
715
- -
716
- def make_custom_env(
717
- full_env_name: str,
718
- cfg: Optional[Dict] = None,
719
- @@ -79,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
720
- num_workers=8,
721
- num_envs_per_worker=8,
722
- worker_num_splits=2,
723
- - train_for_env_steps=100_000_000,
724
- + train_for_env_steps=10_000_000,
725
- encoder_mlp_layers=[64, 64],
726
- env_frameskip=1,
727
- nonlinearity="tanh",
728
- @@ -116,11 +62,10 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
729
-
730
-
731
- def main() -> int:
732
- - for env_name in ENV_NAMES:
733
- - register_env(env_name, make_custom_env)
734
- parser, _ = parse_sf_args(argv=None, evaluation=False)
735
- parser = override_defaults(parser)
736
- cfg = parse_full_cfg(parser)
737
- + register_env(cfg.env, make_custom_env)
738
- status = run_rl(cfg)
739
- return status
740
 
741
  diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
742
  index 91b645c..3e2cae7 100644
@@ -764,30 +40,30 @@ index 91b645c..3e2cae7 100644
764
  return self._evaluate(model)
765
 
766
  diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
767
- index f1f83f5..3e8e182 100644
768
  --- a/gia/eval/rl/envs/core.py
769
  +++ b/gia/eval/rl/envs/core.py
770
- @@ -177,6 +177,7 @@ def make(task_name: str, num_envs: int = 1):
771
 
772
  elif task_name.startswith("metaworld"):
773
- import gym
774
- + import metaworld
775
 
776
  env_id = TASK_TO_ENV_MAPPING[task_name]
777
  env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs)
778
  diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
779
- index f0d0b9b..ca37721 100644
780
  --- a/gia/eval/rl/gia_agent.py
781
  +++ b/gia/eval/rl/gia_agent.py
782
- @@ -9,7 +9,7 @@ from gia.datasets import GiaDataCollator, Prompter
783
- from gia.model.gia_model import GiaModel
784
- from gia.processing import GiaProcessor
785
-
786
- -
787
- +import sample_factory.envs.env_utils
788
- class GiaAgent:
789
- r"""
790
- An RL agent that uses Gia to generate actions.
791
  @@ -75,6 +75,11 @@ class GiaAgent:
792
  ) -> Tuple[Tuple[Tensor, Tensor], ...]:
793
  return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
 
1
+ diff --git a/gia/eval/callback.py b/gia/eval/callback.py
2
+ index 5c3a080..4b6198f 100644
3
+ --- a/gia/eval/callback.py
4
+ +++ b/gia/eval/callback.py
5
+ @@ -2,10 +2,10 @@ import glob
6
+ import json
7
+ import subprocess
8
+
9
+ -import wandb
10
+ from accelerate import Accelerator
11
+ from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
12
+
13
+ +import wandb
14
+ from gia.config import Arguments
15
+ from gia.eval.utils import is_slurm_available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
18
  index 91b645c..3e2cae7 100644
 
40
  return self._evaluate(model)
41
 
42
  diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
43
+ index ec5e5b2..eeaf7cb 100644
44
  --- a/gia/eval/rl/envs/core.py
45
  +++ b/gia/eval/rl/envs/core.py
46
+ @@ -177,7 +177,6 @@ def make(task_name: str, num_envs: int = 1):
47
 
48
  elif task_name.startswith("metaworld"):
49
+ import gymnasium as gym
50
+ - import metaworld
51
 
52
  env_id = TASK_TO_ENV_MAPPING[task_name]
53
  env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs)
54
  diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
55
+ index f0d0b9b..39dc0d2 100644
56
  --- a/gia/eval/rl/gia_agent.py
57
  +++ b/gia/eval/rl/gia_agent.py
58
+ @@ -54,7 +54,7 @@ class GiaAgent:
59
+ self.action_space = action_space
60
+ self.deterministic = deterministic
61
+ self.device = next(model.parameters()).device
62
+ - self._max_length = self.model.config.max_position_embeddings - 10
63
+ + self._max_length = self.model.config.max_position_embeddings - 100 # TODO: fix this
64
+
65
+ if isinstance(observation_space, spaces.Box):
66
+ self._observation_key = "continuous_observations"
67
  @@ -75,6 +75,11 @@ class GiaAgent:
68
  ) -> Tuple[Tuple[Tensor, Tensor], ...]:
69
  return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
replay.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:53a34a5592b2bcc2c6dd205134780353b33455a576ec6ff7a1b3b6c94844a9ed
3
- size 889542
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd7603d5f91547822cbba1aee40ee4afc08a005ca6440fa3698075b4c134087d
3
+ size 856078
sf_log.txt CHANGED
The diff for this file is too large to render. See raw diff