"""RAGLite typing.""" import io import pickle from collections.abc import Callable from typing import Any, Protocol import numpy as np from sqlalchemy.engine import Dialect from sqlalchemy.sql.operators import Operators from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType from raglite._config import RAGLiteConfig FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]] IntVector = np.ndarray[tuple[int], np.dtype[np.intp]] class SearchMethod(Protocol): def __call__( self, query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None ) -> tuple[list[str], list[float]]: ... class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): """A NumPy array column type for SQLAlchemy.""" impl = LargeBinary def process_bind_param( self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect ) -> bytes | None: """Convert a NumPy array to bytes.""" if value is None: return None buffer = io.BytesIO() np.save(buffer, value, allow_pickle=False, fix_imports=False) return buffer.getvalue() def process_result_value( self, value: bytes | None, dialect: Dialect ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: """Convert bytes to a NumPy array.""" if value is None: return None return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return] class PickledObject(TypeDecorator[object]): """A pickled object column type for SQLAlchemy.""" impl = LargeBinary def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None: """Convert a Python object to bytes.""" if value is None: return None return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False) def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None: """Convert bytes to a Python object.""" if value is None: return None return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301 class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]): """A mixin that provides comparison operators for halfvecs.""" def cosine_distance(self, other: FloatVector) -> Operators: """Compute the cosine distance.""" return self.op("<=>", return_type=Float)(other) def dot_distance(self, other: FloatVector) -> Operators: """Compute the dot product distance.""" return self.op("<#>", return_type=Float)(other) def euclidean_distance(self, other: FloatVector) -> Operators: """Compute the Euclidean distance.""" return self.op("<->", return_type=Float)(other) def l1_distance(self, other: FloatVector) -> Operators: """Compute the L1 distance.""" return self.op("<+>", return_type=Float)(other) def l2_distance(self, other: FloatVector) -> Operators: """Compute the L2 distance.""" return self.op("<->", return_type=Float)(other) class HalfVec(UserDefinedType[FloatVector]): """A PostgreSQL half-precision vector column type for SQLAlchemy.""" cache_ok = True # HalfVec is immutable. def __init__(self, dim: int | None = None) -> None: super().__init__() self.dim = dim def get_col_spec(self, **kwargs: Any) -> str: return f"halfvec({self.dim})" def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]: """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters.""" def process(value: FloatVector | None) -> str | None: return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None return process def result_processor( self, dialect: Dialect, coltype: Any ) -> Callable[[str | None], FloatVector | None]: """Process PostgreSQL halfvec format to NumPy ndarray.""" def process(value: str | None) -> FloatVector | None: if value is None: return None return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16) return process class comparator_factory(HalfVecComparatorMixin): # noqa: N801 ... class Embedding(TypeDecorator[FloatVector]): """An embedding column type for SQLAlchemy.""" cache_ok = True # Embedding is immutable. impl = NumpyArray def __init__(self, dim: int = -1): super().__init__() self.dim = dim def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]: if dialect.name == "postgresql": return dialect.type_descriptor(HalfVec(self.dim)) return dialect.type_descriptor(NumpyArray()) class comparator_factory(HalfVecComparatorMixin): # noqa: N801 ...