yym68686 commited on
Commit
f2794a6
·
1 Parent(s): 656d0b5

✨ Feature: Add feature: Support model-level cooling. When a model under a channel reports an error, it does not affect other models under the same channel, only the model that reported the error is cooled.

Browse files
Files changed (1) hide show
  1. main.py +23 -13
main.py CHANGED
@@ -161,35 +161,43 @@ async def parse_request_body(request: Request):
161
 
162
  class ChannelManager:
163
  def __init__(self, cooldown_period: int = 300): # 默认冷却时间5分钟
164
- self._excluded_channels: Dict[str, datetime] = {}
165
  self._lock = asyncio.Lock()
166
  self.cooldown_period = cooldown_period
167
 
168
- async def exclude_channel(self, channel_id: str):
169
- """将渠道添加到排除列表"""
170
  async with self._lock:
171
- self._excluded_channels[channel_id] = datetime.now()
 
172
 
173
- async def is_channel_excluded(self, channel_id: str) -> bool:
174
- """检查渠道是否被排除"""
175
  async with self._lock:
176
- if channel_id not in self._excluded_channels:
 
177
  return False
178
 
179
- excluded_time = self._excluded_channels[channel_id]
180
  if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
181
  # 已超过冷却时间,移除限制
182
- del self._excluded_channels[channel_id]
183
  return False
184
  return True
185
 
186
  async def get_available_providers(self, providers: list) -> list:
187
- """过滤出可用的providers"""
188
  available_providers = []
189
  for provider in providers:
190
- channel_id = f"{provider['provider']}"
191
- if not await self.is_channel_excluded(channel_id):
 
 
 
 
 
192
  available_providers.append(provider)
 
193
  return available_providers
194
 
195
  from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
@@ -992,7 +1000,9 @@ class ModelRequestHandler:
992
 
993
  channel_id = f"{provider['provider']}"
994
  if app.state.channel_manager.cooldown_period > 0:
995
- await app.state.channel_manager.exclude_channel(channel_id)
 
 
996
  matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
997
  num_matching_providers = len(matching_providers)
998
  index = 0
 
161
 
162
  class ChannelManager:
163
  def __init__(self, cooldown_period: int = 300): # 默认冷却时间5分钟
164
+ self._excluded_models: Dict[str, datetime] = {}
165
  self._lock = asyncio.Lock()
166
  self.cooldown_period = cooldown_period
167
 
168
+ async def exclude_model(self, provider: str, model: str):
169
+ """将特定渠道下的特定模型添加到排除列表"""
170
  async with self._lock:
171
+ model_key = f"{provider}/{model}"
172
+ self._excluded_models[model_key] = datetime.now()
173
 
174
+ async def is_model_excluded(self, provider: str, model: str) -> bool:
175
+ """检查特定渠道下的特定模型是否被排除"""
176
  async with self._lock:
177
+ model_key = f"{provider}/{model}"
178
+ if model_key not in self._excluded_models:
179
  return False
180
 
181
+ excluded_time = self._excluded_models[model_key]
182
  if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
183
  # 已超过冷却时间,移除限制
184
+ del self._excluded_models[model_key]
185
  return False
186
  return True
187
 
188
  async def get_available_providers(self, providers: list) -> list:
189
+ """过滤出可用的providers,仅排除不可用的模型"""
190
  available_providers = []
191
  for provider in providers:
192
+ provider_name = provider['provider']
193
+ model_dict = provider['model'][0] # 获取唯一的模型字典
194
+ source_model = list(model_dict.keys())[0] # 源模型名称
195
+ # target_model = list(model_dict.values())[0] # 目标模型名称
196
+
197
+ # 检查该模型是否被排除
198
+ if not await self.is_model_excluded(provider_name, source_model):
199
  available_providers.append(provider)
200
+
201
  return available_providers
202
 
203
  from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
 
1000
 
1001
  channel_id = f"{provider['provider']}"
1002
  if app.state.channel_manager.cooldown_period > 0:
1003
+ # 获取源模型名称(实际配置的模型名)
1004
+ source_model = list(provider['model'][0].keys())[0]
1005
+ await app.state.channel_manager.exclude_model(channel_id, source_model)
1006
  matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
1007
  num_matching_providers = len(matching_providers)
1008
  index = 0