Moonfanz commited on
Commit
78a1562
·
verified ·
1 Parent(s): c06ccd5

Upload 4 files

Browse files
Files changed (1) hide show
  1. app.py +41 -27
app.py CHANGED
@@ -83,27 +83,23 @@ safety_settings_g2 = [
83
  ]
84
  @dataclass
85
  class GeneratedText:
86
- """用于存储生成的文本片段"""
87
  text: str
88
  finish_reason: Optional[str] = None
89
 
90
 
91
  class ResponseWrapper:
92
- """处理非流式响应的包装类"""
93
  def __init__(self, data: Dict[Any, Any]):
94
  self._data = data
95
  self._text = self._extract_text()
96
  self._finish_reason = self._extract_finish_reason()
97
 
98
  def _extract_text(self) -> str:
99
- """从响应数据中提取文本"""
100
  try:
101
  return self._data['candidates'][0]['content']['parts'][0]['text']
102
  except (KeyError, IndexError):
103
  return ""
104
 
105
  def _extract_finish_reason(self) -> Optional[str]:
106
- """提取完成原因"""
107
  try:
108
  return self._data['candidates'][0].get('finishReason')
109
  except (KeyError, IndexError):
@@ -111,12 +107,10 @@ class ResponseWrapper:
111
 
112
  @property
113
  def text(self) -> str:
114
- """获取响应文本"""
115
  return self._text
116
 
117
  @property
118
  def finish_reason(self) -> Optional[str]:
119
- """获取完成原因"""
120
  return self._finish_reason
121
 
122
  class APIKeyManager:
@@ -345,6 +339,7 @@ def chat_completions():
345
  model = request_data.get('model', 'gemini-2.0-flash-exp')
346
  temperature = request_data.get('temperature', 1)
347
  max_tokens = request_data.get('max_tokens', 8192)
 
348
  stream = request_data.get('stream', False)
349
  hint = "流式" if stream else "非流"
350
  logger.info(f"\n{model} [{hint}] → ...")
@@ -399,7 +394,7 @@ def chat_completions():
399
  return handle_api_error(e, attempt)
400
 
401
  def generate_stream(response):
402
- buffer = b""
403
  try:
404
  for line in response.iter_lines():
405
  if not line:
@@ -418,21 +413,25 @@ def chat_completions():
418
  if 'content' in candidate:
419
  content = candidate['content']
420
  if 'parts' in content and content['parts']:
421
- text = content['parts'][0].get('text', '')
422
- finish_reason = candidate.get('finishReason')
423
-
424
- if text:
425
- data = {
426
- 'choices': [{
427
- 'delta': {
428
- 'content': text
429
- },
430
- 'finish_reason': finish_reason,
431
- 'index': 0
432
- }],
433
- 'object': 'chat.completion.chunk'
434
- }
435
- yield f"data: {json.dumps(data)}\n\n"
 
 
 
 
436
 
437
  except json.JSONDecodeError:
438
  logger.debug(f"JSONDecodeError, buffer now: {buffer}")
@@ -487,7 +486,7 @@ def chat_completions():
487
  try:
488
  text_content = response.text
489
  except (AttributeError, IndexError, TypeError, ValueError) as e:
490
- if "response.candidates" in str(e) or "response.text" in str(e):
491
  logger.error(f"用户输入被AI安全过滤器阻止")
492
  return jsonify({
493
  'error': {
@@ -496,13 +495,28 @@ def chat_completions():
496
  'details': str(e)
497
  }
498
  }), 400
499
- else:
500
  return jsonify({
501
  'error': {
502
- 'message': 'AI响应处理失败',
503
- 'type': 'response_processing_error'
504
- }
505
  }), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  response_data = {
508
  'id': 'chatcmpl-xxxxxxxxxxxx',
 
83
  ]
84
  @dataclass
85
  class GeneratedText:
 
86
  text: str
87
  finish_reason: Optional[str] = None
88
 
89
 
90
  class ResponseWrapper:
 
91
  def __init__(self, data: Dict[Any, Any]):
92
  self._data = data
93
  self._text = self._extract_text()
94
  self._finish_reason = self._extract_finish_reason()
95
 
96
  def _extract_text(self) -> str:
 
97
  try:
98
  return self._data['candidates'][0]['content']['parts'][0]['text']
99
  except (KeyError, IndexError):
100
  return ""
101
 
102
  def _extract_finish_reason(self) -> Optional[str]:
 
103
  try:
104
  return self._data['candidates'][0].get('finishReason')
105
  except (KeyError, IndexError):
 
107
 
108
  @property
109
  def text(self) -> str:
 
110
  return self._text
111
 
112
  @property
113
  def finish_reason(self) -> Optional[str]:
 
114
  return self._finish_reason
115
 
116
  class APIKeyManager:
 
339
  model = request_data.get('model', 'gemini-2.0-flash-exp')
340
  temperature = request_data.get('temperature', 1)
341
  max_tokens = request_data.get('max_tokens', 8192)
342
+ show_thoughts = request_data.get('show_thoughts', False)
343
  stream = request_data.get('stream', False)
344
  hint = "流式" if stream else "非流"
345
  logger.info(f"\n{model} [{hint}] → ...")
 
394
  return handle_api_error(e, attempt)
395
 
396
  def generate_stream(response):
397
+ buffer = b""
398
  try:
399
  for line in response.iter_lines():
400
  if not line:
 
413
  if 'content' in candidate:
414
  content = candidate['content']
415
  if 'parts' in content and content['parts']:
416
+ parts = content['parts']
417
+ if is_thinking and not show_thoughts:
418
+ parts = [part for part in parts if not part.get('thought')]
419
+ if parts:
420
+ text = parts[0].get('text', '')
421
+ finish_reason = candidate.get('finishReason')
422
+
423
+ if text:
424
+ data = {
425
+ 'choices': [{
426
+ 'delta': {
427
+ 'content': text
428
+ },
429
+ 'finish_reason': finish_reason,
430
+ 'index': 0
431
+ }],
432
+ 'object': 'chat.completion.chunk'
433
+ }
434
+ yield f"data: {json.dumps(data)}\n\n"
435
 
436
  except json.JSONDecodeError:
437
  logger.debug(f"JSONDecodeError, buffer now: {buffer}")
 
486
  try:
487
  text_content = response.text
488
  except (AttributeError, IndexError, TypeError, ValueError) as e:
489
+ if "response.candidates" in str(e) or "response.text" in str(e):
490
  logger.error(f"用户输入被AI安全过滤器阻止")
491
  return jsonify({
492
  'error': {
 
495
  'details': str(e)
496
  }
497
  }), 400
498
+ else:
499
  return jsonify({
500
  'error': {
501
+ 'message': 'AI响应处理失败',
502
+ 'type': 'response_processing_error'
503
+ }
504
  }), 500
505
+ try:
506
+ response_json = json.loads(response.text)
507
+ if 'candidates' in response_json and response_json['candidates']:
508
+ candidate = response_json['candidates'][0]
509
+ if 'content' in candidate:
510
+ content = candidate['content']
511
+ if 'parts' in content and content['parts']:
512
+ parts = content['parts']
513
+ if is_thinking and not show_thoughts:
514
+ parts = [part for part in parts if not part.get('thought')]
515
+
516
+ if parts:
517
+ text_content = "".join(part.get('text', '') for part in parts)
518
+ except json.JSONDecodeError:
519
+ pass
520
 
521
  response_data = {
522
  'id': 'chatcmpl-xxxxxxxxxxxx',