ahmed-masry commited on
Commit
9f9c2cc
1 Parent(s): e8ad0b4

Create processing_utils.py

Browse files
Files changed (1) hide show
  1. processing_utils.py +121 -0
processing_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import BatchEncoding, BatchFeature
7
+
8
+ def get_torch_device(device: str = "auto") -> str:
9
+ """
10
+ Returns the device (string) to be used by PyTorch.
11
+
12
+ `device` arg defaults to "auto" which will use:
13
+ - "cuda:0" if available
14
+ - else "mps" if available
15
+ - else "cpu".
16
+ """
17
+
18
+ if device == "auto":
19
+ if torch.cuda.is_available():
20
+ device = "cuda:0"
21
+ elif torch.backends.mps.is_available(): # for Apple Silicon
22
+ device = "mps"
23
+ else:
24
+ device = "cpu"
25
+ logger.info(f"Using device: {device}")
26
+
27
+ return device
28
+
29
+ class BaseVisualRetrieverProcessor(ABC):
30
+ """
31
+ Base class for visual retriever processors.
32
+ """
33
+
34
+ @abstractmethod
35
+ def process_images(
36
+ self,
37
+ images: List[Image.Image],
38
+ ) -> Union[BatchFeature, BatchEncoding]:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def process_queries(
43
+ self,
44
+ queries: List[str],
45
+ max_length: int = 50,
46
+ suffix: Optional[str] = None,
47
+ ) -> Union[BatchFeature, BatchEncoding]:
48
+ pass
49
+
50
+ @abstractmethod
51
+ def score(
52
+ self,
53
+ qs: List[torch.Tensor],
54
+ ps: List[torch.Tensor],
55
+ device: Optional[Union[str, torch.device]] = None,
56
+ **kwargs,
57
+ ) -> torch.Tensor:
58
+ pass
59
+
60
+ @staticmethod
61
+ def score_single_vector(
62
+ qs: List[torch.Tensor],
63
+ ps: List[torch.Tensor],
64
+ device: Optional[Union[str, torch.device]] = None,
65
+ ) -> torch.Tensor:
66
+ """
67
+ Compute the dot product score for the given single-vector query and passage embeddings.
68
+ """
69
+ device = device or get_torch_device("auto")
70
+
71
+ if len(qs) == 0:
72
+ raise ValueError("No queries provided")
73
+ if len(ps) == 0:
74
+ raise ValueError("No passages provided")
75
+
76
+ qs_stacked = torch.stack(qs).to(device)
77
+ ps_stacked = torch.stack(ps).to(device)
78
+
79
+ scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
80
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
81
+
82
+ scores = scores.to(torch.float32)
83
+ return scores
84
+
85
+ @staticmethod
86
+ def score_multi_vector(
87
+ qs: List[torch.Tensor],
88
+ ps: List[torch.Tensor],
89
+ batch_size: int = 128,
90
+ device: Optional[Union[str, torch.device]] = None,
91
+ ) -> torch.Tensor:
92
+ """
93
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
94
+ """
95
+ device = device or get_torch_device("auto")
96
+
97
+ if len(qs) == 0:
98
+ raise ValueError("No queries provided")
99
+ if len(ps) == 0:
100
+ raise ValueError("No passages provided")
101
+
102
+ scores_list: List[torch.Tensor] = []
103
+
104
+ for i in range(0, len(qs), batch_size):
105
+ scores_batch = []
106
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
107
+ device
108
+ )
109
+ for j in range(0, len(ps), batch_size):
110
+ ps_batch = torch.nn.utils.rnn.pad_sequence(
111
+ ps[j : j + batch_size], batch_first=True, padding_value=0
112
+ ).to(device)
113
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
114
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
115
+ scores_list.append(scores_batch)
116
+
117
+ scores = torch.cat(scores_list, dim=0)
118
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
119
+
120
+ scores = scores.to(torch.float32)
121
+ return scores