SnapFeast / users /routes.py
Testys's picture
Adding migrations from alembic
d6866b9
raw
history blame
5.73 kB
from fastapi import APIRouter, status, Depends, Request, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from core.database import get_db
from users.models import UserEmbeddings
from core.security import get_current_user, create_refresh_token, create_access_token
from users.schemas import UserCreate, UserBase, UserEmbeddingsBase, User, UserUpdate
from users.services import create_user_account, create_user_embeddings, update_user, update_user_embeddings, get_user_by_id, get_user_by_email
from services.facial_processing import FacialProcessing
from services.face_match import FaceMatch
import os
import tempfile
from datetime import timedelta
from dotenv import load_dotenv
from auth.services import get_token
load_dotenv()
router = APIRouter(
prefix="/users",
tags=["Users"],
responses={404: {"description": "Not found"}},
)
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=UserBase)
async def create_user(data: UserCreate, db: Session = Depends(get_db)):
new_user = await create_user_account(data, db)
access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")))
payload = {"id": new_user.id, "sub": new_user.email}
access_token = await create_access_token(data=payload, expiry=access_token_expires)
refresh_token = await create_refresh_token(data=payload)
return JSONResponse(content={
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
"expires_in": access_token_expires.seconds
}, status_code=status.HTTP_200_OK)
@router.get("/me/", response_model=UserBase)
async def read_users_me(current_user: User = Depends(get_current_user)):
return current_user
@router.put("/me/", response_model=UserBase)
async def update_user_me(user: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
updated_user = update_user(db, current_user.id, user)
return updated_user
@router.post("/me/face/", status_code=status.HTTP_200_OK)
async def create_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)):
face_processor = FacialProcessing()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(await file.read())
temp_file.flush()
image_path = temp_file.name
embeddings = face_processor.extract_embeddings_vgg(image_path)
if embeddings:
user_embeddings = UserEmbeddingsBase(embeddings=embeddings)
await create_user_embeddings(user.id, user_embeddings, db)
os.remove(image_path)
return {"message": "Face embeddings created successfully"}
os.remove(image_path)
raise HTTPException(status_code=400, detail="Failed to process face")
@router.get("/me/face/", status_code=status.HTTP_200_OK)
async def get_face_embeddings(user: User = Depends(get_current_user), db: Session = Depends(get_db)):
face = db.query(UserEmbeddings).filter(UserEmbeddings.user_id == user.id).first()
if not face:
raise HTTPException(status_code=404, detail="Face embeddings not found")
return JSONResponse(content={"embeddings": face.embeddings}, status_code=status.HTTP_200_OK)
@router.put("/me/face/", status_code=status.HTTP_200_OK)
async def update_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)):
face_processor = FacialProcessing()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(await file.read())
temp_file.flush()
image_path = temp_file.name
embeddings = face_processor.extract_embeddings_vgg(image_path)
if embeddings:
user_embeddings = UserEmbeddingsBase(embeddings=embeddings)
await update_user_embeddings(user.id, user_embeddings, db)
os.remove(image_path)
return {"message": "Face embeddings updated successfully"}
os.remove(image_path)
raise HTTPException(status_code=400, detail="Failed to process face")
@router.post("/login/face/")
async def face_login(file: UploadFile = File(...), db: Session = Depends(get_db)):
face_processor = FacialProcessing()
face_matcher = FaceMatch(db)
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(await file.read())
temp_file.flush()
image_path = temp_file.name
embeddings = face_processor.extract_embeddings_vgg(image_path)
if not embeddings:
os.remove(image_path)
raise HTTPException(status_code=400, detail="Failed to process face")
match_result = face_matcher.new_face_matching(embeddings)
if match_result['status'] == 'Success':
user = get_user_by_id(match_result['user_id'], db)
if not user:
os.remove(image_path)
raise HTTPException(status_code=404, detail="User not found")
access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")))
payload = {"id": user.id, "sub": user.email}
access_token = await create_access_token(data=payload, expiry=access_token_expires)
refresh_token = await create_refresh_token(data=payload)
os.remove(image_path)
return JSONResponse(content={
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": access_token_expires.seconds
}, status_code=status.HTTP_200_OK)
os.remove(image_path)
raise HTTPException(status_code=401, detail="Face not recognized")