yym68686 commited on
Commit
a064fa5
·
1 Parent(s): d780861

Support images, plugins, search Q&A

Browse files
Files changed (1) hide show
  1. main.py +55 -24
main.py CHANGED
@@ -51,12 +51,34 @@ def load_config():
51
  config = load_config()
52
  # print(config)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  class ContentItem(BaseModel):
55
  type: str
56
- text: str
 
57
 
58
  class Message(BaseModel):
59
  role: str
 
60
  content: Union[str, List[ContentItem]]
61
 
62
  class RequestModel(BaseModel):
@@ -66,10 +88,20 @@ class RequestModel(BaseModel):
66
  top_logprobs: Optional[int] = None
67
  stream: Optional[bool] = None
68
  include_usage: Optional[bool] = None
 
 
 
 
 
 
 
 
 
69
 
70
  async def fetch_response_stream(client, url, headers, payload):
71
  async with client.stream('POST', url, headers=headers, json=payload) as response:
72
  async for chunk in response.aiter_bytes():
 
73
  yield chunk
74
 
75
  async def fetch_response(client, url, headers, payload):
@@ -88,37 +120,36 @@ async def process_request(request: RequestModel, provider: Dict):
88
  messages = []
89
  for msg in request.messages:
90
  if isinstance(msg.content, list):
91
- content = " ".join([item.text for item in msg.content if item.type == "text"])
 
 
 
 
 
92
  else:
93
  content = msg.content
94
- messages.append({"role": msg.role, "content": content})
 
 
 
 
 
95
 
96
  payload = {
97
  "model": request.model,
98
  "messages": messages
99
  }
100
 
101
- # 只有当相应参数存在且不为None时,才添加到payload中
102
- # print("request: ", request)
103
- if request.stream is not None:
104
- payload["stream"] = request.stream
105
- if request.include_usage is not None:
106
- payload["include_usage"] = request.include_usage
107
 
108
- if provider['provider'] == 'anthropic':
109
- payload["max_tokens"] = 1000 # 您可能想让这个可配置
110
- else:
111
- if request.logprobs is not None:
112
- payload["logprobs"] = request.logprobs
113
- if request.top_logprobs is not None:
114
- payload["top_logprobs"] = request.top_logprobs
115
-
116
- # request_info = {
117
- # "url": url,
118
- # "headers": headers,
119
- # "payload": payload
120
- # }
121
- # print(f"Request details: {json.dumps(request_info, indent=2, ensure_ascii=False)}")
122
  if request.stream:
123
  return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
124
  else:
@@ -134,7 +165,7 @@ class ModelRequestHandler:
134
  async def request_model(self, request: RequestModel, token: str):
135
  model_name = request.model
136
  matching_providers = self.get_matching_providers(model_name)
137
- print("matching_providers", json.dumps(matching_providers, indent=2, ensure_ascii=False))
138
 
139
  if not matching_providers:
140
  raise HTTPException(status_code=404, detail="No matching model found")
 
51
  config = load_config()
52
  # print(config)
53
 
54
+ # 定义 Function 参数模型
55
+ class FunctionParameter(BaseModel):
56
+ type: str
57
+ properties: Dict[str, Dict[str, str]]
58
+ required: List[str]
59
+
60
+ # 定义 Function 模型
61
+ class Function(BaseModel):
62
+ name: str
63
+ description: str
64
+ parameters: FunctionParameter
65
+
66
+ # 定义 Tool 模型
67
+ class Tool(BaseModel):
68
+ type: str
69
+ function: Function
70
+
71
+ class ImageUrl(BaseModel):
72
+ url: str
73
+
74
  class ContentItem(BaseModel):
75
  type: str
76
+ text: Optional[str] = None
77
+ image_url: Optional[ImageUrl] = None
78
 
79
  class Message(BaseModel):
80
  role: str
81
+ name: Optional[str] = None
82
  content: Union[str, List[ContentItem]]
83
 
84
  class RequestModel(BaseModel):
 
88
  top_logprobs: Optional[int] = None
89
  stream: Optional[bool] = None
90
  include_usage: Optional[bool] = None
91
+ temperature: Optional[float] = 0.5
92
+ top_p: Optional[float] = 1.0
93
+ max_tokens: Optional[int] = None
94
+ presence_penalty: Optional[float] = 0.0
95
+ frequency_penalty: Optional[float] = 0.0
96
+ n: Optional[int] = 1
97
+ user: Optional[str] = None
98
+ tool_choice: Optional[str] = None
99
+ tools: Optional[List[Tool]] = None
100
 
101
  async def fetch_response_stream(client, url, headers, payload):
102
  async with client.stream('POST', url, headers=headers, json=payload) as response:
103
  async for chunk in response.aiter_bytes():
104
+ print(chunk.decode('utf-8'))
105
  yield chunk
106
 
107
  async def fetch_response(client, url, headers, payload):
 
120
  messages = []
121
  for msg in request.messages:
122
  if isinstance(msg.content, list):
123
+ content = []
124
+ for item in msg.content:
125
+ if item.type == "text":
126
+ content.append({"type": "text", "text": item.text})
127
+ elif item.type == "image_url":
128
+ content.append({"type": "image_url", "image_url": item.image_url.dict()})
129
  else:
130
  content = msg.content
131
+ name = msg.name
132
+ if name:
133
+ messages.append({"role": msg.role, "name": name, "content": content})
134
+ else:
135
+ messages.append({"role": msg.role, "content": content})
136
+
137
 
138
  payload = {
139
  "model": request.model,
140
  "messages": messages
141
  }
142
 
143
+ for field, value in request.dict(exclude_unset=True).items():
144
+ if field not in ['model', 'messages'] and value is not None:
145
+ payload[field] = value
 
 
 
146
 
147
+ request_info = {
148
+ "url": url,
149
+ "headers": headers,
150
+ "payload": payload
151
+ }
152
+ print(f"Request details: {json.dumps(request_info, indent=2, ensure_ascii=False)}")
 
 
 
 
 
 
 
 
153
  if request.stream:
154
  return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload), media_type="text/event-stream")
155
  else:
 
165
  async def request_model(self, request: RequestModel, token: str):
166
  model_name = request.model
167
  matching_providers = self.get_matching_providers(model_name)
168
+ # print("matching_providers", json.dumps(matching_providers, indent=2, ensure_ascii=False))
169
 
170
  if not matching_providers:
171
  raise HTTPException(status_code=404, detail="No matching model found")