Spaces:
Runtime error
Runtime error
File size: 10,505 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
"""
LangChain MongoDB Caches
Functions "_loads_generations" and "_dumps_generations"
are duplicated in this utility from modules:
- "libs/community/langchain_community/cache.py"
"""
import json
import logging
import time
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Union
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.outputs import Generation
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.driver_info import DriverInfo
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
logger = logging.getLogger(__file__)
def _generate_mongo_client(connection_string: str) -> MongoClient:
return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This function (+ its counterpart `_loads_generations`) rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
See `_dumps_generations`, the inverse of this function.
Args:
generations_str (str): A string representing a list of generations.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
Returns:
RETURN_VAL_TYPE: A list of generations.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
def _wait_until(
predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
"""Wait up to 10 seconds (by default) for predicate to be true.
E.g.:
wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval
if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)
time.sleep(interval)
class MongoDBCache(BaseCache):
"""MongoDB Atlas cache
A cache that uses MongoDB Atlas as a backend
"""
PROMPT = "prompt"
LLM = "llm"
RETURN_VAL = "return_val"
def __init__(
self,
connection_string: str,
collection_name: str = "default",
database_name: str = "default",
**kwargs: Dict[str, Any],
) -> None:
"""
Initialize Atlas Cache. Creates collection on instantiation
Args:
collection_name (str): Name of collection for cache to live.
Defaults to "default".
connection_string (str): Connection URI to MongoDB Atlas.
Defaults to "default".
database_name (str): Name of database for cache to live.
Defaults to "default".
"""
self.client = _generate_mongo_client(connection_string)
self.__database_name = database_name
self.__collection_name = collection_name
if self.__collection_name not in self.database.list_collection_names():
self.database.create_collection(self.__collection_name)
# Create an index on key and llm_string
self.collection.create_index([self.PROMPT, self.LLM])
@property
def database(self) -> Database:
"""Returns the database used to store cache values."""
return self.client[self.__database_name]
@property
def collection(self) -> Collection:
"""Returns the collection used to store cache values."""
return self.database[self.__collection_name]
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
return_doc = (
self.collection.find_one(self._generate_keys(prompt, llm_string)) or {}
)
return_val = return_doc.get(self.RETURN_VAL)
return _loads_generations(return_val) if return_val else None # type: ignore
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.collection.update_one(
{**self._generate_keys(prompt, llm_string)},
{"$set": {self.RETURN_VAL: _dumps_generations(return_val)}},
upsert=True,
)
def _generate_keys(self, prompt: str, llm_string: str) -> Dict[str, str]:
"""Create keyed fields for caching layer"""
return {self.PROMPT: prompt, self.LLM: llm_string}
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted.
E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
"""MongoDB Atlas Semantic cache.
A Cache backed by a MongoDB Atlas server with vector-store support
"""
LLM = "llm_string"
RETURN_VAL = "return_val"
def __init__(
self,
connection_string: str,
embedding: Embeddings,
collection_name: str = "default",
database_name: str = "default",
index_name: str = "default",
wait_until_ready: bool = False,
score_threshold: Optional[float] = None,
**kwargs: Dict[str, Any],
):
"""
Initialize Atlas VectorSearch Cache.
Assumes collection exists before instantiation
Args:
connection_string (str): MongoDB URI to connect to MongoDB Atlas cluster.
embedding (Embeddings): Text embedding model to use.
collection_name (str): MongoDB Collection to add the texts to.
Defaults to "default".
database_name (str): MongoDB Database where to store texts.
Defaults to "default".
index_name: Name of the Atlas Search index.
defaults to 'default'
wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
the stored text. Hard timeout of 10 seconds. Defaults to False.
"""
client = _generate_mongo_client(connection_string)
self.collection = client[database_name][collection_name]
self.score_threshold = score_threshold
self._wait_until_ready = wait_until_ready
super().__init__(
collection=self.collection,
embedding=embedding,
index_name=index_name,
**kwargs, # type: ignore
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
post_filter_pipeline = (
[{"$match": {"score": {"$gte": self.score_threshold}}}]
if self.score_threshold
else None
)
search_response = self.similarity_search_with_score(
prompt,
1,
pre_filter={self.LLM: {"$eq": llm_string}},
post_filter_pipeline=post_filter_pipeline,
)
if search_response:
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
response = _loads_generations(return_val) or return_val # type: ignore
return response
return None
def update(
self,
prompt: str,
llm_string: str,
return_val: RETURN_VAL_TYPE,
wait_until_ready: Optional[bool] = None,
) -> None:
"""Update cache based on prompt and llm_string."""
self.add_texts(
[prompt],
[
{
self.LLM: llm_string,
self.RETURN_VAL: _dumps_generations(return_val),
}
],
)
wait = self._wait_until_ready if wait_until_ready is None else wait_until_ready
def is_indexed() -> bool:
return self.lookup(prompt, llm_string) == return_val
if wait:
_wait_until(is_indexed, return_val)
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted. It will delete any locally cached content regardless
E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
|