kevin commited on
Commit
38f0a99
·
1 Parent(s): e663277

多token轮询

Browse files
Files changed (2) hide show
  1. core/router.py +122 -77
  2. core/utils.py +84 -71
core/router.py CHANGED
@@ -23,6 +23,8 @@ logger = setup_logger(__name__)
23
  router = APIRouter()
24
  ALLOWED_MODELS = get_settings().ALLOWED_MODELS
25
 
 
 
26
  @router.get("/models")
27
  async def list_models():
28
  return {"object": "list", "data": ALLOWED_MODELS, "success": True}
@@ -43,56 +45,65 @@ async def chat_completions_options():
43
  # 文本转语音
44
  @router.post("/audio/speech")
45
  async def speech(request: Request):
 
46
  url = 'https://api.thinkbuddy.ai/v1/content/speech/tts'
47
- request_headers = {**get_settings().HEADERS,
48
- 'authorization': f"Bearer {os.getenv('TOKEN', '')}",
49
- 'Accept': 'application/json, text/plain, */*',
50
- }
51
- # data = {
52
- # "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。",
53
- # "voice": "nova" # alloy echo fable onyx nova shimmer
54
- # }
55
- body = await request.json()
56
- try:
57
- async with httpx.AsyncClient(http2=True) as client:
58
- response = await client.post(url, headers=request_headers, json=body)
59
- response.raise_for_status()
60
-
61
- # 假设响应是音频数据,保存为文件
62
- if response.status_code == 200:
63
- # 保存音频文件
64
- with open('output.mp3', 'wb') as f:
65
- f.write(response.content)
66
- print("音频文件已保存为 output.mp3")
67
-
68
- # 异步播放音频
69
- # 使用 asyncio.to_thread 来避免阻塞事件循环
70
- # await asyncio.to_thread(playsound, 'output.mp3')
71
- return True
72
- else:
73
- print(f"请求失败,状态码: {response.status_code}")
74
- print(f"响应内容: {response.text}")
75
- return False
76
-
77
- except httpx.RequestError as e:
78
- print(f"请求错误: {e}")
79
- print("错误堆栈:")
80
- traceback.print_exc()
81
- return False
82
- except httpx.HTTPStatusError as e:
83
- print(f"HTTP 错误: {e}")
84
- print("错误堆栈:")
85
- traceback.print_exc()
86
- return False
87
- except Exception as e:
88
- print(f"发生错误: {e}")
89
- print("错误堆栈:")
90
- traceback.print_exc()
91
- return False
 
 
 
 
 
 
 
92
 
93
  # 语音转文本
94
  @router.post("/audio/transcriptions")
95
  async def transcriptions(request: Request, file: UploadFile = File(...)):
 
96
  url = 'https://api.thinkbuddy.ai/v1/content/transcribe'
97
  params = {'enhance': 'true'}
98
  try:
@@ -112,26 +123,32 @@ async def transcriptions(request: Request, file: UploadFile = File(...)):
112
  # 记录请求信息
113
  logger.info(f"Received upload request for file: {file.filename}")
114
  logger.info(f"Content-Type: {request.headers.get('content-type')}")
115
- request_headers = {**get_settings().HEADERS,
116
- 'authorization': f"Bearer {os.getenv('TOKEN', '')}",
117
- 'Accept': 'application/json, text/plain, */*',
118
- 'Content-Type': content_type,
119
- }
120
- # 设置较长的超时时间
121
- timeout = httpx.Timeout(
122
- connect=30.0, # 连接超时
123
- read=300.0, # 读取超时
124
- write=30.0, # 写入超时
125
- pool=30.0 # 连接池超时
126
- )
127
- # 使用httpx发送异步请求
128
- async with httpx.AsyncClient(http2=True, timeout=timeout) as client:
129
- response = await client.post(url,
130
- params=params,
131
- headers=request_headers,
132
- files=files)
133
- response.raise_for_status()
134
- return response.json()
 
 
 
 
 
 
135
 
136
  except httpx.TimeoutException:
137
  raise HTTPException(status_code=504, detail="请求目标服务器超时")
@@ -153,6 +170,7 @@ async def safe_read_file(file: UploadFile) -> Optional[bytes]:
153
  # 文件上传
154
  @router.post("/upload")
155
  async def upload_file(request: Request, file: UploadFile = File(...)):
 
156
  try:
157
  # 读取文件内容
158
  content = await safe_read_file(file)
@@ -168,16 +186,22 @@ async def upload_file(request: Request, file: UploadFile = File(...)):
168
  # 记录请求信息
169
  logger.info(f"Received upload request for file: {file.filename}")
170
  logger.info(f"Content-Type: {request.headers.get('content-type')}")
171
- request_headers = {**get_settings().HEADERS,
172
- 'authorization': f"Bearer {os.getenv('TOKEN', '')}",
173
- 'Accept': 'application/json, text/plain, */*',
174
- 'Content-Type': content_type,
175
- }
176
- # 使用httpx发送异步请求
177
- async with httpx.AsyncClient() as client:
178
- response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100)
179
- response.raise_for_status()
180
- return response.json()
 
 
 
 
 
 
181
 
182
  except httpx.TimeoutException:
183
  raise HTTPException(status_code=504, detail="请求目标服务器超时")
@@ -200,6 +224,7 @@ async def upload_file(request: Request, file: UploadFile = File(...)):
200
  async def chat_completions(
201
  request: ChatRequest, app_secret: str = Depends(verify_app_secret)
202
  ):
 
203
  logger.info("Entering chat_completions route")
204
  # logger.info(f"Received request: {request}")
205
  # logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}")
@@ -216,8 +241,22 @@ async def chat_completions(
216
 
217
  if request.stream:
218
  logger.info("Streaming response")
219
- return StreamingResponse(
220
- process_streaming_response(request, app_secret),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  media_type="text/event-stream",
222
  headers={
223
  "Cache-Control": "no-cache",
@@ -225,6 +264,12 @@ async def chat_completions(
225
  "Transfer-Encoding": "chunked"
226
  }
227
  )
 
 
 
 
 
 
228
  else:
229
  logger.info("Non-streaming response")
230
  # return await process_non_streaming_response(request)
 
23
  router = APIRouter()
24
  ALLOWED_MODELS = get_settings().ALLOWED_MODELS
25
 
26
+ current_index = 0
27
+
28
  @router.get("/models")
29
  async def list_models():
30
  return {"object": "list", "data": ALLOWED_MODELS, "success": True}
 
45
  # 文本转语音
46
  @router.post("/audio/speech")
47
  async def speech(request: Request):
48
+ global current_index
49
  url = 'https://api.thinkbuddy.ai/v1/content/speech/tts'
50
+ token_str = os.getenv('TOKEN', '')
51
+ token_array = token_str.split(',')
52
+ if len(token_array) > 0:
53
+ current_index = current_index % len(token_array)
54
+ print('speech current index is ', current_index)
55
+ request_headers = {**get_settings().HEADERS,
56
+ 'authorization': f"Bearer {token_array[current_index]}",
57
+ 'Accept': 'application/json, text/plain, */*',
58
+ }
59
+ # data = {
60
+ # "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。",
61
+ # "voice": "nova" # alloy echo fable onyx nova shimmer
62
+ # }
63
+ body = await request.json()
64
+ try:
65
+ async with httpx.AsyncClient(http2=True) as client:
66
+ response = await client.post(url, headers=request_headers, json=body)
67
+ response.raise_for_status()
68
+
69
+ # 假设响应是音频数据,保存为文件
70
+ if response.status_code == 200:
71
+ # 保存音频文件
72
+ with open('output.mp3', 'wb') as f:
73
+ f.write(response.content)
74
+ print("音频文件已保存为 output.mp3")
75
+
76
+ # 异步播放音频
77
+ # 使用 asyncio.to_thread 来避免阻塞事件循环
78
+ # await asyncio.to_thread(playsound, 'output.mp3')
79
+ return True
80
+ else:
81
+ print(f"请求失败,状态码: {response.status_code}")
82
+ print(f"响应内容: {response.text}")
83
+ return False
84
+
85
+ except httpx.RequestError as e:
86
+ print(f"请求错误: {e}")
87
+ print("错误堆栈:")
88
+ traceback.print_exc()
89
+ return False
90
+ except httpx.HTTPStatusError as e:
91
+ print(f"HTTP 错误: {e}")
92
+ print("错误堆栈:")
93
+ traceback.print_exc()
94
+ return False
95
+ except Exception as e:
96
+ print(f"发生错误: {e}")
97
+ print("错误堆栈:")
98
+ traceback.print_exc()
99
+ return False
100
+ finally:
101
+ current_index += 1
102
 
103
  # 语音转文本
104
  @router.post("/audio/transcriptions")
105
  async def transcriptions(request: Request, file: UploadFile = File(...)):
106
+ global current_index
107
  url = 'https://api.thinkbuddy.ai/v1/content/transcribe'
108
  params = {'enhance': 'true'}
109
  try:
 
123
  # 记录请求信息
124
  logger.info(f"Received upload request for file: {file.filename}")
125
  logger.info(f"Content-Type: {request.headers.get('content-type')}")
126
+ token_str = os.getenv('TOKEN', '')
127
+ token_array = token_str.split(',')
128
+ if len(token_array) > 0:
129
+ current_index = current_index % len(token_array)
130
+ print('transcriptions current index is ', current_index)
131
+ request_headers = {**get_settings().HEADERS,
132
+ 'authorization': f"Bearer {token_array[current_index]}",
133
+ 'Accept': 'application/json, text/plain, */*',
134
+ 'Content-Type': content_type,
135
+ }
136
+ # 设置较长的超时时间
137
+ timeout = httpx.Timeout(
138
+ connect=30.0, # 连接超时
139
+ read=300.0, # 读取超时
140
+ write=30.0, # 写入超时
141
+ pool=30.0 # 连接池超时
142
+ )
143
+ # 使用httpx发送异步请求
144
+ async with httpx.AsyncClient(http2=True, timeout=timeout) as client:
145
+ response = await client.post(url,
146
+ params=params,
147
+ headers=request_headers,
148
+ files=files)
149
+ current_index += 1
150
+ response.raise_for_status()
151
+ return response.json()
152
 
153
  except httpx.TimeoutException:
154
  raise HTTPException(status_code=504, detail="请求目标服务器超时")
 
170
  # 文件上传
171
  @router.post("/upload")
172
  async def upload_file(request: Request, file: UploadFile = File(...)):
173
+ global current_index
174
  try:
175
  # 读取文件内容
176
  content = await safe_read_file(file)
 
186
  # 记录请求信息
187
  logger.info(f"Received upload request for file: {file.filename}")
188
  logger.info(f"Content-Type: {request.headers.get('content-type')}")
189
+ token_str = os.getenv('TOKEN', '')
190
+ token_array = token_str.split(',')
191
+ if len(token_array) > 0:
192
+ current_index = current_index % len(token_array)
193
+ print('upload_file current index is ', current_index)
194
+ request_headers = {**get_settings().HEADERS,
195
+ 'authorization': f"Bearer {token_array[current_index]}",
196
+ 'Accept': 'application/json, text/plain, */*',
197
+ 'Content-Type': content_type,
198
+ }
199
+ # 使用httpx发送异步请求
200
+ async with httpx.AsyncClient() as client:
201
+ response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100)
202
+ current_index += 1
203
+ response.raise_for_status()
204
+ return response.json()
205
 
206
  except httpx.TimeoutException:
207
  raise HTTPException(status_code=504, detail="请求目标服务器超时")
 
224
  async def chat_completions(
225
  request: ChatRequest, app_secret: str = Depends(verify_app_secret)
226
  ):
227
+ global current_index
228
  logger.info("Entering chat_completions route")
229
  # logger.info(f"Received request: {request}")
230
  # logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}")
 
241
 
242
  if request.stream:
243
  logger.info("Streaming response")
244
+
245
+ # 创建一个标志来追踪是否有响应
246
+ has_response = False
247
+
248
+ async def content_generator():
249
+ nonlocal has_response
250
+ try:
251
+ async for item in process_streaming_response(request, app_secret, current_index):
252
+ has_response = True
253
+ yield item
254
+ except Exception as e:
255
+ logger.error(f"Error in streaming response: {e}")
256
+ raise
257
+
258
+ response = StreamingResponse(
259
+ content_generator(),
260
  media_type="text/event-stream",
261
  headers={
262
  "Cache-Control": "no-cache",
 
264
  "Transfer-Encoding": "chunked"
265
  }
266
  )
267
+
268
+ # 在返回响应之前增加 current_index
269
+ # if has_response:
270
+ # current_index += 1
271
+ current_index += 1
272
+ return response
273
  else:
274
  logger.info("Non-streaming response")
275
  # return await process_non_streaming_response(request)
core/utils.py CHANGED
@@ -35,32 +35,39 @@ def decode_unicode_escape(s):
35
 
36
  FIREBASE_API_KEY = settings.FIREBASE_API_KEY
37
  async def refresh_token_via_rest(refresh_token):
38
- # Firebase Auth REST API endpoint
39
- url = f"https://securetoken.googleapis.com/v1/token?key={FIREBASE_API_KEY}"
40
-
41
- payload = {
42
- 'grant_type': 'refresh_token',
43
- 'refresh_token': refresh_token
44
- }
 
 
 
 
45
 
46
- try:
47
- async with httpx.AsyncClient() as client:
48
- response = await client.post(url, json=payload)
49
- if response.status_code == 200:
50
- data = response.json()
51
- print(json.dumps(data, indent=2))
52
- # return {
53
- # 'id_token': data['id_token'],
54
- # 'refresh_token': data.get('refresh_token'),
55
- # 'expires_in': data['expires_in']
56
- # }
57
- return data['id_token']
58
- else:
59
- print(f"刷新失败: {response.text}")
 
 
 
 
60
  return None
61
- except Exception as e:
62
- print(f"请求异常: {e}")
63
- return None
64
 
65
 
66
  async def sign_in_with_idp():
@@ -185,7 +192,7 @@ def create_chat_completion_data(
185
  "usage": None,
186
  }
187
 
188
- async def process_streaming_response(request: ChatRequest, app_secret: str):
189
  # 创建自定义 SSL 上下文
190
  ssl_context = ssl.create_default_context()
191
  ssl_context.check_hostname = True
@@ -196,57 +203,63 @@ async def process_streaming_response(request: ChatRequest, app_secret: str):
196
  # http2=True # 启用 HTTP/2
197
  ) as client:
198
  try:
199
- request_headers = {**settings.HEADERS, 'authorization': f"Bearer {os.getenv('TOKEN', '')}"} # 从环境变量中获取新的TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- # 直接使用 request.model_dump() 或 request.dict() 获取字典格式的数据
202
- request_data = request.model_dump() # 如果使用较新版本的 Pydantic
203
- # # 获取请求数据
204
- # request_data = {
205
- # "model": request.model,
206
- # "messages": [msg.dict() for msg in request.messages],
207
- # "temperature": request.temperature,
208
- # "top_p": request.top_p,
209
- # "max_tokens": request.max_tokens,
210
- # "stream": request.stream
211
- # }
212
- # print("Request Headers:", json.dumps(request_headers, indent=2)) # 格式化打印
213
- # print("Request Body:", json.dumps(request.json(), indent=4, ensure_ascii=False)) # 格式化打印
214
- print("Request Headers:", json.dumps(request_headers, indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印
215
- print("Request Body:", json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印
 
 
216
 
217
- async with client.stream(
218
- "POST",
219
- f"https://api.thinkbuddy.ai/v1/chat/completions",
220
- headers=request_headers,
221
- json=request_data,
222
- timeout=100,
223
- ) as response:
224
- response.raise_for_status()
225
- timestamp = int(datetime.now().timestamp())
226
- async for line in response.aiter_lines():
227
- # print(f"{type(line)}: {line}")
228
- if line and line.startswith("data: "):
229
- try:
230
- if line.strip() == 'data: [DONE]':
231
- await response.aclose()
232
- break
233
- data_str = line[6:] # 去掉 'data: ' 前缀
234
 
235
- # 解析JSON
236
- json_data = json.loads(data_str)
237
- if 'choices' in json_data and len(json_data['choices']) > 0:
238
- delta = json_data['choices'][0].get('delta', {})
239
- if 'content' in delta:
240
- print(delta['content'], end='', flush=True)
241
- yield f"data: {json.dumps(create_chat_completion_data(delta['content'], request.model, timestamp))}\n\n"
242
 
243
- except json.JSONDecodeError as e:
244
- print(f"JSON解析错误: {e}")
245
- print(f"原始数据: {line}")
246
- continue
247
 
248
- yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
249
- yield "data: [DONE]\n\n"
250
  except ConnectError as e:
251
  logger.error(f"Connection error details: {str(e)}")
252
  raise HTTPException(
 
35
 
36
  FIREBASE_API_KEY = settings.FIREBASE_API_KEY
37
  async def refresh_token_via_rest(refresh_token):
38
+ refresh_token_array = [x.strip() for x in refresh_token.split(',')]
39
+ token_array = []
40
+ if len(refresh_token_array) > 0:
41
+ print('refresh token length is ', len(refresh_token_array))
42
+ for e in refresh_token_array:
43
+ # Firebase Auth REST API endpoint
44
+ url = f"https://securetoken.googleapis.com/v1/token?key={FIREBASE_API_KEY}"
45
+ payload = {
46
+ 'grant_type': 'refresh_token',
47
+ 'refresh_token': e
48
+ }
49
 
50
+ try:
51
+ async with httpx.AsyncClient() as client:
52
+ response = await client.post(url, json=payload)
53
+ if response.status_code == 200:
54
+ data = response.json()
55
+ print(json.dumps(data, indent=2))
56
+ # return {
57
+ # 'id_token': data['id_token'],
58
+ # 'refresh_token': data.get('refresh_token'),
59
+ # 'expires_in': data['expires_in']
60
+ # }
61
+ # return data['id_token']
62
+ token_array.append(data['id_token'])
63
+ else:
64
+ print(f"刷新失败: {response.text}")
65
+ return None
66
+ except Exception as e:
67
+ print(f"请求异常: {e}")
68
  return None
69
+
70
+ return ','.join(token_array)
 
71
 
72
 
73
  async def sign_in_with_idp():
 
192
  "usage": None,
193
  }
194
 
195
+ async def process_streaming_response(request: ChatRequest, app_secret: str, current_index: int):
196
  # 创建自定义 SSL 上下文
197
  ssl_context = ssl.create_default_context()
198
  ssl_context.check_hostname = True
 
203
  # http2=True # 启用 HTTP/2
204
  ) as client:
205
  try:
206
+ token_str = os.getenv('TOKEN', '')
207
+ token_array = token_str.split(',')
208
+ if len(token_array) > 0:
209
+ current_index = current_index % len(token_array)
210
+ print('completions current index is ', current_index)
211
+ request_headers = {**settings.HEADERS, 'authorization': f"Bearer {token_array[current_index]}"} # 从环境变量中获取新的TOKEN
212
+
213
+ # 直接使用 request.model_dump() 或 request.dict() 获取字典格式的数据
214
+ request_data = request.model_dump() # 如果使用较新版本的 Pydantic
215
+ # # 获取请求数据
216
+ # request_data = {
217
+ # "model": request.model,
218
+ # "messages": [msg.dict() for msg in request.messages],
219
+ # "temperature": request.temperature,
220
+ # "top_p": request.top_p,
221
+ # "max_tokens": request.max_tokens,
222
+ # "stream": request.stream
223
+ # }
224
+ # print("Request Headers:", json.dumps(request_headers, indent=2)) # 格式化打印
225
+ # print("Request Body:", json.dumps(request.json(), indent=4, ensure_ascii=False)) # 格式化打印
226
+ print("Request Headers:", json.dumps(request_headers, indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印
227
+ print("Request Body:", json.dumps(request.model_dump(), indent=4, ensure_ascii=False, sort_keys=True, separators=(',', ': '))) # 格式化打印
228
 
229
+ async with client.stream(
230
+ "POST",
231
+ f"https://api.thinkbuddy.ai/v1/chat/completions",
232
+ headers=request_headers,
233
+ json=request_data,
234
+ timeout=100,
235
+ ) as response:
236
+ response.raise_for_status()
237
+ timestamp = int(datetime.now().timestamp())
238
+ async for line in response.aiter_lines():
239
+ # print(f"{type(line)}: {line}")
240
+ if line and line.startswith("data: "):
241
+ try:
242
+ if line.strip() == 'data: [DONE]':
243
+ await response.aclose()
244
+ break
245
+ data_str = line[6:] # 去掉 'data: ' 前缀
246
 
247
+ # 解析JSON
248
+ json_data = json.loads(data_str)
249
+ if 'choices' in json_data and len(json_data['choices']) > 0:
250
+ delta = json_data['choices'][0].get('delta', {})
251
+ if 'content' in delta:
252
+ print(delta['content'], end='', flush=True)
253
+ yield f"data: {json.dumps(create_chat_completion_data(delta['content'], request.model, timestamp))}\n\n"
 
 
 
 
 
 
 
 
 
 
254
 
255
+ except json.JSONDecodeError as e:
256
+ print(f"JSON解析错误: {e}")
257
+ print(f"原始数据: {line}")
258
+ continue
 
 
 
259
 
260
+ yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
261
+ yield "data: [DONE]\n\n"
 
 
262
 
 
 
263
  except ConnectError as e:
264
  logger.error(f"Connection error details: {str(e)}")
265
  raise HTTPException(