|
from fastapi import APIRouter, status, Depends, Request, UploadFile, File, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from sqlalchemy.orm import Session |
|
from core.database import get_db |
|
from users.models import UserEmbeddings |
|
from core.security import get_current_user, create_refresh_token, create_access_token |
|
from users.schemas import UserCreate, UserBase, UserEmbeddingsBase, User, UserUpdate |
|
from users.services import create_user_account, create_user_embeddings, update_user, update_user_embeddings, get_user_by_id, get_user_by_email |
|
from services.facial_processing import FacialProcessing |
|
from services.face_match import FaceMatch |
|
import os |
|
import tempfile |
|
from datetime import timedelta |
|
from dotenv import load_dotenv |
|
from auth.services import get_token |
|
|
|
load_dotenv() |
|
|
|
router = APIRouter( |
|
prefix="/users", |
|
tags=["Users"], |
|
responses={404: {"description": "Not found"}}, |
|
) |
|
|
|
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=UserBase) |
|
async def create_user(data: UserCreate, db: Session = Depends(get_db)): |
|
new_user = await create_user_account(data, db) |
|
access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))) |
|
payload = {"id": new_user.id, "sub": new_user.email} |
|
access_token = await create_access_token(data=payload, expiry=access_token_expires) |
|
refresh_token = await create_refresh_token(data=payload) |
|
|
|
return JSONResponse(content={ |
|
"access_token": access_token, |
|
"refresh_token": refresh_token, |
|
"token_type": "Bearer", |
|
"expires_in": access_token_expires.seconds |
|
}, status_code=status.HTTP_200_OK) |
|
|
|
@router.get("/me/", response_model=UserBase) |
|
async def read_users_me(current_user: User = Depends(get_current_user)): |
|
return current_user |
|
|
|
@router.put("/me/", response_model=UserBase) |
|
async def update_user_me(user: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): |
|
updated_user = update_user(db, current_user.id, user) |
|
return updated_user |
|
|
|
@router.post("/me/face/", status_code=status.HTTP_200_OK) |
|
async def create_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)): |
|
face_processor = FacialProcessing() |
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
|
temp_file.write(await file.read()) |
|
temp_file.flush() |
|
image_path = temp_file.name |
|
|
|
|
|
embeddings = face_processor.extract_embeddings_vgg(image_path) |
|
if embeddings: |
|
user_embeddings = UserEmbeddingsBase(embeddings=embeddings) |
|
await create_user_embeddings(user.id, user_embeddings, db) |
|
|
|
os.remove(image_path) |
|
|
|
return {"message": "Face embeddings created successfully"} |
|
|
|
os.remove(image_path) |
|
|
|
raise HTTPException(status_code=400, detail="Failed to process face") |
|
|
|
@router.get("/me/face/", status_code=status.HTTP_200_OK) |
|
async def get_face_embeddings(user: User = Depends(get_current_user), db: Session = Depends(get_db)): |
|
face = db.query(UserEmbeddings).filter(UserEmbeddings.user_id == user.id).first() |
|
if not face: |
|
raise HTTPException(status_code=404, detail="Face embeddings not found") |
|
return JSONResponse(content={"embeddings": face.embeddings}, status_code=status.HTTP_200_OK) |
|
|
|
@router.put("/me/face/", status_code=status.HTTP_200_OK) |
|
async def update_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)): |
|
face_processor = FacialProcessing() |
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
|
temp_file.write(await file.read()) |
|
temp_file.flush() |
|
image_path = temp_file.name |
|
|
|
embeddings = face_processor.extract_embeddings_vgg(image_path) |
|
if embeddings: |
|
user_embeddings = UserEmbeddingsBase(embeddings=embeddings) |
|
await update_user_embeddings(user.id, user_embeddings, db) |
|
|
|
os.remove(image_path) |
|
return {"message": "Face embeddings updated successfully"} |
|
|
|
os.remove(image_path) |
|
raise HTTPException(status_code=400, detail="Failed to process face") |
|
|
|
@router.post("/login/face/") |
|
async def face_login(file: UploadFile = File(...), db: Session = Depends(get_db)): |
|
face_processor = FacialProcessing() |
|
face_matcher = FaceMatch(db) |
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
|
temp_file.write(await file.read()) |
|
temp_file.flush() |
|
image_path = temp_file.name |
|
|
|
embeddings = face_processor.extract_embeddings_vgg(image_path) |
|
if not embeddings: |
|
os.remove(image_path) |
|
raise HTTPException(status_code=400, detail="Failed to process face") |
|
|
|
match_result = face_matcher.new_face_matching(embeddings) |
|
if match_result['status'] == 'Success': |
|
user = get_user_by_id(match_result['user_id'], db) |
|
if not user: |
|
os.remove(image_path) |
|
raise HTTPException(status_code=404, detail="User not found") |
|
|
|
access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))) |
|
payload = {"id": user.id, "sub": user.email} |
|
access_token = await create_access_token(data=payload, expiry=access_token_expires) |
|
refresh_token = await create_refresh_token(data=payload) |
|
|
|
os.remove(image_path) |
|
return JSONResponse(content={ |
|
"access_token": access_token, |
|
"refresh_token": refresh_token, |
|
"token_type": "bearer", |
|
"expires_in": access_token_expires.seconds |
|
}, status_code=status.HTTP_200_OK) |
|
|
|
os.remove(image_path) |
|
raise HTTPException(status_code=401, detail="Face not recognized") |