yym68686 commited on
Commit
9410047
·
1 Parent(s): 3e3ea9a

Support assigning different models to different API keys

Browse files
Files changed (2) hide show
  1. main.py +45 -21
  2. response.py +17 -11
main.py CHANGED
@@ -27,12 +27,6 @@ async def lifespan(app: FastAPI):
27
 
28
  app = FastAPI(lifespan=lifespan)
29
 
30
- # 模拟存储API Key的数据库
31
- api_keys_db = {
32
- "sk-KjjI60Yf0JFcsvgRmXqFwgGmWUd9GZnmi3KlvowmRWpWpQRo": "user1",
33
- # 可以添加更多的API Key
34
- }
35
-
36
  # 安全性依赖
37
  security = HTTPBearer()
38
 
@@ -49,7 +43,7 @@ def load_config():
49
  return []
50
 
51
  config = load_config()
52
- for index, provider in enumerate(config):
53
  model_dict = {}
54
  for model in provider['model']:
55
  if type(model) == str:
@@ -57,8 +51,10 @@ for index, provider in enumerate(config):
57
  if type(model) == dict:
58
  model_dict.update({value: key for key, value in model.items()})
59
  provider['model'] = model_dict
60
- config[index] = provider
61
- # print(json.dumps(config, indent=4, ensure_ascii=False))
 
 
62
 
63
  async def process_request(request: RequestModel, provider: Dict):
64
  print("provider: ", provider['provider'])
@@ -93,17 +89,30 @@ class ModelRequestHandler:
93
  def __init__(self):
94
  self.last_provider_index = -1
95
 
96
- def get_matching_providers(self, model_name):
97
  # for provider in config:
98
  # print("provider", model_name, list(provider['model'].keys()))
99
  # if model_name in provider['model'].keys():
100
  # print("provider", provider)
101
- return [provider for provider in config if model_name in provider['model'].keys()]
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  async def request_model(self, request: RequestModel, token: str):
104
  model_name = request.model
105
- matching_providers = self.get_matching_providers(model_name)
106
- # print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False))
107
 
108
  if not matching_providers:
109
  raise HTTPException(status_code=404, detail="No matching model found")
@@ -153,7 +162,7 @@ async def log_requests(request: Request, call_next):
153
 
154
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
155
  token = credentials.credentials
156
- if token not in api_keys_db:
157
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
158
  return token
159
 
@@ -161,27 +170,42 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
161
  async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
162
  return await model_handler.request_model(request, token)
163
 
164
- def get_all_models():
165
  all_models = []
166
  unique_models = set()
167
 
168
- for provider in config:
169
- for model in provider['model'].keys():
 
 
 
170
  if model not in unique_models:
171
  unique_models.add(model)
172
  model_info = {
173
  "id": model,
174
  "object": "model",
175
  "created": 1720524448858,
176
- "owned_by": provider['provider']
177
  }
178
  all_models.append(model_info)
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  return all_models
181
 
182
- @app.get("/v1/models")
183
- async def list_models():
184
- models = get_all_models()
185
  return {
186
  "object": "list",
187
  "data": models
 
27
 
28
  app = FastAPI(lifespan=lifespan)
29
 
 
 
 
 
 
 
30
  # 安全性依赖
31
  security = HTTPBearer()
32
 
 
43
  return []
44
 
45
  config = load_config()
46
+ for index, provider in enumerate(config['providers']):
47
  model_dict = {}
48
  for model in provider['model']:
49
  if type(model) == str:
 
51
  if type(model) == dict:
52
  model_dict.update({value: key for key, value in model.items()})
53
  provider['model'] = model_dict
54
+ config['providers'][index] = provider
55
+ api_keys_db = config['api_keys']
56
+ api_list = [item["api"] for item in api_keys_db]
57
+ print(json.dumps(config, indent=4, ensure_ascii=False))
58
 
59
  async def process_request(request: RequestModel, provider: Dict):
60
  print("provider: ", provider['provider'])
 
89
  def __init__(self):
90
  self.last_provider_index = -1
91
 
92
+ def get_matching_providers(self, model_name, token):
93
  # for provider in config:
94
  # print("provider", model_name, list(provider['model'].keys()))
95
  # if model_name in provider['model'].keys():
96
  # print("provider", provider)
97
+ api_index = api_list.index(token)
98
+ provider_rules = {}
99
+
100
+ for model in config['api_keys'][api_index]['model']:
101
+ if "/" in model:
102
+ provider_name = model.split("/")[0]
103
+ model = model.split("/")[1]
104
+ if model_name == model:
105
+ provider_rules[provider_name] = model
106
+ provider_list = []
107
+ for provider in config['providers']:
108
+ if model_name in provider['model'].keys() and ((provider_rules != {} and provider['provider'] in provider_rules.keys()) or provider_rules == {}):
109
+ provider_list.append(provider)
110
+ return provider_list
111
 
112
  async def request_model(self, request: RequestModel, token: str):
113
  model_name = request.model
114
+ matching_providers = self.get_matching_providers(model_name, token)
115
+ print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False))
116
 
117
  if not matching_providers:
118
  raise HTTPException(status_code=404, detail="No matching model found")
 
162
 
163
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
164
  token = credentials.credentials
165
+ if token not in api_list:
166
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
167
  return token
168
 
 
170
  async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
171
  return await model_handler.request_model(request, token)
172
 
173
+ def get_all_models(token):
174
  all_models = []
175
  unique_models = set()
176
 
177
+ if token not in api_list:
178
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
179
+ api_index = api_list.index(token)
180
+ if config['api_keys'][api_index]['model']:
181
+ for model in config['api_keys'][api_index]['model']:
182
  if model not in unique_models:
183
  unique_models.add(model)
184
  model_info = {
185
  "id": model,
186
  "object": "model",
187
  "created": 1720524448858,
188
+ "owned_by": model
189
  }
190
  all_models.append(model_info)
191
+ else:
192
+ for provider in config["providers"]:
193
+ for model in provider['model'].keys():
194
+ if model not in unique_models:
195
+ unique_models.add(model)
196
+ model_info = {
197
+ "id": model,
198
+ "object": "model",
199
+ "created": 1720524448858,
200
+ "owned_by": provider['provider']
201
+ }
202
+ all_models.append(model_info)
203
 
204
  return all_models
205
 
206
+ @app.post("/v1/models")
207
+ async def list_models(token: str = Depends(verify_api_key)):
208
+ models = get_all_models(token)
209
  return {
210
  "object": "list",
211
  "data": models
response.py CHANGED
@@ -136,14 +136,20 @@ async def fetch_response(client, url, headers, payload):
136
  return response.json()
137
 
138
  async def fetch_response_stream(client, url, headers, payload, engine, model):
139
- if engine == "gemini":
140
- async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
141
- yield chunk
142
- elif engine == "claude":
143
- async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
144
- yield chunk
145
- elif engine == "gpt":
146
- async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
147
- yield chunk
148
- else:
149
- raise ValueError("Unknown response")
 
 
 
 
 
 
 
136
  return response.json()
137
 
138
  async def fetch_response_stream(client, url, headers, payload, engine, model):
139
+ for _ in range(2):
140
+ try:
141
+ if engine == "gemini":
142
+ async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
143
+ yield chunk
144
+ elif engine == "claude":
145
+ async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
146
+ yield chunk
147
+ elif engine == "gpt":
148
+ async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
149
+ yield chunk
150
+ else:
151
+ raise ValueError("Unknown response")
152
+ break
153
+ except httpx.ConnectError as e:
154
+ print(f"连接错误: {e}")
155
+ continue