✨ Feature: Add feature: Add support for weighted load balancing.
Browse files- .gitignore +2 -1
- README.md +12 -1
- main.py +47 -6
- test/test_matplotlib.py +49 -0
- test/test_weights.py +33 -0
- utils.py +40 -2
.gitignore
CHANGED
@@ -7,4 +7,5 @@ node_modules
|
|
7 |
.wrangler
|
8 |
.pytest_cache
|
9 |
*.jpg
|
10 |
-
*.json
|
|
|
|
7 |
.wrangler
|
8 |
.pytest_cache
|
9 |
*.jpg
|
10 |
+
*.json
|
11 |
+
*.png
|
README.md
CHANGED
@@ -21,7 +21,7 @@
|
|
21 |
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
|
22 |
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
|
23 |
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
|
24 |
-
-
|
25 |
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
|
26 |
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
|
27 |
|
@@ -93,6 +93,17 @@ api_keys:
|
|
93 |
preferences:
|
94 |
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
95 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
```
|
97 |
|
98 |
## 环境变量
|
|
|
21 |
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
|
22 |
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
|
23 |
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
|
24 |
+
- 支持四种负载均衡。1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。
|
25 |
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
|
26 |
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
|
27 |
|
|
|
93 |
preferences:
|
94 |
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
|
95 |
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
|
96 |
+
|
97 |
+
# 渠道级加权负载均衡配置示例
|
98 |
+
- api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
|
99 |
+
model:
|
100 |
+
- gcp1/*: 5 # 冒号后面就是权重,权重仅支持正整数。
|
101 |
+
- gcp2/*: 3 # 数字的大小代表权重,数字越大,请求的概率越大。
|
102 |
+
- gcp3/*: 2 # 在该示例中,所有渠道加起来一共有 10 个权重,及 10 个请求里面有 5 个请求会请求 gcp1/* 模型,2 个请求会请求 gcp2/* 模型,3 个请求会请求 gcp3/* 模型。
|
103 |
+
|
104 |
+
preferences:
|
105 |
+
USE_ROUND_ROBIN: true # 当 USE_ROUND_ROBIN 必须为 true 并且上面的渠道后面没有权重时,会按照原始的渠道顺序请求,如果有权重,会按照加权后的顺序请求。
|
106 |
+
AUTO_RETRY: true
|
107 |
```
|
108 |
|
109 |
## 环境变量
|
main.py
CHANGED
@@ -12,7 +12,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
12 |
from models import RequestModel, ImageGenerationRequest
|
13 |
from request import get_payload
|
14 |
from response import fetch_response, fetch_response_stream
|
15 |
-
from utils import error_handling_wrapper, post_all_models, load_config
|
16 |
|
17 |
from typing import List, Dict, Union
|
18 |
from urllib.parse import urlparse
|
@@ -224,6 +224,29 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
224 |
|
225 |
raise e
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
import asyncio
|
228 |
class ModelRequestHandler:
|
229 |
def __init__(self):
|
@@ -297,13 +320,31 @@ class ModelRequestHandler:
|
|
297 |
|
298 |
# 检查是否启用轮询
|
299 |
api_index = api_list.index(token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
use_round_robin = True
|
301 |
auto_retry = True
|
302 |
-
if config
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
auto_retry = False
|
307 |
|
308 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
309 |
|
|
|
12 |
from models import RequestModel, ImageGenerationRequest
|
13 |
from request import get_payload
|
14 |
from response import fetch_response, fetch_response_stream
|
15 |
+
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
16 |
|
17 |
from typing import List, Dict, Union
|
18 |
from urllib.parse import urlparse
|
|
|
224 |
|
225 |
raise e
|
226 |
|
227 |
+
def weighted_round_robin(weights):
|
228 |
+
provider_names = list(weights.keys())
|
229 |
+
current_weights = {name: 0 for name in provider_names}
|
230 |
+
num_selections = total_weight = sum(weights.values())
|
231 |
+
weighted_provider_list = []
|
232 |
+
|
233 |
+
for _ in range(num_selections):
|
234 |
+
max_ratio = -1
|
235 |
+
selected_letter = None
|
236 |
+
|
237 |
+
for name in provider_names:
|
238 |
+
current_weights[name] += weights[name]
|
239 |
+
ratio = current_weights[name] / weights[name]
|
240 |
+
|
241 |
+
if ratio > max_ratio:
|
242 |
+
max_ratio = ratio
|
243 |
+
selected_letter = name
|
244 |
+
|
245 |
+
weighted_provider_list.append(selected_letter)
|
246 |
+
current_weights[selected_letter] -= total_weight
|
247 |
+
|
248 |
+
return weighted_provider_list
|
249 |
+
|
250 |
import asyncio
|
251 |
class ModelRequestHandler:
|
252 |
def __init__(self):
|
|
|
320 |
|
321 |
# 检查是否启用轮询
|
322 |
api_index = api_list.index(token)
|
323 |
+
weights = safe_get(config, 'api_keys', api_index, "weights")
|
324 |
+
if weights:
|
325 |
+
# 步骤 1: 提取 matching_providers 中的所有 provider 值
|
326 |
+
providers = set(provider['provider'] for provider in matching_providers)
|
327 |
+
weight_keys = set(weights.keys())
|
328 |
+
|
329 |
+
# 步骤 3: 计算交集
|
330 |
+
intersection = providers.intersection(weight_keys)
|
331 |
+
weights = dict(filter(lambda item: item[0] in intersection, weights.items()))
|
332 |
+
weighted_provider_name_list = weighted_round_robin(weights)
|
333 |
+
new_matching_providers = []
|
334 |
+
for provider_name in weighted_provider_name_list:
|
335 |
+
for provider in matching_providers:
|
336 |
+
if provider['provider'] == provider_name:
|
337 |
+
new_matching_providers.append(provider)
|
338 |
+
matching_providers = new_matching_providers
|
339 |
+
# import json
|
340 |
+
# print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
341 |
+
|
342 |
use_round_robin = True
|
343 |
auto_retry = True
|
344 |
+
if safe_get(config, 'api_keys', api_index, "preferences", "USE_ROUND_ROBIN") == False:
|
345 |
+
use_round_robin = False
|
346 |
+
if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False:
|
347 |
+
auto_retry = False
|
|
|
348 |
|
349 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)
|
350 |
|
test/test_matplotlib.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
import matplotlib.font_manager as fm
|
7 |
+
font_path = '/System/Library/Fonts/PingFang.ttc'
|
8 |
+
prop = fm.FontProperties(fname=font_path)
|
9 |
+
plt.rcParams['font.family'] = prop.get_name()
|
10 |
+
|
11 |
+
with open('./test/states.json') as f:
|
12 |
+
data = json.load(f)
|
13 |
+
request_arrivals = data["request_arrivals"]
|
14 |
+
|
15 |
+
def create_pic(request_arrivals, key):
|
16 |
+
request_arrivals = request_arrivals[key]
|
17 |
+
# 将字符串转换为datetime对象
|
18 |
+
datetimes = [datetime.fromisoformat(t) for t in request_arrivals]
|
19 |
+
# 获取最新的时间
|
20 |
+
latest_time = max(datetimes)
|
21 |
+
|
22 |
+
# 创建24小时的时间范围
|
23 |
+
time_range = [latest_time - timedelta(hours=i) for i in range(24, 0, -1)]
|
24 |
+
# 统计每小时的请求数
|
25 |
+
hourly_counts = defaultdict(int)
|
26 |
+
for dt in datetimes:
|
27 |
+
for t in time_range[::-1]:
|
28 |
+
if dt >= t:
|
29 |
+
hourly_counts[t] += 1
|
30 |
+
break
|
31 |
+
|
32 |
+
# 准备绘图数据
|
33 |
+
hours = [t.strftime('%Y-%m-%d %H:00') for t in time_range]
|
34 |
+
counts = [hourly_counts[t] for t in time_range]
|
35 |
+
|
36 |
+
# 创建柱状图
|
37 |
+
plt.figure(figsize=(15, 6))
|
38 |
+
plt.bar(hours, counts)
|
39 |
+
plt.title(f'{key} 端点请求量 (过去24小时)')
|
40 |
+
plt.xlabel('时间')
|
41 |
+
plt.ylabel('请求数')
|
42 |
+
plt.xticks(rotation=45, ha='right')
|
43 |
+
plt.tight_layout()
|
44 |
+
|
45 |
+
# 保存图片
|
46 |
+
plt.savefig(f'{key.replace("/", "")}.png')
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
create_pic(request_arrivals, 'POST /v1/chat/completions')
|
test/test_weights.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def weighted_round_robin(weights):
|
2 |
+
provider_names = list(weights.keys())
|
3 |
+
current_weights = {name: 0 for name in provider_names}
|
4 |
+
num_selections = total_weight = sum(weights.values())
|
5 |
+
weighted_provider_list = []
|
6 |
+
|
7 |
+
for _ in range(num_selections):
|
8 |
+
max_ratio = -1
|
9 |
+
selected_letter = None
|
10 |
+
|
11 |
+
for name in provider_names:
|
12 |
+
current_weights[name] += weights[name]
|
13 |
+
ratio = current_weights[name] / weights[name]
|
14 |
+
|
15 |
+
if ratio > max_ratio:
|
16 |
+
max_ratio = ratio
|
17 |
+
selected_letter = name
|
18 |
+
|
19 |
+
weighted_provider_list.append(selected_letter)
|
20 |
+
current_weights[selected_letter] -= total_weight
|
21 |
+
|
22 |
+
return weighted_provider_list
|
23 |
+
|
24 |
+
# 权重和选择次数
|
25 |
+
weights = {'a': 5, 'b': 3, 'c': 2}
|
26 |
+
index = {'a', 'c'}
|
27 |
+
|
28 |
+
result = dict(filter(lambda item: item[0] in index, weights.items()))
|
29 |
+
print(result)
|
30 |
+
# result = {k: weights[k] for k in index if k in weights}
|
31 |
+
# print(result)
|
32 |
+
weighted_provider_list = weighted_round_robin(weights)
|
33 |
+
print(weighted_provider_list)
|
utils.py
CHANGED
@@ -25,8 +25,25 @@ def update_config(config_data):
|
|
25 |
config_data['providers'][index] = provider
|
26 |
|
27 |
api_keys_db = config_data['api_keys']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
api_list = [item["api"] for item in api_keys_db]
|
29 |
-
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False))
|
30 |
return config_data, api_keys_db, api_list
|
31 |
|
32 |
# 读取YAML配置文件
|
@@ -214,6 +231,12 @@ def get_all_models(config):
|
|
214 |
# us-central1
|
215 |
# europe-west1
|
216 |
# europe-west4
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
from collections import deque
|
218 |
class CircularList:
|
219 |
def __init__(self, items):
|
@@ -226,6 +249,13 @@ class CircularList:
|
|
226 |
self.queue.append(item)
|
227 |
return item
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
c35s = CircularList(["us-east5", "europe-west1"])
|
230 |
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
|
231 |
c3o = CircularList(["us-east5"])
|
@@ -256,4 +286,12 @@ class BaseAPI:
|
|
256 |
else:
|
257 |
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
258 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
259 |
-
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
config_data['providers'][index] = provider
|
26 |
|
27 |
api_keys_db = config_data['api_keys']
|
28 |
+
|
29 |
+
for index, api_key in enumerate(config_data['api_keys']):
|
30 |
+
weights_dict = {}
|
31 |
+
models = []
|
32 |
+
for model in api_key.get('model'):
|
33 |
+
if isinstance(model, dict):
|
34 |
+
key, value = list(model.items())[0]
|
35 |
+
provider_name = key.split("/")[0]
|
36 |
+
if "/" in key:
|
37 |
+
weights_dict.update({provider_name: int(value)})
|
38 |
+
models.append(key)
|
39 |
+
if isinstance(model, str):
|
40 |
+
models.append(model)
|
41 |
+
config_data['api_keys'][index]['weights'] = weights_dict
|
42 |
+
config_data['api_keys'][index]['model'] = models
|
43 |
+
api_keys_db[index]['model'] = models
|
44 |
+
|
45 |
api_list = [item["api"] for item in api_keys_db]
|
46 |
+
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
47 |
return config_data, api_keys_db, api_list
|
48 |
|
49 |
# 读取YAML配置文件
|
|
|
231 |
# us-central1
|
232 |
# europe-west1
|
233 |
# europe-west4
|
234 |
+
|
235 |
+
def circular_list_encoder(obj):
|
236 |
+
if isinstance(obj, CircularList):
|
237 |
+
return obj.to_dict()
|
238 |
+
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
|
239 |
+
|
240 |
from collections import deque
|
241 |
class CircularList:
|
242 |
def __init__(self, items):
|
|
|
249 |
self.queue.append(item)
|
250 |
return item
|
251 |
|
252 |
+
def to_dict(self):
|
253 |
+
return {
|
254 |
+
'queue': list(self.queue)
|
255 |
+
}
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
c35s = CircularList(["us-east5", "europe-west1"])
|
260 |
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
|
261 |
c3o = CircularList(["us-east5"])
|
|
|
286 |
else:
|
287 |
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
288 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
289 |
+
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
290 |
+
|
291 |
+
def safe_get(data, *keys):
|
292 |
+
for key in keys:
|
293 |
+
try:
|
294 |
+
data = data[key] if isinstance(data, (dict, list)) else data.get(key)
|
295 |
+
except (KeyError, IndexError, AttributeError, TypeError):
|
296 |
+
return None
|
297 |
+
return data
|