dray commited on
Commit
3e27f73
·
1 Parent(s): 331ab5a

Refactor: Add support for retrieving all models when the model name is "*"

Browse files
Files changed (2) hide show
  1. main.py +6 -0
  2. utils.py +5 -1
main.py CHANGED
@@ -419,6 +419,12 @@ class ModelRequestHandler:
419
  provider_rules = []
420
 
421
  for model in config['api_keys'][api_index]['model']:
 
 
 
 
 
 
422
  if "/" in model:
423
  if model.startswith("<") and model.endswith(">"):
424
  model = model[1:-1]
 
419
  provider_rules = []
420
 
421
  for model in config['api_keys'][api_index]['model']:
422
+ if model == "*":
423
+ # 如果模型名为 *,则返回所有模型
424
+ for provider in config["providers"]:
425
+ for model in provider["model"].keys():
426
+ provider_rules.append(provider["provider"] + "/" + model)
427
+ break
428
  if "/" in model:
429
  if model.startswith("<") and model.endswith(">"):
430
  model = model[1:-1]
utils.py CHANGED
@@ -62,7 +62,7 @@ async def load_config(app=None):
62
  # is_quoted = not token.plain
63
  # print(f"值: {value}, 是否被引号包裹: {is_quoted}")
64
 
65
- with open('./api.yaml', 'r') as f:
66
  # 判断是否为空文件
67
  conf = yaml.safe_load(f)
68
  # conf = None
@@ -170,6 +170,10 @@ def post_all_models(token, config, api_list):
170
  api_index = api_list.index(token)
171
  if config['api_keys'][api_index]['model']:
172
  for model in config['api_keys'][api_index]['model']:
 
 
 
 
173
  if "/" in model:
174
  provider = model.split("/")[0]
175
  model = model.split("/")[1]
 
62
  # is_quoted = not token.plain
63
  # print(f"值: {value}, 是否被引号包裹: {is_quoted}")
64
 
65
+ with open("./api.yaml", "r", encoding="utf-8") as f:
66
  # 判断是否为空文件
67
  conf = yaml.safe_load(f)
68
  # conf = None
 
170
  api_index = api_list.index(token)
171
  if config['api_keys'][api_index]['model']:
172
  for model in config['api_keys'][api_index]['model']:
173
+ if model == "*":
174
+ # 如果模型名为 *,则返回所有模型
175
+ all_models = get_all_models(config)
176
+ return all_models
177
  if "/" in model:
178
  provider = model.split("/")[0]
179
  model = model.split("/")[1]