Spaces:
Running
Running
import codecs | |
import hashlib | |
import json | |
import os | |
import ssl | |
import uuid | |
from datetime import datetime | |
# from http.client import HTTPException | |
from typing import Dict, Any, Optional | |
import httpx | |
from fastapi import HTTPException | |
from httpx import ConnectError, TransportError | |
from starlette import status | |
from core.config import get_settings | |
from core.logger import setup_logger | |
from core.models import ChatRequest | |
# from rich import print | |
settings = get_settings() | |
logger = setup_logger(__name__) | |
def decode_unicode_escape(s): | |
# 检查输入是否为字典类型 | |
if isinstance(s, dict): | |
return s | |
# 如果需要,将输入转换为字符串 | |
if not isinstance(s, (str, bytes)): | |
s = str(s) | |
# 如果是字符串,转换为字节 | |
if isinstance(s, str): | |
s = s.encode('utf-8') | |
return codecs.decode(s, 'unicode_escape') | |
FIREBASE_API_KEY = settings.FIREBASE_API_KEY | |
async def refresh_token_via_rest(refresh_token): | |
refresh_token_array = [x.strip() for x in refresh_token.split(',')] | |
token_array = [] | |
if len(refresh_token_array) > 0: | |
print('refresh token length is ', len(refresh_token_array)) | |
for e in refresh_token_array: | |
# Firebase Auth REST API endpoint | |
url = f"https://securetoken.googleapis.com/v1/token?key={FIREBASE_API_KEY}" | |
payload = { | |
'grant_type': 'refresh_token', | |
'refresh_token': e | |
} | |
try: | |
async with httpx.AsyncClient() as client: | |
response = await client.post(url, json=payload) | |
if response.status_code == 200: | |
data = response.json() | |
print(json.dumps(data, indent=2)) | |
# return { | |
# 'id_token': data['id_token'], | |
# 'refresh_token': data.get('refresh_token'), | |
# 'expires_in': data['expires_in'] | |
# } | |
# return data['id_token'] | |
token_array.append(data['id_token']) | |
else: | |
print(f"刷新失败: {response.text}") | |
return None | |
except Exception as e: | |
print(f"请求异常: {e}") | |
return None | |
return ','.join(token_array) | |
async def sign_in_with_idp(): | |
url = "https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp" | |
# 查询参数 | |
params = { | |
"key": FIREBASE_API_KEY | |
} | |
# 请求头 | |
headers = { | |
"X-Client-Version": "Node/JsCore/10.5.2/FirebaseCore-web", | |
"X-Firebase-gmpid": "1:123807869619:web:43b278a622ed6322789ec6", | |
"Content-Type": "application/json", | |
"User-Agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)" | |
} | |
# 请求体 | |
data = { | |
"requestUri": "http://localhost", | |
"returnSecureToken": True, | |
"postBody": f"&id_token={settings.AUTHORIZATION_TOKEN}&providerId=google.com" | |
} | |
print("Request Headers:", json.dumps(headers, indent=2)) # 格式化打印 | |
print("Request Body:", json.dumps(data, indent=2)) # 格式化打印 | |
print("Request params:", json.dumps(params, indent=2)) # 格式化打印 | |
async with httpx.AsyncClient() as client: | |
response = await client.post( | |
url, | |
params=params, | |
headers=headers, | |
json=data | |
) | |
# 检查状态码 | |
if response.status_code == 200: | |
return response.json() | |
else: | |
raise Exception(f"Request failed with status code: {response.status_code}") | |
async def handle_firebase_response(response) -> str: | |
try: | |
# 如果响应是字典(已经解析的 JSON) | |
if isinstance(response, dict): | |
print(json.dumps(response, indent=2)) | |
if response.get('error', {}).get('code') == 400: | |
print("Invalid id_token in IdP response") | |
# 保存refresh_token到配置中 | |
if 'refreshToken' in response: | |
os.environ["REFRESH_TOKEN"] = response['refreshToken'] | |
if 'idToken' in response: | |
return response['idToken'] | |
else: | |
raise ValueError("dict case Response does not contain idToken") | |
# 如果响应是 Response 对象 | |
elif hasattr(response, 'status_code'): | |
if response.status_code == 200: | |
data = response.json() | |
print(data) | |
# 保存refresh_token到配置中 | |
if 'refreshToken' in data: | |
os.environ["REFRESH_TOKEN"] = data['refreshToken'] | |
if 'idToken' in data: | |
return data['idToken'] | |
else: | |
raise ValueError("response case Response does not contain idToken") | |
# 处理其他状态码 | |
elif response.status_code == 400: | |
error_data = response.json() | |
raise ValueError(f"Bad Request: {error_data.get('error', {}).get('message', 'Unknown error')}") | |
elif response.status_code == 401: | |
raise ValueError("Unauthorized: Invalid credentials") | |
elif response.status_code == 403: | |
raise ValueError("Forbidden: Insufficient permissions") | |
elif response.status_code == 404: | |
raise ValueError("Not Found: Resource doesn't exist") | |
else: | |
raise ValueError(f"Unexpected status code: {response.status_code}") | |
else: | |
raise ValueError(f"Unexpected response type: {type(response)}") | |
except json.JSONDecodeError: | |
raise ValueError("Invalid JSON response") | |
except Exception as e: | |
raise ValueError(f"Error processing response: {str(e)}") | |
# SHA-256 | |
def _sha256_hash(text): | |
sha256 = hashlib.sha256() | |
sha256.update(text.encode('utf-8')) | |
return sha256.hexdigest() | |
# 处理字典列表 | |
def sha256_hash_messages(messages): | |
# 只提取 role 为 "user" 的消息的 content 字段 | |
message_data = [str(msg['content']) for msg in messages if msg['role'] == "user"] | |
print("Filtered contents:", message_data) # 调试用 | |
json_str = json.dumps(message_data, sort_keys=True) | |
print("JSON string:", json_str) # 调试用 | |
return hashlib.sha256(json_str.encode('utf-8')).hexdigest() | |
def create_chat_completion_data( | |
content: str, model: str, timestamp: int, finish_reason: Optional[str] = None | |
) -> Dict[str, Any]: | |
return { | |
"id": f"chatcmpl-{uuid.uuid4()}", | |
"object": "chat.completion.chunk", | |
"created": timestamp, | |
"model": model, | |
"choices": [ | |
{ | |
"index": 0, | |
"delta": {"content": content, "role": "assistant"}, | |
"finish_reason": finish_reason, | |
} | |
], | |
"usage": None, | |
} | |
async def process_streaming_response(request: ChatRequest, app_secret: str, current_index: int): | |
# 创建自定义 SSL 上下文 | |
ssl_context = ssl.create_default_context() | |
ssl_context.check_hostname = True | |
ssl_context.verify_mode = ssl.CERT_REQUIRED | |
async with httpx.AsyncClient( | |
verify=ssl_context, | |
# timeout=30.0, # 增加超时时间 | |
# http2=True # 启用 HTTP/2 | |
) as client: | |
try: | |
token_str = os.getenv('TOKEN', '') | |
token_array = token_str.split(',') | |
if len(token_array) > 0: | |
current_index = current_index % len(token_array) | |
print('completions current index is ', current_index) | |
request_headers = {**settings.HEADERS, 'authorization': f"Bearer {token_array[current_index]}"} # 从环境变量中获取新的TOKEN | |
# 直接使用 request.model_dump() 或 request.dict() 获取字典格式的数据 | |
request_data = request.model_dump() # 如果使用较新版本的 Pydantic | |
# # 获取请求数据 | |
# request_data = { | |
# "model": request.model, | |
# "messages": [msg.dict() for msg in request.messages], | |
# "temperature": request.temperature, | |
# "top_p": request.top_p, | |
# "max_tokens": request.max_tokens, | |
# "stream": request.stream | |
# } | |
# print("Request Headers:", json.dumps(request_headers, indent=2)) # 格式化打印 | |
# print("Request Body:", json.dumps(request.json(), indent=4, ensure_ascii=False)) # 格式化打印 | |
print("Request Headers:", json.dumps(request_headers, indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印 | |
print("Request Body:", json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印 | |
async with client.stream( | |
"POST", | |
f"https://api.thinkbuddy.ai/v1/chat/completions", | |
headers=request_headers, | |
json=request_data, | |
timeout=100, | |
) as response: | |
response.raise_for_status() | |
print(f"Response status code: {response.status_code}") | |
timestamp = int(datetime.now().timestamp()) | |
async for line in response.aiter_lines(): | |
# print(f"{type(line)}: {line}") | |
if line and line.startswith("data: "): | |
try: | |
if line.strip() == 'data: [DONE]': | |
await response.aclose() | |
break | |
data_str = line[6:] # 去掉 'data: ' 前缀 | |
# 解析JSON | |
json_data = json.loads(data_str) | |
if 'choices' in json_data and len(json_data['choices']) > 0: | |
delta = json_data['choices'][0].get('delta', {}) | |
if 'content' in delta: | |
print(delta['content'], end='', flush=True) | |
yield f"data: {json.dumps(create_chat_completion_data(delta['content'], request.model, timestamp))}\n\n" | |
except json.JSONDecodeError as e: | |
print(f"JSON解析错误: {e}") | |
print(f"原始数据: {line}") | |
continue | |
yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n" | |
yield "data: [DONE]\n\n" | |
except ConnectError as e: | |
logger.error(f"Connection error details: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
detail="Service temporarily unavailable. Please try again later." | |
) | |
except TransportError as e: | |
logger.error(f"Transport error details: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_502_BAD_GATEWAY, | |
detail="Network transport error occurred." | |
) | |
except httpx.HTTPStatusError as e: | |
# 这里需要处理401错误 | |
# 处理429错误 | |
if e.response.status_code == 429: | |
token_str = os.getenv('TOKEN', '') | |
token_array = token_str.split(',') | |
token_array.pop(current_index) | |
os.environ["TOKEN"] = ','.join(token_array) | |
refresh_token_str = os.getenv('REFRESH_TOKEN', '') | |
refresh_token_array = refresh_token_str.split(',') | |
refresh_token_array.pop(current_index) | |
os.environ["REFRESH_TOKEN"] = ','.join(refresh_token_array) | |
logger.error(f"HTTP error occurred: {e}") | |
raise HTTPException(status_code=e.response.status_code, detail=str(e)) | |
except httpx.RequestError as e: | |
logger.error(f"Error occurred during request: {e}") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) | |
finally: | |
await response.aclose() | |