Spaces:
Running
Running
File size: 2,304 Bytes
b247dc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""Test client pool."""
import time
import pytest
from manifest.connections.client_pool import ClientConnection, ClientConnectionPool
from manifest.request import LMRequest
def test_init() -> None:
"""Test initialization."""
client_connection1 = ClientConnection(
client_name="openai", client_connection="XXX", engine="text-davinci-002"
)
client_connection2 = ClientConnection(
client_name="openai", client_connection="XXX", engine="text-ada-001"
)
client_connection3 = ClientConnection(
client_name="openaiembedding", client_connection="XXX"
)
with pytest.raises(ValueError) as exc_info:
ClientConnectionPool(
[client_connection1, client_connection2], client_pool_scheduler="bad"
)
assert str(exc_info.value) == "Unknown scheduler: bad."
with pytest.raises(ValueError) as exc_info:
ClientConnectionPool([client_connection1, client_connection3])
assert (
str(exc_info.value)
== "All clients in the client pool must use the same request type. You have [\"<class 'manifest.request.EmbeddingRequest'>\", \"<class 'manifest.request.LMRequest'>\"]" # noqa: E501"
)
pool = ClientConnectionPool([client_connection1, client_connection2])
assert pool.request_type == LMRequest
assert len(pool.client_pool) == 2
assert len(pool.client_pool_metrics) == 2
assert pool.client_pool[0].engine == "text-davinci-002" # type: ignore
assert pool.client_pool[1].engine == "text-ada-001" # type: ignore
def test_timing() -> None:
"""Test timing client."""
client_connection1 = ClientConnection(client_name="dummy")
client_connection2 = ClientConnection(client_name="dummy")
connection_pool = ClientConnectionPool([client_connection1, client_connection2])
connection_pool.get_next_client()
assert connection_pool.current_client_id == 0
connection_pool.start_timer()
time.sleep(2)
connection_pool.end_timer()
connection_pool.get_next_client()
assert connection_pool.current_client_id == 1
connection_pool.start_timer()
time.sleep(1)
connection_pool.end_timer()
timing = connection_pool.client_pool_metrics
assert timing[0].end - timing[0].start > 1.9
assert timing[1].end - timing[1].start > 0.9
|