✨ Feature: Add feature: Support model-level cooling. When a model under a channel reports an error, it does not affect other models under the same channel, only the model that reported the error is cooled.
Browse files
main.py
CHANGED
@@ -161,35 +161,43 @@ async def parse_request_body(request: Request):
|
|
161 |
|
162 |
class ChannelManager:
|
163 |
def __init__(self, cooldown_period: int = 300): # 默认冷却时间5分钟
|
164 |
-
self.
|
165 |
self._lock = asyncio.Lock()
|
166 |
self.cooldown_period = cooldown_period
|
167 |
|
168 |
-
async def
|
169 |
-
"""
|
170 |
async with self._lock:
|
171 |
-
|
|
|
172 |
|
173 |
-
async def
|
174 |
-
"""
|
175 |
async with self._lock:
|
176 |
-
|
|
|
177 |
return False
|
178 |
|
179 |
-
excluded_time = self.
|
180 |
if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
|
181 |
# 已超过冷却时间,移除限制
|
182 |
-
del self.
|
183 |
return False
|
184 |
return True
|
185 |
|
186 |
async def get_available_providers(self, providers: list) -> list:
|
187 |
-
"""过滤出可用的providers"""
|
188 |
available_providers = []
|
189 |
for provider in providers:
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
192 |
available_providers.append(provider)
|
|
|
193 |
return available_providers
|
194 |
|
195 |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
@@ -992,7 +1000,9 @@ class ModelRequestHandler:
|
|
992 |
|
993 |
channel_id = f"{provider['provider']}"
|
994 |
if app.state.channel_manager.cooldown_period > 0:
|
995 |
-
|
|
|
|
|
996 |
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
997 |
num_matching_providers = len(matching_providers)
|
998 |
index = 0
|
|
|
161 |
|
162 |
class ChannelManager:
|
163 |
def __init__(self, cooldown_period: int = 300): # 默认冷却时间5分钟
|
164 |
+
self._excluded_models: Dict[str, datetime] = {}
|
165 |
self._lock = asyncio.Lock()
|
166 |
self.cooldown_period = cooldown_period
|
167 |
|
168 |
+
async def exclude_model(self, provider: str, model: str):
|
169 |
+
"""将特定渠道下的特定模型添加到排除列表"""
|
170 |
async with self._lock:
|
171 |
+
model_key = f"{provider}/{model}"
|
172 |
+
self._excluded_models[model_key] = datetime.now()
|
173 |
|
174 |
+
async def is_model_excluded(self, provider: str, model: str) -> bool:
|
175 |
+
"""检查特定渠道下的特定模型是否被排除"""
|
176 |
async with self._lock:
|
177 |
+
model_key = f"{provider}/{model}"
|
178 |
+
if model_key not in self._excluded_models:
|
179 |
return False
|
180 |
|
181 |
+
excluded_time = self._excluded_models[model_key]
|
182 |
if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
|
183 |
# 已超过冷却时间,移除限制
|
184 |
+
del self._excluded_models[model_key]
|
185 |
return False
|
186 |
return True
|
187 |
|
188 |
async def get_available_providers(self, providers: list) -> list:
|
189 |
+
"""过滤出可用的providers,仅排除不可用的模型"""
|
190 |
available_providers = []
|
191 |
for provider in providers:
|
192 |
+
provider_name = provider['provider']
|
193 |
+
model_dict = provider['model'][0] # 获取唯一的模型字典
|
194 |
+
source_model = list(model_dict.keys())[0] # 源模型名称
|
195 |
+
# target_model = list(model_dict.values())[0] # 目标模型名称
|
196 |
+
|
197 |
+
# 检查该模型是否被排除
|
198 |
+
if not await self.is_model_excluded(provider_name, source_model):
|
199 |
available_providers.append(provider)
|
200 |
+
|
201 |
return available_providers
|
202 |
|
203 |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
|
1000 |
|
1001 |
channel_id = f"{provider['provider']}"
|
1002 |
if app.state.channel_manager.cooldown_period > 0:
|
1003 |
+
# 获取源模型名称(实际配置的模型名)
|
1004 |
+
source_model = list(provider['model'][0].keys())[0]
|
1005 |
+
await app.state.channel_manager.exclude_model(channel_id, source_model)
|
1006 |
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
|
1007 |
num_matching_providers = len(matching_providers)
|
1008 |
index = 0
|