yym68686 commited on
Commit
d62232c
·
1 Parent(s): 2cdadb6

Refactor code

Browse files
Files changed (2) hide show
  1. main.py +5 -150
  2. utils.py +136 -0
main.py CHANGED
@@ -1,21 +1,20 @@
1
  import json
2
- import yaml
 
3
  import httpx
4
- import logging
5
  import secrets
6
- import traceback
7
  from contextlib import asynccontextmanager
8
 
9
- from fastapi import FastAPI, Request, HTTPException, Depends
10
- from fastapi.responses import StreamingResponse
11
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
 
13
  from models import RequestModel
 
14
  from request import get_payload
15
  from response import fetch_response, fetch_response_stream
16
 
17
  from typing import List, Dict
18
  from urllib.parse import urlparse
 
19
 
20
  @asynccontextmanager
21
  async def lifespan(app: FastAPI):
@@ -28,66 +27,6 @@ async def lifespan(app: FastAPI):
28
 
29
  app = FastAPI(lifespan=lifespan)
30
 
31
- # 安全性依赖
32
- security = HTTPBearer()
33
-
34
- # 读取YAML配置文件
35
- def load_config():
36
- try:
37
- with open('./api.yaml', 'r') as f:
38
- conf = yaml.safe_load(f)
39
- for index, provider in enumerate(conf['providers']):
40
- model_dict = {}
41
- for model in provider['model']:
42
- if type(model) == str:
43
- model_dict[model] = model
44
- if type(model) == dict:
45
- model_dict.update({value: key for key, value in model.items()})
46
- provider['model'] = model_dict
47
- conf['providers'][index] = provider
48
- api_keys_db = conf['api_keys']
49
- api_list = [item["api"] for item in api_keys_db]
50
- # print(json.dumps(conf, indent=4, ensure_ascii=False))
51
- return conf, api_keys_db, api_list
52
- except FileNotFoundError:
53
- print("配置文件 'api.yaml' 未找到。请确保文件存在于正确的位置。")
54
- return [], [], []
55
- except yaml.YAMLError:
56
- print("配置文件 'api.yaml' 格式不正确。请检查 YAML 格式。")
57
- return [], [], []
58
-
59
- config, api_keys_db, api_list = load_config()
60
-
61
- async def error_handling_wrapper(generator, status_code=200):
62
- try:
63
- first_item = await generator.__anext__()
64
- first_item_str = first_item
65
- if isinstance(first_item_str, (bytes, bytearray)):
66
- first_item_str = first_item_str.decode("utf-8")
67
- if isinstance(first_item_str, str):
68
- if first_item_str.startswith("data: "):
69
- first_item_str = first_item_str[6:]
70
- elif first_item_str.startswith("data:"):
71
- first_item_str = first_item_str[5:]
72
- first_item_str = json.loads(first_item_str)
73
- if isinstance(first_item_str, dict) and 'error' in first_item_str:
74
- print('\033[31m')
75
- print(f"first_item_str: {first_item_str}")
76
- print('\033[0m')
77
- # 如果第一个 yield 的项是错误信息,抛出 HTTPException
78
- raise HTTPException(status_code=status_code, detail=first_item_str)
79
-
80
- # 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
81
- async def new_generator():
82
- yield first_item
83
- async for item in generator:
84
- yield item
85
-
86
- return new_generator()
87
- except StopAsyncIteration:
88
- # 处理生成器为空的情况
89
- return []
90
-
91
  async def process_request(request: RequestModel, provider: Dict):
92
  print("provider: ", provider['provider'])
93
  url = provider['base_url']
@@ -201,94 +140,10 @@ class ModelRequestHandler:
201
 
202
  model_handler = ModelRequestHandler()
203
 
204
- @app.middleware("http")
205
- async def log_requests(request: Request, call_next):
206
- # 打印请求信息
207
- logging.info(f"Request: {request.method} {request.url}")
208
- # 打印请求体(如果有)
209
- if request.method in ["POST", "PUT", "PATCH"]:
210
- body = await request.body()
211
- logging.info(f"Request Body: {body.decode('utf-8')}")
212
-
213
- response = await call_next(request)
214
- return response
215
-
216
- def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
217
- token = credentials.credentials
218
- if token not in api_list:
219
- raise HTTPException(status_code=403, detail="Invalid or missing API Key")
220
- return token
221
-
222
  @app.post("/v1/chat/completions")
223
  async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
224
  return await model_handler.request_model(request, token)
225
 
226
- def get_all_models(token):
227
- all_models = []
228
- unique_models = set()
229
-
230
- if token not in api_list:
231
- raise HTTPException(status_code=403, detail="Invalid or missing API Key")
232
- api_index = api_list.index(token)
233
- if config['api_keys'][api_index]['model']:
234
- for model in config['api_keys'][api_index]['model']:
235
- if "/" in model:
236
- provider = model.split("/")[0]
237
- model = model.split("/")[1]
238
- if model == "*":
239
- for provider_item in config["providers"]:
240
- if provider_item['provider'] != provider:
241
- continue
242
- for model_item in provider_item['model'].keys():
243
- if model_item not in unique_models:
244
- unique_models.add(model_item)
245
- model_info = {
246
- "id": model_item,
247
- "object": "model",
248
- "created": 1720524448858,
249
- "owned_by": provider_item['provider']
250
- }
251
- all_models.append(model_info)
252
- else:
253
- for provider_item in config["providers"]:
254
- if provider_item['provider'] != provider:
255
- continue
256
- for model_item in provider_item['model'].keys() :
257
- if model_item not in unique_models and model_item == model:
258
- unique_models.add(model_item)
259
- model_info = {
260
- "id": model_item,
261
- "object": "model",
262
- "created": 1720524448858,
263
- "owned_by": provider_item['provider']
264
- }
265
- all_models.append(model_info)
266
- continue
267
-
268
- if model not in unique_models:
269
- unique_models.add(model)
270
- model_info = {
271
- "id": model,
272
- "object": "model",
273
- "created": 1720524448858,
274
- "owned_by": model
275
- }
276
- all_models.append(model_info)
277
- else:
278
- for provider in config["providers"]:
279
- for model in provider['model'].keys():
280
- if model not in unique_models:
281
- unique_models.add(model)
282
- model_info = {
283
- "id": model,
284
- "object": "model",
285
- "created": 1720524448858,
286
- "owned_by": provider['provider']
287
- }
288
- all_models.append(model_info)
289
-
290
- return all_models
291
-
292
  @app.post("/v1/models")
293
  async def list_models(token: str = Depends(verify_api_key)):
294
  models = get_all_models(token)
 
1
  import json
2
+ import traceback
3
+
4
  import httpx
 
5
  import secrets
 
6
  from contextlib import asynccontextmanager
7
 
8
+ from fastapi import FastAPI, HTTPException, Depends
 
 
9
 
10
  from models import RequestModel
11
+ from utils import config, api_keys_db, api_list, error_handling_wrapper, get_all_models, verify_api_key
12
  from request import get_payload
13
  from response import fetch_response, fetch_response_stream
14
 
15
  from typing import List, Dict
16
  from urllib.parse import urlparse
17
+ from fastapi.responses import StreamingResponse
18
 
19
  @asynccontextmanager
20
  async def lifespan(app: FastAPI):
 
27
 
28
  app = FastAPI(lifespan=lifespan)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  async def process_request(request: RequestModel, provider: Dict):
31
  print("provider: ", provider['provider'])
32
  url = provider['base_url']
 
140
 
141
  model_handler = ModelRequestHandler()
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  @app.post("/v1/chat/completions")
144
  async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
145
  return await model_handler.request_model(request, token)
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  @app.post("/v1/models")
148
  async def list_models(token: str = Depends(verify_api_key)):
149
  models = get_all_models(token)
utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import json
3
+ from fastapi import HTTPException, Depends
4
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
5
+
6
+ # 读取YAML配置文件
7
+ def load_config():
8
+ try:
9
+ with open('./api.yaml', 'r') as f:
10
+ conf = yaml.safe_load(f)
11
+ for index, provider in enumerate(conf['providers']):
12
+ model_dict = {}
13
+ for model in provider['model']:
14
+ if type(model) == str:
15
+ model_dict[model] = model
16
+ if type(model) == dict:
17
+ model_dict.update({value: key for key, value in model.items()})
18
+ provider['model'] = model_dict
19
+ conf['providers'][index] = provider
20
+ api_keys_db = conf['api_keys']
21
+ api_list = [item["api"] for item in api_keys_db]
22
+ # print(json.dumps(conf, indent=4, ensure_ascii=False))
23
+ return conf, api_keys_db, api_list
24
+ except FileNotFoundError:
25
+ print("配置文件 'api.yaml' 未找到。请确保文件存在于正确的位置。")
26
+ return [], [], []
27
+ except yaml.YAMLError:
28
+ print("配置文件 'api.yaml' 格式不正确。请检查 YAML 格式。")
29
+ return [], [], []
30
+
31
+ config, api_keys_db, api_list = load_config()
32
+
33
+ async def error_handling_wrapper(generator, status_code=200):
34
+ try:
35
+ first_item = await generator.__anext__()
36
+ first_item_str = first_item
37
+ if isinstance(first_item_str, (bytes, bytearray)):
38
+ first_item_str = first_item_str.decode("utf-8")
39
+ if isinstance(first_item_str, str):
40
+ if first_item_str.startswith("data: "):
41
+ first_item_str = first_item_str[6:]
42
+ elif first_item_str.startswith("data:"):
43
+ first_item_str = first_item_str[5:]
44
+ first_item_str = json.loads(first_item_str)
45
+ if isinstance(first_item_str, dict) and 'error' in first_item_str:
46
+ print('\033[31m')
47
+ print(f"first_item_str: {first_item_str}")
48
+ print('\033[0m')
49
+ # 如果第一个 yield 的项是错误信息,抛出 HTTPException
50
+ raise HTTPException(status_code=status_code, detail=first_item_str)
51
+
52
+ # 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
53
+ async def new_generator():
54
+ yield first_item
55
+ async for item in generator:
56
+ yield item
57
+
58
+ return new_generator()
59
+ except StopAsyncIteration:
60
+ # 处理生成器为空的情况
61
+ return []
62
+
63
+ def get_all_models(token):
64
+ all_models = []
65
+ unique_models = set()
66
+
67
+ if token not in api_list:
68
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
69
+ api_index = api_list.index(token)
70
+ if config['api_keys'][api_index]['model']:
71
+ for model in config['api_keys'][api_index]['model']:
72
+ if "/" in model:
73
+ provider = model.split("/")[0]
74
+ model = model.split("/")[1]
75
+ if model == "*":
76
+ for provider_item in config["providers"]:
77
+ if provider_item['provider'] != provider:
78
+ continue
79
+ for model_item in provider_item['model'].keys():
80
+ if model_item not in unique_models:
81
+ unique_models.add(model_item)
82
+ model_info = {
83
+ "id": model_item,
84
+ "object": "model",
85
+ "created": 1720524448858,
86
+ "owned_by": provider_item['provider']
87
+ }
88
+ all_models.append(model_info)
89
+ else:
90
+ for provider_item in config["providers"]:
91
+ if provider_item['provider'] != provider:
92
+ continue
93
+ for model_item in provider_item['model'].keys() :
94
+ if model_item not in unique_models and model_item == model:
95
+ unique_models.add(model_item)
96
+ model_info = {
97
+ "id": model_item,
98
+ "object": "model",
99
+ "created": 1720524448858,
100
+ "owned_by": provider_item['provider']
101
+ }
102
+ all_models.append(model_info)
103
+ continue
104
+
105
+ if model not in unique_models:
106
+ unique_models.add(model)
107
+ model_info = {
108
+ "id": model,
109
+ "object": "model",
110
+ "created": 1720524448858,
111
+ "owned_by": model
112
+ }
113
+ all_models.append(model_info)
114
+ else:
115
+ for provider in config["providers"]:
116
+ for model in provider['model'].keys():
117
+ if model not in unique_models:
118
+ unique_models.add(model)
119
+ model_info = {
120
+ "id": model,
121
+ "object": "model",
122
+ "created": 1720524448858,
123
+ "owned_by": provider['provider']
124
+ }
125
+ all_models.append(model_info)
126
+
127
+ return all_models
128
+
129
+ # 安全性依赖
130
+ security = HTTPBearer()
131
+
132
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
133
+ token = credentials.credentials
134
+ if token not in api_list:
135
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
136
+ return token