yym68686 commited on
Commit
f2a60ff
·
1 Parent(s): f01693f

🐛 Bug: Fix the bug of incorrect weight allocation

Browse files
Files changed (2) hide show
  1. main.py +7 -7
  2. utils.py +2 -2
main.py CHANGED
@@ -191,11 +191,11 @@ class ChannelManager:
191
  for provider in providers:
192
  provider_name = provider['provider']
193
  model_dict = provider['model'][0] # 获取唯一的模型字典
194
- source_model = list(model_dict.keys())[0] # 源模型名称
195
- # target_model = list(model_dict.values())[0] # 目标模型名称
196
 
197
  # 检查该模型是否被排除
198
- if not await self.is_model_excluded(provider_name, source_model):
199
  available_providers.append(provider)
200
 
201
  return available_providers
@@ -894,14 +894,14 @@ async def get_right_order_providers(request_model, config, api_index, scheduling
894
 
895
  if weights:
896
  intersection = None
897
- all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
898
  if all_providers:
899
  weight_keys = set(weights.keys())
900
  provider_rules = []
901
  for model_rule in weight_keys:
902
  provider_rules.extend(get_provider_rules(model_rule, config, request_model))
903
  provider_list = get_provider_list(provider_rules, config, request_model)
904
- weight_keys = set([provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in provider_list])
905
  # print("all_providers", all_providers)
906
  # print("weights", weights)
907
  # print("weight_keys", weight_keys)
@@ -1001,8 +1001,8 @@ class ModelRequestHandler:
1001
  channel_id = f"{provider['provider']}"
1002
  if app.state.channel_manager.cooldown_period > 0:
1003
  # 获取源模型名称(实际配置的模型名)
1004
- source_model = list(provider['model'][0].keys())[0]
1005
- await app.state.channel_manager.exclude_model(channel_id, source_model)
1006
  matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
1007
  num_matching_providers = len(matching_providers)
1008
  index = 0
 
191
  for provider in providers:
192
  provider_name = provider['provider']
193
  model_dict = provider['model'][0] # 获取唯一的模型字典
194
+ # source_model = list(model_dict.keys())[0] # 源模型名称
195
+ target_model = list(model_dict.values())[0] # 目标模型名称
196
 
197
  # 检查该模型是否被排除
198
+ if not await self.is_model_excluded(provider_name, target_model):
199
  available_providers.append(provider)
200
 
201
  return available_providers
 
894
 
895
  if weights:
896
  intersection = None
897
+ all_providers = set(provider['provider'] + "/" + request_model for provider in matching_providers)
898
  if all_providers:
899
  weight_keys = set(weights.keys())
900
  provider_rules = []
901
  for model_rule in weight_keys:
902
  provider_rules.extend(get_provider_rules(model_rule, config, request_model))
903
  provider_list = get_provider_list(provider_rules, config, request_model)
904
+ weight_keys = set([provider['provider'] + "/" + request_model for provider in provider_list])
905
  # print("all_providers", all_providers)
906
  # print("weights", weights)
907
  # print("weight_keys", weight_keys)
 
1001
  channel_id = f"{provider['provider']}"
1002
  if app.state.channel_manager.cooldown_period > 0:
1003
  # 获取源模型名称(实际配置的模型名)
1004
+ # source_model = list(provider['model'][0].keys())[0]
1005
+ await app.state.channel_manager.exclude_model(channel_id, request_model)
1006
  matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
1007
  num_matching_providers = len(matching_providers)
1008
  index = 0
utils.py CHANGED
@@ -117,9 +117,9 @@ def update_config(config_data, use_config_url=False):
117
  continue
118
  model_dict = get_model_dict(provider_item)
119
  if model_name in model_dict.keys():
120
- weights_dict.update({provider_name + "/" + model_dict[model_name]: int(value)})
121
  elif model_name == "*":
122
- weights_dict.update({provider_name + "/" + model_dict[model_item]: int(value) for model_item in model_dict.keys()})
123
 
124
  models.append(key)
125
  if isinstance(model, str):
 
117
  continue
118
  model_dict = get_model_dict(provider_item)
119
  if model_name in model_dict.keys():
120
+ weights_dict.update({provider_name + "/" + model_name: int(value)})
121
  elif model_name == "*":
122
+ weights_dict.update({provider_name + "/" + model_name: int(value) for model_item in model_dict.keys()})
123
 
124
  models.append(key)
125
  if isinstance(model, str):