tt1225 commited on
Commit
70f1502
1 Parent(s): 1af01cc

Update modeling_internvl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +14 -9
modeling_internvl_chat.py CHANGED
@@ -16,6 +16,8 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import ModelOutput, logging
18
 
 
 
19
  from .configuration_internvl_chat import InternVLChatConfig
20
  from .conversation import get_conv_template
21
  from .modeling_intern_vit import InternVisionModel
@@ -280,13 +282,6 @@ class InternVLChatModel(PreTrainedModel):
280
  model_inputs = tokenizer(query, return_tensors='pt')
281
  input_ids = model_inputs['input_ids'].cuda()
282
  attention_mask = model_inputs['attention_mask'].cuda()
283
-
284
- if verbose:
285
- print(f"hehe: {self.num_image_token * num_patches}")
286
- print(f"hehe: {query}")
287
- print(f"hehe: {input_ids.shape}")
288
- print(f"hehe: {attention_mask.shape}")
289
-
290
  generation_config['eos_token_id'] = eos_token_id
291
  generation_output = self.generate(
292
  pixel_values=pixel_values,
@@ -324,8 +319,18 @@ class InternVLChatModel(PreTrainedModel):
324
  if visual_features is not None:
325
  vit_embeds = visual_features
326
  else:
327
- vit_embeds = self.extract_feature(pixel_values)
328
-
 
 
 
 
 
 
 
 
 
 
329
  print(vit_embeds.shape)
330
 
331
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import ModelOutput, logging
18
 
19
+ from einops import rearrange
20
+
21
  from .configuration_internvl_chat import InternVLChatConfig
22
  from .conversation import get_conv_template
23
  from .modeling_intern_vit import InternVisionModel
 
282
  model_inputs = tokenizer(query, return_tensors='pt')
283
  input_ids = model_inputs['input_ids'].cuda()
284
  attention_mask = model_inputs['attention_mask'].cuda()
 
 
 
 
 
 
 
285
  generation_config['eos_token_id'] = eos_token_id
286
  generation_output = self.generate(
287
  pixel_values=pixel_values,
 
319
  if visual_features is not None:
320
  vit_embeds = visual_features
321
  else:
322
+ vit_embeds = []
323
+ num_chunks = 2
324
+ pixel_values_splitted = pixel_values.chunk(num_chunks)
325
+
326
+ for pixel_values_ in pixel_values_splitted:
327
+ pixel_values_ = pixel_values_.flatten(0, 1)
328
+ v_feats = self.extract_feature(pixel_values_) # examples: (18, 256, 896) = 2 * (9, 256, 896)
329
+ v_feats = rearrange(v_feats, "(b n) t c -> b n t c", n=num_chunks) # b: batch_size, n: num_patches, t: num_tokens, c: hidden_size
330
+ vit_embeds.append(v_feats)
331
+
332
+ vit_embeds = torch.cat(vit_embeds) #
333
+
334
  print(vit_embeds.shape)
335
 
336
  input_embeds = self.language_model.get_input_embeddings()(input_ids)