File size: 3,622 Bytes
b247dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Setup for all tests."""
import os
import shutil
from pathlib import Path
from typing import Generator

import numpy as np
import pytest
import redis

from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest
from manifest.response import ArrayModelChoice, LMModelChoice, ModelChoices


@pytest.fixture
def model_choice() -> ModelChoices:
    """Get dummy model choice."""
    model_choices = ModelChoices(
        choices=[
            LMModelChoice(
                text="hello", token_logprobs=[0.1, 0.2], tokens=["hel", "lo"]
            ),
            LMModelChoice(text="bye", token_logprobs=[0.3], tokens=["bye"]),
        ]
    )
    return model_choices


@pytest.fixture
def model_choice_single() -> ModelChoices:
    """Get dummy model choice."""
    model_choices = ModelChoices(
        choices=[
            LMModelChoice(
                text="helloo", token_logprobs=[0.1, 0.2], tokens=["hel", "loo"]
            ),
        ]
    )
    return model_choices


@pytest.fixture
def model_choice_arr() -> ModelChoices:
    """Get dummy model choice."""
    np.random.seed(0)
    model_choices = ModelChoices(
        choices=[
            ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.1, 0.2]),
            ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.3]),
        ]
    )
    return model_choices


@pytest.fixture
def model_choice_arr_int() -> ModelChoices:
    """Get dummy model choice."""
    np.random.seed(0)
    model_choices = ModelChoices(
        choices=[
            ArrayModelChoice(
                array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.1, 0.2]
            ),
            ArrayModelChoice(
                array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.3]
            ),
        ]
    )
    return model_choices


@pytest.fixture
def request_lm() -> LMRequest:
    """Get dummy request."""
    request = LMRequest(prompt=["what", "cat"])
    return request


@pytest.fixture
def request_lm_single() -> LMRequest:
    """Get dummy request."""
    request = LMRequest(prompt="monkey", engine="dummy")
    return request


@pytest.fixture
def request_array() -> EmbeddingRequest:
    """Get dummy request."""
    request = EmbeddingRequest(prompt="hello")
    return request


@pytest.fixture
def request_diff() -> DiffusionRequest:
    """Get dummy request."""
    request = DiffusionRequest(prompt="hello")
    return request


@pytest.fixture
def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]:
    """Sqlite Cache."""
    cache = str(tmp_path / "sqlite_cache.sqlite")
    yield cache
    shutil.rmtree(cache, ignore_errors=True)


@pytest.fixture
def redis_cache() -> Generator[str, None, None]:
    """Redis cache."""
    host = os.environ.get("REDIS_HOST", "localhost")
    port = int(os.environ.get("REDIS_PORT", 6379))
    yield f"{host}:{port}"
    # Clear out the database
    try:
        db = redis.Redis(host=host, port=port)
        db.flushdb()
    # For better local testing, pass if redis DB not started
    except redis.exceptions.ConnectionError:
        pass


@pytest.fixture
def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None]:
    """Postgres cache."""
    import sqlalchemy  # type: ignore

    # Replace the sqlalchemy.create_engine function with a function that returns an
    # in-memory SQLite engine
    url = sqlalchemy.engine.url.URL.create("sqlite", database=":memory:")
    engine = sqlalchemy.create_engine(url)
    monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine)
    return engine  # type: ignore