yym68686 commited on
Commit
819abc0
·
1 Parent(s): 2a7fbb2

Add traffic middleware

Browse files
Files changed (2) hide show
  1. log_config.py +2 -1
  2. main.py +114 -48
log_config.py CHANGED
@@ -2,4 +2,5 @@ import logging
2
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
3
  logger = logging.getLogger("uni-api")
4
 
5
- logging.getLogger("httpx").setLevel(logging.CRITICAL)
 
 
2
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
3
  logger = logging.getLogger("uni-api")
4
 
5
+ logging.getLogger("httpx").setLevel(logging.CRITICAL)
6
+ logging.getLogger("watchfiles.main").setLevel(logging.CRITICAL)
main.py CHANGED
@@ -5,14 +5,14 @@ import secrets
5
  from contextlib import asynccontextmanager
6
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
- from fastapi import FastAPI, HTTPException, Depends
9
  from fastapi.responses import StreamingResponse, JSONResponse
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
 
12
  from models import RequestModel, ImageGenerationRequest
13
- from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
14
  from request import get_payload
15
  from response import fetch_response, fetch_response_stream
 
16
 
17
  from typing import List, Dict, Union
18
  from urllib.parse import urlparse
@@ -42,36 +42,91 @@ async def lifespan(app: FastAPI):
42
 
43
  app = FastAPI(lifespan=lifespan)
44
 
45
- # from time import time
46
- # from collections import defaultdict
47
- # import asyncio
48
-
49
- # class StatsMiddleware:
50
- # def __init__(self):
51
- # self.request_counts = defaultdict(int)
52
- # self.request_times = defaultdict(float)
53
- # self.ip_counts = defaultdict(lambda: defaultdict(int))
54
- # self.lock = asyncio.Lock()
55
-
56
- # async def __call__(self, request: Request, call_next):
57
- # start_time = time()
58
- # response = await call_next(request)
59
- # process_time = time() - start_time
60
-
61
- # endpoint = f"{request.method} {request.url.path}"
62
- # client_ip = request.client.host
63
-
64
- # async with self.lock:
65
- # self.request_counts[endpoint] += 1
66
- # self.request_times[endpoint] += process_time
67
- # self.ip_counts[endpoint][client_ip] += 1
68
-
69
- # return response
70
- # # 创建 StatsMiddleware 实例
71
- # stats_middleware = StatsMiddleware()
72
-
73
- # # 添加 StatsMiddleware
74
- # app.add_middleware(StatsMiddleware)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # 配置 CORS 中间件
77
  app.add_middleware(
@@ -82,6 +137,8 @@ app.add_middleware(
82
  allow_headers=["*"], # 允许所有头部字段
83
  )
84
 
 
 
85
  async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
86
  url = provider['base_url']
87
  parsed_url = urlparse(url)
@@ -233,6 +290,17 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
233
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
234
  return token
235
 
 
 
 
 
 
 
 
 
 
 
 
236
  @app.post("/v1/chat/completions")
237
  async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
238
  return await model_handler.request_model(request, token)
@@ -258,24 +326,22 @@ async def images_generations(
258
 
259
  @app.get("/generate-api-key")
260
  def generate_api_key():
261
- api_key = "sk-" + secrets.token_urlsafe(32)
262
  return JSONResponse(content={"api_key": api_key})
263
 
264
- # @app.get("/stats")
265
- # async def get_stats(token: str = Depends(verify_api_key)):
266
- # async with stats_middleware.lock:
267
- # return {
268
- # "request_counts": dict(stats_middleware.request_counts),
269
- # "average_request_times": {
270
- # endpoint: total_time / count
271
- # for endpoint, total_time in stats_middleware.request_times.items()
272
- # for count in [stats_middleware.request_counts[endpoint]]
273
- # },
274
- # "ip_counts": {
275
- # endpoint: dict(ips)
276
- # for endpoint, ips in stats_middleware.ip_counts.items()
277
- # }
278
- # }
279
 
280
  # async def on_fetch(request, env):
281
  # import asgi
 
5
  from contextlib import asynccontextmanager
6
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi import FastAPI, HTTPException, Depends, Request
9
  from fastapi.responses import StreamingResponse, JSONResponse
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
 
12
  from models import RequestModel, ImageGenerationRequest
 
13
  from request import get_payload
14
  from response import fetch_response, fetch_response_stream
15
+ from utils import error_handling_wrapper, post_all_models, load_config
16
 
17
  from typing import List, Dict, Union
18
  from urllib.parse import urlparse
 
42
 
43
  app = FastAPI(lifespan=lifespan)
44
 
45
+ import asyncio
46
+ from time import time
47
+ from collections import defaultdict
48
+ from starlette.middleware.base import BaseHTTPMiddleware
49
+ from datetime import datetime
50
+ from datetime import timedelta
51
+ import json
52
+ import aiofiles
53
+
54
+ class StatsMiddleware(BaseHTTPMiddleware):
55
+ def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
56
+ super().__init__(app)
57
+ self.request_counts = defaultdict(int)
58
+ self.request_times = defaultdict(float)
59
+ self.ip_counts = defaultdict(lambda: defaultdict(int))
60
+ self.request_arrivals = defaultdict(list)
61
+ self.lock = asyncio.Lock()
62
+ self.exclude_paths = set(exclude_paths or [])
63
+ self.save_interval = save_interval
64
+ self.filename = filename
65
+ self.last_save_time = time()
66
+
67
+ # 启动定期保存和清理任务
68
+ asyncio.create_task(self.periodic_save_and_cleanup())
69
+
70
+ async def dispatch(self, request: Request, call_next):
71
+ arrival_time = datetime.now()
72
+ start_time = time()
73
+ response = await call_next(request)
74
+ process_time = time() - start_time
75
+
76
+ endpoint = f"{request.method} {request.url.path}"
77
+ client_ip = request.client.host
78
+
79
+ if request.url.path not in self.exclude_paths:
80
+ async with self.lock:
81
+ self.request_counts[endpoint] += 1
82
+ self.request_times[endpoint] += process_time
83
+ self.ip_counts[endpoint][client_ip] += 1
84
+ self.request_arrivals[endpoint].append(arrival_time)
85
+
86
+ return response
87
+
88
+ async def periodic_save_and_cleanup(self):
89
+ while True:
90
+ await asyncio.sleep(self.save_interval)
91
+ await self.save_stats()
92
+ await self.cleanup_old_data()
93
+
94
+ async def save_stats(self):
95
+ current_time = time()
96
+ if current_time - self.last_save_time < self.save_interval:
97
+ return
98
+
99
+ async with self.lock:
100
+ stats = {
101
+ "request_counts": dict(self.request_counts),
102
+ "request_times": dict(self.request_times),
103
+ "ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
104
+ "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}
105
+ }
106
+
107
+ filename = self.filename
108
+ async with aiofiles.open(filename, mode='w') as f:
109
+ await f.write(json.dumps(stats, indent=2))
110
+
111
+ self.last_save_time = current_time
112
+ # print(f"Stats saved to {filename}")
113
+
114
+ async def cleanup_old_data(self):
115
+ # cutoff_time = datetime.now() - timedelta(seconds=30)
116
+ cutoff_time = datetime.now() - timedelta(hours=24)
117
+ async with self.lock:
118
+ for endpoint in list(self.request_arrivals.keys()):
119
+ self.request_arrivals[endpoint] = [
120
+ t for t in self.request_arrivals[endpoint] if t > cutoff_time
121
+ ]
122
+ if not self.request_arrivals[endpoint]:
123
+ del self.request_arrivals[endpoint]
124
+ self.request_counts.pop(endpoint, None)
125
+ self.request_times.pop(endpoint, None)
126
+ self.ip_counts.pop(endpoint, None)
127
+
128
+ async def cleanup(self):
129
+ await self.save_stats()
130
 
131
  # 配置 CORS 中间件
132
  app.add_middleware(
 
137
  allow_headers=["*"], # 允许所有头部字段
138
  )
139
 
140
+ app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])
141
+
142
  async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
143
  url = provider['base_url']
144
  parsed_url = urlparse(url)
 
290
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
291
  return token
292
 
293
+ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
294
+ api_list = app.state.api_list
295
+ token = credentials.credentials
296
+ if token not in api_list:
297
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
298
+ for api_key in app.state.api_keys_db:
299
+ if api_key['api'] == token:
300
+ if api_key.get('role') != "admin":
301
+ raise HTTPException(status_code=403, detail="Permission denied")
302
+ return token
303
+
304
  @app.post("/v1/chat/completions")
305
  async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
306
  return await model_handler.request_model(request, token)
 
326
 
327
  @app.get("/generate-api-key")
328
  def generate_api_key():
329
+ api_key = "sk-" + secrets.token_urlsafe(36)
330
  return JSONResponse(content={"api_key": api_key})
331
 
332
+ @app.get("/stats")
333
+ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
334
+ middleware = app.middleware_stack.app
335
+ if isinstance(middleware, StatsMiddleware):
336
+ async with middleware.lock:
337
+ stats = {
338
+ "request_counts": dict(middleware.request_counts),
339
+ "request_times": dict(middleware.request_times),
340
+ "ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
341
+ "request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}
342
+ }
343
+ return JSONResponse(content=stats)
344
+ return {"error": "StatsMiddleware not found"}
 
 
345
 
346
  # async def on_fetch(request, env):
347
  # import asgi