dan92 commited on
Commit
25bf87c
·
verified ·
1 Parent(s): fbea3b4

Upload 7 files

Browse files
Files changed (7) hide show
  1. .env +3 -0
  2. .gitattributes +35 -35
  3. Dockerfile +19 -19
  4. README.md +12 -12
  5. app/config.py +2 -1
  6. main.py +377 -162
  7. requirements.txt +7 -6
.env CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ API_KEYS=["ccc"]
2
+ ALLOWED_TOKENS=["xxx"]
3
+ BASE_URL=https://api.groq.com/openai/v1
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -1,19 +1,19 @@
1
- FROM python:3.9-slim
2
-
3
- WORKDIR /app
4
-
5
- # 复制所需文件到容器中
6
- COPY ./app /app/app
7
- COPY ./main.py /app
8
- COPY ./requirements.txt /app
9
-
10
- RUN pip install --no-cache-dir -r requirements.txt
11
- ENV API_KEYS=["your_api_key_1"]
12
- ENV ALLOWED_TOKENS=["your_token_1"]
13
- ENV BASE_URL=https://api.groq.com/openai/v1
14
-
15
- # Expose port
16
- EXPOSE 8000
17
-
18
- # Run the application
19
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # 复制所需文件到容器中
6
+ COPY ./app /app/app
7
+ COPY ./main.py /app
8
+ COPY ./requirements.txt /app
9
+
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+ ENV API_KEYS=["your_api_key_1"]
12
+ ENV ALLOWED_TOKENS=["your_token_1"]
13
+ ENV BASE_URL=https://api.groq.com/openai/v1
14
+
15
+ # Expose port
16
+ EXPOSE 8000
17
+
18
+ # Run the application
19
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Gemini Balance
3
- emoji: 🐨
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- app_port: 8000
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Gemini Balance
3
+ emoji: 🐨
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ app_port: 8000
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app/config.py CHANGED
@@ -6,7 +6,8 @@ class Settings(BaseSettings):
6
  API_KEYS: List[str]
7
  ALLOWED_TOKENS: List[str]
8
  BASE_URL: str
9
-
 
10
  class Config:
11
  env_file = ".env"
12
  env_file_encoding = "utf-8"
 
6
  API_KEYS: List[str]
7
  ALLOWED_TOKENS: List[str]
8
  BASE_URL: str
9
+ MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
10
+
11
  class Config:
12
  env_file = ".env"
13
  env_file_encoding = "utf-8"
main.py CHANGED
@@ -1,163 +1,378 @@
1
- from fastapi import FastAPI, HTTPException, Header, Request
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import StreamingResponse
4
- from pydantic import BaseModel
5
- import openai
6
- from typing import List, Optional
7
- import logging
8
- from itertools import cycle
9
- import asyncio
10
-
11
- import uvicorn
12
-
13
- from app import config
14
- import requests
15
- from datetime import datetime, timezone
16
-
17
- # 配置日志
18
- logging.basicConfig(
19
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
20
- )
21
- logger = logging.getLogger(__name__)
22
-
23
- app = FastAPI()
24
-
25
- # 允许跨域
26
- app.add_middleware(
27
- CORSMiddleware,
28
- allow_origins=["*"],
29
- allow_credentials=True,
30
- allow_methods=["*"],
31
- allow_headers=["*"],
32
- )
33
-
34
- # API密钥配置
35
- API_KEYS = config.settings.API_KEYS
36
-
37
- # 创建一个循环迭代器
38
- key_cycle = cycle(API_KEYS)
39
- key_lock = asyncio.Lock()
40
-
41
-
42
- class ChatRequest(BaseModel):
43
- messages: List[dict]
44
- model: str = "llama-3.2-90b-text-preview"
45
- temperature: Optional[float] = 0.7
46
- stream: Optional[bool] = False
47
-
48
-
49
- async def verify_authorization(authorization: str = Header(None)):
50
- if not authorization:
51
- logger.error("Missing Authorization header")
52
- raise HTTPException(status_code=401, detail="Missing Authorization header")
53
- if not authorization.startswith("Bearer "):
54
- logger.error("Invalid Authorization header format")
55
- raise HTTPException(
56
- status_code=401, detail="Invalid Authorization header format"
57
- )
58
- token = authorization.replace("Bearer ", "")
59
- if token not in config.settings.ALLOWED_TOKENS:
60
- logger.error("Invalid token")
61
- raise HTTPException(status_code=401, detail="Invalid token")
62
- return token
63
-
64
-
65
- def get_gemini_models(api_key):
66
- base_url = "https://generativelanguage.googleapis.com/v1beta"
67
- url = f"{base_url}/models?key={api_key}"
68
-
69
- try:
70
- response = requests.get(url)
71
- if response.status_code == 200:
72
- gemini_models = response.json()
73
- return convert_to_openai_format(gemini_models)
74
- else:
75
- print(f"Error: {response.status_code}")
76
- print(response.text)
77
- return None
78
-
79
- except requests.RequestException as e:
80
- print(f"Request failed: {e}")
81
- return None
82
-
83
- def convert_to_openai_format(gemini_models):
84
- openai_format = {
85
- "object": "list",
86
- "data": []
87
- }
88
-
89
- for model in gemini_models.get('models', []):
90
- openai_model = {
91
- "id": model['name'].split('/')[-1], # 取最后一部分作为ID
92
- "object": "model",
93
- "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
94
- "owned_by": "google", # 假设所有Gemini模型都由Google拥有
95
- "permission": [], # Gemini API可能没有直接对应的权限信息
96
- "root": model['name'],
97
- "parent": None, # Gemini API可能没有直接对应的父模型信息
98
- }
99
- openai_format["data"].append(openai_model)
100
-
101
- return openai_format
102
-
103
-
104
- @app.get("/v1/models")
105
- @app.get("/hf/v1/models")
106
- async def list_models(authorization: str = Header(None)):
107
- await verify_authorization(authorization)
108
- async with key_lock:
109
- api_key = next(key_cycle)
110
- logger.info(f"Using API key: {api_key}")
111
- try:
112
- response = get_gemini_models(api_key)
113
- logger.info("Successfully retrieved models list")
114
- return response
115
- except Exception as e:
116
- logger.error(f"Error listing models: {str(e)}")
117
- raise HTTPException(status_code=500, detail=str(e))
118
-
119
-
120
- @app.post("/v1/chat/completions")
121
- @app.post("/hf/v1/chat/completions")
122
- async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
123
- await verify_authorization(authorization)
124
- async with key_lock:
125
- api_key = next(key_cycle)
126
- logger.info(f"Using API key: {api_key}")
127
-
128
- try:
129
- logger.info(f"Chat completion request - Model: {request.model}")
130
- client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
131
- response = client.chat.completions.create(
132
- model=request.model,
133
- messages=request.messages,
134
- temperature=request.temperature,
135
- stream=request.stream if hasattr(request, "stream") else False,
136
- )
137
-
138
- if hasattr(request, "stream") and request.stream:
139
- logger.info("Streaming response enabled")
140
-
141
- async def generate():
142
- for chunk in response:
143
- yield f"data: {chunk.model_dump_json()}\n\n"
144
-
145
- return StreamingResponse(content=generate(), media_type="text/event-stream")
146
-
147
- logger.info("Chat completion successful")
148
- return response
149
-
150
- except Exception as e:
151
- logger.error(f"Error in chat completion: {str(e)}")
152
- raise HTTPException(status_code=500, detail=str(e))
153
-
154
-
155
- @app.get("/health")
156
- @app.get("/")
157
- async def health_check():
158
- logger.info("Health check endpoint called")
159
- return {"status": "healthy"}
160
-
161
-
162
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, Header
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel
5
+ import openai
6
+ from typing import List, Optional, Union
7
+ import logging
8
+ from itertools import cycle
9
+ import asyncio
10
+
11
+ import uvicorn
12
+
13
+ from app import config
14
+ import requests
15
+ from datetime import datetime, timezone
16
+ import json
17
+ import httpx
18
+ import uuid
19
+ import time
20
+
21
+ # 配置日志
22
+ logging.basicConfig(
23
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ app = FastAPI()
28
+
29
+ # 允许跨域
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # API密钥配置
39
+ API_KEYS = config.settings.API_KEYS
40
+
41
+ # 创建一个循环迭代器
42
+ key_cycle = cycle(API_KEYS)
43
+
44
+ # 创建两个独立的锁
45
+ key_cycle_lock = asyncio.Lock()
46
+ failure_count_lock = asyncio.Lock()
47
+
48
+ # 添加key失败计数记录
49
+ key_failure_counts = {key: 0 for key in API_KEYS}
50
+ MAX_FAILURES = 10 # 最大失败次数阈值
51
+ MAX_RETRIES = 3 # 最大重试次数
52
+
53
+
54
+ async def get_next_key():
55
+ """仅获取下一个key,不检查失败次数"""
56
+ async with key_cycle_lock:
57
+ return next(key_cycle)
58
+
59
+ async def is_key_valid(key):
60
+ """检查key是否有效"""
61
+ async with failure_count_lock:
62
+ return key_failure_counts[key] < MAX_FAILURES
63
+
64
+ async def reset_failure_counts():
65
+ """重置所有key的失败计数"""
66
+ async with failure_count_lock:
67
+ for key in key_failure_counts:
68
+ key_failure_counts[key] = 0
69
+
70
+ async def get_next_working_key():
71
+ """获取下一个可用的API key"""
72
+ initial_key = await get_next_key()
73
+ current_key = initial_key
74
+
75
+ while True:
76
+ if await is_key_valid(current_key):
77
+ return current_key
78
+
79
+ current_key = await get_next_key()
80
+ if current_key == initial_key: # 已经循环了一圈
81
+ await reset_failure_counts()
82
+ return current_key
83
+
84
+ async def handle_api_failure(api_key):
85
+ """处理API调用失败"""
86
+ async with failure_count_lock:
87
+ key_failure_counts[api_key] += 1
88
+ if key_failure_counts[api_key] >= MAX_FAILURES:
89
+ logger.warning(f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key")
90
+
91
+ # 在锁外获取新的key
92
+ return await get_next_working_key()
93
+
94
+
95
+ class ChatRequest(BaseModel):
96
+ messages: List[dict]
97
+ model: str = "gemini-1.5-flash-002"
98
+ temperature: Optional[float] = 0.7
99
+ stream: Optional[bool] = False
100
+ tools: Optional[List[dict]] = []
101
+ tool_choice: Optional[str] = "auto"
102
+
103
+
104
+ class EmbeddingRequest(BaseModel):
105
+ input: Union[str, List[str]]
106
+ model: str = "text-embedding-004"
107
+ encoding_format: Optional[str] = "float"
108
+
109
+
110
+ async def verify_authorization(authorization: str = Header(None)):
111
+ if not authorization:
112
+ logger.error("Missing Authorization header")
113
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
114
+ if not authorization.startswith("Bearer "):
115
+ logger.error("Invalid Authorization header format")
116
+ raise HTTPException(
117
+ status_code=401, detail="Invalid Authorization header format"
118
+ )
119
+ token = authorization.replace("Bearer ", "")
120
+ if token not in config.settings.ALLOWED_TOKENS:
121
+ logger.error("Invalid token")
122
+ raise HTTPException(status_code=401, detail="Invalid token")
123
+ return token
124
+
125
+
126
+ def get_gemini_models(api_key):
127
+ base_url = "https://generativelanguage.googleapis.com/v1beta"
128
+ url = f"{base_url}/models?key={api_key}"
129
+
130
+ try:
131
+ response = requests.get(url)
132
+ if response.status_code == 200:
133
+ gemini_models = response.json()
134
+ return convert_to_openai_models_format(gemini_models)
135
+ else:
136
+ print(f"Error: {response.status_code}")
137
+ print(response.text)
138
+ return None
139
+
140
+ except requests.RequestException as e:
141
+ print(f"Request failed: {e}")
142
+ return None
143
+
144
+
145
+ def convert_to_openai_models_format(gemini_models):
146
+ openai_format = {"object": "list", "data": []}
147
+
148
+ for model in gemini_models.get("models", []):
149
+ openai_model = {
150
+ "id": model["name"].split("/")[-1], # 取最后一部分作为ID
151
+ "object": "model",
152
+ "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
153
+ "owned_by": "google", # 假设所有Gemini模型都由Google拥有
154
+ "permission": [], # Gemini API可能没有直接对应的权限信息
155
+ "root": model["name"],
156
+ "parent": None, # Gemini API可能没有直接对应的父模型信息
157
+ }
158
+ openai_format["data"].append(openai_model)
159
+
160
+ return openai_format
161
+
162
+
163
+ def convert_messages_to_gemini_format(messages):
164
+ """Convert OpenAI message format to Gemini format"""
165
+ gemini_messages = []
166
+ for message in messages:
167
+ gemini_message = {
168
+ "role": "user" if message["role"] == "user" else "model",
169
+ "parts": [{"text": message["content"]}],
170
+ }
171
+ gemini_messages.append(gemini_message)
172
+ return gemini_messages
173
+
174
+
175
+ def convert_gemini_response_to_openai(response, model, stream=False):
176
+ """Convert Gemini response to OpenAI format"""
177
+ if stream:
178
+ # 处理流式响应
179
+ chunk = response
180
+ if not chunk["candidates"]:
181
+ return None
182
+
183
+ return {
184
+ "id": "chatcmpl-" + str(uuid.uuid4()),
185
+ "object": "chat.completion.chunk",
186
+ "created": int(time.time()),
187
+ "model": model,
188
+ "choices": [
189
+ {
190
+ "index": 0,
191
+ "delta": {
192
+ "content": chunk["candidates"][0]["content"]["parts"][0]["text"]
193
+ },
194
+ "finish_reason": None,
195
+ }
196
+ ],
197
+ }
198
+ else:
199
+ # 处理普通响应
200
+ return {
201
+ "id": "chatcmpl-" + str(uuid.uuid4()),
202
+ "object": "chat.completion",
203
+ "created": int(time.time()),
204
+ "model": model,
205
+ "choices": [
206
+ {
207
+ "index": 0,
208
+ "message": {
209
+ "role": "assistant",
210
+ "content": response["candidates"][0]["content"]["parts"][0][
211
+ "text"
212
+ ],
213
+ },
214
+ "finish_reason": "stop",
215
+ }
216
+ ],
217
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
218
+ }
219
+
220
+
221
+ @app.get("/v1/models")
222
+ @app.get("/hf/v1/models")
223
+ async def list_models(authorization: str = Header(None)):
224
+ await verify_authorization(authorization)
225
+ api_key = await get_next_working_key()
226
+ logger.info(f"Using API key: {api_key}")
227
+ try:
228
+ response = get_gemini_models(api_key)
229
+ logger.info("Successfully retrieved models list")
230
+ return response
231
+ except Exception as e:
232
+ logger.error(f"Error listing models: {str(e)}")
233
+ raise HTTPException(status_code=500, detail=str(e))
234
+
235
+
236
+ @app.post("/v1/chat/completions")
237
+ @app.post("/hf/v1/chat/completions")
238
+ async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
239
+ await verify_authorization(authorization)
240
+ api_key = await get_next_working_key()
241
+ logger.info(f"Chat completion request - Model: {request.model}")
242
+ retries = 0
243
+
244
+ while retries < MAX_RETRIES:
245
+ try:
246
+ logger.info(f"Attempt {retries + 1} with API key: {api_key}")
247
+
248
+ if request.model in config.settings.MODEL_SEARCH:
249
+ # Gemini API调用部分
250
+ gemini_messages = convert_messages_to_gemini_format(request.messages)
251
+ # 调用Gemini API
252
+ payload = {
253
+ "contents": gemini_messages,
254
+ "generationConfig": {
255
+ "temperature": request.temperature,
256
+ },
257
+ "tools": [{"googleSearch": {}}],
258
+ }
259
+
260
+ if request.stream:
261
+ logger.info("Streaming response enabled")
262
+
263
+ async def generate():
264
+ nonlocal api_key, retries
265
+ while retries < MAX_RETRIES:
266
+ try:
267
+ async with httpx.AsyncClient() as client:
268
+ stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
269
+ async with client.stream("POST", stream_url, json=payload) as response:
270
+ if response.status_code == 429:
271
+ logger.warning(f"Rate limit reached for key: {api_key}")
272
+ api_key = await handle_api_failure(api_key)
273
+ logger.info(f"Retrying with new API key: {api_key}")
274
+ retries += 1
275
+ if retries >= MAX_RETRIES:
276
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
277
+ break
278
+ continue
279
+
280
+ if response.status_code != 200:
281
+ logger.error(f"Error in streaming response: {response.status_code}")
282
+ yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
283
+ break
284
+
285
+ async for line in response.aiter_lines():
286
+ if line.startswith("data: "):
287
+ try:
288
+ chunk = json.loads(line[6:])
289
+ openai_chunk = convert_gemini_response_to_openai(
290
+ chunk, request.model, stream=True
291
+ )
292
+ if openai_chunk:
293
+ yield f"data: {json.dumps(openai_chunk)}\n\n"
294
+ except json.JSONDecodeError:
295
+ continue
296
+ yield "data: [DONE]\n\n"
297
+ return
298
+ except Exception as e:
299
+ logger.error(f"Stream error: {str(e)}")
300
+ api_key = await handle_api_failure(api_key)
301
+ retries += 1
302
+ if retries >= MAX_RETRIES:
303
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
304
+ break
305
+ continue
306
+
307
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
308
+ else:
309
+ # 非流式响应
310
+ async with httpx.AsyncClient() as client:
311
+ non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
312
+ response = await client.post(non_stream_url, json=payload)
313
+ gemini_response = response.json()
314
+ logger.info("Chat completion successful")
315
+ return convert_gemini_response_to_openai(gemini_response, request.model)
316
+
317
+ # OpenAI API调用部分
318
+ client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
319
+ response = client.chat.completions.create(
320
+ model=request.model,
321
+ messages=request.messages,
322
+ temperature=request.temperature,
323
+ stream=request.stream if hasattr(request, "stream") else False,
324
+ )
325
+
326
+ if hasattr(request, "stream") and request.stream:
327
+ logger.info("Streaming response enabled")
328
+
329
+ async def generate():
330
+ for chunk in response:
331
+ yield f"data: {chunk.model_dump_json()}\n\n"
332
+ logger.info("Chat completion successful")
333
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
334
+
335
+ logger.info("Chat completion successful")
336
+ return response
337
+
338
+ except Exception as e:
339
+ logger.error(f"Error in chat completion: {str(e)}")
340
+ api_key = await handle_api_failure(api_key)
341
+ retries += 1
342
+
343
+ if retries >= MAX_RETRIES:
344
+ logger.error("Max retries reached, giving up")
345
+ raise HTTPException(status_code=500, detail="Max retries reached with all available API keys")
346
+
347
+ logger.info(f"Retrying with new API key: {api_key}")
348
+ continue
349
+
350
+ raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
351
+
352
+
353
+ @app.post("/v1/embeddings")
354
+ @app.post("/hf/v1/embeddings")
355
+ async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
356
+ await verify_authorization(authorization)
357
+ api_key = await get_next_working_key()
358
+ logger.info(f"Using API key: {api_key}")
359
+
360
+ try:
361
+ client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
362
+ response = client.embeddings.create(input=request.input, model=request.model)
363
+ logger.info("Embedding successful")
364
+ return response
365
+ except Exception as e:
366
+ logger.error(f"Error in embedding: {str(e)}")
367
+ raise HTTPException(status_code=500, detail=str(e))
368
+
369
+
370
+ @app.get("/health")
371
+ @app.get("/")
372
+ async def health_check():
373
+ logger.info("Health check endpoint called")
374
+ return {"status": "healthy"}
375
+
376
+
377
+ if __name__ == "__main__":
378
  uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- fastapi
2
- openai
3
- pydantic
4
- pydantic_settings
5
- uvicorn
6
- requests
 
 
1
+ fastapi
2
+ httpx
3
+ openai
4
+ pydantic
5
+ pydantic_settings
6
+ requests
7
+ uvicorn