yym68686 commited on
Commit
888a669
·
1 Parent(s): 17409c4

✨ Feature: Add support for v1/moderations endpoint.

Browse files
Files changed (5) hide show
  1. main.py +16 -5
  2. models.py +5 -0
  3. request.py +19 -1
  4. response.py +0 -1
  5. utils.py +1 -0
main.py CHANGED
@@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from fastapi.exceptions import RequestValidationError
14
 
15
- from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest
16
  from request import get_payload
17
  from response import fetch_response, fetch_response_stream
18
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
@@ -191,7 +191,7 @@ app.add_middleware(
191
  app.add_middleware(StatsMiddleware)
192
 
193
  # 在 process_request 函数中更新成功和失败计数
194
- async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], provider: Dict, endpoint=None, token=None):
195
  url = provider['base_url']
196
  parsed_url = urlparse(url)
197
  # print("parsed_url", parsed_url)
@@ -237,6 +237,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
237
  engine = "whisper"
238
  request.stream = False
239
 
 
 
 
 
240
  if provider.get("engine"):
241
  engine = provider["engine"]
242
 
@@ -363,7 +367,7 @@ class ModelRequestHandler:
363
  print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
364
  return provider_list
365
 
366
- async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str, endpoint=None):
367
  config = app.state.config
368
  # api_keys_db = app.state.api_keys_db
369
  api_list = app.state.api_list
@@ -406,7 +410,7 @@ class ModelRequestHandler:
406
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
407
 
408
  # 在 try_all_providers 函数中处理失败的情况
409
- async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
410
  status_code = 500
411
  error_message = None
412
  num_providers = len(providers)
@@ -533,7 +537,7 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
533
  return token
534
 
535
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
536
- async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str = Depends(verify_api_key)):
537
  # logger.info(f"Request received: {request}")
538
  return await model_handler.request_model(request, token)
539
 
@@ -556,6 +560,13 @@ async def images_generations(
556
  ):
557
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
558
 
 
 
 
 
 
 
 
559
  from fastapi import UploadFile, File, Form, HTTPException
560
  import io
561
  @app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
 
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from fastapi.exceptions import RequestValidationError
14
 
15
+ from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest
16
  from request import get_payload
17
  from response import fetch_response, fetch_response_stream
18
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
 
191
  app.add_middleware(StatsMiddleware)
192
 
193
  # 在 process_request 函数中更新成功和失败计数
194
+ async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
195
  url = provider['base_url']
196
  parsed_url = urlparse(url)
197
  # print("parsed_url", parsed_url)
 
237
  engine = "whisper"
238
  request.stream = False
239
 
240
+ if endpoint == "/v1/moderations":
241
+ engine = "moderation"
242
+ request.stream = False
243
+
244
  if provider.get("engine"):
245
  engine = provider["engine"]
246
 
 
367
  print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
368
  return provider_list
369
 
370
+ async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
371
  config = app.state.config
372
  # api_keys_db = app.state.api_keys_db
373
  api_list = app.state.api_list
 
410
  return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
411
 
412
  # 在 try_all_providers 函数中处理失败的情况
413
+ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
414
  status_code = 500
415
  error_message = None
416
  num_providers = len(providers)
 
537
  return token
538
 
539
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
540
+ async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str = Depends(verify_api_key)):
541
  # logger.info(f"Request received: {request}")
542
  return await model_handler.request_model(request, token)
543
 
 
560
  ):
561
  return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
562
 
563
+ @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
564
+ async def images_generations(
565
+ request: ModerationRequest,
566
+ token: str = Depends(verify_api_key)
567
+ ):
568
+ return await model_handler.request_model(request, token, endpoint="/v1/moderations")
569
+
570
  from fastapi import UploadFile, File, Form, HTTPException
571
  import io
572
  @app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
models.py CHANGED
@@ -21,6 +21,11 @@ class AudioTranscriptionRequest(BaseModel):
21
  class Config:
22
  arbitrary_types_allowed = True
23
 
 
 
 
 
 
24
  class FunctionParameter(BaseModel):
25
  type: str
26
  properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
 
21
  class Config:
22
  arbitrary_types_allowed = True
23
 
24
+ class ModerationRequest(BaseModel):
25
+ input: str
26
+ model: Optional[str] = "text-moderation-latest"
27
+ stream: bool = False
28
+
29
  class FunctionParameter(BaseModel):
30
  type: str
31
  properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
request.py CHANGED
@@ -1043,7 +1043,7 @@ async def get_dalle_payload(request, engine, provider):
1043
  async def get_whisper_payload(request, engine, provider):
1044
  model = provider['model'][request.model]
1045
  headers = {
1046
- "Content-Type": "application/json",
1047
  }
1048
  if provider.get("api"):
1049
  headers['Authorization'] = f"Bearer {provider['api'].next()}"
@@ -1066,6 +1066,22 @@ async def get_whisper_payload(request, engine, provider):
1066
 
1067
  return url, headers, payload
1068
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1069
  async def get_payload(request: RequestModel, engine, provider):
1070
  if engine == "gemini":
1071
  return await get_gemini_payload(request, engine, provider)
@@ -1089,5 +1105,7 @@ async def get_payload(request: RequestModel, engine, provider):
1089
  return await get_dalle_payload(request, engine, provider)
1090
  elif engine == "whisper":
1091
  return await get_whisper_payload(request, engine, provider)
 
 
1092
  else:
1093
  raise ValueError("Unknown payload")
 
1043
  async def get_whisper_payload(request, engine, provider):
1044
  model = provider['model'][request.model]
1045
  headers = {
1046
+ "Content-Type": "multipart/form-data",
1047
  }
1048
  if provider.get("api"):
1049
  headers['Authorization'] = f"Bearer {provider['api'].next()}"
 
1066
 
1067
  return url, headers, payload
1068
 
1069
+ async def get_moderation_payload(request, engine, provider):
1070
+ model = provider['model'][request.model]
1071
+ headers = {
1072
+ "Content-Type": "application/json",
1073
+ }
1074
+ if provider.get("api"):
1075
+ headers['Authorization'] = f"Bearer {provider['api'].next()}"
1076
+ url = provider['base_url']
1077
+ url = BaseAPI(url).moderations
1078
+
1079
+ payload = {
1080
+ "input": request.input,
1081
+ }
1082
+
1083
+ return url, headers, payload
1084
+
1085
  async def get_payload(request: RequestModel, engine, provider):
1086
  if engine == "gemini":
1087
  return await get_gemini_payload(request, engine, provider)
 
1105
  return await get_dalle_payload(request, engine, provider)
1106
  elif engine == "whisper":
1107
  return await get_whisper_payload(request, engine, provider)
1108
+ elif engine == "moderation":
1109
+ return await get_moderation_payload(request, engine, provider)
1110
  else:
1111
  raise ValueError("Unknown payload")
response.py CHANGED
@@ -273,7 +273,6 @@ async def fetch_response(client, url, headers, payload):
273
  response = None
274
  if payload.get("file"):
275
  file = payload.pop("file")
276
- headers.pop("Content-Type")
277
  response = await client.post(url, headers=headers, data=payload, files={"file": file})
278
  else:
279
  response = await client.post(url, headers=headers, json=payload)
 
273
  response = None
274
  if payload.get("file"):
275
  file = payload.pop("file")
 
276
  response = await client.post(url, headers=headers, data=payload, files={"file": file})
277
  else:
278
  response = await client.post(url, headers=headers, json=payload)
utils.py CHANGED
@@ -308,6 +308,7 @@ class BaseAPI:
308
  self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
309
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
310
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
 
311
 
312
  def safe_get(data, *keys):
313
  for key in keys:
 
308
  self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
309
  self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
310
  self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
311
+ self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
312
 
313
  def safe_get(data, *keys):
314
  for key in keys: