yym68686 commited on
Commit
3ec7a0b
·
1 Parent(s): 73a667f

✨ Feature: Add feature: Add support for weighted load balancing.

Browse files
Files changed (6) hide show
  1. .gitignore +2 -1
  2. README.md +12 -1
  3. main.py +47 -6
  4. test/test_matplotlib.py +49 -0
  5. test/test_weights.py +33 -0
  6. utils.py +40 -2
.gitignore CHANGED
@@ -7,4 +7,5 @@ node_modules
7
  .wrangler
8
  .pytest_cache
9
  *.jpg
10
- *.json
 
 
7
  .wrangler
8
  .pytest_cache
9
  *.jpg
10
+ *.json
11
+ *.png
README.md CHANGED
@@ -21,7 +21,7 @@
21
  - 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
22
  - 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
23
  - 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
24
- - 支持三种负载均衡,默认同时开启。1. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级负载均衡,提高沉浸式翻译体验。
25
  - 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
26
  - 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
27
 
@@ -93,6 +93,17 @@ api_keys:
93
  preferences:
94
  USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
95
  AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
 
 
 
 
 
 
 
 
 
 
 
96
  ```
97
 
98
  ## 环境变量
 
21
  - 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
22
  - 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
23
  - 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
24
+ - 支持四种负载均衡。1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。
25
  - 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
26
  - 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
27
 
 
93
  preferences:
94
  USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
95
  AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
96
+
97
+ # 渠道级加权负载均衡配置示例
98
+ - api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
99
+ model:
100
+ - gcp1/*: 5 # 冒号后面就是权重,权重仅支持正整数。
101
+ - gcp2/*: 3 # 数字的大小代表权重,数字越大,请求的概率越大。
102
+ - gcp3/*: 2 # 在该示例中,所有渠道加起来一共有 10 个权重,及 10 个请求里面有 5 个请求会请求 gcp1/* 模型,2 个请求会请求 gcp2/* 模型,3 个请求会请求 gcp3/* 模型。
103
+
104
+ preferences:
105
+ USE_ROUND_ROBIN: true # 当 USE_ROUND_ROBIN 必须为 true 并且上面的渠道后面没有权重时,会按照原始的渠道顺序请求,如果有权重,会按照加权后的顺序请求。
106
+ AUTO_RETRY: true
107
  ```
108
 
109
  ## 环境变量
main.py CHANGED
@@ -12,7 +12,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
  from models import RequestModel, ImageGenerationRequest
13
  from request import get_payload
14
  from response import fetch_response, fetch_response_stream
15
- from utils import error_handling_wrapper, post_all_models, load_config
16
 
17
  from typing import List, Dict, Union
18
  from urllib.parse import urlparse
@@ -224,6 +224,29 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
224
 
225
  raise e
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  import asyncio
228
  class ModelRequestHandler:
229
  def __init__(self):
@@ -297,13 +320,31 @@ class ModelRequestHandler:
297
 
298
  # 检查是否启用轮询
299
  api_index = api_list.index(token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  use_round_robin = True
301
  auto_retry = True
302
- if config['api_keys'][api_index].get("preferences"):
303
- if config['api_keys'][api_index]["preferences"].get("USE_ROUND_ROBIN") == False:
304
- use_round_robin = False
305
- if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
306
- auto_retry = False
307
 
308
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
309
 
 
12
  from models import RequestModel, ImageGenerationRequest
13
  from request import get_payload
14
  from response import fetch_response, fetch_response_stream
15
+ from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
16
 
17
  from typing import List, Dict, Union
18
  from urllib.parse import urlparse
 
224
 
225
  raise e
226
 
227
+ def weighted_round_robin(weights):
228
+ provider_names = list(weights.keys())
229
+ current_weights = {name: 0 for name in provider_names}
230
+ num_selections = total_weight = sum(weights.values())
231
+ weighted_provider_list = []
232
+
233
+ for _ in range(num_selections):
234
+ max_ratio = -1
235
+ selected_letter = None
236
+
237
+ for name in provider_names:
238
+ current_weights[name] += weights[name]
239
+ ratio = current_weights[name] / weights[name]
240
+
241
+ if ratio > max_ratio:
242
+ max_ratio = ratio
243
+ selected_letter = name
244
+
245
+ weighted_provider_list.append(selected_letter)
246
+ current_weights[selected_letter] -= total_weight
247
+
248
+ return weighted_provider_list
249
+
250
  import asyncio
251
  class ModelRequestHandler:
252
  def __init__(self):
 
320
 
321
  # 检查是否启用轮询
322
  api_index = api_list.index(token)
323
+ weights = safe_get(config, 'api_keys', api_index, "weights")
324
+ if weights:
325
+ # 步骤 1: 提取 matching_providers 中的所有 provider 值
326
+ providers = set(provider['provider'] for provider in matching_providers)
327
+ weight_keys = set(weights.keys())
328
+
329
+ # 步骤 3: 计算交集
330
+ intersection = providers.intersection(weight_keys)
331
+ weights = dict(filter(lambda item: item[0] in intersection, weights.items()))
332
+ weighted_provider_name_list = weighted_round_robin(weights)
333
+ new_matching_providers = []
334
+ for provider_name in weighted_provider_name_list:
335
+ for provider in matching_providers:
336
+ if provider['provider'] == provider_name:
337
+ new_matching_providers.append(provider)
338
+ matching_providers = new_matching_providers
339
+ # import json
340
+ # print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False, default=circular_list_encoder))
341
+
342
  use_round_robin = True
343
  auto_retry = True
344
+ if safe_get(config, 'api_keys', api_index, "preferences", "USE_ROUND_ROBIN") == False:
345
+ use_round_robin = False
346
+ if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False:
347
+ auto_retry = False
 
348
 
349
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
350
 
test/test_matplotlib.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import matplotlib.pyplot as plt
3
+ from datetime import datetime, timedelta
4
+ from collections import defaultdict
5
+
6
+ import matplotlib.font_manager as fm
7
+ font_path = '/System/Library/Fonts/PingFang.ttc'
8
+ prop = fm.FontProperties(fname=font_path)
9
+ plt.rcParams['font.family'] = prop.get_name()
10
+
11
+ with open('./test/states.json') as f:
12
+ data = json.load(f)
13
+ request_arrivals = data["request_arrivals"]
14
+
15
+ def create_pic(request_arrivals, key):
16
+ request_arrivals = request_arrivals[key]
17
+ # 将字符串转换为datetime对象
18
+ datetimes = [datetime.fromisoformat(t) for t in request_arrivals]
19
+ # 获取最新的时间
20
+ latest_time = max(datetimes)
21
+
22
+ # 创建24小时的时间范围
23
+ time_range = [latest_time - timedelta(hours=i) for i in range(24, 0, -1)]
24
+ # 统计每小时的请求数
25
+ hourly_counts = defaultdict(int)
26
+ for dt in datetimes:
27
+ for t in time_range[::-1]:
28
+ if dt >= t:
29
+ hourly_counts[t] += 1
30
+ break
31
+
32
+ # 准备绘图数据
33
+ hours = [t.strftime('%Y-%m-%d %H:00') for t in time_range]
34
+ counts = [hourly_counts[t] for t in time_range]
35
+
36
+ # 创建柱状图
37
+ plt.figure(figsize=(15, 6))
38
+ plt.bar(hours, counts)
39
+ plt.title(f'{key} 端点请求量 (过去24小时)')
40
+ plt.xlabel('时间')
41
+ plt.ylabel('请求数')
42
+ plt.xticks(rotation=45, ha='right')
43
+ plt.tight_layout()
44
+
45
+ # 保存图片
46
+ plt.savefig(f'{key.replace("/", "")}.png')
47
+
48
+ if __name__ == '__main__':
49
+ create_pic(request_arrivals, 'POST /v1/chat/completions')
test/test_weights.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def weighted_round_robin(weights):
2
+ provider_names = list(weights.keys())
3
+ current_weights = {name: 0 for name in provider_names}
4
+ num_selections = total_weight = sum(weights.values())
5
+ weighted_provider_list = []
6
+
7
+ for _ in range(num_selections):
8
+ max_ratio = -1
9
+ selected_letter = None
10
+
11
+ for name in provider_names:
12
+ current_weights[name] += weights[name]
13
+ ratio = current_weights[name] / weights[name]
14
+
15
+ if ratio > max_ratio:
16
+ max_ratio = ratio
17
+ selected_letter = name
18
+
19
+ weighted_provider_list.append(selected_letter)
20
+ current_weights[selected_letter] -= total_weight
21
+
22
+ return weighted_provider_list
23
+
24
+ # 权重和选择次数
25
+ weights = {'a': 5, 'b': 3, 'c': 2}
26
+ index = {'a', 'c'}
27
+
28
+ result = dict(filter(lambda item: item[0] in index, weights.items()))
29
+ print(result)
30
+ # result = {k: weights[k] for k in index if k in weights}
31
+ # print(result)
32
+ weighted_provider_list = weighted_round_robin(weights)
33
+ print(weighted_provider_list)
utils.py CHANGED
@@ -25,8 +25,25 @@ def update_config(config_data):
25
  config_data['providers'][index] = provider
26
 
27
  api_keys_db = config_data['api_keys']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  api_list = [item["api"] for item in api_keys_db]
29
- # logger.info(json.dumps(config_data, indent=4, ensure_ascii=False))
30
  return config_data, api_keys_db, api_list
31
 
32
  # 读取YAML配置文件
@@ -214,6 +231,12 @@ def get_all_models(config):
214
  # us-central1
215
  # europe-west1
216
  # europe-west4
 
 
 
 
 
 
217
  from collections import deque
218
  class CircularList:
219
  def __init__(self, items):
@@ -226,6 +249,13 @@ class CircularList:
226
  self.queue.append(item)
227
  return item
228
 
 
 
 
 
 
 
 
229
  c35s = CircularList(["us-east5", "europe-west1"])
230
  c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
231
  c3o = CircularList(["us-east5"])
@@ -256,4 +286,12 @@ class BaseAPI:
256
  else:
257
  self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
258
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
259
- self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
 
 
 
 
 
 
 
 
 
25
  config_data['providers'][index] = provider
26
 
27
  api_keys_db = config_data['api_keys']
28
+
29
+ for index, api_key in enumerate(config_data['api_keys']):
30
+ weights_dict = {}
31
+ models = []
32
+ for model in api_key.get('model'):
33
+ if isinstance(model, dict):
34
+ key, value = list(model.items())[0]
35
+ provider_name = key.split("/")[0]
36
+ if "/" in key:
37
+ weights_dict.update({provider_name: int(value)})
38
+ models.append(key)
39
+ if isinstance(model, str):
40
+ models.append(model)
41
+ config_data['api_keys'][index]['weights'] = weights_dict
42
+ config_data['api_keys'][index]['model'] = models
43
+ api_keys_db[index]['model'] = models
44
+
45
  api_list = [item["api"] for item in api_keys_db]
46
+ # logger.info(json.dumps(config_data, indent=4, ensure_ascii=False, default=circular_list_encoder))
47
  return config_data, api_keys_db, api_list
48
 
49
  # 读取YAML配置文件
 
231
  # us-central1
232
  # europe-west1
233
  # europe-west4
234
+
235
+ def circular_list_encoder(obj):
236
+ if isinstance(obj, CircularList):
237
+ return obj.to_dict()
238
+ raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
239
+
240
  from collections import deque
241
  class CircularList:
242
  def __init__(self, items):
 
249
  self.queue.append(item)
250
  return item
251
 
252
+ def to_dict(self):
253
+ return {
254
+ 'queue': list(self.queue)
255
+ }
256
+
257
+
258
+
259
  c35s = CircularList(["us-east5", "europe-west1"])
260
  c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
261
  c3o = CircularList(["us-east5"])
 
286
  else:
287
  self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
288
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
289
+ self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
290
+
291
+ def safe_get(data, *keys):
292
+ for key in keys:
293
+ try:
294
+ data = data[key] if isinstance(data, (dict, list)) else data.get(key)
295
+ except (KeyError, IndexError, AttributeError, TypeError):
296
+ return None
297
+ return data