import yaml |
import json |
from fastapi import HTTPException, Depends |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
def load_config(): |
try: |
with open('./api.yaml', 'r') as f: |
conf = yaml.safe_load(f) |
for index, provider in enumerate(conf['providers']): |
model_dict = {} |
for model in provider['model']: |
if type(model) == str: |
model_dict[model] = model |
if type(model) == dict: |
model_dict.update({new: old for old, new in model.items()}) |
model_dict.update({old: old for old, new in model.items()}) |
provider['model'] = model_dict |
conf['providers'][index] = provider |
api_keys_db = conf['api_keys'] |
api_list = [item["api"] for item in api_keys_db] |
return conf, api_keys_db, api_list |
except FileNotFoundError: |
print("配置文件 'api.yaml' 未找到。请确保文件存在于正确的位置。") |
return [], [], [] |
except yaml.YAMLError: |
print("配置文件 'api.yaml' 格式不正确。请检查 YAML 格式。") |
return [], [], [] |
config, api_keys_db, api_list = load_config() |
async def error_handling_wrapper(generator, status_code=200): |
try: |
first_item = await generator.__anext__() |
first_item_str = first_item |
if isinstance(first_item_str, (bytes, bytearray)): |
first_item_str = first_item_str.decode("utf-8") |
if isinstance(first_item_str, str): |
if first_item_str.startswith("data: "): |
first_item_str = first_item_str[6:] |
elif first_item_str.startswith("data:"): |
first_item_str = first_item_str[5:] |
first_item_str = json.loads(first_item_str) |
if isinstance(first_item_str, dict) and 'error' in first_item_str: |
print('\033[31m') |
print(f"first_item_str: {first_item_str}") |
print('\033[0m') |
raise HTTPException(status_code=status_code, detail=first_item_str) |
async def new_generator(): |
yield first_item |
async for item in generator: |
yield item |
return new_generator() |
except StopAsyncIteration: |
return [] |
def get_all_models(token): |
all_models = [] |
unique_models = set() |
if token not in api_list: |
raise HTTPException(status_code=403, detail="Invalid or missing API Key") |
api_index = api_list.index(token) |
if config['api_keys'][api_index]['model']: |
for model in config['api_keys'][api_index]['model']: |
if "/" in model: |
provider = model.split("/")[0] |
model = model.split("/")[1] |
if model == "*": |
for provider_item in config["providers"]: |
if provider_item['provider'] != provider: |
continue |
for model_item in provider_item['model'].keys(): |
if model_item not in unique_models: |
unique_models.add(model_item) |
model_info = { |
"id": model_item, |
"object": "model", |
"created": 1720524448858, |
"owned_by": provider_item['provider'] |
} |
all_models.append(model_info) |
else: |
for provider_item in config["providers"]: |
if provider_item['provider'] != provider: |
continue |
for model_item in provider_item['model'].keys() : |
if model_item not in unique_models and model_item == model: |
unique_models.add(model_item) |
model_info = { |
"id": model_item, |
"object": "model", |
"created": 1720524448858, |
"owned_by": provider_item['provider'] |
} |
all_models.append(model_info) |
continue |
if model not in unique_models: |
unique_models.add(model) |
model_info = { |
"id": model, |
"object": "model", |
"created": 1720524448858, |
"owned_by": model |
} |
all_models.append(model_info) |
else: |
for provider in config["providers"]: |
for model in provider['model'].keys(): |
if model not in unique_models: |
unique_models.add(model) |
model_info = { |
"id": model, |
"object": "model", |
"created": 1720524448858, |
"owned_by": provider['provider'] |
} |
all_models.append(model_info) |
return all_models |
security = HTTPBearer() |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
token = credentials.credentials |
if token not in api_list: |
raise HTTPException(status_code=403, detail="Invalid or missing API Key") |
return token |