Fixed the bug where the Claude role could not be obtained and the SSE format was incorrect.
Browse files- json_str/gpt/mess_sse.json +12 -0
- main.py +38 -11
- request.py +6 -4
- response.py +7 -1
json_str/gpt/mess_sse.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}
|
2 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null}
|
3 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null}
|
4 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null}
|
5 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null}
|
6 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null}
|
7 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null}
|
8 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null}
|
9 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null}
|
10 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null}
|
11 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}
|
12 |
+
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[],"usage":{"prompt_tokens":178,"completion_tokens":10,"total_tokens":188}}
|
main.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import httpx
|
|
|
4 |
import yaml
|
5 |
import traceback
|
6 |
from contextlib import asynccontextmanager
|
7 |
|
8 |
-
from fastapi import FastAPI, HTTPException, Depends
|
9 |
from fastapi.responses import StreamingResponse
|
10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
|
@@ -48,7 +49,16 @@ def load_config():
|
|
48 |
return []
|
49 |
|
50 |
config = load_config()
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
async def process_request(request: RequestModel, provider: Dict):
|
54 |
print("provider: ", provider['provider'])
|
@@ -64,15 +74,16 @@ async def process_request(request: RequestModel, provider: Dict):
|
|
64 |
|
65 |
url, headers, payload = await get_payload(request, engine, provider)
|
66 |
|
67 |
-
request_info = {
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
}
|
72 |
-
print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
|
73 |
|
74 |
if request.stream:
|
75 |
-
|
|
|
76 |
else:
|
77 |
return await fetch_response(app.state.client, url, headers, payload)
|
78 |
|
@@ -81,7 +92,11 @@ class ModelRequestHandler:
|
|
81 |
self.last_provider_index = -1
|
82 |
|
83 |
def get_matching_providers(self, model_name):
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
async def request_model(self, request: RequestModel, token: str):
|
87 |
model_name = request.model
|
@@ -122,6 +137,18 @@ class ModelRequestHandler:
|
|
122 |
|
123 |
model_handler = ModelRequestHandler()
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
126 |
token = credentials.credentials
|
127 |
if token not in api_keys_db:
|
@@ -137,7 +164,7 @@ def get_all_models():
|
|
137 |
unique_models = set()
|
138 |
|
139 |
for provider in config:
|
140 |
-
for model in provider['model']:
|
141 |
if model not in unique_models:
|
142 |
unique_models.add(model)
|
143 |
model_info = {
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import httpx
|
4 |
+
import logging
|
5 |
import yaml
|
6 |
import traceback
|
7 |
from contextlib import asynccontextmanager
|
8 |
|
9 |
+
from fastapi import FastAPI, Request, HTTPException, Depends
|
10 |
from fastapi.responses import StreamingResponse
|
11 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
12 |
|
|
|
49 |
return []
|
50 |
|
51 |
config = load_config()
|
52 |
+
for index, provider in enumerate(config):
|
53 |
+
model_dict = {}
|
54 |
+
for model in provider['model']:
|
55 |
+
if type(model) == str:
|
56 |
+
model_dict[model] = model
|
57 |
+
if type(model) == dict:
|
58 |
+
model_dict.update({value: key for key, value in model.items()})
|
59 |
+
provider['model'] = model_dict
|
60 |
+
config[index] = provider
|
61 |
+
# print(json.dumps(config, indent=4, ensure_ascii=False))
|
62 |
|
63 |
async def process_request(request: RequestModel, provider: Dict):
|
64 |
print("provider: ", provider['provider'])
|
|
|
74 |
|
75 |
url, headers, payload = await get_payload(request, engine, provider)
|
76 |
|
77 |
+
# request_info = {
|
78 |
+
# "url": url,
|
79 |
+
# "headers": headers,
|
80 |
+
# "payload": payload
|
81 |
+
# }
|
82 |
+
# print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
|
83 |
|
84 |
if request.stream:
|
85 |
+
model = provider['model'][request.model]
|
86 |
+
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, model), media_type="text/event-stream")
|
87 |
else:
|
88 |
return await fetch_response(app.state.client, url, headers, payload)
|
89 |
|
|
|
92 |
self.last_provider_index = -1
|
93 |
|
94 |
def get_matching_providers(self, model_name):
|
95 |
+
# for provider in config:
|
96 |
+
# print("provider", model_name, list(provider['model'].keys()))
|
97 |
+
# if model_name in provider['model'].keys():
|
98 |
+
# print("provider", provider)
|
99 |
+
return [provider for provider in config if model_name in provider['model'].keys()]
|
100 |
|
101 |
async def request_model(self, request: RequestModel, token: str):
|
102 |
model_name = request.model
|
|
|
137 |
|
138 |
model_handler = ModelRequestHandler()
|
139 |
|
140 |
+
@app.middleware("http")
|
141 |
+
async def log_requests(request: Request, call_next):
|
142 |
+
# 打印请求信息
|
143 |
+
logging.info(f"Request: {request.method} {request.url}")
|
144 |
+
# 打印请求体(如果有)
|
145 |
+
if request.method in ["POST", "PUT", "PATCH"]:
|
146 |
+
body = await request.body()
|
147 |
+
logging.info(f"Request Body: {body.decode('utf-8')}")
|
148 |
+
|
149 |
+
response = await call_next(request)
|
150 |
+
return response
|
151 |
+
|
152 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
153 |
token = credentials.credentials
|
154 |
if token not in api_keys_db:
|
|
|
164 |
unique_models = set()
|
165 |
|
166 |
for provider in config:
|
167 |
+
for model in provider['model'].keys():
|
168 |
if model not in unique_models:
|
169 |
unique_models.add(model)
|
170 |
model_info = {
|
request.py
CHANGED
@@ -38,9 +38,10 @@ async def get_gemini_payload(request, engine, provider):
|
|
38 |
'Content-Type': 'application/json'
|
39 |
}
|
40 |
url = provider['base_url']
|
|
|
41 |
if request.stream:
|
42 |
gemini_stream = "streamGenerateContent"
|
43 |
-
url = url.format(model=
|
44 |
|
45 |
messages = []
|
46 |
for msg in request.messages:
|
@@ -112,7 +113,6 @@ async def get_gpt_payload(request, engine, provider):
|
|
112 |
'Content-Type': 'application/json'
|
113 |
}
|
114 |
url = provider['base_url']
|
115 |
-
url = url.format(model=request.model, stream=request.stream, api_key=provider['api'])
|
116 |
|
117 |
messages = []
|
118 |
for msg in request.messages:
|
@@ -133,8 +133,9 @@ async def get_gpt_payload(request, engine, provider):
|
|
133 |
else:
|
134 |
messages.append({"role": msg.role, "content": content})
|
135 |
|
|
|
136 |
payload = {
|
137 |
-
"model":
|
138 |
"messages": messages,
|
139 |
}
|
140 |
|
@@ -222,8 +223,9 @@ async def get_claude_payload(request, engine, provider):
|
|
222 |
elif msg.role == "system":
|
223 |
system_prompt = content
|
224 |
|
|
|
225 |
payload = {
|
226 |
-
"model":
|
227 |
"messages": messages,
|
228 |
"system": system_prompt,
|
229 |
}
|
|
|
38 |
'Content-Type': 'application/json'
|
39 |
}
|
40 |
url = provider['base_url']
|
41 |
+
model = provider['model'][request.model]
|
42 |
if request.stream:
|
43 |
gemini_stream = "streamGenerateContent"
|
44 |
+
url = url.format(model=model, stream=gemini_stream, api_key=provider['api'])
|
45 |
|
46 |
messages = []
|
47 |
for msg in request.messages:
|
|
|
113 |
'Content-Type': 'application/json'
|
114 |
}
|
115 |
url = provider['base_url']
|
|
|
116 |
|
117 |
messages = []
|
118 |
for msg in request.messages:
|
|
|
133 |
else:
|
134 |
messages.append({"role": msg.role, "content": content})
|
135 |
|
136 |
+
model = provider['model'][request.model]
|
137 |
payload = {
|
138 |
+
"model": model,
|
139 |
"messages": messages,
|
140 |
}
|
141 |
|
|
|
223 |
elif msg.role == "system":
|
224 |
system_prompt = content
|
225 |
|
226 |
+
model = provider['model'][request.model]
|
227 |
payload = {
|
228 |
+
"model": model,
|
229 |
"messages": messages,
|
230 |
"system": system_prompt,
|
231 |
}
|
response.py
CHANGED
@@ -2,7 +2,7 @@ from datetime import datetime
|
|
2 |
import json
|
3 |
import httpx
|
4 |
|
5 |
-
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None):
|
6 |
sample_data = {
|
7 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
8 |
"object": "chat.completion.chunk",
|
@@ -24,6 +24,8 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
24 |
if tools_id and function_call_name:
|
25 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
|
26 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
|
|
|
|
27 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
28 |
|
29 |
# 构建SSE响应
|
@@ -91,6 +93,10 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
91 |
message = resp.get("message")
|
92 |
if message:
|
93 |
tokens_use = resp.get("usage")
|
|
|
|
|
|
|
|
|
94 |
if tokens_use:
|
95 |
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
|
96 |
# print("\n\rtotal_tokens", total_tokens)
|
|
|
2 |
import json
|
3 |
import httpx
|
4 |
|
5 |
+
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
|
6 |
sample_data = {
|
7 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
8 |
"object": "chat.completion.chunk",
|
|
|
24 |
if tools_id and function_call_name:
|
25 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
|
26 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
27 |
+
if role:
|
28 |
+
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
29 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
30 |
|
31 |
# 构建SSE响应
|
|
|
93 |
message = resp.get("message")
|
94 |
if message:
|
95 |
tokens_use = resp.get("usage")
|
96 |
+
role = message.get("role")
|
97 |
+
if role:
|
98 |
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
99 |
+
yield sse_string
|
100 |
if tokens_use:
|
101 |
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
|
102 |
# print("\n\rtotal_tokens", total_tokens)
|