|
import os |
|
import uvicorn |
|
import time |
|
from fastapi import FastAPI, Request, HTTPException |
|
import tiktoken |
|
from json.decoder import JSONDecodeError |
|
import multiprocessing |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
encoding = tiktoken.get_encoding("cl100k_base") |
|
debug=True |
|
|
|
|
|
def general_id(zimu=4, num=6): |
|
import random |
|
import string |
|
letters = ''.join(random.choices(string.ascii_lowercase, k=zimu)) |
|
numbers = ''.join(random.choices(string.digits, k=num)) |
|
return f"chatcmpl-{letters}{numbers}" |
|
|
|
|
|
|
|
def handle_exception(e): |
|
if isinstance(e, JSONDecodeError): |
|
raise HTTPException(status_code=400, detail="无效的JSON格式") |
|
else: |
|
raise HTTPException(status_code=500, detail="服务器内部错误") |
|
|
|
|
|
|
|
def is_chatgpt_format(data): |
|
try: |
|
|
|
if isinstance(data, dict) and "choices" in data and "message" in data["choices"][0]: |
|
return True |
|
except Exception as e: |
|
pass |
|
return False |
|
|
|
def get_workers_count() -> int: |
|
""" |
|
Calculate optimal number of workers |
|
Default: 4, Maximum: 8 |
|
Formula: min(max(4, (2 * CPU cores) + 1), 8) |
|
""" |
|
try: |
|
cpu_cores = multiprocessing.cpu_count() |
|
recommended_workers = (2 * cpu_cores) + 1 |
|
return min(max(4, recommended_workers), 8) |
|
except Exception as e: |
|
if debug: |
|
print(f"Worker count calculation failed: {e}, using default 4") |
|
return 4 |
|
|
|
|
|
def generate_response(headers, data): |
|
try: |
|
|
|
|
|
|
|
result = "This is a test result." |
|
|
|
|
|
|
|
|
|
if is_chatgpt_format(result): |
|
response_data = result |
|
else: |
|
|
|
current_timestamp = int(time.time() * 1000) |
|
|
|
|
|
prompt_tokens = len(encoding.encode(str(data))) |
|
completion_tokens = len(encoding.encode(result)) |
|
total_tokens = prompt_tokens + completion_tokens |
|
|
|
|
|
response_data = { |
|
"id": general_id(), |
|
"object": "chat.completion", |
|
"created": current_timestamp, |
|
"model": data.get("model", "gpt-4o"), |
|
"usage": { |
|
"prompt_tokens": prompt_tokens, |
|
"completion_tokens": completion_tokens, |
|
"total_tokens": total_tokens, |
|
"completion_tokens_details": { |
|
"reasoning_tokens": 0, |
|
"accepted_prediction_tokens": 0, |
|
"rejected_prediction_tokens": 0 |
|
} |
|
}, |
|
"choices": [ |
|
{ |
|
"message": { |
|
"role": "assistant", |
|
"content": result |
|
}, |
|
"logprobs": None, |
|
"finish_reason": "stop", |
|
"index": 0 |
|
} |
|
] |
|
} |
|
|
|
print(f"Response Data: {response_data}" ) |
|
return response_data |
|
except Exception as e: |
|
handle_exception(e) |
|
|
|
|
|
|
|
|
|
def get_dynamic_routes(): |
|
""" |
|
根据环境变量动态生成路由路径。 |
|
|
|
优先级: REPLACE_CHAT > PREFIX_CHAT > APPEND_CHAT |
|
- REPLACE_CHAT: 替换默认访问地址,直接替换默认路径,支持多个,逗号分割 |
|
- PREFIX_CHAT: 在默认访问地址前加上前缀,支持多个,逗号分割 |
|
- APPEND_CHAT: 添加额外的访问地址,支持多个,逗号分割 |
|
|
|
:return: 返回一个包含动态生成的路由路径的列表 |
|
""" |
|
|
|
default_path = "/v1/chat/completions" |
|
|
|
|
|
replace_chat = os.getenv("REPLACE_CHAT", "") |
|
prefix_chat = os.getenv("PREFIX_CHAT", "") |
|
append_chat = os.getenv("APPEND_CHAT", "") |
|
|
|
|
|
if replace_chat: |
|
|
|
return [path.strip() for path in replace_chat.split(",") if path.strip()] |
|
|
|
|
|
routes = [] |
|
|
|
|
|
if prefix_chat: |
|
|
|
prefixes = prefix_chat.split(",") |
|
for prefix in prefixes: |
|
|
|
routes.append(prefix + default_path) |
|
|
|
return routes |
|
|
|
|
|
if append_chat: |
|
|
|
append_paths = [path.strip() for path in append_chat.split(",") if path.strip()] |
|
|
|
routes = [default_path] + append_paths |
|
|
|
|
|
if not routes: |
|
routes.append(default_path) |
|
|
|
|
|
return routes |
|
|
|
|
|
|
|
def register_route(path: str): |
|
print(f"register route path: {path}") |
|
@app.post(path) |
|
async def dynamic_chat_endpoint(request: Request): |
|
try: |
|
headers = request.headers |
|
data = request.json() |
|
print(f"Received Request Header: {headers}\nData: {data}") |
|
result = generate_response(headers, data) |
|
return result |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
def setup_dynamic_routes(): |
|
routes = get_dynamic_routes() |
|
print("Registering routes: {routes}") |
|
for path in routes: |
|
register_route(path) |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return "working..." |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
setup_dynamic_routes() |
|
|
|
port = int(os.getenv("PORT", "7860")) |
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|