update llama_action model
Browse files- modeling_llama_action.py +4 -14
modeling_llama_action.py
CHANGED
@@ -200,29 +200,19 @@ class LlamaActionForCausalLM(LlamaForCausalLM):
|
|
200 |
past_key_values=None,
|
201 |
attention_mask=None,
|
202 |
use_cache=None,
|
203 |
-
|
204 |
-
prefix="",
|
205 |
-
total=0,
|
206 |
**kwargs):
|
207 |
batch_size = input_ids.size(0)
|
208 |
seq_length = input_ids.size(1)
|
209 |
n_frames = seq_length // self.num_image_patches
|
210 |
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
|
211 |
-
if
|
212 |
-
|
213 |
-
pbar = tqdm(total=total - len(input_ids[0]), desc=prefix, leave=False)
|
214 |
-
postfix = f"Frame [{n_frames + 1}/{total // self.num_image_patches}]"
|
215 |
-
pbar.set_postfix_str(postfix)
|
216 |
-
else:
|
217 |
-
pbar.update()
|
218 |
|
219 |
if seq_length % self.num_image_patches != 0:
|
220 |
n_last_frame_tokens = seq_length % self.num_image_patches
|
221 |
attention_mask_length += n_last_frame_tokens
|
222 |
-
|
223 |
-
if show_progress:
|
224 |
-
postfix = f"Frame [{n_frames + 1}/{total // self.num_image_patches}]"
|
225 |
-
pbar.set_postfix_str(postfix)
|
226 |
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
|
227 |
# cut decoder_input_ids if past_key_values is used
|
228 |
if past_key_values is not None and len(past_key_values) > 0:
|
|
|
200 |
past_key_values=None,
|
201 |
attention_mask=None,
|
202 |
use_cache=None,
|
203 |
+
progress_bar=None,
|
|
|
|
|
204 |
**kwargs):
|
205 |
batch_size = input_ids.size(0)
|
206 |
seq_length = input_ids.size(1)
|
207 |
n_frames = seq_length // self.num_image_patches
|
208 |
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
|
209 |
+
if progress_bar is not None:
|
210 |
+
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
if seq_length % self.num_image_patches != 0:
|
213 |
n_last_frame_tokens = seq_length % self.num_image_patches
|
214 |
attention_mask_length += n_last_frame_tokens
|
215 |
+
|
|
|
|
|
|
|
216 |
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
|
217 |
# cut decoder_input_ids if past_key_values is used
|
218 |
if past_key_values is not None and len(past_key_values) > 0:
|