Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
from fastapi import FastAPI, Request | |
from fastapi.exceptions import RequestValidationError | |
from pydantic import ValidationError | |
from starlette.exceptions import HTTPException | |
from starlette.middleware.cors import CORSMiddleware | |
from uvicorn.protocols.http.h11_impl import STATUS_PHRASES | |
from common.exception.errors import BaseExceptionMixin | |
from common.response.response_code import CustomResponseCode, StandardResponseCode | |
from common.response.response_schema import response_base | |
from common.schema import ( | |
CUSTOM_VALIDATION_ERROR_MESSAGES, | |
) | |
from core.conf import settings | |
from utils.serializers import MsgSpecJSONResponse | |
from utils.trace_id import get_request_trace_id | |
def _get_exception_code(status_code: int) -> int: | |
try: | |
STATUS_PHRASES[status_code] | |
return status_code | |
except Exception: | |
return StandardResponseCode.HTTP_400 | |
async def _validation_exception_handler(request: Request, exc: RequestValidationError | ValidationError): | |
errors = [] | |
for error in exc.errors(): | |
custom_message = CUSTOM_VALIDATION_ERROR_MESSAGES.get(error['type']) | |
if custom_message: | |
ctx = error.get('ctx') | |
if not ctx: | |
error['msg'] = custom_message | |
else: | |
error['msg'] = custom_message.format(**ctx) | |
ctx_error = ctx.get('error') | |
if ctx_error: | |
error['ctx']['error'] = ( | |
ctx_error.__str__().replace("'", '"') if isinstance(ctx_error, Exception) else None | |
) | |
errors.append(error) | |
error = errors[0] | |
if error.get('type') == 'json_invalid': | |
message = 'json解析失败' | |
else: | |
error_input = error.get('input') | |
field = str(error.get('loc')[-1]) | |
error_msg = error.get('msg') | |
message = f'{field} {error_msg},输入:{error_input}' if settings.ENVIRONMENT == 'dev' else error_msg | |
msg = f'请求参数非法: {message}' | |
data = {'errors': errors} if settings.ENVIRONMENT == 'dev' else None | |
content = { | |
'code': StandardResponseCode.HTTP_422, | |
'msg': msg, | |
'data': data, | |
} | |
request.state.__request_validation_exception__ = content | |
content.update(trace_id=get_request_trace_id(request)) | |
return MsgSpecJSONResponse(status_code=422, content=content) | |
def register_exception(app: FastAPI): | |
async def http_exception_handler(request: Request, exc: HTTPException): | |
if settings.ENVIRONMENT == 'dev': | |
content = { | |
'code': exc.status_code, | |
'msg': exc.detail, | |
'data': None, | |
} | |
else: | |
res = response_base.fail(res=CustomResponseCode.HTTP_400) | |
content = res.model_dump() | |
request.state.__request_http_exception__ = content | |
content.update(trace_id=get_request_trace_id(request)) | |
return MsgSpecJSONResponse( | |
status_code=_get_exception_code(exc.status_code), | |
content=content, | |
headers=exc.headers, | |
) | |
async def fastapi_validation_exception_handler(request: Request, exc: RequestValidationError): | |
return await _validation_exception_handler(request, exc) | |
async def pydantic_validation_exception_handler(request: Request, exc: ValidationError): | |
return await _validation_exception_handler(request, exc) | |
async def assertion_error_handler(request: Request, exc: AssertionError): | |
if settings.ENVIRONMENT == 'dev': | |
content = { | |
'code': StandardResponseCode.HTTP_500, | |
'msg': str(''.join(exc.args) if exc.args else exc.__doc__), | |
'data': None, | |
} | |
else: | |
res = response_base.fail(res=CustomResponseCode.HTTP_500) | |
content = res.model_dump() | |
request.state.__request_assertion_error__ = content | |
content.update(trace_id=get_request_trace_id(request)) | |
return MsgSpecJSONResponse( | |
status_code=StandardResponseCode.HTTP_500, | |
content=content, | |
) | |
async def custom_exception_handler(request: Request, exc: BaseExceptionMixin): | |
content = { | |
'code': exc.code, | |
'msg': str(exc.msg), | |
'data': exc.data if exc.data else None, | |
} | |
request.state.__request_custom_exception__ = content | |
content.update(trace_id=get_request_trace_id(request)) | |
return MsgSpecJSONResponse( | |
status_code=_get_exception_code(exc.code), | |
content=content, | |
background=exc.background, | |
) | |
async def all_unknown_exception_handler(request: Request, exc: Exception): | |
if settings.ENVIRONMENT == 'dev': | |
content = { | |
'code': StandardResponseCode.HTTP_500, | |
'msg': str(exc), | |
'data': None, | |
} | |
else: | |
res = response_base.fail(res=CustomResponseCode.HTTP_500) | |
content = res.model_dump() | |
request.state.__request_all_unknown_exception__ = content | |
content.update(trace_id=get_request_trace_id(request)) | |
return MsgSpecJSONResponse( | |
status_code=StandardResponseCode.HTTP_500, | |
content=content, | |
) | |
if settings.MIDDLEWARE_CORS: | |
async def cors_custom_code_500_exception_handler(request, exc): | |
""" | |
500 | |
`Related issue <https://github.com/encode/starlette/issues/1175>`_ | |
`Solution <https://github.com/fastapi/fastapi/discussions/7847#discussioncomment-5144709>`_ | |
:param request: FastAPI | |
:param exc: | |
:return: | |
""" | |
if isinstance(exc, BaseExceptionMixin): | |
content = { | |
'code': exc.code, | |
'msg': exc.msg, | |
'data': exc.data, | |
} | |
else: | |
if settings.ENVIRONMENT == 'dev': | |
content = { | |
'code': StandardResponseCode.HTTP_500, | |
'msg': str(exc), | |
'data': None, | |
} | |
else: | |
res = response_base.fail(res=CustomResponseCode.HTTP_500) | |
content = res.model_dump() | |
request.state.__request_cors_500_exception__ = content | |
content.update(trace_id=get_request_trace_id(request)) | |
response = MsgSpecJSONResponse( | |
status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500, | |
content=content, | |
background=exc.background if isinstance(exc, BaseExceptionMixin) else None, | |
) | |
origin = request.headers.get('origin') | |
if origin: | |
cors = CORSMiddleware( | |
app=app, | |
allow_origins=settings.CORS_ALLOWED_ORIGINS, | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
expose_headers=settings.CORS_EXPOSE_HEADERS, | |
) | |
response.headers.update(cors.simple_headers) | |
has_cookie = 'cookie' in request.headers | |
if cors.allow_all_origins and has_cookie: | |
response.headers['Access-Control-Allow-Origin'] = origin | |
elif not cors.allow_all_origins and cors.is_allowed_origin(origin=origin): | |
response.headers['Access-Control-Allow-Origin'] = origin | |
response.headers.add_vary_header('Origin') | |
return response | |