khang119966 commited on
Commit
7e559c4
1 Parent(s): 0d06e91

Update processing_colinternvl2.py

Browse files
Files changed (1) hide show
  1. processing_colinternvl2.py +58 -0
processing_colinternvl2.py CHANGED
@@ -16,6 +16,25 @@ from transformers import AutoModel, AutoTokenizer
16
  from .conversation import get_conv_template
17
  from transformers import BatchFeature, ProcessorMixin
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class ColInternVL2Processor(BaseVisualRetrieverProcessor, ProcessorMixin):
20
  """
21
  Processor for ColInternVL2.
@@ -205,3 +224,42 @@ class ColInternVL2Processor(BaseVisualRetrieverProcessor, ProcessorMixin):
205
  patch_size: int,
206
  ) -> Tuple[int, int]:
207
  raise NotImplementedError("This method is not implemented for ColInternVL2.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from .conversation import get_conv_template
17
  from transformers import BatchFeature, ProcessorMixin
18
 
19
+ def get_torch_device(device: str = "auto") -> str:
20
+ """
21
+ Returns the device (string) to be used by PyTorch.
22
+
23
+ `device` arg defaults to "auto" which will use:
24
+ - "cuda:0" if available
25
+ - else "mps" if available
26
+ - else "cpu".
27
+ """
28
+
29
+ if device == "auto":
30
+ if torch.cuda.is_available():
31
+ device = "cuda:0"
32
+ elif torch.backends.mps.is_available(): # for Apple Silicon
33
+ device = "mps"
34
+ else:
35
+ device = "cpu"
36
+ return device
37
+
38
  class ColInternVL2Processor(BaseVisualRetrieverProcessor, ProcessorMixin):
39
  """
40
  Processor for ColInternVL2.
 
224
  patch_size: int,
225
  ) -> Tuple[int, int]:
226
  raise NotImplementedError("This method is not implemented for ColInternVL2.")
227
+
228
+ def score_multi_vector(
229
+ self,
230
+ qs: List[torch.Tensor],
231
+ ps: List[torch.Tensor],
232
+ batch_size: int = 128,
233
+ device: Optional[Union[str, torch.device]] = None,
234
+ ) -> torch.Tensor:
235
+ """
236
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
237
+ """
238
+ device = device or get_torch_device("auto")
239
+
240
+ if len(qs) == 0:
241
+ raise ValueError("No queries provided")
242
+ if len(ps) == 0:
243
+ raise ValueError("No passages provided")
244
+
245
+ scores_list: List[torch.Tensor] = []
246
+
247
+ for i in range(0, len(qs), batch_size):
248
+ scores_batch = []
249
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).float().to(
250
+ device
251
+ )
252
+ for j in range(0, len(ps), batch_size):
253
+ ps_batch = torch.nn.utils.rnn.pad_sequence(
254
+ ps[j : j + batch_size], batch_first=True, padding_value=0
255
+ ).float().to(device)
256
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
257
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
258
+ scores_list.append(scores_batch)
259
+
260
+ scores = torch.cat(scores_list, dim=0)
261
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
262
+
263
+ scores = scores.to(torch.float32)
264
+ return scores
265
+