liruiw commited on
Commit
eeb6b8a
·
verified ·
1 Parent(s): 387453d

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +6 -21
  2. config.json +1068 -0
  3. model.safetensors +3 -0
README.md CHANGED
@@ -1,24 +1,9 @@
1
  ---
2
- pipeline_tag: robotics
 
 
3
  ---
4
- # 🦾 Heterogenous Masked Autoregression
5
-
6
 
7
-
8
- Paper:
9
-
10
- You can find more details on our [project page](https://liruiw.github.io/hma).
11
-
12
-
13
- **TL;DR:** HMA can generate diverse embodied video-action dynamics with real-time speed.
14
-
15
-
16
- If you find HMA useful in your research, please consider citing:
17
- ```
18
- }
19
- ```
20
-
21
-
22
- ## Contact
23
-
24
- If you have any questions, feel free to contact me through email ([email protected]). Enjoy!
 
1
  ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
  ---
 
 
6
 
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Library: [More Information Needed]
9
+ - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "S": 256,
3
+ "T": 12,
4
+ "action_contrastive_loss": false,
5
+ "action_domains": [
6
+ "bridge_data_v2",
7
+ "fractal20220817_data",
8
+ "language_table",
9
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
10
+ "kaist_nonprehensile_converted_externally_to_rlds",
11
+ "ucsd_kitchen_dataset_converted_externally_to_rlds",
12
+ "utokyo_xarm_bimanual_converted_externally_to_rlds",
13
+ "stanford_hydra_dataset_converted_externally_to_rlds",
14
+ "austin_sirius_dataset_converted_externally_to_rlds",
15
+ "berkeley_fanuc_manipulation",
16
+ "berkeley_mvp_converted_externally_to_rlds",
17
+ "berkeley_rpt_converted_externally_to_rlds",
18
+ "cmu_play_fusion",
19
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
20
+ "qut_dexterous_manpulation",
21
+ "robo_net",
22
+ "furniture_bench_dataset_converted_externally_to_rlds",
23
+ "dlr_sara_grid_clamp_converted_externally_to_rlds",
24
+ "cmu_stretch",
25
+ "spoc",
26
+ "columbia_cairlab_pusht_real",
27
+ "droid",
28
+ "toto",
29
+ "io_ai_tech",
30
+ "conq_hose_manipulation",
31
+ "dobbe",
32
+ "berkeley_gnm_cory_hall",
33
+ "plex_robosuite",
34
+ "usc_cloth_sim_converted_externally_to_rlds",
35
+ "berkeley_cable_routing",
36
+ "imperial_wrist_dataset",
37
+ "bc_z",
38
+ "kuka",
39
+ "roboturk",
40
+ "metaworld",
41
+ "robomimic",
42
+ "epic_kitchen",
43
+ "ego4d",
44
+ "nyu_door_opening_surprising_effectiveness"
45
+ ],
46
+ "action_loss_weight": 0.5,
47
+ "action_network": "concat+modulate",
48
+ "action_stats": [
49
+ [
50
+ [
51
+ 0.0005091051571071148,
52
+ 0.00015454298409167677,
53
+ 0.0008503405260853469,
54
+ 2.549389682826586e-05,
55
+ -0.0016934372251853347,
56
+ 0.0001974346669157967,
57
+ 0.5945848822593689
58
+ ],
59
+ [
60
+ 0.01032705046236515,
61
+ 0.015463403426110744,
62
+ 0.014627186581492424,
63
+ 0.027947312220931053,
64
+ 0.02963864989578724,
65
+ 0.10752595961093903,
66
+ 0.48598647117614746
67
+ ]
68
+ ],
69
+ [
70
+ [
71
+ 7.725676368863788e-06,
72
+ 1.0291715035748439e-08,
73
+ 1.0607171134324744e-05,
74
+ 0.021814998239278793,
75
+ 0.04335036501288414,
76
+ -0.005775672383606434,
77
+ 0.0009348704479634762,
78
+ 0.02302502654492855,
79
+ 0.951511561870575,
80
+ 0.0024383722338825464,
81
+ 0.006974320858716965,
82
+ 0.006260569207370281,
83
+ -0.012636600062251091
84
+ ],
85
+ [
86
+ 0.002023447770625353,
87
+ 5.1052771596005186e-05,
88
+ 0.005440520588308573,
89
+ 0.36035439372062683,
90
+ 0.15620605647563934,
91
+ 0.131776362657547,
92
+ 0.14597612619400024,
93
+ 0.14867420494556427,
94
+ 0.21220634877681732,
95
+ 0.04925995320081711,
96
+ 0.06934574991464615,
97
+ 0.06012243032455444,
98
+ 0.07359173148870468
99
+ ]
100
+ ],
101
+ [
102
+ [
103
+ 0.00014842326345387846,
104
+ -0.0005635050474666059
105
+ ],
106
+ [
107
+ 0.030163198709487915,
108
+ 0.042305462062358856
109
+ ]
110
+ ],
111
+ [
112
+ [
113
+ 0.14784614741802216,
114
+ -0.12238457798957825,
115
+ 0.05175986886024475,
116
+ -0.0750587061047554
117
+ ],
118
+ [
119
+ 0.4859539568424225,
120
+ 0.46922266483306885,
121
+ 0.5474159717559814,
122
+ 0.892055869102478
123
+ ]
124
+ ],
125
+ [
126
+ [
127
+ 0.0018891734071075916,
128
+ 0.00020449502335395664,
129
+ 0.0008909617899917066,
130
+ -0.002328932750970125,
131
+ 0.00014602410374209285,
132
+ -0.0036141453310847282,
133
+ 180.55873107910156,
134
+ 170.33209228515625,
135
+ 185.4281768798828,
136
+ 151.96224975585938,
137
+ 174.7725372314453,
138
+ 82.36817169189453,
139
+ 31.219266891479492,
140
+ 1.0021777153015137,
141
+ 0.6951844692230225,
142
+ 0.33550527691841125,
143
+ 0.48643627762794495,
144
+ 0.49088379740715027,
145
+ 0.7387751340866089,
146
+ 1.2315938472747803
147
+ ],
148
+ [
149
+ 0.014649114571511745,
150
+ 0.016541698947548866,
151
+ 0.012851938605308533,
152
+ 0.023206548765301704,
153
+ 0.020458899438381195,
154
+ 0.01998721994459629,
155
+ 28.153791427612305,
156
+ 34.738853454589844,
157
+ 25.903417587280273,
158
+ 53.904666900634766,
159
+ 32.97977066040039,
160
+ 50.42049789428711,
161
+ 31.806129455566406,
162
+ 0.532600998878479,
163
+ 0.4410288631916046,
164
+ 0.1271168440580368,
165
+ 0.375587522983551,
166
+ 0.33406388759613037,
167
+ 0.39596208930015564,
168
+ 0.5636153817176819
169
+ ]
170
+ ],
171
+ [
172
+ [
173
+ 411.012451171875,
174
+ 115.89311218261719,
175
+ 192.7315216064453,
176
+ -122.11846160888672,
177
+ -34.10114288330078,
178
+ 50.14624786376953,
179
+ 0.7442169785499573,
180
+ 0.03775591775774956
181
+ ],
182
+ [
183
+ 123.2818832397461,
184
+ 108.88812255859375,
185
+ 129.3837127685547,
186
+ 115.44219970703125,
187
+ 27.88010025024414,
188
+ 40.964500427246094,
189
+ 0.4363022744655609,
190
+ 0.1906062811613083
191
+ ]
192
+ ],
193
+ [
194
+ [
195
+ 0.46005550026893616,
196
+ 0.10042445361614227,
197
+ 0.2565317451953888,
198
+ -1.49611496925354,
199
+ -0.05020594596862793,
200
+ -0.49960270524024963,
201
+ 0.3615054190158844,
202
+ 0.4448537826538086,
203
+ 0.5080803632736206,
204
+ 0.26795846223831177,
205
+ 1.0428096055984497,
206
+ 0.13593977689743042,
207
+ 0.07739049196243286,
208
+ 0.33761394023895264
209
+ ],
210
+ [
211
+ 0.05466218292713165,
212
+ 0.057073455303907394,
213
+ 0.04310256987810135,
214
+ 2.5955734252929688,
215
+ 0.18967270851135254,
216
+ 0.24594734609127045,
217
+ 0.4793280363082886,
218
+ 0.06319230049848557,
219
+ 0.042976442724466324,
220
+ 0.04913073033094406,
221
+ 2.780755043029785,
222
+ 0.29540929198265076,
223
+ 0.7207639217376709,
224
+ 0.4693118929862976
225
+ ]
226
+ ],
227
+ [
228
+ [
229
+ 0.0007806111825630069,
230
+ 0.00014432442549150437,
231
+ -0.0002542313886806369,
232
+ 0.0012392610078677535,
233
+ -0.004758258815854788,
234
+ 0.002720650052651763,
235
+ 0.5114527940750122
236
+ ],
237
+ [
238
+ 0.008004522882401943,
239
+ 0.009107149206101894,
240
+ 0.009566642343997955,
241
+ 0.04117538407444954,
242
+ 0.038478001952171326,
243
+ 0.04605141654610634,
244
+ 0.4997573792934418
245
+ ]
246
+ ],
247
+ [
248
+ [
249
+ 0.07752787321805954,
250
+ 0.03168660029768944,
251
+ 0.04253425449132919,
252
+ 0.0,
253
+ 0.0,
254
+ -0.01578887365758419,
255
+ 0.13820601999759674
256
+ ],
257
+ [
258
+ 0.39060941338539124,
259
+ 0.2989002466201782,
260
+ 0.27805662155151367,
261
+ 0.0,
262
+ 0.0,
263
+ 0.08069407194852829,
264
+ 0.9916107058525085
265
+ ]
266
+ ],
267
+ [
268
+ [
269
+ 0.0007780246087349951,
270
+ -0.0003209571004845202,
271
+ -0.0015096653951331973,
272
+ -0.0007919073104858398,
273
+ -0.0001698181586107239,
274
+ 0.0001668699405854568
275
+ ],
276
+ [
277
+ 0.0033910737838596106,
278
+ 0.0049757580272853374,
279
+ 0.005356575828045607,
280
+ 0.007685699500143528,
281
+ 0.004107887391000986,
282
+ 0.008493276312947273
283
+ ]
284
+ ],
285
+ [
286
+ [
287
+ -6.816067616455257e-05,
288
+ 0.0032774528954178095,
289
+ -0.00014815537724643946,
290
+ -0.0009022713056765497,
291
+ -2.8410413506207988e-05,
292
+ -0.0028941172640770674,
293
+ -0.0007027793908491731,
294
+ 0.48355934023857117
295
+ ],
296
+ [
297
+ 0.002599249128252268,
298
+ 0.012856663204729557,
299
+ 0.005673616658896208,
300
+ 0.018063854426145554,
301
+ 0.0016916567692533135,
302
+ 0.021029461175203323,
303
+ 0.005813282448798418,
304
+ 0.49965402483940125
305
+ ]
306
+ ],
307
+ [
308
+ [
309
+ 0.00013238923565950245,
310
+ -0.0002748446131590754,
311
+ -1.063202307705069e-05,
312
+ -0.0003201523795723915,
313
+ 1.9412083929637447e-05,
314
+ 2.8738231776515022e-05,
315
+ 4.167715087532997e-05,
316
+ 0.47832271456718445
317
+ ],
318
+ [
319
+ 0.001529942499473691,
320
+ 0.00453608063980937,
321
+ 0.00077804084867239,
322
+ 0.0030119596049189568,
323
+ 0.0010617133229970932,
324
+ 0.005115122999995947,
325
+ 0.0041729300282895565,
326
+ 0.4998096525669098
327
+ ]
328
+ ],
329
+ [
330
+ [
331
+ 0.0005282792262732983,
332
+ 3.792543429881334e-05,
333
+ -0.00017229391960427165,
334
+ -0.0002936247910838574,
335
+ -0.0003607820253819227,
336
+ -0.00015651936701033264,
337
+ 3.999160526291234e-06,
338
+ 0.5698255300521851,
339
+ 0.0024680662900209427
340
+ ],
341
+ [
342
+ 0.0017995948437601328,
343
+ 0.0023995412047952414,
344
+ 0.0018815182847902179,
345
+ 0.03787316754460335,
346
+ 0.036895327270030975,
347
+ 0.0051568918861448765,
348
+ 0.0063850656151771545,
349
+ 0.49463561177253723,
350
+ 0.04957345128059387
351
+ ]
352
+ ],
353
+ [
354
+ [
355
+ 0.5278156995773315,
356
+ 0.027324192225933075,
357
+ 0.18753373622894287,
358
+ -0.013062607496976852,
359
+ 0.9998879432678223,
360
+ 0.0036541626323014498,
361
+ 0.016103636473417282,
362
+ 0.5530186891555786
363
+ ],
364
+ [
365
+ 0.08116760849952698,
366
+ 0.11193274706602097,
367
+ 0.07751200348138809,
368
+ 0.016174238175153732,
369
+ 0.0006333293858915567,
370
+ 0.007865803316235542,
371
+ 0.013844785280525684,
372
+ 0.4971791207790375
373
+ ]
374
+ ],
375
+ [
376
+ [
377
+ -1.0947237569780555e-05,
378
+ -0.0005107750184834003,
379
+ -0.0003358577669132501,
380
+ 0.0009847612818703055,
381
+ -4.268335032975301e-05,
382
+ -0.002284084912389517,
383
+ 0.00044375265133567154,
384
+ -0.0035538373049348593
385
+ ],
386
+ [
387
+ 0.012695086188614368,
388
+ 0.015595865435898304,
389
+ 0.012070356868207455,
390
+ 0.0009854151867330074,
391
+ 0.002212227089330554,
392
+ 0.02389119192957878,
393
+ 0.9998002648353577,
394
+ 0.027204347774386406
395
+ ]
396
+ ],
397
+ [
398
+ [
399
+ 0.00024059691349975765,
400
+ -0.0002765821700450033,
401
+ -0.007495959755033255,
402
+ 0.001606731559149921,
403
+ -0.006666666828095913
404
+ ],
405
+ [
406
+ 0.03234556317329407,
407
+ 0.03240759298205376,
408
+ 0.09059711545705795,
409
+ 0.17567689716815948,
410
+ 0.9999619126319885
411
+ ]
412
+ ],
413
+ [
414
+ [
415
+ 0.00014585924509447068,
416
+ 0.0010840297909453511,
417
+ 0.0006219727802090347,
418
+ -0.0011805534595623612,
419
+ -0.0002698474272619933,
420
+ 0.009135263971984386,
421
+ 0.9928883910179138,
422
+ 0.023449113592505455
423
+ ],
424
+ [
425
+ 0.016102593392133713,
426
+ 0.014900527894496918,
427
+ 0.014007486402988434,
428
+ 0.028777308762073517,
429
+ 0.05673017352819443,
430
+ 0.1638830006122589,
431
+ 0.02685263752937317,
432
+ 1.0003585815429688
433
+ ]
434
+ ],
435
+ [
436
+ [
437
+ -1.758056714606937e-05,
438
+ -3.7838817661395296e-05,
439
+ -0.00038337393198162317,
440
+ 2.7642881832434796e-05,
441
+ 8.176987466868013e-05,
442
+ 6.723385740770027e-05,
443
+ 1.0
444
+ ],
445
+ [
446
+ 0.0004286598414182663,
447
+ 0.0005066184094175696,
448
+ 0.0012515239650383592,
449
+ 0.0005647603538818657,
450
+ 0.0007578085060231388,
451
+ 0.000697559560649097,
452
+ 0.0
453
+ ]
454
+ ],
455
+ [
456
+ [
457
+ 0.0003578192263375968,
458
+ 0.0,
459
+ 0.0016306436154991388,
460
+ 0.0,
461
+ 0.0,
462
+ 0.0,
463
+ 0.3986235558986664,
464
+ 0.005404492374509573
465
+ ],
466
+ [
467
+ 0.004071093630045652,
468
+ 0.0,
469
+ 0.003779628314077854,
470
+ 0.0,
471
+ 0.0,
472
+ 0.0,
473
+ 0.4895978569984436,
474
+ 0.07332254946231842
475
+ ]
476
+ ],
477
+ [
478
+ [
479
+ -0.6983144283294678,
480
+ 0.13209794461727142,
481
+ 0.0039050180930644274,
482
+ 0.0022984833922237158,
483
+ 1.080454707145691,
484
+ 0.007018118165433407,
485
+ 0.0022595408372581005,
486
+ 0.011918018572032452,
487
+ 0.002541287336498499
488
+ ],
489
+ [
490
+ 23.353342056274414,
491
+ 0.30039435625076294,
492
+ 0.05812933295965195,
493
+ 0.04016513004899025,
494
+ 8.097066879272461,
495
+ 0.08314091712236404,
496
+ 0.047434594482183456,
497
+ 0.1077399030327797,
498
+ 0.05028529092669487
499
+ ]
500
+ ],
501
+ [
502
+ [
503
+ 0.0,
504
+ 0.0,
505
+ 0.0,
506
+ 0.0,
507
+ 0.0097132483497262,
508
+ -0.0013603067491203547,
509
+ 0.0015324887353926897,
510
+ 0.0
511
+ ],
512
+ [
513
+ 0.0,
514
+ 0.0,
515
+ 0.0,
516
+ 0.0,
517
+ 0.09806635230779648,
518
+ 0.01324226800352335,
519
+ 0.0157295111566782,
520
+ 0.0
521
+ ]
522
+ ],
523
+ [
524
+ [
525
+ 0.5393299460411072,
526
+ 0.0013205910800024867,
527
+ 0.3156873285770416,
528
+ 0.31605127453804016,
529
+ -0.09015484154224396,
530
+ -0.04938371106982231,
531
+ 0.4098437428474426
532
+ ],
533
+ [
534
+ 0.11740713566541672,
535
+ 0.1749163120985031,
536
+ 0.16180050373077393,
537
+ 2.7441937923431396,
538
+ 0.3496316969394684,
539
+ 0.7599275708198547,
540
+ 0.430283784866333
541
+ ]
542
+ ],
543
+ [
544
+ [
545
+ 0.0,
546
+ -0.664889395236969,
547
+ 0.19378533959388733,
548
+ 0.03492932394146919,
549
+ 0.006112735718488693,
550
+ 0.3858254551887512,
551
+ 0.006408770103007555,
552
+ 0.3634810745716095
553
+ ],
554
+ [
555
+ 0.0,
556
+ 0.5742393732070923,
557
+ 0.29617878794670105,
558
+ 0.32670700550079346,
559
+ 0.07788320630788803,
560
+ 0.12196606397628784,
561
+ 0.1935708373785019,
562
+ 0.10228537768125534
563
+ ]
564
+ ],
565
+ [
566
+ [
567
+ 2.9126644221832976e-05,
568
+ 0.00012177458847872913,
569
+ -0.00010258024121867493,
570
+ -7.230440678540617e-05,
571
+ 0.00022300859563983977,
572
+ 7.248757901834324e-05,
573
+ 0.08678723871707916
574
+ ],
575
+ [
576
+ 0.0028375224210321903,
577
+ 0.002654739422723651,
578
+ 0.0025536646135151386,
579
+ 0.017998138442635536,
580
+ 0.024928707629442215,
581
+ 0.02371523156762123,
582
+ 0.5916166305541992
583
+ ]
584
+ ],
585
+ [
586
+ [
587
+ 0.010949725285172462,
588
+ -0.0022855771239846945,
589
+ -0.005606506951153278,
590
+ 0.005102947819977999,
591
+ 0.028554478660225868,
592
+ 0.010207906365394592,
593
+ 0.6073462963104248
594
+ ],
595
+ [
596
+ 0.037630267441272736,
597
+ 0.022229386493563652,
598
+ 0.04602375999093056,
599
+ 0.057648736983537674,
600
+ 0.14510659873485565,
601
+ 0.07890969514846802,
602
+ 0.4869178235530853
603
+ ]
604
+ ],
605
+ [
606
+ [
607
+ -0.00010904707596637309,
608
+ 0.001118879416026175,
609
+ -0.000113105088530574,
610
+ -7.428864046232775e-05,
611
+ -0.0006643814849667251,
612
+ -5.145822433405556e-05,
613
+ 0.6308034658432007
614
+ ],
615
+ [
616
+ 0.04345196858048439,
617
+ 0.0452333502471447,
618
+ 0.12535300850868225,
619
+ 0.00519426679238677,
620
+ 0.011263296939432621,
621
+ 0.006335486192256212,
622
+ 0.39740097522735596
623
+ ]
624
+ ],
625
+ [
626
+ [
627
+ 0.06118948012590408,
628
+ 0.0038687449414283037
629
+ ],
630
+ [
631
+ 0.02585630863904953,
632
+ 0.00309689249843359
633
+ ]
634
+ ],
635
+ [
636
+ [
637
+ 0.06141459941864014,
638
+ 0.052588578313589096,
639
+ -0.041776858270168304,
640
+ -0.0006756950169801712,
641
+ -0.00152083788998425,
642
+ 0.003654239699244499,
643
+ -0.05188782885670662
644
+ ],
645
+ [
646
+ 0.3462270498275757,
647
+ 0.46296226978302,
648
+ 0.43901339173316956,
649
+ 0.023543866351246834,
650
+ 0.01941801607608795,
651
+ 0.161599263548851,
652
+ 0.9985517263412476
653
+ ]
654
+ ],
655
+ [
656
+ [
657
+ 0.10499999672174454,
658
+ 0.03899981826543808,
659
+ 3.1370865107016588e-12,
660
+ 0.28803354501724243
661
+ ],
662
+ [
663
+ 0.20360185205936432,
664
+ 0.22257499396800995,
665
+ 0.36332041025161743,
666
+ 0.3840075135231018
667
+ ]
668
+ ],
669
+ [
670
+ [
671
+ 0.0,
672
+ 0.0,
673
+ 0.04837876185774803,
674
+ 0.07777562737464905,
675
+ -0.07173224538564682,
676
+ 0.02389858104288578,
677
+ 0.1029987707734108
678
+ ],
679
+ [
680
+ 0.0,
681
+ 0.0,
682
+ 0.3481692373752594,
683
+ 0.2678722143173218,
684
+ 0.18070226907730103,
685
+ 0.1805194914340973,
686
+ 0.21225926280021667
687
+ ]
688
+ ],
689
+ [
690
+ [
691
+ 0.00021432287758216262,
692
+ -0.001017225906252861,
693
+ 0.0009490080992691219,
694
+ 0.0012771538458764553,
695
+ -4.03298472519964e-05,
696
+ 1.6517331459908746e-05,
697
+ 0.5683804750442505,
698
+ 0.023855386301875114
699
+ ],
700
+ [
701
+ 0.003075462533161044,
702
+ 0.006757265888154507,
703
+ 0.010920294560492039,
704
+ 0.024084405973553658,
705
+ 0.0031898680608719587,
706
+ 0.0039015202783048153,
707
+ 0.4952875077724457,
708
+ 0.15259981155395508
709
+ ]
710
+ ],
711
+ [
712
+ [
713
+ 0.0002901389671023935,
714
+ -0.008735175244510174,
715
+ -0.03063584677875042,
716
+ -0.0008425223059020936,
717
+ -0.01699526235461235,
718
+ -0.057361308485269547,
719
+ -0.0026800602208822966,
720
+ -0.024219293147325516,
721
+ -0.07954999059438705,
722
+ -0.004751862026751041,
723
+ -0.030450690537691116,
724
+ -0.09751348197460175,
725
+ -0.006432248279452324,
726
+ -0.03592115268111229,
727
+ -0.11248139292001724,
728
+ -0.007074976805597544,
729
+ -0.04055371135473251,
730
+ -0.12592929601669312,
731
+ -0.007033170200884342,
732
+ -0.04448538273572922,
733
+ -0.13845883309841156,
734
+ -0.006582007277756929,
735
+ -0.04794967174530029,
736
+ -0.15048547089099884,
737
+ -0.005856209900230169,
738
+ -0.05115331709384918,
739
+ -0.16212087869644165,
740
+ -0.004947068635374308,
741
+ -0.05432483181357384,
742
+ -0.17349986732006073,
743
+ 0.1654718667268753,
744
+ 0.15334908664226532,
745
+ 0.14446938037872314,
746
+ 0.1381329596042633,
747
+ 0.14069639146327972,
748
+ 0.15441010892391205,
749
+ 0.16613994538784027,
750
+ 0.17573139071464539,
751
+ 0.18380865454673767,
752
+ 0.19018495082855225,
753
+ -0.009925751015543938,
754
+ 0.000903706590179354,
755
+ 0.004995723720639944,
756
+ -0.01862962171435356,
757
+ 0.0023904081899672747,
758
+ 0.009501785971224308,
759
+ -0.025607934221625328,
760
+ 0.004396800417453051,
761
+ 0.013722752220928669,
762
+ -0.0311029814183712,
763
+ 0.006277366541326046,
764
+ 0.017449310049414635,
765
+ -0.035394586622714996,
766
+ 0.007309201173484325,
767
+ 0.0201618243008852,
768
+ -0.038875941187143326,
769
+ 0.006901388987898827,
770
+ 0.021353811025619507,
771
+ -0.04197891056537628,
772
+ 0.005615231115370989,
773
+ 0.021482987329363823,
774
+ -0.044972267001867294,
775
+ 0.0038469291757792234,
776
+ 0.02095557004213333,
777
+ -0.04788981378078461,
778
+ 0.0018653281731531024,
779
+ 0.020031819120049477,
780
+ -0.05081408843398094,
781
+ -0.00014616991393268108,
782
+ 0.018887164071202278
783
+ ],
784
+ [
785
+ 0.041665900498628616,
786
+ 0.04644826427102089,
787
+ 0.07724159955978394,
788
+ 0.06889741122722626,
789
+ 0.07856184244155884,
790
+ 0.13755755126476288,
791
+ 0.08888623118400574,
792
+ 0.10273010283708572,
793
+ 0.18770478665828705,
794
+ 0.10409166663885117,
795
+ 0.12153571844100952,
796
+ 0.2295888066291809,
797
+ 0.11617765575647354,
798
+ 0.13659057021141052,
799
+ 0.2621166706085205,
800
+ 0.12616446614265442,
801
+ 0.14888973534107208,
802
+ 0.2901250720024109,
803
+ 0.13466258347034454,
804
+ 0.15917454659938812,
805
+ 0.3126468062400818,
806
+ 0.14190249145030975,
807
+ 0.16750091314315796,
808
+ 0.3314359188079834,
809
+ 0.14818254113197327,
810
+ 0.17438767850399017,
811
+ 0.34811294078826904,
812
+ 0.1538163274526596,
813
+ 0.18024739623069763,
814
+ 0.3633021414279938,
815
+ 0.365497350692749,
816
+ 0.3581995368003845,
817
+ 0.35088542103767395,
818
+ 0.34727540612220764,
819
+ 0.3497294485569,
820
+ 0.3590857684612274,
821
+ 0.366073876619339,
822
+ 0.3722895681858063,
823
+ 0.3952684998512268,
824
+ 0.39616790413856506,
825
+ 0.0306396521627903,
826
+ 0.02315478026866913,
827
+ 0.020674167200922966,
828
+ 0.0541083849966526,
829
+ 0.039156123995780945,
830
+ 0.03597668185830116,
831
+ 0.07291583716869354,
832
+ 0.05113911256194115,
833
+ 0.048078753054142,
834
+ 0.08785141259431839,
835
+ 0.06023723632097244,
836
+ 0.05775951221585274,
837
+ 0.09968066215515137,
838
+ 0.0672508031129837,
839
+ 0.06552024185657501,
840
+ 0.10900193452835083,
841
+ 0.07301993668079376,
842
+ 0.07219257205724716,
843
+ 0.11649126559495926,
844
+ 0.0778062641620636,
845
+ 0.07776165008544922,
846
+ 0.12322770804166794,
847
+ 0.08162706345319748,
848
+ 0.08214502781629562,
849
+ 0.12892073392868042,
850
+ 0.08458641171455383,
851
+ 0.08537810295820236,
852
+ 0.1341981142759323,
853
+ 0.08686579018831253,
854
+ 0.08767224848270416
855
+ ]
856
+ ],
857
+ [
858
+ [
859
+ 0.0,
860
+ 0.0,
861
+ 0.0,
862
+ 0.06573251634836197,
863
+ 0.0,
864
+ 0.0,
865
+ -0.04171482101082802,
866
+ 0.055990710854530334,
867
+ 0.8763954043388367,
868
+ 0.0,
869
+ -0.0006664131069555879,
870
+ 0.0005193803808651865,
871
+ -0.002541124587878585
872
+ ],
873
+ [
874
+ 0.0,
875
+ 0.0,
876
+ 0.0,
877
+ 0.3761773705482483,
878
+ 0.0,
879
+ 0.0,
880
+ 0.15819795429706573,
881
+ 0.22519806027412415,
882
+ 0.31361594796180725,
883
+ 0.0,
884
+ 0.02349717728793621,
885
+ 0.035984065383672714,
886
+ 0.057680174708366394
887
+ ]
888
+ ],
889
+ [
890
+ [
891
+ -0.15317592024803162,
892
+ 0.0022776976693421602,
893
+ -0.0009320594253949821,
894
+ -0.00012546770449262112,
895
+ 0.0,
896
+ 0.0014520460972562432,
897
+ -0.0015851958887651563,
898
+ -0.0011786657851189375
899
+ ],
900
+ [
901
+ 0.9882915019989014,
902
+ 0.09682300686836243,
903
+ 0.08487523347139359,
904
+ 0.06578352302312851,
905
+ 0.0,
906
+ 0.04951652139425278,
907
+ 0.06362643837928772,
908
+ 0.061268579214811325
909
+ ]
910
+ ],
911
+ [
912
+ [
913
+ 0.05171596631407738,
914
+ 0.2228323519229889,
915
+ -0.14441145956516266,
916
+ 0.3582628071308136
917
+ ],
918
+ [
919
+ 0.5066923499107361,
920
+ 0.5439659953117371,
921
+ 0.6763138771057129,
922
+ 0.7307645678520203
923
+ ]
924
+ ],
925
+ [
926
+ [
927
+ 0.17030708491802216,
928
+ 0.05116971209645271,
929
+ -0.09447753429412842,
930
+ 0.0008643743931315839,
931
+ 0.0091375932097435,
932
+ 0.0027675393503159285,
933
+ -0.056102462112903595
934
+ ],
935
+ [
936
+ 0.3859453797340393,
937
+ 0.24690480530261993,
938
+ 0.4153101444244385,
939
+ 0.02985573559999466,
940
+ 0.06587829440832138,
941
+ 0.14103403687477112,
942
+ 0.9984298348426819
943
+ ]
944
+ ],
945
+ [
946
+ [
947
+ 1.1466649993963074e-05,
948
+ -6.314022175502032e-05,
949
+ -3.7018828152213246e-05,
950
+ -6.448457861552015e-05
951
+ ],
952
+ [
953
+ 0.018885519355535507,
954
+ 0.01785615272819996,
955
+ 0.019242558628320694,
956
+ 0.01844671182334423
957
+ ]
958
+ ],
959
+ [
960
+ [
961
+ 0.4066544473171234,
962
+ 0.6006699204444885,
963
+ 0.655235230922699,
964
+ 0.5799534320831299
965
+ ],
966
+ [
967
+ 0.12175972759723663,
968
+ 0.1993582844734192,
969
+ 0.11979881674051285,
970
+ 0.19911807775497437
971
+ ]
972
+ ],
973
+ [
974
+ [
975
+ 0.022955210879445076,
976
+ -0.00010563172691036016,
977
+ -0.011366250924766064,
978
+ -0.0015581168700009584,
979
+ 0.04793407768011093,
980
+ -0.006341158412396908,
981
+ 0.00134763540700078,
982
+ 0.001134222256951034
983
+ ],
984
+ [
985
+ 0.09014830738306046,
986
+ 0.00810029823333025,
987
+ 0.033557258546352386,
988
+ 0.013230394572019577,
989
+ 0.2136397510766983,
990
+ 0.012250976637005806,
991
+ 0.019736526533961296,
992
+ 0.007935167290270329
993
+ ]
994
+ ]
995
+ ],
996
+ "action_token_size": 64,
997
+ "arch": "STTransformerDecoder",
998
+ "attn_drop": 0.0,
999
+ "d_action": 28,
1000
+ "d_actions": [
1001
+ 7,
1002
+ 13,
1003
+ 2,
1004
+ 4,
1005
+ 100,
1006
+ 8,
1007
+ 14,
1008
+ 35,
1009
+ 70,
1010
+ 6,
1011
+ 16,
1012
+ 120,
1013
+ 18,
1014
+ 80,
1015
+ 8,
1016
+ 5,
1017
+ 40,
1018
+ 7,
1019
+ 8,
1020
+ 9,
1021
+ 40,
1022
+ 49,
1023
+ 120,
1024
+ 7,
1025
+ 105,
1026
+ 7,
1027
+ 2,
1028
+ 7,
1029
+ 20,
1030
+ 35,
1031
+ 8,
1032
+ 350,
1033
+ 13,
1034
+ 40,
1035
+ 12,
1036
+ 21,
1037
+ 60,
1038
+ 4,
1039
+ 8
1040
+ ],
1041
+ "d_model": 256,
1042
+ "dataloader_apply_corruption": true,
1043
+ "dataloader_apply_mask": true,
1044
+ "dataloader_mask_ratio_min": 0.2,
1045
+ "drop_action_ratio": 0.0,
1046
+ "factored_vocab_size": 512,
1047
+ "image_vocab_size": 262144,
1048
+ "init_actions": true,
1049
+ "jointly_predict_actions": false,
1050
+ "jointly_predict_states": true,
1051
+ "label_drop_prob": 0.5,
1052
+ "max_corrupt_rate": 0.2,
1053
+ "mlp_bias": true,
1054
+ "mlp_drop": 0.0,
1055
+ "mlp_ratio": 4.0,
1056
+ "non_mlm_ratio": 0.2,
1057
+ "num_factored_vocabs": 2,
1058
+ "num_heads": 8,
1059
+ "num_layers": 32,
1060
+ "num_prompt_frames": 4,
1061
+ "proj_bias": true,
1062
+ "qk_norm": false,
1063
+ "qkv_bias": false,
1064
+ "random_dummy_action": true,
1065
+ "shared_action_mlps": true,
1066
+ "use_actions": true,
1067
+ "use_mup": true
1068
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcfcc359c49c1c40c9dcd9557f6a7cad76180a6bce3e017794d7d44e69eaae41
3
+ size 1469035472