thinkbuddy2api / core /router.py
kevin
精简日志输出
0a3d38f
import asyncio
import json
import os
import traceback
from typing import Optional
from uuid import uuid4
import httpx
from fastapi import File, UploadFile
from fastapi import APIRouter, Response, Request, Depends, HTTPException
from fastapi.responses import StreamingResponse
from core.auth import verify_app_secret
from core.config import get_settings
from core.logger import setup_logger
from core.models import ChatRequest
from core.utils import process_streaming_response
from playsound import playsound # 用于播放音频
# from rich import print
logger = setup_logger(__name__)
router = APIRouter()
ALLOWED_MODELS = get_settings().ALLOWED_MODELS
current_index = 0
@router.get("/models")
async def list_models():
return {"object": "list", "data": ALLOWED_MODELS, "success": True}
@router.options("/chat/completions")
async def chat_completions_options():
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
},
)
# 识图
# 识图
# 文本转语音
@router.post("/audio/speech")
async def speech(request: Request):
global current_index
url = 'https://api.thinkbuddy.ai/v1/content/speech/tts'
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
if len(token_array) > 0:
current_index = current_index % len(token_array)
print('speech current index is ', current_index)
request_headers = {**get_settings().HEADERS,
'authorization': f"Bearer {token_array[current_index]}",
'Accept': 'application/json, text/plain, */*',
}
# data = {
# "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。",
# "voice": "nova" # alloy echo fable onyx nova shimmer
# }
body = await request.json()
try:
async with httpx.AsyncClient(http2=True) as client:
response = await client.post(url, headers=request_headers, json=body)
response.raise_for_status()
# 假设响应是音频数据,保存为文件
if response.status_code == 200:
# 保存音频文件
with open('output.mp3', 'wb') as f:
f.write(response.content)
print("音频文件已保存为 output.mp3")
# 异步播放音频
# 使用 asyncio.to_thread 来避免阻塞事件循环
# await asyncio.to_thread(playsound, 'output.mp3')
return True
else:
print(f"请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return False
except httpx.RequestError as e:
print(f"请求错误: {e}")
print("错误堆栈:")
traceback.print_exc()
return False
except httpx.HTTPStatusError as e:
print(f"HTTP 错误: {e}")
print("错误堆栈:")
traceback.print_exc()
return False
except Exception as e:
print(f"发生错误: {e}")
print("错误堆栈:")
traceback.print_exc()
return False
finally:
current_index += 1
# 语音转文本
@router.post("/audio/transcriptions")
async def transcriptions(request: Request, file: UploadFile = File(...)):
global current_index
url = 'https://api.thinkbuddy.ai/v1/content/transcribe'
params = {'enhance': 'true'}
try:
# 读取文件内容
content = await safe_read_file(file)
# 获取原始 content-type
content_type = request.headers.get('content-type')
# files = {
# 'file': (str(uuid4()),
# content,
# file.content_type or 'application/octet-stream')
# }
files = {
'file': ('file.mp4', content, 'audio/mp4'),
'model': (None, 'whisper-1')
}
# 记录请求信息
logger.info(f"Received upload request for file: {file.filename}")
logger.info(f"Content-Type: {request.headers.get('content-type')}")
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
if len(token_array) > 0:
current_index = current_index % len(token_array)
print('transcriptions current index is ', current_index)
request_headers = {**get_settings().HEADERS,
'authorization': f"Bearer {token_array[current_index]}",
'Accept': 'application/json, text/plain, */*',
'Content-Type': content_type,
}
# 设置较长的超时时间
timeout = httpx.Timeout(
connect=30.0, # 连接超时
read=300.0, # 读取超时
write=30.0, # 写入超时
pool=30.0 # 连接池超时
)
# 使用httpx发送异步请求
async with httpx.AsyncClient(http2=True, timeout=timeout) as client:
response = await client.post(url,
params=params,
headers=request_headers,
files=files)
current_index += 1
response.raise_for_status()
return response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求目标服务器超时")
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except Exception as e:
traceback.print_tb(e.__traceback__)
raise HTTPException(status_code=500, detail=str(e))
finally:
# 清理资源
await file.close()
async def safe_read_file(file: UploadFile) -> Optional[bytes]:
"""安全地读取文件内容"""
try:
return await file.read()
except Exception as e:
logger.error(f"Error reading file: {str(e)}")
return None
# 文件上传
@router.post("/upload")
async def upload_file(request: Request, file: UploadFile = File(...)):
global current_index
try:
# 读取文件内容
content = await safe_read_file(file)
# 获取原始 content-type
content_type = request.headers.get('content-type')
files = {
'file': (
# str(uuid4()),
file.filename, # 使用原始文件名而不是 UUID
content,
file.content_type )
}
# 记录请求信息
logger.info(f"Received upload request for file: {file.filename}")
logger.info(f"Content-Type: {request.headers.get('content-type')}")
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
if len(token_array) > 0:
current_index = current_index % len(token_array)
print('upload_file current index is ', current_index)
request_headers = {**get_settings().HEADERS,
'authorization': f"Bearer {token_array[current_index]}",
'Accept': 'application/json, text/plain, */*',
'Content-Type': content_type,
}
# 使用httpx发送异步请求
async with httpx.AsyncClient() as client:
response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100)
current_index += 1
response.raise_for_status()
return response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求目标服务器超时")
except httpx.HTTPStatusError as e:
# raise HTTPException(status_code=e.response.status_code, detail=str(e))
print(f"HTTPStatusError发生错误: {e}")
print("错误堆栈:")
traceback.print_exc()
except Exception as e:
# traceback.print_tb(e.__traceback__)
# raise HTTPException(status_code=500, detail=str(e))
print(f"发生错误: {e}")
print("错误堆栈:")
traceback.print_exc()
finally:
# 清理资源
await file.close()
@router.post("/chat/completions")
async def chat_completions(
request: ChatRequest, app_secret: str = Depends(verify_app_secret)
):
global current_index
logger.info("Entering chat_completions route")
# logger.info(f"Received request: {request}")
# logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}")
# logger.info(f"Received request json format: {json.dumps(request.model_dump())}")
# logger.info(f"Received request json format: {json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))}")
logger.info(f"App secret: {app_secret}")
logger.info(f"Received chat completion request for model: {request.model}")
if request.model not in [model["id"] for model in ALLOWED_MODELS]:
raise HTTPException(
status_code=400,
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}",
)
if request.stream:
logger.info("Streaming response")
# 创建一个标志来追踪是否有响应
has_response = False
async def content_generator():
nonlocal has_response
try:
async for item in process_streaming_response(request, app_secret, current_index):
has_response = True
yield item
except Exception as e:
logger.error(f"Error in streaming response: {e}")
raise
response = StreamingResponse(
content_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
# 在返回响应之前增加 current_index
# if has_response:
# current_index += 1
current_index += 1
return response
else:
logger.info("Non-streaming response")
# return await process_non_streaming_response(request)
@router.route('/')
@router.route('/healthz')
@router.route('/ready')
@router.route('/alive')
@router.route('/status')
@router.get("/health")
async def health_check(request: Request):
return Response(content=json.dumps({"status": "ok"}), media_type="application/json")
@router.post("/env")
async def environment(app_secret: str = Depends(verify_app_secret)):
length = 0
if os.getenv('TOKEN', '').split(',') is not None:
length = len(os.getenv('TOKEN', '').split(','))
return Response(content=json.dumps({"token": os.getenv("TOKEN", ""), "length": length, "refresh_token": os.getenv("REFRESH_TOKEN", ""), "key": os.getenv("FIREBASE_API_KEY", "")}), media_type="application/json")