yym68686 commited on
Commit
5eb10aa
·
1 Parent(s): bdfeca1

✨ Feature: Add feature: Support for counting model usage in the stats endpoint

Browse files
Files changed (1) hide show
  1. main.py +27 -0
main.py CHANGED
@@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi import FastAPI, HTTPException, Depends, Request
11
  from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
13
 
14
  from models import RequestModel, ImageGenerationRequest
15
  from request import get_payload
@@ -70,6 +71,14 @@ from datetime import timedelta
70
  import json
71
  import aiofiles
72
 
 
 
 
 
 
 
 
 
73
  class StatsMiddleware(BaseHTTPMiddleware):
74
  def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
75
  super().__init__(app)
@@ -78,6 +87,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
78
  self.ip_counts = defaultdict(lambda: defaultdict(int))
79
  self.request_arrivals = defaultdict(list)
80
  self.channel_success_counts = defaultdict(int)
 
81
  self.channel_failure_counts = defaultdict(int)
82
  self.lock = asyncio.Lock()
83
  self.exclude_paths = set(exclude_paths or [])
@@ -91,6 +101,20 @@ class StatsMiddleware(BaseHTTPMiddleware):
91
  async def dispatch(self, request: Request, call_next):
92
  arrival_time = datetime.now()
93
  start_time = time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  response = await call_next(request)
95
  process_time = time() - start_time
96
 
@@ -103,6 +127,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
103
  self.request_times[endpoint] += process_time
104
  self.ip_counts[endpoint][client_ip] += 1
105
  self.request_arrivals[endpoint].append(arrival_time)
 
106
 
107
  return response
108
 
@@ -121,6 +146,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
121
  stats = {
122
  "request_counts": dict(self.request_counts),
123
  "request_times": dict(self.request_times),
 
124
  "ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
125
  "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
126
  "channel_success_counts": dict(self.channel_success_counts),
@@ -553,6 +579,7 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)
553
  stats = {
554
  "channel_success_percentages": middleware.calculate_success_percentages(),
555
  "channel_failure_percentages": middleware.calculate_failure_percentages(),
 
556
  "request_counts": dict(middleware.request_counts),
557
  "request_times": dict(middleware.request_times),
558
  "ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
 
10
  from fastapi import FastAPI, HTTPException, Depends, Request
11
  from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
+ from fastapi.exceptions import RequestValidationError
14
 
15
  from models import RequestModel, ImageGenerationRequest
16
  from request import get_payload
 
71
  import json
72
  import aiofiles
73
 
74
+ async def parse_request_body(request: Request):
75
+ if request.method == "POST" and "application/json" in request.headers.get("content-type", ""):
76
+ try:
77
+ return await request.json()
78
+ except json.JSONDecodeError:
79
+ return None
80
+ return None
81
+
82
  class StatsMiddleware(BaseHTTPMiddleware):
83
  def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
84
  super().__init__(app)
 
87
  self.ip_counts = defaultdict(lambda: defaultdict(int))
88
  self.request_arrivals = defaultdict(list)
89
  self.channel_success_counts = defaultdict(int)
90
+ self.model_counts = defaultdict(int)
91
  self.channel_failure_counts = defaultdict(int)
92
  self.lock = asyncio.Lock()
93
  self.exclude_paths = set(exclude_paths or [])
 
101
  async def dispatch(self, request: Request, call_next):
102
  arrival_time = datetime.now()
103
  start_time = time()
104
+
105
+ # 使用依赖注入获取预解析的请求体
106
+ request.state.parsed_body = await parse_request_body(request)
107
+
108
+ model = "unknown"
109
+ if request.state.parsed_body:
110
+ try:
111
+ request_model = RequestModel(**request.state.parsed_body)
112
+ model = request_model.model
113
+ except RequestValidationError:
114
+ pass
115
+ except Exception as e:
116
+ logger.error(f"Error processing request: {str(e)}")
117
+
118
  response = await call_next(request)
119
  process_time = time() - start_time
120
 
 
127
  self.request_times[endpoint] += process_time
128
  self.ip_counts[endpoint][client_ip] += 1
129
  self.request_arrivals[endpoint].append(arrival_time)
130
+ self.model_counts[model] += 1
131
 
132
  return response
133
 
 
146
  stats = {
147
  "request_counts": dict(self.request_counts),
148
  "request_times": dict(self.request_times),
149
+ "model_counts": dict(self.model_counts),
150
  "ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
151
  "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
152
  "channel_success_counts": dict(self.channel_success_counts),
 
579
  stats = {
580
  "channel_success_percentages": middleware.calculate_success_percentages(),
581
  "channel_failure_percentages": middleware.calculate_failure_percentages(),
582
+ "model_counts": dict(middleware.model_counts),
583
  "request_counts": dict(middleware.request_counts),
584
  "request_times": dict(middleware.request_times),
585
  "ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},