thinkbuddy2api / core /utils.py
kevin
精简日志输出
0a3d38f
import codecs
import hashlib
import json
import os
import ssl
import uuid
from datetime import datetime
# from http.client import HTTPException
from typing import Dict, Any, Optional
import httpx
from fastapi import HTTPException
from httpx import ConnectError, TransportError
from starlette import status
from core.config import get_settings
from core.logger import setup_logger
from core.models import ChatRequest
# from rich import print
settings = get_settings()
logger = setup_logger(__name__)
def decode_unicode_escape(s):
# 检查输入是否为字典类型
if isinstance(s, dict):
return s
# 如果需要,将输入转换为字符串
if not isinstance(s, (str, bytes)):
s = str(s)
# 如果是字符串,转换为字节
if isinstance(s, str):
s = s.encode('utf-8')
return codecs.decode(s, 'unicode_escape')
FIREBASE_API_KEY = settings.FIREBASE_API_KEY
async def refresh_token_via_rest(refresh_token):
refresh_token_array = [x.strip() for x in refresh_token.split(',')]
token_array = []
if len(refresh_token_array) > 0:
print('refresh token length is ', len(refresh_token_array))
for e in refresh_token_array:
# Firebase Auth REST API endpoint
url = f"https://securetoken.googleapis.com/v1/token?key={FIREBASE_API_KEY}"
payload = {
'grant_type': 'refresh_token',
'refresh_token': e
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload)
if response.status_code == 200:
data = response.json()
print(json.dumps(data, indent=2))
# return {
# 'id_token': data['id_token'],
# 'refresh_token': data.get('refresh_token'),
# 'expires_in': data['expires_in']
# }
# return data['id_token']
token_array.append(data['id_token'])
else:
print(f"刷新失败: {response.text}")
return None
except Exception as e:
print(f"请求异常: {e}")
return None
return ','.join(token_array)
async def sign_in_with_idp():
url = "https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp"
# 查询参数
params = {
"key": FIREBASE_API_KEY
}
# 请求头
headers = {
"X-Client-Version": "Node/JsCore/10.5.2/FirebaseCore-web",
"X-Firebase-gmpid": "1:123807869619:web:43b278a622ed6322789ec6",
"Content-Type": "application/json",
"User-Agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)"
}
# 请求体
data = {
"requestUri": "http://localhost",
"returnSecureToken": True,
"postBody": f"&id_token={settings.AUTHORIZATION_TOKEN}&providerId=google.com"
}
print("Request Headers:", json.dumps(headers, indent=2)) # 格式化打印
print("Request Body:", json.dumps(data, indent=2)) # 格式化打印
print("Request params:", json.dumps(params, indent=2)) # 格式化打印
async with httpx.AsyncClient() as client:
response = await client.post(
url,
params=params,
headers=headers,
json=data
)
# 检查状态码
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Request failed with status code: {response.status_code}")
async def handle_firebase_response(response) -> str:
try:
# 如果响应是字典(已经解析的 JSON)
if isinstance(response, dict):
print(json.dumps(response, indent=2))
if response.get('error', {}).get('code') == 400:
print("Invalid id_token in IdP response")
# 保存refresh_token到配置中
if 'refreshToken' in response:
os.environ["REFRESH_TOKEN"] = response['refreshToken']
if 'idToken' in response:
return response['idToken']
else:
raise ValueError("dict case Response does not contain idToken")
# 如果响应是 Response 对象
elif hasattr(response, 'status_code'):
if response.status_code == 200:
data = response.json()
print(data)
# 保存refresh_token到配置中
if 'refreshToken' in data:
os.environ["REFRESH_TOKEN"] = data['refreshToken']
if 'idToken' in data:
return data['idToken']
else:
raise ValueError("response case Response does not contain idToken")
# 处理其他状态码
elif response.status_code == 400:
error_data = response.json()
raise ValueError(f"Bad Request: {error_data.get('error', {}).get('message', 'Unknown error')}")
elif response.status_code == 401:
raise ValueError("Unauthorized: Invalid credentials")
elif response.status_code == 403:
raise ValueError("Forbidden: Insufficient permissions")
elif response.status_code == 404:
raise ValueError("Not Found: Resource doesn't exist")
else:
raise ValueError(f"Unexpected status code: {response.status_code}")
else:
raise ValueError(f"Unexpected response type: {type(response)}")
except json.JSONDecodeError:
raise ValueError("Invalid JSON response")
except Exception as e:
raise ValueError(f"Error processing response: {str(e)}")
# SHA-256
def _sha256_hash(text):
sha256 = hashlib.sha256()
sha256.update(text.encode('utf-8'))
return sha256.hexdigest()
# 处理字典列表
def sha256_hash_messages(messages):
# 只提取 role 为 "user" 的消息的 content 字段
message_data = [str(msg['content']) for msg in messages if msg['role'] == "user"]
print("Filtered contents:", message_data) # 调试用
json_str = json.dumps(message_data, sort_keys=True)
print("JSON string:", json_str) # 调试用
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
def create_chat_completion_data(
content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": content, "role": "assistant"},
"finish_reason": finish_reason,
}
],
"usage": None,
}
async def process_streaming_response(request: ChatRequest, app_secret: str, current_index: int):
# 创建自定义 SSL 上下文
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = True
ssl_context.verify_mode = ssl.CERT_REQUIRED
async with httpx.AsyncClient(
verify=ssl_context,
# timeout=30.0, # 增加超时时间
# http2=True # 启用 HTTP/2
) as client:
try:
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
if len(token_array) > 0:
current_index = current_index % len(token_array)
print('completions current index is ', current_index)
request_headers = {**settings.HEADERS, 'authorization': f"Bearer {token_array[current_index]}"} # 从环境变量中获取新的TOKEN
# 直接使用 request.model_dump() 或 request.dict() 获取字典格式的数据
request_data = request.model_dump() # 如果使用较新版本的 Pydantic
# # 获取请求数据
# request_data = {
# "model": request.model,
# "messages": [msg.dict() for msg in request.messages],
# "temperature": request.temperature,
# "top_p": request.top_p,
# "max_tokens": request.max_tokens,
# "stream": request.stream
# }
# print("Request Headers:", json.dumps(request_headers, indent=2)) # 格式化打印
# print("Request Body:", json.dumps(request.json(), indent=4, ensure_ascii=False)) # 格式化打印
print("Request Headers:", json.dumps(request_headers, indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印
print("Request Body:", json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印
async with client.stream(
"POST",
f"https://api.thinkbuddy.ai/v1/chat/completions",
headers=request_headers,
json=request_data,
timeout=100,
) as response:
response.raise_for_status()
print(f"Response status code: {response.status_code}")
timestamp = int(datetime.now().timestamp())
async for line in response.aiter_lines():
# print(f"{type(line)}: {line}")
if line and line.startswith("data: "):
try:
if line.strip() == 'data: [DONE]':
await response.aclose()
break
data_str = line[6:] # 去掉 'data: ' 前缀
# 解析JSON
json_data = json.loads(data_str)
if 'choices' in json_data and len(json_data['choices']) > 0:
delta = json_data['choices'][0].get('delta', {})
if 'content' in delta:
print(delta['content'], end='', flush=True)
yield f"data: {json.dumps(create_chat_completion_data(delta['content'], request.model, timestamp))}\n\n"
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
print(f"原始数据: {line}")
continue
yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
yield "data: [DONE]\n\n"
except ConnectError as e:
logger.error(f"Connection error details: {str(e)}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service temporarily unavailable. Please try again later."
)
except TransportError as e:
logger.error(f"Transport error details: {str(e)}")
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Network transport error occurred."
)
except httpx.HTTPStatusError as e:
# 这里需要处理401错误
# 处理429错误
if e.response.status_code == 429:
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
token_array.pop(current_index)
os.environ["TOKEN"] = ','.join(token_array)
refresh_token_str = os.getenv('REFRESH_TOKEN', '')
refresh_token_array = refresh_token_str.split(',')
refresh_token_array.pop(current_index)
os.environ["REFRESH_TOKEN"] = ','.join(refresh_token_array)
logger.error(f"HTTP error occurred: {e}")
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except httpx.RequestError as e:
logger.error(f"Error occurred during request: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
finally:
await response.aclose()