Add polling
Browse files- .env.example +1 -0
- .gitignore +2 -1
- main.py +85 -46
.env.example
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
USE_ROUND_ROBIN=true
|
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
api.json
|
2 |
-
api.yaml
|
|
|
|
1 |
api.json
|
2 |
+
api.yaml
|
3 |
+
.env
|
main.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import httpx
|
2 |
import yaml
|
3 |
from contextlib import asynccontextmanager
|
@@ -82,56 +83,94 @@ async def fetch_response(client, url, headers, payload):
|
|
82 |
# print(response.text)
|
83 |
return response.json()
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
try:
|
127 |
-
|
128 |
-
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
|
129 |
-
else:
|
130 |
-
return await fetch_response(app.state.client, url, headers, payload)
|
131 |
except Exception as e:
|
132 |
raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
|
133 |
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
if __name__ == '__main__':
|
137 |
import uvicorn
|
|
|
1 |
+
import os
|
2 |
import httpx
|
3 |
import yaml
|
4 |
from contextlib import asynccontextmanager
|
|
|
83 |
# print(response.text)
|
84 |
return response.json()
|
85 |
|
86 |
+
async def process_request(request: RequestModel, provider: Dict):
|
87 |
+
print("provider: ", provider['provider'])
|
88 |
+
url = provider['base_url']
|
89 |
+
headers = {
|
90 |
+
'Authorization': f"Bearer {provider['api']}",
|
91 |
+
'Content-Type': 'application/json'
|
92 |
+
}
|
93 |
+
|
94 |
+
# 转换消息格式
|
95 |
+
messages = []
|
96 |
+
for msg in request.messages:
|
97 |
+
if isinstance(msg.content, list):
|
98 |
+
content = " ".join([item.text for item in msg.content if item.type == "text"])
|
99 |
+
else:
|
100 |
+
content = msg.content
|
101 |
+
messages.append({"role": msg.role, "content": content})
|
102 |
+
|
103 |
+
payload = {
|
104 |
+
"model": request.model,
|
105 |
+
"messages": messages
|
106 |
+
}
|
107 |
+
|
108 |
+
# 只有当相应参数存在且不为None时,才添加到payload中
|
109 |
+
if request.stream is not None:
|
110 |
+
payload["stream"] = request.stream
|
111 |
+
if request.include_usage is not None:
|
112 |
+
payload["include_usage"] = request.include_usage
|
113 |
+
|
114 |
+
if provider['provider'] == 'anthropic':
|
115 |
+
payload["max_tokens"] = 1000 # 您可能想让这个可配置
|
116 |
+
else:
|
117 |
+
if request.logprobs is not None:
|
118 |
+
payload["logprobs"] = request.logprobs
|
119 |
+
if request.top_logprobs is not None:
|
120 |
+
payload["top_logprobs"] = request.top_logprobs
|
121 |
+
|
122 |
+
if request.stream:
|
123 |
+
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
|
124 |
+
else:
|
125 |
+
return await fetch_response(app.state.client, url, headers, payload)
|
126 |
+
|
127 |
+
class ModelRequestHandler:
|
128 |
+
def __init__(self):
|
129 |
+
self.last_provider_index = -1
|
130 |
+
|
131 |
+
def get_matching_providers(self, model_name):
|
132 |
+
return [provider for provider in config if model_name in provider['model']]
|
133 |
+
|
134 |
+
async def request_model(self, request: RequestModel, token: str):
|
135 |
+
model_name = request.model
|
136 |
+
matching_providers = self.get_matching_providers(model_name)
|
137 |
+
print("matching_providers", matching_providers)
|
138 |
+
|
139 |
+
if not matching_providers:
|
140 |
+
raise HTTPException(status_code=404, detail="No matching model found")
|
141 |
+
|
142 |
+
# 检查是否启用轮询
|
143 |
+
use_round_robin = os.environ.get('USE_ROUND_ROBIN', 'false').lower() == 'true'
|
144 |
+
|
145 |
+
if use_round_robin:
|
146 |
+
return await self.round_robin_request(request, matching_providers)
|
147 |
+
else:
|
148 |
+
# 使用第一个匹配的提供者
|
149 |
+
provider = matching_providers[0]
|
150 |
try:
|
151 |
+
return await process_request(request, provider)
|
|
|
|
|
|
|
152 |
except Exception as e:
|
153 |
raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
|
154 |
|
155 |
+
async def round_robin_request(self, request: RequestModel, providers: List[Dict]):
|
156 |
+
num_providers = len(providers)
|
157 |
+
for i in range(num_providers):
|
158 |
+
self.last_provider_index = (self.last_provider_index + 1) % num_providers
|
159 |
+
# print(f"Trying provider {self.last_provider_index}")
|
160 |
+
provider = providers[self.last_provider_index]
|
161 |
+
try:
|
162 |
+
response = await process_request(request, provider)
|
163 |
+
return response
|
164 |
+
except Exception as e:
|
165 |
+
print(f"Error with provider {provider['provider']}: {str(e)}")
|
166 |
+
continue
|
167 |
+
raise HTTPException(status_code=500, detail="All providers failed")
|
168 |
+
|
169 |
+
model_handler = ModelRequestHandler()
|
170 |
+
|
171 |
+
@app.post("/v1/chat/completions")
|
172 |
+
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
173 |
+
return await model_handler.request_model(request, token)
|
174 |
|
175 |
if __name__ == '__main__':
|
176 |
import uvicorn
|