File size: 5,052 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""RAGLite typing."""

import io
import pickle
from collections.abc import Callable
from typing import Any, Protocol

import numpy as np
from sqlalchemy.engine import Dialect
from sqlalchemy.sql.operators import Operators
from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType

from raglite._config import RAGLiteConfig

FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]


class SearchMethod(Protocol):
    def __call__(
        self, query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
    ) -> tuple[list[str], list[float]]: ...


class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
    """A NumPy array column type for SQLAlchemy."""

    impl = LargeBinary

    def process_bind_param(
        self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
    ) -> bytes | None:
        """Convert a NumPy array to bytes."""
        if value is None:
            return None
        buffer = io.BytesIO()
        np.save(buffer, value, allow_pickle=False, fix_imports=False)
        return buffer.getvalue()

    def process_result_value(
        self, value: bytes | None, dialect: Dialect
    ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
        """Convert bytes to a NumPy array."""
        if value is None:
            return None
        return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False)  # type: ignore[no-any-return]


class PickledObject(TypeDecorator[object]):
    """A pickled object column type for SQLAlchemy."""

    impl = LargeBinary

    def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None:
        """Convert a Python object to bytes."""
        if value is None:
            return None
        return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False)

    def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None:
        """Convert bytes to a Python object."""
        if value is None:
            return None
        return pickle.loads(value, fix_imports=False)  # type: ignore[no-any-return]  # noqa: S301


class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]):
    """A mixin that provides comparison operators for halfvecs."""

    def cosine_distance(self, other: FloatVector) -> Operators:
        """Compute the cosine distance."""
        return self.op("<=>", return_type=Float)(other)

    def dot_distance(self, other: FloatVector) -> Operators:
        """Compute the dot product distance."""
        return self.op("<#>", return_type=Float)(other)

    def euclidean_distance(self, other: FloatVector) -> Operators:
        """Compute the Euclidean distance."""
        return self.op("<->", return_type=Float)(other)

    def l1_distance(self, other: FloatVector) -> Operators:
        """Compute the L1 distance."""
        return self.op("<+>", return_type=Float)(other)

    def l2_distance(self, other: FloatVector) -> Operators:
        """Compute the L2 distance."""
        return self.op("<->", return_type=Float)(other)


class HalfVec(UserDefinedType[FloatVector]):
    """A PostgreSQL half-precision vector column type for SQLAlchemy."""

    cache_ok = True  # HalfVec is immutable.

    def __init__(self, dim: int | None = None) -> None:
        super().__init__()
        self.dim = dim

    def get_col_spec(self, **kwargs: Any) -> str:
        return f"halfvec({self.dim})"

    def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]:
        """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters."""

        def process(value: FloatVector | None) -> str | None:
            return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None

        return process

    def result_processor(
        self, dialect: Dialect, coltype: Any
    ) -> Callable[[str | None], FloatVector | None]:
        """Process PostgreSQL halfvec format to NumPy ndarray."""

        def process(value: str | None) -> FloatVector | None:
            if value is None:
                return None
            return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16)

        return process

    class comparator_factory(HalfVecComparatorMixin):  # noqa: N801
        ...


class Embedding(TypeDecorator[FloatVector]):
    """An embedding column type for SQLAlchemy."""

    cache_ok = True  # Embedding is immutable.

    impl = NumpyArray

    def __init__(self, dim: int = -1):
        super().__init__()
        self.dim = dim

    def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]:
        if dialect.name == "postgresql":
            return dialect.type_descriptor(HalfVec(self.dim))
        return dialect.type_descriptor(NumpyArray())

    class comparator_factory(HalfVecComparatorMixin):  # noqa: N801
        ...