v1-chat / app.py
Abhaykoul's picture
Update app.py
11a81ee verified
raw
history blame
11.1 kB
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel, Field
from typing import AsyncGenerator, Optional, List, Dict, Any
from enum import Enum
from datetime import datetime
import json
import aiohttp
from functools import lru_cache
import os
class Role(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class ChatMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
message_dict = {}
if self.role is not None:
message_dict['role'] = self.role
if self.content is not None:
message_dict['content'] = self.content
return message_dict
class UsageInfo(BaseModel):
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
estimated_cost: Optional[float] = None
class ChatCompletionChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'index': self.index,
'message': self.message.to_dict(),
'finish_reason': self.finish_reason
}
class ChatCompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[ChatCompletionChoice]
usage: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
return {
'id': self.id,
'object': self.object,
'created': self.created,
'model': self.model,
'choices': [choice.to_dict() for choice in self.choices],
'usage': self.usage
}
class ChatCompletionChunkChoice(BaseModel):
index: int
delta: ChatMessage
finish_reason: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'index': self.index,
'delta': self.delta.to_dict(),
'finish_reason': self.finish_reason
}
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
choices: List[ChatCompletionChunkChoice]
def to_dict(self) -> Dict[str, Any]:
return {
'id': self.id,
'object': self.object,
'created': self.created,
'model': self.model,
'choices': [choice.to_dict() for choice in self.choices]
}
class ChatRequest(BaseModel):
messages: List[ChatMessage]
model: str = Field(default="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo")
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=2048)
stream: bool = Field(default=False)
response_format: Optional[Dict[str, str]] = None
class DeepInfraClient:
def __init__(self, api_key: Optional[str] = None):
self.url = "https://api.deepinfra.com/v1/openai/chat/completions"
self.headers = {
"Accept": "text/event-stream, application/json",
"Content-Type": "application/json"
}
if api_key:
self.headers["Authorization"] = f"Bearer {api_key}"
def _prepare_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
return [message.to_dict() for message in messages]
async def generate_stream(self, request: ChatRequest) -> AsyncGenerator:
payload = {
"model": request.model,
"messages": self._prepare_messages(request.messages),
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": True
}
if request.response_format:
payload["response_format"] = request.response_format
timeout = aiohttp.ClientTimeout(total=300)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
self.url,
headers=self.headers,
json=payload,
chunked=True
) as response:
if response.status != 200:
error_msg = await response.text()
raise HTTPException(
status_code=response.status,
detail=f"API request failed: {error_msg}"
)
async for line in response.content:
if not line:
continue
try:
line = line.decode('utf-8').strip()
if not line:
continue
if line.startswith("data: "):
json_str = line[6:]
if json_str == "[DONE]":
yield {"data": "[DONE]"}
break
chunk = json.loads(json_str)
chunk_obj = ChatCompletionChunk(
id=chunk["id"],
object="chat.completion.chunk",
created=int(datetime.now().timestamp()),
model=request.model,
choices=[
ChatCompletionChunkChoice(
index=choice["index"],
delta=ChatMessage(**choice.get("delta", {})),
finish_reason=choice.get("finish_reason")
)
for choice in chunk["choices"]
]
)
yield {"data": json.dumps(chunk_obj.to_dict())}
except json.JSONDecodeError:
continue
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Stream processing error: {str(e)}"
)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=500,
detail=f"Connection error: {str(e)}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
async def generate(self, request: ChatRequest) -> ChatCompletionResponse:
payload = {
"model": request.model,
"messages": self._prepare_messages(request.messages),
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": False
}
if request.response_format:
payload["response_format"] = request.response_format
timeout = aiohttp.ClientTimeout(total=300)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
self.url,
headers=self.headers,
json=payload
) as response:
if response.status != 200:
error_msg = await response.text()
raise HTTPException(
status_code=response.status,
detail=f"API request failed: {error_msg}"
)
try:
response_data = await response.json()
if not isinstance(response_data, dict):
raise HTTPException(
status_code=500,
detail="Invalid response format from API"
)
# Handle usage data
if 'usage' in response_data:
usage_data = response_data['usage']
for key in ['prompt_tokens', 'completion_tokens', 'total_tokens']:
if key in usage_data and isinstance(usage_data[key], float):
usage_data[key] = int(usage_data[key])
# Ensure required fields are present
response_data.setdefault('id', str(datetime.now().timestamp()))
response_data.setdefault('object', 'chat.completion')
response_data.setdefault('created', int(datetime.now().timestamp()))
response_data.setdefault('model', request.model)
return ChatCompletionResponse(**response_data)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=500,
detail=f"Failed to parse API response: {str(e)}"
)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=500,
detail=f"Connection error: {str(e)}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
app = FastAPI(title="DeepInfra OpenAI Compatible API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@lru_cache()
def get_client():
return DeepInfraClient()
@app.post("/v1/chat/completions")
async def create_chat_completion(
request: ChatRequest,
client: DeepInfraClient = Depends(get_client)
):
try:
if request.stream:
return EventSourceResponse(client.generate_stream(request))
return await client.generate(request)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
models = os.getenv("MODELS", "").split(",")
current_timestamp = int(datetime.now().timestamp())
return {
"data": [
{
"id": model_id,
"object": "model",
"created": current_timestamp,
"owned_by": "deepinfra"
}
for model_id in models
]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)