Spaces:
Runtime error
Runtime error
File size: 2,099 Bytes
b6f0f70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import base64
from typing import List
from fastapi import Depends, HTTPException, status
from fastapi_jwt_auth import AuthJWT
from pydantic import BaseModel
from . import models
from .database import get_db
from sqlalchemy.orm import Session
from .config import settings
class Settings(BaseModel):
authjwt_algorithm: str = settings.JWT_ALGORITHM
authjwt_decode_algorithms: List[str] = [settings.JWT_ALGORITHM]
authjwt_token_location: set = {'cookies', 'headers'}
authjwt_access_cookie_key: str = 'access_token'
authjwt_refresh_cookie_key: str = 'refresh_token'
authjwt_cookie_csrf_protect: bool = False
authjwt_public_key: str = base64.b64decode(
settings.JWT_PUBLIC_KEY).decode('utf-8')
authjwt_private_key: str = base64.b64decode(
settings.JWT_PRIVATE_KEY).decode('utf-8')
@AuthJWT.load_config
def get_config():
return Settings()
class NotVerified(Exception):
pass
class UserNotFound(Exception):
pass
def require_user(db: Session = Depends(get_db), Authorize: AuthJWT = Depends()):
try:
Authorize.jwt_required()
user_id = Authorize.get_jwt_subject()
user = db.query(models.User).filter(models.User.id == user_id).first()
if not user:
raise UserNotFound('User no longer exist')
if not user.verified:
raise NotVerified('You are not verified')
except Exception as e:
error = e.__class__.__name__
print(error)
if error == 'MissingTokenError':
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail='You are not logged in')
if error == 'UserNotFound':
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail='User no longer exist')
if error == 'NotVerified':
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail='Please verify your account')
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail='Token is invalid or has expired')
return user_id
|