Spaces:
Running
Running
# filename: service.py | |
from __future__ import annotations | |
import asyncio | |
import logging | |
from enum import Enum | |
from typing import List, Union, Dict, Optional, NamedTuple, Any | |
from dataclasses import dataclass | |
from pathlib import Path | |
from io import BytesIO | |
from hashlib import md5 | |
from cachetools import LRUCache | |
import httpx | |
import numpy as np | |
import torch | |
from PIL import Image | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoProcessor, AutoModel | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class TextModelType(str, Enum): | |
""" | |
Enumeration of supported text models. | |
""" | |
MULTILINGUAL_E5_SMALL = "multilingual-e5-small" | |
MULTILINGUAL_E5_BASE = "multilingual-e5-base" | |
MULTILINGUAL_E5_LARGE = "multilingual-e5-large" | |
SNOWFLAKE_ARCTIC_EMBED_L_V2 = "snowflake-arctic-embed-l-v2.0" | |
PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2" | |
PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = "paraphrase-multilingual-mpnet-base-v2" | |
BGE_M3 = "bge-m3" | |
GTE_MULTILINGUAL_BASE = "gte-multilingual-base" | |
class ImageModelType(str, Enum): | |
""" | |
Enumeration of supported image models. | |
""" | |
SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual" | |
class MaxModelLength(str, Enum): | |
""" | |
Enumeration of maximum token lengths for supported text models. | |
""" | |
MULTILINGUAL_E5_SMALL = 512 | |
MULTILINGUAL_E5_BASE = 512 | |
MULTILINGUAL_E5_LARGE = 512 | |
SNOWFLAKE_ARCTIC_EMBED_L_V2 = 8192 | |
PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = 128 | |
PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = 128 | |
BGE_M3 = 8192 | |
GTE_MULTILINGUAL_BASE = 8192 | |
class ModelInfo(NamedTuple): | |
""" | |
Container mapping a model type to its model identifier and optional ONNX file. | |
""" | |
model_id: str | |
onnx_file: Optional[str] = None | |
class ModelConfig: | |
""" | |
Configuration for text and image models. | |
""" | |
text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL | |
image_model_type: ImageModelType = ( | |
ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL | |
) | |
logit_scale: float = 4.60517 # Example scale used in cross-modal similarity | |
def text_model_info(self) -> ModelInfo: | |
""" | |
Return model information for the configured text model. | |
""" | |
text_configs = { | |
TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo( | |
model_id="Xenova/multilingual-e5-small", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
TextModelType.MULTILINGUAL_E5_BASE: ModelInfo( | |
model_id="Xenova/multilingual-e5-base", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
TextModelType.MULTILINGUAL_E5_LARGE: ModelInfo( | |
model_id="Xenova/multilingual-e5-large", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
TextModelType.SNOWFLAKE_ARCTIC_EMBED_L_V2: ModelInfo( | |
model_id="Snowflake/snowflake-arctic-embed-l-v2.0", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo( | |
model_id="Xenova/paraphrase-multilingual-MiniLM-L12-v2", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
TextModelType.PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2: ModelInfo( | |
model_id="Xenova/paraphrase-multilingual-mpnet-base-v2", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
TextModelType.BGE_M3: ModelInfo( | |
model_id="Xenova/bge-m3", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
TextModelType.GTE_MULTILINGUAL_BASE: ModelInfo( | |
model_id="onnx-community/gte-multilingual-base", | |
onnx_file="onnx/model_quantized.onnx", | |
), | |
} | |
return text_configs[self.text_model_type] | |
def image_model_info(self) -> ModelInfo: | |
""" | |
Return model information for the configured image model. | |
""" | |
image_configs = { | |
ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo( | |
model_id="google/siglip-base-patch16-256-multilingual" | |
), | |
} | |
return image_configs[self.image_model_type] | |
class ModelKind(str, Enum): | |
""" | |
Indicates the type of model: text or image. | |
""" | |
TEXT = "text" | |
IMAGE = "image" | |
def detect_model_kind(model_id: str) -> ModelKind: | |
""" | |
Detect whether the model identifier corresponds to a text or image model. | |
Raises: | |
ValueError: If the model identifier is unrecognized. | |
""" | |
if model_id in [m.value for m in TextModelType]: | |
return ModelKind.TEXT | |
elif model_id in [m.value for m in ImageModelType]: | |
return ModelKind.IMAGE | |
else: | |
raise ValueError( | |
f"Unrecognized model ID: {model_id}.\n" | |
f"Valid text: {[m.value for m in TextModelType]}\n" | |
f"Valid image: {[m.value for m in ImageModelType]}" | |
) | |
class EmbeddingsService: | |
""" | |
Service for generating text/image embeddings and performing similarity ranking. | |
Asynchronous methods are used to maximize throughput and avoid blocking the event loop. | |
""" | |
def __init__(self, config: Optional[ModelConfig] = None): | |
""" | |
Initialize the service by setting up model caches, device configuration, | |
and asynchronous HTTP client. | |
""" | |
self.lru_cache = LRUCache(maxsize=10_000) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.config = config or ModelConfig() | |
# Dictionaries to hold preloaded models. | |
self.text_models: Dict[TextModelType, SentenceTransformer] = {} | |
self.image_models: Dict[ImageModelType, AutoModel] = {} | |
self.image_processors: Dict[ImageModelType, AutoProcessor] = {} | |
# Create a persistent asynchronous HTTP client. | |
self.async_http_client = httpx.AsyncClient(timeout=10) | |
# Preload all models. | |
self._load_all_models() | |
def _load_all_models(self) -> None: | |
""" | |
Pre-load all text and image models to minimize latency at request time. | |
""" | |
try: | |
# Preload text models. | |
for t_model_type in TextModelType: | |
info = ModelConfig(text_model_type=t_model_type).text_model_info | |
logger.info("Loading text model: %s", info.model_id) | |
if info.onnx_file: | |
logger.info("Using ONNX file: %s", info.onnx_file) | |
self.text_models[t_model_type] = SentenceTransformer( | |
info.model_id, | |
device=self.device, | |
backend="onnx", | |
model_kwargs={ | |
"provider": "CPUExecutionProvider", | |
"file_name": info.onnx_file, | |
}, | |
trust_remote_code=True, | |
) | |
else: | |
self.text_models[t_model_type] = SentenceTransformer( | |
info.model_id, | |
device=self.device, | |
trust_remote_code=True, | |
) | |
# Set maximum sequence length based on configuration. | |
max_length = int(MaxModelLength[t_model_type.name].value) | |
self.text_models[t_model_type].max_seq_length = max_length | |
logger.info( | |
"Set max_seq_length=%d for text model: %s", | |
max_length, | |
info.model_id, | |
) | |
# Preload image models. | |
for i_model_type in ImageModelType: | |
model_id = ModelConfig( | |
image_model_type=i_model_type | |
).image_model_info.model_id | |
logger.info("Loading image model: %s", model_id) | |
model = AutoModel.from_pretrained(model_id).to(self.device) | |
model.eval() # Set the model to evaluation mode. | |
processor = AutoProcessor.from_pretrained(model_id) | |
self.image_models[i_model_type] = model | |
self.image_processors[i_model_type] = processor | |
logger.info("All models loaded successfully.") | |
except Exception as e: | |
msg = f"Error loading models: {str(e)}" | |
logger.error(msg) | |
raise RuntimeError(msg) from e | |
def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]: | |
""" | |
Validate and convert text input into a non-empty list of strings. | |
Raises: | |
ValueError: If the input is invalid. | |
""" | |
if isinstance(input_text, str): | |
if not input_text.strip(): | |
raise ValueError("Text input cannot be empty.") | |
return [input_text] | |
if not isinstance(input_text, list) or not all( | |
isinstance(x, str) for x in input_text | |
): | |
raise ValueError("Text input must be a string or a list of strings.") | |
if len(input_text) == 0: | |
raise ValueError("Text input list cannot be empty.") | |
return input_text | |
def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]: | |
""" | |
Validate and convert image input into a non-empty list of image paths/URLs. | |
Raises: | |
ValueError: If the input is invalid. | |
""" | |
if isinstance(input_images, str): | |
if not input_images.strip(): | |
raise ValueError("Image input cannot be empty.") | |
return [input_images] | |
if not isinstance(input_images, list) or not all( | |
isinstance(x, str) for x in input_images | |
): | |
raise ValueError("Image input must be a string or a list of strings.") | |
if len(input_images) == 0: | |
raise ValueError("Image input list cannot be empty.") | |
return input_images | |
def _truncate_text(self, text: str, model: SentenceTransformer) -> str: | |
""" | |
Truncate the input text to the maximum allowed tokens for the given model. | |
Args: | |
text: The input text. | |
model: The SentenceTransformer model used for tokenization. | |
Returns: | |
The truncated text if token length exceeds the maximum allowed length, | |
otherwise the original text. | |
""" | |
try: | |
# Attempt to get the tokenizer from the first module of the SentenceTransformer. | |
module = model._first_module() | |
if not hasattr(module, "tokenizer"): | |
return text | |
tokenizer = module.tokenizer | |
# Tokenize without truncation. | |
encoded = tokenizer(text, add_special_tokens=True, truncation=False) | |
max_length = model.max_seq_length | |
if len(encoded["input_ids"]) > max_length: | |
truncated_ids = encoded["input_ids"][:max_length] | |
truncated_text = tokenizer.decode( | |
truncated_ids, skip_special_tokens=True | |
) | |
return truncated_text | |
except Exception as e: | |
logger.warning("Error during text truncation: %s", str(e)) | |
return text | |
async def _fetch_image(self, path_or_url: str) -> Image.Image: | |
""" | |
Asynchronously fetch an image from a URL or load from a local path. | |
Args: | |
path_or_url: The URL or file path of the image. | |
Returns: | |
A PIL Image in RGB mode. | |
Raises: | |
ValueError: If image fetching or processing fails. | |
""" | |
try: | |
if path_or_url.startswith("http"): | |
# Asynchronously fetch the image bytes. | |
response = await self.async_http_client.get(path_or_url) | |
response.raise_for_status() | |
# Offload the blocking I/O (PIL image opening) to a thread. | |
img = await asyncio.to_thread(Image.open, BytesIO(response.content)) | |
else: | |
# Offload file I/O to a thread. | |
img = await asyncio.to_thread(Image.open, Path(path_or_url)) | |
return img.convert("RGB") | |
except Exception as e: | |
raise ValueError(f"Error fetching image '{path_or_url}': {str(e)}") from e | |
async def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]: | |
""" | |
Asynchronously load and process a single image. | |
Args: | |
path_or_url: The image URL or local path. | |
Returns: | |
A dictionary of processed tensors ready for model input. | |
Raises: | |
ValueError: If image processing fails. | |
""" | |
img = await self._fetch_image(path_or_url) | |
processor = self.image_processors[self.config.image_model_type] | |
# Note: Processor may perform CPU-intensive work; if needed, offload to thread. | |
processed_data = processor(images=img, return_tensors="pt").to(self.device) | |
return processed_data | |
def _generate_text_embeddings( | |
self, model_id: TextModelType, texts: List[str] | |
) -> np.ndarray: | |
""" | |
Generate text embeddings using the SentenceTransformer model. | |
Single-text requests are cached using an LRU cache. | |
Args: | |
model_id: The text model type. | |
texts: A list of input texts. | |
Returns: | |
A NumPy array of text embeddings. | |
Raises: | |
RuntimeError: If text embedding generation fails. | |
""" | |
try: | |
if len(texts) == 1: | |
single_text = texts[0] | |
key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8] | |
if key in self.lru_cache: | |
return self.lru_cache[key] | |
model = self.text_models[model_id] | |
emb = model.encode([single_text]) | |
self.lru_cache[key] = emb | |
return emb | |
model = self.text_models[model_id] | |
return model.encode(texts) | |
except Exception as e: | |
raise RuntimeError( | |
f"Error generating text embeddings with model '{model_id}': {e}" | |
) from e | |
async def _async_generate_image_embeddings( | |
self, model_id: ImageModelType, images: List[str] | |
) -> np.ndarray: | |
""" | |
Asynchronously generate image embeddings. | |
This method concurrently processes multiple images and offloads | |
the blocking model inference to a separate thread. | |
Args: | |
model_id: The image model type. | |
images: A list of image URLs or file paths. | |
Returns: | |
A NumPy array of image embeddings. | |
Raises: | |
RuntimeError: If image embedding generation fails. | |
""" | |
try: | |
# Concurrently process all images. | |
processed_tensors = await asyncio.gather( | |
*[self._process_image(img_path) for img_path in images] | |
) | |
# Assume all processed outputs have the same keys. | |
keys = processed_tensors[0].keys() | |
combined = { | |
k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys | |
} | |
def infer(): | |
with torch.no_grad(): | |
embeddings = self.image_models[model_id].get_image_features( | |
**combined | |
) | |
return embeddings.cpu().numpy() | |
return await asyncio.to_thread(infer) | |
except Exception as e: | |
raise RuntimeError( | |
f"Error generating image embeddings with model '{model_id}': {e}" | |
) from e | |
async def generate_embeddings( | |
self, model: str, inputs: Union[str, List[str]] | |
) -> np.ndarray: | |
""" | |
Asynchronously generate embeddings for text or image inputs based on model type. | |
Args: | |
model: The model identifier. | |
inputs: The text or image input(s). | |
Returns: | |
A NumPy array of embeddings. | |
""" | |
modality = detect_model_kind(model) | |
if modality == ModelKind.TEXT: | |
text_model_enum = TextModelType(model) | |
text_list = self._validate_text_list(inputs) | |
model_instance = self.text_models[text_model_enum] | |
# Truncate each text if it exceeds the maximum allowed token length. | |
truncated_texts = [ | |
self._truncate_text(text, model_instance) for text in text_list | |
] | |
return await asyncio.to_thread( | |
self._generate_text_embeddings, text_model_enum, truncated_texts | |
) | |
elif modality == ModelKind.IMAGE: | |
image_model_enum = ImageModelType(model) | |
image_list = self._validate_image_list(inputs) | |
return await self._async_generate_image_embeddings( | |
image_model_enum, image_list | |
) | |
async def rank( | |
self, | |
model: str, | |
queries: Union[str, List[str]], | |
candidates: Union[str, List[str]], | |
) -> Dict[str, Any]: | |
""" | |
Asynchronously rank candidate texts/images against the provided queries. | |
Embeddings for queries and candidates are generated concurrently. | |
Args: | |
model: The model identifier. | |
queries: The query input(s). | |
candidates: The candidate input(s). | |
Returns: | |
A dictionary containing probabilities, cosine similarities, and usage statistics. | |
""" | |
modality = detect_model_kind(model) | |
if modality == ModelKind.TEXT: | |
model_enum = TextModelType(model) | |
else: | |
model_enum = ImageModelType(model) | |
# Concurrently generate embeddings. | |
query_task = asyncio.create_task(self.generate_embeddings(model, queries)) | |
candidate_task = asyncio.create_task( | |
self.generate_embeddings(model, candidates) | |
) | |
query_embeds, candidate_embeds = await asyncio.gather( | |
query_task, candidate_task | |
) | |
# Compute cosine similarity. | |
sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds) | |
scaled = np.exp(self.config.logit_scale) * sim_matrix | |
probs = self.softmax(scaled) | |
if modality == ModelKind.TEXT: | |
query_tokens = self.estimate_tokens(queries) | |
candidate_tokens = self.estimate_tokens(candidates) | |
total_tokens = query_tokens + candidate_tokens | |
else: | |
total_tokens = 0 | |
usage = { | |
"prompt_tokens": total_tokens, | |
"total_tokens": total_tokens, | |
} | |
return { | |
"probabilities": probs.tolist(), | |
"cosine_similarities": sim_matrix.tolist(), | |
"usage": usage, | |
} | |
def estimate_tokens(self, input_data: Union[str, List[str]]) -> int: | |
""" | |
Estimate the token count for the given text input using the SentenceTransformer tokenizer. | |
Args: | |
input_data: The text input(s). | |
Returns: | |
The total number of tokens. | |
""" | |
texts = self._validate_text_list(input_data) | |
model = self.text_models[self.config.text_model_type] | |
tokenized = model.tokenize(texts) | |
return sum(len(ids) for ids in tokenized["input_ids"]) | |
def softmax(scores: np.ndarray) -> np.ndarray: | |
""" | |
Compute the softmax over the last dimension of the input array. | |
Args: | |
scores: A NumPy array of scores. | |
Returns: | |
A NumPy array of softmax probabilities. | |
""" | |
exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True)) | |
return exps / np.sum(exps, axis=-1, keepdims=True) | |
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray: | |
""" | |
Compute the pairwise cosine similarity between all rows of arrays a and b. | |
Args: | |
a: A NumPy array. | |
b: A NumPy array. | |
Returns: | |
A (N x M) matrix of cosine similarities. | |
""" | |
a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9) | |
b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9) | |
return np.dot(a_norm, b_norm.T) | |
async def close(self) -> None: | |
""" | |
Close the asynchronous HTTP client. | |
""" | |
await self.async_http_client.aclose() | |