File size: 6,756 Bytes
68515e1 778d7a6 68515e1 a487faf 68515e1 778d7a6 68515e1 a487faf 68515e1 a487faf 68515e1 a487faf 68515e1 a487faf 68515e1 a487faf 68515e1 a487faf 68515e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import uvicorn
import time
from fastapi import FastAPI, Request, HTTPException
import tiktoken
from json.decoder import JSONDecodeError
import multiprocessing
# 初始化 FastAPI
app = FastAPI()
# 获取编码器
encoding = tiktoken.get_encoding("cl100k_base")
debug=True
# 随机生成 ID
def general_id(zimu=4, num=6):
import random
import string
letters = ''.join(random.choices(string.ascii_lowercase, k=zimu))
numbers = ''.join(random.choices(string.digits, k=num))
return f"chatcmpl-{letters}{numbers}"
# 异常处理函数
def handle_exception(e):
if isinstance(e, JSONDecodeError):
raise HTTPException(status_code=400, detail="无效的JSON格式")
else:
raise HTTPException(status_code=500, detail="服务器内部错误")
# chatgpt format
def is_chatgpt_format(data):
try:
# 检查数据是否符合 ChatGPT 的格式
if isinstance(data, dict) and "choices" in data and "message" in data["choices"][0]:
return True
except Exception as e:
pass
return False
def get_workers_count() -> int:
"""
Calculate optimal number of workers
Default: 4, Maximum: 8
Formula: min(max(4, (2 * CPU cores) + 1), 8)
"""
try:
cpu_cores = multiprocessing.cpu_count()
recommended_workers = (2 * cpu_cores) + 1
return min(max(4, recommended_workers), 8)
except Exception as e:
if debug:
print(f"Worker count calculation failed: {e}, using default 4")
return 4
# 模拟 ChatGPT 响应
def generate_response(headers, data):
try:
## 尝试解析头及请求,主服务未来不是授权key不能使用
## 返回结果
result = "This is a test result."
## 计算tokean
# 判断请求体中的数据是否已经符合 ChatGPT 的格式
if is_chatgpt_format(result):
response_data = result
else:
# 计算时间戳
current_timestamp = int(time.time() * 1000)
# 计算 token 数量
prompt_tokens = len(encoding.encode(str(data))) # 对请求体的编码进行 token 计算
completion_tokens = len(encoding.encode(result)) # 对返回结果的编码进行 token 计算
total_tokens = prompt_tokens + completion_tokens
# 构造符合 ChatGPT 格式的响应
response_data = {
"id": general_id(),
"object": "chat.completion",
"created": current_timestamp,
"model": data.get("model", "gpt-4o"),
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"completion_tokens_details": {
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
},
"choices": [
{
"message": {
"role": "assistant",
"content": result
},
"logprobs": None,
"finish_reason": "stop",
"index": 0
}
]
}
# 打印最终返回的数据
print(f"Response Data: {response_data}" )
return response_data
except Exception as e:
handle_exception(e) # 异常处理函数
# 解析动态路由
def get_dynamic_routes():
"""
根据环境变量动态生成路由路径。
优先级: REPLACE_CHAT > PREFIX_CHAT > APPEND_CHAT
- REPLACE_CHAT: 替换默认访问地址,直接替换默认路径,支持多个,逗号分割
- PREFIX_CHAT: 在默认访问地址前加上前缀,支持多个,逗号分割
- APPEND_CHAT: 添加额外的访问地址,支持多个,逗号分割
:return: 返回一个包含动态生成的路由路径的列表
"""
# 默认路径
default_path = "/v1/chat/completions"
# 获取环境变量,默认为空字符串
replace_chat = os.getenv("REPLACE_CHAT", "")
prefix_chat = os.getenv("PREFIX_CHAT", "")
append_chat = os.getenv("APPEND_CHAT", "")
# 优先级: 如果设置了 REPLACE_CHAT,直接返回替换路径
if replace_chat:
# 如果设置了 REPLACE_CHAT,按逗号分割,返回多个替换路径
return [path.strip() for path in replace_chat.split(",") if path.strip()]
# 如果没有 REPLACE_CHAT,检查是否设置了 PREFIX_CHAT
routes = [] # 用于保存最终生成的路由列表
# 处理 PREFIX_CHAT: 在默认路径前加前缀
if prefix_chat:
# 将 PREFIX_CHAT 按逗号分隔,支持多个前缀
prefixes = prefix_chat.split(",")
for prefix in prefixes:
# 确保路径拼接时,前缀和默认路径之间有一个斜杠
routes.append(prefix + default_path)
# 如果设置了 PREFIX_CHAT,返回所有生成的前缀路径
return routes
# 如果没有设置 PREFIX_CHAT,检查是否设置了 APPEND_CHAT
if append_chat:
# 将 APPEND_CHAT 按逗号分隔,去除可能的空值和空白字符
append_paths = [path.strip() for path in append_chat.split(",") if path.strip()]
# 将默认路径和 APPEND_CHAT 中的路径合并
routes = [default_path] + append_paths
# 如果没有设置 REPLACE_CHAT、PREFIX_CHAT 或 APPEND_CHAT,返回默认路径
if not routes:
routes.append(default_path)
# 返回生成的所有路由路径
return routes
# 注册单个动态路由
def register_route(path: str):
print(f"register route path: {path}")
@app.post(path)
async def dynamic_chat_endpoint(request: Request):
try:
headers = request.headers
data = request.json()
print(f"Received Request Header: {headers}\nData: {data}")
result = generate_response(headers, data)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 动态设置路由
def setup_dynamic_routes():
routes = get_dynamic_routes()
print("Registering routes: {routes}")
for path in routes:
register_route(path)
@app.get("/")
def read_root():
return "working..."
if __name__ == "__main__":
# 注册动态路由
setup_dynamic_routes()
# 获取端口环境变量(默认为 7860)
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)
|