Add traffic middleware
Browse files- log_config.py +2 -1
- main.py +114 -48
log_config.py
CHANGED
@@ -2,4 +2,5 @@ import logging
|
|
2 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
3 |
logger = logging.getLogger("uni-api")
|
4 |
|
5 |
-
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
|
|
2 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
3 |
logger = logging.getLogger("uni-api")
|
4 |
|
5 |
+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
6 |
+
logging.getLogger("watchfiles.main").setLevel(logging.CRITICAL)
|
main.py
CHANGED
@@ -5,14 +5,14 @@ import secrets
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
-
from fastapi import FastAPI, HTTPException, Depends
|
9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
|
12 |
from models import RequestModel, ImageGenerationRequest
|
13 |
-
from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
|
14 |
from request import get_payload
|
15 |
from response import fetch_response, fetch_response_stream
|
|
|
16 |
|
17 |
from typing import List, Dict, Union
|
18 |
from urllib.parse import urlparse
|
@@ -42,36 +42,91 @@ async def lifespan(app: FastAPI):
|
|
42 |
|
43 |
app = FastAPI(lifespan=lifespan)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
#
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# 配置 CORS 中间件
|
77 |
app.add_middleware(
|
@@ -82,6 +137,8 @@ app.add_middleware(
|
|
82 |
allow_headers=["*"], # 允许所有头部字段
|
83 |
)
|
84 |
|
|
|
|
|
85 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
86 |
url = provider['base_url']
|
87 |
parsed_url = urlparse(url)
|
@@ -233,6 +290,17 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
233 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
234 |
return token
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
@app.post("/v1/chat/completions")
|
237 |
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
|
238 |
return await model_handler.request_model(request, token)
|
@@ -258,24 +326,22 @@ async def images_generations(
|
|
258 |
|
259 |
@app.get("/generate-api-key")
|
260 |
def generate_api_key():
|
261 |
-
api_key = "sk-" + secrets.token_urlsafe(
|
262 |
return JSONResponse(content={"api_key": api_key})
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
# }
|
278 |
-
# }
|
279 |
|
280 |
# async def on_fetch(request, env):
|
281 |
# import asgi
|
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
+
from fastapi import FastAPI, HTTPException, Depends, Request
|
9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
|
12 |
from models import RequestModel, ImageGenerationRequest
|
|
|
13 |
from request import get_payload
|
14 |
from response import fetch_response, fetch_response_stream
|
15 |
+
from utils import error_handling_wrapper, post_all_models, load_config
|
16 |
|
17 |
from typing import List, Dict, Union
|
18 |
from urllib.parse import urlparse
|
|
|
42 |
|
43 |
app = FastAPI(lifespan=lifespan)
|
44 |
|
45 |
+
import asyncio
|
46 |
+
from time import time
|
47 |
+
from collections import defaultdict
|
48 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
49 |
+
from datetime import datetime
|
50 |
+
from datetime import timedelta
|
51 |
+
import json
|
52 |
+
import aiofiles
|
53 |
+
|
54 |
+
class StatsMiddleware(BaseHTTPMiddleware):
|
55 |
+
def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
|
56 |
+
super().__init__(app)
|
57 |
+
self.request_counts = defaultdict(int)
|
58 |
+
self.request_times = defaultdict(float)
|
59 |
+
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
60 |
+
self.request_arrivals = defaultdict(list)
|
61 |
+
self.lock = asyncio.Lock()
|
62 |
+
self.exclude_paths = set(exclude_paths or [])
|
63 |
+
self.save_interval = save_interval
|
64 |
+
self.filename = filename
|
65 |
+
self.last_save_time = time()
|
66 |
+
|
67 |
+
# 启动定期保存和清理任务
|
68 |
+
asyncio.create_task(self.periodic_save_and_cleanup())
|
69 |
+
|
70 |
+
async def dispatch(self, request: Request, call_next):
|
71 |
+
arrival_time = datetime.now()
|
72 |
+
start_time = time()
|
73 |
+
response = await call_next(request)
|
74 |
+
process_time = time() - start_time
|
75 |
+
|
76 |
+
endpoint = f"{request.method} {request.url.path}"
|
77 |
+
client_ip = request.client.host
|
78 |
+
|
79 |
+
if request.url.path not in self.exclude_paths:
|
80 |
+
async with self.lock:
|
81 |
+
self.request_counts[endpoint] += 1
|
82 |
+
self.request_times[endpoint] += process_time
|
83 |
+
self.ip_counts[endpoint][client_ip] += 1
|
84 |
+
self.request_arrivals[endpoint].append(arrival_time)
|
85 |
+
|
86 |
+
return response
|
87 |
+
|
88 |
+
async def periodic_save_and_cleanup(self):
|
89 |
+
while True:
|
90 |
+
await asyncio.sleep(self.save_interval)
|
91 |
+
await self.save_stats()
|
92 |
+
await self.cleanup_old_data()
|
93 |
+
|
94 |
+
async def save_stats(self):
|
95 |
+
current_time = time()
|
96 |
+
if current_time - self.last_save_time < self.save_interval:
|
97 |
+
return
|
98 |
+
|
99 |
+
async with self.lock:
|
100 |
+
stats = {
|
101 |
+
"request_counts": dict(self.request_counts),
|
102 |
+
"request_times": dict(self.request_times),
|
103 |
+
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
104 |
+
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}
|
105 |
+
}
|
106 |
+
|
107 |
+
filename = self.filename
|
108 |
+
async with aiofiles.open(filename, mode='w') as f:
|
109 |
+
await f.write(json.dumps(stats, indent=2))
|
110 |
+
|
111 |
+
self.last_save_time = current_time
|
112 |
+
# print(f"Stats saved to {filename}")
|
113 |
+
|
114 |
+
async def cleanup_old_data(self):
|
115 |
+
# cutoff_time = datetime.now() - timedelta(seconds=30)
|
116 |
+
cutoff_time = datetime.now() - timedelta(hours=24)
|
117 |
+
async with self.lock:
|
118 |
+
for endpoint in list(self.request_arrivals.keys()):
|
119 |
+
self.request_arrivals[endpoint] = [
|
120 |
+
t for t in self.request_arrivals[endpoint] if t > cutoff_time
|
121 |
+
]
|
122 |
+
if not self.request_arrivals[endpoint]:
|
123 |
+
del self.request_arrivals[endpoint]
|
124 |
+
self.request_counts.pop(endpoint, None)
|
125 |
+
self.request_times.pop(endpoint, None)
|
126 |
+
self.ip_counts.pop(endpoint, None)
|
127 |
+
|
128 |
+
async def cleanup(self):
|
129 |
+
await self.save_stats()
|
130 |
|
131 |
# 配置 CORS 中间件
|
132 |
app.add_middleware(
|
|
|
137 |
allow_headers=["*"], # 允许所有头部字段
|
138 |
)
|
139 |
|
140 |
+
app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
|
141 |
+
|
142 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
|
143 |
url = provider['base_url']
|
144 |
parsed_url = urlparse(url)
|
|
|
290 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
291 |
return token
|
292 |
|
293 |
+
def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
294 |
+
api_list = app.state.api_list
|
295 |
+
token = credentials.credentials
|
296 |
+
if token not in api_list:
|
297 |
+
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
298 |
+
for api_key in app.state.api_keys_db:
|
299 |
+
if api_key['api'] == token:
|
300 |
+
if api_key.get('role') != "admin":
|
301 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
302 |
+
return token
|
303 |
+
|
304 |
@app.post("/v1/chat/completions")
|
305 |
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
|
306 |
return await model_handler.request_model(request, token)
|
|
|
326 |
|
327 |
@app.get("/generate-api-key")
|
328 |
def generate_api_key():
|
329 |
+
api_key = "sk-" + secrets.token_urlsafe(36)
|
330 |
return JSONResponse(content={"api_key": api_key})
|
331 |
|
332 |
+
@app.get("/stats")
|
333 |
+
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
|
334 |
+
middleware = app.middleware_stack.app
|
335 |
+
if isinstance(middleware, StatsMiddleware):
|
336 |
+
async with middleware.lock:
|
337 |
+
stats = {
|
338 |
+
"request_counts": dict(middleware.request_counts),
|
339 |
+
"request_times": dict(middleware.request_times),
|
340 |
+
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|
341 |
+
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}
|
342 |
+
}
|
343 |
+
return JSONResponse(content=stats)
|
344 |
+
return {"error": "StatsMiddleware not found"}
|
|
|
|
|
345 |
|
346 |
# async def on_fetch(request, env):
|
347 |
# import asgi
|