Spaces:
Sleeping
Sleeping
File size: 8,047 Bytes
5564ecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from starlette.exceptions import HTTPException
from starlette.middleware.cors import CORSMiddleware
from uvicorn.protocols.http.h11_impl import STATUS_PHRASES
from common.exception.errors import BaseExceptionMixin
from common.response.response_code import CustomResponseCode, StandardResponseCode
from common.response.response_schema import response_base
from common.schema import (
CUSTOM_VALIDATION_ERROR_MESSAGES,
)
from core.conf import settings
from utils.serializers import MsgSpecJSONResponse
from utils.trace_id import get_request_trace_id
def _get_exception_code(status_code: int) -> int:
try:
STATUS_PHRASES[status_code]
return status_code
except Exception:
return StandardResponseCode.HTTP_400
async def _validation_exception_handler(request: Request, exc: RequestValidationError | ValidationError):
errors = []
for error in exc.errors():
custom_message = CUSTOM_VALIDATION_ERROR_MESSAGES.get(error['type'])
if custom_message:
ctx = error.get('ctx')
if not ctx:
error['msg'] = custom_message
else:
error['msg'] = custom_message.format(**ctx)
ctx_error = ctx.get('error')
if ctx_error:
error['ctx']['error'] = (
ctx_error.__str__().replace("'", '"') if isinstance(ctx_error, Exception) else None
)
errors.append(error)
error = errors[0]
if error.get('type') == 'json_invalid':
message = 'json解析失败'
else:
error_input = error.get('input')
field = str(error.get('loc')[-1])
error_msg = error.get('msg')
message = f'{field} {error_msg},输入:{error_input}' if settings.ENVIRONMENT == 'dev' else error_msg
msg = f'请求参数非法: {message}'
data = {'errors': errors} if settings.ENVIRONMENT == 'dev' else None
content = {
'code': StandardResponseCode.HTTP_422,
'msg': msg,
'data': data,
}
request.state.__request_validation_exception__ = content
content.update(trace_id=get_request_trace_id(request))
return MsgSpecJSONResponse(status_code=422, content=content)
def register_exception(app: FastAPI):
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
if settings.ENVIRONMENT == 'dev':
content = {
'code': exc.status_code,
'msg': exc.detail,
'data': None,
}
else:
res = response_base.fail(res=CustomResponseCode.HTTP_400)
content = res.model_dump()
request.state.__request_http_exception__ = content
content.update(trace_id=get_request_trace_id(request))
return MsgSpecJSONResponse(
status_code=_get_exception_code(exc.status_code),
content=content,
headers=exc.headers,
)
@app.exception_handler(RequestValidationError)
async def fastapi_validation_exception_handler(request: Request, exc: RequestValidationError):
return await _validation_exception_handler(request, exc)
@app.exception_handler(ValidationError)
async def pydantic_validation_exception_handler(request: Request, exc: ValidationError):
return await _validation_exception_handler(request, exc)
@app.exception_handler(AssertionError)
async def assertion_error_handler(request: Request, exc: AssertionError):
if settings.ENVIRONMENT == 'dev':
content = {
'code': StandardResponseCode.HTTP_500,
'msg': str(''.join(exc.args) if exc.args else exc.__doc__),
'data': None,
}
else:
res = response_base.fail(res=CustomResponseCode.HTTP_500)
content = res.model_dump()
request.state.__request_assertion_error__ = content
content.update(trace_id=get_request_trace_id(request))
return MsgSpecJSONResponse(
status_code=StandardResponseCode.HTTP_500,
content=content,
)
@app.exception_handler(BaseExceptionMixin)
async def custom_exception_handler(request: Request, exc: BaseExceptionMixin):
content = {
'code': exc.code,
'msg': str(exc.msg),
'data': exc.data if exc.data else None,
}
request.state.__request_custom_exception__ = content
content.update(trace_id=get_request_trace_id(request))
return MsgSpecJSONResponse(
status_code=_get_exception_code(exc.code),
content=content,
background=exc.background,
)
@app.exception_handler(Exception)
async def all_unknown_exception_handler(request: Request, exc: Exception):
if settings.ENVIRONMENT == 'dev':
content = {
'code': StandardResponseCode.HTTP_500,
'msg': str(exc),
'data': None,
}
else:
res = response_base.fail(res=CustomResponseCode.HTTP_500)
content = res.model_dump()
request.state.__request_all_unknown_exception__ = content
content.update(trace_id=get_request_trace_id(request))
return MsgSpecJSONResponse(
status_code=StandardResponseCode.HTTP_500,
content=content,
)
if settings.MIDDLEWARE_CORS:
@app.exception_handler(StandardResponseCode.HTTP_500)
async def cors_custom_code_500_exception_handler(request, exc):
"""
500
`Related issue <https://github.com/encode/starlette/issues/1175>`_
`Solution <https://github.com/fastapi/fastapi/discussions/7847#discussioncomment-5144709>`_
:param request: FastAPI
:param exc:
:return:
"""
if isinstance(exc, BaseExceptionMixin):
content = {
'code': exc.code,
'msg': exc.msg,
'data': exc.data,
}
else:
if settings.ENVIRONMENT == 'dev':
content = {
'code': StandardResponseCode.HTTP_500,
'msg': str(exc),
'data': None,
}
else:
res = response_base.fail(res=CustomResponseCode.HTTP_500)
content = res.model_dump()
request.state.__request_cors_500_exception__ = content
content.update(trace_id=get_request_trace_id(request))
response = MsgSpecJSONResponse(
status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500,
content=content,
background=exc.background if isinstance(exc, BaseExceptionMixin) else None,
)
origin = request.headers.get('origin')
if origin:
cors = CORSMiddleware(
app=app,
allow_origins=settings.CORS_ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
expose_headers=settings.CORS_EXPOSE_HEADERS,
)
response.headers.update(cors.simple_headers)
has_cookie = 'cookie' in request.headers
if cors.allow_all_origins and has_cookie:
response.headers['Access-Control-Allow-Origin'] = origin
elif not cors.allow_all_origins and cors.is_allowed_origin(origin=origin):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers.add_vary_header('Origin')
return response
|