|
import os |
|
import time |
|
import multiprocessing |
|
from typing import Dict, Any, List |
|
from fastapi import FastAPI, Request, HTTPException |
|
import uvicorn |
|
import tiktoken |
|
from json.decoder import JSONDecodeError |
|
import random |
|
import string |
|
|
|
app = FastAPI( |
|
title="ones", |
|
description="High-performance API service", |
|
version="1.0.0|2025.1.6" |
|
) |
|
|
|
debug = False |
|
|
|
class APIServer: |
|
"""High-performance API server implementation""" |
|
|
|
def __init__(self, app: FastAPI): |
|
self.app = app |
|
self.encoding = tiktoken.get_encoding("cl100k_base") |
|
self._setup_routes() |
|
|
|
def _setup_routes(self) -> None: |
|
"""Initialize API routes""" |
|
routes = self._get_routes() |
|
for path in routes: |
|
self._register_route(path) |
|
|
|
@self.app.get("/") |
|
async def health_check() -> str: |
|
return "你好" |
|
|
|
def _get_routes(self) -> List[str]: |
|
"""Get configured API routes""" |
|
default_path = "/v1/chat/completions" |
|
replace_chat = os.getenv("REPLACE_CHAT", "") |
|
prefix_chat = os.getenv("PREFIX_CHAT", "") |
|
append_chat = os.getenv("APPEND_CHAT", "") |
|
|
|
if replace_chat: |
|
return [path.strip() for path in replace_chat.split(",") if path.strip()] |
|
|
|
routes = [] |
|
if prefix_chat: |
|
routes.extend(f"{prefix.rstrip('/')}{default_path}" |
|
for prefix in prefix_chat.split(",")) |
|
return routes |
|
|
|
if append_chat: |
|
append_paths = [path.strip() for path in append_chat.split(",") if path.strip()] |
|
routes = [default_path] + append_paths |
|
return routes |
|
|
|
return [default_path] |
|
|
|
def _register_route(self, path: str) -> None: |
|
"""Register a single API route""" |
|
async def chat_endpoint(request: Request) -> Dict[str, Any]: |
|
try: |
|
headers = dict(request.headers) |
|
data = await request.json() |
|
if debug: |
|
print(f"Request received...\r\n\tHeaders: {headers},\r\n\tData: {data}") |
|
return await self._generate_response(headers, data) |
|
except JSONDecodeError as e: |
|
if debug: |
|
print(f"JSON decode error: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid JSON format") from e |
|
except Exception as e: |
|
if debug: |
|
print(f"Request processing error: {e}") |
|
raise HTTPException(status_code=500, detail="Internal server error") from e |
|
|
|
self.app.post(path)(chat_endpoint) |
|
|
|
def _calculate_tokens(self, text: str) -> int: |
|
"""Calculate token count for text""" |
|
return len(self.encoding.encode(text)) |
|
|
|
def _generate_id(self, letters: int = 4, numbers: int = 6) -> str: |
|
"""Generate unique chat completion ID""" |
|
letters_str = ''.join(random.choices(string.ascii_lowercase, k=letters)) |
|
numbers_str = ''.join(random.choices(string.digits, k=numbers)) |
|
return f"chatcmpl-{letters_str}{numbers_str}" |
|
|
|
async def _generate_response(self, headers: Dict[str, str], data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Generate API response""" |
|
try: |
|
result = "This is a test result." |
|
prompt_tokens = self._calculate_tokens(str(data)) |
|
completion_tokens = self._calculate_tokens(result) |
|
total_tokens = prompt_tokens + completion_tokens |
|
|
|
return { |
|
"id": self._generate_id(), |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": data.get("model", "gpt-3.5-turbo"), |
|
"usage": { |
|
"prompt_tokens": prompt_tokens, |
|
"completion_tokens": completion_tokens, |
|
"total_tokens": total_tokens |
|
}, |
|
"choices": [{ |
|
"message": { |
|
"role": "assistant", |
|
"content": result |
|
}, |
|
"finish_reason": "stop", |
|
"index": 0 |
|
}] |
|
} |
|
except Exception as e: |
|
if debug: |
|
print(f"Response generation error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) from e |
|
|
|
def _get_workers_count(self) -> int: |
|
"""Calculate optimal worker count""" |
|
try: |
|
cpu_cores = multiprocessing.cpu_count() |
|
recommended_workers = (2 * cpu_cores) + 1 |
|
return min(max(4, recommended_workers), 8) |
|
except Exception as e: |
|
if debug: |
|
print(f"Worker count calculation failed: {e}, using default 4") |
|
return 4 |
|
|
|
def get_server_config(self, host: str = "0.0.0.0", port: int = 7860) -> uvicorn.Config: |
|
"""Get server configuration""" |
|
workers = self._get_workers_count() |
|
if debug: |
|
print(f"Configuring server with {workers} workers") |
|
|
|
return uvicorn.Config( |
|
app=self.app, |
|
host=host, |
|
port=port, |
|
workers=workers, |
|
loop="uvloop", |
|
limit_concurrency=1000, |
|
timeout_keep_alive=30, |
|
access_log=True, |
|
log_level="info", |
|
http="httptools" |
|
) |
|
|
|
def run(self, host: str = "0.0.0.0", port: int = 7860) -> None: |
|
"""Run the API server""" |
|
config = self.get_server_config(host, port) |
|
server = uvicorn.Server(config) |
|
server.run() |
|
|
|
def create_server() -> APIServer: |
|
"""Factory function to create server instance""" |
|
return APIServer(app) |
|
|
|
if __name__ == "__main__": |
|
port = int(os.getenv("PORT", "7860")) |
|
server = create_server() |
|
server.run(port=port) |