✨ Feature: Add feature: Support for counting model usage in the stats endpoint
Browse files
main.py
CHANGED
@@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
10 |
from fastapi import FastAPI, HTTPException, Depends, Request
|
11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
13 |
|
14 |
from models import RequestModel, ImageGenerationRequest
|
15 |
from request import get_payload
|
@@ -70,6 +71,14 @@ from datetime import timedelta
|
|
70 |
import json
|
71 |
import aiofiles
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
class StatsMiddleware(BaseHTTPMiddleware):
|
74 |
def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
|
75 |
super().__init__(app)
|
@@ -78,6 +87,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
78 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
79 |
self.request_arrivals = defaultdict(list)
|
80 |
self.channel_success_counts = defaultdict(int)
|
|
|
81 |
self.channel_failure_counts = defaultdict(int)
|
82 |
self.lock = asyncio.Lock()
|
83 |
self.exclude_paths = set(exclude_paths or [])
|
@@ -91,6 +101,20 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
91 |
async def dispatch(self, request: Request, call_next):
|
92 |
arrival_time = datetime.now()
|
93 |
start_time = time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
response = await call_next(request)
|
95 |
process_time = time() - start_time
|
96 |
|
@@ -103,6 +127,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
103 |
self.request_times[endpoint] += process_time
|
104 |
self.ip_counts[endpoint][client_ip] += 1
|
105 |
self.request_arrivals[endpoint].append(arrival_time)
|
|
|
106 |
|
107 |
return response
|
108 |
|
@@ -121,6 +146,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
121 |
stats = {
|
122 |
"request_counts": dict(self.request_counts),
|
123 |
"request_times": dict(self.request_times),
|
|
|
124 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
125 |
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
|
126 |
"channel_success_counts": dict(self.channel_success_counts),
|
@@ -553,6 +579,7 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)
|
|
553 |
stats = {
|
554 |
"channel_success_percentages": middleware.calculate_success_percentages(),
|
555 |
"channel_failure_percentages": middleware.calculate_failure_percentages(),
|
|
|
556 |
"request_counts": dict(middleware.request_counts),
|
557 |
"request_times": dict(middleware.request_times),
|
558 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|
|
|
10 |
from fastapi import FastAPI, HTTPException, Depends, Request
|
11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
+
from fastapi.exceptions import RequestValidationError
|
14 |
|
15 |
from models import RequestModel, ImageGenerationRequest
|
16 |
from request import get_payload
|
|
|
71 |
import json
|
72 |
import aiofiles
|
73 |
|
74 |
+
async def parse_request_body(request: Request):
|
75 |
+
if request.method == "POST" and "application/json" in request.headers.get("content-type", ""):
|
76 |
+
try:
|
77 |
+
return await request.json()
|
78 |
+
except json.JSONDecodeError:
|
79 |
+
return None
|
80 |
+
return None
|
81 |
+
|
82 |
class StatsMiddleware(BaseHTTPMiddleware):
|
83 |
def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.json"):
|
84 |
super().__init__(app)
|
|
|
87 |
self.ip_counts = defaultdict(lambda: defaultdict(int))
|
88 |
self.request_arrivals = defaultdict(list)
|
89 |
self.channel_success_counts = defaultdict(int)
|
90 |
+
self.model_counts = defaultdict(int)
|
91 |
self.channel_failure_counts = defaultdict(int)
|
92 |
self.lock = asyncio.Lock()
|
93 |
self.exclude_paths = set(exclude_paths or [])
|
|
|
101 |
async def dispatch(self, request: Request, call_next):
|
102 |
arrival_time = datetime.now()
|
103 |
start_time = time()
|
104 |
+
|
105 |
+
# 使用依赖注入获取预解析的请求体
|
106 |
+
request.state.parsed_body = await parse_request_body(request)
|
107 |
+
|
108 |
+
model = "unknown"
|
109 |
+
if request.state.parsed_body:
|
110 |
+
try:
|
111 |
+
request_model = RequestModel(**request.state.parsed_body)
|
112 |
+
model = request_model.model
|
113 |
+
except RequestValidationError:
|
114 |
+
pass
|
115 |
+
except Exception as e:
|
116 |
+
logger.error(f"Error processing request: {str(e)}")
|
117 |
+
|
118 |
response = await call_next(request)
|
119 |
process_time = time() - start_time
|
120 |
|
|
|
127 |
self.request_times[endpoint] += process_time
|
128 |
self.ip_counts[endpoint][client_ip] += 1
|
129 |
self.request_arrivals[endpoint].append(arrival_time)
|
130 |
+
self.model_counts[model] += 1
|
131 |
|
132 |
return response
|
133 |
|
|
|
146 |
stats = {
|
147 |
"request_counts": dict(self.request_counts),
|
148 |
"request_times": dict(self.request_times),
|
149 |
+
"model_counts": dict(self.model_counts),
|
150 |
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
|
151 |
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
|
152 |
"channel_success_counts": dict(self.channel_success_counts),
|
|
|
579 |
stats = {
|
580 |
"channel_success_percentages": middleware.calculate_success_percentages(),
|
581 |
"channel_failure_percentages": middleware.calculate_failure_percentages(),
|
582 |
+
"model_counts": dict(middleware.model_counts),
|
583 |
"request_counts": dict(middleware.request_counts),
|
584 |
"request_times": dict(middleware.request_times),
|
585 |
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
|