Spaces:
Running
Running
import asyncio | |
import json | |
import os | |
import traceback | |
from typing import Optional | |
from uuid import uuid4 | |
import httpx | |
from fastapi import File, UploadFile | |
from fastapi import APIRouter, Response, Request, Depends, HTTPException | |
from fastapi.responses import StreamingResponse | |
from core.auth import verify_app_secret | |
from core.config import get_settings | |
from core.logger import setup_logger | |
from core.models import ChatRequest | |
from core.utils import process_streaming_response | |
from playsound import playsound # 用于播放音频 | |
# from rich import print | |
logger = setup_logger(__name__) | |
router = APIRouter() | |
ALLOWED_MODELS = get_settings().ALLOWED_MODELS | |
current_index = 0 | |
async def list_models(): | |
return {"object": "list", "data": ALLOWED_MODELS, "success": True} | |
async def chat_completions_options(): | |
return Response( | |
status_code=200, | |
headers={ | |
"Access-Control-Allow-Origin": "*", | |
"Access-Control-Allow-Methods": "POST, OPTIONS", | |
"Access-Control-Allow-Headers": "Content-Type, Authorization", | |
}, | |
) | |
# 识图 | |
# 识图 | |
# 文本转语音 | |
async def speech(request: Request): | |
global current_index | |
url = 'https://api.thinkbuddy.ai/v1/content/speech/tts' | |
token_str = os.getenv('TOKEN', '') | |
token_array = token_str.split(',') | |
if len(token_array) > 0: | |
current_index = current_index % len(token_array) | |
print('speech current index is ', current_index) | |
request_headers = {**get_settings().HEADERS, | |
'authorization': f"Bearer {token_array[current_index]}", | |
'Accept': 'application/json, text/plain, */*', | |
} | |
# data = { | |
# "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。", | |
# "voice": "nova" # alloy echo fable onyx nova shimmer | |
# } | |
body = await request.json() | |
try: | |
async with httpx.AsyncClient(http2=True) as client: | |
response = await client.post(url, headers=request_headers, json=body) | |
response.raise_for_status() | |
# 假设响应是音频数据,保存为文件 | |
if response.status_code == 200: | |
# 保存音频文件 | |
with open('output.mp3', 'wb') as f: | |
f.write(response.content) | |
print("音频文件已保存为 output.mp3") | |
# 异步播放音频 | |
# 使用 asyncio.to_thread 来避免阻塞事件循环 | |
# await asyncio.to_thread(playsound, 'output.mp3') | |
return True | |
else: | |
print(f"请求失败,状态码: {response.status_code}") | |
print(f"响应内容: {response.text}") | |
return False | |
except httpx.RequestError as e: | |
print(f"请求错误: {e}") | |
print("错误堆栈:") | |
traceback.print_exc() | |
return False | |
except httpx.HTTPStatusError as e: | |
print(f"HTTP 错误: {e}") | |
print("错误堆栈:") | |
traceback.print_exc() | |
return False | |
except Exception as e: | |
print(f"发生错误: {e}") | |
print("错误堆栈:") | |
traceback.print_exc() | |
return False | |
finally: | |
current_index += 1 | |
# 语音转文本 | |
async def transcriptions(request: Request, file: UploadFile = File(...)): | |
global current_index | |
url = 'https://api.thinkbuddy.ai/v1/content/transcribe' | |
params = {'enhance': 'true'} | |
try: | |
# 读取文件内容 | |
content = await safe_read_file(file) | |
# 获取原始 content-type | |
content_type = request.headers.get('content-type') | |
# files = { | |
# 'file': (str(uuid4()), | |
# content, | |
# file.content_type or 'application/octet-stream') | |
# } | |
files = { | |
'file': ('file.mp4', content, 'audio/mp4'), | |
'model': (None, 'whisper-1') | |
} | |
# 记录请求信息 | |
logger.info(f"Received upload request for file: {file.filename}") | |
logger.info(f"Content-Type: {request.headers.get('content-type')}") | |
token_str = os.getenv('TOKEN', '') | |
token_array = token_str.split(',') | |
if len(token_array) > 0: | |
current_index = current_index % len(token_array) | |
print('transcriptions current index is ', current_index) | |
request_headers = {**get_settings().HEADERS, | |
'authorization': f"Bearer {token_array[current_index]}", | |
'Accept': 'application/json, text/plain, */*', | |
'Content-Type': content_type, | |
} | |
# 设置较长的超时时间 | |
timeout = httpx.Timeout( | |
connect=30.0, # 连接超时 | |
read=300.0, # 读取超时 | |
write=30.0, # 写入超时 | |
pool=30.0 # 连接池超时 | |
) | |
# 使用httpx发送异步请求 | |
async with httpx.AsyncClient(http2=True, timeout=timeout) as client: | |
response = await client.post(url, | |
params=params, | |
headers=request_headers, | |
files=files) | |
current_index += 1 | |
response.raise_for_status() | |
return response.json() | |
except httpx.TimeoutException: | |
raise HTTPException(status_code=504, detail="请求目标服务器超时") | |
except httpx.HTTPStatusError as e: | |
raise HTTPException(status_code=e.response.status_code, detail=str(e)) | |
except Exception as e: | |
traceback.print_tb(e.__traceback__) | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
# 清理资源 | |
await file.close() | |
async def safe_read_file(file: UploadFile) -> Optional[bytes]: | |
"""安全地读取文件内容""" | |
try: | |
return await file.read() | |
except Exception as e: | |
logger.error(f"Error reading file: {str(e)}") | |
return None | |
# 文件上传 | |
async def upload_file(request: Request, file: UploadFile = File(...)): | |
global current_index | |
try: | |
# 读取文件内容 | |
content = await safe_read_file(file) | |
# 获取原始 content-type | |
content_type = request.headers.get('content-type') | |
files = { | |
'file': ( | |
# str(uuid4()), | |
file.filename, # 使用原始文件名而不是 UUID | |
content, | |
file.content_type ) | |
} | |
# 记录请求信息 | |
logger.info(f"Received upload request for file: {file.filename}") | |
logger.info(f"Content-Type: {request.headers.get('content-type')}") | |
token_str = os.getenv('TOKEN', '') | |
token_array = token_str.split(',') | |
if len(token_array) > 0: | |
current_index = current_index % len(token_array) | |
print('upload_file current index is ', current_index) | |
request_headers = {**get_settings().HEADERS, | |
'authorization': f"Bearer {token_array[current_index]}", | |
'Accept': 'application/json, text/plain, */*', | |
'Content-Type': content_type, | |
} | |
# 使用httpx发送异步请求 | |
async with httpx.AsyncClient() as client: | |
response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100) | |
current_index += 1 | |
response.raise_for_status() | |
return response.json() | |
except httpx.TimeoutException: | |
raise HTTPException(status_code=504, detail="请求目标服务器超时") | |
except httpx.HTTPStatusError as e: | |
# raise HTTPException(status_code=e.response.status_code, detail=str(e)) | |
print(f"HTTPStatusError发生错误: {e}") | |
print("错误堆栈:") | |
traceback.print_exc() | |
except Exception as e: | |
# traceback.print_tb(e.__traceback__) | |
# raise HTTPException(status_code=500, detail=str(e)) | |
print(f"发生错误: {e}") | |
print("错误堆栈:") | |
traceback.print_exc() | |
finally: | |
# 清理资源 | |
await file.close() | |
async def chat_completions( | |
request: ChatRequest, app_secret: str = Depends(verify_app_secret) | |
): | |
global current_index | |
logger.info("Entering chat_completions route") | |
# logger.info(f"Received request: {request}") | |
# logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}") | |
# logger.info(f"Received request json format: {json.dumps(request.model_dump())}") | |
# logger.info(f"Received request json format: {json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))}") | |
logger.info(f"App secret: {app_secret}") | |
logger.info(f"Received chat completion request for model: {request.model}") | |
if request.model not in [model["id"] for model in ALLOWED_MODELS]: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}", | |
) | |
if request.stream: | |
logger.info("Streaming response") | |
# 创建一个标志来追踪是否有响应 | |
has_response = False | |
async def content_generator(): | |
nonlocal has_response | |
try: | |
async for item in process_streaming_response(request, app_secret, current_index): | |
has_response = True | |
yield item | |
except Exception as e: | |
logger.error(f"Error in streaming response: {e}") | |
raise | |
response = StreamingResponse( | |
content_generator(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"Transfer-Encoding": "chunked" | |
} | |
) | |
# 在返回响应之前增加 current_index | |
# if has_response: | |
# current_index += 1 | |
current_index += 1 | |
return response | |
else: | |
logger.info("Non-streaming response") | |
# return await process_non_streaming_response(request) | |
async def health_check(request: Request): | |
return Response(content=json.dumps({"status": "ok"}), media_type="application/json") | |
async def environment(app_secret: str = Depends(verify_app_secret)): | |
length = 0 | |
if os.getenv('TOKEN', '').split(',') is not None: | |
length = len(os.getenv('TOKEN', '').split(',')) | |
return Response(content=json.dumps({"token": os.getenv("TOKEN", ""), "length": length, "refresh_token": os.getenv("REFRESH_TOKEN", ""), "key": os.getenv("FIREBASE_API_KEY", "")}), media_type="application/json") |