🐛 Bug: Fixed the bug where the cooling model could not take effect in this request.
Browse files
main.py
CHANGED
@@ -866,48 +866,28 @@ def get_matching_providers(request_model, config, api_index):
|
|
866 |
# print("provider_list", provider_list)
|
867 |
return provider_list
|
868 |
|
869 |
-
|
870 |
-
|
871 |
-
def __init__(self):
|
872 |
-
self.last_provider_indices = defaultdict(lambda: -1)
|
873 |
-
self.locks = defaultdict(asyncio.Lock)
|
874 |
-
|
875 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
876 |
-
config = app.state.config
|
877 |
-
api_list = app.state.api_list
|
878 |
-
api_index = api_list.index(token)
|
879 |
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
request_model = request.model
|
884 |
-
matching_providers = get_matching_providers(request_model, config, api_index)
|
885 |
|
|
|
|
|
886 |
if not matching_providers:
|
887 |
-
raise HTTPException(status_code=
|
888 |
-
|
889 |
-
if app.state.channel_manager.cooldown_period > 0:
|
890 |
-
matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
|
891 |
-
if not matching_providers:
|
892 |
-
raise HTTPException(status_code=503, detail="No available providers at the moment")
|
893 |
|
|
|
|
|
894 |
num_matching_providers = len(matching_providers)
|
|
|
895 |
|
|
|
896 |
|
897 |
-
|
898 |
-
scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority")
|
899 |
-
if scheduling_algorithm == "random":
|
900 |
-
matching_providers = random.sample(matching_providers, num_matching_providers)
|
901 |
-
|
902 |
-
weights = safe_get(config, 'api_keys', api_index, "weights")
|
903 |
-
|
904 |
-
# 步骤 1: 提取 matching_providers 中的所有 provider 值
|
905 |
-
# print("matching_providers", matching_providers)
|
906 |
-
# print(type(matching_providers[0]['model'][0].keys()), list(matching_providers[0]['model'][0].keys())[0], matching_providers[0]['model'][0].keys())
|
907 |
-
all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
|
908 |
-
|
909 |
intersection = None
|
910 |
-
|
|
|
911 |
weight_keys = set(weights.keys())
|
912 |
provider_rules = []
|
913 |
for model_rule in weight_keys:
|
@@ -922,7 +902,7 @@ class ModelRequestHandler:
|
|
922 |
intersection = all_providers.intersection(weight_keys)
|
923 |
# print("intersection", intersection)
|
924 |
|
925 |
-
if
|
926 |
filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
|
927 |
# print("filtered_weights", filtered_weights)
|
928 |
|
@@ -941,9 +921,31 @@ class ModelRequestHandler:
|
|
941 |
new_matching_providers.append(provider)
|
942 |
matching_providers = new_matching_providers
|
943 |
|
944 |
-
|
945 |
-
|
946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
|
948 |
status_code = 500
|
949 |
error_message = None
|
@@ -956,8 +958,12 @@ class ModelRequestHandler:
|
|
956 |
|
957 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
958 |
|
959 |
-
|
960 |
-
|
|
|
|
|
|
|
|
|
961 |
provider = matching_providers[current_index]
|
962 |
try:
|
963 |
response = await process_request(request, provider, endpoint, token)
|
@@ -987,6 +993,10 @@ class ModelRequestHandler:
|
|
987 |
channel_id = f"{provider['provider']}"
|
988 |
if app.state.channel_manager.cooldown_period > 0:
|
989 |
await app.state.channel_manager.exclude_channel(channel_id)
|
|
|
|
|
|
|
|
|
990 |
logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
|
991 |
if is_debug:
|
992 |
import traceback
|
|
|
866 |
# print("provider_list", provider_list)
|
867 |
return provider_list
|
868 |
|
869 |
+
async def get_right_order_providers(request_model, config, api_index, scheduling_algorithm):
|
870 |
+
matching_providers = get_matching_providers(request_model, config, api_index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
|
872 |
+
if not matching_providers:
|
873 |
+
raise HTTPException(status_code=404, detail="No matching model found")
|
|
|
|
|
|
|
874 |
|
875 |
+
if app.state.channel_manager.cooldown_period > 0:
|
876 |
+
matching_providers = await app.state.channel_manager.get_available_providers(matching_providers)
|
877 |
if not matching_providers:
|
878 |
+
raise HTTPException(status_code=503, detail="No available providers at the moment")
|
|
|
|
|
|
|
|
|
|
|
879 |
|
880 |
+
# 检查是否启用轮询
|
881 |
+
if scheduling_algorithm == "random":
|
882 |
num_matching_providers = len(matching_providers)
|
883 |
+
matching_providers = random.sample(matching_providers, num_matching_providers)
|
884 |
|
885 |
+
weights = safe_get(config, 'api_keys', api_index, "weights")
|
886 |
|
887 |
+
if weights:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
888 |
intersection = None
|
889 |
+
all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
|
890 |
+
if all_providers:
|
891 |
weight_keys = set(weights.keys())
|
892 |
provider_rules = []
|
893 |
for model_rule in weight_keys:
|
|
|
902 |
intersection = all_providers.intersection(weight_keys)
|
903 |
# print("intersection", intersection)
|
904 |
|
905 |
+
if intersection:
|
906 |
filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
|
907 |
# print("filtered_weights", filtered_weights)
|
908 |
|
|
|
921 |
new_matching_providers.append(provider)
|
922 |
matching_providers = new_matching_providers
|
923 |
|
924 |
+
if is_debug:
|
925 |
+
for provider in matching_providers:
|
926 |
+
logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
927 |
+
|
928 |
+
return matching_providers
|
929 |
+
|
930 |
+
import asyncio
|
931 |
+
class ModelRequestHandler:
|
932 |
+
def __init__(self):
|
933 |
+
self.last_provider_indices = defaultdict(lambda: -1)
|
934 |
+
self.locks = defaultdict(asyncio.Lock)
|
935 |
+
|
936 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
937 |
+
config = app.state.config
|
938 |
+
api_list = app.state.api_list
|
939 |
+
api_index = api_list.index(token)
|
940 |
+
|
941 |
+
if not safe_get(config, 'api_keys', api_index, 'model'):
|
942 |
+
raise HTTPException(status_code=404, detail="No matching model found")
|
943 |
+
|
944 |
+
request_model = request.model
|
945 |
+
scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority")
|
946 |
+
|
947 |
+
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
948 |
+
num_matching_providers = len(matching_providers)
|
949 |
|
950 |
status_code = 500
|
951 |
error_message = None
|
|
|
958 |
|
959 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
960 |
|
961 |
+
index = 0
|
962 |
+
while True:
|
963 |
+
if index >= num_matching_providers:
|
964 |
+
break
|
965 |
+
current_index = (start_index + index) % num_matching_providers
|
966 |
+
index += 1
|
967 |
provider = matching_providers[current_index]
|
968 |
try:
|
969 |
response = await process_request(request, provider, endpoint, token)
|
|
|
993 |
channel_id = f"{provider['provider']}"
|
994 |
if app.state.channel_manager.cooldown_period > 0:
|
995 |
await app.state.channel_manager.exclude_channel(channel_id)
|
996 |
+
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
997 |
+
num_matching_providers = len(matching_providers)
|
998 |
+
index = 0
|
999 |
+
|
1000 |
logger.error(f"Error {status_code} with provider {channel_id}: {error_message}")
|
1001 |
if is_debug:
|
1002 |
import traceback
|