|
from abc import ABC, abstractmethod |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from PIL import Image |
|
from transformers import BatchEncoding, BatchFeature |
|
|
|
from .torch_utils import get_torch_device |
|
|
|
|
|
class BaseVisualRetrieverProcessor(ABC): |
|
""" |
|
Base class for visual retriever processors. |
|
""" |
|
|
|
@abstractmethod |
|
def process_images( |
|
self, |
|
images: List[Image.Image], |
|
) -> Union[BatchFeature, BatchEncoding]: |
|
pass |
|
|
|
@abstractmethod |
|
def process_queries( |
|
self, |
|
queries: List[str], |
|
max_length: int = 50, |
|
suffix: Optional[str] = None, |
|
) -> Union[BatchFeature, BatchEncoding]: |
|
pass |
|
|
|
@abstractmethod |
|
def score( |
|
self, |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
device: Optional[Union[str, torch.device]] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
pass |
|
|
|
@staticmethod |
|
def score_single_vector( |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
device: Optional[Union[str, torch.device]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the dot product score for the given single-vector query and passage embeddings. |
|
""" |
|
device = device or get_torch_device("auto") |
|
|
|
if len(qs) == 0: |
|
raise ValueError("No queries provided") |
|
if len(ps) == 0: |
|
raise ValueError("No passages provided") |
|
|
|
qs_stacked = torch.stack(qs).to(device) |
|
ps_stacked = torch.stack(ps).to(device) |
|
|
|
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) |
|
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
|
scores = scores.to(torch.float32) |
|
return scores |
|
|
|
@staticmethod |
|
def score_multi_vector( |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
batch_size: int = 128, |
|
device: Optional[Union[str, torch.device]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. |
|
""" |
|
device = device or get_torch_device("auto") |
|
|
|
if len(qs) == 0: |
|
raise ValueError("No queries provided") |
|
if len(ps) == 0: |
|
raise ValueError("No passages provided") |
|
|
|
scores_list: List[torch.Tensor] = [] |
|
|
|
for i in range(0, len(qs), batch_size): |
|
scores_batch = [] |
|
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to( |
|
device |
|
) |
|
for j in range(0, len(ps), batch_size): |
|
ps_batch = torch.nn.utils.rnn.pad_sequence( |
|
ps[j : j + batch_size], batch_first=True, padding_value=0 |
|
).to(device) |
|
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) |
|
scores_batch = torch.cat(scores_batch, dim=1).cpu() |
|
scores_list.append(scores_batch) |
|
|
|
scores = torch.cat(scores_list, dim=0) |
|
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
|
scores = scores.to(torch.float32) |
|
return scores |
|
|
|
@abstractmethod |
|
def get_n_patches( |
|
self, |
|
image_size: Tuple[int, int], |
|
patch_size: int = 14, |
|
*args, |
|
**kwargs, |
|
) -> Tuple[int, int]: |
|
""" |
|
Get the number of patches (n_patches_x, n_patches_y) that will be used to process an |
|
image of size (height, width) with the given patch size. |
|
""" |
|
pass |
|
|