Support assigning different models to different API keys
Browse files- main.py +45 -21
- response.py +17 -11
main.py
CHANGED
@@ -27,12 +27,6 @@ async def lifespan(app: FastAPI):
|
|
27 |
|
28 |
app = FastAPI(lifespan=lifespan)
|
29 |
|
30 |
-
# 模拟存储API Key的数据库
|
31 |
-
api_keys_db = {
|
32 |
-
"sk-KjjI60Yf0JFcsvgRmXqFwgGmWUd9GZnmi3KlvowmRWpWpQRo": "user1",
|
33 |
-
# 可以添加更多的API Key
|
34 |
-
}
|
35 |
-
|
36 |
# 安全性依赖
|
37 |
security = HTTPBearer()
|
38 |
|
@@ -49,7 +43,7 @@ def load_config():
|
|
49 |
return []
|
50 |
|
51 |
config = load_config()
|
52 |
-
for index, provider in enumerate(config):
|
53 |
model_dict = {}
|
54 |
for model in provider['model']:
|
55 |
if type(model) == str:
|
@@ -57,8 +51,10 @@ for index, provider in enumerate(config):
|
|
57 |
if type(model) == dict:
|
58 |
model_dict.update({value: key for key, value in model.items()})
|
59 |
provider['model'] = model_dict
|
60 |
-
config[index] = provider
|
61 |
-
|
|
|
|
|
62 |
|
63 |
async def process_request(request: RequestModel, provider: Dict):
|
64 |
print("provider: ", provider['provider'])
|
@@ -93,17 +89,30 @@ class ModelRequestHandler:
|
|
93 |
def __init__(self):
|
94 |
self.last_provider_index = -1
|
95 |
|
96 |
-
def get_matching_providers(self, model_name):
|
97 |
# for provider in config:
|
98 |
# print("provider", model_name, list(provider['model'].keys()))
|
99 |
# if model_name in provider['model'].keys():
|
100 |
# print("provider", provider)
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
async def request_model(self, request: RequestModel, token: str):
|
104 |
model_name = request.model
|
105 |
-
matching_providers = self.get_matching_providers(model_name)
|
106 |
-
|
107 |
|
108 |
if not matching_providers:
|
109 |
raise HTTPException(status_code=404, detail="No matching model found")
|
@@ -153,7 +162,7 @@ async def log_requests(request: Request, call_next):
|
|
153 |
|
154 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
155 |
token = credentials.credentials
|
156 |
-
if token not in
|
157 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
158 |
return token
|
159 |
|
@@ -161,27 +170,42 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
161 |
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
162 |
return await model_handler.request_model(request, token)
|
163 |
|
164 |
-
def get_all_models():
|
165 |
all_models = []
|
166 |
unique_models = set()
|
167 |
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
170 |
if model not in unique_models:
|
171 |
unique_models.add(model)
|
172 |
model_info = {
|
173 |
"id": model,
|
174 |
"object": "model",
|
175 |
"created": 1720524448858,
|
176 |
-
"owned_by":
|
177 |
}
|
178 |
all_models.append(model_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
return all_models
|
181 |
|
182 |
-
@app.
|
183 |
-
async def list_models():
|
184 |
-
models = get_all_models()
|
185 |
return {
|
186 |
"object": "list",
|
187 |
"data": models
|
|
|
27 |
|
28 |
app = FastAPI(lifespan=lifespan)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# 安全性依赖
|
31 |
security = HTTPBearer()
|
32 |
|
|
|
43 |
return []
|
44 |
|
45 |
config = load_config()
|
46 |
+
for index, provider in enumerate(config['providers']):
|
47 |
model_dict = {}
|
48 |
for model in provider['model']:
|
49 |
if type(model) == str:
|
|
|
51 |
if type(model) == dict:
|
52 |
model_dict.update({value: key for key, value in model.items()})
|
53 |
provider['model'] = model_dict
|
54 |
+
config['providers'][index] = provider
|
55 |
+
api_keys_db = config['api_keys']
|
56 |
+
api_list = [item["api"] for item in api_keys_db]
|
57 |
+
print(json.dumps(config, indent=4, ensure_ascii=False))
|
58 |
|
59 |
async def process_request(request: RequestModel, provider: Dict):
|
60 |
print("provider: ", provider['provider'])
|
|
|
89 |
def __init__(self):
|
90 |
self.last_provider_index = -1
|
91 |
|
92 |
+
def get_matching_providers(self, model_name, token):
|
93 |
# for provider in config:
|
94 |
# print("provider", model_name, list(provider['model'].keys()))
|
95 |
# if model_name in provider['model'].keys():
|
96 |
# print("provider", provider)
|
97 |
+
api_index = api_list.index(token)
|
98 |
+
provider_rules = {}
|
99 |
+
|
100 |
+
for model in config['api_keys'][api_index]['model']:
|
101 |
+
if "/" in model:
|
102 |
+
provider_name = model.split("/")[0]
|
103 |
+
model = model.split("/")[1]
|
104 |
+
if model_name == model:
|
105 |
+
provider_rules[provider_name] = model
|
106 |
+
provider_list = []
|
107 |
+
for provider in config['providers']:
|
108 |
+
if model_name in provider['model'].keys() and ((provider_rules != {} and provider['provider'] in provider_rules.keys()) or provider_rules == {}):
|
109 |
+
provider_list.append(provider)
|
110 |
+
return provider_list
|
111 |
|
112 |
async def request_model(self, request: RequestModel, token: str):
|
113 |
model_name = request.model
|
114 |
+
matching_providers = self.get_matching_providers(model_name, token)
|
115 |
+
print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False))
|
116 |
|
117 |
if not matching_providers:
|
118 |
raise HTTPException(status_code=404, detail="No matching model found")
|
|
|
162 |
|
163 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
164 |
token = credentials.credentials
|
165 |
+
if token not in api_list:
|
166 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
167 |
return token
|
168 |
|
|
|
170 |
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
171 |
return await model_handler.request_model(request, token)
|
172 |
|
173 |
+
def get_all_models(token):
|
174 |
all_models = []
|
175 |
unique_models = set()
|
176 |
|
177 |
+
if token not in api_list:
|
178 |
+
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
179 |
+
api_index = api_list.index(token)
|
180 |
+
if config['api_keys'][api_index]['model']:
|
181 |
+
for model in config['api_keys'][api_index]['model']:
|
182 |
if model not in unique_models:
|
183 |
unique_models.add(model)
|
184 |
model_info = {
|
185 |
"id": model,
|
186 |
"object": "model",
|
187 |
"created": 1720524448858,
|
188 |
+
"owned_by": model
|
189 |
}
|
190 |
all_models.append(model_info)
|
191 |
+
else:
|
192 |
+
for provider in config["providers"]:
|
193 |
+
for model in provider['model'].keys():
|
194 |
+
if model not in unique_models:
|
195 |
+
unique_models.add(model)
|
196 |
+
model_info = {
|
197 |
+
"id": model,
|
198 |
+
"object": "model",
|
199 |
+
"created": 1720524448858,
|
200 |
+
"owned_by": provider['provider']
|
201 |
+
}
|
202 |
+
all_models.append(model_info)
|
203 |
|
204 |
return all_models
|
205 |
|
206 |
+
@app.post("/v1/models")
|
207 |
+
async def list_models(token: str = Depends(verify_api_key)):
|
208 |
+
models = get_all_models(token)
|
209 |
return {
|
210 |
"object": "list",
|
211 |
"data": models
|
response.py
CHANGED
@@ -136,14 +136,20 @@ async def fetch_response(client, url, headers, payload):
|
|
136 |
return response.json()
|
137 |
|
138 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
return response.json()
|
137 |
|
138 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
139 |
+
for _ in range(2):
|
140 |
+
try:
|
141 |
+
if engine == "gemini":
|
142 |
+
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
143 |
+
yield chunk
|
144 |
+
elif engine == "claude":
|
145 |
+
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
146 |
+
yield chunk
|
147 |
+
elif engine == "gpt":
|
148 |
+
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
149 |
+
yield chunk
|
150 |
+
else:
|
151 |
+
raise ValueError("Unknown response")
|
152 |
+
break
|
153 |
+
except httpx.ConnectError as e:
|
154 |
+
print(f"连接错误: {e}")
|
155 |
+
continue
|