Spaces:
Running
Running
File size: 3,607 Bytes
8ff63e4 315ec00 8ff63e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from __future__ import annotations
import json
import contextlib
from uuid import uuid4, UUID
from typing import Generator, Literal
import requests
import gradio as gr
from spitfight.colosseum.common import (
COLOSSEUM_PROMPT_ROUTE,
COLOSSEUM_RESP_VOTE_ROUTE,
COLOSSEUM_ENERGY_VOTE_ROUTE,
PromptRequest,
ResponseVoteRequest,
ResponseVoteResponse,
EnergyVoteRequest,
EnergyVoteResponse,
)
class ControllerClient:
"""Client for the Colosseum controller, to be used by Gradio."""
def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None:
"""Initialize the controller client."""
self.controller_addr = controller_addr
self.timeout = timeout
self.request_id = str(uuid4()) if request_id is None else str(request_id)
def fork(self) -> ControllerClient:
"""Return a copy of the client with a new request ID."""
return ControllerClient(
controller_addr=self.controller_addr,
timeout=self.timeout,
request_id=uuid4(),
)
def prompt(self, prompt: str, index: Literal[0, 1]) -> Generator[str, None, None]:
"""Generate the response of the `index`th model with the prompt."""
prompt_request = PromptRequest(request_id=self.request_id, prompt=prompt, model_index=index)
with _catch_requests_exceptions():
resp = requests.post(
f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
json=prompt_request.dict(),
stream=True,
timeout=self.timeout,
)
_check_response(resp)
# XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines?
for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield json.loads(chunk.decode("utf-8"))
def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse:
"""Notify the controller of the user's vote for the response."""
response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index)
with _catch_requests_exceptions():
resp = requests.post(
f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}",
json=response_vote_request.dict(),
)
_check_response(resp)
return ResponseVoteResponse(**resp.json())
def energy_vote(self, is_worth: bool) -> EnergyVoteResponse:
"""Notify the controller of the user's vote for energy."""
energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth)
with _catch_requests_exceptions():
resp = requests.post(
f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}",
json=energy_vote_request.dict(),
)
_check_response(resp)
return EnergyVoteResponse(**resp.json())
@contextlib.contextmanager
def _catch_requests_exceptions():
"""Catch requests exceptions and raise gr.Error instead."""
try:
yield
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
raise gr.Error("Failed to connect to our the backend server. Please try again later.")
def _check_response(response: requests.Response) -> None:
if 400 <= response.status_code < 500:
raise gr.Error(response.json()["detail"])
elif response.status_code >= 500:
raise gr.Error("Failed to talk to our backend server. Please try again later.")
|