dan92 commited on
Commit
3d0ac92
·
verified ·
1 Parent(s): 8ce4b70

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +377 -435
main.py CHANGED
@@ -1,436 +1,378 @@
1
- from fastapi import FastAPI, HTTPException, Header
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import StreamingResponse
4
- from pydantic import BaseModel
5
- import openai
6
- from typing import List, Optional, Union
7
- import logging
8
- from itertools import cycle
9
- import asyncio
10
-
11
- import uvicorn
12
-
13
- from app import config
14
- import requests
15
- from datetime import datetime, timezone
16
- import json
17
- import httpx
18
- import uuid
19
- import time
20
-
21
- # 配置日志
22
- logging.basicConfig(
23
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
24
- )
25
- logger = logging.getLogger(__name__)
26
-
27
- app = FastAPI()
28
-
29
- # 允许跨域
30
- app.add_middleware(
31
- CORSMiddleware,
32
- allow_origins=["*"],
33
- allow_credentials=True,
34
- allow_methods=["*"],
35
- allow_headers=["*"],
36
- )
37
-
38
- # API密钥配置
39
- API_KEYS = config.settings.API_KEYS
40
-
41
- # 创建一个循环迭代器
42
- key_cycle = cycle(API_KEYS)
43
-
44
- # 创建两个独立的锁
45
- key_cycle_lock = asyncio.Lock()
46
- failure_count_lock = asyncio.Lock()
47
-
48
- # 添加key失败计数记录
49
- key_failure_counts = {key: 0 for key in API_KEYS}
50
- MAX_FAILURES = 10 # 最大失败次数阈值
51
- MAX_RETRIES = 3 # 最大重试次数
52
-
53
-
54
- async def get_next_key():
55
- """仅获取下一个key,不检查失败次数"""
56
- async with key_cycle_lock:
57
- return next(key_cycle)
58
-
59
- async def is_key_valid(key):
60
- """检查key是否有效"""
61
- async with failure_count_lock:
62
- return key_failure_counts[key] < MAX_FAILURES
63
-
64
- async def reset_failure_counts():
65
- """重置所有key的失败计数"""
66
- async with failure_count_lock:
67
- for key in key_failure_counts:
68
- key_failure_counts[key] = 0
69
-
70
- async def get_next_working_key():
71
- """获取下一个可用的API key"""
72
- initial_key = await get_next_key()
73
- current_key = initial_key
74
-
75
- while True:
76
- if await is_key_valid(current_key):
77
- return current_key
78
-
79
- current_key = await get_next_key()
80
- if current_key == initial_key: # 已经循环了一圈
81
- await reset_failure_counts()
82
- return current_key
83
-
84
- async def handle_api_failure(api_key):
85
- """处理API调用失败"""
86
- async with failure_count_lock:
87
- key_failure_counts[api_key] += 1
88
- if key_failure_counts[api_key] >= MAX_FAILURES:
89
- logger.warning(f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key")
90
-
91
- # 在锁外获取新的key
92
- return await get_next_working_key()
93
-
94
-
95
- class ChatRequest(BaseModel):
96
- messages: List[dict]
97
- model: str = "gemini-1.5-flash-002"
98
- temperature: Optional[float] = 0.7
99
- stream: Optional[bool] = False
100
- tools: Optional[List[dict]] = []
101
- tool_choice: Optional[str] = "auto"
102
- max_tokens: Optional[int] = 1024
103
- stop: Optional[List[str]] = []
104
- top_p: Optional[float] = 0.9
105
- top_k: Optional[int] = 100
106
-
107
-
108
- class EmbeddingRequest(BaseModel):
109
- input: Union[str, List[str]]
110
- model: str = "text-embedding-004"
111
- encoding_format: Optional[str] = "float"
112
-
113
-
114
- async def verify_authorization(authorization: str = Header(None)):
115
- if not authorization:
116
- logger.error("Missing Authorization header")
117
- raise HTTPException(status_code=401, detail="Missing Authorization header")
118
- if not authorization.startswith("Bearer "):
119
- logger.error("Invalid Authorization header format")
120
- raise HTTPException(
121
- status_code=401, detail="Invalid Authorization header format"
122
- )
123
- token = authorization.replace("Bearer ", "")
124
- if token not in config.settings.ALLOWED_TOKENS:
125
- logger.error("Invalid token")
126
- raise HTTPException(status_code=401, detail="Invalid token")
127
- return token
128
-
129
-
130
- def get_gemini_models(api_key):
131
- url = f"{config.settings.BASE_URL}/models?key={api_key}"
132
-
133
- try:
134
- response = requests.get(url)
135
- if response.status_code == 200:
136
- gemini_models = response.json()
137
- return convert_to_openai_models_format(gemini_models)
138
- else:
139
- print(f"Error: {response.status_code}")
140
- print(response.text)
141
- return None
142
-
143
- except requests.RequestException as e:
144
- print(f"Request failed: {e}")
145
- return None
146
-
147
-
148
- def convert_to_openai_models_format(gemini_models):
149
- openai_format = {"object": "list", "data": []}
150
-
151
- for model in gemini_models.get("models", []):
152
- openai_model = {
153
- "id": model["name"].split("/")[-1], # 取最后一部分作为ID
154
- "object": "model",
155
- "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
156
- "owned_by": "google", # 假设所有Gemini模型都由Google拥有
157
- "permission": [], # Gemini API可能没有直接对应的权限信息
158
- "root": model["name"],
159
- "parent": None, # Gemini API可能没有直接对应的父模型信息
160
- }
161
- openai_format["data"].append(openai_model)
162
-
163
- return openai_format
164
-
165
-
166
- def convert_messages_to_gemini_format(messages):
167
- """Convert OpenAI message format to Gemini format"""
168
- converted_messages = []
169
- for msg in messages:
170
- # 修改角色映射
171
- if msg["role"] == "user":
172
- role = "user"
173
- elif msg["role"] == "assistant":
174
- role = "model"
175
- else:
176
- role = "user" # 默认作为用户消息
177
-
178
- parts = []
179
-
180
- # 处理文本内容
181
- if isinstance(msg["content"], str):
182
- parts.append({"text": msg["content"]})
183
- # 处理包含图片的消息
184
- elif isinstance(msg["content"], list):
185
- for content in msg["content"]:
186
- if isinstance(content, str):
187
- parts.append({"text": content})
188
- elif isinstance(content, dict) and content["type"] == "text":
189
- parts.append({"text": content["text"]})
190
- elif isinstance(content, dict) and content["type"] == "image_url":
191
- # 处理图片URL
192
- image_url = content["image_url"]["url"]
193
- if image_url.startswith("data:image"):
194
- # 处理base64图片
195
- parts.append(
196
- {
197
- "inline_data": {
198
- "mime_type": "image/jpeg",
199
- "data": image_url.split(",")[1],
200
- }
201
- }
202
- )
203
- else:
204
- # 处理普通URL图片
205
- parts.append(
206
- {
207
- "image_url": {
208
- "url": image_url,
209
- }
210
- }
211
- )
212
-
213
- converted_messages.append({"role": role, "parts": parts})
214
- return converted_messages
215
-
216
-
217
- def convert_gemini_response_to_openai(response, model, stream=False):
218
- """Convert Gemini response to OpenAI format"""
219
- if stream:
220
- # 处理流式响应
221
- chunk = response
222
- if not chunk["candidates"]:
223
- return None
224
-
225
- return {
226
- "id": "chatcmpl-" + str(uuid.uuid4()),
227
- "object": "chat.completion.chunk",
228
- "created": int(time.time()),
229
- "model": model,
230
- "choices": [
231
- {
232
- "index": 0,
233
- "delta": {
234
- "content": chunk["candidates"][0]["content"]["parts"][0]["text"]
235
- },
236
- "finish_reason": None,
237
- }
238
- ],
239
- }
240
- else:
241
- # 处理普通响应
242
- return {
243
- "id": "chatcmpl-" + str(uuid.uuid4()),
244
- "object": "chat.completion",
245
- "created": int(time.time()),
246
- "model": model,
247
- "choices": [
248
- {
249
- "index": 0,
250
- "message": {
251
- "role": "assistant",
252
- "content": response["candidates"][0]["content"]["parts"][0][
253
- "text"
254
- ],
255
- },
256
- "finish_reason": "stop",
257
- }
258
- ],
259
- "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
260
- }
261
-
262
-
263
- @app.get("/v1/models")
264
- @app.get("/hf/v1/models")
265
- async def list_models(authorization: str = Header(None)):
266
- await verify_authorization(authorization)
267
- api_key = await get_next_working_key()
268
- logger.info(f"Using API key: {api_key}")
269
- try:
270
- response = get_gemini_models(api_key)
271
- logger.info("Successfully retrieved models list")
272
- return response
273
- except Exception as e:
274
- logger.error(f"Error listing models: {str(e)}")
275
- raise HTTPException(status_code=500, detail=str(e))
276
-
277
-
278
- @app.post("/v1/chat/completions")
279
- @app.post("/hf/v1/chat/completions")
280
- async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
281
- await verify_authorization(authorization)
282
- api_key = await get_next_working_key()
283
- logger.info(f"Chat completion request - Model: {request.model}")
284
- retries = 0
285
-
286
- while retries < MAX_RETRIES:
287
- try:
288
- logger.info(f"Attempt {retries + 1} with API key: {api_key}")
289
-
290
- if request.model in config.settings.MODEL_SEARCH:
291
- # Gemini API调用部分
292
- gemini_messages = convert_messages_to_gemini_format(request.messages)
293
- # 调用Gemini API
294
- payload = {
295
- "contents": gemini_messages,
296
- "generationConfig": {
297
- "temperature": request.temperature,
298
- "maxOutputTokens": request.max_tokens,
299
- "stopSequences": request.stop,
300
- "topP": request.top_p,
301
- "topK": request.top_k,
302
- },
303
- "safetySettings": [
304
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
305
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
306
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
307
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
308
- ]
309
- }
310
-
311
- # 打印请���内容
312
- logger.info(f"Request payload: {json.dumps(payload, indent=2)}")
313
-
314
- if request.stream:
315
- logger.info("Streaming response enabled")
316
-
317
- async def generate():
318
- nonlocal api_key, retries
319
- while retries < MAX_RETRIES:
320
- try:
321
- async with httpx.AsyncClient() as client:
322
- stream_url = f"{config.settings.BASE_URL}/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
323
- logger.info(f"Making request to: {stream_url}")
324
- async with client.stream("POST", stream_url, json=payload) as response:
325
- if response.status_code == 429:
326
- logger.warning(f"Rate limit reached for key: {api_key}")
327
- api_key = await handle_api_failure(api_key)
328
- logger.info(f"Retrying with new API key: {api_key}")
329
- retries += 1
330
- if retries >= MAX_RETRIES:
331
- yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
332
- break
333
- continue
334
-
335
- if response.status_code != 200:
336
- error_content = await response.read()
337
- error_text = error_content.decode('utf-8')
338
- logger.error(f"Error in streaming response: {response.status_code}")
339
- logger.error(f"Error details: {error_text}")
340
- yield f"data: {json.dumps({'error': f'API error: {response.status_code}, {error_text}'})}\n\n"
341
- break
342
-
343
- async for line in response.aiter_lines():
344
- if line.startswith("data: "):
345
- try:
346
- chunk = json.loads(line[6:])
347
- openai_chunk = convert_gemini_response_to_openai(
348
- chunk, request.model, stream=True
349
- )
350
- if openai_chunk:
351
- yield f"data: {json.dumps(openai_chunk)}\n\n"
352
- except json.JSONDecodeError:
353
- continue
354
- yield "data: [DONE]\n\n"
355
- return
356
- except Exception as e:
357
- logger.error(f"Stream error: {str(e)}")
358
- api_key = await handle_api_failure(api_key)
359
- retries += 1
360
- if retries >= MAX_RETRIES:
361
- yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
362
- break
363
- continue
364
-
365
- return StreamingResponse(content=generate(), media_type="text/event-stream")
366
- else:
367
- # 非流式响应
368
- async with httpx.AsyncClient() as client:
369
- non_stream_url = f"{config.settings.BASE_URL}/models/{request.model}:generateContent?key={api_key}"
370
- response = await client.post(non_stream_url, json=payload)
371
- gemini_response = response.json()
372
- logger.info("Chat completion successful")
373
- return convert_gemini_response_to_openai(gemini_response, request.model)
374
-
375
- # OpenAI API调用部分
376
- client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
377
- response = client.chat.completions.create(
378
- model=request.model,
379
- messages=request.messages,
380
- temperature=request.temperature,
381
- stream=request.stream if hasattr(request, "stream") else False,
382
- )
383
-
384
- if hasattr(request, "stream") and request.stream:
385
- logger.info("Streaming response enabled")
386
-
387
- async def generate():
388
- for chunk in response:
389
- yield f"data: {chunk.model_dump_json()}\n\n"
390
- logger.info("Chat completion successful")
391
- return StreamingResponse(content=generate(), media_type="text/event-stream")
392
-
393
- logger.info("Chat completion successful")
394
- return response
395
-
396
- except Exception as e:
397
- logger.error(f"Error in chat completion: {str(e)}")
398
- api_key = await handle_api_failure(api_key)
399
- retries += 1
400
-
401
- if retries >= MAX_RETRIES:
402
- logger.error("Max retries reached, giving up")
403
- raise HTTPException(status_code=500, detail="Max retries reached with all available API keys")
404
-
405
- logger.info(f"Retrying with new API key: {api_key}")
406
- continue
407
-
408
- raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
409
-
410
-
411
- @app.post("/v1/embeddings")
412
- @app.post("/hf/v1/embeddings")
413
- async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
414
- await verify_authorization(authorization)
415
- api_key = await get_next_working_key()
416
- logger.info(f"Using API key: {api_key}")
417
-
418
- try:
419
- client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
420
- response = client.embeddings.create(input=request.input, model=request.model)
421
- logger.info("Embedding successful")
422
- return response
423
- except Exception as e:
424
- logger.error(f"Error in embedding: {str(e)}")
425
- raise HTTPException(status_code=500, detail=str(e))
426
-
427
-
428
- @app.get("/health")
429
- @app.get("/")
430
- async def health_check():
431
- logger.info("Health check endpoint called")
432
- return {"status": "healthy"}
433
-
434
-
435
- if __name__ == "__main__":
436
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, Header
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel
5
+ import openai
6
+ from typing import List, Optional, Union
7
+ import logging
8
+ from itertools import cycle
9
+ import asyncio
10
+
11
+ import uvicorn
12
+
13
+ from app import config
14
+ import requests
15
+ from datetime import datetime, timezone
16
+ import json
17
+ import httpx
18
+ import uuid
19
+ import time
20
+
21
+ # 配置日志
22
+ logging.basicConfig(
23
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ app = FastAPI()
28
+
29
+ # 允许跨域
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # API密钥配置
39
+ API_KEYS = config.settings.API_KEYS
40
+
41
+ # 创建一个循环迭代器
42
+ key_cycle = cycle(API_KEYS)
43
+
44
+ # 创建两个独立的锁
45
+ key_cycle_lock = asyncio.Lock()
46
+ failure_count_lock = asyncio.Lock()
47
+
48
+ # 添加key失败计数记录
49
+ key_failure_counts = {key: 0 for key in API_KEYS}
50
+ MAX_FAILURES = 10 # 最大失败次数阈值
51
+ MAX_RETRIES = 3 # 最大重试次数
52
+
53
+
54
+ async def get_next_key():
55
+ """仅获取下一个key,不检查失败次数"""
56
+ async with key_cycle_lock:
57
+ return next(key_cycle)
58
+
59
+ async def is_key_valid(key):
60
+ """检查key是否有效"""
61
+ async with failure_count_lock:
62
+ return key_failure_counts[key] < MAX_FAILURES
63
+
64
+ async def reset_failure_counts():
65
+ """重置所有key的失败计数"""
66
+ async with failure_count_lock:
67
+ for key in key_failure_counts:
68
+ key_failure_counts[key] = 0
69
+
70
+ async def get_next_working_key():
71
+ """获取下一个可用的API key"""
72
+ initial_key = await get_next_key()
73
+ current_key = initial_key
74
+
75
+ while True:
76
+ if await is_key_valid(current_key):
77
+ return current_key
78
+
79
+ current_key = await get_next_key()
80
+ if current_key == initial_key: # 已经循环了一圈
81
+ await reset_failure_counts()
82
+ return current_key
83
+
84
+ async def handle_api_failure(api_key):
85
+ """处理API调用失败"""
86
+ async with failure_count_lock:
87
+ key_failure_counts[api_key] += 1
88
+ if key_failure_counts[api_key] >= MAX_FAILURES:
89
+ logger.warning(f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key")
90
+
91
+ # 在锁外获取新的key
92
+ return await get_next_working_key()
93
+
94
+
95
+ class ChatRequest(BaseModel):
96
+ messages: List[dict]
97
+ model: str = "gemini-1.5-flash-002"
98
+ temperature: Optional[float] = 0.7
99
+ stream: Optional[bool] = False
100
+ tools: Optional[List[dict]] = []
101
+ tool_choice: Optional[str] = "auto"
102
+
103
+
104
+ class EmbeddingRequest(BaseModel):
105
+ input: Union[str, List[str]]
106
+ model: str = "text-embedding-004"
107
+ encoding_format: Optional[str] = "float"
108
+
109
+
110
+ async def verify_authorization(authorization: str = Header(None)):
111
+ if not authorization:
112
+ logger.error("Missing Authorization header")
113
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
114
+ if not authorization.startswith("Bearer "):
115
+ logger.error("Invalid Authorization header format")
116
+ raise HTTPException(
117
+ status_code=401, detail="Invalid Authorization header format"
118
+ )
119
+ token = authorization.replace("Bearer ", "")
120
+ if token not in config.settings.ALLOWED_TOKENS:
121
+ logger.error("Invalid token")
122
+ raise HTTPException(status_code=401, detail="Invalid token")
123
+ return token
124
+
125
+
126
+ def get_gemini_models(api_key):
127
+ base_url = "https://generativelanguage.googleapis.com/v1beta"
128
+ url = f"{base_url}/models?key={api_key}"
129
+
130
+ try:
131
+ response = requests.get(url)
132
+ if response.status_code == 200:
133
+ gemini_models = response.json()
134
+ return convert_to_openai_models_format(gemini_models)
135
+ else:
136
+ print(f"Error: {response.status_code}")
137
+ print(response.text)
138
+ return None
139
+
140
+ except requests.RequestException as e:
141
+ print(f"Request failed: {e}")
142
+ return None
143
+
144
+
145
+ def convert_to_openai_models_format(gemini_models):
146
+ openai_format = {"object": "list", "data": []}
147
+
148
+ for model in gemini_models.get("models", []):
149
+ openai_model = {
150
+ "id": model["name"].split("/")[-1], # 取最后一部分作为ID
151
+ "object": "model",
152
+ "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
153
+ "owned_by": "google", # 假设所有Gemini模型都由Google拥有
154
+ "permission": [], # Gemini API可能没有直接对应的权限信息
155
+ "root": model["name"],
156
+ "parent": None, # Gemini API可能没有直接对应的父模型信息
157
+ }
158
+ openai_format["data"].append(openai_model)
159
+
160
+ return openai_format
161
+
162
+
163
+ def convert_messages_to_gemini_format(messages):
164
+ """Convert OpenAI message format to Gemini format"""
165
+ gemini_messages = []
166
+ for message in messages:
167
+ gemini_message = {
168
+ "role": "user" if message["role"] == "user" else "model",
169
+ "parts": [{"text": message["content"]}],
170
+ }
171
+ gemini_messages.append(gemini_message)
172
+ return gemini_messages
173
+
174
+
175
+ def convert_gemini_response_to_openai(response, model, stream=False):
176
+ """Convert Gemini response to OpenAI format"""
177
+ if stream:
178
+ # 处理流式响应
179
+ chunk = response
180
+ if not chunk["candidates"]:
181
+ return None
182
+
183
+ return {
184
+ "id": "chatcmpl-" + str(uuid.uuid4()),
185
+ "object": "chat.completion.chunk",
186
+ "created": int(time.time()),
187
+ "model": model,
188
+ "choices": [
189
+ {
190
+ "index": 0,
191
+ "delta": {
192
+ "content": chunk["candidates"][0]["content"]["parts"][0]["text"]
193
+ },
194
+ "finish_reason": None,
195
+ }
196
+ ],
197
+ }
198
+ else:
199
+ # 处理普通响应
200
+ return {
201
+ "id": "chatcmpl-" + str(uuid.uuid4()),
202
+ "object": "chat.completion",
203
+ "created": int(time.time()),
204
+ "model": model,
205
+ "choices": [
206
+ {
207
+ "index": 0,
208
+ "message": {
209
+ "role": "assistant",
210
+ "content": response["candidates"][0]["content"]["parts"][0][
211
+ "text"
212
+ ],
213
+ },
214
+ "finish_reason": "stop",
215
+ }
216
+ ],
217
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
218
+ }
219
+
220
+
221
+ @app.get("/v1/models")
222
+ @app.get("/hf/v1/models")
223
+ async def list_models(authorization: str = Header(None)):
224
+ await verify_authorization(authorization)
225
+ api_key = await get_next_working_key()
226
+ logger.info(f"Using API key: {api_key}")
227
+ try:
228
+ response = get_gemini_models(api_key)
229
+ logger.info("Successfully retrieved models list")
230
+ return response
231
+ except Exception as e:
232
+ logger.error(f"Error listing models: {str(e)}")
233
+ raise HTTPException(status_code=500, detail=str(e))
234
+
235
+
236
+ @app.post("/v1/chat/completions")
237
+ @app.post("/hf/v1/chat/completions")
238
+ async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
239
+ await verify_authorization(authorization)
240
+ api_key = await get_next_working_key()
241
+ logger.info(f"Chat completion request - Model: {request.model}")
242
+ retries = 0
243
+
244
+ while retries < MAX_RETRIES:
245
+ try:
246
+ logger.info(f"Attempt {retries + 1} with API key: {api_key}")
247
+
248
+ if request.model in config.settings.MODEL_SEARCH:
249
+ # Gemini API调用部分
250
+ gemini_messages = convert_messages_to_gemini_format(request.messages)
251
+ # 调用Gemini API
252
+ payload = {
253
+ "contents": gemini_messages,
254
+ "generationConfig": {
255
+ "temperature": request.temperature,
256
+ },
257
+ "tools": [{"googleSearch": {}}],
258
+ }
259
+
260
+ if request.stream:
261
+ logger.info("Streaming response enabled")
262
+
263
+ async def generate():
264
+ nonlocal api_key, retries
265
+ while retries < MAX_RETRIES:
266
+ try:
267
+ async with httpx.AsyncClient() as client:
268
+ stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
269
+ async with client.stream("POST", stream_url, json=payload) as response:
270
+ if response.status_code == 429:
271
+ logger.warning(f"Rate limit reached for key: {api_key}")
272
+ api_key = await handle_api_failure(api_key)
273
+ logger.info(f"Retrying with new API key: {api_key}")
274
+ retries += 1
275
+ if retries >= MAX_RETRIES:
276
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
277
+ break
278
+ continue
279
+
280
+ if response.status_code != 200:
281
+ logger.error(f"Error in streaming response: {response.status_code}")
282
+ yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
283
+ break
284
+
285
+ async for line in response.aiter_lines():
286
+ if line.startswith("data: "):
287
+ try:
288
+ chunk = json.loads(line[6:])
289
+ openai_chunk = convert_gemini_response_to_openai(
290
+ chunk, request.model, stream=True
291
+ )
292
+ if openai_chunk:
293
+ yield f"data: {json.dumps(openai_chunk)}\n\n"
294
+ except json.JSONDecodeError:
295
+ continue
296
+ yield "data: [DONE]\n\n"
297
+ return
298
+ except Exception as e:
299
+ logger.error(f"Stream error: {str(e)}")
300
+ api_key = await handle_api_failure(api_key)
301
+ retries += 1
302
+ if retries >= MAX_RETRIES:
303
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
304
+ break
305
+ continue
306
+
307
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
308
+ else:
309
+ # 非流式响应
310
+ async with httpx.AsyncClient() as client:
311
+ non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
312
+ response = await client.post(non_stream_url, json=payload)
313
+ gemini_response = response.json()
314
+ logger.info("Chat completion successful")
315
+ return convert_gemini_response_to_openai(gemini_response, request.model)
316
+
317
+ # OpenAI API调用部分
318
+ client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
319
+ response = client.chat.completions.create(
320
+ model=request.model,
321
+ messages=request.messages,
322
+ temperature=request.temperature,
323
+ stream=request.stream if hasattr(request, "stream") else False,
324
+ )
325
+
326
+ if hasattr(request, "stream") and request.stream:
327
+ logger.info("Streaming response enabled")
328
+
329
+ async def generate():
330
+ for chunk in response:
331
+ yield f"data: {chunk.model_dump_json()}\n\n"
332
+ logger.info("Chat completion successful")
333
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
334
+
335
+ logger.info("Chat completion successful")
336
+ return response
337
+
338
+ except Exception as e:
339
+ logger.error(f"Error in chat completion: {str(e)}")
340
+ api_key = await handle_api_failure(api_key)
341
+ retries += 1
342
+
343
+ if retries >= MAX_RETRIES:
344
+ logger.error("Max retries reached, giving up")
345
+ raise HTTPException(status_code=500, detail="Max retries reached with all available API keys")
346
+
347
+ logger.info(f"Retrying with new API key: {api_key}")
348
+ continue
349
+
350
+ raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
351
+
352
+
353
+ @app.post("/v1/embeddings")
354
+ @app.post("/hf/v1/embeddings")
355
+ async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
356
+ await verify_authorization(authorization)
357
+ api_key = await get_next_working_key()
358
+ logger.info(f"Using API key: {api_key}")
359
+
360
+ try:
361
+ client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
362
+ response = client.embeddings.create(input=request.input, model=request.model)
363
+ logger.info("Embedding successful")
364
+ return response
365
+ except Exception as e:
366
+ logger.error(f"Error in embedding: {str(e)}")
367
+ raise HTTPException(status_code=500, detail=str(e))
368
+
369
+
370
+ @app.get("/health")
371
+ @app.get("/")
372
+ async def health_check():
373
+ logger.info("Health check endpoint called")
374
+ return {"status": "healthy"}
375
+
376
+
377
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  uvicorn.run(app, host="0.0.0.0", port=8000)