import asyncio import json import os import traceback from typing import Optional from uuid import uuid4 import httpx from fastapi import File, UploadFile from fastapi import APIRouter, Response, Request, Depends, HTTPException from fastapi.responses import StreamingResponse from core.auth import verify_app_secret from core.config import get_settings from core.logger import setup_logger from core.models import ChatRequest from core.utils import process_streaming_response from playsound import playsound # 用于播放音频 # from rich import print logger = setup_logger(__name__) router = APIRouter() ALLOWED_MODELS = get_settings().ALLOWED_MODELS current_index = 0 @router.get("/models") async def list_models(): return {"object": "list", "data": ALLOWED_MODELS, "success": True} @router.options("/chat/completions") async def chat_completions_options(): return Response( status_code=200, headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization", }, ) # 识图 # 识图 # 文本转语音 @router.post("/audio/speech") async def speech(request: Request): global current_index url = 'https://api.thinkbuddy.ai/v1/content/speech/tts' token_str = os.getenv('TOKEN', '') token_array = token_str.split(',') if len(token_array) > 0: current_index = current_index % len(token_array) print('speech current index is ', current_index) request_headers = {**get_settings().HEADERS, 'authorization': f"Bearer {token_array[current_index]}", 'Accept': 'application/json, text/plain, */*', } # data = { # "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。", # "voice": "nova" # alloy echo fable onyx nova shimmer # } body = await request.json() try: async with httpx.AsyncClient(http2=True) as client: response = await client.post(url, headers=request_headers, json=body) response.raise_for_status() # 假设响应是音频数据,保存为文件 if response.status_code == 200: # 保存音频文件 with open('output.mp3', 'wb') as f: f.write(response.content) print("音频文件已保存为 output.mp3") # 异步播放音频 # 使用 asyncio.to_thread 来避免阻塞事件循环 # await asyncio.to_thread(playsound, 'output.mp3') return True else: print(f"请求失败,状态码: {response.status_code}") print(f"响应内容: {response.text}") return False except httpx.RequestError as e: print(f"请求错误: {e}") print("错误堆栈:") traceback.print_exc() return False except httpx.HTTPStatusError as e: print(f"HTTP 错误: {e}") print("错误堆栈:") traceback.print_exc() return False except Exception as e: print(f"发生错误: {e}") print("错误堆栈:") traceback.print_exc() return False finally: current_index += 1 # 语音转文本 @router.post("/audio/transcriptions") async def transcriptions(request: Request, file: UploadFile = File(...)): global current_index url = 'https://api.thinkbuddy.ai/v1/content/transcribe' params = {'enhance': 'true'} try: # 读取文件内容 content = await safe_read_file(file) # 获取原始 content-type content_type = request.headers.get('content-type') # files = { # 'file': (str(uuid4()), # content, # file.content_type or 'application/octet-stream') # } files = { 'file': ('file.mp4', content, 'audio/mp4'), 'model': (None, 'whisper-1') } # 记录请求信息 logger.info(f"Received upload request for file: {file.filename}") logger.info(f"Content-Type: {request.headers.get('content-type')}") token_str = os.getenv('TOKEN', '') token_array = token_str.split(',') if len(token_array) > 0: current_index = current_index % len(token_array) print('transcriptions current index is ', current_index) request_headers = {**get_settings().HEADERS, 'authorization': f"Bearer {token_array[current_index]}", 'Accept': 'application/json, text/plain, */*', 'Content-Type': content_type, } # 设置较长的超时时间 timeout = httpx.Timeout( connect=30.0, # 连接超时 read=300.0, # 读取超时 write=30.0, # 写入超时 pool=30.0 # 连接池超时 ) # 使用httpx发送异步请求 async with httpx.AsyncClient(http2=True, timeout=timeout) as client: response = await client.post(url, params=params, headers=request_headers, files=files) current_index += 1 response.raise_for_status() return response.json() except httpx.TimeoutException: raise HTTPException(status_code=504, detail="请求目标服务器超时") except httpx.HTTPStatusError as e: raise HTTPException(status_code=e.response.status_code, detail=str(e)) except Exception as e: traceback.print_tb(e.__traceback__) raise HTTPException(status_code=500, detail=str(e)) finally: # 清理资源 await file.close() async def safe_read_file(file: UploadFile) -> Optional[bytes]: """安全地读取文件内容""" try: return await file.read() except Exception as e: logger.error(f"Error reading file: {str(e)}") return None # 文件上传 @router.post("/upload") async def upload_file(request: Request, file: UploadFile = File(...)): global current_index try: # 读取文件内容 content = await safe_read_file(file) # 获取原始 content-type content_type = request.headers.get('content-type') files = { 'file': ( # str(uuid4()), file.filename, # 使用原始文件名而不是 UUID content, file.content_type ) } # 记录请求信息 logger.info(f"Received upload request for file: {file.filename}") logger.info(f"Content-Type: {request.headers.get('content-type')}") token_str = os.getenv('TOKEN', '') token_array = token_str.split(',') if len(token_array) > 0: current_index = current_index % len(token_array) print('upload_file current index is ', current_index) request_headers = {**get_settings().HEADERS, 'authorization': f"Bearer {token_array[current_index]}", 'Accept': 'application/json, text/plain, */*', 'Content-Type': content_type, } # 使用httpx发送异步请求 async with httpx.AsyncClient() as client: response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100) current_index += 1 response.raise_for_status() return response.json() except httpx.TimeoutException: raise HTTPException(status_code=504, detail="请求目标服务器超时") except httpx.HTTPStatusError as e: # raise HTTPException(status_code=e.response.status_code, detail=str(e)) print(f"HTTPStatusError发生错误: {e}") print("错误堆栈:") traceback.print_exc() except Exception as e: # traceback.print_tb(e.__traceback__) # raise HTTPException(status_code=500, detail=str(e)) print(f"发生错误: {e}") print("错误堆栈:") traceback.print_exc() finally: # 清理资源 await file.close() @router.post("/chat/completions") async def chat_completions( request: ChatRequest, app_secret: str = Depends(verify_app_secret) ): global current_index logger.info("Entering chat_completions route") # logger.info(f"Received request: {request}") # logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}") # logger.info(f"Received request json format: {json.dumps(request.model_dump())}") logger.info(f"Received request json format: {json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))}") logger.info(f"App secret: {app_secret}") logger.info(f"Received chat completion request for model: {request.model}") if request.model not in [model["id"] for model in ALLOWED_MODELS]: raise HTTPException( status_code=400, detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}", ) if request.stream: logger.info("Streaming response") # 创建一个标志来追踪是否有响应 has_response = False async def content_generator(): nonlocal has_response try: async for item in process_streaming_response(request, app_secret, current_index): has_response = True yield item except Exception as e: logger.error(f"Error in streaming response: {e}") raise response = StreamingResponse( content_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Transfer-Encoding": "chunked" } ) # 在返回响应之前增加 current_index # if has_response: # current_index += 1 current_index += 1 return response else: logger.info("Non-streaming response") # return await process_non_streaming_response(request) @router.route('/') @router.route('/healthz') @router.route('/ready') @router.route('/alive') @router.route('/status') @router.get("/health") async def health_check(request: Request): return Response(content=json.dumps({"status": "ok"}), media_type="application/json") @router.post("/env") async def environment(app_secret: str = Depends(verify_app_secret)): return Response(content=json.dumps({"token": os.getenv("TOKEN", ""), "refresh_token": os.getenv("REFRESH_TOKEN", ""), "key": os.getenv("FIREBASE_API_KEY", "")}), media_type="application/json")