Update modeling_internvl_chat.py
Browse files- modeling_internvl_chat.py +14 -9
modeling_internvl_chat.py
CHANGED
@@ -16,6 +16,8 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
16 |
from transformers.modeling_utils import PreTrainedModel
|
17 |
from transformers.utils import ModelOutput, logging
|
18 |
|
|
|
|
|
19 |
from .configuration_internvl_chat import InternVLChatConfig
|
20 |
from .conversation import get_conv_template
|
21 |
from .modeling_intern_vit import InternVisionModel
|
@@ -280,13 +282,6 @@ class InternVLChatModel(PreTrainedModel):
|
|
280 |
model_inputs = tokenizer(query, return_tensors='pt')
|
281 |
input_ids = model_inputs['input_ids'].cuda()
|
282 |
attention_mask = model_inputs['attention_mask'].cuda()
|
283 |
-
|
284 |
-
if verbose:
|
285 |
-
print(f"hehe: {self.num_image_token * num_patches}")
|
286 |
-
print(f"hehe: {query}")
|
287 |
-
print(f"hehe: {input_ids.shape}")
|
288 |
-
print(f"hehe: {attention_mask.shape}")
|
289 |
-
|
290 |
generation_config['eos_token_id'] = eos_token_id
|
291 |
generation_output = self.generate(
|
292 |
pixel_values=pixel_values,
|
@@ -324,8 +319,18 @@ class InternVLChatModel(PreTrainedModel):
|
|
324 |
if visual_features is not None:
|
325 |
vit_embeds = visual_features
|
326 |
else:
|
327 |
-
vit_embeds =
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
print(vit_embeds.shape)
|
330 |
|
331 |
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
|
|
16 |
from transformers.modeling_utils import PreTrainedModel
|
17 |
from transformers.utils import ModelOutput, logging
|
18 |
|
19 |
+
from einops import rearrange
|
20 |
+
|
21 |
from .configuration_internvl_chat import InternVLChatConfig
|
22 |
from .conversation import get_conv_template
|
23 |
from .modeling_intern_vit import InternVisionModel
|
|
|
282 |
model_inputs = tokenizer(query, return_tensors='pt')
|
283 |
input_ids = model_inputs['input_ids'].cuda()
|
284 |
attention_mask = model_inputs['attention_mask'].cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
generation_config['eos_token_id'] = eos_token_id
|
286 |
generation_output = self.generate(
|
287 |
pixel_values=pixel_values,
|
|
|
319 |
if visual_features is not None:
|
320 |
vit_embeds = visual_features
|
321 |
else:
|
322 |
+
vit_embeds = []
|
323 |
+
num_chunks = 2
|
324 |
+
pixel_values_splitted = pixel_values.chunk(num_chunks)
|
325 |
+
|
326 |
+
for pixel_values_ in pixel_values_splitted:
|
327 |
+
pixel_values_ = pixel_values_.flatten(0, 1)
|
328 |
+
v_feats = self.extract_feature(pixel_values_) # examples: (18, 256, 896) = 2 * (9, 256, 896)
|
329 |
+
v_feats = rearrange(v_feats, "(b n) t c -> b n t c", n=num_chunks) # b: batch_size, n: num_patches, t: num_tokens, c: hidden_size
|
330 |
+
vit_embeds.append(v_feats)
|
331 |
+
|
332 |
+
vit_embeds = torch.cat(vit_embeds) #
|
333 |
+
|
334 |
print(vit_embeds.shape)
|
335 |
|
336 |
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|