from fastapi import FastAPI, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from sse_starlette.sse import EventSourceResponse from pydantic import BaseModel, Field from typing import AsyncGenerator, Optional, List, Dict, Any from enum import Enum from datetime import datetime import json import aiohttp from functools import lru_cache import os class Role(str, Enum): SYSTEM = "system" USER = "user" ASSISTANT = "assistant" class ChatMessage(BaseModel): role: Optional[Role] = None content: Optional[str] = None def to_dict(self) -> Dict[str, Any]: message_dict = {} if self.role is not None: message_dict['role'] = self.role if self.content is not None: message_dict['content'] = self.content return message_dict class UsageInfo(BaseModel): prompt_tokens: Optional[int] = None completion_tokens: Optional[int] = None total_tokens: Optional[int] = None estimated_cost: Optional[float] = None class ChatCompletionChoice(BaseModel): index: int message: ChatMessage finish_reason: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { 'index': self.index, 'message': self.message.to_dict(), 'finish_reason': self.finish_reason } class ChatCompletionResponse(BaseModel): id: str object: str created: int model: str choices: List[ChatCompletionChoice] usage: Optional[Dict[str, Any]] = None def to_dict(self) -> Dict[str, Any]: return { 'id': self.id, 'object': self.object, 'created': self.created, 'model': self.model, 'choices': [choice.to_dict() for choice in self.choices], 'usage': self.usage } class ChatCompletionChunkChoice(BaseModel): index: int delta: ChatMessage finish_reason: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { 'index': self.index, 'delta': self.delta.to_dict(), 'finish_reason': self.finish_reason } class ChatCompletionChunk(BaseModel): id: str object: str created: int model: str choices: List[ChatCompletionChunkChoice] def to_dict(self) -> Dict[str, Any]: return { 'id': self.id, 'object': self.object, 'created': self.created, 'model': self.model, 'choices': [choice.to_dict() for choice in self.choices] } class ChatRequest(BaseModel): messages: List[ChatMessage] model: str = Field(default="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo") temperature: float = Field(default=0.7, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(default=2048) stream: bool = Field(default=False) response_format: Optional[Dict[str, str]] = None class DeepInfraClient: def __init__(self, api_key: Optional[str] = None): self.url = "https://api.deepinfra.com/v1/openai/chat/completions" self.headers = { "Accept": "text/event-stream, application/json", "Content-Type": "application/json" } if api_key: self.headers["Authorization"] = f"Bearer {api_key}" def _prepare_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: return [message.to_dict() for message in messages] async def generate_stream(self, request: ChatRequest) -> AsyncGenerator: payload = { "model": request.model, "messages": self._prepare_messages(request.messages), "temperature": request.temperature, "max_tokens": request.max_tokens, "stream": True } if request.response_format: payload["response_format"] = request.response_format timeout = aiohttp.ClientTimeout(total=300) try: async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post( self.url, headers=self.headers, json=payload, chunked=True ) as response: if response.status != 200: error_msg = await response.text() raise HTTPException( status_code=response.status, detail=f"API request failed: {error_msg}" ) async for line in response.content: if not line: continue try: line = line.decode('utf-8').strip() if not line: continue if line.startswith("data: "): json_str = line[6:] if json_str == "[DONE]": yield {"data": "[DONE]"} break chunk = json.loads(json_str) chunk_obj = ChatCompletionChunk( id=chunk["id"], object="chat.completion.chunk", created=int(datetime.now().timestamp()), model=request.model, choices=[ ChatCompletionChunkChoice( index=choice["index"], delta=ChatMessage(**choice.get("delta", {})), finish_reason=choice.get("finish_reason") ) for choice in chunk["choices"] ] ) yield {"data": json.dumps(chunk_obj.to_dict())} except json.JSONDecodeError: continue except Exception as e: raise HTTPException( status_code=500, detail=f"Stream processing error: {str(e)}" ) except aiohttp.ClientError as e: raise HTTPException( status_code=500, detail=f"Connection error: {str(e)}" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Unexpected error: {str(e)}" ) async def generate(self, request: ChatRequest) -> ChatCompletionResponse: payload = { "model": request.model, "messages": self._prepare_messages(request.messages), "temperature": request.temperature, "max_tokens": request.max_tokens, "stream": False } if request.response_format: payload["response_format"] = request.response_format timeout = aiohttp.ClientTimeout(total=300) try: async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post( self.url, headers=self.headers, json=payload ) as response: if response.status != 200: error_msg = await response.text() raise HTTPException( status_code=response.status, detail=f"API request failed: {error_msg}" ) try: response_data = await response.json() if not isinstance(response_data, dict): raise HTTPException( status_code=500, detail="Invalid response format from API" ) # Handle usage data if 'usage' in response_data: usage_data = response_data['usage'] for key in ['prompt_tokens', 'completion_tokens', 'total_tokens']: if key in usage_data and isinstance(usage_data[key], float): usage_data[key] = int(usage_data[key]) # Ensure required fields are present response_data.setdefault('id', str(datetime.now().timestamp())) response_data.setdefault('object', 'chat.completion') response_data.setdefault('created', int(datetime.now().timestamp())) response_data.setdefault('model', request.model) return ChatCompletionResponse(**response_data) except json.JSONDecodeError as e: raise HTTPException( status_code=500, detail=f"Failed to parse API response: {str(e)}" ) except aiohttp.ClientError as e: raise HTTPException( status_code=500, detail=f"Connection error: {str(e)}" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Unexpected error: {str(e)}" ) app = FastAPI(title="DeepInfra OpenAI Compatible API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @lru_cache() def get_client(): return DeepInfraClient() @app.post("/v1/chat/completions") async def create_chat_completion( request: ChatRequest, client: DeepInfraClient = Depends(get_client) ): try: if request.stream: return EventSourceResponse(client.generate_stream(request)) return await client.generate(request) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/models") async def list_models(): models = os.getenv("MODELS", "").split(",") current_timestamp = int(datetime.now().timestamp()) return { "data": [ { "id": model_id, "object": "model", "created": current_timestamp, "owned_by": "deepinfra" } for model_id in models ] } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)