✨ Feature: Add support for v1/moderations endpoint.
Browse files- main.py +16 -5
- models.py +5 -0
- request.py +19 -1
- response.py +0 -1
- utils.py +1 -0
main.py
CHANGED
@@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
|
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from fastapi.exceptions import RequestValidationError
|
14 |
|
15 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest
|
16 |
from request import get_payload
|
17 |
from response import fetch_response, fetch_response_stream
|
18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
@@ -191,7 +191,7 @@ app.add_middleware(
|
|
191 |
app.add_middleware(StatsMiddleware)
|
192 |
|
193 |
# 在 process_request 函数中更新成功和失败计数
|
194 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], provider: Dict, endpoint=None, token=None):
|
195 |
url = provider['base_url']
|
196 |
parsed_url = urlparse(url)
|
197 |
# print("parsed_url", parsed_url)
|
@@ -237,6 +237,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
237 |
engine = "whisper"
|
238 |
request.stream = False
|
239 |
|
|
|
|
|
|
|
|
|
240 |
if provider.get("engine"):
|
241 |
engine = provider["engine"]
|
242 |
|
@@ -363,7 +367,7 @@ class ModelRequestHandler:
|
|
363 |
print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
364 |
return provider_list
|
365 |
|
366 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str, endpoint=None):
|
367 |
config = app.state.config
|
368 |
# api_keys_db = app.state.api_keys_db
|
369 |
api_list = app.state.api_list
|
@@ -406,7 +410,7 @@ class ModelRequestHandler:
|
|
406 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
407 |
|
408 |
# 在 try_all_providers 函数中处理失败的情况
|
409 |
-
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
410 |
status_code = 500
|
411 |
error_message = None
|
412 |
num_providers = len(providers)
|
@@ -533,7 +537,7 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
|
|
533 |
return token
|
534 |
|
535 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
536 |
-
async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest], token: str = Depends(verify_api_key)):
|
537 |
# logger.info(f"Request received: {request}")
|
538 |
return await model_handler.request_model(request, token)
|
539 |
|
@@ -556,6 +560,13 @@ async def images_generations(
|
|
556 |
):
|
557 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
from fastapi import UploadFile, File, Form, HTTPException
|
560 |
import io
|
561 |
@app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
|
|
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from fastapi.exceptions import RequestValidationError
|
14 |
|
15 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest
|
16 |
from request import get_payload
|
17 |
from response import fetch_response, fetch_response_stream
|
18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
|
191 |
app.add_middleware(StatsMiddleware)
|
192 |
|
193 |
# 在 process_request 函数中更新成功和失败计数
|
194 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
195 |
url = provider['base_url']
|
196 |
parsed_url = urlparse(url)
|
197 |
# print("parsed_url", parsed_url)
|
|
|
237 |
engine = "whisper"
|
238 |
request.stream = False
|
239 |
|
240 |
+
if endpoint == "/v1/moderations":
|
241 |
+
engine = "moderation"
|
242 |
+
request.stream = False
|
243 |
+
|
244 |
if provider.get("engine"):
|
245 |
engine = provider["engine"]
|
246 |
|
|
|
367 |
print(json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder))
|
368 |
return provider_list
|
369 |
|
370 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
371 |
config = app.state.config
|
372 |
# api_keys_db = app.state.api_keys_db
|
373 |
api_list = app.state.api_list
|
|
|
410 |
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint, token)
|
411 |
|
412 |
# 在 try_all_providers 函数中处理失败的情况
|
413 |
+
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None, token: str = None):
|
414 |
status_code = 500
|
415 |
error_message = None
|
416 |
num_providers = len(providers)
|
|
|
537 |
return token
|
538 |
|
539 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
540 |
+
async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str = Depends(verify_api_key)):
|
541 |
# logger.info(f"Request received: {request}")
|
542 |
return await model_handler.request_model(request, token)
|
543 |
|
|
|
560 |
):
|
561 |
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
|
562 |
|
563 |
+
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
564 |
+
async def images_generations(
|
565 |
+
request: ModerationRequest,
|
566 |
+
token: str = Depends(verify_api_key)
|
567 |
+
):
|
568 |
+
return await model_handler.request_model(request, token, endpoint="/v1/moderations")
|
569 |
+
|
570 |
from fastapi import UploadFile, File, Form, HTTPException
|
571 |
import io
|
572 |
@app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
|
models.py
CHANGED
@@ -21,6 +21,11 @@ class AudioTranscriptionRequest(BaseModel):
|
|
21 |
class Config:
|
22 |
arbitrary_types_allowed = True
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
class FunctionParameter(BaseModel):
|
25 |
type: str
|
26 |
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
|
|
21 |
class Config:
|
22 |
arbitrary_types_allowed = True
|
23 |
|
24 |
+
class ModerationRequest(BaseModel):
|
25 |
+
input: str
|
26 |
+
model: Optional[str] = "text-moderation-latest"
|
27 |
+
stream: bool = False
|
28 |
+
|
29 |
class FunctionParameter(BaseModel):
|
30 |
type: str
|
31 |
properties: Dict[str, Dict[str, Union[str, Dict[str, str]]]]
|
request.py
CHANGED
@@ -1043,7 +1043,7 @@ async def get_dalle_payload(request, engine, provider):
|
|
1043 |
async def get_whisper_payload(request, engine, provider):
|
1044 |
model = provider['model'][request.model]
|
1045 |
headers = {
|
1046 |
-
"Content-Type": "
|
1047 |
}
|
1048 |
if provider.get("api"):
|
1049 |
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
@@ -1066,6 +1066,22 @@ async def get_whisper_payload(request, engine, provider):
|
|
1066 |
|
1067 |
return url, headers, payload
|
1068 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1069 |
async def get_payload(request: RequestModel, engine, provider):
|
1070 |
if engine == "gemini":
|
1071 |
return await get_gemini_payload(request, engine, provider)
|
@@ -1089,5 +1105,7 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
1089 |
return await get_dalle_payload(request, engine, provider)
|
1090 |
elif engine == "whisper":
|
1091 |
return await get_whisper_payload(request, engine, provider)
|
|
|
|
|
1092 |
else:
|
1093 |
raise ValueError("Unknown payload")
|
|
|
1043 |
async def get_whisper_payload(request, engine, provider):
|
1044 |
model = provider['model'][request.model]
|
1045 |
headers = {
|
1046 |
+
"Content-Type": "multipart/form-data",
|
1047 |
}
|
1048 |
if provider.get("api"):
|
1049 |
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
|
|
1066 |
|
1067 |
return url, headers, payload
|
1068 |
|
1069 |
+
async def get_moderation_payload(request, engine, provider):
|
1070 |
+
model = provider['model'][request.model]
|
1071 |
+
headers = {
|
1072 |
+
"Content-Type": "application/json",
|
1073 |
+
}
|
1074 |
+
if provider.get("api"):
|
1075 |
+
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
1076 |
+
url = provider['base_url']
|
1077 |
+
url = BaseAPI(url).moderations
|
1078 |
+
|
1079 |
+
payload = {
|
1080 |
+
"input": request.input,
|
1081 |
+
}
|
1082 |
+
|
1083 |
+
return url, headers, payload
|
1084 |
+
|
1085 |
async def get_payload(request: RequestModel, engine, provider):
|
1086 |
if engine == "gemini":
|
1087 |
return await get_gemini_payload(request, engine, provider)
|
|
|
1105 |
return await get_dalle_payload(request, engine, provider)
|
1106 |
elif engine == "whisper":
|
1107 |
return await get_whisper_payload(request, engine, provider)
|
1108 |
+
elif engine == "moderation":
|
1109 |
+
return await get_moderation_payload(request, engine, provider)
|
1110 |
else:
|
1111 |
raise ValueError("Unknown payload")
|
response.py
CHANGED
@@ -273,7 +273,6 @@ async def fetch_response(client, url, headers, payload):
|
|
273 |
response = None
|
274 |
if payload.get("file"):
|
275 |
file = payload.pop("file")
|
276 |
-
headers.pop("Content-Type")
|
277 |
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
278 |
else:
|
279 |
response = await client.post(url, headers=headers, json=payload)
|
|
|
273 |
response = None
|
274 |
if payload.get("file"):
|
275 |
file = payload.pop("file")
|
|
|
276 |
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
277 |
else:
|
278 |
response = await client.post(url, headers=headers, json=payload)
|
utils.py
CHANGED
@@ -308,6 +308,7 @@ class BaseAPI:
|
|
308 |
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
309 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
310 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
|
|
311 |
|
312 |
def safe_get(data, *keys):
|
313 |
for key in keys:
|
|
|
308 |
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
|
309 |
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
|
310 |
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
|
311 |
+
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
|
312 |
|
313 |
def safe_get(data, *keys):
|
314 |
for key in keys:
|