|
from fastapi import FastAPI, HTTPException, Depends |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from sse_starlette.sse import EventSourceResponse |
|
from pydantic import BaseModel, Field |
|
from typing import AsyncGenerator, Optional, List, Dict, Any |
|
from enum import Enum |
|
from datetime import datetime |
|
import json |
|
import aiohttp |
|
from functools import lru_cache |
|
import os |
|
class Role(str, Enum): |
|
SYSTEM = "system" |
|
USER = "user" |
|
ASSISTANT = "assistant" |
|
|
|
class ChatMessage(BaseModel): |
|
role: Optional[Role] = None |
|
content: Optional[str] = None |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
message_dict = {} |
|
if self.role is not None: |
|
message_dict['role'] = self.role |
|
if self.content is not None: |
|
message_dict['content'] = self.content |
|
return message_dict |
|
|
|
class UsageInfo(BaseModel): |
|
prompt_tokens: Optional[int] = None |
|
completion_tokens: Optional[int] = None |
|
total_tokens: Optional[int] = None |
|
estimated_cost: Optional[float] = None |
|
|
|
class ChatCompletionChoice(BaseModel): |
|
index: int |
|
message: ChatMessage |
|
finish_reason: Optional[str] = None |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
return { |
|
'index': self.index, |
|
'message': self.message.to_dict(), |
|
'finish_reason': self.finish_reason |
|
} |
|
|
|
class ChatCompletionResponse(BaseModel): |
|
id: str |
|
object: str |
|
created: int |
|
model: str |
|
choices: List[ChatCompletionChoice] |
|
usage: Optional[Dict[str, Any]] = None |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
return { |
|
'id': self.id, |
|
'object': self.object, |
|
'created': self.created, |
|
'model': self.model, |
|
'choices': [choice.to_dict() for choice in self.choices], |
|
'usage': self.usage |
|
} |
|
|
|
class ChatCompletionChunkChoice(BaseModel): |
|
index: int |
|
delta: ChatMessage |
|
finish_reason: Optional[str] = None |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
return { |
|
'index': self.index, |
|
'delta': self.delta.to_dict(), |
|
'finish_reason': self.finish_reason |
|
} |
|
|
|
class ChatCompletionChunk(BaseModel): |
|
id: str |
|
object: str |
|
created: int |
|
model: str |
|
choices: List[ChatCompletionChunkChoice] |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
return { |
|
'id': self.id, |
|
'object': self.object, |
|
'created': self.created, |
|
'model': self.model, |
|
'choices': [choice.to_dict() for choice in self.choices] |
|
} |
|
|
|
class ChatRequest(BaseModel): |
|
messages: List[ChatMessage] |
|
model: str = Field(default="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo") |
|
temperature: float = Field(default=0.7, ge=0.0, le=2.0) |
|
max_tokens: Optional[int] = Field(default=2048) |
|
stream: bool = Field(default=False) |
|
response_format: Optional[Dict[str, str]] = None |
|
|
|
class DeepInfraClient: |
|
def __init__(self, api_key: Optional[str] = None): |
|
self.url = "https://api.deepinfra.com/v1/openai/chat/completions" |
|
self.headers = { |
|
"Accept": "text/event-stream, application/json", |
|
"Content-Type": "application/json" |
|
} |
|
if api_key: |
|
self.headers["Authorization"] = f"Bearer {api_key}" |
|
|
|
def _prepare_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: |
|
return [message.to_dict() for message in messages] |
|
|
|
async def generate_stream(self, request: ChatRequest) -> AsyncGenerator: |
|
payload = { |
|
"model": request.model, |
|
"messages": self._prepare_messages(request.messages), |
|
"temperature": request.temperature, |
|
"max_tokens": request.max_tokens, |
|
"stream": True |
|
} |
|
|
|
if request.response_format: |
|
payload["response_format"] = request.response_format |
|
|
|
timeout = aiohttp.ClientTimeout(total=300) |
|
|
|
try: |
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
async with session.post( |
|
self.url, |
|
headers=self.headers, |
|
json=payload, |
|
chunked=True |
|
) as response: |
|
if response.status != 200: |
|
error_msg = await response.text() |
|
raise HTTPException( |
|
status_code=response.status, |
|
detail=f"API request failed: {error_msg}" |
|
) |
|
|
|
async for line in response.content: |
|
if not line: |
|
continue |
|
|
|
try: |
|
line = line.decode('utf-8').strip() |
|
if not line: |
|
continue |
|
|
|
if line.startswith("data: "): |
|
json_str = line[6:] |
|
if json_str == "[DONE]": |
|
yield {"data": "[DONE]"} |
|
break |
|
|
|
chunk = json.loads(json_str) |
|
chunk_obj = ChatCompletionChunk( |
|
id=chunk["id"], |
|
object="chat.completion.chunk", |
|
created=int(datetime.now().timestamp()), |
|
model=request.model, |
|
choices=[ |
|
ChatCompletionChunkChoice( |
|
index=choice["index"], |
|
delta=ChatMessage(**choice.get("delta", {})), |
|
finish_reason=choice.get("finish_reason") |
|
) |
|
for choice in chunk["choices"] |
|
] |
|
) |
|
yield {"data": json.dumps(chunk_obj.to_dict())} |
|
except json.JSONDecodeError: |
|
continue |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Stream processing error: {str(e)}" |
|
) |
|
|
|
except aiohttp.ClientError as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Connection error: {str(e)}" |
|
) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Unexpected error: {str(e)}" |
|
) |
|
|
|
async def generate(self, request: ChatRequest) -> ChatCompletionResponse: |
|
payload = { |
|
"model": request.model, |
|
"messages": self._prepare_messages(request.messages), |
|
"temperature": request.temperature, |
|
"max_tokens": request.max_tokens, |
|
"stream": False |
|
} |
|
|
|
if request.response_format: |
|
payload["response_format"] = request.response_format |
|
|
|
timeout = aiohttp.ClientTimeout(total=300) |
|
|
|
try: |
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
async with session.post( |
|
self.url, |
|
headers=self.headers, |
|
json=payload |
|
) as response: |
|
if response.status != 200: |
|
error_msg = await response.text() |
|
raise HTTPException( |
|
status_code=response.status, |
|
detail=f"API request failed: {error_msg}" |
|
) |
|
|
|
try: |
|
response_data = await response.json() |
|
if not isinstance(response_data, dict): |
|
raise HTTPException( |
|
status_code=500, |
|
detail="Invalid response format from API" |
|
) |
|
|
|
|
|
if 'usage' in response_data: |
|
usage_data = response_data['usage'] |
|
for key in ['prompt_tokens', 'completion_tokens', 'total_tokens']: |
|
if key in usage_data and isinstance(usage_data[key], float): |
|
usage_data[key] = int(usage_data[key]) |
|
|
|
|
|
response_data.setdefault('id', str(datetime.now().timestamp())) |
|
response_data.setdefault('object', 'chat.completion') |
|
response_data.setdefault('created', int(datetime.now().timestamp())) |
|
response_data.setdefault('model', request.model) |
|
|
|
return ChatCompletionResponse(**response_data) |
|
except json.JSONDecodeError as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Failed to parse API response: {str(e)}" |
|
) |
|
except aiohttp.ClientError as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Connection error: {str(e)}" |
|
) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Unexpected error: {str(e)}" |
|
) |
|
|
|
app = FastAPI(title="DeepInfra OpenAI Compatible API") |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@lru_cache() |
|
def get_client(): |
|
return DeepInfraClient() |
|
|
|
@app.post("/v1/chat/completions") |
|
async def create_chat_completion( |
|
request: ChatRequest, |
|
client: DeepInfraClient = Depends(get_client) |
|
): |
|
try: |
|
if request.stream: |
|
return EventSourceResponse(client.generate_stream(request)) |
|
return await client.generate(request) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/v1/models") |
|
async def list_models(): |
|
models = os.getenv("MODELS", "").split(",") |
|
|
|
current_timestamp = int(datetime.now().timestamp()) |
|
|
|
return { |
|
"data": [ |
|
{ |
|
"id": model_id, |
|
"object": "model", |
|
"created": current_timestamp, |
|
"owned_by": "deepinfra" |
|
} |
|
for model_id in models |
|
] |
|
} |
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |