Spaces:
Running
Running
File size: 11,778 Bytes
feb939c eeaab51 feb939c 38f0a99 feb939c 38f0a99 feb939c 38f0a99 feb939c 38f0a99 feb939c 38f0a99 feb939c 38f0a99 feb939c 38f0a99 feb939c 38f0a99 feb939c 3e71e74 eeaab51 0a3d38f feb939c 38f0a99 feb939c 38f0a99 feb939c 0a3d38f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
import asyncio
import json
import os
import traceback
from typing import Optional
from uuid import uuid4
import httpx
from fastapi import File, UploadFile
from fastapi import APIRouter, Response, Request, Depends, HTTPException
from fastapi.responses import StreamingResponse
from core.auth import verify_app_secret
from core.config import get_settings
from core.logger import setup_logger
from core.models import ChatRequest
from core.utils import process_streaming_response
from playsound import playsound # 用于播放音频
# from rich import print
logger = setup_logger(__name__)
router = APIRouter()
ALLOWED_MODELS = get_settings().ALLOWED_MODELS
current_index = 0
@router.get("/models")
async def list_models():
return {"object": "list", "data": ALLOWED_MODELS, "success": True}
@router.options("/chat/completions")
async def chat_completions_options():
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
},
)
# 识图
# 识图
# 文本转语音
@router.post("/audio/speech")
async def speech(request: Request):
global current_index
url = 'https://api.thinkbuddy.ai/v1/content/speech/tts'
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
if len(token_array) > 0:
current_index = current_index % len(token_array)
print('speech current index is ', current_index)
request_headers = {**get_settings().HEADERS,
'authorization': f"Bearer {token_array[current_index]}",
'Accept': 'application/json, text/plain, */*',
}
# data = {
# "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。",
# "voice": "nova" # alloy echo fable onyx nova shimmer
# }
body = await request.json()
try:
async with httpx.AsyncClient(http2=True) as client:
response = await client.post(url, headers=request_headers, json=body)
response.raise_for_status()
# 假设响应是音频数据,保存为文件
if response.status_code == 200:
# 保存音频文件
with open('output.mp3', 'wb') as f:
f.write(response.content)
print("音频文件已保存为 output.mp3")
# 异步播放音频
# 使用 asyncio.to_thread 来避免阻塞事件循环
# await asyncio.to_thread(playsound, 'output.mp3')
return True
else:
print(f"请求失败,状态码: {response.status_code}")
print(f"响应内容: {response.text}")
return False
except httpx.RequestError as e:
print(f"请求错误: {e}")
print("错误堆栈:")
traceback.print_exc()
return False
except httpx.HTTPStatusError as e:
print(f"HTTP 错误: {e}")
print("错误堆栈:")
traceback.print_exc()
return False
except Exception as e:
print(f"发生错误: {e}")
print("错误堆栈:")
traceback.print_exc()
return False
finally:
current_index += 1
# 语音转文本
@router.post("/audio/transcriptions")
async def transcriptions(request: Request, file: UploadFile = File(...)):
global current_index
url = 'https://api.thinkbuddy.ai/v1/content/transcribe'
params = {'enhance': 'true'}
try:
# 读取文件内容
content = await safe_read_file(file)
# 获取原始 content-type
content_type = request.headers.get('content-type')
# files = {
# 'file': (str(uuid4()),
# content,
# file.content_type or 'application/octet-stream')
# }
files = {
'file': ('file.mp4', content, 'audio/mp4'),
'model': (None, 'whisper-1')
}
# 记录请求信息
logger.info(f"Received upload request for file: {file.filename}")
logger.info(f"Content-Type: {request.headers.get('content-type')}")
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
if len(token_array) > 0:
current_index = current_index % len(token_array)
print('transcriptions current index is ', current_index)
request_headers = {**get_settings().HEADERS,
'authorization': f"Bearer {token_array[current_index]}",
'Accept': 'application/json, text/plain, */*',
'Content-Type': content_type,
}
# 设置较长的超时时间
timeout = httpx.Timeout(
connect=30.0, # 连接超时
read=300.0, # 读取超时
write=30.0, # 写入超时
pool=30.0 # 连接池超时
)
# 使用httpx发送异步请求
async with httpx.AsyncClient(http2=True, timeout=timeout) as client:
response = await client.post(url,
params=params,
headers=request_headers,
files=files)
current_index += 1
response.raise_for_status()
return response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求目标服务器超时")
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=str(e))
except Exception as e:
traceback.print_tb(e.__traceback__)
raise HTTPException(status_code=500, detail=str(e))
finally:
# 清理资源
await file.close()
async def safe_read_file(file: UploadFile) -> Optional[bytes]:
"""安全地读取文件内容"""
try:
return await file.read()
except Exception as e:
logger.error(f"Error reading file: {str(e)}")
return None
# 文件上传
@router.post("/upload")
async def upload_file(request: Request, file: UploadFile = File(...)):
global current_index
try:
# 读取文件内容
content = await safe_read_file(file)
# 获取原始 content-type
content_type = request.headers.get('content-type')
files = {
'file': (
# str(uuid4()),
file.filename, # 使用原始文件名而不是 UUID
content,
file.content_type )
}
# 记录请求信息
logger.info(f"Received upload request for file: {file.filename}")
logger.info(f"Content-Type: {request.headers.get('content-type')}")
token_str = os.getenv('TOKEN', '')
token_array = token_str.split(',')
if len(token_array) > 0:
current_index = current_index % len(token_array)
print('upload_file current index is ', current_index)
request_headers = {**get_settings().HEADERS,
'authorization': f"Bearer {token_array[current_index]}",
'Accept': 'application/json, text/plain, */*',
'Content-Type': content_type,
}
# 使用httpx发送异步请求
async with httpx.AsyncClient() as client:
response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100)
current_index += 1
response.raise_for_status()
return response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求目标服务器超时")
except httpx.HTTPStatusError as e:
# raise HTTPException(status_code=e.response.status_code, detail=str(e))
print(f"HTTPStatusError发生错误: {e}")
print("错误堆栈:")
traceback.print_exc()
except Exception as e:
# traceback.print_tb(e.__traceback__)
# raise HTTPException(status_code=500, detail=str(e))
print(f"发生错误: {e}")
print("错误堆栈:")
traceback.print_exc()
finally:
# 清理资源
await file.close()
@router.post("/chat/completions")
async def chat_completions(
request: ChatRequest, app_secret: str = Depends(verify_app_secret)
):
global current_index
logger.info("Entering chat_completions route")
# logger.info(f"Received request: {request}")
# logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}")
# logger.info(f"Received request json format: {json.dumps(request.model_dump())}")
# logger.info(f"Received request json format: {json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))}")
logger.info(f"App secret: {app_secret}")
logger.info(f"Received chat completion request for model: {request.model}")
if request.model not in [model["id"] for model in ALLOWED_MODELS]:
raise HTTPException(
status_code=400,
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}",
)
if request.stream:
logger.info("Streaming response")
# 创建一个标志来追踪是否有响应
has_response = False
async def content_generator():
nonlocal has_response
try:
async for item in process_streaming_response(request, app_secret, current_index):
has_response = True
yield item
except Exception as e:
logger.error(f"Error in streaming response: {e}")
raise
response = StreamingResponse(
content_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
# 在返回响应之前增加 current_index
# if has_response:
# current_index += 1
current_index += 1
return response
else:
logger.info("Non-streaming response")
# return await process_non_streaming_response(request)
@router.route('/')
@router.route('/healthz')
@router.route('/ready')
@router.route('/alive')
@router.route('/status')
@router.get("/health")
async def health_check(request: Request):
return Response(content=json.dumps({"status": "ok"}), media_type="application/json")
@router.post("/env")
async def environment(app_secret: str = Depends(verify_app_secret)):
length = 0
if os.getenv('TOKEN', '').split(',') is not None:
length = len(os.getenv('TOKEN', '').split(','))
return Response(content=json.dumps({"token": os.getenv("TOKEN", ""), "length": length, "refresh_token": os.getenv("REFRESH_TOKEN", ""), "key": os.getenv("FIREBASE_API_KEY", "")}), media_type="application/json") |