yym68686 commited on
Commit
0b0a2f7
·
0 Parent(s):

first commit

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. main.py +95 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ api.json
main.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import httpx
3
+ import json
4
+ from fastapi import FastAPI, HTTPException, Request
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Dict, Any, Optional, Union
7
+
8
+ app = FastAPI()
9
+
10
+ # 读取JSON配置文件
11
+ def load_config():
12
+ try:
13
+ with open('config.json', 'r') as f:
14
+ return json.load(f)
15
+ except FileNotFoundError:
16
+ print("配置文件 'config.json' 未找到。请确保文件存在于正确的位置。")
17
+ return []
18
+ except json.JSONDecodeError:
19
+ print("配置文件 'config.json' 格式不正确。请检查JSON格式。")
20
+ return []
21
+
22
+ config = load_config()
23
+
24
+ class ContentItem(BaseModel):
25
+ type: str
26
+ text: str
27
+
28
+ class Message(BaseModel):
29
+ role: str
30
+ content: Union[str, List[ContentItem]]
31
+
32
+ class RequestModel(BaseModel):
33
+ model: str
34
+ messages: List[Message]
35
+ logprobs: Optional[bool] = False
36
+ top_logprobs: Optional[int] = None
37
+ stream: Optional[bool] = False
38
+ include_usage: Optional[bool] = False
39
+
40
+ async def fetch_response(client, url, headers, payload):
41
+ response = await client.post(url, headers=headers, json=payload)
42
+ return response.json()
43
+
44
+ @app.post("/request_model")
45
+ async def request_model(request: RequestModel):
46
+ model_name = request.model
47
+
48
+ tasks = []
49
+ async with httpx.AsyncClient() as client:
50
+ for provider in config:
51
+ if model_name in provider['model']:
52
+ url = provider['base_url']
53
+ headers = {
54
+ 'Authorization': f"Bearer {provider['api']}",
55
+ 'Content-Type': 'application/json'
56
+ }
57
+
58
+ # 转换消息格式
59
+ messages = []
60
+ for msg in request.messages:
61
+ if isinstance(msg.content, list):
62
+ content = " ".join([item.text for item in msg.content if item.type == "text"])
63
+ else:
64
+ content = msg.content
65
+ messages.append({"role": msg.role, "content": content})
66
+
67
+ payload = {
68
+ "model": model_name,
69
+ "messages": messages,
70
+ "stream": request.stream,
71
+ "include_usage": request.include_usage
72
+ }
73
+
74
+ if provider['provider'] == 'anthropic':
75
+ payload["max_tokens"] = 1000 # 您可能想让这个可配置
76
+ else:
77
+ if request.logprobs:
78
+ payload["logprobs"] = request.logprobs
79
+ if request.top_logprobs:
80
+ payload["top_logprobs"] = request.top_logprobs
81
+
82
+ tasks.append(fetch_response(client, url, headers, payload))
83
+
84
+ if not tasks:
85
+ raise HTTPException(status_code=404, detail="No matching model found")
86
+
87
+ try:
88
+ responses = await asyncio.gather(*tasks)
89
+ return responses
90
+ except Exception as e:
91
+ raise HTTPException(status_code=500, detail=f"Error calling API: {str(e)}")
92
+
93
+ if __name__ == '__main__':
94
+ import uvicorn
95
+ uvicorn.run(app, host="0.0.0.0", port=8000)