File size: 3,583 Bytes
8ff63e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

import json
import contextlib
from uuid import uuid4, UUID
from typing import Generator, Literal

import requests
import gradio as gr

from spitfight.colosseum.common import (
    COLOSSEUM_PROMPT_ROUTE,
    COLOSSEUM_RESP_VOTE_ROUTE,
    COLOSSEUM_ENERGY_VOTE_ROUTE,
    PromptRequest,
    ResponseVoteRequest,
    ResponseVoteResponse,
    EnergyVoteRequest,
    EnergyVoteResponse,
)


class ControllerClient:
    """Client for the Colosseum controller, to be used by Gradio."""

    def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None:
        """Initialize the controller client."""
        self.controller_addr = controller_addr
        self.timeout = timeout
        self.request_id = str(request_id) or str(uuid4())

    def fork(self) -> ControllerClient:
        """Return a copy of the client with a new request ID."""
        return ControllerClient(
            controller_addr=self.controller_addr,
            timeout=self.timeout,
            request_id=uuid4(),
        )

    def prompt(self, prompt: str, index: Literal[0, 1]) -> Generator[str, None, None]:
        """Generate the response of the `index`th model with the prompt."""
        prompt_request = PromptRequest(request_id=self.request_id, prompt=prompt, model_index=index)
        with _catch_requests_exceptions():
            resp = requests.post(
                f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
                json=prompt_request.dict(),
                stream=True,
                timeout=self.timeout,
            )
        _check_response(resp)
        # XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines?
        for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"):
            if chunk:
                yield json.loads(chunk.decode("utf-8"))

    def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse:
        """Notify the controller of the user's vote for the response."""
        response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index)
        with _catch_requests_exceptions():
            resp = requests.post(
                f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}",
                json=response_vote_request.dict(),
            )
        _check_response(resp)
        return ResponseVoteResponse(**resp.json())

    def energy_vote(self, is_worth: bool) -> EnergyVoteResponse:
        """Notify the controller of the user's vote for energy."""
        energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth)
        with _catch_requests_exceptions():
            resp = requests.post(
                f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}",
                json=energy_vote_request.dict(),
            )
        _check_response(resp)
        return EnergyVoteResponse(**resp.json())


@contextlib.contextmanager
def _catch_requests_exceptions():
    """Catch requests exceptions and raise gr.Error instead."""
    try:
        yield
    except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
        raise gr.Error("Failed to connect to our the backend server. Please try again later.")


def _check_response(response: requests.Response) -> None:
    if 400 <= response.status_code < 500:
        raise gr.Error(response.json()["detail"])
    elif response.status_code >= 500:
        raise gr.Error("Failed to talk to our backend server. Please try again later.")