ksridhar commited on
Commit
8e8a936
1 Parent(s): b5d2331

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ONLY_RL_TASKS": true,
3
+ "_name_or_path": "checkpoints/jat-regent-medium-10.0lamda-1.0MDM-1.0ADM-p95DN-resnet18_512ADT_embeddings/checkpoint-27726",
4
+ "action_loss_coef": 1.0,
5
+ "action_vocab_size": 18,
6
+ "activation_function": "gelu_new",
7
+ "architectures": [
8
+ "JatRegentModel"
9
+ ],
10
+ "atari_dist_multiplier": 1.0,
11
+ "atari_dist_type": "resnet18_512",
12
+ "attention_dropout": 0.0,
13
+ "attention_layers": [
14
+ "global",
15
+ "local",
16
+ "global",
17
+ "local",
18
+ "global",
19
+ "local",
20
+ "global",
21
+ "local",
22
+ "global",
23
+ "local",
24
+ "global",
25
+ "local"
26
+ ],
27
+ "attention_types": [
28
+ [
29
+ [
30
+ "global",
31
+ "local"
32
+ ],
33
+ 6
34
+ ]
35
+ ],
36
+ "auto_map": {
37
+ "AutoConfig": "configuration_jat.JatConfig",
38
+ "AutoModelForCausalLM": "modeling_jat_regent.JatRegentModel"
39
+ },
40
+ "bos_token_id": 50256,
41
+ "classifier_dropout": 0.1,
42
+ "dist_normalizer": "p95",
43
+ "embed_dropout": 0.0,
44
+ "eos_token_id": 50256,
45
+ "finetune_num_demos": null,
46
+ "hidden_size": 768,
47
+ "image_size": 224,
48
+ "initializer_range": 0.02,
49
+ "intermediate_size": null,
50
+ "lamda": 10.0,
51
+ "layer_norm_epsilon": 1e-05,
52
+ "max_continuous_size": 513,
53
+ "max_discrete_value": 212,
54
+ "max_position_embeddings": 40,
55
+ "model_type": "jat",
56
+ "mujoco_dist_multiplier": 1.0,
57
+ "num_channels": 3,
58
+ "num_contexts": 20,
59
+ "num_heads": 12,
60
+ "num_layers": 12,
61
+ "observation_loss_coef": 0.0,
62
+ "patch_size": 16,
63
+ "resid_dropout": 0.0,
64
+ "tokenizer_class": "GPT2TokenizerFast",
65
+ "torch_dtype": "float32",
66
+ "transformers_version": "4.41.2",
67
+ "use_atari_embeddings": true,
68
+ "use_cache": true,
69
+ "use_global_atari_actions": true,
70
+ "vocab_size": 50257,
71
+ "window_size": 256
72
+ }
configuration_jat.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPTNeoConfig
2
+
3
+
4
+ class JatConfig(GPTNeoConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a [`JatModel`]. It is used to instantiate a Jat
7
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with
8
+ the defaults will yield a similar configuration to that of the ... (TODO)
9
+
10
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
11
+ documentation from [`PretrainedConfig`] for more information.
12
+
13
+
14
+ Args:
15
+ vocab_size (`int`, *optional*, defaults to 50257):
16
+ Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the
17
+ `inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different
18
+ tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`].
19
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
20
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
21
+ just in case (e.g., 512 or 1024 or 2048).
22
+ hidden_size (`int`, *optional*, defaults to 2048):
23
+ Dimensionality of the encoder layers and the pooler layer.
24
+ num_layers (`int`, *optional*, defaults to 24):
25
+ Number of hidden layers in the Transformer encoder.
26
+ attention_types (`List`, *optional*, defaults to `[[["global", "local"], 12]]`):
27
+ The type of attention for each layer in a `List` of the following format `[[["attention_type"],
28
+ num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the
29
+ value of `attention_type` from `["global", "local"]`
30
+ num_heads (`int`, *optional*, defaults to 16):
31
+ Number of attention heads for each attention layer in the Transformer encoder.
32
+ intermediate_size (`int`, *optional*, defaults to 8192):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ window_size (`int`, *optional*, defaults to 256):
35
+ The size of the sliding window for local attention.
36
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`):
37
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
38
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
39
+ resid_dropout (`float`, *optional*, defaults to 0.0):
40
+ Residual dropout used in the attention pattern.
41
+ embed_dropout (`float`, *optional*, defaults to 0.0):
42
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
43
+ attention_dropout (`float`, *optional*, defaults to 0.0):
44
+ The dropout ratio for the attention probabilities.
45
+ classifier_dropout (`float`, *optional*, defaults to 0.1):
46
+ Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`]. The
47
+ dropout ratio for the hidden layer.
48
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
49
+ The epsilon used by the layer normalization layers.
50
+ initializer_range (`float`, *optional*, defaults to 0.02):
51
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
54
+ relevant if `config.is_decoder=True`.
55
+ bos_token_id (`int`, *optional*, defaults to 50256):
56
+ The id of the beginning of sentence token in the vocabulary.
57
+ eos_token_id (`int`, *optional*, defaults to 50256):
58
+ The id of the end of sentence token in the vocabulary.
59
+ max_continuous_size (`int`, *optional*, default to 376):
60
+ The maximum size of the continuous values.
61
+ max_discrete_value (`int`, *optional*, default to 18):
62
+ The maximum value of the discrete values.
63
+ image_size (`int`, *optional*, defaults to 224):
64
+ The size (resolution) of each image.
65
+ patch_size (`int`, *optional*, defaults to 16):
66
+ The size (resolution) of each patch.
67
+ observation_loss_coef (`float`, *optional*, defaults to 0.005):
68
+ The coefficient for the observation loss. When set to 0.0, the observation is not even predicted.
69
+ action_loss_coef (`float`, *optional*, defaults to 0.995):
70
+ The coefficient for the action loss.
71
+ """
72
+
73
+ model_type = "jat"
74
+
75
+ def __init__(
76
+ self,
77
+ vocab_size=50257,
78
+ max_position_embeddings=2048,
79
+ hidden_size=2048,
80
+ num_layers=24,
81
+ attention_types=[[["global", "local"], 12]],
82
+ num_heads=16,
83
+ intermediate_size=None,
84
+ window_size=256,
85
+ activation_function="gelu_new",
86
+ resid_dropout=0.0,
87
+ embed_dropout=0.0,
88
+ attention_dropout=0.0,
89
+ classifier_dropout=0.1,
90
+ layer_norm_epsilon=1e-5,
91
+ initializer_range=0.02,
92
+ use_cache=True,
93
+ bos_token_id=50256,
94
+ eos_token_id=50256,
95
+ max_continuous_size=377,
96
+ max_discrete_value=18,
97
+ image_size=224,
98
+ num_channels=3,
99
+ patch_size=16,
100
+ observation_loss_coef=0.005,
101
+ action_loss_coef=0.995,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(
105
+ vocab_size,
106
+ max_position_embeddings,
107
+ hidden_size,
108
+ num_layers,
109
+ attention_types,
110
+ num_heads,
111
+ intermediate_size,
112
+ window_size,
113
+ activation_function,
114
+ resid_dropout,
115
+ embed_dropout,
116
+ attention_dropout,
117
+ classifier_dropout,
118
+ layer_norm_epsilon,
119
+ initializer_range,
120
+ use_cache,
121
+ bos_token_id,
122
+ eos_token_id,
123
+ **kwargs,
124
+ )
125
+ self.max_continuous_size = max_continuous_size
126
+ self.max_discrete_value = max_discrete_value
127
+ self.image_size = image_size
128
+ self.num_channels = num_channels
129
+ self.patch_size = patch_size
130
+ self.observation_loss_coef = observation_loss_coef
131
+ self.action_loss_coef = action_loss_coef
132
+
133
+
134
+ JatConfig.register_for_auto_class()
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.41.2"
6
+ }
modeling_jat_regent.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from gymnasium import spaces
9
+ from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
10
+ from transformers import GPTNeoModel, GPTNeoPreTrainedModel
11
+ from transformers.modeling_outputs import ModelOutput
12
+ from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
13
+ import torch.nn.functional as F
14
+
15
+
16
+ from jat.configuration_jat import JatConfig
17
+ from jat.processing_jat import JatProcessor
18
+ from jat.modeling_jat import JatModel, compute_mse_loss, cyclic_expand_dim, JatOutput
19
+ from jat_regent.utils import build_index_vector, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_vector, myprint, L2dist, get_dist_stats, get_images_of_retrieved_obs, get_emb_transform_model_dim, get_optional_suffix
20
+ from jat_regent.atari_utils import convert_local_to_global_action, convert_global_to_local_action
21
+ from jat_regent.eval.rl import SEEN_TASK_NAME_TO_ENV_ID, UNSEEN_TASK_NAME_TO_ENV_ID
22
+ from PIL import Image
23
+ import os
24
+ from copy import deepcopy
25
+ from pytorch_msssim import ssim
26
+ import json
27
+
28
+
29
+ def cross_entropy_from_softmax(softmax_probs, targets, reduction="mean", epsilon=1e-9):
30
+ """
31
+ Calculate the cross entropy loss given softmax_probs and targets.
32
+
33
+ :param softmax_probs: tensor containing softmax probabilities
34
+ :param targets: tensor containing the target classes (not one-hot encoded)
35
+ :return: cross entropy loss
36
+ """
37
+ assert len(softmax_probs.shape) == 2, "softmax_probs should be of shape (batch_size, num_classes)"
38
+ assert len(targets.shape) == 1, "targets should be of shape (batch_size,)"
39
+
40
+ # Convert targets to one-hot encoding
41
+ targets_one_hot = F.one_hot(targets, num_classes=softmax_probs.shape[1]).float() # shape: (batch_size, num_classes)
42
+
43
+ # Calculate the cross entropy loss
44
+ softmax_probs = softmax_probs.clamp(min=epsilon, max=1-epsilon) # to avoid NaNs from log(0) and instabilities from log(1)
45
+ log_softmax_probs = softmax_probs.log() # safe to take log as softmax_probs are non-zero
46
+ loss = -torch.sum(targets_one_hot * log_softmax_probs, dim=1)
47
+
48
+ if reduction == "mean":
49
+ return loss.mean()
50
+ elif reduction == "sum":
51
+ return loss.sum()
52
+ elif reduction == "none":
53
+ return loss
54
+ else:
55
+ raise ValueError("reduction should be one of 'mean', 'sum', or 'none'")
56
+
57
+
58
+ def compute_ce_loss_from_softmax(
59
+ logits: FloatTensor, labels: torch.LongTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None
60
+ ) -> FloatTensor:
61
+ """
62
+ Compute the Cross Entropy (CE) loss between predicted logits and true class labels, considering valid timesteps.
63
+
64
+ Args:
65
+ logits (`FloatTensor` of shape `(batch_size, max_seq_len, [inner_size,] num_classes)`):
66
+ Predicted logits at the output of the model.
67
+ labels (`torch.LongTensor` of shape `(batch_size, max_seq_len, [inner_size,])`):
68
+ Ground truth class labels.
69
+ mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*):
70
+ Boolean mask indicating valid timesteps.
71
+ weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*):
72
+ Weights to be applied to the loss.
73
+
74
+ Returns:
75
+ loss (`FloatTensor` of shape `(,)`):
76
+ CE loss between predicted logits and true class labels.
77
+ """
78
+ if mask is not None:
79
+ logits = logits[mask.bool()] # (Y, X, C)
80
+ labels = labels[mask.bool()] # (Y, X)
81
+ if weights is not None:
82
+ weights = weights[mask.bool()] # (Y,)
83
+ else:
84
+ logits = logits.flatten(end_dim=2) # (B, L, X, C) -> (B*L, X, C)
85
+ labels = labels.flatten(end_dim=1) # (B, L, X) -> (B*L, X)
86
+ if weights is not None:
87
+ weights = weights.flatten(end_dim=1) # (B, L) -> (B*L,)
88
+
89
+ loss = cross_entropy_from_softmax(logits.view(-1, logits.size(-1)), labels.view(-1), reduction="none") # (Y*X,) # we don't use F.cross_entropy here to avoid double softmax
90
+ loss = loss.view(labels.size()) # (Y, X)
91
+ loss = loss.mean(-1) # (Y,)
92
+
93
+ # Multiply the loss by the weights
94
+ if weights is not None:
95
+ loss = loss * weights # (Y,)
96
+
97
+ # Average the loss
98
+ loss = loss.mean()
99
+
100
+ return loss
101
+
102
+
103
+ def crazy_relu(x, beta):
104
+ return nn.LeakyReLU(beta)(x) - (1-beta) * nn.ReLU()(x-1)
105
+
106
+
107
+ class JatRegentModel(JatModel):
108
+ """
109
+ Jat Regent model.
110
+ """
111
+ def __init__(self, config: JatConfig) -> None:
112
+ super().__init__(config)
113
+ hidden_size = config.hidden_size
114
+ action_vocab_size = config.action_vocab_size
115
+
116
+ if config.ONLY_RL_TASKS:
117
+ self.single_discrete_decoder = nn.Linear(hidden_size, action_vocab_size, bias=False)
118
+ self.N = config.action_vocab_size
119
+ else:
120
+ self.N = config.vocab_size
121
+ self.multi_discrete_decoder = None # not needed
122
+ self.image_decoder = None # not needed
123
+ self.num_contexts = config.num_contexts # used in get_next_action() at evaluation in an env only
124
+ self.lamda = config.lamda # used in get_next_action() at evaluation in an env only
125
+ self.use_global_atari_actions = config.use_global_atari_actions
126
+ self.dist_multipliers = {'mujoco': config.mujoco_dist_multiplier, 'atari': config.atari_dist_multiplier}
127
+ self.dist_normalizer = config.dist_normalizer
128
+ self.atari_dist_type = config.atari_dist_type
129
+ self.use_atari_embeddings = config.use_atari_embeddings
130
+ self.finetune_num_demos = config.finetune_num_demos if hasattr(config, 'finetune_num_demos') else None
131
+ if self.use_atari_embeddings:
132
+ self.image_encoder = None
133
+ self.emb_dim_full = (512,)
134
+
135
+ # print number of parameters
136
+ num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
137
+ myprint(f"number of parameters: {num_params / 1e6:.4f}M")
138
+
139
+ def retrieval_setup(self,
140
+ task,
141
+ dataset,
142
+ num_demos, # to retrieve from
143
+ device,
144
+ batch_size_retrieval=16, # for atari envs on gpu
145
+ nb_cores_autofaiss=8, # for vector obs envs on cpu cores
146
+ ):
147
+ # setup
148
+ rew_key, attn_key, obs_key, act_key, B, obs_dim, act_dim = get_task_info(task)
149
+ extra_key = 'discrete_RandP_action_logits' if task.startswith("atari") or task.startswith("babyai") else 'continuous_RandP_actions'
150
+ optional_suffix = get_optional_suffix(task, self.atari_dist_type, self.finetune_num_demos)
151
+ mean_dist, std_dist, max_dist, p80, p85, p90, p95, p99 = get_dist_stats(task=task, optional_suffix=optional_suffix)
152
+
153
+ # get embedding model
154
+ if task.startswith("atari"):
155
+ self.emb_transform, self.emb_model, emb_dim, self.emb_model_full = get_emb_transform_model_dim(self.atari_dist_type, self.device, return_emb_weights=True)
156
+ obs_dim = emb_dim # overwrite for atari_dist_type
157
+
158
+ kwargs = {'B': B,
159
+ 'obs_dim': obs_dim,
160
+ 'attn_key': attn_key,
161
+ 'obs_key': obs_key,
162
+ 'device': device,
163
+ 'task': task,
164
+ 'batch_size_retrieval': batch_size_retrieval,
165
+ 'nb_cores_autofaiss': nb_cores_autofaiss,
166
+ 'verbose': False,
167
+ 'atari_dist_type': self.atari_dist_type,
168
+ }
169
+ raw_obs_dim = obs_dim
170
+ if task.startswith("atari"): # overwrite raw_obs_dim because raw obs in atari are (4, 84, 84) and raw obs in babyai have 64 extra dim
171
+ raw_obs_dim = (4, 84, 84)
172
+ elif task.startswith("babyai"):
173
+ raw_obs_dim = (obs_dim[0]+64,)
174
+
175
+ # save
176
+ self.task = task
177
+ self.dataset = dataset
178
+ self.obs_key = obs_key
179
+ self.act_key = act_key
180
+ self.rew_key = rew_key
181
+ self.attn_key = attn_key
182
+ self.obs_dim = obs_dim
183
+ self.act_dim = act_dim
184
+ self.extra_key = extra_key
185
+ self.kwargs = kwargs
186
+ self.raw_obs_dim = raw_obs_dim
187
+ self.max_dist = max_dist
188
+ self.mean_dist = mean_dist
189
+ self.std_dist = std_dist
190
+ self.p80, self.p85, self.p90, self.p95, self.p99 = p80, p85, p90, p95, p99
191
+ self.dist_normalizer_value = {'std': std_dist, 'max': max_dist, 'p80': p80, 'p85': p85, 'p90': p90, 'p95': p95, 'p99': p99}[self.dist_normalizer]
192
+ if self.dist_normalizer_value == 0.0: self.dist_normalizer_value = 1.0
193
+
194
+ # for retrieval,
195
+ all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs, all_datarows_dict = collect_all_data(dataset, task, obs_key, num_demos, return_datarows_dict=True, atari_dist_type=self.atari_dist_type)
196
+ if task.startswith("babyai"):
197
+ # for each mission in task,
198
+ self.all_indices = {}
199
+ self.knn_index = {}
200
+ for mission_idx, mission in enumerate(all_row_idxs.keys()):
201
+ # create index, collect subset of data that we can retrieve from
202
+ myprint(('*'*50) + f'{mission=} - {mission_idx+1}/{len(all_row_idxs.keys())}')
203
+ self.all_indices[mission], self.knn_index[mission] = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG[mission],
204
+ all_attn_masks_OG=all_attn_masks_OG[mission],
205
+ all_row_idxs=all_row_idxs[mission],
206
+ kwargs=kwargs)
207
+ else:
208
+ # create index, collect subset of data that we can retrieve from
209
+ self.all_indices, self.knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
210
+ all_attn_masks_OG=all_attn_masks_OG,
211
+ all_row_idxs=all_row_idxs,
212
+ kwargs=kwargs)
213
+
214
+ # for retrieval inside retrieve()
215
+ self.datarows = all_datarows_dict
216
+
217
+
218
+ # # for checking if first env state is similar to retrieval episode's first states
219
+ # if task.startswith("mujoco"):
220
+ # local_path = f"dataset_jat_regent/{task}"
221
+ # with open(f"{local_path}/eps_2_rows_tokenized.json", 'r') as f:
222
+ # eps_2_rows_tokenized = json.load(f)
223
+ # eps_2_rows_tokenized = {int(k): v for k, v in eps_2_rows_tokenized.items()}
224
+ # row_idxs_of_first_state_of_demos = [eps_2_rows_tokenized[eps][0] for eps in range(num_demos)]
225
+ # self.first_states_of_demos = [np.array(dataset['train'][row_idx][obs_key][0]) for row_idx in row_idxs_of_first_state_of_demos]
226
+ # else:
227
+ # self.first_states_of_demos = None
228
+
229
+ def output_rl(
230
+ self,
231
+ transformer_outputs,
232
+ continuous_observations: Optional[FloatTensor] = None,
233
+ discrete_observations: Optional[LongTensor] = None,
234
+ image_observations: Optional[FloatTensor] = None,
235
+ continuous_actions: Optional[FloatTensor] = None,
236
+ discrete_actions: Optional[LongTensor] = None,
237
+ rewards: Optional[FloatTensor] = None,
238
+ attention_mask: Optional[BoolTensor] = None,
239
+ return_loss: bool = True,
240
+ return_dict: Optional[bool] = None,
241
+ loss_weight: Optional[FloatTensor] = None,
242
+ exp_lamda_distances: Optional[FloatTensor] = None,
243
+ continuous_RandP_actions: Optional[FloatTensor] = None,
244
+ discrete_RandP_action_logits: Optional[FloatTensor] = None,
245
+ ):
246
+ hidden_states = transformer_outputs.last_hidden_state
247
+ loss, observation_loss, action_loss = None, None, None
248
+
249
+ # Observations
250
+ assert rewards is not None
251
+ observations_mask = attention_mask[:, 1::2] if attention_mask is not None else None
252
+ assert self.observation_loss_coef == 0.0, f'{self.observation_loss_coef=} should be 0.0 as we are not predicting observations!'
253
+ # warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
254
+ pred_observations = None
255
+ observation_loss = 0.0
256
+
257
+ # Actions
258
+ actions_mask = attention_mask[:, ::2] if attention_mask is not None else None
259
+ if continuous_actions is not None:
260
+ act_size = continuous_actions.shape[-1]
261
+ continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size)
262
+ continuous_RandP_actions = cyclic_expand_dim(continuous_RandP_actions, self.config.max_continuous_size)
263
+ init_pred_actions = self.continuous_decoder(hidden_states[:, ::2])
264
+ pred_actions = self.continuous_action_interpolation(init_pred_actions, exp_lamda_distances, continuous_RandP_actions, beta=0.0)
265
+ if return_loss:
266
+ action_loss = compute_mse_loss(pred_actions, continuous_actions, actions_mask, weights=loss_weight) # loss_weight is usually 50 for metaworld, 10 for mujoco (except two tasks where it is 20, 50), 1 for the rest!
267
+ pred_actions = pred_actions[..., :act_size]
268
+ elif discrete_actions is not None:
269
+ init_pred_actions = self.single_discrete_decoder(hidden_states[:, ::2])
270
+ pred_actions = self.discrete_action_interpolation(init_pred_actions, exp_lamda_distances, discrete_RandP_action_logits, beta=0.0)
271
+ if return_loss:
272
+ action_loss = compute_ce_loss_from_softmax(pred_actions, discrete_actions, actions_mask, weights=loss_weight)
273
+
274
+ # Return output
275
+ if return_loss:
276
+ loss = self.observation_loss_coef * observation_loss + self.action_loss_coef * action_loss
277
+
278
+ if not return_dict:
279
+ output = (pred_observations, pred_actions) + transformer_outputs[1:]
280
+ return ((loss, observation_loss, action_loss) + output) if loss is not None else output
281
+
282
+ return JatOutput(
283
+ loss=loss,
284
+ observation_loss=observation_loss,
285
+ action_loss=action_loss,
286
+ pred_observations=pred_observations,
287
+ pred_actions=pred_actions,
288
+ past_key_values=transformer_outputs.past_key_values,
289
+ hidden_states=transformer_outputs.hidden_states,
290
+ attentions=transformer_outputs.attentions,
291
+ )
292
+
293
+ def shifted_crazy_relu(self, x, beta):
294
+ return 2 * crazy_relu(0.5*(x+1), beta) - 1
295
+
296
+ def continuous_action_interpolation(self, init_pred_actions, exp_lamda_distances, continuous_RandP_actions, beta=0.0):
297
+ batch_size, max_seq_len, act_size = init_pred_actions.shape
298
+ assert (init_pred_actions.shape == (batch_size, max_seq_len, act_size) and
299
+ exp_lamda_distances.shape == (batch_size, max_seq_len, 1) and
300
+ continuous_RandP_actions.shape == (batch_size, max_seq_len, act_size)), f'{init_pred_actions.shape=}, {exp_lamda_distances.shape=}, {continuous_RandP_actions.shape=}, {(batch_size, max_seq_len, act_size)=}'
301
+
302
+ """ MCNN interpolation (https://arxiv.org/abs/2310.06171) """
303
+ act_fn = self.shifted_crazy_relu
304
+ final_actions = exp_lamda_distances * continuous_RandP_actions + 10.0 * (1 - exp_lamda_distances) * act_fn(init_pred_actions, beta=beta)
305
+ return final_actions
306
+
307
+ def discrete_action_interpolation(self, init_pred_actions, exp_lamda_distances, discrete_RandP_action_logits, beta=0.0):
308
+ batch_size, max_seq_len, action_vocab_size = init_pred_actions.shape
309
+ assert (init_pred_actions.shape == (batch_size, max_seq_len, action_vocab_size) and
310
+ exp_lamda_distances.shape == (batch_size, max_seq_len, 1) and
311
+ discrete_RandP_action_logits.shape == (batch_size, max_seq_len, action_vocab_size)), f'{init_pred_actions.shape=}, {exp_lamda_distances.shape=}, {discrete_RandP_action_logits.shape=}, {(batch_size, max_seq_len, action_vocab_size)=}'
312
+
313
+ """ MCNN-like interpolation """
314
+ # print(f'{torch.round(discrete_RandP_action_logits[:, -1],decimals=2)=}')
315
+ # print(f'{torch.round(F.softmax(init_pred_actions, dim=-1)[:, -1],decimals=2)=}')
316
+ # print(f'{torch.round(exp_lamda_distances[:, -1],decimals=2)=}')
317
+ # print(f'first term: {torch.round((exp_lamda_distances * discrete_RandP_action_logits)[:, -1],decimals=2)}')
318
+ # print(f'second term: {torch.round(((1 - exp_lamda_distances) * F.softmax(init_pred_actions, dim=-1))[:, -1],decimals=2)}')
319
+ final_actions = exp_lamda_distances * discrete_RandP_action_logits + (1 - exp_lamda_distances) * F.softmax(init_pred_actions, dim=-1)
320
+ return final_actions
321
+
322
+ # Copied the forward function from the Parent class with the addition of the last 3 args in the input args and in output_rl args
323
+ def forward(
324
+ self,
325
+ input_ids: Optional[LongTensor] = None,
326
+ pixel_values: Optional[FloatTensor] = None,
327
+ continuous_observations: Optional[FloatTensor] = None,
328
+ discrete_observations: Optional[LongTensor] = None,
329
+ image_observations: Optional[FloatTensor] = None,
330
+ continuous_actions: Optional[FloatTensor] = None,
331
+ discrete_actions: Optional[LongTensor] = None,
332
+ rewards: Optional[FloatTensor] = None,
333
+ past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None,
334
+ attention_mask: Optional[BoolTensor] = None,
335
+ token_type_ids: Optional[LongTensor] = None,
336
+ position_ids: Optional[LongTensor] = None,
337
+ return_loss: bool = True,
338
+ use_cache: Optional[bool] = None,
339
+ output_attentions: Optional[bool] = None,
340
+ output_hidden_states: Optional[bool] = None,
341
+ return_dict: Optional[bool] = None,
342
+ loss_weight: Optional[FloatTensor] = None,
343
+ exp_lamda_distances: Optional[FloatTensor] = None,
344
+ continuous_RandP_actions: Optional[FloatTensor] = None,
345
+ discrete_RandP_action_logits: Optional[FloatTensor] = None,
346
+ ) -> JatOutput:
347
+
348
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
349
+
350
+ # Textual tasks
351
+ if input_ids is not None or pixel_values is not None:
352
+ inputs_embeds, attention_mask = self.embed_textual(input_ids, pixel_values, attention_mask)
353
+ # RL tasks
354
+ elif (
355
+ continuous_observations is not None or discrete_observations is not None or image_observations is not None
356
+ ):
357
+ inputs_embeds, attention_mask = self.embed_rl(
358
+ continuous_observations,
359
+ discrete_observations,
360
+ image_observations,
361
+ continuous_actions,
362
+ discrete_actions,
363
+ rewards,
364
+ attention_mask,
365
+ )
366
+ else:
367
+ raise ValueError("Input not provided.")
368
+
369
+ # Pass through transformer
370
+ transformer_outputs = self.transformer(
371
+ past_key_values=past_key_values,
372
+ attention_mask=attention_mask,
373
+ token_type_ids=token_type_ids,
374
+ position_ids=position_ids,
375
+ inputs_embeds=inputs_embeds,
376
+ use_cache=use_cache,
377
+ output_attentions=output_attentions,
378
+ output_hidden_states=output_hidden_states,
379
+ return_dict=return_dict,
380
+ )
381
+
382
+ if input_ids is not None or pixel_values is not None:
383
+ return self.output_textual(transformer_outputs, input_ids, attention_mask, return_loss, return_dict)
384
+ else:
385
+ return self.output_rl(
386
+ transformer_outputs,
387
+ continuous_observations,
388
+ discrete_observations,
389
+ image_observations,
390
+ continuous_actions,
391
+ discrete_actions,
392
+ rewards,
393
+ attention_mask,
394
+ return_loss,
395
+ return_dict,
396
+ loss_weight,
397
+ exp_lamda_distances,
398
+ continuous_RandP_actions,
399
+ discrete_RandP_action_logits,
400
+ )
401
+
402
+
403
+ def reset_rl(self):
404
+ self.steps = 0
405
+
406
+ def process(
407
+ self,
408
+ processor: JatProcessor,
409
+ continuous_observation: Optional[List[float]] = None,
410
+ discrete_observation: Optional[List[int]] = None,
411
+ text_observation: Optional[str] = None,
412
+ image_observation: Optional[np.ndarray] = None,
413
+ action_space: Union[spaces.Box, spaces.Discrete] = None,
414
+ reward: Optional[float] = None,
415
+ deterministic: bool = True,
416
+ context_window: Optional[int] = None,
417
+ ):
418
+ # Get the maximum sequence length
419
+ max_length = self.config.max_position_embeddings // 2
420
+
421
+ # Get the maximum sequence length
422
+ ### see script/train_jat.py > L161.
423
+ ### None ==> value set to 512 in jat/processing_jat.py > L354 and then // 2 in L355.
424
+ ### weirdly, the value in script/eval_jat.py is set as 256 so it will be // 2 again in L355.
425
+ # max_length = 64 if self.task.startswith("atari") else None
426
+
427
+ # Convert everything to lists
428
+ def to_list(x):
429
+ return x.tolist() if isinstance(x, np.ndarray) else x
430
+
431
+ continuous_observation = to_list(continuous_observation)
432
+ discrete_observation = to_list(discrete_observation)
433
+
434
+ # get babyai mission within task
435
+ if self.task.startswith("babyai"):
436
+ mission = deepcopy(text_observation)
437
+ assert mission in self.knn_index.keys(), f'{mission=} should be in {self.knn_index.keys()=}'
438
+
439
+ # Add a fake action to the end of the sequence
440
+ if isinstance(action_space, spaces.Box):
441
+ fake_continuous_action = [0.0 for _ in range(action_space.shape[0])]
442
+ fake_discrete_action = None
443
+ elif isinstance(action_space, spaces.Discrete):
444
+ fake_continuous_action = None
445
+ fake_discrete_action = 0
446
+
447
+ continuous_observations = [continuous_observation] if continuous_observation is not None else None
448
+ discrete_observations = [discrete_observation] if discrete_observation is not None else None
449
+ text_observations = [text_observation] if text_observation is not None else None
450
+ image_observations = [image_observation] if image_observation is not None else None
451
+ continuous_actions = [fake_continuous_action] if fake_continuous_action is not None else None
452
+ discrete_actions = [fake_discrete_action] if fake_discrete_action is not None else None
453
+ rewards = [reward] if reward is not None else [0.0]
454
+
455
+ # Add the batch dimension
456
+ continuous_observations = [continuous_observations] if continuous_observations is not None else None
457
+ discrete_observations = [discrete_observations] if discrete_observations is not None else None
458
+ text_observations = [text_observations] if text_observations is not None else None
459
+ image_observations = [image_observations] if image_observations is not None else None
460
+ continuous_actions = [continuous_actions] if continuous_actions is not None else None
461
+ discrete_actions = [discrete_actions] if discrete_actions is not None else None
462
+ rewards = [rewards]
463
+
464
+ # Process the inputs
465
+ processed = processor(
466
+ continuous_observations=continuous_observations,
467
+ discrete_observations=discrete_observations,
468
+ text_observations=text_observations,
469
+ image_observations=image_observations,
470
+ continuous_actions=continuous_actions,
471
+ discrete_actions=discrete_actions,
472
+ rewards=rewards,
473
+ truncation=True,
474
+ truncation_side="left",
475
+ max_length=max_length,
476
+ return_tensors="pt",
477
+ )
478
+
479
+ assert (((self.act_key == 'continuous_actions' and processed[self.act_key].shape == (1, 1, self.act_dim)) or # zeros
480
+ (self.act_key == 'discrete_actions' and processed[self.act_key].shape == (1, 1))) and
481
+ processed[self.obs_key].shape == (1, 1, *self.raw_obs_dim) and
482
+ processed[self.rew_key].shape == (1, 1)), f'{processed[self.act_key].shape=}, {processed[self.obs_key].shape=}, {processed[self.rew_key].shape=}, {self.act_dim=}, {self.raw_obs_dim=}'
483
+
484
+ # save babyai mission
485
+ if self.task.startswith("babyai"):
486
+ processed['mission'] = mission
487
+
488
+ # save action_space and deterministic
489
+ processed['action_space'] = action_space
490
+ processed['deterministic'] = deterministic
491
+
492
+ return processed
493
+
494
+ def retrieve(
495
+ self,
496
+ all_processed: List[dict],
497
+ num_to_retrieve: int,
498
+ ):
499
+ self.steps += 1
500
+ # Set num envs
501
+ num_envs = len(all_processed)
502
+
503
+ # Get obs from processed and make batch
504
+ row_of_obs = [all_processed[idx][self.obs_key][0].numpy() for idx in range(num_envs)]
505
+ row_of_obs = np.concatenate(row_of_obs)
506
+ assert row_of_obs.shape == (num_envs, *self.raw_obs_dim) and isinstance(row_of_obs, np.ndarray)
507
+ if self.task.startswith("atari"):
508
+ row_of_obs = process_row_of_obs_atari_full_without_mask(row_of_obs)
509
+ row_of_obs = torch.from_numpy(row_of_obs).to(self.device)
510
+ with torch.no_grad():
511
+ row_of_obs = self.emb_model(self.emb_transform(row_of_obs)).cpu().numpy()
512
+ elif self.task.startswith("babyai"):
513
+ row_of_obs = row_of_obs[:, :148] # removing last 64 text tokens
514
+ assert row_of_obs.shape == (num_envs, *self.obs_dim) and isinstance(row_of_obs, np.ndarray)
515
+
516
+ # Retrieve indices
517
+ if self.task.startswith("babyai"):
518
+ retrieved_indices = []
519
+ for idx in range(num_envs):
520
+ mission = all_processed[idx]['mission']
521
+ retrieved_indices_mission = retrieve_vector(row_of_obs=row_of_obs[idx:idx+1],
522
+ knn_index=self.knn_index[mission],
523
+ all_indices=self.all_indices[mission],
524
+ num_to_retrieve=num_to_retrieve,
525
+ kwargs=self.kwargs)
526
+ retrieved_indices.append(retrieved_indices_mission) # appending (1, 1, 2)
527
+ retrieved_indices = np.concatenate(retrieved_indices, axis=0)
528
+ assert retrieved_indices.shape == (num_envs, num_to_retrieve, 2)
529
+ else:
530
+ retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
531
+ knn_index=self.knn_index,
532
+ all_indices=self.all_indices,
533
+ num_to_retrieve=num_to_retrieve,
534
+ kwargs=self.kwargs)
535
+
536
+ # Return action
537
+ all_retrieved_act = []
538
+ all_retrieved_obs = []
539
+ all_retrieved_rew = []
540
+ for all_row_idx_and_i in retrieved_indices:
541
+ all_retrieved_act.append([])
542
+ all_retrieved_obs.append([])
543
+ all_retrieved_rew.append([])
544
+ for row_idx, i in all_row_idx_and_i:
545
+ datarow = self.datarows[int(row_idx)]
546
+ temp_a = datarow[self.act_key][int(i)]
547
+ if self.task.startswith("atari") and self.use_global_atari_actions:
548
+ temp_a = convert_local_to_global_action( temp_a, self.task )
549
+ all_retrieved_act[-1].append(temp_a)
550
+ all_retrieved_obs[-1].append(datarow[self.obs_key][int(i)])
551
+ all_retrieved_rew[-1].append(datarow[self.rew_key][int(i)])
552
+
553
+ return all_retrieved_act, all_retrieved_obs, all_retrieved_rew, row_of_obs
554
+
555
+ def get_distances(
556
+ self,
557
+ all_retrieved_obs: np.ndarray,
558
+ all_processed: List[dict],
559
+ query_obs: np.ndarray,
560
+ ):
561
+ num_envs = len(all_processed)
562
+
563
+ # Process retrieved obs like in retrieve
564
+ num_contexts = all_retrieved_obs.shape[1] + 1
565
+ assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.raw_obs_dim) and isinstance(all_retrieved_obs, np.ndarray)
566
+ if self.task.startswith("atari"):
567
+ all_retrieved_obs = all_retrieved_obs.reshape(num_envs * (num_contexts - 1), *self.raw_obs_dim)
568
+ all_retrieved_obs = process_row_of_obs_atari_full_without_mask(all_retrieved_obs)
569
+ all_retrieved_obs = torch.from_numpy(all_retrieved_obs).to(self.device)
570
+ with torch.no_grad():
571
+ all_retrieved_obs = self.emb_model(self.emb_transform(all_retrieved_obs)).cpu().numpy()
572
+ all_retrieved_obs = all_retrieved_obs.reshape(num_envs, num_contexts - 1, *self.obs_dim)
573
+ elif self.task.startswith("babyai"):
574
+ all_retrieved_obs = all_retrieved_obs[:, :, :148]
575
+ assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.obs_dim) and isinstance(all_retrieved_obs, np.ndarray)
576
+
577
+ # Compute distances
578
+ all_distances = []
579
+ for idx in range(num_envs):
580
+ first_state = all_retrieved_obs[idx, 0:1]
581
+ distances = [0.0]
582
+ for i in range(1, num_contexts - 1):
583
+ curr_state = all_retrieved_obs[idx, i:i+1]
584
+ dist = L2dist(first_state, curr_state)
585
+ distances.append(dist)
586
+ curr_state = query_obs[idx:idx+1]
587
+ dist = L2dist(first_state, curr_state)
588
+ distances.append(dist)
589
+ all_distances.append(distances)
590
+ all_distances = np.array(all_distances)
591
+ assert all_distances.shape == (num_envs, num_contexts), f'{all_distances.shape=}, {num_envs=}, {num_contexts=}'
592
+
593
+ # distances: divide by std
594
+ all_distances = all_distances / self.dist_normalizer_value
595
+ if self.task.startswith("mujoco"):
596
+ all_distances = all_distances * self.dist_multipliers['mujoco']
597
+ elif self.task.startswith("atari"):
598
+ all_distances = all_distances * self.dist_multipliers['atari']
599
+ print(f'{self.dist_normalizer_value=}')
600
+ print(f'{all_distances=}')
601
+
602
+ return all_distances
603
+
604
+ @torch.no_grad()
605
+ def get_next_action(
606
+ self,
607
+ all_processed: List[dict],
608
+ return_retrieved_obs: bool = False,
609
+ ):
610
+ num_envs = len(all_processed)
611
+ num_contexts = self.num_contexts
612
+
613
+ # Get the retrieved data
614
+ all_retrieved_act, all_retrieved_obs, all_retrieved_rew, row_of_obs = self.retrieve(all_processed, num_to_retrieve=num_contexts - 1)
615
+ if return_retrieved_obs:
616
+ all_retrieved_images = get_images_of_retrieved_obs(deepcopy(all_retrieved_obs), self.task)
617
+
618
+ # Get the distances
619
+ all_retrieved_obs = np.stack(all_retrieved_obs).astype(np.int32 if self.obs_key == 'discrete_observations' else np.float32)
620
+ assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.raw_obs_dim), f'{all_retrieved_obs.shape=}, {num_envs=}, {self.raw_obs_dim=}, {num_contexts-1=}'
621
+ all_distances = self.get_distances(all_retrieved_obs=all_retrieved_obs, all_processed=all_processed, query_obs=row_of_obs)
622
+
623
+ # Batch retrieved data
624
+ all_retrieved_act = np.stack(all_retrieved_act).astype(np.int32 if self.act_key == 'discrete_actions' else np.float32)
625
+ all_retrieved_rew = np.stack(all_retrieved_rew).astype(np.float32)
626
+ assert (((self.act_key == 'continuous_actions' and all_retrieved_act.shape == (num_envs, num_contexts - 1, self.act_dim)) or
627
+ (self.act_key == 'discrete_actions' and all_retrieved_act.shape == (num_envs, num_contexts - 1))) and
628
+ all_retrieved_rew.shape == (num_envs, num_contexts - 1)), f'{all_retrieved_act.shape=}, {all_retrieved_rew.shape=}, {num_envs=}, {self.act_dim=}, {self.raw_obs_dim=}, {num_contexts-1=}'
629
+
630
+ # Batch query data (already tensors) # query data is already int32/float32 after processing
631
+ all_query_act = torch.stack([all_processed[idx][self.act_key][0] for idx in range(num_envs)])
632
+ all_query_obs = np.stack([all_processed[idx][self.obs_key][0] for idx in range(num_envs)])
633
+ all_query_rew = torch.stack([all_processed[idx][self.rew_key][0] for idx in range(num_envs)])
634
+ assert (((self.act_key == 'continuous_actions' and all_query_act.shape == (num_envs, 1, self.act_dim)) or
635
+ (self.act_key == 'discrete_actions' and all_query_act.shape == (num_envs, 1))) and
636
+ all_query_obs.shape == (num_envs, 1, *self.raw_obs_dim) and
637
+ all_query_rew.shape == (num_envs, 1)), f'{all_query_act.shape=}, {all_query_obs.shape=}, {all_query_rew.shape=}, {num_envs=}, {self.act_dim=}, {self.raw_obs_dim=}'
638
+
639
+ # Collect attn
640
+ attn_weights = np.ones((num_envs, num_contexts)).astype(np.float32)
641
+
642
+ # Compute exp_lamda_distances
643
+ exp_lamda_distances = np.exp(-self.lamda * all_distances)[:, :, np.newaxis]
644
+ assert exp_lamda_distances.shape == (num_envs, num_contexts, 1), f'{exp_lamda_distances.shape=}, {num_envs=}, {num_contexts=}'
645
+
646
+ # Compute extra_key
647
+ all_extra_key = []
648
+ for idx in range(num_envs):
649
+ RandP_action = all_retrieved_act[idx, 0]
650
+ if self.extra_key == 'continuous_RandP_actions':
651
+ extra_key = [RandP_action for _ in range(num_contexts)]
652
+ elif self.extra_key == 'discrete_RandP_action_logits':
653
+ extra_key = []
654
+ for d in all_distances[idx]:
655
+ d = min(1.0, max(0.0, d))
656
+ curr_logits = [1.0/self.N * d for _ in range(self.N)]
657
+ curr_logits[RandP_action] = (1.0 + (self.N - 1.0)*(1.0 - d))/self.N
658
+ extra_key.append(curr_logits)
659
+ extra_key = np.stack(extra_key)
660
+ all_extra_key.append(extra_key)
661
+ all_extra_key = np.stack(all_extra_key).astype(np.float32)
662
+
663
+ if self.extra_key == 'continuous_RandP_actions':
664
+ assert all_extra_key.shape == (num_envs, num_contexts, self.act_dim), f'{all_extra_key.shape=}, {num_envs=}, {num_contexts=}, {self.act_dim=}'
665
+ elif self.extra_key == 'discrete_RandP_action_logits':
666
+ assert all_extra_key.shape == (num_envs, num_contexts, self.N), f'{all_extra_key.shape=}, {num_envs=}, {num_contexts=}, {self.N=}'
667
+
668
+ # Tensorify
669
+ all_retrieved_act = torch.from_numpy(all_retrieved_act)
670
+ all_retrieved_rew = torch.from_numpy(all_retrieved_rew)
671
+ attn_weights = torch.from_numpy(attn_weights).to(self.device)
672
+ exp_lamda_distances = torch.from_numpy(exp_lamda_distances).to(self.device)
673
+ all_extra_key = torch.from_numpy(all_extra_key).to(self.device)
674
+
675
+ # Concat retrieved and query batches
676
+ all_act = torch.cat([all_retrieved_act, all_query_act], dim=1).to(self.device)
677
+ all_obs = np.concatenate([all_retrieved_obs, all_query_obs], axis=1)
678
+ if self.use_atari_embeddings and self.task.startswith("atari"):
679
+ all_obs = all_obs.reshape(num_envs * num_contexts, *self.raw_obs_dim)
680
+ all_obs = process_row_of_obs_atari_full_without_mask(all_obs)
681
+ all_obs = torch.from_numpy(all_obs).to(self.device)
682
+ with torch.no_grad():
683
+ all_obs = self.emb_model_full(self.emb_transform(all_obs)).reshape(num_envs, num_contexts, *self.emb_dim_full)
684
+ else:
685
+ all_obs = torch.from_numpy(all_obs).to(self.device)
686
+ all_rew = torch.cat([all_retrieved_rew, all_query_rew], dim=1).to(self.device)
687
+
688
+ # Collect action_space, deterministic from all_processed
689
+ all_action_space = [all_processed[idx]['action_space'] for idx in range(num_envs)]
690
+ all_deterministic = [all_processed[idx]['deterministic'] for idx in range(num_envs)]
691
+ ## assert that all action_space and deterministic are same for all envs
692
+ assert all([action_space == all_action_space[0] for action_space in all_action_space]), f'{all_action_space=}'
693
+ assert all([deterministic == all_deterministic[0] for deterministic in all_deterministic]), f'{all_deterministic=}'
694
+ ## then just use first one!
695
+ action_space = all_action_space[0]
696
+ deterministic = all_deterministic[0]
697
+
698
+ # Forward pass
699
+ if self.use_atari_embeddings and self.task.startswith("atari"):
700
+ final_obs_key = 'continuous_observations'
701
+ else:
702
+ final_obs_key = self.obs_key
703
+ outputs = self.forward(**{final_obs_key: all_obs,
704
+ self.act_key: all_act,
705
+ self.rew_key: all_rew,
706
+ self.attn_key: attn_weights,
707
+ 'exp_lamda_distances': exp_lamda_distances,
708
+ self.extra_key: all_extra_key,
709
+ }, return_loss=False)
710
+
711
+ # Return the predicted action
712
+ if self.act_key == 'continuous_actions':
713
+ self.last_continuous_action = outputs.pred_actions[:, -1].cpu().numpy()
714
+
715
+ assert self.last_continuous_action.shape == (num_envs, self.act_dim), f'{self.last_continuous_action.shape=}, {num_envs=}, {self.act_dim=}'
716
+
717
+ myprint(f'L2dist(RandP action, Pred action): {[L2dist(all_retrieved_act[idx, 0].cpu().numpy(), self.last_continuous_action[idx]) for idx in range(num_envs)]}')
718
+ self.last_continuous_action = list(self.last_continuous_action) # list of arrays
719
+ return self.last_continuous_action if not return_retrieved_obs else (self.last_continuous_action, all_retrieved_images)
720
+
721
+ elif self.act_key == 'discrete_actions':
722
+ act_n = self.config.action_vocab_size if (self.task.startswith('atari') and self.use_global_atari_actions) else action_space.n
723
+ logits = outputs.pred_actions[:, -1, : act_n]
724
+ assert logits.shape == (num_envs, act_n), f'{logits.shape=}, {num_envs=}, {act_n=}'
725
+ if deterministic:
726
+ # myprint(f'{all_extra_key[:, -1, : action_space.n]=}')
727
+ # myprint(f'{logits=}')
728
+ self.last_discrete_action = logits.argmax(dim=-1, keepdim=True).cpu().numpy().reshape(-1)
729
+ else: # sample
730
+ self.last_discrete_action = torch.multinomial(logits.softmax(dim=-1), num_samples=1).cpu().numpy().reshape(-1)
731
+
732
+ assert self.last_discrete_action.shape == (num_envs,), f'{self.last_discrete_action.shape=}, {num_envs=}'
733
+
734
+ self.last_discrete_action = list(self.last_discrete_action) # list of ints
735
+ myprint(f'RandP action: {all_retrieved_act[:, 0].cpu().numpy().tolist()} vs Pred action: {self.last_discrete_action}')
736
+
737
+ if self.task.startswith("atari") and self.use_global_atari_actions:
738
+ self.last_discrete_action = [convert_global_to_local_action(a, self.task) for a in self.last_discrete_action]
739
+ myprint(f'[IN LOCAL ACTION] RandP action: {[convert_global_to_local_action(a, self.task) for a in all_retrieved_act[:, 0].cpu().numpy().tolist()]} vs Pred action: {self.last_discrete_action}')
740
+ myprint(f'[IN LOCAL ACTION] diff: {[convert_global_to_local_action(a, self.task) - b for a, b in zip(all_retrieved_act[:, 0].cpu().numpy().tolist(), self.last_discrete_action)]}')
741
+
742
+ return self.last_discrete_action if not return_retrieved_obs else (self.last_discrete_action, all_retrieved_images)
743
+
744
+
745
+ JatRegentModel.register_for_auto_class("AutoModelForCausalLM")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f64fcecd190c7ed6c8c913e44d0ecc47ceaca4d52a84d1e96d18ebe985db8ef5
3
+ size 510060470