Spaces:
Running
Running
"""Manifest test.""" | |
import asyncio | |
import os | |
from typing import Iterator, cast | |
from unittest.mock import MagicMock, Mock, patch | |
import numpy as np | |
import pytest | |
import requests | |
from requests import HTTPError | |
from manifest import Manifest, Response | |
from manifest.caches.noop import NoopCache | |
from manifest.caches.sqlite import SQLiteCache | |
from manifest.clients.dummy import DummyClient | |
from manifest.connections.client_pool import ClientConnection | |
URL = "http://localhost:6000" | |
try: | |
_ = requests.post(URL + "/params").json() | |
MODEL_ALIVE = True | |
except Exception: | |
MODEL_ALIVE = False | |
OPENAI_ALIVE = os.environ.get("OPENAI_API_KEY") is not None | |
def test_init(sqlite_cache: str) -> None: | |
"""Test manifest initialization.""" | |
with pytest.raises(ValueError) as exc_info: | |
Manifest( | |
client_name="dummy", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
sep_tok="", | |
) | |
assert str(exc_info.value) == "[('sep_tok', '')] arguments are not recognized." | |
manifest = Manifest( | |
client_name="dummy", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
) | |
assert len(manifest.client_pool.client_pool) == 1 | |
client = manifest.client_pool.get_next_client() | |
assert isinstance(client, DummyClient) | |
assert isinstance(manifest.cache, SQLiteCache) | |
assert client.n == 1 # type: ignore | |
assert manifest.stop_token == "" | |
manifest = Manifest( | |
client_name="dummy", | |
cache_name="noop", | |
n=3, | |
stop_token="\n", | |
) | |
assert len(manifest.client_pool.client_pool) == 1 | |
client = manifest.client_pool.get_next_client() | |
assert isinstance(client, DummyClient) | |
assert isinstance(manifest.cache, NoopCache) | |
assert client.n == 3 # type: ignore | |
assert manifest.stop_token == "\n" | |
def test_run(sqlite_cache: str, n: int, return_response: bool) -> None: | |
"""Test manifest run.""" | |
manifest = Manifest( | |
client_name="dummy", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
n=n, | |
temperature=0.0, | |
) | |
prompt = "This is a prompt" | |
with pytest.raises(ValueError) as exc_info: | |
result = manifest.run(prompt, return_response=return_response, bad_input=5) | |
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized." | |
result = manifest.run(prompt, return_response=return_response, top_k=5) | |
assert result is not None | |
prompt = "This is a prompt" | |
result = manifest.run(prompt, return_response=return_response) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(manifest.stop_token) | |
else: | |
res = cast(str, result) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": n, | |
"prompt": "This is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
if n == 1: | |
assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines" | |
else: | |
assert res == [ | |
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", | |
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", | |
] | |
prompt = "This is a prompt" | |
result = manifest.run(prompt, run_id="34", return_response=return_response) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(manifest.stop_token) | |
else: | |
res = cast(str, result) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": n, | |
"prompt": "This is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
"run_id": "34", | |
} | |
) | |
is not None | |
) | |
if n == 1: | |
assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines" | |
else: | |
assert res == [ | |
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", | |
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", | |
] | |
prompt = "Hello is a prompt" | |
result = manifest.run(prompt, return_response=return_response) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(manifest.stop_token) | |
else: | |
res = cast(str, result) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": n, | |
"prompt": "Hello is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
if n == 1: | |
assert res == "appersstoff210 currentNodeleh norm unified_voice DIYHam" | |
else: | |
assert res == [ | |
"appersstoff210 currentNodeleh norm unified_voice DIYHam", | |
"appersstoff210 currentNodeleh norm unified_voice DIYHam", | |
] | |
prompt = "Hello is a prompt" | |
result = manifest.run( | |
prompt, stop_token=" current", return_response=return_response | |
) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(stop_token=" current") | |
else: | |
res = cast(str, result) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": n, | |
"prompt": "Hello is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
if n == 1: | |
assert res == "appersstoff210" | |
else: | |
assert res == ["appersstoff210", "appersstoff210"] | |
def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None: | |
"""Test manifest run.""" | |
manifest = Manifest( | |
client_name="dummy", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
n=n, | |
temperature=0.0, | |
) | |
prompt = ["This is a prompt"] | |
if n == 2: | |
with pytest.raises(ValueError) as exc_info: | |
result = manifest.run(prompt, return_response=return_response) | |
assert str(exc_info.value) == "Batch mode does not support n > 1." | |
else: | |
result = manifest.run(prompt, return_response=return_response) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(manifest.stop_token, is_batch=True) | |
else: | |
res = cast(str, result) | |
assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"] | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": n, | |
"prompt": "This is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
prompt = ["Hello is a prompt", "Hello is a prompt"] | |
result = manifest.run(prompt, return_response=return_response) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(manifest.stop_token, is_batch=True) | |
else: | |
res = cast(str, result) | |
assert res == [ | |
"appersstoff210 currentNodeleh norm unified_voice DIYHam", | |
"appersstoff210 currentNodeleh norm unified_voice DIYHam", | |
] | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": n, | |
"prompt": "Hello is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
result = manifest.run(prompt, return_response=True) | |
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True) | |
assert cast(Response, result).is_cached() | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": n, | |
"prompt": "New prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is None | |
) | |
prompt = ["This is a prompt", "New prompt"] | |
result = manifest.run(prompt, return_response=return_response) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(manifest.stop_token, is_batch=True) | |
# Cached because one item is in cache | |
assert result.is_cached() | |
else: | |
res = cast(str, result) | |
assert res == [ | |
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", | |
".vol.deserializebigmnchantment ROTıl='')\najsС", | |
] | |
prompt = ["Hello is a prompt", "Hello is a prompt"] | |
result = manifest.run( | |
prompt, stop_token=" current", return_response=return_response | |
) | |
if return_response: | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len( | |
result.get_response_obj().choices | |
) | |
res = result.get_response(stop_token=" current", is_batch=True) | |
else: | |
res = cast(str, result) | |
assert res == ["appersstoff210", "appersstoff210"] | |
def test_abatch_run(sqlite_cache: str) -> None: | |
"""Test manifest run.""" | |
manifest = Manifest( | |
client_name="dummy", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
temperature=0.0, | |
) | |
prompt = ["This is a prompt"] | |
result = cast( | |
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) | |
) | |
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) | |
res = result.get_response(manifest.stop_token, is_batch=True) | |
assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"] | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": "This is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
prompt = ["Hello is a prompt", "Hello is a prompt"] | |
result = cast( | |
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) | |
) | |
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) | |
res = result.get_response(manifest.stop_token, is_batch=True) | |
assert res == [ | |
"appersstoff210 currentNodeleh norm unified_voice DIYHam", | |
"appersstoff210 currentNodeleh norm unified_voice DIYHam", | |
] | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": "Hello is a prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
result = cast( | |
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) | |
) | |
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) | |
res = result.get_response(manifest.stop_token, is_batch=True) | |
assert result.is_cached() | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": "New prompt", | |
"request_cls": "LMRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is None | |
) | |
prompt = ["This is a prompt", "New prompt"] | |
result = cast( | |
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) | |
) | |
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) | |
res = result.get_response(manifest.stop_token, is_batch=True) | |
# Cached because one item is in cache | |
assert result.is_cached() | |
assert res == [ | |
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", | |
".vol.deserializebigmnchantment ROTıl='')\najsС", | |
] | |
prompt = ["Hello is a prompt", "Hello is a prompt"] | |
result = cast( | |
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True)) | |
) | |
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) | |
res = result.get_response(stop_token=" current", is_batch=True) | |
assert res == ["appersstoff210", "appersstoff210"] | |
def test_run_chat(sqlite_cache: str) -> None: | |
"""Test manifest run.""" | |
manifest = Manifest( | |
client_name="dummy", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
temperature=0.0, | |
) | |
# Set CHAT to be true for this model | |
manifest.client_pool.client_pool[0].IS_CHAT = True | |
prompt = [ | |
{"role": "system", "content": "Hello."}, | |
] | |
result = manifest.run(prompt, return_response=False) | |
assert ( | |
result | |
== "ectors WortGo ré_sg|--------------------------------------------------------------------------\n contradictory Aad \u200b getUserId" # noqa: E501 | |
) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": [{"content": "Hello.", "role": "system"}], | |
"request_cls": "LMChatRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
prompt = [ | |
{"role": "system", "content": "Hello."}, | |
{"role": "user", "content": "Goodbye?"}, | |
] | |
result = manifest.run(prompt, return_response=True) | |
assert isinstance(result, Response) | |
result = cast(Response, result) | |
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) | |
res = result.get_response() | |
assert res == "_deploy_age_gp hora Plus Scheduler EisenhowerRF视 chemotherapy" | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": [ | |
{"role": "system", "content": "Hello."}, | |
{"role": "user", "content": "Goodbye?"}, | |
], | |
"request_cls": "LMChatRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
def test_score_run(sqlite_cache: str) -> None: | |
"""Test manifest run.""" | |
manifest = Manifest( | |
client_name="dummy", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
temperature=0.0, | |
) | |
prompt = "This is a prompt" | |
result = manifest.score_prompt(prompt) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": "This is a prompt", | |
"request_cls": "LMScoreRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
assert result == { | |
"response": { | |
"choices": [ | |
{ | |
"text": "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines", | |
"token_logprobs": [ | |
-1.827188890438529, | |
-1.6981601736417915, | |
-0.24606708391178755, | |
-1.9209383499010613, | |
-0.8833563758318617, | |
-1.4121369466920703, | |
-0.376352908076236, | |
-1.3200064558188096, | |
-0.813028447207917, | |
-0.5977255311239729, | |
], | |
"tokens": [ | |
"46078", | |
"21445", | |
"48305", | |
"7927", | |
"76125", | |
"46233", | |
"34581", | |
"23679", | |
"63021", | |
"78158", | |
], | |
} | |
] | |
}, | |
"usages": { | |
"usages": [ | |
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14} | |
] | |
}, | |
"cached": False, | |
"request": { | |
"prompt": "This is a prompt", | |
"engine": "text-davinci-003", | |
"n": 1, | |
"client_timeout": 60, | |
"run_id": None, | |
"batch_size": 20, | |
"temperature": 0.0, | |
"max_tokens": 10, | |
"top_p": 1.0, | |
"top_k": 1, | |
"logprobs": None, | |
"stop_sequences": None, | |
"num_beams": 1, | |
"do_sample": False, | |
"repetition_penalty": 1.0, | |
"length_penalty": 1.0, | |
"presence_penalty": 0.0, | |
"frequency_penalty": 0.0, | |
}, | |
"response_type": "text", | |
"request_type": "LMScoreRequest", | |
"item_dtype": None, | |
} | |
prompt_list = ["Hello is a prompt", "Hello is another prompt"] | |
result = manifest.score_prompt(prompt_list) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": "Hello is a prompt", | |
"request_cls": "LMScoreRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
assert ( | |
manifest.cache.get( | |
{ | |
"best_of": 1, | |
"engine": "dummy", | |
"max_tokens": 10, | |
"model": "text-davinci-003", | |
"n": 1, | |
"prompt": "Hello is another prompt", | |
"request_cls": "LMScoreRequest", | |
"temperature": 0.0, | |
"top_p": 1.0, | |
} | |
) | |
is not None | |
) | |
assert result == { | |
"response": { | |
"choices": [ | |
{ | |
"text": "appersstoff210 currentNodeleh norm unified_voice DIYHam", | |
"token_logprobs": [ | |
-0.5613340599860608, | |
-1.2822870706137146, | |
-1.9909319620162806, | |
-0.6312373658222814, | |
-1.9066239705571664, | |
-1.2420939968397082, | |
-0.7208735169940805, | |
-1.9144266963723062, | |
-0.041181937860757856, | |
-0.5356282450367043, | |
], | |
"tokens": [ | |
"28921", | |
"81056", | |
"8848", | |
"47399", | |
"74890", | |
"7617", | |
"43790", | |
"77865", | |
"32558", | |
"41041", | |
], | |
}, | |
{ | |
"text": ".addAttribute_size DE imageUrl_datas\tapFixed(hour setups\tcomment", # noqa: E501 | |
"token_logprobs": [ | |
-1.1142500072582333, | |
-0.819706434396527, | |
-1.9956443391600693, | |
-0.8425896744807639, | |
-1.8398050571245623, | |
-1.912564137256891, | |
-1.6677665162080606, | |
-1.1579612203844727, | |
-1.9876114502998343, | |
-0.2698297864722319, | |
], | |
"tokens": [ | |
"26300", | |
"2424", | |
"3467", | |
"40749", | |
"47630", | |
"70998", | |
"13829", | |
"72135", | |
"84823", | |
"97368", | |
], | |
}, | |
] | |
}, | |
"usages": { | |
"usages": [ | |
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14}, | |
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14}, | |
] | |
}, | |
"cached": False, | |
"request": { | |
"prompt": ["Hello is a prompt", "Hello is another prompt"], | |
"engine": "text-davinci-003", | |
"n": 1, | |
"client_timeout": 60, | |
"run_id": None, | |
"batch_size": 20, | |
"temperature": 0.0, | |
"max_tokens": 10, | |
"top_p": 1.0, | |
"top_k": 1, | |
"logprobs": None, | |
"stop_sequences": None, | |
"num_beams": 1, | |
"do_sample": False, | |
"repetition_penalty": 1.0, | |
"length_penalty": 1.0, | |
"presence_penalty": 0.0, | |
"frequency_penalty": 0.0, | |
}, | |
"response_type": "text", | |
"request_type": "LMScoreRequest", | |
"item_dtype": None, | |
} | |
def test_local_huggingface(sqlite_cache: str) -> None: | |
"""Test local huggingface client.""" | |
client = Manifest( | |
client_name="huggingface", | |
client_connection=URL, | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
) | |
res = client.run("Why are there apples?") | |
assert isinstance(res, str) and len(res) > 0 | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 | |
assert response.is_cached() is True | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert response.is_cached() is True | |
res_list = client.run(["Why are there apples?", "Why are there bananas?"]) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
response = cast( | |
Response, client.run("Why are there bananas?", return_response=True) | |
) | |
assert response.is_cached() is True | |
res_list = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is True | |
scores = client.score_prompt("Why are there apples?") | |
assert isinstance(scores, dict) and len(scores) > 0 | |
assert scores["cached"] is False | |
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len( | |
scores["response"]["choices"][0]["tokens"] | |
) | |
scores = client.score_prompt(["Why are there apples?", "Why are there bananas?"]) | |
assert isinstance(scores, dict) and len(scores) > 0 | |
assert scores["cached"] is True | |
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len( | |
scores["response"]["choices"][0]["tokens"] | |
) | |
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len( | |
scores["response"]["choices"][0]["tokens"] | |
) | |
def test_local_huggingfaceembedding(sqlite_cache: str) -> None: | |
"""Test openaichat client.""" | |
client = Manifest( | |
client_name="huggingfaceembedding", | |
client_connection=URL, | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
) | |
res = client.run("Why are there carrots?") | |
assert isinstance(res, np.ndarray) | |
response = cast( | |
Response, client.run("Why are there carrots?", return_response=True) | |
) | |
assert isinstance(response.get_response(), np.ndarray) | |
assert np.allclose(response.get_response(), res) | |
client = Manifest( | |
client_name="huggingfaceembedding", | |
client_connection=URL, | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
) | |
res = client.run("Why are there apples?") | |
assert isinstance(res, np.ndarray) | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert isinstance(response.get_response(), np.ndarray) | |
assert np.allclose(response.get_response(), res) | |
assert response.is_cached() is True | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert response.is_cached() is True | |
res_list = client.run(["Why are there apples?", "Why are there bananas?"]) | |
assert ( | |
isinstance(res_list, list) | |
and len(res_list) == 2 | |
and isinstance(res_list[0], np.ndarray) | |
) | |
response = cast( | |
Response, | |
client.run( | |
["Why are there apples?", "Why are there mangos?"], return_response=True | |
), | |
) | |
assert ( | |
isinstance(response.get_response(), list) and len(response.get_response()) == 2 | |
) | |
response = cast( | |
Response, client.run("Why are there bananas?", return_response=True) | |
) | |
assert response.is_cached() is True | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is False | |
res_list = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert ( | |
isinstance(res_list, list) | |
and len(res_list) == 2 | |
and isinstance(res_list[0], np.ndarray) | |
) | |
response = cast( | |
Response, | |
asyncio.run( | |
client.arun_batch( | |
["Why are there pinenuts?", "Why are there cocoa?"], | |
return_response=True, | |
) | |
), | |
) | |
assert ( | |
isinstance(response.get_response(), list) | |
and len(res_list) == 2 | |
and isinstance(res_list[0], np.ndarray) | |
) | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is True | |
def test_openai(sqlite_cache: str) -> None: | |
"""Test openai client.""" | |
client = Manifest( | |
client_name="openai", | |
engine="text-ada-001", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
temperature=0.0, | |
) | |
res = client.run("Why are there apples?") | |
assert isinstance(res, str) and len(res) > 0 | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 | |
assert response.get_response() == res | |
assert response.is_cached() is True | |
assert response.get_usage_obj().usages | |
assert response.get_usage_obj().usages[0].total_tokens == 15 | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert response.is_cached() is True | |
res_list = client.run(["Why are there apples?", "Why are there bananas?"]) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
response = cast( | |
Response, | |
client.run( | |
["Why are there apples?", "Why are there mangos?"], return_response=True | |
), | |
) | |
assert ( | |
isinstance(response.get_response(), list) and len(response.get_response()) == 2 | |
) | |
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 | |
assert response.get_usage_obj().usages[0].total_tokens == 15 | |
assert response.get_usage_obj().usages[1].total_tokens == 16 | |
response = cast( | |
Response, client.run("Why are there bananas?", return_response=True) | |
) | |
assert response.is_cached() is True | |
res_list = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
response = cast( | |
Response, | |
asyncio.run( | |
client.arun_batch( | |
["Why are there pinenuts?", "Why are there cocoa?"], | |
return_response=True, | |
) | |
), | |
) | |
assert ( | |
isinstance(response.get_response(), list) and len(response.get_response()) == 2 | |
) | |
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 | |
assert response.get_usage_obj().usages[0].total_tokens == 17 | |
assert response.get_usage_obj().usages[1].total_tokens == 15 | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is True | |
# Test streaming | |
num_responses = 0 | |
streaming_response_text = cast( | |
Iterator[str], client.run("Why are there oranges?", stream=True) | |
) | |
for res_text in streaming_response_text: | |
num_responses += 1 | |
assert isinstance(res_text, str) and len(res_text) > 0 | |
assert num_responses == 8 | |
streaming_response = cast( | |
Iterator[Response], | |
client.run("Why are there mandarines?", return_response=True, stream=True), | |
) | |
num_responses = 0 | |
merged_res = [] | |
for res in streaming_response: | |
num_responses += 1 | |
assert isinstance(res, Response) and len(res.get_response()) > 0 | |
merged_res.append(cast(str, res.get_response())) | |
assert not res.is_cached() | |
assert num_responses == 10 | |
# Make sure cached | |
streaming_response = cast( | |
Iterator[Response], | |
client.run("Why are there mandarines?", return_response=True, stream=True), | |
) | |
num_responses = 0 | |
merged_res_cachced = [] | |
for res in streaming_response: | |
num_responses += 1 | |
assert isinstance(res, Response) and len(res.get_response()) > 0 | |
merged_res_cachced.append(cast(str, res.get_response())) | |
assert res.is_cached() | |
# OpenAI stream does not return logprobs, so this is by number of words | |
assert num_responses == 7 | |
assert "".join(merged_res) == "".join(merged_res_cachced) | |
def test_openaichat(sqlite_cache: str) -> None: | |
"""Test openaichat client.""" | |
client = Manifest( | |
client_name="openaichat", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
temperature=0.0, | |
) | |
res = client.run("Why are there apples?") | |
assert isinstance(res, str) and len(res) > 0 | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 | |
assert response.get_response() == res | |
assert response.is_cached() is True | |
assert response.get_usage_obj().usages | |
assert response.get_usage_obj().usages[0].total_tokens == 23 | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert response.is_cached() is True | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is False | |
res_list = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
response = cast( | |
Response, | |
asyncio.run( | |
client.arun_batch( | |
["Why are there pinenuts?", "Why are there cocoa?"], | |
return_response=True, | |
) | |
), | |
) | |
assert ( | |
isinstance(response.get_response(), list) and len(response.get_response()) == 2 | |
) | |
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 | |
assert response.get_usage_obj().usages[0].total_tokens == 25 | |
assert response.get_usage_obj().usages[1].total_tokens == 23 | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is True | |
chat_dict = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Who won the world series in 2020?"}, | |
{ | |
"role": "assistant", | |
"content": "The Los Angeles Dodgers won the World Series in 2020.", | |
}, | |
{"role": "user", "content": "Where was it played?"}, | |
] | |
res = client.run(chat_dict) | |
assert isinstance(res, str) and len(res) > 0 | |
response = cast(Response, client.run(chat_dict, return_response=True)) | |
assert response.is_cached() is True | |
assert response.get_usage_obj().usages[0].total_tokens == 67 | |
chat_dict = [ | |
{"role": "system", "content": "You are a helpful assistanttttt."}, | |
{"role": "user", "content": "Who won the world series in 2020?"}, | |
{ | |
"role": "assistant", | |
"content": "The Los Angeles Dodgers won the World Series in 2020.", | |
}, | |
{"role": "user", "content": "Where was it played?"}, | |
] | |
response = cast(Response, client.run(chat_dict, return_response=True)) | |
assert response.is_cached() is False | |
# Test streaming | |
num_responses = 0 | |
streaming_response_text = cast( | |
Iterator[str], client.run("Why are there oranges?", stream=True) | |
) | |
for res_text in streaming_response_text: | |
num_responses += 1 | |
assert isinstance(res_text, str) and len(res_text) > 0 | |
assert num_responses == 9 | |
streaming_response = cast( | |
Iterator[Response], | |
client.run("Why are there mandarines?", return_response=True, stream=True), | |
) | |
num_responses = 0 | |
merged_res = [] | |
for res in streaming_response: | |
num_responses += 1 | |
assert isinstance(res, Response) and len(res.get_response()) > 0 | |
merged_res.append(cast(str, res.get_response())) | |
assert not res.is_cached() | |
assert num_responses == 10 | |
# Make sure cached | |
streaming_response = cast( | |
Iterator[Response], | |
client.run("Why are there mandarines?", return_response=True, stream=True), | |
) | |
num_responses = 0 | |
merged_res_cachced = [] | |
for res in streaming_response: | |
num_responses += 1 | |
assert isinstance(res, Response) and len(res.get_response()) > 0 | |
merged_res_cachced.append(cast(str, res.get_response())) | |
assert res.is_cached() | |
# OpenAI stream does not return logprobs, so this is by number of words | |
assert num_responses == 7 | |
assert "".join(merged_res) == "".join(merged_res_cachced) | |
def test_openaiembedding(sqlite_cache: str) -> None: | |
"""Test openaichat client.""" | |
client = Manifest( | |
client_name="openaiembedding", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
array_serializer="local_file", | |
) | |
res = client.run("Why are there carrots?") | |
assert isinstance(res, np.ndarray) | |
response = cast( | |
Response, client.run("Why are there carrots?", return_response=True) | |
) | |
assert isinstance(response.get_response(), np.ndarray) | |
assert np.allclose(response.get_response(), res) | |
client = Manifest( | |
client_name="openaiembedding", | |
cache_name="sqlite", | |
cache_connection=sqlite_cache, | |
) | |
res = client.run("Why are there apples?") | |
assert isinstance(res, np.ndarray) | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert isinstance(response.get_response(), np.ndarray) | |
assert np.allclose(response.get_response(), res) | |
assert response.is_cached() is True | |
assert response.get_usage_obj().usages | |
assert response.get_usage_obj().usages[0].total_tokens == 5 | |
response = cast(Response, client.run("Why are there apples?", return_response=True)) | |
assert response.is_cached() is True | |
res_list = client.run(["Why are there apples?", "Why are there bananas?"]) | |
assert ( | |
isinstance(res_list, list) | |
and len(res_list) == 2 | |
and isinstance(res_list[0], np.ndarray) | |
) | |
response = cast( | |
Response, | |
client.run( | |
["Why are there apples?", "Why are there mangos?"], return_response=True | |
), | |
) | |
assert ( | |
isinstance(response.get_response(), list) and len(response.get_response()) == 2 | |
) | |
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 | |
assert response.get_usage_obj().usages[0].total_tokens == 5 | |
assert response.get_usage_obj().usages[1].total_tokens == 6 | |
response = cast( | |
Response, client.run("Why are there bananas?", return_response=True) | |
) | |
assert response.is_cached() is True | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is False | |
res_list = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert ( | |
isinstance(res_list, list) | |
and len(res_list) == 2 | |
and isinstance(res_list[0], np.ndarray) | |
) | |
response = cast( | |
Response, | |
asyncio.run( | |
client.arun_batch( | |
["Why are there pinenuts?", "Why are there cocoa?"], | |
return_response=True, | |
) | |
), | |
) | |
assert ( | |
isinstance(response.get_response(), list) | |
and len(res_list) == 2 | |
and isinstance(res_list[0], np.ndarray) | |
) | |
assert response.get_usage_obj().usages and len(response.get_usage_obj().usages) == 2 | |
assert response.get_usage_obj().usages[0].total_tokens == 7 | |
assert response.get_usage_obj().usages[1].total_tokens == 5 | |
response = cast( | |
Response, client.run("Why are there oranges?", return_response=True) | |
) | |
assert response.is_cached() is True | |
def test_openai_pool(sqlite_cache: str) -> None: | |
"""Test openai and openaichat client.""" | |
client_connection1 = ClientConnection( | |
client_name="openaichat", | |
) | |
client_connection2 = ClientConnection(client_name="openai", engine="text-ada-001") | |
client = Manifest( | |
client_pool=[client_connection1, client_connection2], | |
cache_name="sqlite", | |
client_connection=sqlite_cache, | |
) | |
res = client.run("Why are there apples?") | |
assert isinstance(res, str) and len(res) > 0 | |
res2 = client.run("Why are there apples?") | |
assert isinstance(res2, str) and len(res2) > 0 | |
# Different models | |
assert res != res2 | |
assert cast( | |
Response, client.run("Why are there apples?", return_response=True) | |
).is_cached() | |
res_list = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
res_list2 = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert isinstance(res_list2, list) and len(res_list2) == 2 | |
# Different models | |
assert res_list != res_list2 | |
assert cast( | |
Response, | |
asyncio.run( | |
client.arun_batch( | |
["Why are there pears?", "Why are there oranges?"], return_response=True | |
) | |
), | |
).is_cached() | |
# Test chunk size of 1 | |
res_list = asyncio.run( | |
client.arun_batch( | |
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1 | |
) | |
) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
res_list2 = asyncio.run( | |
client.arun_batch( | |
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1 | |
) | |
) | |
# Because we split across both models exactly in first run, | |
# we will get the same result | |
assert res_list == res_list2 | |
def test_mixed_pool(sqlite_cache: str) -> None: | |
"""Test openai and openaichat client.""" | |
client_connection1 = ClientConnection( | |
client_name="huggingface", | |
client_connection=URL, | |
) | |
client_connection2 = ClientConnection(client_name="openai", engine="text-ada-001") | |
client = Manifest( | |
client_pool=[client_connection1, client_connection2], | |
cache_name="sqlite", | |
client_connection=sqlite_cache, | |
) | |
res = client.run("Why are there apples?") | |
assert isinstance(res, str) and len(res) > 0 | |
res2 = client.run("Why are there apples?") | |
assert isinstance(res2, str) and len(res2) > 0 | |
# Different models | |
assert res != res2 | |
assert cast( | |
Response, client.run("Why are there apples?", return_response=True) | |
).is_cached() | |
res_list = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
res_list2 = asyncio.run( | |
client.arun_batch(["Why are there pears?", "Why are there oranges?"]) | |
) | |
assert isinstance(res_list2, list) and len(res_list2) == 2 | |
# Different models | |
assert res_list != res_list2 | |
assert cast( | |
Response, | |
asyncio.run( | |
client.arun_batch( | |
["Why are there pears?", "Why are there oranges?"], return_response=True | |
) | |
), | |
).is_cached() | |
# Test chunk size of 1 | |
res_list = asyncio.run( | |
client.arun_batch( | |
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1 | |
) | |
) | |
assert isinstance(res_list, list) and len(res_list) == 2 | |
res_list2 = asyncio.run( | |
client.arun_batch( | |
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1 | |
) | |
) | |
# Because we split across both models exactly in first run, | |
# we will get the same result | |
assert res_list == res_list2 | |
def test_retry_handling() -> None: | |
"""Test retry handling.""" | |
# We'll mock the response so we won't need a real connection | |
client = Manifest(client_name="openai", client_connection="fake") | |
mock_create = MagicMock( | |
side_effect=[ | |
# raise a 429 error | |
HTTPError( | |
response=Mock(status_code=429, json=Mock(return_value={})), | |
request=Mock(), | |
), | |
# get a valid http response with a 200 status code | |
Mock( | |
status_code=200, | |
json=Mock( | |
return_value={ | |
"choices": [ | |
{ | |
"finish_reason": "length", | |
"index": 0, | |
"logprobs": None, | |
"text": " WHATTT.", | |
}, | |
{ | |
"finish_reason": "length", | |
"index": 1, | |
"logprobs": None, | |
"text": " UH OH.", | |
}, | |
{ | |
"finish_reason": "length", | |
"index": 2, | |
"logprobs": None, | |
"text": " HARG", | |
}, | |
], | |
"created": 1679469056, | |
"id": "cmpl-6wmuWfmyuzi68B6gfeNC0h5ywxXL5", | |
"model": "text-ada-001", | |
"object": "text_completion", | |
"usage": { | |
"completion_tokens": 30, | |
"prompt_tokens": 24, | |
"total_tokens": 54, | |
}, | |
} | |
), | |
), | |
] | |
) | |
prompts = [ | |
"The sky is purple. This is because", | |
"The sky is magnet. This is because", | |
"The sky is fuzzy. This is because", | |
] | |
with patch("manifest.clients.client.requests.post", mock_create): | |
# Run manifest | |
result = client.run(prompts, temperature=0, overwrite_cache=True) | |
assert result == [" WHATTT.", " UH OH.", " HARG"] | |
# Assert that OpenAI client was called twice | |
assert mock_create.call_count == 2 | |
# Now make sure it errors when not a 429 or 500 | |
mock_create = MagicMock( | |
side_effect=[ | |
# raise a 505 error | |
HTTPError( | |
response=Mock(status_code=505, json=Mock(return_value={})), | |
request=Mock(), | |
), | |
] | |
) | |
with patch("manifest.clients.client.requests.post", mock_create): | |
# Run manifest | |
with pytest.raises(HTTPError): | |
client.run(prompts, temperature=0, overwrite_cache=True) | |
# Assert that OpenAI client was called once | |
assert mock_create.call_count == 1 | |