ColVintern-1B-v1 / processing_utils.py
khang119966's picture
Upload 4 files
c39b2dc verified
raw
history blame
3.63 kB
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature
from .torch_utils import get_torch_device
class BaseVisualRetrieverProcessor(ABC):
"""
Base class for visual retriever processors.
"""
@abstractmethod
def process_images(
self,
images: List[Image.Image],
) -> Union[BatchFeature, BatchEncoding]:
pass
@abstractmethod
def process_queries(
self,
queries: List[str],
max_length: int = 50,
suffix: Optional[str] = None,
) -> Union[BatchFeature, BatchEncoding]:
pass
@abstractmethod
def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
pass
@staticmethod
def score_single_vector(
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""
Compute the dot product score for the given single-vector query and passage embeddings.
"""
device = device or get_torch_device("auto")
if len(qs) == 0:
raise ValueError("No queries provided")
if len(ps) == 0:
raise ValueError("No passages provided")
qs_stacked = torch.stack(qs).to(device)
ps_stacked = torch.stack(ps).to(device)
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
scores = scores.to(torch.float32)
return scores
@staticmethod
def score_multi_vector(
qs: List[torch.Tensor],
ps: List[torch.Tensor],
batch_size: int = 128,
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
"""
device = device or get_torch_device("auto")
if len(qs) == 0:
raise ValueError("No queries provided")
if len(ps) == 0:
raise ValueError("No passages provided")
scores_list: List[torch.Tensor] = []
for i in range(0, len(qs), batch_size):
scores_batch = []
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
device
)
for j in range(0, len(ps), batch_size):
ps_batch = torch.nn.utils.rnn.pad_sequence(
ps[j : j + batch_size], batch_first=True, padding_value=0
).to(device)
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
scores_batch = torch.cat(scores_batch, dim=1).cpu()
scores_list.append(scores_batch)
scores = torch.cat(scores_list, dim=0)
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
scores = scores.to(torch.float32)
return scores
@abstractmethod
def get_n_patches(
self,
image_size: Tuple[int, int],
patch_size: int = 14,
*args,
**kwargs,
) -> Tuple[int, int]:
"""
Get the number of patches (n_patches_x, n_patches_y) that will be used to process an
image of size (height, width) with the given patch size.
"""
pass