File size: 1,085 Bytes
8efa30f
7978894
 
 
 
8efa30f
7978894
 
 
 
 
 
 
 
 
d7c6a2a
7978894
 
 
 
 
 
 
d7c6a2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7978894
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
---

tags:
- deep-reinforcement-learning
- reinforcement-learning

---


Find here pretrained model weights for the [Decision Transformer] (https://github.com/kzl/decision-transformer).
Weights are available for 4 Atari games: Breakout, Pong, Qbert and Seaquest. Found in the checkpoints directory.
We share models trained for one seed (123), whereas the paper contained weights for 3 random seeds.


### Usage

```

git clone https://huggingface.co/edbeeching/decision_transformer_atari

conda env create -f conda_env.yml

```

Then, you can use the model like this:

```python



from decision_transform_atari import GPTConfig, GPT



vocab_size = 4

block_size = 90

model_type = "reward_conditioned"

timesteps = 2654



mconf = GPTConfig(

    vocab_size,

    block_size,

    n_layer=6,

    n_head=8,

    n_embd=128,

    model_type=model_type,

    max_timestep=timesteps,

)

model = GPT(mconf)



checkpoint_path = "checkpoints/Breakout_123.pth"  # or Pong, Qbert, Seaquest

checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint)

```