yym68686 commited on
Commit
3972d74
·
1 Parent(s): 14428d9

🐛 Bug: Fix the bug where error codes are not accurately returned to the client.

Browse files
Files changed (3) hide show
  1. main.py +16 -4
  2. response.py +1 -1
  3. utils.py +5 -3
main.py CHANGED
@@ -218,7 +218,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
218
  if request.stream:
219
  model = provider['model'][request.model]
220
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
221
- wrapped_generator = await error_handling_wrapper(generator, status_code=500)
222
  response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
223
  else:
224
  response = await anext(fetch_response(app.state.client, url, headers, payload))
@@ -369,6 +369,8 @@ class ModelRequestHandler:
369
 
370
  # 在 try_all_providers 函数中处理失败的情况
371
  async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
 
 
372
  num_providers = len(providers)
373
  start_index = self.last_provider_index + 1 if use_round_robin else 0
374
  for i in range(num_providers + 1):
@@ -377,14 +379,24 @@ class ModelRequestHandler:
377
  try:
378
  response = await process_request(request, provider, endpoint)
379
  return response
380
- except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
381
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
 
 
 
 
 
 
 
 
 
 
382
  if auto_retry:
383
  continue
384
  else:
385
- raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
386
 
387
- raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
388
 
389
  model_handler = ModelRequestHandler()
390
 
 
218
  if request.stream:
219
  model = provider['model'][request.model]
220
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
221
+ wrapped_generator = await error_handling_wrapper(generator)
222
  response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
223
  else:
224
  response = await anext(fetch_response(app.state.client, url, headers, payload))
 
369
 
370
  # 在 try_all_providers 函数中处理失败的情况
371
  async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
372
+ status_code = 500
373
+ error_message = None
374
  num_providers = len(providers)
375
  start_index = self.last_provider_index + 1 if use_round_robin else 0
376
  for i in range(num_providers + 1):
 
379
  try:
380
  response = await process_request(request, provider, endpoint)
381
  return response
382
+ except HTTPException as e:
383
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
384
+ status_code = e.status_code
385
+ error_message = e.detail
386
+
387
+ if auto_retry:
388
+ continue
389
+ else:
390
+ raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")
391
+ except (Exception, asyncio.CancelledError, httpx.ReadError) as e:
392
+ logger.error(f"Error with provider {provider['provider']}: {str(e)}")
393
+ error_message = str(e)
394
  if auto_retry:
395
  continue
396
  else:
397
+ raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")
398
 
399
+ raise HTTPException(status_code=status_code, detail=f"All {request.model} error: {error_message}")
400
 
401
  model_handler = ModelRequestHandler()
402
 
response.py CHANGED
@@ -48,7 +48,7 @@ async def check_response(response, error_log):
48
  error_json = json.loads(error_str)
49
  except json.JSONDecodeError:
50
  error_json = error_str
51
- return {"error": f"{error_log} HTTP Error {response.status_code}", "details": error_json}
52
  return None
53
 
54
  async def fetch_gemini_response_stream(client, url, headers, payload, model):
 
48
  error_json = json.loads(error_str)
49
  except json.JSONDecodeError:
50
  error_json = error_str
51
+ return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
52
  return None
53
 
54
  async def fetch_gemini_response_stream(client, url, headers, payload, model):
utils.py CHANGED
@@ -104,7 +104,7 @@ def ensure_string(item):
104
  return str(item)
105
 
106
  import asyncio
107
- async def error_handling_wrapper(generator, status_code=200):
108
  try:
109
  first_item = await generator.__anext__()
110
  first_item_str = first_item
@@ -126,7 +126,9 @@ async def error_handling_wrapper(generator, status_code=200):
126
  raise StopAsyncIteration
127
  if isinstance(first_item_str, dict) and 'error' in first_item_str:
128
  # 如果第一个 yield 的项是错误信息,抛出 HTTPException
129
- raise HTTPException(status_code=status_code, detail=f"{first_item_str}"[:300])
 
 
130
 
131
  # 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
132
  async def new_generator():
@@ -141,7 +143,7 @@ async def error_handling_wrapper(generator, status_code=200):
141
  return new_generator()
142
 
143
  except StopAsyncIteration:
144
- raise HTTPException(status_code=status_code, detail="data: {'error': 'No data returned'}")
145
 
146
  def post_all_models(token, config, api_list):
147
  all_models = []
 
104
  return str(item)
105
 
106
  import asyncio
107
+ async def error_handling_wrapper(generator):
108
  try:
109
  first_item = await generator.__anext__()
110
  first_item_str = first_item
 
126
  raise StopAsyncIteration
127
  if isinstance(first_item_str, dict) and 'error' in first_item_str:
128
  # 如果第一个 yield 的项是错误信息,抛出 HTTPException
129
+ status_code = first_item_str.get('status_code', 500)
130
+ detail = first_item_str.get('details', f"{first_item_str}")
131
+ raise HTTPException(status_code=status_code, detail=f"{detail}"[:300])
132
 
133
  # 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
134
  async def new_generator():
 
143
  return new_generator()
144
 
145
  except StopAsyncIteration:
146
+ raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
147
 
148
  def post_all_models(token, config, api_list):
149
  all_models = []