yym68686 commited on
Commit
c405f98
·
1 Parent(s): 1af48fa

Supported Claude

Browse files
Files changed (4) hide show
  1. json_str/claude/request.json +72 -0
  2. main.py +6 -6
  3. request.py +95 -1
  4. response.py +32 -38
json_str/claude/request.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "claude-3-5-sonnet-20240620",
3
+ "messages": [
4
+ {
5
+ "role": "user",
6
+ "content": [
7
+ {
8
+ "type": "text",
9
+ "text": "hi"
10
+ }
11
+ ]
12
+ }
13
+ ],
14
+ "temperature": 0.5,
15
+ "top_p": 0.7,
16
+ "max_tokens": 4096,
17
+ "stream": true,
18
+ "system": "You are Claude, a large language model trained by Anthropic. Use simple characters to represent mathematical symbols. Do not use LaTeX commands. Respond conversationally in English.",
19
+ "tools": [
20
+ {
21
+ "name": "get_search_results",
22
+ "description": "Search Google to enhance knowledge.",
23
+ "input_schema": {
24
+ "type": "object",
25
+ "properties": {
26
+ "prompt": {
27
+ "type": "string",
28
+ "description": "The prompt to search."
29
+ }
30
+ },
31
+ "required": [
32
+ "prompt"
33
+ ]
34
+ }
35
+ },
36
+ {
37
+ "name": "get_url_content",
38
+ "description": "Get the webpage content of a URL",
39
+ "input_schema": {
40
+ "type": "object",
41
+ "properties": {
42
+ "url": {
43
+ "type": "string",
44
+ "description": "the URL to request"
45
+ }
46
+ },
47
+ "required": [
48
+ "url"
49
+ ]
50
+ }
51
+ },
52
+ {
53
+ "name": "download_read_arxiv_pdf",
54
+ "description": "Get the content of the paper corresponding to the arXiv ID",
55
+ "input_schema": {
56
+ "type": "object",
57
+ "properties": {
58
+ "prompt": {
59
+ "type": "string",
60
+ "description": "the arXiv ID of the paper"
61
+ }
62
+ },
63
+ "required": [
64
+ "prompt"
65
+ ]
66
+ }
67
+ }
68
+ ],
69
+ "tool_choice": {
70
+ "type": "auto"
71
+ }
72
+ }
main.py CHANGED
@@ -64,12 +64,12 @@ async def process_request(request: RequestModel, provider: Dict):
64
 
65
  url, headers, payload = await get_payload(request, engine, provider)
66
 
67
- # request_info = {
68
- # "url": url,
69
- # "headers": headers,
70
- # "payload": payload
71
- # }
72
- # print(f"Request details: {json.dumps(request_info, indent=2, ensure_ascii=False)}")
73
 
74
  if request.stream:
75
  return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, request.model), media_type="text/event-stream")
 
64
 
65
  url, headers, payload = await get_payload(request, engine, provider)
66
 
67
+ request_info = {
68
+ "url": url,
69
+ "headers": headers,
70
+ "payload": payload
71
+ }
72
+ print(f"Request details: {json.dumps(request_info, indent=2, ensure_ascii=False)}")
73
 
74
  if request.stream:
75
  return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, request.model), media_type="text/event-stream")
request.py CHANGED
@@ -149,8 +149,102 @@ async def get_gpt_payload(request, engine, provider):
149
 
150
  return url, headers, payload
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  async def get_claude_payload(request, engine, provider):
153
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  async def get_payload(request: RequestModel, engine, provider):
156
  if engine == "gemini":
 
149
 
150
  return url, headers, payload
151
 
152
+ async def gpt2claude_tools_json(json_dict):
153
+ import copy
154
+ json_dict = copy.deepcopy(json_dict)
155
+ keys_to_change = {
156
+ "parameters": "input_schema",
157
+ }
158
+ for old_key, new_key in keys_to_change.items():
159
+ if old_key in json_dict:
160
+ if new_key:
161
+ json_dict[new_key] = json_dict.pop(old_key)
162
+ else:
163
+ json_dict.pop(old_key)
164
+ # if "tools" in json_dict.keys():
165
+ # json_dict["tool_choice"] = {
166
+ # "type": "auto"
167
+ # }
168
+ return json_dict
169
+
170
  async def get_claude_payload(request, engine, provider):
171
+ headers = {
172
+ "content-type": "application/json",
173
+ "x-api-key": f"{provider['api']}",
174
+ "anthropic-version": "2023-06-01",
175
+ "anthropic-beta": "tools-2024-05-16"
176
+ }
177
+ url = provider['base_url']
178
+
179
+ messages = []
180
+ for msg in request.messages:
181
+ if isinstance(msg.content, list):
182
+ content = []
183
+ for item in msg.content:
184
+ if item.type == "text":
185
+ text_message = await get_text_message(msg.role, item.text, engine)
186
+ content.append(text_message)
187
+ elif item.type == "image_url":
188
+ image_message = await get_image_message(item.image_url.url, engine)
189
+ content.append(image_message)
190
+ else:
191
+ content = msg.content
192
+ name = msg.name
193
+ if name:
194
+ messages.append({"role": msg.role, "name": name, "content": content})
195
+ elif msg.role != "system":
196
+ messages.append({"role": msg.role, "content": content})
197
+ elif msg.role == "system":
198
+ system_prompt = content
199
+
200
+ payload = {
201
+ "model": request.model,
202
+ "messages": messages,
203
+ "system": system_prompt,
204
+ }
205
+ # json_post = {
206
+ # "model": model or self.engine,
207
+ # "messages": self.conversation[convo_id] if pass_history else [{
208
+ # "role": "user",
209
+ # "content": prompt
210
+ # }],
211
+ # "temperature": kwargs.get("temperature", self.temperature),
212
+ # "top_p": kwargs.get("top_p", self.top_p),
213
+ # "max_tokens": model_max_tokens,
214
+ # "stream": True,
215
+ # }
216
+
217
+ miss_fields = [
218
+ 'model',
219
+ 'messages',
220
+ 'presence_penalty',
221
+ 'frequency_penalty',
222
+ 'n',
223
+ 'user',
224
+ 'include_usage',
225
+ ]
226
+
227
+ for field, value in request.model_dump(exclude_unset=True).items():
228
+ if field not in miss_fields and value is not None:
229
+ payload[field] = value
230
+
231
+ tools = []
232
+ for tool in request.tools:
233
+ print("tool", type(tool), tool)
234
+
235
+ json_tool = await gpt2claude_tools_json(tool.dict()["function"])
236
+ tools.append(json_tool)
237
+ payload["tools"] = tools
238
+ # del payload["type"]
239
+ # del payload["function"]
240
+ if "tool_choice" in payload:
241
+ payload["tool_choice"] = {
242
+ "type": "auto"
243
+ }
244
+ import json
245
+ print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
246
+
247
+ return url, headers, payload
248
 
249
  async def get_payload(request: RequestModel, engine, provider):
250
  if engine == "gemini":
response.py CHANGED
@@ -72,47 +72,42 @@ async def fetch_gpt_response_stream(client, url, headers, payload):
72
  except httpx.ConnectError as e:
73
  print(f"连接错误: {e}")
74
 
75
- async def fetch_claude_response_stream(client, url, headers, payload, engine, model):
76
  try:
77
  timestamp = datetime.timestamp(datetime.now())
78
  async with client.stream('POST', url, headers=headers, json=payload) as response:
79
- buffer = ""
80
- async for chunk in response.aiter_text():
81
- buffer += chunk
82
- while "\n" in buffer:
83
- line, buffer = buffer.split("\n", 1)
84
- # print(line)
85
- if engine == "gemini":
86
- if line and '\"text\": \"' in line:
87
- try:
88
- json_data = json.loads( "{" + line + "}")
89
- content = json_data.get('text', '')
90
- content = "\n".join(content.split("\\n"))
91
- sse_string = await generate_sse_response(timestamp, model, content)
92
- yield sse_string
93
- except json.JSONDecodeError:
94
- print(f"无法解析JSON: {line}")
95
- else:
96
- yield line + "\n"
97
-
98
- # 处理缓冲区中剩余的内容
99
- if buffer:
100
- # print(buffer)
101
- if engine == "gemini":
102
- if '\"text\": \"' in buffer:
103
- try:
104
- json_data = json.loads(buffer)
105
- content = json_data.get('text', '')
106
- content = "\n".join(content.split("\\n"))
107
  sse_string = await generate_sse_response(timestamp, model, content)
 
108
  yield sse_string
109
- except json.JSONDecodeError:
110
- print(f"无法解析JSON: {buffer}")
111
- else:
112
- yield buffer
113
-
114
- if engine == "gemini":
115
- yield "data: [DONE]\n\n"
116
  except httpx.ConnectError as e:
117
  print(f"连接错误: {e}")
118
 
@@ -121,12 +116,11 @@ async def fetch_response(client, url, headers, payload):
121
  return response.json()
122
 
123
  async def fetch_response_stream(client, url, headers, payload, engine, model):
124
- print(f"Engine: {engine}")
125
  if engine == "gemini":
126
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
127
  yield chunk
128
  elif engine == "claude":
129
- async for chunk in fetch_claude_response_stream(client, url, headers, payload, engine, model):
130
  yield chunk
131
  elif engine == "gpt":
132
  async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
 
72
  except httpx.ConnectError as e:
73
  print(f"连接错误: {e}")
74
 
75
+ async def fetch_claude_response_stream(client, url, headers, payload, model):
76
  try:
77
  timestamp = datetime.timestamp(datetime.now())
78
  async with client.stream('POST', url, headers=headers, json=payload) as response:
79
+ async for chunk in response.aiter_bytes():
80
+ chunk_line = chunk.decode('utf-8').split("\n")
81
+ for chunk in chunk_line:
82
+ if chunk.startswith("data:"):
83
+ line = chunk[6:]
84
+ # print(line)
85
+ resp: dict = json.loads(line)
86
+ message = resp.get("message")
87
+ if message:
88
+ tokens_use = resp.get("usage")
89
+ if tokens_use:
90
+ total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
91
+ # print("\n\rtotal_tokens", total_tokens)
92
+ # tool_use = resp.get("content_block")
93
+ # if tool_use and "tool_use" == tool_use['type']:
94
+ # # print("tool_use", tool_use)
95
+ # tools_id = tool_use["id"]
96
+ # need_function_call = True
97
+ # if "name" in tool_use:
98
+ # function_call_name = tool_use["name"]
99
+ delta = resp.get("delta")
100
+ # print("delta", delta)
101
+ if not delta:
102
+ continue
103
+ if "text" in delta:
104
+ content = delta["text"]
 
 
105
  sse_string = await generate_sse_response(timestamp, model, content)
106
+ print(sse_string)
107
  yield sse_string
108
+ # if "partial_json" in delta:
109
+ # function_call_content = delta["partial_json"]
110
+ yield "data: [DONE]\n\n"
 
 
 
 
111
  except httpx.ConnectError as e:
112
  print(f"连接错误: {e}")
113
 
 
116
  return response.json()
117
 
118
  async def fetch_response_stream(client, url, headers, payload, engine, model):
 
119
  if engine == "gemini":
120
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
121
  yield chunk
122
  elif engine == "claude":
123
+ async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
124
  yield chunk
125
  elif engine == "gpt":
126
  async for chunk in fetch_gpt_response_stream(client, url, headers, payload):