koukyo1994 commited on
Commit
29e517a
·
verified ·
1 Parent(s): 3d2e679

update llama_action model

Browse files
Files changed (1) hide show
  1. modeling_llama_action.py +4 -14
modeling_llama_action.py CHANGED
@@ -200,29 +200,19 @@ class LlamaActionForCausalLM(LlamaForCausalLM):
200
  past_key_values=None,
201
  attention_mask=None,
202
  use_cache=None,
203
- show_progress=False,
204
- prefix="",
205
- total=0,
206
  **kwargs):
207
  batch_size = input_ids.size(0)
208
  seq_length = input_ids.size(1)
209
  n_frames = seq_length // self.num_image_patches
210
  attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
211
- if show_progress:
212
- if past_key_values is None or len(past_key_values) == 0:
213
- pbar = tqdm(total=total - len(input_ids[0]), desc=prefix, leave=False)
214
- postfix = f"Frame [{n_frames + 1}/{total // self.num_image_patches}]"
215
- pbar.set_postfix_str(postfix)
216
- else:
217
- pbar.update()
218
 
219
  if seq_length % self.num_image_patches != 0:
220
  n_last_frame_tokens = seq_length % self.num_image_patches
221
  attention_mask_length += n_last_frame_tokens
222
- else:
223
- if show_progress:
224
- postfix = f"Frame [{n_frames + 1}/{total // self.num_image_patches}]"
225
- pbar.set_postfix_str(postfix)
226
  attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
227
  # cut decoder_input_ids if past_key_values is used
228
  if past_key_values is not None and len(past_key_values) > 0:
 
200
  past_key_values=None,
201
  attention_mask=None,
202
  use_cache=None,
203
+ progress_bar=None,
 
 
204
  **kwargs):
205
  batch_size = input_ids.size(0)
206
  seq_length = input_ids.size(1)
207
  n_frames = seq_length // self.num_image_patches
208
  attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
209
+ if progress_bar is not None:
210
+ progress_bar.update()
 
 
 
 
 
211
 
212
  if seq_length % self.num_image_patches != 0:
213
  n_last_frame_tokens = seq_length % self.num_image_patches
214
  attention_mask_length += n_last_frame_tokens
215
+
 
 
 
216
  attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
217
  # cut decoder_input_ids if past_key_values is used
218
  if past_key_values is not None and len(past_key_values) > 0: