dray
commited on
Commit
·
3e27f73
1
Parent(s):
331ab5a
Refactor: Add support for retrieving all models when the model name is "*"
Browse files
main.py
CHANGED
@@ -419,6 +419,12 @@ class ModelRequestHandler:
|
|
419 |
provider_rules = []
|
420 |
|
421 |
for model in config['api_keys'][api_index]['model']:
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
if "/" in model:
|
423 |
if model.startswith("<") and model.endswith(">"):
|
424 |
model = model[1:-1]
|
|
|
419 |
provider_rules = []
|
420 |
|
421 |
for model in config['api_keys'][api_index]['model']:
|
422 |
+
if model == "*":
|
423 |
+
# 如果模型名为 *,则返回所有模型
|
424 |
+
for provider in config["providers"]:
|
425 |
+
for model in provider["model"].keys():
|
426 |
+
provider_rules.append(provider["provider"] + "/" + model)
|
427 |
+
break
|
428 |
if "/" in model:
|
429 |
if model.startswith("<") and model.endswith(">"):
|
430 |
model = model[1:-1]
|
utils.py
CHANGED
@@ -62,7 +62,7 @@ async def load_config(app=None):
|
|
62 |
# is_quoted = not token.plain
|
63 |
# print(f"值: {value}, 是否被引号包裹: {is_quoted}")
|
64 |
|
65 |
-
with open(
|
66 |
# 判断是否为空文件
|
67 |
conf = yaml.safe_load(f)
|
68 |
# conf = None
|
@@ -170,6 +170,10 @@ def post_all_models(token, config, api_list):
|
|
170 |
api_index = api_list.index(token)
|
171 |
if config['api_keys'][api_index]['model']:
|
172 |
for model in config['api_keys'][api_index]['model']:
|
|
|
|
|
|
|
|
|
173 |
if "/" in model:
|
174 |
provider = model.split("/")[0]
|
175 |
model = model.split("/")[1]
|
|
|
62 |
# is_quoted = not token.plain
|
63 |
# print(f"值: {value}, 是否被引号包裹: {is_quoted}")
|
64 |
|
65 |
+
with open("./api.yaml", "r", encoding="utf-8") as f:
|
66 |
# 判断是否为空文件
|
67 |
conf = yaml.safe_load(f)
|
68 |
# conf = None
|
|
|
170 |
api_index = api_list.index(token)
|
171 |
if config['api_keys'][api_index]['model']:
|
172 |
for model in config['api_keys'][api_index]['model']:
|
173 |
+
if model == "*":
|
174 |
+
# 如果模型名为 *,则返回所有模型
|
175 |
+
all_models = get_all_models(config)
|
176 |
+
return all_models
|
177 |
if "/" in model:
|
178 |
provider = model.split("/")[0]
|
179 |
model = model.split("/")[1]
|