yym68686 commited on
Commit
52bcfe4
·
1 Parent(s): 8866240

Fixed the bug where the Claude role could not be obtained and the SSE format was incorrect.

Browse files
Files changed (4) hide show
  1. json_str/gpt/mess_sse.json +12 -0
  2. main.py +38 -11
  3. request.py +6 -4
  4. response.py +7 -1
json_str/gpt/mess_sse.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}
2
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null}
3
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null}
4
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null}
5
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null}
6
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null}
7
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null}
8
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null}
9
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null}
10
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null}
11
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}
12
+ data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[],"usage":{"prompt_tokens":178,"completion_tokens":10,"total_tokens":188}}
main.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import json
3
  import httpx
 
4
  import yaml
5
  import traceback
6
  from contextlib import asynccontextmanager
7
 
8
- from fastapi import FastAPI, HTTPException, Depends
9
  from fastapi.responses import StreamingResponse
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
 
@@ -48,7 +49,16 @@ def load_config():
48
  return []
49
 
50
  config = load_config()
51
- # print(config)
 
 
 
 
 
 
 
 
 
52
 
53
  async def process_request(request: RequestModel, provider: Dict):
54
  print("provider: ", provider['provider'])
@@ -64,15 +74,16 @@ async def process_request(request: RequestModel, provider: Dict):
64
 
65
  url, headers, payload = await get_payload(request, engine, provider)
66
 
67
- request_info = {
68
- "url": url,
69
- "headers": headers,
70
- "payload": payload
71
- }
72
- print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
73
 
74
  if request.stream:
75
- return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, request.model), media_type="text/event-stream")
 
76
  else:
77
  return await fetch_response(app.state.client, url, headers, payload)
78
 
@@ -81,7 +92,11 @@ class ModelRequestHandler:
81
  self.last_provider_index = -1
82
 
83
  def get_matching_providers(self, model_name):
84
- return [provider for provider in config if model_name in provider['model']]
 
 
 
 
85
 
86
  async def request_model(self, request: RequestModel, token: str):
87
  model_name = request.model
@@ -122,6 +137,18 @@ class ModelRequestHandler:
122
 
123
  model_handler = ModelRequestHandler()
124
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
126
  token = credentials.credentials
127
  if token not in api_keys_db:
@@ -137,7 +164,7 @@ def get_all_models():
137
  unique_models = set()
138
 
139
  for provider in config:
140
- for model in provider['model']:
141
  if model not in unique_models:
142
  unique_models.add(model)
143
  model_info = {
 
1
  import os
2
  import json
3
  import httpx
4
+ import logging
5
  import yaml
6
  import traceback
7
  from contextlib import asynccontextmanager
8
 
9
+ from fastapi import FastAPI, Request, HTTPException, Depends
10
  from fastapi.responses import StreamingResponse
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
 
 
49
  return []
50
 
51
  config = load_config()
52
+ for index, provider in enumerate(config):
53
+ model_dict = {}
54
+ for model in provider['model']:
55
+ if type(model) == str:
56
+ model_dict[model] = model
57
+ if type(model) == dict:
58
+ model_dict.update({value: key for key, value in model.items()})
59
+ provider['model'] = model_dict
60
+ config[index] = provider
61
+ # print(json.dumps(config, indent=4, ensure_ascii=False))
62
 
63
  async def process_request(request: RequestModel, provider: Dict):
64
  print("provider: ", provider['provider'])
 
74
 
75
  url, headers, payload = await get_payload(request, engine, provider)
76
 
77
+ # request_info = {
78
+ # "url": url,
79
+ # "headers": headers,
80
+ # "payload": payload
81
+ # }
82
+ # print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
83
 
84
  if request.stream:
85
+ model = provider['model'][request.model]
86
+ return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, model), media_type="text/event-stream")
87
  else:
88
  return await fetch_response(app.state.client, url, headers, payload)
89
 
 
92
  self.last_provider_index = -1
93
 
94
  def get_matching_providers(self, model_name):
95
+ # for provider in config:
96
+ # print("provider", model_name, list(provider['model'].keys()))
97
+ # if model_name in provider['model'].keys():
98
+ # print("provider", provider)
99
+ return [provider for provider in config if model_name in provider['model'].keys()]
100
 
101
  async def request_model(self, request: RequestModel, token: str):
102
  model_name = request.model
 
137
 
138
  model_handler = ModelRequestHandler()
139
 
140
+ @app.middleware("http")
141
+ async def log_requests(request: Request, call_next):
142
+ # 打印请求信息
143
+ logging.info(f"Request: {request.method} {request.url}")
144
+ # 打印请求体(如果有)
145
+ if request.method in ["POST", "PUT", "PATCH"]:
146
+ body = await request.body()
147
+ logging.info(f"Request Body: {body.decode('utf-8')}")
148
+
149
+ response = await call_next(request)
150
+ return response
151
+
152
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
153
  token = credentials.credentials
154
  if token not in api_keys_db:
 
164
  unique_models = set()
165
 
166
  for provider in config:
167
+ for model in provider['model'].keys():
168
  if model not in unique_models:
169
  unique_models.add(model)
170
  model_info = {
request.py CHANGED
@@ -38,9 +38,10 @@ async def get_gemini_payload(request, engine, provider):
38
  'Content-Type': 'application/json'
39
  }
40
  url = provider['base_url']
 
41
  if request.stream:
42
  gemini_stream = "streamGenerateContent"
43
- url = url.format(model=request.model, stream=gemini_stream, api_key=provider['api'])
44
 
45
  messages = []
46
  for msg in request.messages:
@@ -112,7 +113,6 @@ async def get_gpt_payload(request, engine, provider):
112
  'Content-Type': 'application/json'
113
  }
114
  url = provider['base_url']
115
- url = url.format(model=request.model, stream=request.stream, api_key=provider['api'])
116
 
117
  messages = []
118
  for msg in request.messages:
@@ -133,8 +133,9 @@ async def get_gpt_payload(request, engine, provider):
133
  else:
134
  messages.append({"role": msg.role, "content": content})
135
 
 
136
  payload = {
137
- "model": request.model,
138
  "messages": messages,
139
  }
140
 
@@ -222,8 +223,9 @@ async def get_claude_payload(request, engine, provider):
222
  elif msg.role == "system":
223
  system_prompt = content
224
 
 
225
  payload = {
226
- "model": request.model,
227
  "messages": messages,
228
  "system": system_prompt,
229
  }
 
38
  'Content-Type': 'application/json'
39
  }
40
  url = provider['base_url']
41
+ model = provider['model'][request.model]
42
  if request.stream:
43
  gemini_stream = "streamGenerateContent"
44
+ url = url.format(model=model, stream=gemini_stream, api_key=provider['api'])
45
 
46
  messages = []
47
  for msg in request.messages:
 
113
  'Content-Type': 'application/json'
114
  }
115
  url = provider['base_url']
 
116
 
117
  messages = []
118
  for msg in request.messages:
 
133
  else:
134
  messages.append({"role": msg.role, "content": content})
135
 
136
+ model = provider['model'][request.model]
137
  payload = {
138
+ "model": model,
139
  "messages": messages,
140
  }
141
 
 
223
  elif msg.role == "system":
224
  system_prompt = content
225
 
226
+ model = provider['model'][request.model]
227
  payload = {
228
+ "model": model,
229
  "messages": messages,
230
  "system": system_prompt,
231
  }
response.py CHANGED
@@ -2,7 +2,7 @@ from datetime import datetime
2
  import json
3
  import httpx
4
 
5
- async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None):
6
  sample_data = {
7
  "id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
8
  "object": "chat.completion.chunk",
@@ -24,6 +24,8 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
24
  if tools_id and function_call_name:
25
  sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
26
  # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
 
 
27
  json_data = json.dumps(sample_data, ensure_ascii=False)
28
 
29
  # 构建SSE响应
@@ -91,6 +93,10 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
91
  message = resp.get("message")
92
  if message:
93
  tokens_use = resp.get("usage")
 
 
 
 
94
  if tokens_use:
95
  total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
96
  # print("\n\rtotal_tokens", total_tokens)
 
2
  import json
3
  import httpx
4
 
5
+ async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
6
  sample_data = {
7
  "id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
8
  "object": "chat.completion.chunk",
 
24
  if tools_id and function_call_name:
25
  sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
26
  # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
27
+ if role:
28
+ sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
29
  json_data = json.dumps(sample_data, ensure_ascii=False)
30
 
31
  # 构建SSE响应
 
93
  message = resp.get("message")
94
  if message:
95
  tokens_use = resp.get("usage")
96
+ role = message.get("role")
97
+ if role:
98
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
99
+ yield sse_string
100
  if tokens_use:
101
  total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
102
  # print("\n\rtotal_tokens", total_tokens)