Spaces:
Running
Running
File size: 5,052 Bytes
54f5afe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
"""RAGLite typing."""
import io
import pickle
from collections.abc import Callable
from typing import Any, Protocol
import numpy as np
from sqlalchemy.engine import Dialect
from sqlalchemy.sql.operators import Operators
from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType
from raglite._config import RAGLiteConfig
FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]
class SearchMethod(Protocol):
def __call__(
self, query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
) -> tuple[list[str], list[float]]: ...
class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
"""A NumPy array column type for SQLAlchemy."""
impl = LargeBinary
def process_bind_param(
self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
) -> bytes | None:
"""Convert a NumPy array to bytes."""
if value is None:
return None
buffer = io.BytesIO()
np.save(buffer, value, allow_pickle=False, fix_imports=False)
return buffer.getvalue()
def process_result_value(
self, value: bytes | None, dialect: Dialect
) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
"""Convert bytes to a NumPy array."""
if value is None:
return None
return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return]
class PickledObject(TypeDecorator[object]):
"""A pickled object column type for SQLAlchemy."""
impl = LargeBinary
def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None:
"""Convert a Python object to bytes."""
if value is None:
return None
return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False)
def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None:
"""Convert bytes to a Python object."""
if value is None:
return None
return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301
class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]):
"""A mixin that provides comparison operators for halfvecs."""
def cosine_distance(self, other: FloatVector) -> Operators:
"""Compute the cosine distance."""
return self.op("<=>", return_type=Float)(other)
def dot_distance(self, other: FloatVector) -> Operators:
"""Compute the dot product distance."""
return self.op("<#>", return_type=Float)(other)
def euclidean_distance(self, other: FloatVector) -> Operators:
"""Compute the Euclidean distance."""
return self.op("<->", return_type=Float)(other)
def l1_distance(self, other: FloatVector) -> Operators:
"""Compute the L1 distance."""
return self.op("<+>", return_type=Float)(other)
def l2_distance(self, other: FloatVector) -> Operators:
"""Compute the L2 distance."""
return self.op("<->", return_type=Float)(other)
class HalfVec(UserDefinedType[FloatVector]):
"""A PostgreSQL half-precision vector column type for SQLAlchemy."""
cache_ok = True # HalfVec is immutable.
def __init__(self, dim: int | None = None) -> None:
super().__init__()
self.dim = dim
def get_col_spec(self, **kwargs: Any) -> str:
return f"halfvec({self.dim})"
def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]:
"""Process NumPy ndarray to PostgreSQL halfvec format for bound parameters."""
def process(value: FloatVector | None) -> str | None:
return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None
return process
def result_processor(
self, dialect: Dialect, coltype: Any
) -> Callable[[str | None], FloatVector | None]:
"""Process PostgreSQL halfvec format to NumPy ndarray."""
def process(value: str | None) -> FloatVector | None:
if value is None:
return None
return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16)
return process
class comparator_factory(HalfVecComparatorMixin): # noqa: N801
...
class Embedding(TypeDecorator[FloatVector]):
"""An embedding column type for SQLAlchemy."""
cache_ok = True # Embedding is immutable.
impl = NumpyArray
def __init__(self, dim: int = -1):
super().__init__()
self.dim = dim
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]:
if dialect.name == "postgresql":
return dialect.type_descriptor(HalfVec(self.dim))
return dialect.type_descriptor(NumpyArray())
class comparator_factory(HalfVecComparatorMixin): # noqa: N801
...
|