sanbo
commited on
Commit
·
49cdbc0
1
Parent(s):
9599fd6
update sth. at 2025-01-08 13:57:45
Browse files- more_core.py +18 -2
- some_base_method/more_core——back.py +317 -0
more_core.py
CHANGED
@@ -6,6 +6,7 @@ import string
|
|
6 |
import time
|
7 |
from json.decoder import JSONDecodeError
|
8 |
from typing import Dict, Any, List
|
|
|
9 |
import tiktoken
|
10 |
import uvicorn
|
11 |
from apscheduler.schedulers.background import BackgroundScheduler
|
@@ -39,7 +40,7 @@ class APIServer:
|
|
39 |
"""Initialize API routes"""
|
40 |
|
41 |
# 修改根路由的重定向实现
|
42 |
-
@self.app.get("/")
|
43 |
async def root():
|
44 |
# 添加状态码和确保完整URL
|
45 |
return RedirectResponse(
|
@@ -47,12 +48,27 @@ class APIServer:
|
|
47 |
status_code=302 # 添加明确的重定向状态码
|
48 |
)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# 修改 web 路由的返回类型
|
51 |
@self.app.get("/web")
|
52 |
async def web():
|
53 |
# # 返回 JSONResponse 或 HTML 内容
|
54 |
# return JSONResponse(content={"message": "hello. It's web page."})
|
55 |
-
|
56 |
## 或者返回HTML内容
|
57 |
return HTMLResponse(content="<h1>hello. It's web page.</h1>")
|
58 |
|
|
|
6 |
import time
|
7 |
from json.decoder import JSONDecodeError
|
8 |
from typing import Dict, Any, List
|
9 |
+
|
10 |
import tiktoken
|
11 |
import uvicorn
|
12 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
40 |
"""Initialize API routes"""
|
41 |
|
42 |
# 修改根路由的重定向实现
|
43 |
+
@self.app.get("/", include_in_schema=False)
|
44 |
async def root():
|
45 |
# 添加状态码和确保完整URL
|
46 |
return RedirectResponse(
|
|
|
48 |
status_code=302 # 添加明确的重定向状态码
|
49 |
)
|
50 |
|
51 |
+
@app.exception_handler(404)
|
52 |
+
async def custom_404_handler(request: Request, exc: HTTPException):
|
53 |
+
return JSONResponse(
|
54 |
+
content={"detail": "This path is not defined in the server"},
|
55 |
+
status_code=404,
|
56 |
+
)
|
57 |
+
|
58 |
+
def validate_logs_parameter(logs: str) -> Dict[str, Any]:
|
59 |
+
if logs == "container":
|
60 |
+
return {"message": "Logs feature not implemented"}
|
61 |
+
raise HTTPException(status_code=404, detail="Invalid logs parameter")
|
62 |
+
|
63 |
+
@self.app.get("/logs")
|
64 |
+
async def logs(logs: str = None):
|
65 |
+
return JSONResponse(content=self.validate_logs_parameter(logs))
|
66 |
+
|
67 |
# 修改 web 路由的返回类型
|
68 |
@self.app.get("/web")
|
69 |
async def web():
|
70 |
# # 返回 JSONResponse 或 HTML 内容
|
71 |
# return JSONResponse(content={"message": "hello. It's web page."})
|
|
|
72 |
## 或者返回HTML内容
|
73 |
return HTMLResponse(content="<h1>hello. It's web page.</h1>")
|
74 |
|
some_base_method/more_core——back.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import multiprocessing
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import string
|
6 |
+
import time
|
7 |
+
from json.decoder import JSONDecodeError
|
8 |
+
from typing import Dict, Any, List
|
9 |
+
import tiktoken
|
10 |
+
import uvicorn
|
11 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
12 |
+
from fastapi import FastAPI, Request, HTTPException
|
13 |
+
from fastapi.responses import RedirectResponse, JSONResponse
|
14 |
+
from starlette.responses import HTMLResponse
|
15 |
+
|
16 |
+
import degpt as dg
|
17 |
+
|
18 |
+
app = FastAPI(
|
19 |
+
title="ones",
|
20 |
+
description="High-performance API service",
|
21 |
+
version="1.0.0|2025.1.6"
|
22 |
+
)
|
23 |
+
# debug for Log
|
24 |
+
debug = False
|
25 |
+
|
26 |
+
|
27 |
+
class APIServer:
|
28 |
+
"""High-performance API server implementation"""
|
29 |
+
|
30 |
+
def __init__(self, app: FastAPI):
|
31 |
+
self.app = app
|
32 |
+
self.encoding = tiktoken.get_encoding("cl100k_base")
|
33 |
+
self._setup_routes()
|
34 |
+
self.scheduler = BackgroundScheduler()
|
35 |
+
self._schedule_route_check()
|
36 |
+
self.scheduler.start()
|
37 |
+
|
38 |
+
def _setup_routes(self) -> None:
|
39 |
+
"""Initialize API routes"""
|
40 |
+
|
41 |
+
# 修改根路由的重定向实现
|
42 |
+
@self.app.get("/")
|
43 |
+
async def root():
|
44 |
+
# 添加状态码和确保完整URL
|
45 |
+
return RedirectResponse(
|
46 |
+
url="/web",
|
47 |
+
status_code=302 # 添加明确的重定向状态码
|
48 |
+
)
|
49 |
+
|
50 |
+
# 修改 web 路由的返回类型
|
51 |
+
@self.app.get("/web")
|
52 |
+
async def web():
|
53 |
+
# # 返回 JSONResponse 或 HTML 内容
|
54 |
+
# return JSONResponse(content={"message": "hello. It's web page."})
|
55 |
+
|
56 |
+
## 或者返回HTML内容
|
57 |
+
return HTMLResponse(content="<h1>hello. It's web page.</h1>")
|
58 |
+
|
59 |
+
@self.app.get("/api/v1/models")
|
60 |
+
async def models() -> str:
|
61 |
+
models_str = dg.get_models() # Get the JSON string
|
62 |
+
models_json = json.loads(models_str) # Convert it back to a Python object
|
63 |
+
return JSONResponse(content=models_json) # Return as JSON response
|
64 |
+
|
65 |
+
routes = self._get_routes()
|
66 |
+
for path in routes:
|
67 |
+
self._register_route(path)
|
68 |
+
|
69 |
+
def _get_routes(self) -> List[str]:
|
70 |
+
"""Get configured API routes"""
|
71 |
+
default_path = "/api/v1/chat/completions"
|
72 |
+
replace_chat = os.getenv("REPLACE_CHAT", "")
|
73 |
+
prefix_chat = os.getenv("PREFIX_CHAT", "")
|
74 |
+
append_chat = os.getenv("APPEND_CHAT", "")
|
75 |
+
|
76 |
+
if replace_chat:
|
77 |
+
return [path.strip() for path in replace_chat.split(",") if path.strip()]
|
78 |
+
|
79 |
+
routes = []
|
80 |
+
if prefix_chat:
|
81 |
+
routes.extend(f"{prefix.rstrip('/')}{default_path}"
|
82 |
+
for prefix in prefix_chat.split(","))
|
83 |
+
return routes
|
84 |
+
|
85 |
+
if append_chat:
|
86 |
+
append_paths = [path.strip() for path in append_chat.split(",") if path.strip()]
|
87 |
+
routes = [default_path] + append_paths
|
88 |
+
return routes
|
89 |
+
|
90 |
+
return [default_path]
|
91 |
+
|
92 |
+
def _register_route(self, path: str) -> None:
|
93 |
+
"""Register a single API route"""
|
94 |
+
global debug
|
95 |
+
|
96 |
+
async def chat_endpoint(request: Request) -> Dict[str, Any]:
|
97 |
+
try:
|
98 |
+
headers = dict(request.headers)
|
99 |
+
data = await request.json()
|
100 |
+
if debug:
|
101 |
+
print(f"Request received...\r\n\tHeaders: {headers},\r\n\tData: {data}")
|
102 |
+
return await self._generate_response(headers, data)
|
103 |
+
except JSONDecodeError as e:
|
104 |
+
if debug:
|
105 |
+
print(f"JSON decode error: {e}")
|
106 |
+
raise HTTPException(status_code=400, detail="Invalid JSON format") from e
|
107 |
+
except Exception as e:
|
108 |
+
if debug:
|
109 |
+
print(f"Request processing error: {e}")
|
110 |
+
raise HTTPException(status_code=500, detail="Internal server error") from e
|
111 |
+
|
112 |
+
self.app.post(path)(chat_endpoint)
|
113 |
+
|
114 |
+
def _calculate_tokens(self, text: str) -> int:
|
115 |
+
"""Calculate token count for text"""
|
116 |
+
return len(self.encoding.encode(text))
|
117 |
+
|
118 |
+
def _generate_id(self, letters: int = 4, numbers: int = 6) -> str:
|
119 |
+
"""Generate unique chat completion ID"""
|
120 |
+
letters_str = ''.join(random.choices(string.ascii_lowercase, k=letters))
|
121 |
+
numbers_str = ''.join(random.choices(string.digits, k=numbers))
|
122 |
+
return f"chatcmpl-{letters_str}{numbers_str}"
|
123 |
+
|
124 |
+
def is_chatgpt_format(self, data):
|
125 |
+
"""Check if the data is in the expected ChatGPT format"""
|
126 |
+
try:
|
127 |
+
# If the data is a string, try to parse it as JSON
|
128 |
+
if isinstance(data, str):
|
129 |
+
try:
|
130 |
+
data = json.loads(data)
|
131 |
+
except json.JSONDecodeError:
|
132 |
+
return False # If the string can't be parsed, it's not in the expected format
|
133 |
+
|
134 |
+
# Now check if data is a dictionary and contains the necessary structure
|
135 |
+
if isinstance(data, dict):
|
136 |
+
# Ensure 'choices' is a list and the first item has a 'message' field
|
137 |
+
if "choices" in data and isinstance(data["choices"], list) and len(data["choices"]) > 0:
|
138 |
+
if "message" in data["choices"][0]:
|
139 |
+
return True
|
140 |
+
except Exception as e:
|
141 |
+
print(f"Error checking ChatGPT format: {e}")
|
142 |
+
return False
|
143 |
+
|
144 |
+
def process_result(self, result, model):
|
145 |
+
# 如果result是字符串,尝试将其转换为JSON
|
146 |
+
if isinstance(result, str):
|
147 |
+
try:
|
148 |
+
result = json.loads(result) # 转换为JSON
|
149 |
+
except json.JSONDecodeError:
|
150 |
+
return result
|
151 |
+
|
152 |
+
# 确保result是一个字典(JSON对象)
|
153 |
+
if isinstance(result, dict):
|
154 |
+
# 设置新的id和object值
|
155 |
+
result['id'] = self._generate_id() # 根据需要设置新的ID值
|
156 |
+
result['object'] = "chat.completion" # 根据需要设置新的object值
|
157 |
+
|
158 |
+
# 添加model值
|
159 |
+
result['model'] = model # 根据需要设置model值
|
160 |
+
return result
|
161 |
+
|
162 |
+
async def _generate_response(self, headers: Dict[str, str], data: Dict[str, Any]) -> Dict[str, Any]:
|
163 |
+
"""Generate API response"""
|
164 |
+
global debug
|
165 |
+
try:
|
166 |
+
# check model
|
167 |
+
model = data.get("model")
|
168 |
+
# print(f"model: {model}")
|
169 |
+
# if "auto" == model:
|
170 |
+
# model = dg.get_auto_model(model)
|
171 |
+
# else:
|
172 |
+
# if not dg.is_model_available(model):
|
173 |
+
# raise HTTPException(status_code=400, detail="Invalid Model")
|
174 |
+
## kuan
|
175 |
+
model = dg.get_model_by_autoupdate(model)
|
176 |
+
# must has token ? token check
|
177 |
+
|
178 |
+
# call ai
|
179 |
+
msgs = data.get("messages")
|
180 |
+
if debug:
|
181 |
+
print(f"req messages: {msgs}")
|
182 |
+
result = dg.chat_completion_messages(messages=msgs, model=model)
|
183 |
+
if debug:
|
184 |
+
print(f"result: {result}---- {self.is_chatgpt_format(result)}")
|
185 |
+
|
186 |
+
# # Assuming this 'result' comes from your model or some other logic
|
187 |
+
# result = "This is a test result."
|
188 |
+
|
189 |
+
# If the request body data already matches ChatGPT format, return it directly
|
190 |
+
if self.is_chatgpt_format(result):
|
191 |
+
response_data = self.process_result(result,
|
192 |
+
model) # If data already follows ChatGPT format, use it directly
|
193 |
+
else:
|
194 |
+
# Calculate the current timestamp
|
195 |
+
current_timestamp = int(time.time() * 1000)
|
196 |
+
# Otherwise, calculate the tokens and return a structured response
|
197 |
+
prompt_tokens = self._calculate_tokens(str(data))
|
198 |
+
completion_tokens = self._calculate_tokens(result)
|
199 |
+
total_tokens = prompt_tokens + completion_tokens
|
200 |
+
|
201 |
+
response_data = {
|
202 |
+
"id": self._generate_id(),
|
203 |
+
"object": "chat.completion",
|
204 |
+
"created": current_timestamp,
|
205 |
+
"model": data.get("model", "gpt-4o"),
|
206 |
+
"usage": {
|
207 |
+
"prompt_tokens": prompt_tokens,
|
208 |
+
"completion_tokens": completion_tokens,
|
209 |
+
"total_tokens": total_tokens
|
210 |
+
},
|
211 |
+
"choices": [{
|
212 |
+
"message": {
|
213 |
+
"role": "assistant",
|
214 |
+
"content": result
|
215 |
+
},
|
216 |
+
"finish_reason": "stop",
|
217 |
+
"index": 0
|
218 |
+
}]
|
219 |
+
}
|
220 |
+
|
221 |
+
# Print the response for debugging (you may remove this in production)
|
222 |
+
if debug:
|
223 |
+
print(f"Response Data: {response_data}")
|
224 |
+
|
225 |
+
return response_data
|
226 |
+
except Exception as e:
|
227 |
+
if debug:
|
228 |
+
print(f"Response generation error: {e}")
|
229 |
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
230 |
+
|
231 |
+
def _get_workers_count(self) -> int:
|
232 |
+
"""Calculate optimal worker count"""
|
233 |
+
try:
|
234 |
+
cpu_cores = multiprocessing.cpu_count()
|
235 |
+
recommended_workers = (2 * cpu_cores) + 1
|
236 |
+
return min(max(4, recommended_workers), 8)
|
237 |
+
except Exception as e:
|
238 |
+
if debug:
|
239 |
+
print(f"Worker count calculation failed: {e}, using default 4")
|
240 |
+
return 4
|
241 |
+
|
242 |
+
def get_server_config(self, host: str = "0.0.0.0", port: int = 7860) -> uvicorn.Config:
|
243 |
+
"""Get server configuration"""
|
244 |
+
workers = self._get_workers_count()
|
245 |
+
if debug:
|
246 |
+
print(f"Configuring server with {workers} workers")
|
247 |
+
|
248 |
+
return uvicorn.Config(
|
249 |
+
app=self.app,
|
250 |
+
host=host,
|
251 |
+
port=port,
|
252 |
+
workers=workers,
|
253 |
+
loop="uvloop",
|
254 |
+
limit_concurrency=1000,
|
255 |
+
timeout_keep_alive=30,
|
256 |
+
access_log=True,
|
257 |
+
log_level="info",
|
258 |
+
http="httptools"
|
259 |
+
)
|
260 |
+
|
261 |
+
def run(self, host: str = "0.0.0.0", port: int = 7860) -> None:
|
262 |
+
"""Run the API server"""
|
263 |
+
config = self.get_server_config(host, port)
|
264 |
+
server = uvicorn.Server(config)
|
265 |
+
server.run()
|
266 |
+
|
267 |
+
def _reload_check(self) -> None:
|
268 |
+
dg.reload_check()
|
269 |
+
|
270 |
+
def _schedule_route_check(self) -> None:
|
271 |
+
"""
|
272 |
+
Schedule tasks to check and reload routes and models at regular intervals.
|
273 |
+
- Reload routes every 30 seconds.
|
274 |
+
- Reload models every 30 minutes.
|
275 |
+
"""
|
276 |
+
# Scheduled Task 1: Check and reload routes every 30 seconds
|
277 |
+
# Calls _reload_routes_if_needed method to check if routes need to be updated
|
278 |
+
self.scheduler.add_job(self._reload_routes_if_needed, 'interval', seconds=30)
|
279 |
+
|
280 |
+
# Scheduled Task 2: Reload models every 30 minutes (1800 seconds)
|
281 |
+
# This task will check and update the model data periodically
|
282 |
+
self.scheduler.add_job(self._reload_check, 'interval', seconds=60 * 30)
|
283 |
+
pass
|
284 |
+
|
285 |
+
def _reload_routes_if_needed(self) -> None:
|
286 |
+
"""Check if routes need to be reloaded based on environment variables"""
|
287 |
+
# reload Debug
|
288 |
+
global debug
|
289 |
+
debug = os.getenv("DEBUG", "False").lower() in ["true", "1", "t"]
|
290 |
+
# relaod routes
|
291 |
+
new_routes = self._get_routes()
|
292 |
+
current_routes = [route for route in self.app.routes if hasattr(route, 'path')]
|
293 |
+
|
294 |
+
# Check if the current routes are different from the new routes
|
295 |
+
if [route.path for route in current_routes] != new_routes:
|
296 |
+
if debug:
|
297 |
+
print("Routes changed, reloading...")
|
298 |
+
self._reload_routes(new_routes)
|
299 |
+
|
300 |
+
def _reload_routes(self, new_routes: List[str]) -> None:
|
301 |
+
"""Reload the routes based on the updated configuration"""
|
302 |
+
# Clear existing routes
|
303 |
+
self.app.routes.clear()
|
304 |
+
# Register new routes
|
305 |
+
for path in new_routes:
|
306 |
+
self._register_route(path)
|
307 |
+
|
308 |
+
|
309 |
+
def create_server() -> APIServer:
|
310 |
+
"""Factory function to create server instance"""
|
311 |
+
return APIServer(app)
|
312 |
+
|
313 |
+
|
314 |
+
if __name__ == "__main__":
|
315 |
+
port = int(os.getenv("PORT", "7860"))
|
316 |
+
server = create_server()
|
317 |
+
server.run(port=port)
|