File size: 10,505 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""
LangChain MongoDB Caches

Functions "_loads_generations" and "_dumps_generations"
are duplicated in this utility from modules:
    - "libs/community/langchain_community/cache.py"
"""

import json
import logging
import time
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Union

from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.outputs import Generation
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.driver_info import DriverInfo

from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch

logger = logging.getLogger(__file__)


def _generate_mongo_client(connection_string: str) -> MongoClient:
    return MongoClient(
        connection_string,
        driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
    )


def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
    """
    Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`

    Args:
        generations (RETURN_VAL_TYPE): A list of language model generations.

    Returns:
        str: a single string representing a list of generations.

    This function (+ its counterpart `_loads_generations`) rely on
    the dumps/loads pair with Reviver, so are able to deal
    with all subclasses of Generation.

    Each item in the list can be `dumps`ed to a string,
    then we make the whole list of strings into a json-dumped.
    """
    return json.dumps([dumps(_item) for _item in generations])


def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
    """
    Deserialization of a string into a generic RETURN_VAL_TYPE
    (i.e. a sequence of `Generation`).

    See `_dumps_generations`, the inverse of this function.

    Args:
        generations_str (str): A string representing a list of generations.

    Compatible with the legacy cache-blob format
    Does not raise exceptions for malformed entries, just logs a warning
    and returns none: the caller should be prepared for such a cache miss.

    Returns:
        RETURN_VAL_TYPE: A list of generations.
    """
    try:
        generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
        return generations
    except (json.JSONDecodeError, TypeError):
        # deferring the (soft) handling to after the legacy-format attempt
        pass

    try:
        gen_dicts = json.loads(generations_str)
        # not relying on `_load_generations_from_json` (which could disappear):
        generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
        logger.warning(
            f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
        )
        return generations
    except (json.JSONDecodeError, TypeError):
        logger.warning(
            f"Malformed/unparsable cached blob encountered: '{generations_str}'"
        )
        return None


def _wait_until(
    predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
    """Wait up to 10 seconds (by default) for predicate to be true.

    E.g.:

        wait_until(lambda: client.primary == ('a', 1),
                'connect to the primary')

    If the lambda-expression isn't true after 10 seconds, we raise
    AssertionError("Didn't ever connect to the primary").

    Returns the predicate's first true value.
    """
    start = time.time()
    interval = min(float(timeout) / 100, 0.1)
    while True:
        retval = predicate()
        if retval:
            return retval

        if time.time() - start > timeout:
            raise TimeoutError("Didn't ever %s" % success_description)

        time.sleep(interval)


class MongoDBCache(BaseCache):
    """MongoDB Atlas cache

    A cache that uses MongoDB Atlas as a backend
    """

    PROMPT = "prompt"
    LLM = "llm"
    RETURN_VAL = "return_val"

    def __init__(
        self,
        connection_string: str,
        collection_name: str = "default",
        database_name: str = "default",
        **kwargs: Dict[str, Any],
    ) -> None:
        """
        Initialize Atlas Cache. Creates collection on instantiation

        Args:
            collection_name (str): Name of collection for cache to live.
                Defaults to "default".
            connection_string (str): Connection URI to MongoDB Atlas.
                Defaults to "default".
            database_name (str): Name of database for cache to live.
                Defaults to "default".
        """
        self.client = _generate_mongo_client(connection_string)
        self.__database_name = database_name
        self.__collection_name = collection_name

        if self.__collection_name not in self.database.list_collection_names():
            self.database.create_collection(self.__collection_name)
            # Create an index on key and llm_string
            self.collection.create_index([self.PROMPT, self.LLM])

    @property
    def database(self) -> Database:
        """Returns the database used to store cache values."""
        return self.client[self.__database_name]

    @property
    def collection(self) -> Collection:
        """Returns the collection used to store cache values."""
        return self.database[self.__collection_name]

    def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
        """Look up based on prompt and llm_string."""
        return_doc = (
            self.collection.find_one(self._generate_keys(prompt, llm_string)) or {}
        )
        return_val = return_doc.get(self.RETURN_VAL)
        return _loads_generations(return_val) if return_val else None  # type: ignore

    def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
        """Update cache based on prompt and llm_string."""
        self.collection.update_one(
            {**self._generate_keys(prompt, llm_string)},
            {"$set": {self.RETURN_VAL: _dumps_generations(return_val)}},
            upsert=True,
        )

    def _generate_keys(self, prompt: str, llm_string: str) -> Dict[str, str]:
        """Create keyed fields for caching layer"""
        return {self.PROMPT: prompt, self.LLM: llm_string}

    def clear(self, **kwargs: Any) -> None:
        """Clear cache that can take additional keyword arguments.
        Any additional arguments will propagate as filtration criteria for
        what gets deleted.

        E.g.
            # Delete only entries that have llm_string as "fake-model"
            self.clear(llm_string="fake-model")
        """
        self.collection.delete_many({**kwargs})


class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
    """MongoDB Atlas Semantic cache.

    A Cache backed by a MongoDB Atlas server with vector-store support
    """

    LLM = "llm_string"
    RETURN_VAL = "return_val"

    def __init__(
        self,
        connection_string: str,
        embedding: Embeddings,
        collection_name: str = "default",
        database_name: str = "default",
        index_name: str = "default",
        wait_until_ready: bool = False,
        score_threshold: Optional[float] = None,
        **kwargs: Dict[str, Any],
    ):
        """
        Initialize Atlas VectorSearch Cache.
        Assumes collection exists before instantiation

        Args:
            connection_string (str): MongoDB URI to connect to MongoDB Atlas cluster.
            embedding (Embeddings): Text embedding model to use.
            collection_name (str): MongoDB Collection to add the texts to.
                Defaults to "default".
            database_name (str): MongoDB Database where to store texts.
                Defaults to "default".
            index_name: Name of the Atlas Search index.
                defaults to 'default'
            wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
                the stored text. Hard timeout of 10 seconds. Defaults to False.
        """
        client = _generate_mongo_client(connection_string)
        self.collection = client[database_name][collection_name]
        self.score_threshold = score_threshold
        self._wait_until_ready = wait_until_ready
        super().__init__(
            collection=self.collection,
            embedding=embedding,
            index_name=index_name,
            **kwargs,  # type: ignore
        )

    def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
        """Look up based on prompt and llm_string."""
        post_filter_pipeline = (
            [{"$match": {"score": {"$gte": self.score_threshold}}}]
            if self.score_threshold
            else None
        )

        search_response = self.similarity_search_with_score(
            prompt,
            1,
            pre_filter={self.LLM: {"$eq": llm_string}},
            post_filter_pipeline=post_filter_pipeline,
        )
        if search_response:
            return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
            response = _loads_generations(return_val) or return_val  # type: ignore
            return response
        return None

    def update(
        self,
        prompt: str,
        llm_string: str,
        return_val: RETURN_VAL_TYPE,
        wait_until_ready: Optional[bool] = None,
    ) -> None:
        """Update cache based on prompt and llm_string."""
        self.add_texts(
            [prompt],
            [
                {
                    self.LLM: llm_string,
                    self.RETURN_VAL: _dumps_generations(return_val),
                }
            ],
        )
        wait = self._wait_until_ready if wait_until_ready is None else wait_until_ready

        def is_indexed() -> bool:
            return self.lookup(prompt, llm_string) == return_val

        if wait:
            _wait_until(is_indexed, return_val)

    def clear(self, **kwargs: Any) -> None:
        """Clear cache that can take additional keyword arguments.
        Any additional arguments will propagate as filtration criteria for
        what gets deleted. It will delete any locally cached content regardless

        E.g.
            # Delete only entries that have llm_string as "fake-model"
            self.clear(llm_string="fake-model")
        """
        self.collection.delete_many({**kwargs})