khang119966 commited on
Commit
c39b2dc
1 Parent(s): 573c459

Upload 4 files

Browse files
modeling_colinternvl2.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import ClassVar, List, Optional
2
+ from typing import Any, List, Optional, Tuple, Union
3
+ import torch
4
+ from torch import nn
5
+ from .modeling_internvl_chat import InternVLChatModel, InternVLChatConfig
6
+ import math
7
+
8
+ class ColInternVL2(InternVLChatModel):
9
+ """
10
+ ColInternVL2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
11
+ """
12
+
13
+ # main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
14
+
15
+ def __init__(self, config: InternVLChatConfig):
16
+ super().__init__(config=config)
17
+ self.dim = 128
18
+ self.custom_text_proj = nn.Linear(self.language_model.model.config.hidden_size, self.dim ) #, bias=False)
19
+ self.padding_side = "left"
20
+ self.img_context_token_id = 151648
21
+ # self.post_init()
22
+ self.init_linear()
23
+
24
+ def init_linear(self):
25
+ print(self.language_model.model.embed_tokens.weight)
26
+ stdv = 1. / math.sqrt(self.custom_text_proj.weight.size(1))
27
+ self.custom_text_proj.weight.data = self.custom_text_proj.weight.data.uniform_(-stdv, stdv)
28
+ if self.custom_text_proj.bias is not None:
29
+ self.custom_text_proj.bias.data = self.custom_text_proj.bias.data.uniform_(-stdv, stdv)
30
+
31
+
32
+ def forward(
33
+ self,
34
+ pixel_values: torch.FloatTensor = None,
35
+ input_ids: torch.LongTensor = None,
36
+ attention_mask: Optional[torch.Tensor] = None,
37
+ position_ids: Optional[torch.LongTensor] = None,
38
+ image_flags: Optional[torch.LongTensor] = None,
39
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
40
+ labels: Optional[torch.LongTensor] = None,
41
+ use_cache: Optional[bool] = None,
42
+ output_attentions: Optional[bool] = None,
43
+ output_hidden_states: Optional[bool] = None,
44
+ return_dict: Optional[bool] = None,
45
+ statistics: Optional[torch.LongTensor] = None,
46
+ loss_weight: Optional[List] = None,
47
+ loss_reduction_all_gather: Optional[bool] = False,
48
+ **kwargs
49
+ ) -> torch.Tensor:
50
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
51
+
52
+
53
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
54
+ B, N, C = input_embeds.shape
55
+
56
+ if pixel_values is not None:
57
+
58
+ pixel_values = pixel_values.type(self.vision_model.embeddings.patch_embedding.weight.dtype)
59
+ vit_embeds = self.extract_feature(pixel_values)
60
+ # image_flags = image_flags.squeeze(-1)
61
+ # vit_embeds = vit_embeds[image_flags == 1]
62
+ vit_batch_size = pixel_values.shape[0]
63
+
64
+ input_embeds = input_embeds.reshape(B * N, C)
65
+
66
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
67
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
68
+ if statistics is not None:
69
+ num_samples, num_padding_tokens, num_padding_images = statistics.tolist()
70
+ self.num_samples += num_samples
71
+ print(f'total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}')
72
+
73
+ input_ids = input_ids.reshape(B * N)
74
+ selected = (input_ids == self.img_context_token_id)
75
+ try:
76
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
77
+ ignore_flag = False
78
+ except Exception as e:
79
+
80
+ vit_embeds = vit_embeds.reshape(-1, C)
81
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
82
+ f'vit_embeds.shape={vit_embeds.shape}')
83
+ n_token = selected.sum()
84
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
85
+ ignore_flag = True
86
+
87
+ input_embeds = input_embeds.reshape(B, N, C)
88
+
89
+ outputs = self.language_model.model(
90
+ inputs_embeds=input_embeds,
91
+ attention_mask=attention_mask,
92
+ position_ids=position_ids,
93
+ past_key_values=past_key_values,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=True,
97
+ return_dict=return_dict,
98
+ )
99
+
100
+ last_hidden_states = outputs[0].type(self.custom_text_proj.weight.dtype)
101
+ proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
102
+
103
+ # L2 normalization
104
+ proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
105
+ proj = proj * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
106
+ return proj
107
+
108
+
109
+ @property
110
+ def get_patch_size(self) -> int:
111
+ return self.visual.config.patch_size
112
+
113
+ @property
114
+ def spatial_merge_size(self) -> int:
115
+ return self.visual.config.spatial_merge_size
processing_colinternvl2.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import ClassVar, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import BatchFeature
7
+
8
+ from .processing_utils import BaseVisualRetrieverProcessor
9
+ import numpy as np
10
+ import torch
11
+ import torchvision.transforms as T
12
+ from decord import VideoReader, cpu
13
+ from PIL import Image
14
+ from torchvision.transforms.functional import InterpolationMode
15
+ from transformers import AutoModel, AutoTokenizer
16
+ from .conversation import get_conv_template
17
+ from transformers import BatchFeature, ProcessorMixin
18
+
19
+ class ColInternVL2Processor(BaseVisualRetrieverProcessor, ProcessorMixin):
20
+ """
21
+ Processor for ColInternVL2.
22
+ """
23
+ attributes = [ "tokenizer"]
24
+ image_processor_class = "InternVL2ImageProcessor"
25
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
26
+
27
+ def __init__(self, tokenizer, **kwargs):
28
+ self.template = "Hermes-2"
29
+ self.num_image_token = 256
30
+ # self.max_num = 6
31
+ self.max_num = 4
32
+
33
+ if isinstance(tokenizer, str):
34
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True, use_fast=False)
35
+ else:
36
+ self.tokenizer = tokenizer
37
+
38
+ self.tokenizer.padding_side = 'left'
39
+ self.IMAGENET_MEAN = (0.485, 0.456, 0.406)
40
+ self.IMAGENET_STD = (0.229, 0.224, 0.225)
41
+ self.IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
42
+ self.IMG_START_TOKEN='<img>'
43
+ self.IMG_END_TOKEN='</img>'
44
+ self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN)
45
+ # self.system_message = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
46
+ self.system_message = 'Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.'
47
+ super().__init__(tokenizer)
48
+
49
+ # def from_pretrained(pretrained_model_name_or_path, template="Hermes-2", **kwargs):
50
+ # return ColInternVL2Processor(pretrained_model_name_or_path, template=template, **kwargs)
51
+
52
+ def build_transform(self, input_size):
53
+ MEAN, STD = self.IMAGENET_MEAN, self.IMAGENET_STD
54
+ transform = T.Compose([
55
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
56
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
57
+ T.ToTensor(),
58
+ T.Normalize(mean=MEAN, std=STD)
59
+ ])
60
+ return transform
61
+
62
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
63
+ best_ratio_diff = float('inf')
64
+ best_ratio = (1, 1)
65
+ area = width * height
66
+ for ratio in target_ratios:
67
+ target_aspect_ratio = ratio[0] / ratio[1]
68
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
69
+ if ratio_diff < best_ratio_diff:
70
+ best_ratio_diff = ratio_diff
71
+ best_ratio = ratio
72
+ elif ratio_diff == best_ratio_diff:
73
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
74
+ best_ratio = ratio
75
+ return best_ratio
76
+
77
+ def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
78
+ orig_width, orig_height = image.size
79
+ aspect_ratio = orig_width / orig_height
80
+
81
+ # calculate the existing image aspect ratio
82
+ target_ratios = set(
83
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
84
+ i * j <= max_num and i * j >= min_num)
85
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
86
+
87
+ # find the closest aspect ratio to the target
88
+ target_aspect_ratio = self.find_closest_aspect_ratio(
89
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
90
+
91
+ # calculate the target width and height
92
+ target_width = image_size * target_aspect_ratio[0]
93
+ target_height = image_size * target_aspect_ratio[1]
94
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
95
+
96
+ # resize the image
97
+ resized_img = image.resize((target_width, target_height))
98
+ processed_images = []
99
+ for i in range(blocks):
100
+ box = (
101
+ (i % (target_width // image_size)) * image_size,
102
+ (i // (target_width // image_size)) * image_size,
103
+ ((i % (target_width // image_size)) + 1) * image_size,
104
+ ((i // (target_width // image_size)) + 1) * image_size
105
+ )
106
+ # split the image
107
+ split_img = resized_img.crop(box)
108
+ processed_images.append(split_img)
109
+ assert len(processed_images) == blocks
110
+ if use_thumbnail and len(processed_images) != 1:
111
+ thumbnail_img = image.resize((image_size, image_size))
112
+ processed_images.append(thumbnail_img)
113
+ return processed_images
114
+
115
+ def load_image(self, image, input_size=448, max_num=12):
116
+ transform = self.build_transform(input_size=input_size)
117
+ images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
118
+ pixel_values = [transform(image) for image in images]
119
+ pixel_values = torch.stack(pixel_values)
120
+ return pixel_values
121
+
122
+
123
+ def process_images(
124
+ self,
125
+ images: List[Image.Image],
126
+ ) -> BatchFeature:
127
+ """
128
+ Process images for InternVl2.
129
+ """
130
+
131
+ pixel_values = [ self.load_image(image, max_num=self.max_num) for image in images]
132
+
133
+ num_patches_list = [ pixel_.size(0) for pixel_ in pixel_values]
134
+ image_flags = [ torch.tensor([1] * pixel_.shape[0], dtype=torch.long) for pixel_ in pixel_values ]
135
+
136
+ queries = []
137
+ for idx, num_patches in enumerate(num_patches_list):
138
+ question = "<image>\nDescribe the image."
139
+
140
+ template = get_conv_template(self.template)
141
+ template.system_message = self.system_message
142
+ template.append_message(template.roles[0], question)
143
+ template.append_message(template.roles[1], None)
144
+ query = template.get_prompt()
145
+ image_tokens = self.IMG_START_TOKEN + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + self.IMG_END_TOKEN
146
+ query = query.replace('<image>', image_tokens, 1)
147
+ queries.append(query)
148
+
149
+ model_inputs = self.tokenizer(queries, return_tensors='pt', padding=True)
150
+ input_ids = model_inputs['input_ids'] #.to(self.device)
151
+ attention_mask = model_inputs['attention_mask'] #.to(self.device)
152
+ pixel_values = torch.cat(pixel_values)
153
+
154
+ batch_doc = BatchFeature({
155
+ "pixel_values" : pixel_values,
156
+ "input_ids" : input_ids,
157
+ "attention_mask" : attention_mask,
158
+ # "image_flags" : image_flags
159
+ })
160
+ return batch_doc
161
+
162
+ def process_queries(
163
+ self,
164
+ queries: List[str],
165
+ max_length: int = 100,
166
+ suffix: Optional[str] = None,
167
+ ) -> BatchFeature:
168
+ """
169
+ Process queries for InternVl2.
170
+ """
171
+
172
+ texts_query: List[str] = []
173
+
174
+ for query in queries:
175
+ query = f"Query: {query}"
176
+ template = get_conv_template(self.template)
177
+ template.system_message = self.system_message
178
+ template.append_message(template.roles[0], query)
179
+ template.append_message(template.roles[1], None)
180
+ query = template.get_prompt()
181
+ texts_query.append(query)
182
+
183
+ model_inputs = self.tokenizer(texts_query, return_tensors='pt', max_length=max_length, padding="longest")
184
+ input_ids = model_inputs['input_ids'] #.to(self.device)
185
+ attention_mask = model_inputs['attention_mask'] #.to(self.device)
186
+
187
+ batch_query = BatchFeature({
188
+ "pixel_values" : None,
189
+ "input_ids" : input_ids,
190
+ "attention_mask" : attention_mask,
191
+ })
192
+ return batch_query
193
+
194
+ def score(
195
+ self,
196
+ qs: List[torch.Tensor],
197
+ ps: List[torch.Tensor],
198
+ device: Optional[Union[str, torch.device]] = None,
199
+ **kwargs,
200
+ ) -> torch.Tensor:
201
+ """
202
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
203
+ """
204
+ return self.score_multi_vector(qs, ps, device=device, **kwargs)
205
+
206
+ def get_n_patches(
207
+ self,
208
+ image_size: Tuple[int, int],
209
+ patch_size: int,
210
+ ) -> Tuple[int, int]:
211
+ raise NotImplementedError("This method is not implemented for ColInternVL2.")
processing_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import BatchEncoding, BatchFeature
7
+
8
+ from .torch_utils import get_torch_device
9
+
10
+
11
+ class BaseVisualRetrieverProcessor(ABC):
12
+ """
13
+ Base class for visual retriever processors.
14
+ """
15
+
16
+ @abstractmethod
17
+ def process_images(
18
+ self,
19
+ images: List[Image.Image],
20
+ ) -> Union[BatchFeature, BatchEncoding]:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def process_queries(
25
+ self,
26
+ queries: List[str],
27
+ max_length: int = 50,
28
+ suffix: Optional[str] = None,
29
+ ) -> Union[BatchFeature, BatchEncoding]:
30
+ pass
31
+
32
+ @abstractmethod
33
+ def score(
34
+ self,
35
+ qs: List[torch.Tensor],
36
+ ps: List[torch.Tensor],
37
+ device: Optional[Union[str, torch.device]] = None,
38
+ **kwargs,
39
+ ) -> torch.Tensor:
40
+ pass
41
+
42
+ @staticmethod
43
+ def score_single_vector(
44
+ qs: List[torch.Tensor],
45
+ ps: List[torch.Tensor],
46
+ device: Optional[Union[str, torch.device]] = None,
47
+ ) -> torch.Tensor:
48
+ """
49
+ Compute the dot product score for the given single-vector query and passage embeddings.
50
+ """
51
+ device = device or get_torch_device("auto")
52
+
53
+ if len(qs) == 0:
54
+ raise ValueError("No queries provided")
55
+ if len(ps) == 0:
56
+ raise ValueError("No passages provided")
57
+
58
+ qs_stacked = torch.stack(qs).to(device)
59
+ ps_stacked = torch.stack(ps).to(device)
60
+
61
+ scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
62
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
63
+
64
+ scores = scores.to(torch.float32)
65
+ return scores
66
+
67
+ @staticmethod
68
+ def score_multi_vector(
69
+ qs: List[torch.Tensor],
70
+ ps: List[torch.Tensor],
71
+ batch_size: int = 128,
72
+ device: Optional[Union[str, torch.device]] = None,
73
+ ) -> torch.Tensor:
74
+ """
75
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
76
+ """
77
+ device = device or get_torch_device("auto")
78
+
79
+ if len(qs) == 0:
80
+ raise ValueError("No queries provided")
81
+ if len(ps) == 0:
82
+ raise ValueError("No passages provided")
83
+
84
+ scores_list: List[torch.Tensor] = []
85
+
86
+ for i in range(0, len(qs), batch_size):
87
+ scores_batch = []
88
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
89
+ device
90
+ )
91
+ for j in range(0, len(ps), batch_size):
92
+ ps_batch = torch.nn.utils.rnn.pad_sequence(
93
+ ps[j : j + batch_size], batch_first=True, padding_value=0
94
+ ).to(device)
95
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
96
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
97
+ scores_list.append(scores_batch)
98
+
99
+ scores = torch.cat(scores_list, dim=0)
100
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
101
+
102
+ scores = scores.to(torch.float32)
103
+ return scores
104
+
105
+ @abstractmethod
106
+ def get_n_patches(
107
+ self,
108
+ image_size: Tuple[int, int],
109
+ patch_size: int = 14,
110
+ *args,
111
+ **kwargs,
112
+ ) -> Tuple[int, int]:
113
+ """
114
+ Get the number of patches (n_patches_x, n_patches_y) that will be used to process an
115
+ image of size (height, width) with the given patch size.
116
+ """
117
+ pass
torch_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from typing import List, TypeVar
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ logger = logging.getLogger(__name__)
9
+ T = TypeVar("T")
10
+
11
+
12
+ def get_torch_device(device: str = "auto") -> str:
13
+ """
14
+ Returns the device (string) to be used by PyTorch.
15
+
16
+ `device` arg defaults to "auto" which will use:
17
+ - "cuda:0" if available
18
+ - else "mps" if available
19
+ - else "cpu".
20
+ """
21
+
22
+ if device == "auto":
23
+ if torch.cuda.is_available():
24
+ device = "cuda:0"
25
+ elif torch.backends.mps.is_available(): # for Apple Silicon
26
+ device = "mps"
27
+ else:
28
+ device = "cpu"
29
+ logger.info(f"Using device: {device}")
30
+
31
+ return device
32
+
33
+
34
+ def tear_down_torch():
35
+ """
36
+ Teardown for PyTorch.
37
+ Clears GPU cache for both CUDA and MPS.
38
+ """
39
+ gc.collect()
40
+ torch.cuda.empty_cache()
41
+ torch.mps.empty_cache()
42
+
43
+
44
+ class ListDataset(Dataset[T]):
45
+ def __init__(self, elements: List[T]):
46
+ self.elements = elements
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.elements)
50
+
51
+ def __getitem__(self, idx: int) -> T:
52
+ return self.elements[idx]