from typing import List, Dict, Union from ..embedding_provider import EmbeddingProvider import numpy as np class SentenceTransformerEmbedding(EmbeddingProvider): def __init__( self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", device: str = None, batch_size: int = 32, normalize_embeddings: bool = True ) -> None: """Initialize sentence transformer embedding provider Args: model_name (str, optional): Name of the sentence tranformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2". """ from sentence_transformers import SentenceTransformer if device is None: import torch self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = SentenceTransformer(model_name, device=device) self.model_name = model_name self.batch_size = batch_size self.normalize_embeddings = normalize_embeddings def embed_documents(self, documents: List[str]) -> np.ndarray: """Embed a list of documents Args: documents (List[str]): List of documents to embed """ return self.model.encode( documents, batch_size=self.batch_size, normalize_embeddings=self.normalize_embeddings ) def embed_query(self, query: str) -> np.ndarray: """Embed a single query Args: query (str): Query to embed Returns: np.ndarray: Embedding vector """ return self.model.encode( query, normalize_embeddings=self.normalize_embeddings ) def get_model_info(self) -> Dict[str, Union[str, int]]: """ Retrieve information about the current embedding model Returns: Dict: Model information """ return { "model_name": self.model_name, "device": self.device, "batch_size": self.batch_size, "normalize_embeddings": self.normalize_embeddings, "embedding_dim": self.model.get_sentence_embedding_dimension() } def list_available_models(self) -> List[str]: """ List some popular Sentence Transformer models Returns: List[str]: Available model names """ popular_models = [ "sentence-transformers/all-MiniLM-L6-v2", # Small and fast "sentence-transformers/all-mpnet-base-v2", # High performance "sentence-transformers/all-distilroberta-v1", # Lightweight "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", # Question Answering "sentence-transformers/multi-qa-mpnet-base-cos-v1", # Multilingual QA "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" # Multilingual ] return popular_models