Spaces:
Running
Running
"""Cache test.""" | |
from typing import Dict, Type, cast | |
import numpy as np | |
import pytest | |
from redis import Redis | |
from sqlitedict import SqliteDict | |
from manifest.caches.cache import Cache | |
from manifest.caches.noop import NoopCache | |
from manifest.caches.postgres import PostgresCache | |
from manifest.caches.redis import RedisCache | |
from manifest.caches.sqlite import SQLiteCache | |
from manifest.request import DiffusionRequest, LMRequest, Request | |
from manifest.response import ArrayModelChoice, ModelChoices, Response | |
def _get_postgres_cache( | |
request_type: Type[Request] = LMRequest, cache_args: Dict = {} | |
) -> Cache: # type: ignore | |
"""Get postgres cache.""" | |
cache_args.update({"cache_user": "", "cache_password": "", "cache_db": ""}) | |
return PostgresCache( | |
"postgres", | |
request_type=request_type, | |
cache_args=cache_args, | |
) | |
def test_init( | |
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str | |
) -> None: | |
"""Test cache initialization.""" | |
if cache_type == "sqlite": | |
sql_cache_obj = SQLiteCache(sqlite_cache) | |
assert isinstance(sql_cache_obj.cache, SqliteDict) | |
elif cache_type == "redis": | |
redis_cache_obj = RedisCache(redis_cache) | |
assert isinstance(redis_cache_obj.redis, Redis) | |
elif cache_type == "postgres": | |
postgres_cache_obj = _get_postgres_cache() | |
isinstance(postgres_cache_obj, PostgresCache) | |
def test_key_get_and_set( | |
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str | |
) -> None: | |
"""Test cache key get and set.""" | |
if cache_type == "sqlite": | |
cache = cast(Cache, SQLiteCache(sqlite_cache)) | |
elif cache_type == "redis": | |
cache = cast(Cache, RedisCache(redis_cache)) | |
elif cache_type == "postgres": | |
cache = cast(Cache, _get_postgres_cache()) | |
cache.set_key("test", "valueA") | |
cache.set_key("testA", "valueB") | |
assert cache.get_key("test") == "valueA" | |
assert cache.get_key("testA") == "valueB" | |
cache.set_key("testA", "valueC") | |
assert cache.get_key("testA") == "valueC" | |
cache.get_key("test", table="prompt") is None | |
cache.set_key("test", "valueA", table="prompt") | |
cache.get_key("test", table="prompt") == "valueA" | |
def test_get( | |
sqlite_cache: str, | |
redis_cache: str, | |
postgres_cache: str, | |
cache_type: str, | |
model_choice: ModelChoices, | |
model_choice_single: ModelChoices, | |
model_choice_arr_int: ModelChoices, | |
request_lm: LMRequest, | |
request_lm_single: LMRequest, | |
request_diff: DiffusionRequest, | |
) -> None: | |
"""Test cache save prompt.""" | |
if cache_type == "sqlite": | |
cache = cast(Cache, SQLiteCache(sqlite_cache)) | |
elif cache_type == "redis": | |
cache = cast(Cache, RedisCache(redis_cache)) | |
elif cache_type == "postgres": | |
cache = cast(Cache, _get_postgres_cache()) | |
response = Response( | |
response=model_choice_single, | |
cached=False, | |
request=request_lm_single, | |
usages=None, | |
request_type=LMRequest, | |
response_type="text", | |
) | |
cache_response = cache.get(request_lm_single.dict()) | |
assert cache_response is None | |
cache.set(request_lm_single.dict(), response.to_dict(drop_request=True)) | |
cache_response = cache.get(request_lm_single.dict()) | |
assert cache_response.get_response() == "helloo" | |
assert cache_response.is_cached() | |
assert cache_response.get_request_obj() == request_lm_single | |
response = Response( | |
response=model_choice, | |
cached=False, | |
request=request_lm, | |
usages=None, | |
request_type=LMRequest, | |
response_type="text", | |
) | |
cache_response = cache.get(request_lm.dict()) | |
assert cache_response is None | |
cache.set(request_lm.dict(), response.to_dict(drop_request=True)) | |
cache_response = cache.get(request_lm.dict()) | |
assert cache_response.get_response() == ["hello", "bye"] | |
assert cache_response.is_cached() | |
assert cache_response.get_request_obj() == request_lm | |
# Test array | |
response = Response( | |
response=model_choice_arr_int, | |
cached=False, | |
request=request_diff, | |
usages=None, | |
request_type=DiffusionRequest, | |
response_type="array", | |
) | |
if cache_type == "sqlite": | |
cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest) | |
elif cache_type == "redis": | |
cache = RedisCache(redis_cache, request_type=DiffusionRequest) | |
elif cache_type == "postgres": | |
cache = _get_postgres_cache(request_type=DiffusionRequest) | |
cache_response = cache.get(request_diff.dict()) | |
assert cache_response is None | |
cache.set(request_diff.dict(), response.to_dict(drop_request=True)) | |
cached_response = cache.get(request_diff.dict()) | |
assert np.allclose( | |
cached_response.get_response()[0], | |
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array, | |
) | |
assert np.allclose( | |
cached_response.get_response()[1], | |
cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array, | |
) | |
assert cached_response.is_cached() | |
assert cached_response.get_request_obj() == request_diff | |
# Test array byte string | |
# Make sure to not hit the cache | |
new_request_diff = DiffusionRequest(**request_diff.dict()) | |
new_request_diff.prompt = ["blahhh", "yayayay"] | |
response = Response( | |
response=model_choice_arr_int, | |
cached=False, | |
request=new_request_diff, | |
usages=None, | |
request_type=DiffusionRequest, | |
response_type="array", | |
) | |
if cache_type == "sqlite": | |
cache = SQLiteCache( | |
sqlite_cache, | |
request_type=DiffusionRequest, | |
cache_args={"array_serializer": "byte_string"}, | |
) | |
elif cache_type == "redis": | |
cache = RedisCache( | |
redis_cache, | |
request_type=DiffusionRequest, | |
cache_args={"array_serializer": "byte_string"}, | |
) | |
elif cache_type == "postgres": | |
cache = _get_postgres_cache( | |
request_type=DiffusionRequest, | |
cache_args={"array_serializer": "byte_string"}, | |
) | |
cached_response = cache.get(new_request_diff.dict()) | |
assert cached_response is None | |
cache.set(new_request_diff.dict(), response.to_dict(drop_request=True)) | |
cached_response = cache.get(new_request_diff.dict()) | |
assert np.allclose( | |
cached_response.get_response()[0], | |
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array, | |
) | |
assert np.allclose( | |
cached_response.get_response()[1], | |
cast(ArrayModelChoice, model_choice_arr_int.choices[1]).array, | |
) | |
assert cached_response.is_cached() | |
assert cached_response.get_request_obj() == new_request_diff | |
def test_noop_cache() -> None: | |
"""Test cache that is a no-op cache.""" | |
cache = NoopCache(None) | |
cache.set_key("test", "valueA") | |
cache.set_key("testA", "valueB") | |
assert cache.get_key("test") is None | |
assert cache.get_key("testA") is None | |
cache.set_key("testA", "valueC") | |
assert cache.get_key("testA") is None | |
cache.get_key("test", table="prompt") is None | |
cache.set_key("test", "valueA", table="prompt") | |
cache.get_key("test", table="prompt") is None | |
# Assert always not cached | |
test_request = {"test": "hello", "testA": "world"} | |
test_response = {"choices": [{"text": "hello"}]} | |
response = cache.get(test_request) | |
assert response is None | |
cache.set(test_request, test_response) | |
response = cache.get(test_request) | |
assert response is None | |