yym68686 commited on
Commit
bdb98ff
·
1 Parent(s): 5ab6b69

🐛 Bug: Fixed the bug where the cooling model could not take effect in this request.

Browse files
Files changed (1) hide show
  1. main.py +50 -40
main.py CHANGED
@@ -866,48 +866,28 @@ def get_matching_providers(request_model, config, api_index):
866
  # print("provider_list", provider_list)
867
  return provider_list
868
 
869
- import asyncio
870
- class ModelRequestHandler:
871
- def __init__(self):
872
- self.last_provider_indices = defaultdict(lambda: -1)
873
- self.locks = defaultdict(asyncio.Lock)
874
-
875
- async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
876
- config = app.state.config
877
- api_list = app.state.api_list
878
- api_index = api_list.index(token)
879
 
880
- if not safe_get(config, 'api_keys', api_index, 'model'):
881
- raise HTTPException(status_code=404, detail="No matching model found")
882
-
883
- request_model = request.model
884
- matching_providers = get_matching_providers(request_model, config, api_index)
885
 
 
 
886
  if not matching_providers:
887
- raise HTTPException(status_code=404, detail="No matching model found")
888
-
889
- if app.state.channel_manager.cooldown_period > 0:
890
- matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
891
- if not matching_providers:
892
- raise HTTPException(status_code=503, detail="No available providers at the moment")
893
 
 
 
894
  num_matching_providers = len(matching_providers)
 
895
 
 
896
 
897
- # 检查是否启用轮询
898
- scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority")
899
- if scheduling_algorithm == "random":
900
- matching_providers = random.sample(matching_providers, num_matching_providers)
901
-
902
- weights = safe_get(config, 'api_keys', api_index, "weights")
903
-
904
- # 步骤 1: 提取 matching_providers 中的所有 provider 值
905
- # print("matching_providers", matching_providers)
906
- # print(type(matching_providers[0]['model'][0].keys()), list(matching_providers[0]['model'][0].keys())[0], matching_providers[0]['model'][0].keys())
907
- all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
908
-
909
  intersection = None
910
- if weights and all_providers:
 
911
  weight_keys = set(weights.keys())
912
  provider_rules = []
913
  for model_rule in weight_keys:
@@ -922,7 +902,7 @@ class ModelRequestHandler:
922
  intersection = all_providers.intersection(weight_keys)
923
  # print("intersection", intersection)
924
 
925
- if weights and intersection:
926
  filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
927
  # print("filtered_weights", filtered_weights)
928
 
@@ -941,9 +921,31 @@ class ModelRequestHandler:
941
  new_matching_providers.append(provider)
942
  matching_providers = new_matching_providers
943
 
944
- if is_debug:
945
- for provider in matching_providers:
946
- logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
948
  status_code = 500
949
  error_message = None
@@ -956,8 +958,12 @@ class ModelRequestHandler:
956
 
957
  auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
958
 
959
- for i in range(num_matching_providers + 1):
960
- current_index = (start_index + i) % num_matching_providers
 
 
 
 
961
  provider = matching_providers[current_index]
962
  try:
963
  response = await process_request(request, provider, endpoint, token)
@@ -987,6 +993,10 @@ class ModelRequestHandler:
987
  channel_id = f"{provider['provider']}"
988
  if app.state.channel_manager.cooldown_period > 0:
989
  await app.state.channel_manager.exclude_channel(channel_id)
 
 
 
 
990
  logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
991
  if is_debug:
992
  import traceback
 
866
  # print("provider_list", provider_list)
867
  return provider_list
868
 
869
+ async def get_right_order_providers(request_model, config, api_index, scheduling_algorithm):
870
+ matching_providers = get_matching_providers(request_model, config, api_index)
 
 
 
 
 
 
 
 
871
 
872
+ if not matching_providers:
873
+ raise HTTPException(status_code=404, detail="No matching model found")
 
 
 
874
 
875
+ if app.state.channel_manager.cooldown_period > 0:
876
+ matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
877
  if not matching_providers:
878
+ raise HTTPException(status_code=503, detail="No available providers at the moment")
 
 
 
 
 
879
 
880
+ # 检查是否启用轮询
881
+ if scheduling_algorithm == "random":
882
  num_matching_providers = len(matching_providers)
883
+ matching_providers = random.sample(matching_providers, num_matching_providers)
884
 
885
+ weights = safe_get(config, 'api_keys', api_index, "weights")
886
 
887
+ if weights:
 
 
 
 
 
 
 
 
 
 
 
888
  intersection = None
889
+ all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
890
+ if all_providers:
891
  weight_keys = set(weights.keys())
892
  provider_rules = []
893
  for model_rule in weight_keys:
 
902
  intersection = all_providers.intersection(weight_keys)
903
  # print("intersection", intersection)
904
 
905
+ if intersection:
906
  filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
907
  # print("filtered_weights", filtered_weights)
908
 
 
921
  new_matching_providers.append(provider)
922
  matching_providers = new_matching_providers
923
 
924
+ if is_debug:
925
+ for provider in matching_providers:
926
+ logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
927
+
928
+ return matching_providers
929
+
930
+ import asyncio
931
+ class ModelRequestHandler:
932
+ def __init__(self):
933
+ self.last_provider_indices = defaultdict(lambda: -1)
934
+ self.locks = defaultdict(asyncio.Lock)
935
+
936
+ async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
937
+ config = app.state.config
938
+ api_list = app.state.api_list
939
+ api_index = api_list.index(token)
940
+
941
+ if not safe_get(config, 'api_keys', api_index, 'model'):
942
+ raise HTTPException(status_code=404, detail="No matching model found")
943
+
944
+ request_model = request.model
945
+ scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority")
946
+
947
+ matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
948
+ num_matching_providers = len(matching_providers)
949
 
950
  status_code = 500
951
  error_message = None
 
958
 
959
  auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
960
 
961
+ index = 0
962
+ while True:
963
+ if index >= num_matching_providers:
964
+ break
965
+ current_index = (start_index + index) % num_matching_providers
966
+ index += 1
967
  provider = matching_providers[current_index]
968
  try:
969
  response = await process_request(request, provider, endpoint, token)
 
993
  channel_id = f"{provider['provider']}"
994
  if app.state.channel_manager.cooldown_period > 0:
995
  await app.state.channel_manager.exclude_channel(channel_id)
996
+ matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
997
+ num_matching_providers = len(matching_providers)
998
+ index = 0
999
+
1000
  logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
1001
  if is_debug:
1002
  import traceback