🐛 Bug: Fix the bug where Gemini cannot use non-streaming output.
Browse files- main.py +4 -3
- response.py +69 -4
- test/test_json.py +26 -0
main.py
CHANGED
@@ -531,7 +531,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
531 |
if provider.get("engine"):
|
532 |
engine = provider["engine"]
|
533 |
|
534 |
-
logger.info(f"provider: {provider['provider']:<
|
535 |
|
536 |
url, headers, payload = await get_payload(request, engine, provider)
|
537 |
if is_debug:
|
@@ -542,16 +542,17 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
542 |
logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
|
543 |
current_info = request_info.get()
|
544 |
try:
|
|
|
545 |
if request.stream:
|
546 |
-
model = model_dict[request.model]
|
547 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
548 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
549 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
550 |
else:
|
551 |
-
generator = fetch_response(app.state.client, url, headers, payload)
|
552 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
553 |
first_element = await anext(wrapped_generator)
|
554 |
first_element = first_element.lstrip("data: ")
|
|
|
555 |
first_element = json.loads(first_element)
|
556 |
response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json")
|
557 |
# response = JSONResponse(first_element)
|
|
|
531 |
if provider.get("engine"):
|
532 |
engine = provider["engine"]
|
533 |
|
534 |
+
logger.info(f"provider: {provider['provider']:<11} model: {request.model:<22} engine: {engine}")
|
535 |
|
536 |
url, headers, payload = await get_payload(request, engine, provider)
|
537 |
if is_debug:
|
|
|
542 |
logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
|
543 |
current_info = request_info.get()
|
544 |
try:
|
545 |
+
model = model_dict[request.model]
|
546 |
if request.stream:
|
|
|
547 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
548 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
549 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
550 |
else:
|
551 |
+
generator = fetch_response(app.state.client, url, headers, payload, engine, model)
|
552 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
553 |
first_element = await anext(wrapped_generator)
|
554 |
first_element = first_element.lstrip("data: ")
|
555 |
+
print("first_element", first_element)
|
556 |
first_element = json.loads(first_element)
|
557 |
response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json")
|
558 |
# response = JSONResponse(first_element)
|
response.py
CHANGED
@@ -4,6 +4,8 @@ from datetime import datetime
|
|
4 |
|
5 |
from log_config import logger
|
6 |
|
|
|
|
|
7 |
# end_of_line = "\n\r\n"
|
8 |
# end_of_line = "\r\n"
|
9 |
# end_of_line = "\n\r"
|
@@ -17,7 +19,6 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
17 |
"object": "chat.completion.chunk",
|
18 |
"created": timestamp,
|
19 |
"model": model,
|
20 |
-
"system_fingerprint": "fp_d576307f90",
|
21 |
"choices": [
|
22 |
{
|
23 |
"index": 0,
|
@@ -26,7 +27,8 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
26 |
"finish_reason": None
|
27 |
}
|
28 |
],
|
29 |
-
"usage": None
|
|
|
30 |
}
|
31 |
if function_call_content:
|
32 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
|
@@ -46,6 +48,34 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
46 |
|
47 |
return sse_response
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
async def check_response(response, error_log):
|
50 |
if response and response.status_code != 200:
|
51 |
error_message = await response.aread()
|
@@ -274,7 +304,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
274 |
yield sse_string
|
275 |
yield "data: [DONE]" + end_of_line
|
276 |
|
277 |
-
async def fetch_response(client, url, headers, payload):
|
278 |
response = None
|
279 |
if payload.get("file"):
|
280 |
file = payload.pop("file")
|
@@ -285,7 +315,42 @@ async def fetch_response(client, url, headers, payload):
|
|
285 |
if error_message:
|
286 |
yield error_message
|
287 |
return
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
291 |
try:
|
|
|
4 |
|
5 |
from log_config import logger
|
6 |
|
7 |
+
from utils import safe_get
|
8 |
+
|
9 |
# end_of_line = "\n\r\n"
|
10 |
# end_of_line = "\r\n"
|
11 |
# end_of_line = "\n\r"
|
|
|
19 |
"object": "chat.completion.chunk",
|
20 |
"created": timestamp,
|
21 |
"model": model,
|
|
|
22 |
"choices": [
|
23 |
{
|
24 |
"index": 0,
|
|
|
27 |
"finish_reason": None
|
28 |
}
|
29 |
],
|
30 |
+
"usage": None,
|
31 |
+
"system_fingerprint": "fp_d576307f90",
|
32 |
}
|
33 |
if function_call_content:
|
34 |
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
|
|
|
48 |
|
49 |
return sse_response
|
50 |
|
51 |
+
async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
|
52 |
+
sample_data = {
|
53 |
+
"id": "chatcmpl-ALGS9hpJBb8xVAe62DRriY2SpoT4L",
|
54 |
+
"object": "chat.completion",
|
55 |
+
"created": timestamp,
|
56 |
+
"model": model,
|
57 |
+
"choices": [
|
58 |
+
{
|
59 |
+
"index": 0,
|
60 |
+
"message": {
|
61 |
+
"role": role,
|
62 |
+
"content": content,
|
63 |
+
"refusal": None
|
64 |
+
},
|
65 |
+
"logprobs": None,
|
66 |
+
"finish_reason": "stop"
|
67 |
+
}
|
68 |
+
],
|
69 |
+
"usage": None,
|
70 |
+
"system_fingerprint": "fp_a7d06e42a7"
|
71 |
+
}
|
72 |
+
if total_tokens:
|
73 |
+
total_tokens = prompt_tokens + completion_tokens
|
74 |
+
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
|
75 |
+
json_data = json.dumps(sample_data, ensure_ascii=False)
|
76 |
+
|
77 |
+
return json_data
|
78 |
+
|
79 |
async def check_response(response, error_log):
|
80 |
if response and response.status_code != 200:
|
81 |
error_message = await response.aread()
|
|
|
304 |
yield sse_string
|
305 |
yield "data: [DONE]" + end_of_line
|
306 |
|
307 |
+
async def fetch_response(client, url, headers, payload, engine, model):
|
308 |
response = None
|
309 |
if payload.get("file"):
|
310 |
file = payload.pop("file")
|
|
|
315 |
if error_message:
|
316 |
yield error_message
|
317 |
return
|
318 |
+
response_json = response.json()
|
319 |
+
if engine == "gemini" or engine == "vertex-gemini":
|
320 |
+
|
321 |
+
if isinstance(response_json, str):
|
322 |
+
import ast
|
323 |
+
parsed_data = ast.literal_eval(str(response_json))
|
324 |
+
elif isinstance(response_json, list):
|
325 |
+
parsed_data = response_json
|
326 |
+
else:
|
327 |
+
logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
|
328 |
+
parsed_data = response_json
|
329 |
+
|
330 |
+
content = ""
|
331 |
+
for item in parsed_data:
|
332 |
+
chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
|
333 |
+
# logger.info(f"chunk: {repr(chunk)}")
|
334 |
+
if chunk:
|
335 |
+
content += chunk
|
336 |
+
|
337 |
+
usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
|
338 |
+
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
|
339 |
+
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
340 |
+
total_tokens = usage_metadata.get("totalTokenCount", 0)
|
341 |
+
|
342 |
+
role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
|
343 |
+
if role == "model":
|
344 |
+
role = "assistant"
|
345 |
+
else:
|
346 |
+
logger.error(f"Unknown role: {role}")
|
347 |
+
role = "assistant"
|
348 |
+
|
349 |
+
timestamp = int(datetime.timestamp(datetime.now()))
|
350 |
+
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
|
351 |
+
|
352 |
+
else:
|
353 |
+
yield response_json
|
354 |
|
355 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
356 |
try:
|
test/test_json.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import json
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
7 |
+
from utils import safe_get
|
8 |
+
# 读取文件内容
|
9 |
+
with open('test/states.json', 'r', encoding='utf-8') as file:
|
10 |
+
content = file.read()
|
11 |
+
|
12 |
+
# 使用ast.literal_eval解析非标准JSON
|
13 |
+
parsed_data = ast.literal_eval(content)
|
14 |
+
|
15 |
+
for item in parsed_data:
|
16 |
+
print(safe_get(item, "candidates", 0, "content", "parts", 0, "text"))
|
17 |
+
print(safe_get(item, "candidates", 0, "content", "role"))
|
18 |
+
|
19 |
+
# 将解析后的数据转换为标准JSON
|
20 |
+
standard_json = json.dumps(parsed_data, ensure_ascii=False, indent=2)
|
21 |
+
|
22 |
+
# 将标准JSON写入新文件
|
23 |
+
with open('test/standard_states.json', 'w', encoding='utf-8') as file:
|
24 |
+
file.write(standard_json)
|
25 |
+
|
26 |
+
print("转换完成,标准JSON已保存到 'test/standard_states.json'")
|