yym68686 commited on
Commit
1ddc959
·
1 Parent(s): 0b0a2f7

update api

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. README.md +1 -0
  3. main.py +99 -56
.gitignore CHANGED
@@ -1 +1,2 @@
1
- api.json
 
 
1
+ api.json
2
+ api.yaml
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # uni-api
main.py CHANGED
@@ -1,25 +1,53 @@
1
- import asyncio
2
  import httpx
3
- import json
4
- from fastapi import FastAPI, HTTPException, Request
5
- from pydantic import BaseModel, Field
 
 
 
 
 
6
  from typing import List, Dict, Any, Optional, Union
7
 
8
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # 读取JSON配置文件
11
  def load_config():
12
  try:
13
- with open('config.json', 'r') as f:
14
- return json.load(f)
15
  except FileNotFoundError:
16
- print("配置文件 'config.json' 未找到。请确保文件存在于正确的位置。")
17
  return []
18
- except json.JSONDecodeError:
19
- print("配置文件 'config.json' 格式不正确。请检查JSON格式。")
20
  return []
21
 
22
  config = load_config()
 
23
 
24
  class ContentItem(BaseModel):
25
  type: str
@@ -37,59 +65,74 @@ class RequestModel(BaseModel):
37
  stream: Optional[bool] = False
38
  include_usage: Optional[bool] = False
39
 
 
 
 
 
 
40
  async def fetch_response(client, url, headers, payload):
 
 
 
 
 
 
 
41
  response = await client.post(url, headers=headers, json=payload)
 
42
  return response.json()
43
 
44
- @app.post("/request_model")
45
- async def request_model(request: RequestModel):
46
  model_name = request.model
47
 
48
- tasks = []
49
- async with httpx.AsyncClient() as client:
50
- for provider in config:
51
- if model_name in provider['model']:
52
- url = provider['base_url']
53
- headers = {
54
- 'Authorization': f"Bearer {provider['api']}",
55
- 'Content-Type': 'application/json'
56
- }
57
-
58
- # 转换消息格式
59
- messages = []
60
- for msg in request.messages:
61
- if isinstance(msg.content, list):
62
- content = " ".join([item.text for item in msg.content if item.type == "text"])
63
- else:
64
- content = msg.content
65
- messages.append({"role": msg.role, "content": content})
66
-
67
- payload = {
68
- "model": model_name,
69
- "messages": messages,
70
- "stream": request.stream,
71
- "include_usage": request.include_usage
72
- }
73
-
74
- if provider['provider'] == 'anthropic':
75
- payload["max_tokens"] = 1000 # 您可能想让这个可配置
76
  else:
77
- if request.logprobs:
78
- payload["logprobs"] = request.logprobs
79
- if request.top_logprobs:
80
- payload["top_logprobs"] = request.top_logprobs
81
-
82
- tasks.append(fetch_response(client, url, headers, payload))
83
-
84
- if not tasks:
85
- raise HTTPException(status_code=404, detail="No matching model found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- try:
88
- responses = await asyncio.gather(*tasks)
89
- return responses
90
- except Exception as e:
91
- raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
92
 
93
  if __name__ == '__main__':
94
  import uvicorn
95
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
  import httpx
2
+ import yaml
3
+ from contextlib import asynccontextmanager
4
+
5
+ from fastapi import FastAPI, HTTPException, Depends
6
+ from fastapi.responses import StreamingResponse
7
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
+
9
+ from pydantic import BaseModel
10
  from typing import List, Dict, Any, Optional, Union
11
 
12
+ # 模拟存储API Key的数据库
13
+ api_keys_db = {
14
+ "sk-KjjI60Yf0JFcsvgRmXqFwgGmWUd9GZnmi3KlvowmRWpWpQRo": "user1",
15
+ # 可以添加更多的API Key
16
+ }
17
+
18
+ # 安全性依赖
19
+ security = HTTPBearer()
20
+
21
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
22
+ token = credentials.credentials
23
+ if token not in api_keys_db:
24
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
25
+ return token
26
+
27
+ @asynccontextmanager
28
+ async def lifespan(app: FastAPI):
29
+ # 启动时的代码
30
+ app.state.client = httpx.AsyncClient()
31
+ yield
32
+ # 关闭时的代码
33
+ await app.state.client.aclose()
34
+
35
+ app = FastAPI(lifespan=lifespan)
36
 
37
+ # 读取YAML配置文件
38
  def load_config():
39
  try:
40
+ with open('api.yaml', 'r') as f:
41
+ return yaml.safe_load(f)
42
  except FileNotFoundError:
43
+ print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
44
  return []
45
+ except yaml.YAMLError:
46
+ print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
47
  return []
48
 
49
  config = load_config()
50
+ # print(config)
51
 
52
  class ContentItem(BaseModel):
53
  type: str
 
65
  stream: Optional[bool] = False
66
  include_usage: Optional[bool] = False
67
 
68
+ async def fetch_response_stream(client, url, headers, payload):
69
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
70
+ async for chunk in response.aiter_bytes():
71
+ yield chunk
72
+
73
  async def fetch_response(client, url, headers, payload):
74
+ # request_info = {
75
+ # "url": url,
76
+ # "headers": headers,
77
+ # "payload": payload
78
+ # }
79
+ # print(f"Request details: {json.dumps(request_info, indent=2, ensure_ascii=False)}")
80
+
81
  response = await client.post(url, headers=headers, json=payload)
82
+ # print(response.text)
83
  return response.json()
84
 
85
+ @app.post("/v1/chat/completions")
86
+ async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
87
  model_name = request.model
88
 
89
+ for provider in config:
90
+ if model_name in provider['model']:
91
+ print("provider: ", provider['provider'])
92
+ url = provider['base_url']
93
+ headers = {
94
+ 'Authorization': f"Bearer {provider['api']}",
95
+ 'Content-Type': 'application/json'
96
+ }
97
+
98
+ # 转换消息格式
99
+ messages = []
100
+ for msg in request.messages:
101
+ if isinstance(msg.content, list):
102
+ content = " ".join([item.text for item in msg.content if item.type == "text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  else:
104
+ content = msg.content
105
+ messages.append({"role": msg.role, "content": content})
106
+
107
+ payload = {
108
+ "model": model_name,
109
+ "messages": messages
110
+ }
111
+
112
+ # 只有当相应参数存在且不为None时,才添加到payload中
113
+ if request.stream is not None:
114
+ payload["stream"] = request.stream
115
+ if request.include_usage is not None:
116
+ payload["include_usage"] = request.include_usage
117
+
118
+ if provider['provider'] == 'anthropic':
119
+ payload["max_tokens"] = 1000 # 您可能想让这个可配置
120
+ else:
121
+ if request.logprobs is not None:
122
+ payload["logprobs"] = request.logprobs
123
+ if request.top_logprobs is not None:
124
+ payload["top_logprobs"] = request.top_logprobs
125
+
126
+ try:
127
+ if request.stream:
128
+ return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
129
+ else:
130
+ return await fetch_response(app.state.client, url, headers, payload)
131
+ except Exception as e:
132
+ raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
133
 
134
+ raise HTTPException(status_code=404, detail="No matching model found")
 
 
 
 
135
 
136
  if __name__ == '__main__':
137
  import uvicorn
138
+ uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)