yym68686 commited on
Commit
73319d1
·
1 Parent(s): bf80b6a

🐛 Bug: Fix the bug where Gemini cannot use non-streaming output.

Browse files
Files changed (3) hide show
  1. main.py +4 -3
  2. response.py +69 -4
  3. test/test_json.py +26 -0
main.py CHANGED
@@ -531,7 +531,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
531
  if provider.get("engine"):
532
  engine = provider["engine"]
533
 
534
- logger.info(f"provider: {provider['provider']:<10} model: {request.model:<10} engine: {engine}")
535
 
536
  url, headers, payload = await get_payload(request, engine, provider)
537
  if is_debug:
@@ -542,16 +542,17 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
542
  logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
543
  current_info = request_info.get()
544
  try:
 
545
  if request.stream:
546
- model = model_dict[request.model]
547
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
548
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
549
  response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
550
  else:
551
- generator = fetch_response(app.state.client, url, headers, payload)
552
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
553
  first_element = await anext(wrapped_generator)
554
  first_element = first_element.lstrip("data: ")
 
555
  first_element = json.loads(first_element)
556
  response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json")
557
  # response = JSONResponse(first_element)
 
531
  if provider.get("engine"):
532
  engine = provider["engine"]
533
 
534
+ logger.info(f"provider: {provider['provider']:<11} model: {request.model:<22} engine: {engine}")
535
 
536
  url, headers, payload = await get_payload(request, engine, provider)
537
  if is_debug:
 
542
  logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
543
  current_info = request_info.get()
544
  try:
545
+ model = model_dict[request.model]
546
  if request.stream:
 
547
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
548
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
549
  response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
550
  else:
551
+ generator = fetch_response(app.state.client, url, headers, payload, engine, model)
552
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
553
  first_element = await anext(wrapped_generator)
554
  first_element = first_element.lstrip("data: ")
555
+ print("first_element", first_element)
556
  first_element = json.loads(first_element)
557
  response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json")
558
  # response = JSONResponse(first_element)
response.py CHANGED
@@ -4,6 +4,8 @@ from datetime import datetime
4
 
5
  from log_config import logger
6
 
 
 
7
  # end_of_line = "\n\r\n"
8
  # end_of_line = "\r\n"
9
  # end_of_line = "\n\r"
@@ -17,7 +19,6 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
17
  "object": "chat.completion.chunk",
18
  "created": timestamp,
19
  "model": model,
20
- "system_fingerprint": "fp_d576307f90",
21
  "choices": [
22
  {
23
  "index": 0,
@@ -26,7 +27,8 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
26
  "finish_reason": None
27
  }
28
  ],
29
- "usage": None
 
30
  }
31
  if function_call_content:
32
  sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
@@ -46,6 +48,34 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
46
 
47
  return sse_response
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  async def check_response(response, error_log):
50
  if response and response.status_code != 200:
51
  error_message = await response.aread()
@@ -274,7 +304,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
274
  yield sse_string
275
  yield "data: [DONE]" + end_of_line
276
 
277
- async def fetch_response(client, url, headers, payload):
278
  response = None
279
  if payload.get("file"):
280
  file = payload.pop("file")
@@ -285,7 +315,42 @@ async def fetch_response(client, url, headers, payload):
285
  if error_message:
286
  yield error_message
287
  return
288
- yield response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  async def fetch_response_stream(client, url, headers, payload, engine, model):
291
  try:
 
4
 
5
  from log_config import logger
6
 
7
+ from utils import safe_get
8
+
9
  # end_of_line = "\n\r\n"
10
  # end_of_line = "\r\n"
11
  # end_of_line = "\n\r"
 
19
  "object": "chat.completion.chunk",
20
  "created": timestamp,
21
  "model": model,
 
22
  "choices": [
23
  {
24
  "index": 0,
 
27
  "finish_reason": None
28
  }
29
  ],
30
+ "usage": None,
31
+ "system_fingerprint": "fp_d576307f90",
32
  }
33
  if function_call_content:
34
  sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
 
48
 
49
  return sse_response
50
 
51
+ async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
52
+ sample_data = {
53
+ "id": "chatcmpl-ALGS9hpJBb8xVAe62DRriY2SpoT4L",
54
+ "object": "chat.completion",
55
+ "created": timestamp,
56
+ "model": model,
57
+ "choices": [
58
+ {
59
+ "index": 0,
60
+ "message": {
61
+ "role": role,
62
+ "content": content,
63
+ "refusal": None
64
+ },
65
+ "logprobs": None,
66
+ "finish_reason": "stop"
67
+ }
68
+ ],
69
+ "usage": None,
70
+ "system_fingerprint": "fp_a7d06e42a7"
71
+ }
72
+ if total_tokens:
73
+ total_tokens = prompt_tokens + completion_tokens
74
+ sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
75
+ json_data = json.dumps(sample_data, ensure_ascii=False)
76
+
77
+ return json_data
78
+
79
  async def check_response(response, error_log):
80
  if response and response.status_code != 200:
81
  error_message = await response.aread()
 
304
  yield sse_string
305
  yield "data: [DONE]" + end_of_line
306
 
307
+ async def fetch_response(client, url, headers, payload, engine, model):
308
  response = None
309
  if payload.get("file"):
310
  file = payload.pop("file")
 
315
  if error_message:
316
  yield error_message
317
  return
318
+ response_json = response.json()
319
+ if engine == "gemini" or engine == "vertex-gemini":
320
+
321
+ if isinstance(response_json, str):
322
+ import ast
323
+ parsed_data = ast.literal_eval(str(response_json))
324
+ elif isinstance(response_json, list):
325
+ parsed_data = response_json
326
+ else:
327
+ logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
328
+ parsed_data = response_json
329
+
330
+ content = ""
331
+ for item in parsed_data:
332
+ chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
333
+ # logger.info(f"chunk: {repr(chunk)}")
334
+ if chunk:
335
+ content += chunk
336
+
337
+ usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
338
+ prompt_tokens = usage_metadata.get("promptTokenCount", 0)
339
+ candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
340
+ total_tokens = usage_metadata.get("totalTokenCount", 0)
341
+
342
+ role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
343
+ if role == "model":
344
+ role = "assistant"
345
+ else:
346
+ logger.error(f"Unknown role: {role}")
347
+ role = "assistant"
348
+
349
+ timestamp = int(datetime.timestamp(datetime.now()))
350
+ yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=None, function_call_content=None, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
351
+
352
+ else:
353
+ yield response_json
354
 
355
  async def fetch_response_stream(client, url, headers, payload, engine, model):
356
  try:
test/test_json.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+
4
+ import os
5
+ import sys
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+ from utils import safe_get
8
+ # 读取文件内容
9
+ with open('test/states.json', 'r', encoding='utf-8') as file:
10
+ content = file.read()
11
+
12
+ # 使用ast.literal_eval解析非标准JSON
13
+ parsed_data = ast.literal_eval(content)
14
+
15
+ for item in parsed_data:
16
+ print(safe_get(item, "candidates", 0, "content", "parts", 0, "text"))
17
+ print(safe_get(item, "candidates", 0, "content", "role"))
18
+
19
+ # 将解析后的数据转换为标准JSON
20
+ standard_json = json.dumps(parsed_data, ensure_ascii=False, indent=2)
21
+
22
+ # 将标准JSON写入新文件
23
+ with open('test/standard_states.json', 'w', encoding='utf-8') as file:
24
+ file.write(standard_json)
25
+
26
+ print("转换完成,标准JSON已保存到 'test/standard_states.json'")