Spaces:
Running
Running
"""Setup for all tests.""" | |
import os | |
import shutil | |
from pathlib import Path | |
from typing import Generator | |
import numpy as np | |
import pytest | |
import redis | |
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest | |
from manifest.response import ArrayModelChoice, LMModelChoice, ModelChoices | |
def model_choice() -> ModelChoices: | |
"""Get dummy model choice.""" | |
model_choices = ModelChoices( | |
choices=[ | |
LMModelChoice( | |
text="hello", token_logprobs=[0.1, 0.2], tokens=["hel", "lo"] | |
), | |
LMModelChoice(text="bye", token_logprobs=[0.3], tokens=["bye"]), | |
] | |
) | |
return model_choices | |
def model_choice_single() -> ModelChoices: | |
"""Get dummy model choice.""" | |
model_choices = ModelChoices( | |
choices=[ | |
LMModelChoice( | |
text="helloo", token_logprobs=[0.1, 0.2], tokens=["hel", "loo"] | |
), | |
] | |
) | |
return model_choices | |
def model_choice_arr() -> ModelChoices: | |
"""Get dummy model choice.""" | |
np.random.seed(0) | |
model_choices = ModelChoices( | |
choices=[ | |
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.1, 0.2]), | |
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.3]), | |
] | |
) | |
return model_choices | |
def model_choice_arr_int() -> ModelChoices: | |
"""Get dummy model choice.""" | |
np.random.seed(0) | |
model_choices = ModelChoices( | |
choices=[ | |
ArrayModelChoice( | |
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.1, 0.2] | |
), | |
ArrayModelChoice( | |
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.3] | |
), | |
] | |
) | |
return model_choices | |
def request_lm() -> LMRequest: | |
"""Get dummy request.""" | |
request = LMRequest(prompt=["what", "cat"]) | |
return request | |
def request_lm_single() -> LMRequest: | |
"""Get dummy request.""" | |
request = LMRequest(prompt="monkey", engine="dummy") | |
return request | |
def request_array() -> EmbeddingRequest: | |
"""Get dummy request.""" | |
request = EmbeddingRequest(prompt="hello") | |
return request | |
def request_diff() -> DiffusionRequest: | |
"""Get dummy request.""" | |
request = DiffusionRequest(prompt="hello") | |
return request | |
def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]: | |
"""Sqlite Cache.""" | |
cache = str(tmp_path / "sqlite_cache.sqlite") | |
yield cache | |
shutil.rmtree(cache, ignore_errors=True) | |
def redis_cache() -> Generator[str, None, None]: | |
"""Redis cache.""" | |
host = os.environ.get("REDIS_HOST", "localhost") | |
port = int(os.environ.get("REDIS_PORT", 6379)) | |
yield f"{host}:{port}" | |
# Clear out the database | |
try: | |
db = redis.Redis(host=host, port=port) | |
db.flushdb() | |
# For better local testing, pass if redis DB not started | |
except redis.exceptions.ConnectionError: | |
pass | |
def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None]: | |
"""Postgres cache.""" | |
import sqlalchemy # type: ignore | |
# Replace the sqlalchemy.create_engine function with a function that returns an | |
# in-memory SQLite engine | |
url = sqlalchemy.engine.url.URL.create("sqlite", database=":memory:") | |
engine = sqlalchemy.create_engine(url) | |
monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine) | |
return engine # type: ignore | |