win7win commited on
Commit
65dd154
·
verified ·
1 Parent(s): 2f4cba4

Upload 10 files

Browse files
app/__init__.py ADDED
File without changes
app/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (116 Bytes). View file
 
app/__pycache__/gemini.cpython-39.pyc ADDED
Binary file (3.02 kB). View file
 
app/__pycache__/main.cpython-39.pyc ADDED
Binary file (7.65 kB). View file
 
app/__pycache__/models.cpython-39.pyc ADDED
Binary file (2.32 kB). View file
 
app/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.42 kB). View file
 
app/gemini.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ import asyncio
5
+ from app.models import ChatCompletionRequest, Message # 相对导入
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Dict, Any, List
8
+ import httpx
9
+
10
+
11
+ @dataclass
12
+ class GeneratedText:
13
+ text: str
14
+ finish_reason: Optional[str] = None
15
+
16
+
17
+ class ResponseWrapper:
18
+ def __init__(self, data: Dict[Any, Any]): # 正确的初始化方法名
19
+ self._data = data
20
+ self._text = self._extract_text()
21
+ self._finish_reason = self._extract_finish_reason()
22
+ self._prompt_token_count = self._extract_prompt_token_count()
23
+ self._candidates_token_count = self._extract_candidates_token_count()
24
+ self._total_token_count = self._extract_total_token_count()
25
+ self._thoughts = self._extract_thoughts()
26
+ self._json_dumps = json.dumps(self._data, indent=4, ensure_ascii=False)
27
+
28
+ def _extract_thoughts(self) -> Optional[str]:
29
+ try:
30
+ for part in self._data['candidates'][0]['content']['parts']:
31
+ if 'thought' in part:
32
+ return part['text']
33
+ return ""
34
+ except (KeyError, IndexError):
35
+ return ""
36
+
37
+ def _extract_text(self) -> str:
38
+ try:
39
+ for part in self._data['candidates'][0]['content']['parts']:
40
+ if 'thought' not in part:
41
+ return part['text']
42
+ return ""
43
+ except (KeyError, IndexError):
44
+ return ""
45
+
46
+ def _extract_finish_reason(self) -> Optional[str]:
47
+ try:
48
+ return self._data['candidates'][0].get('finishReason')
49
+ except (KeyError, IndexError):
50
+ return None
51
+
52
+ def _extract_prompt_token_count(self) -> Optional[int]:
53
+ try:
54
+ return self._data['usageMetadata'].get('promptTokenCount')
55
+ except (KeyError):
56
+ return None
57
+
58
+ def _extract_candidates_token_count(self) -> Optional[int]:
59
+ try:
60
+ return self._data['usageMetadata'].get('candidatesTokenCount')
61
+ except (KeyError):
62
+ return None
63
+
64
+ def _extract_total_token_count(self) -> Optional[int]:
65
+ try:
66
+ return self._data['usageMetadata'].get('totalTokenCount')
67
+ except (KeyError):
68
+ return None
69
+
70
+ @property
71
+ def text(self) -> str:
72
+ return self._text
73
+
74
+ @property
75
+ def finish_reason(self) -> Optional[str]:
76
+ return self._finish_reason
77
+
78
+ @property
79
+ def prompt_token_count(self) -> Optional[int]:
80
+ return self._prompt_token_count
81
+
82
+ @property
83
+ def candidates_token_count(self) -> Optional[int]:
84
+ return self._candidates_token_count
85
+
86
+ @property
87
+ def total_token_count(self) -> Optional[int]:
88
+ return self._total_token_count
89
+
90
+ @property
91
+ def thoughts(self) -> Optional[str]:
92
+ return self._thoughts
93
+
94
+ @property
95
+ def json_dumps(self) -> str:
96
+ return self._json_dumps
97
+
98
+
99
+ class GeminiClient:
100
+
101
+ AVAILABLE_MODELS = []
102
+ EXTRA_MODELS = os.environ.get("EXTRA_MODELS", "").split(",")
103
+
104
+ def __init__(self, api_key: str):
105
+ self.api_key = api_key
106
+
107
+ async def stream_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
108
+ api_version = "v1alpha" if "think" in request.model else "v1beta"
109
+ url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:streamGenerateContent?key={self.api_key}&alt=sse"
110
+ headers = {
111
+ "Content-Type": "application/json",
112
+ }
113
+ data = {
114
+ "contents": contents,
115
+ "generationConfig": {
116
+ "temperature": request.temperature,
117
+ "maxOutputTokens": request.max_tokens,
118
+ },
119
+ "safetySettings": safety_settings,
120
+ }
121
+ if system_instruction:
122
+ data["system_instruction"] = system_instruction
123
+
124
+ async with httpx.AsyncClient() as client:
125
+ async with client.stream("POST", url, headers=headers, json=data, timeout=600) as response:
126
+ buffer = b"" # 初始化 JSON 缓冲
127
+ async for line in response.aiter_lines():
128
+ if line.startswith("data: "):
129
+ line = line[len("data: "):]
130
+ buffer += line.encode('utf-8')
131
+ try:
132
+ data = json.loads(buffer.decode('utf-8'))
133
+ buffer = b""
134
+ if 'candidates' in data and data['candidates']:
135
+ candidate = data['candidates'][0]
136
+ if 'content' in candidate:
137
+ content = candidate['content']
138
+ if 'parts' in content and content['parts']:
139
+ parts = content['parts']
140
+ text = ""
141
+ for part in parts:
142
+ if 'text' in part:
143
+ text += part['text']
144
+ finish_reason = candidate.get('finishReason')
145
+ if text:
146
+ yield text
147
+ except json.JSONDecodeError:
148
+ continue
149
+ except Exception as e:
150
+ print(f"Error parsing JSON: {e}")
151
+ continue
152
+
153
+ def complete_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
154
+ api_version = "v1alpha" if "think" in request.model else "v1beta"
155
+ url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:generateContent?key={self.api_key}"
156
+ headers = {
157
+ "Content-Type": "application/json",
158
+ }
159
+ data = {
160
+ "contents": contents,
161
+ "generationConfig": {
162
+ "temperature": request.temperature,
163
+ "maxOutputTokens": request.max_tokens,
164
+ },
165
+ "safetySettings": safety_settings,
166
+ }
167
+ if system_instruction:
168
+ data["system_instruction"] = system_instruction
169
+ response = requests.post(url, headers=headers, json=data)
170
+ response.raise_for_status()
171
+ return ResponseWrapper(response.json())
172
+
173
+ def convert_messages(self, messages, use_system_prompt=False):
174
+ gemini_history = []
175
+ errors = []
176
+ system_instruction_text = ""
177
+ is_system_phase = use_system_prompt
178
+ for i, message in enumerate(messages):
179
+ role = message.role
180
+ content = message.content
181
+
182
+ if isinstance(content, str):
183
+ if is_system_phase and role == 'system':
184
+ if system_instruction_text:
185
+ system_instruction_text += "\n" + content
186
+ else:
187
+ system_instruction_text = content
188
+ else:
189
+ is_system_phase = False
190
+
191
+ if role in ['user', 'system']:
192
+ role_to_use = 'user'
193
+ elif role == 'assistant':
194
+ role_to_use = 'model'
195
+ else:
196
+ errors.append(f"Invalid role: {role}")
197
+ continue
198
+
199
+ if gemini_history and gemini_history[-1]['role'] == role_to_use:
200
+ gemini_history[-1]['parts'].append({"text": content})
201
+ else:
202
+ gemini_history.append(
203
+ {"role": role_to_use, "parts": [{"text": content}]})
204
+ elif isinstance(content, list):
205
+ parts = []
206
+ for item in content:
207
+ if item.get('type') == 'text':
208
+ parts.append({"text": item.get('text')})
209
+ elif item.get('type') == 'image_url':
210
+ image_data = item.get('image_url', {}).get('url', '')
211
+ if image_data.startswith('data:image/'):
212
+ try:
213
+ mime_type, base64_data = image_data.split(';')[
214
+ 0].split(':')[1], image_data.split(',')[1]
215
+ parts.append({
216
+ "inline_data": {
217
+ "mime_type": mime_type,
218
+ "data": base64_data
219
+ }
220
+ })
221
+ except (IndexError, ValueError):
222
+ errors.append(
223
+ f"Invalid data URI for image: {image_data}")
224
+ else:
225
+ errors.append(
226
+ f"Invalid image URL format for item: {item}")
227
+
228
+ if parts:
229
+ if role in ['user', 'system']:
230
+ role_to_use = 'user'
231
+ elif role == 'assistant':
232
+ role_to_use = 'model'
233
+ else:
234
+ errors.append(f"Invalid role: {role}")
235
+ continue
236
+ if gemini_history and gemini_history[-1]['role'] == role_to_use:
237
+ gemini_history[-1]['parts'].extend(parts)
238
+ else:
239
+ gemini_history.append(
240
+ {"role": role_to_use, "parts": parts})
241
+ if errors:
242
+ return errors
243
+ else:
244
+ return gemini_history, {"parts": [{"text": system_instruction_text}]}
245
+
246
+ @staticmethod
247
+ async def list_available_models(api_key) -> list:
248
+ url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(
249
+ api_key)
250
+ async with httpx.AsyncClient() as client:
251
+ response = await client.get(url)
252
+ response.raise_for_status()
253
+ data = response.json()
254
+ models = [model["name"] for model in data.get("models", [])]
255
+ models.extend(GeminiClient.EXTRA_MODELS)
256
+ return models
app/main.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request, Depends, status
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from .models import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ModelList
4
+ from .gemini import GeminiClient, ResponseWrapper
5
+ from .utils import handle_gemini_error, protect_from_abuse, APIKeyManager, test_api_key
6
+ import os
7
+ import json
8
+ import asyncio
9
+ from typing import Literal
10
+ import random
11
+ import requests
12
+ from datetime import datetime, timedelta
13
+ from apscheduler.schedulers.background import BackgroundScheduler
14
+ import sys
15
+
16
+
17
+ DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
18
+ LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
19
+ LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
20
+
21
+
22
+ def format_log_message(level, message, extra=None):
23
+ """格式化日志消息,模拟之前的 logging 格式"""
24
+ log_values = {
25
+ 'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), # 模拟 asctime
26
+ 'levelname': level, # 日志级别
27
+ 'key': extra.get('key', 'N/A') if extra else 'N/A',
28
+ 'request_type': extra.get('request_type', 'N/A') if extra else 'N/A',
29
+ 'model': extra.get('model', 'N/A') if extra else 'N/A',
30
+ 'status_code': extra.get('status_code', 'N/A') if extra else 'N/A',
31
+ 'error_message': extra.get('error_message', '') if extra else '' ,
32
+ 'message': message
33
+ }
34
+ log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
35
+ return log_format % log_values
36
+
37
+
38
+ def translate_error(message: str) -> str:
39
+ if "quota exceeded" in message.lower():
40
+ return "API 密钥配额已用尽"
41
+ if "invalid argument" in message.lower():
42
+ return "无效参数"
43
+ if "internal server error" in message.lower():
44
+ return "服务器内部错误"
45
+ if "service unavailable" in message.lower():
46
+ return "服务不可用"
47
+ return message
48
+
49
+
50
+ def handle_exception(exc_type, exc_value, exc_traceback):
51
+ if issubclass(exc_type, KeyboardInterrupt):
52
+ sys.excepthook(exc_type, exc_value, exc_traceback)
53
+ return
54
+ error_message = translate_error(str(exc_value))
55
+ log_msg = format_log_message('ERROR', f"未捕获的异常: %s" % error_message, extra={'status_code': 500, 'error_message': error_message})
56
+ print(log_msg)
57
+
58
+
59
+ sys.excepthook = handle_exception
60
+
61
+ app = FastAPI()
62
+
63
+ PASSWORD = os.environ.get("PASSWORD", "123")
64
+ MAX_REQUESTS_PER_MINUTE = int(os.environ.get("MAX_REQUESTS_PER_MINUTE", "30"))
65
+ MAX_REQUESTS_PER_DAY_PER_IP = int(
66
+ os.environ.get("MAX_REQUESTS_PER_DAY_PER_IP", "600"))
67
+ MAX_RETRIES = int(os.environ.get('MaxRetries', '3').strip() or '3')
68
+ RETRY_DELAY = 1
69
+ MAX_RETRY_DELAY = 16
70
+ safety_settings = [
71
+ {
72
+ "category": "HARM_CATEGORY_HARASSMENT",
73
+ "threshold": "BLOCK_NONE"
74
+ },
75
+ {
76
+ "category": "HARM_CATEGORY_HATE_SPEECH",
77
+ "threshold": "BLOCK_NONE"
78
+ },
79
+ {
80
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
81
+ "threshold": "BLOCK_NONE"
82
+ },
83
+ {
84
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
85
+ "threshold": "BLOCK_NONE"
86
+ },
87
+ {
88
+ "category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
89
+ "threshold": 'BLOCK_NONE'
90
+ }
91
+ ]
92
+ safety_settings_g2 = [
93
+ {
94
+ "category": "HARM_CATEGORY_HARASSMENT",
95
+ "threshold": "OFF"
96
+ },
97
+ {
98
+ "category": "HARM_CATEGORY_HATE_SPEECH",
99
+ "threshold": "OFF"
100
+ },
101
+ {
102
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
103
+ "threshold": "OFF"
104
+ },
105
+ {
106
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
107
+ "threshold": "OFF"
108
+ },
109
+ {
110
+ "category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
111
+ "threshold": 'OFF'
112
+ }
113
+ ]
114
+
115
+ key_manager = APIKeyManager() # 实例化 APIKeyManager,栈会在 __init__ 中初始化
116
+ current_api_key = key_manager.get_available_key()
117
+
118
+
119
+ def switch_api_key():
120
+ global current_api_key
121
+ key = key_manager.get_available_key() # get_available_key 会处理栈的逻辑
122
+ if key:
123
+ current_api_key = key
124
+ log_msg = format_log_message('INFO', f"API key 替换为 → {current_api_key[:8]}...", extra={'key': current_api_key[:8], 'request_type': 'switch_key'})
125
+ print(log_msg)
126
+ else:
127
+ log_msg = format_log_message('ERROR', "API key 替换失败,所有API key都已耗尽或被暂时禁用,请重新配置或稍后重试", extra={'key': 'N/A', 'request_type': 'switch_key', 'status_code': 'N/A'})
128
+ print(log_msg)
129
+
130
+
131
+ async def check_keys():
132
+ available_keys = []
133
+ for key in key_manager.api_keys:
134
+ is_valid = await test_api_key(key)
135
+ status_msg = "有效" if is_valid else "无效"
136
+ log_msg = format_log_message('INFO', f"API Key {key[:10]}... {status_msg}.")
137
+ print(log_msg)
138
+ if is_valid:
139
+ available_keys.append(key)
140
+ if not available_keys:
141
+ log_msg = format_log_message('ERROR', "没有可用的 API 密钥!", extra={'key': 'N/A', 'request_type': 'startup', 'status_code': 'N/A'})
142
+ print(log_msg)
143
+ return available_keys
144
+
145
+
146
+ @app.on_event("startup")
147
+ async def startup_event():
148
+ log_msg = format_log_message('INFO', "Starting Gemini API proxy...")
149
+ print(log_msg)
150
+ available_keys = await check_keys()
151
+ if available_keys:
152
+ key_manager.api_keys = available_keys
153
+ key_manager._reset_key_stack() # 启动时也确保创建随机栈
154
+ key_manager.show_all_keys()
155
+ log_msg = format_log_message('INFO', f"可用 API 密钥数量:{len(key_manager.api_keys)}")
156
+ print(log_msg)
157
+ MAX_RETRIES = len(key_manager.api_keys) # 动态设置 MAX_RETRIES 为密钥数量
158
+ log_msg = format_log_message('INFO', f"最大重试次数设置为:{MAX_RETRIES}") # 添加日志
159
+ print(log_msg)
160
+ if key_manager.api_keys:
161
+ all_models = await GeminiClient.list_available_models(key_manager.api_keys[0])
162
+ GeminiClient.AVAILABLE_MODELS = [model.replace(
163
+ "models/", "") for model in all_models]
164
+ log_msg = format_log_message('INFO', "Available models loaded.")
165
+ print(log_msg)
166
+
167
+ @app.get("/v1/models", response_model=ModelList)
168
+ def list_models():
169
+ log_msg = format_log_message('INFO', "Received request to list models", extra={'request_type': 'list_models', 'status_code': 200})
170
+ print(log_msg)
171
+ return ModelList(data=[{"id": model, "object": "model", "created": 1678888888, "owned_by": "organization-owner"} for model in GeminiClient.AVAILABLE_MODELS])
172
+
173
+
174
+ async def verify_password(request: Request):
175
+ if PASSWORD:
176
+ auth_header = request.headers.get("Authorization")
177
+ if not auth_header or not auth_header.startswith("Bearer "):
178
+ raise HTTPException(
179
+ status_code=401, detail="Unauthorized: Missing or invalid token")
180
+ token = auth_header.split(" ")[1]
181
+ if token != PASSWORD:
182
+ raise HTTPException(
183
+ status_code=401, detail="Unauthorized: Invalid token")
184
+
185
+
186
+ async def process_request(chat_request: ChatCompletionRequest, http_request: Request, request_type: Literal['stream', 'non-stream']):
187
+ global current_api_key
188
+ protect_from_abuse(
189
+ http_request, MAX_REQUESTS_PER_MINUTE, MAX_REQUESTS_PER_DAY_PER_IP)
190
+ if chat_request.model not in GeminiClient.AVAILABLE_MODELS:
191
+ error_msg = "无效的模型"
192
+ extra_log = {'request_type': request_type, 'model': chat_request.model, 'status_code': 400, 'error_message': error_msg}
193
+ log_msg = format_log_message('ERROR', error_msg, extra=extra_log)
194
+ print(log_msg)
195
+ raise HTTPException(
196
+ status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
197
+
198
+ key_manager.reset_tried_keys_for_request() # 在每次请求处理开始时重置 tried_keys 集合
199
+
200
+ contents, system_instruction = GeminiClient.convert_messages(
201
+ GeminiClient, chat_request.messages)
202
+
203
+ retry_attempts = len(key_manager.api_keys) if key_manager.api_keys else 1 # 重试次数等于密钥数量,至少尝试 1 次
204
+ for attempt in range(1, retry_attempts + 1):
205
+ extra_log_attempt_start = {'key': 'N/A', 'request_type': request_type, 'model': chat_request.model} # 初始 key 为 N/A
206
+ log_msg_attempt_start = format_log_message('INFO', f"第 {attempt}/{retry_attempts} 次尝试 ...", extra=extra_log_attempt_start)
207
+ print(log_msg_attempt_start)
208
+
209
+ current_api_key = key_manager.get_available_key() # 每次循环都获取新的 key, 栈逻辑在 get_available_key 中处理
210
+
211
+ if current_api_key is None: # 检查是否获取到 API 密钥
212
+ log_msg_no_key = format_log_message('WARNING', "没有可用的 API 密钥,跳过本次尝试", extra={'request_type': request_type, 'model': chat_request.model, 'status_code': 'N/A'})
213
+ print(log_msg_no_key)
214
+ break # 如果没有可用密钥,跳出循环
215
+
216
+ extra_log = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model} # 使用获取到的 key 记录日志
217
+ log_msg = format_log_message('INFO', f"第 {attempt}/{retry_attempts} 次尝试 ... 使用密钥: {current_api_key[:8]}...", extra=extra_log)
218
+ print(log_msg)
219
+
220
+
221
+ gemini_client = GeminiClient(current_api_key)
222
+ try:
223
+ if chat_request.stream:
224
+ async def stream_generator():
225
+ try:
226
+ async for chunk in gemini_client.stream_chat(chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction):
227
+ formatted_chunk = {"id": "chatcmpl-someid", "object": "chat.completion.chunk", "created": 1234567,
228
+ "model": chat_request.model, "choices": [{"delta": {"role": "assistant", "content": chunk}, "index": 0, "finish_reason": None}]}
229
+ yield f"data: {json.dumps(formatted_chunk)}\n\n"
230
+ yield "data: [DONE]\n\n"
231
+
232
+ except asyncio.CancelledError:
233
+ extra_log_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端已断开连接'}
234
+ log_msg = format_log_message('INFO', "Client disconnected", extra=extra_log_cancel)
235
+ print(log_msg)
236
+ except Exception as e:
237
+ error_detail = handle_gemini_error(
238
+ e, current_api_key, key_manager, switch_api_key)
239
+ log_message = f"API Key failed: {error_detail}"
240
+ extra_log_error = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 500, 'error_message': error_detail}
241
+ log_msg = format_log_message('ERROR', log_message, extra=extra_log_error)
242
+ print(log_msg)
243
+ yield f"data: {json.dumps({'error': {'message': error_detail, 'type': 'gemini_error'}})}\n\n"
244
+ if attempt < retry_attempts: # 流式也根据apikey 数量判断是否切换key
245
+ switch_api_key() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
246
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
247
+ else:
248
+ async def run_gemini_completion():
249
+ try:
250
+ response_content = await asyncio.to_thread(gemini_client.complete_chat, chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction)
251
+ return response_content
252
+ except asyncio.CancelledError:
253
+ extra_log_gemini_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': 'Gemini API 调用因客户端断开连接而被取消'}
254
+ log_msg = format_log_message('INFO', "Gemini API call cancelled due to client disconnect", extra=extra_log_gemini_cancel)
255
+ print(log_msg)
256
+ raise
257
+
258
+ async def check_client_disconnect():
259
+ while True:
260
+ if await http_request.is_disconnected():
261
+ extra_log_client_disconnect = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '在非流式请求期间检测到客户端断开连接。正在取消 Gemini API 调用。'}
262
+ log_msg = format_log_message('INFO', "Client disconnected during non-streaming request. Cancelling Gemini API call.", extra=extra_log_client_disconnect)
263
+ print(log_msg)
264
+ return True
265
+ await asyncio.sleep(0.5)
266
+
267
+ gemini_task = asyncio.create_task(run_gemini_completion())
268
+ disconnect_task = asyncio.create_task(check_client_disconnect())
269
+
270
+ try:
271
+ done, pending = await asyncio.wait(
272
+ [gemini_task, disconnect_task],
273
+ return_when=asyncio.FIRST_COMPLETED
274
+ )
275
+
276
+ if disconnect_task in done:
277
+ gemini_task.cancel()
278
+ try:
279
+ await gemini_task
280
+ except asyncio.CancelledError:
281
+ extra_log_gemini_task_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端断开连接后,Gemini API 任务已成功取消。'}
282
+ log_msg = format_log_message('INFO', "Gemini API task successfully cancelled after client disconnect.", extra=extra_log_gemini_task_cancel)
283
+ print(log_msg)
284
+ pass
285
+ raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected")
286
+
287
+ if gemini_task in done:
288
+ disconnect_task.cancel()
289
+ try:
290
+ await disconnect_task
291
+ except asyncio.CancelledError:
292
+ pass
293
+ response_content = gemini_task.result()
294
+ response = ChatCompletionResponse(id="chatcmpl-someid", object="chat.completion", created=1234567890, model=chat_request.model,
295
+ choices=[{"index": 0, "message": {"role": "assistant", "content": response_content.text}, "finish_reason": "stop"}])
296
+ extra_log_success = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 200}
297
+ log_msg = format_log_message('INFO', "Request successful", extra=extra_log_success)
298
+ print(log_msg)
299
+ return response
300
+
301
+ except asyncio.CancelledError:
302
+ extra_log_request_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message':"请求被取消" }
303
+ log_msg = format_log_message('INFO', "Request cancelled", extra=extra_log_request_cancel)
304
+ print(log_msg)
305
+ raise
306
+
307
+
308
+ except requests.exceptions.RequestException as e:
309
+ error_detail = handle_gemini_error(
310
+ e, current_api_key, key_manager, switch_api_key)
311
+ extra_log_request_exception = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 500, 'error_message': error_detail}
312
+ log_msg = format_log_message('ERROR', f"{error_detail}", extra=extra_log_request_exception)
313
+ print(log_msg)
314
+ if attempt < retry_attempts: # 根据apikey 数量判断是否切换key
315
+ switch_api_key() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
316
+ else:
317
+ raise HTTPException(
318
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{retry_attempts} 次尝试后仍然失败,请修改预设或输入") # 错误信息里的重试次数也动态修改
319
+ except Exception as e:
320
+ error_detail = handle_gemini_error(
321
+ e, current_api_key, key_manager, switch_api_key)
322
+ extra_log_exception = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 500, 'error_message': error_detail}
323
+ log_msg = format_log_message('ERROR', f"{error_detail}", extra=extra_log_exception)
324
+ print(log_msg)
325
+ if attempt < retry_attempts: # 根据apikey 数量判断是否切换key
326
+ switch_api_key() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
327
+ else:
328
+ raise HTTPException(
329
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{retry_attempts} 次尝试后仍然失败,请修改预设或输入") # 错误信息里的重试次数也动态修改
330
+
331
+ msg = "所有API密钥或重试次数均失败"
332
+ extra_log_all_fail = {'key': "ALL", 'request_type': request_type, 'model': chat_request.model, 'status_code': 500, 'error_message': msg}
333
+ log_msg = format_log_message('ERROR', msg, extra=extra_log_all_fail)
334
+ print(log_msg)
335
+ raise HTTPException(
336
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg)
337
+
338
+
339
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
340
+ async def chat_completions(request: ChatCompletionRequest, http_request: Request, _: None = Depends(verify_password)):
341
+ return await process_request(request, http_request, "stream" if request.stream else "non-stream")
342
+
343
+
344
+ @app.exception_handler(Exception)
345
+ async def global_exception_handler(request: Request, exc: Exception):
346
+ error_message = translate_error(str(exc))
347
+ extra_log_unhandled_exception = {'status_code': 500, 'error_message': error_message}
348
+ log_msg = format_log_message('ERROR', f"Unhandled exception: {error_message}", extra=extra_log_unhandled_exception)
349
+ print(log_msg)
350
+ return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ErrorResponse(message=str(exc), type="internal_error").dict())
app/models.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Union, Literal
2
+ from pydantic import BaseModel, Field
3
+
4
+ class Message(BaseModel):
5
+ role: str
6
+ content: str
7
+
8
+ class ChatCompletionRequest(BaseModel):
9
+ model: str
10
+ messages: List[Message]
11
+ temperature: float = 0.7
12
+ top_p: Optional[float] = 1.0
13
+ n: int = 1
14
+ stream: bool = False
15
+ stop: Optional[Union[str, List[str]]] = None
16
+ max_tokens: Optional[int] = None
17
+ presence_penalty: Optional[float] = 0.0
18
+ frequency_penalty: Optional[float] = 0.0
19
+
20
+ class Choice(BaseModel):
21
+ index: int
22
+ message: Message
23
+ finish_reason: Optional[str] = None
24
+
25
+ class Usage(BaseModel):
26
+ prompt_tokens: int = 0
27
+ completion_tokens: int = 0
28
+ total_tokens: int = 0
29
+
30
+ class ChatCompletionResponse(BaseModel):
31
+ id: str
32
+ object: Literal["chat.completion"]
33
+ created: int
34
+ model: str
35
+ choices: List[Choice]
36
+ usage: Usage = Field(default_factory=Usage)
37
+
38
+ class ErrorResponse(BaseModel):
39
+ message: str
40
+ type: str
41
+ param: Optional[str] = None
42
+ code: Optional[str] = None
43
+
44
+ class ModelList(BaseModel):
45
+ object: str = "list"
46
+ data: List[Dict]
app/utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from fastapi import HTTPException, Request
3
+ import time
4
+ import re
5
+ from datetime import datetime, timedelta
6
+ from apscheduler.schedulers.background import BackgroundScheduler
7
+ import os
8
+ import requests
9
+ import httpx
10
+ from threading import Lock
11
+
12
+ DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
13
+ LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
14
+ LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
15
+
16
+
17
+ def format_log_message(level, message, extra=None):
18
+ """格式化日志消息,模拟之前的 logging 格式"""
19
+ log_values = {
20
+ 'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), # 模拟 asctime
21
+ 'levelname': level, # 日志级别
22
+ 'key': extra.get('key', 'N/A') if extra else 'N/A',
23
+ 'request_type': extra.get('request_type', 'N/A') if extra else 'N/A',
24
+ 'model': extra.get('model', 'N/A') if extra else 'N/A',
25
+ 'status_code': extra.get('status_code', 'N/A') if extra else 'N/A',
26
+ 'error_message': extra.get('error_message', '') if extra else '' ,
27
+ 'message': message
28
+ }
29
+ log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
30
+ return log_format % log_values
31
+
32
+
33
+ class APIKeyManager:
34
+ def __init__(self):
35
+ self.api_keys = re.findall(
36
+ r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', ""))
37
+ self.key_stack = [] # 初始化密钥栈
38
+ self._reset_key_stack() # 初始化时创建随机密钥栈
39
+ self.api_key_blacklist = set()
40
+ self.api_key_blacklist_duration = 60
41
+ self.scheduler = BackgroundScheduler()
42
+ self.scheduler.start()
43
+ self.tried_keys_for_request = set() # 用于跟踪当前请求尝试中已试过的 key
44
+
45
+ def _reset_key_stack(self):
46
+ """创建并随机化密钥栈"""
47
+ shuffled_keys = self.api_keys[:] # 创建 api_keys 的副本以避免直接修改原列表
48
+ random.shuffle(shuffled_keys)
49
+ self.key_stack = shuffled_keys
50
+ log_msg = format_log_message('INFO', "已重新创建随机密钥栈", extra={'request_type': 'key_stack', 'status_code': 'N/A'})
51
+ print(log_msg)
52
+
53
+
54
+ def get_available_key(self):
55
+ """从栈顶获取密钥,栈空时重新生成 (修改后)"""
56
+ while self.key_stack:
57
+ key = self.key_stack.pop()
58
+ if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
59
+ self.tried_keys_for_request.add(key)
60
+ return key
61
+
62
+ if not self.api_keys:
63
+ log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!")
64
+ print(log_msg)
65
+ return None
66
+
67
+ log_msg = format_log_message('WARNING', "密钥栈已用尽或栈内密钥均不可用,重新生成密钥栈")
68
+ print(log_msg)
69
+ self._reset_key_stack() # 重新生成密钥栈
70
+
71
+ # 再次尝试从新栈中获取密钥 (迭代一次)
72
+ while self.key_stack:
73
+ key = self.key_stack.pop()
74
+ if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
75
+ self.tried_keys_for_request.add(key)
76
+ return key
77
+
78
+ return None
79
+
80
+
81
+ def show_all_keys(self):
82
+ log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ")
83
+ print(log_msg)
84
+ for i, api_key in enumerate(self.api_keys):
85
+ log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}")
86
+ print(log_msg)
87
+
88
+ def blacklist_key(self, key):
89
+ log_msg = format_log_message('WARNING', f"{key[:8]} → 暂时禁用 {self.api_key_blacklist_duration} 秒")
90
+ print(log_msg)
91
+ self.api_key_blacklist.add(key)
92
+ self.scheduler.add_job(lambda: self.api_key_blacklist.discard(key), 'date',
93
+ run_date=datetime.now() + timedelta(seconds=self.api_key_blacklist_duration))
94
+
95
+ def reset_tried_keys_for_request(self):
96
+ """在新的请求尝试时重置已尝试的 key 集合"""
97
+ self.tried_keys_for_request = set()
98
+
99
+
100
+ def handle_gemini_error(error, current_api_key, key_manager, switch_api_key_func) -> str:
101
+ if isinstance(error, requests.exceptions.HTTPError):
102
+ status_code = error.response.status_code
103
+ if status_code == 400:
104
+ try:
105
+ error_data = error.response.json()
106
+ if 'error' in error_data:
107
+ if error_data['error'].get('code') == "invalid_argument":
108
+ error_message = "无效的 API 密钥"
109
+ extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
110
+ log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key)
111
+ print(log_msg)
112
+ key_manager.blacklist_key(current_api_key)
113
+ switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
114
+ return error_message
115
+ error_message = error_data['error'].get(
116
+ 'message', 'Bad Request')
117
+ extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
118
+ log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400)
119
+ print(log_msg)
120
+ return f"400 错误请求: {error_message}"
121
+ except ValueError:
122
+ error_message = "400 错误请求:响应不是有效的JSON格式"
123
+ extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
124
+ log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json)
125
+ print(log_msg)
126
+ return error_message
127
+
128
+ elif status_code == 429:
129
+ error_message = "API 密钥配额已用尽"
130
+ extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
131
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽", extra=extra_log_429)
132
+ print(log_msg)
133
+ key_manager.blacklist_key(current_api_key)
134
+ switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
135
+ return error_message
136
+
137
+ elif status_code == 403:
138
+ error_message = "权限被拒绝"
139
+ extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
140
+ log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403)
141
+ print(log_msg)
142
+ key_manager.blacklist_key(current_api_key)
143
+ switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
144
+ return error_message
145
+ elif status_code == 500:
146
+ error_message = "服务器内部错误"
147
+ extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
148
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500)
149
+ print(log_msg)
150
+ switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
151
+ return "Gemini API 内部错误"
152
+
153
+ elif status_code == 503:
154
+ error_message = "服务不可用"
155
+ extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
156
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503)
157
+ print(log_msg)
158
+ switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
159
+ return "Gemini API 服务不可用"
160
+ else:
161
+ error_message = f"未知错误: {status_code}"
162
+ extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
163
+ log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other)
164
+ print(log_msg)
165
+ switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
166
+ return f"未知错误/模型不可用: {status_code}"
167
+
168
+ elif isinstance(error, requests.exceptions.ConnectionError):
169
+ error_message = "连接错误"
170
+ log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
171
+ print(log_msg)
172
+ return error_message
173
+
174
+ elif isinstance(error, requests.exceptions.Timeout):
175
+ error_message = "请求超时"
176
+ log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
177
+ print(log_msg)
178
+ return error_message
179
+ else:
180
+ error_message = f"发生未知错误: {error}"
181
+ log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message})
182
+ print(log_msg)
183
+ return error_message
184
+
185
+
186
+ async def test_api_key(api_key: str) -> bool:
187
+ """
188
+ 测试 API 密钥是否有效���
189
+ """
190
+ try:
191
+ url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key)
192
+ async with httpx.AsyncClient() as client:
193
+ response = await client.get(url)
194
+ response.raise_for_status()
195
+ return True
196
+ except Exception:
197
+ return False
198
+
199
+
200
+ rate_limit_data = {}
201
+ rate_limit_lock = Lock()
202
+
203
+
204
+ def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600):
205
+ now = int(time.time())
206
+ minute = now // 60
207
+ day = now // (60 * 60 * 24)
208
+
209
+ minute_key = f"{request.url.path}:{minute}"
210
+ day_key = f"{request.client.host}:{day}"
211
+
212
+ with rate_limit_lock:
213
+ minute_count, minute_timestamp = rate_limit_data.get(
214
+ minute_key, (0, now))
215
+ if now - minute_timestamp >= 60:
216
+ minute_count = 0
217
+ minute_timestamp = now
218
+ minute_count += 1
219
+ rate_limit_data[minute_key] = (minute_count, minute_timestamp)
220
+
221
+ day_count, day_timestamp = rate_limit_data.get(day_key, (0, now))
222
+ if now - day_timestamp >= 86400:
223
+ day_count = 0
224
+ day_timestamp = now
225
+ day_count += 1
226
+ rate_limit_data[day_key] = (day_count, day_timestamp)
227
+
228
+ if minute_count > max_requests_per_minute:
229
+ raise HTTPException(status_code=429, detail={
230
+ "message": "Too many requests per minute", "limit": max_requests_per_minute})
231
+ if day_count > max_requests_per_day_per_ip:
232
+ raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip})