zeta commited on
Commit
47c28b9
·
1 Parent(s): bdd215f

revert: try_all_providers method in ModelRequestHandler

Browse files

- Remove class-level last_provider_index variable
- Revert try_all_providers method to previous implementation

Files changed (1) hide show
  1. main.py +8 -9
main.py CHANGED
@@ -126,8 +126,6 @@ async def process_request(request: RequestModel, provider: Dict):
126
 
127
  import asyncio
128
  class ModelRequestHandler:
129
- last_provider_index = -1 # 类变量
130
-
131
  def __init__(self):
132
  self.last_provider_index = -1
133
 
@@ -196,21 +194,22 @@ class ModelRequestHandler:
196
 
197
  async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool, auto_retry: bool):
198
  num_providers = len(providers)
199
- start_index = (ModelRequestHandler.last_provider_index + 1) % num_providers if use_round_robin else 0
200
 
201
- for i in range(num_providers):
202
- index = (start_index + i) % num_providers
203
- provider = providers[index]
204
  try:
205
  response = await process_request(request, provider)
206
- if use_round_robin:
207
- ModelRequestHandler.last_provider_index = index
208
  return response
209
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
210
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
211
- if not auto_retry:
 
 
212
  raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
213
 
 
214
  raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
215
 
216
  model_handler = ModelRequestHandler()
 
126
 
127
  import asyncio
128
  class ModelRequestHandler:
 
 
129
  def __init__(self):
130
  self.last_provider_index = -1
131
 
 
194
 
195
  async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool, auto_retry: bool):
196
  num_providers = len(providers)
197
+ start_index = self.last_provider_index + 1 if use_round_robin else 0
198
 
199
+ for i in range(num_providers + 1):
200
+ self.last_provider_index = (start_index + i) % num_providers
201
+ provider = providers[self.last_provider_index]
202
  try:
203
  response = await process_request(request, provider)
 
 
204
  return response
205
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
206
  logger.error(f"Error with provider {provider['provider']}: {str(e)}")
207
+ if auto_retry:
208
+ continue
209
+ else:
210
  raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
211
 
212
+
213
  raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
214
 
215
  model_handler = ModelRequestHandler()