hagemi / app /utils.py
win7win's picture
Upload 10 files
65dd154 verified
import random
from fastapi import HTTPException, Request
import time
import re
from datetime import datetime, timedelta
from apscheduler.schedulers.background import BackgroundScheduler
import os
import requests
import httpx
from threading import Lock
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
def format_log_message(level, message, extra=None):
"""格式化日志消息,模拟之前的 logging 格式"""
log_values = {
'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), # 模拟 asctime
'levelname': level, # 日志级别
'key': extra.get('key', 'N/A') if extra else 'N/A',
'request_type': extra.get('request_type', 'N/A') if extra else 'N/A',
'model': extra.get('model', 'N/A') if extra else 'N/A',
'status_code': extra.get('status_code', 'N/A') if extra else 'N/A',
'error_message': extra.get('error_message', '') if extra else '' ,
'message': message
}
log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
return log_format % log_values
class APIKeyManager:
def __init__(self):
self.api_keys = re.findall(
r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', ""))
self.key_stack = [] # 初始化密钥栈
self._reset_key_stack() # 初始化时创建随机密钥栈
self.api_key_blacklist = set()
self.api_key_blacklist_duration = 60
self.scheduler = BackgroundScheduler()
self.scheduler.start()
self.tried_keys_for_request = set() # 用于跟踪当前请求尝试中已试过的 key
def _reset_key_stack(self):
"""创建并随机化密钥栈"""
shuffled_keys = self.api_keys[:] # 创建 api_keys 的副本以避免直接修改原列表
random.shuffle(shuffled_keys)
self.key_stack = shuffled_keys
log_msg = format_log_message('INFO', "已重新创建随机密钥栈", extra={'request_type': 'key_stack', 'status_code': 'N/A'})
print(log_msg)
def get_available_key(self):
"""从栈顶获取密钥,栈空时重新生成 (修改后)"""
while self.key_stack:
key = self.key_stack.pop()
if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
self.tried_keys_for_request.add(key)
return key
if not self.api_keys:
log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!")
print(log_msg)
return None
log_msg = format_log_message('WARNING', "密钥栈已用尽或栈内密钥均不可用,重新生成密钥栈")
print(log_msg)
self._reset_key_stack() # 重新生成密钥栈
# 再次尝试从新栈中获取密钥 (迭代一次)
while self.key_stack:
key = self.key_stack.pop()
if key not in self.api_key_blacklist and key not in self.tried_keys_for_request:
self.tried_keys_for_request.add(key)
return key
return None
def show_all_keys(self):
log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ")
print(log_msg)
for i, api_key in enumerate(self.api_keys):
log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}")
print(log_msg)
def blacklist_key(self, key):
log_msg = format_log_message('WARNING', f"{key[:8]} → 暂时禁用 {self.api_key_blacklist_duration} 秒")
print(log_msg)
self.api_key_blacklist.add(key)
self.scheduler.add_job(lambda: self.api_key_blacklist.discard(key), 'date',
run_date=datetime.now() + timedelta(seconds=self.api_key_blacklist_duration))
def reset_tried_keys_for_request(self):
"""在新的请求尝试时重置已尝试的 key 集合"""
self.tried_keys_for_request = set()
def handle_gemini_error(error, current_api_key, key_manager, switch_api_key_func) -> str:
if isinstance(error, requests.exceptions.HTTPError):
status_code = error.response.status_code
if status_code == 400:
try:
error_data = error.response.json()
if 'error' in error_data:
if error_data['error'].get('code') == "invalid_argument":
error_message = "无效的 API 密钥"
extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key)
print(log_msg)
key_manager.blacklist_key(current_api_key)
switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
return error_message
error_message = error_data['error'].get(
'message', 'Bad Request')
extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400)
print(log_msg)
return f"400 错误请求: {error_message}"
except ValueError:
error_message = "400 错误请求:响应不是有效的JSON格式"
extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json)
print(log_msg)
return error_message
elif status_code == 429:
error_message = "API 密钥配额已用尽"
extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽", extra=extra_log_429)
print(log_msg)
key_manager.blacklist_key(current_api_key)
switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
return error_message
elif status_code == 403:
error_message = "权限被拒绝"
extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403)
print(log_msg)
key_manager.blacklist_key(current_api_key)
switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
return error_message
elif status_code == 500:
error_message = "服务器内部错误"
extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500)
print(log_msg)
switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
return "Gemini API 内部错误"
elif status_code == 503:
error_message = "服务不可用"
extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503)
print(log_msg)
switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
return "Gemini API 服务不可用"
else:
error_message = f"未知错误: {status_code}"
extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]}{status_code} 未知错误", extra=extra_log_other)
print(log_msg)
switch_api_key_func() # 这里虽然叫 switch_api_key_func, 但实际上 get_available_key 会处理栈和重新生成
return f"未知错误/模型不可用: {status_code}"
elif isinstance(error, requests.exceptions.ConnectionError):
error_message = "连接错误"
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
print(log_msg)
return error_message
elif isinstance(error, requests.exceptions.Timeout):
error_message = "请求超时"
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
print(log_msg)
return error_message
else:
error_message = f"发生未知错误: {error}"
log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message})
print(log_msg)
return error_message
async def test_api_key(api_key: str) -> bool:
"""
测试 API 密钥是否有效。
"""
try:
url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key)
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
return True
except Exception:
return False
rate_limit_data = {}
rate_limit_lock = Lock()
def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600):
now = int(time.time())
minute = now // 60
day = now // (60 * 60 * 24)
minute_key = f"{request.url.path}:{minute}"
day_key = f"{request.client.host}:{day}"
with rate_limit_lock:
minute_count, minute_timestamp = rate_limit_data.get(
minute_key, (0, now))
if now - minute_timestamp >= 60:
minute_count = 0
minute_timestamp = now
minute_count += 1
rate_limit_data[minute_key] = (minute_count, minute_timestamp)
day_count, day_timestamp = rate_limit_data.get(day_key, (0, now))
if now - day_timestamp >= 86400:
day_count = 0
day_timestamp = now
day_count += 1
rate_limit_data[day_key] = (day_count, day_timestamp)
if minute_count > max_requests_per_minute:
raise HTTPException(status_code=429, detail={
"message": "Too many requests per minute", "limit": max_requests_per_minute})
if day_count > max_requests_per_day_per_ip:
raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip})