yym68686 commited on
Commit
3f0a3dd
·
1 Parent(s): 1ddc959

Add polling

Browse files
Files changed (3) hide show
  1. .env.example +1 -0
  2. .gitignore +2 -1
  3. main.py +85 -46
.env.example ADDED
@@ -0,0 +1 @@
 
 
1
+ USE_ROUND_ROBIN=true
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  api.json
2
- api.yaml
 
 
1
  api.json
2
+ api.yaml
3
+ .env
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import httpx
2
  import yaml
3
  from contextlib import asynccontextmanager
@@ -82,56 +83,94 @@ async def fetch_response(client, url, headers, payload):
82
  # print(response.text)
83
  return response.json()
84
 
85
- @app.post("/v1/chat/completions")
86
- async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
87
- model_name = request.model
88
-
89
- for provider in config:
90
- if model_name in provider['model']:
91
- print("provider: ", provider['provider'])
92
- url = provider['base_url']
93
- headers = {
94
- 'Authorization': f"Bearer {provider['api']}",
95
- 'Content-Type': 'application/json'
96
- }
97
-
98
- # 转换消息格式
99
- messages = []
100
- for msg in request.messages:
101
- if isinstance(msg.content, list):
102
- content = " ".join([item.text for item in msg.content if item.type == "text"])
103
- else:
104
- content = msg.content
105
- messages.append({"role": msg.role, "content": content})
106
-
107
- payload = {
108
- "model": model_name,
109
- "messages": messages
110
- }
111
-
112
- # 只有当相应参数存在且不为None时,才添加到payload中
113
- if request.stream is not None:
114
- payload["stream"] = request.stream
115
- if request.include_usage is not None:
116
- payload["include_usage"] = request.include_usage
117
-
118
- if provider['provider'] == 'anthropic':
119
- payload["max_tokens"] = 1000 # 您可能想让这个可配置
120
- else:
121
- if request.logprobs is not None:
122
- payload["logprobs"] = request.logprobs
123
- if request.top_logprobs is not None:
124
- payload["top_logprobs"] = request.top_logprobs
125
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  try:
127
- if request.stream:
128
- return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
129
- else:
130
- return await fetch_response(app.state.client, url, headers, payload)
131
  except Exception as e:
132
  raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
133
 
134
- raise HTTPException(status_code=404, detail="No matching model found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == '__main__':
137
  import uvicorn
 
1
+ import os
2
  import httpx
3
  import yaml
4
  from contextlib import asynccontextmanager
 
83
  # print(response.text)
84
  return response.json()
85
 
86
+ async def process_request(request: RequestModel, provider: Dict):
87
+ print("provider: ", provider['provider'])
88
+ url = provider['base_url']
89
+ headers = {
90
+ 'Authorization': f"Bearer {provider['api']}",
91
+ 'Content-Type': 'application/json'
92
+ }
93
+
94
+ # 转换消息格式
95
+ messages = []
96
+ for msg in request.messages:
97
+ if isinstance(msg.content, list):
98
+ content = " ".join([item.text for item in msg.content if item.type == "text"])
99
+ else:
100
+ content = msg.content
101
+ messages.append({"role": msg.role, "content": content})
102
+
103
+ payload = {
104
+ "model": request.model,
105
+ "messages": messages
106
+ }
107
+
108
+ # 只有当相应参数存在且不为None时,才添加到payload
109
+ if request.stream is not None:
110
+ payload["stream"] = request.stream
111
+ if request.include_usage is not None:
112
+ payload["include_usage"] = request.include_usage
113
+
114
+ if provider['provider'] == 'anthropic':
115
+ payload["max_tokens"] = 1000 # 您可能想让这个可配置
116
+ else:
117
+ if request.logprobs is not None:
118
+ payload["logprobs"] = request.logprobs
119
+ if request.top_logprobs is not None:
120
+ payload["top_logprobs"] = request.top_logprobs
121
+
122
+ if request.stream:
123
+ return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
124
+ else:
125
+ return await fetch_response(app.state.client, url, headers, payload)
126
+
127
+ class ModelRequestHandler:
128
+ def __init__(self):
129
+ self.last_provider_index = -1
130
+
131
+ def get_matching_providers(self, model_name):
132
+ return [provider for provider in config if model_name in provider['model']]
133
+
134
+ async def request_model(self, request: RequestModel, token: str):
135
+ model_name = request.model
136
+ matching_providers = self.get_matching_providers(model_name)
137
+ print("matching_providers", matching_providers)
138
+
139
+ if not matching_providers:
140
+ raise HTTPException(status_code=404, detail="No matching model found")
141
+
142
+ # 检查是否启用轮询
143
+ use_round_robin = os.environ.get('USE_ROUND_ROBIN', 'false').lower() == 'true'
144
+
145
+ if use_round_robin:
146
+ return await self.round_robin_request(request, matching_providers)
147
+ else:
148
+ # 使用第一个匹配的提供者
149
+ provider = matching_providers[0]
150
  try:
151
+ return await process_request(request, provider)
 
 
 
152
  except Exception as e:
153
  raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
154
 
155
+ async def round_robin_request(self, request: RequestModel, providers: List[Dict]):
156
+ num_providers = len(providers)
157
+ for i in range(num_providers):
158
+ self.last_provider_index = (self.last_provider_index + 1) % num_providers
159
+ # print(f"Trying provider {self.last_provider_index}")
160
+ provider = providers[self.last_provider_index]
161
+ try:
162
+ response = await process_request(request, provider)
163
+ return response
164
+ except Exception as e:
165
+ print(f"Error with provider {provider['provider']}: {str(e)}")
166
+ continue
167
+ raise HTTPException(status_code=500, detail="All providers failed")
168
+
169
+ model_handler = ModelRequestHandler()
170
+
171
+ @app.post("/v1/chat/completions")
172
+ async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
173
+ return await model_handler.request_model(request, token)
174
 
175
  if __name__ == '__main__':
176
  import uvicorn