|
import httpx |
|
|
|
from fastapi import APIRouter, HTTPException, Header, Request, Response |
|
from httpx_sse import EventSource |
|
from sse_starlette.sse import EventSourceResponse |
|
from starlette.background import BackgroundTask |
|
from app.internal.constants import DUCKDUCKGO_CHAT_ENDPOINT, DUCKDUCKGO_STATUS_ENDPOINT, SESSION_KEY |
|
from typing import Annotated, AsyncIterator, List, Literal |
|
from pydantic import BaseModel |
|
|
|
DONE = '[DONE]' |
|
|
|
|
|
class Message(BaseModel): |
|
role: Literal['assistant', 'user'] |
|
content: str |
|
|
|
|
|
class Chat(BaseModel): |
|
model: Literal[ |
|
'gpt-3.5-turbo-0125', |
|
'claude-3-haiku-20240307' |
|
] = 'gpt-3.5-turbo-0125' |
|
messages: list[Message] |
|
stream: bool = False |
|
|
|
model_config = { |
|
'json_schema_extra': { |
|
'examples': [ |
|
{ |
|
'model': 'claude-3-haiku-20240307', |
|
'messages': [{ |
|
'role': 'user', |
|
'content': 'Hello', |
|
}] |
|
} |
|
] |
|
} |
|
} |
|
|
|
|
|
class Choice(BaseModel): |
|
message: Message |
|
|
|
|
|
class CompletionsResult(BaseModel): |
|
choices: List[Choice] |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
@router.post('/ddg/chat/completions', |
|
response_model=CompletionsResult, |
|
responses={ |
|
200: { |
|
'content': {'text/event-stream': {}}, |
|
'description': 'Return the JSON completions result or an event stream.', |
|
} |
|
}) |
|
async def chat(input: Chat, request: Request, response: Response, x_session_id: Annotated[str | None, Header()] = None): |
|
http_client: httpx.AsyncClient = request.state.http_client |
|
session_id = x_session_id or (await http_client.get(DUCKDUCKGO_STATUS_ENDPOINT, headers={'x-vqd-accept': '1'})).headers.get(SESSION_KEY) |
|
|
|
req = http_client.build_request('POST', DUCKDUCKGO_CHAT_ENDPOINT, |
|
json=input.model_dump(exclude={'stream'}), |
|
headers={SESSION_KEY: session_id}) |
|
resp = await http_client.send(req, stream=True) |
|
|
|
if resp.status_code != 200: |
|
raise HTTPException(status_code=400) |
|
|
|
async def agenerator() -> AsyncIterator[str]: |
|
async for event in EventSource(resp).aiter_sse(): |
|
if event.data == DONE: |
|
return |
|
|
|
if 'message' in (decoded := event.json()): |
|
yield decoded['message'] |
|
|
|
async def event_generator(): |
|
async for chunk in agenerator(): |
|
yield { |
|
'data': { |
|
'choices': [{ |
|
'delta': chunk |
|
}] |
|
} |
|
} |
|
|
|
yield DONE |
|
|
|
response.headers['x-session-id'] = resp.headers.get(SESSION_KEY) |
|
|
|
if input.stream: |
|
return EventSourceResponse(event_generator(), background=BackgroundTask(resp.aclose), headers=response.headers) |
|
|
|
content = '' |
|
async for chunk in agenerator(): |
|
content += chunk |
|
|
|
return CompletionsResult(choices=[Choice(message=Message(role='assistant', content=content))]) |
|
|