|
from fastapi import FastAPI, HTTPException, Header, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel |
|
import openai |
|
from typing import List, Optional,Union |
|
import logging |
|
from itertools import cycle |
|
import asyncio |
|
|
|
import uvicorn |
|
|
|
from app import config |
|
import requests |
|
from datetime import datetime, timezone |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
API_KEYS = config.settings.API_KEYS |
|
|
|
|
|
key_cycle = cycle(API_KEYS) |
|
key_lock = asyncio.Lock() |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
messages: List[dict] |
|
model: str = "llama-3.2-90b-text-preview" |
|
temperature: Optional[float] = 0.7 |
|
stream: Optional[bool] = False |
|
tools: Optional[List[dict]] = [] |
|
tool_choice: Optional[str] = "auto" |
|
|
|
|
|
class EmbeddingRequest(BaseModel): |
|
input: Union[str, List[str]] |
|
model: str = "text-embedding-004" |
|
encoding_format: Optional[str] = "float" |
|
|
|
|
|
async def verify_authorization(authorization: str = Header(None)): |
|
if not authorization: |
|
logger.error("Missing Authorization header") |
|
raise HTTPException(status_code=401, detail="Missing Authorization header") |
|
if not authorization.startswith("Bearer "): |
|
logger.error("Invalid Authorization header format") |
|
raise HTTPException( |
|
status_code=401, detail="Invalid Authorization header format" |
|
) |
|
token = authorization.replace("Bearer ", "") |
|
if token not in config.settings.ALLOWED_TOKENS: |
|
logger.error("Invalid token") |
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
return token |
|
|
|
|
|
def get_gemini_models(api_key): |
|
base_url = "https://generativelanguage.googleapis.com/v1beta" |
|
url = f"{base_url}/models?key={api_key}" |
|
|
|
try: |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
gemini_models = response.json() |
|
return convert_to_openai_format(gemini_models) |
|
else: |
|
print(f"Error: {response.status_code}") |
|
print(response.text) |
|
return None |
|
|
|
except requests.RequestException as e: |
|
print(f"Request failed: {e}") |
|
return None |
|
|
|
def convert_to_openai_format(gemini_models): |
|
openai_format = { |
|
"object": "list", |
|
"data": [] |
|
} |
|
|
|
for model in gemini_models.get('models', []): |
|
openai_model = { |
|
"id": model['name'].split('/')[-1], |
|
"object": "model", |
|
"created": int(datetime.now(timezone.utc).timestamp()), |
|
"owned_by": "google", |
|
"permission": [], |
|
"root": model['name'], |
|
"parent": None, |
|
} |
|
openai_format["data"].append(openai_model) |
|
|
|
return openai_format |
|
|
|
|
|
@app.get("/v1/models") |
|
@app.get("/hf/v1/models") |
|
async def list_models(authorization: str = Header(None)): |
|
await verify_authorization(authorization) |
|
async with key_lock: |
|
api_key = next(key_cycle) |
|
logger.info(f"Using API key: {api_key}") |
|
try: |
|
response = get_gemini_models(api_key) |
|
logger.info("Successfully retrieved models list") |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error listing models: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
@app.post("/hf/v1/chat/completions") |
|
async def chat_completion(request: ChatRequest, authorization: str = Header(None)): |
|
await verify_authorization(authorization) |
|
async with key_lock: |
|
api_key = next(key_cycle) |
|
logger.info(f"Using API key: {api_key}") |
|
|
|
try: |
|
logger.info(f"Chat completion request - Model: {request.model}") |
|
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) |
|
response = client.chat.completions.create( |
|
model=request.model, |
|
messages=request.messages, |
|
temperature=request.temperature, |
|
stream=request.stream if hasattr(request, "stream") else False, |
|
) |
|
|
|
if hasattr(request, "stream") and request.stream: |
|
logger.info("Streaming response enabled") |
|
|
|
async def generate(): |
|
for chunk in response: |
|
yield f"data: {chunk.model_dump_json()}\n\n" |
|
|
|
return StreamingResponse(content=generate(), media_type="text/event-stream") |
|
|
|
logger.info("Chat completion successful") |
|
return response |
|
|
|
except Exception as e: |
|
logger.error(f"Error in chat completion: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/v1/embeddings") |
|
@app.post("/hf/v1/embeddings") |
|
async def embedding(request: EmbeddingRequest, authorization: str = Header(None)): |
|
await verify_authorization(authorization) |
|
async with key_lock: |
|
api_key = next(key_cycle) |
|
logger.info(f"Using API key: {api_key}") |
|
|
|
try: |
|
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) |
|
response = client.embeddings.create(input=request.input, model=request.model) |
|
logger.info("Embedding successful") |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error in embedding: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/health") |
|
@app.get("/") |
|
async def health_check(): |
|
logger.info("Health check endpoint called") |
|
return {"status": "healthy"} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |