Spaces:
Running
Running
File size: 3,622 Bytes
b247dc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""Setup for all tests."""
import os
import shutil
from pathlib import Path
from typing import Generator
import numpy as np
import pytest
import redis
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest
from manifest.response import ArrayModelChoice, LMModelChoice, ModelChoices
@pytest.fixture
def model_choice() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(
text="hello", token_logprobs=[0.1, 0.2], tokens=["hel", "lo"]
),
LMModelChoice(text="bye", token_logprobs=[0.3], tokens=["bye"]),
]
)
return model_choices
@pytest.fixture
def model_choice_single() -> ModelChoices:
"""Get dummy model choice."""
model_choices = ModelChoices(
choices=[
LMModelChoice(
text="helloo", token_logprobs=[0.1, 0.2], tokens=["hel", "loo"]
),
]
)
return model_choices
@pytest.fixture
def model_choice_arr() -> ModelChoices:
"""Get dummy model choice."""
np.random.seed(0)
model_choices = ModelChoices(
choices=[
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.1, 0.2]),
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.3]),
]
)
return model_choices
@pytest.fixture
def model_choice_arr_int() -> ModelChoices:
"""Get dummy model choice."""
np.random.seed(0)
model_choices = ModelChoices(
choices=[
ArrayModelChoice(
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.1, 0.2]
),
ArrayModelChoice(
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.3]
),
]
)
return model_choices
@pytest.fixture
def request_lm() -> LMRequest:
"""Get dummy request."""
request = LMRequest(prompt=["what", "cat"])
return request
@pytest.fixture
def request_lm_single() -> LMRequest:
"""Get dummy request."""
request = LMRequest(prompt="monkey", engine="dummy")
return request
@pytest.fixture
def request_array() -> EmbeddingRequest:
"""Get dummy request."""
request = EmbeddingRequest(prompt="hello")
return request
@pytest.fixture
def request_diff() -> DiffusionRequest:
"""Get dummy request."""
request = DiffusionRequest(prompt="hello")
return request
@pytest.fixture
def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]:
"""Sqlite Cache."""
cache = str(tmp_path / "sqlite_cache.sqlite")
yield cache
shutil.rmtree(cache, ignore_errors=True)
@pytest.fixture
def redis_cache() -> Generator[str, None, None]:
"""Redis cache."""
host = os.environ.get("REDIS_HOST", "localhost")
port = int(os.environ.get("REDIS_PORT", 6379))
yield f"{host}:{port}"
# Clear out the database
try:
db = redis.Redis(host=host, port=port)
db.flushdb()
# For better local testing, pass if redis DB not started
except redis.exceptions.ConnectionError:
pass
@pytest.fixture
def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None]:
"""Postgres cache."""
import sqlalchemy # type: ignore
# Replace the sqlalchemy.create_engine function with a function that returns an
# in-memory SQLite engine
url = sqlalchemy.engine.url.URL.create("sqlite", database=":memory:")
engine = sqlalchemy.create_engine(url)
monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine)
return engine # type: ignore
|