Testys commited on
Commit
60cc4ec
1 Parent(s): d76e641

Push new edits

Browse files
core/config.py CHANGED
@@ -2,31 +2,42 @@ import os
2
  from pathlib import Path
3
  from dotenv import load_dotenv
4
  from urllib.parse import quote_plus
5
- from pydantic_settings import BaseSettings # Make sure this import is correct
6
 
7
  load_dotenv()
8
 
9
  class Settings(BaseSettings):
10
- # Use environment variables to configure the database URL components
11
  DATABASE_USER: str = os.getenv("PG_USER")
12
  DATABASE_PASSWORD: str = os.getenv("PG_PASSWORD")
13
  DATABASE_HOST: str = os.getenv("PG_HOST")
14
  DATABASE_PORT: str = os.getenv("PG_PORT")
15
  DATABASE_NAME: str = os.getenv("PG_NAME")
16
 
17
- # Combine components into the full URL using a property
 
 
 
 
 
 
 
 
 
 
 
 
18
  @property
19
  def DATABASE_URL(self) -> str:
20
  user = quote_plus(self.DATABASE_USER)
21
  password = quote_plus(self.DATABASE_PASSWORD)
22
  return f"postgresql://{user}:{password}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}?sslmode=require"
23
 
24
- JWT_SECRET_KEY: str = os.getenv("JWT_SECRET")
25
- JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM")
26
- ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", default="30")) # Default value as fallback
27
 
28
- def get_settings():
29
  return Settings()
30
 
31
  # Usage example
32
- settings = get_settings()
 
2
  from pathlib import Path
3
  from dotenv import load_dotenv
4
  from urllib.parse import quote_plus
5
+ from pydantic_settings import BaseSettings
6
 
7
  load_dotenv()
8
 
9
  class Settings(BaseSettings):
10
+ # Database settings
11
  DATABASE_USER: str = os.getenv("PG_USER")
12
  DATABASE_PASSWORD: str = os.getenv("PG_PASSWORD")
13
  DATABASE_HOST: str = os.getenv("PG_HOST")
14
  DATABASE_PORT: str = os.getenv("PG_PORT")
15
  DATABASE_NAME: str = os.getenv("PG_NAME")
16
 
17
+ # JWT settings
18
+ JWT_SECRET_KEY: str = os.getenv("JWT_SECRET")
19
+ JWT_ALGORITHM: str = "HS256"
20
+ ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
21
+ REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
22
+
23
+ # Application settings
24
+ DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true"
25
+ ALLOWED_HOSTS: list = os.getenv("ALLOWED_HOSTS", "*").split(",")
26
+
27
+ # Face recognition settings
28
+ FACE_RECOGNITION_THRESHOLD: float = float(os.getenv("FACE_RECOGNITION_THRESHOLD", "0.6"))
29
+
30
  @property
31
  def DATABASE_URL(self) -> str:
32
  user = quote_plus(self.DATABASE_USER)
33
  password = quote_plus(self.DATABASE_PASSWORD)
34
  return f"postgresql://{user}:{password}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}?sslmode=require"
35
 
36
+ class Config:
37
+ env_file = ".env"
 
38
 
39
+ def get_settings() -> Settings:
40
  return Settings()
41
 
42
  # Usage example
43
+ settings = get_settings()
core/database.py CHANGED
@@ -14,6 +14,7 @@ engine = create_engine(
14
  max_overflow=0
15
 
16
  )
 
17
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
18
 
19
  Base = declarative_base()
 
14
  max_overflow=0
15
 
16
  )
17
+
18
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
19
 
20
  Base = declarative_base()
core/security.py CHANGED
@@ -1,75 +1,59 @@
1
  from passlib.context import CryptContext
2
  from fastapi.security import OAuth2PasswordBearer
3
- from fastapi import Depends, HTTPException
4
  from datetime import timedelta, datetime
5
  from jose import JWTError, jwt
6
  from core.config import get_settings
7
  from sqlalchemy.orm import Session
8
  from core.database import get_db
9
 
10
-
11
  settings = get_settings()
12
 
13
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
14
- oauth2scheme = OAuth2PasswordBearer(tokenUrl="auth/token/")
15
 
16
- def get_password_hash(password):
17
  return pwd_context.hash(password)
18
 
19
- def verify_password(plain_password, hashed_password):
20
  return pwd_context.verify(plain_password, hashed_password)
21
 
22
- async def create_access_token(data:dict, expiry:timedelta):
23
  payload = data.copy()
24
  expire = datetime.utcnow() + expiry
25
  payload.update({"exp": expire})
26
- token = jwt.encode(payload,
27
- settings.JWT_SECRET_KEY,
28
- algorithm=settings.JWT_ALGORITHM)
29
-
30
- return token
31
 
32
- async def create_refresh_token(data:dict):
33
  payload = data.copy()
34
- token = jwt.encode(payload,
35
- settings.JWT_SECRET_KEY,
36
- algorithm=settings.JWT_ALGORITHM)
37
- return token
38
 
39
- def get_token_payload(token:str):
40
  try:
41
- payload = jwt.decode(token,
42
- settings.JWT_SECRET_KEY,
43
- algorithms=[settings.JWT_ALGORITHM])
44
- return payload
45
  except JWTError:
46
  return None
47
-
48
- async def get_current_user(token: str = Depends(oauth2scheme), db: Session = Depends(get_db)):
49
- from users.services import get_user_by_email # Local import
 
 
 
 
 
 
50
 
51
  try:
52
  payload = get_token_payload(token)
53
- email = payload.get("sub")
 
 
54
  if email is None:
55
- raise HTTPException(status_code=401,
56
- detail="Invalid Token",
57
- headers={"WWW-Authenticate": "Bearer"})
58
  except JWTError:
59
- raise HTTPException(status_code=401,
60
- detail="Invalid Token",
61
- headers={"WWW-Authenticate": "Bearer"}
62
- )
63
-
64
  user = get_user_by_email(email, db=db)
65
  if user is None:
66
- raise HTTPException(status_code=401,
67
- detail="User not found",
68
- headers={"WWW-Authenticate": "Bearer"}
69
- )
70
- return user
71
-
72
-
73
-
74
-
75
-
 
1
  from passlib.context import CryptContext
2
  from fastapi.security import OAuth2PasswordBearer
3
+ from fastapi import Depends, HTTPException, status
4
  from datetime import timedelta, datetime
5
  from jose import JWTError, jwt
6
  from core.config import get_settings
7
  from sqlalchemy.orm import Session
8
  from core.database import get_db
9
 
 
10
  settings = get_settings()
11
 
12
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
13
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token/")
14
 
15
+ def get_password_hash(password: str) -> str:
16
  return pwd_context.hash(password)
17
 
18
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
19
  return pwd_context.verify(plain_password, hashed_password)
20
 
21
+ async def create_access_token(data: dict, expiry: timedelta) -> str:
22
  payload = data.copy()
23
  expire = datetime.utcnow() + expiry
24
  payload.update({"exp": expire})
25
+ return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
 
 
 
 
26
 
27
+ async def create_refresh_token(data: dict) -> str:
28
  payload = data.copy()
29
+ return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
 
 
 
30
 
31
+ def get_token_payload(token: str) -> dict:
32
  try:
33
+ return jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
 
 
 
34
  except JWTError:
35
  return None
36
+
37
+ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
38
+ from users.services import get_user_by_email # Local import to avoid circular dependency
39
+
40
+ credentials_exception = HTTPException(
41
+ status_code=status.HTTP_401_UNAUTHORIZED,
42
+ detail="Could not validate credentials",
43
+ headers={"WWW-Authenticate": "Bearer"},
44
+ )
45
 
46
  try:
47
  payload = get_token_payload(token)
48
+ if payload is None:
49
+ raise credentials_exception
50
+ email: str = payload.get("sub")
51
  if email is None:
52
+ raise credentials_exception
 
 
53
  except JWTError:
54
+ raise credentials_exception
55
+
 
 
 
56
  user = get_user_by_email(email, db=db)
57
  if user is None:
58
+ raise credentials_exception
59
+ return user
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,6 +1,5 @@
1
- # main.py
2
  from fastapi import FastAPI
3
- from typing import List
4
  from users.routes import router as users_router
5
  from auth.route import router as auth_router
6
  from orders.routes import order_router, meal_router
@@ -13,24 +12,30 @@ app = FastAPI(
13
  redoc_url="/redoc",
14
  openapi_url="/openapi.json",
15
  )
 
 
 
 
 
 
 
 
 
 
 
16
  app.include_router(users_router)
17
  app.include_router(auth_router)
18
  app.include_router(meal_router)
19
  app.include_router(order_router)
20
 
21
-
22
  @app.get("/", tags=["Home"])
23
  def read_root():
24
  return {"message": "Welcome to SnapFeast API!"}
25
 
26
-
27
  @app.get("/health", tags=["Health"])
28
  def health_check():
29
  return {"status": "ok"}
30
 
31
-
32
  if __name__ == "__main__":
33
  import uvicorn
34
-
35
- uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
36
-
 
 
1
  from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from users.routes import router as users_router
4
  from auth.route import router as auth_router
5
  from orders.routes import order_router, meal_router
 
12
  redoc_url="/redoc",
13
  openapi_url="/openapi.json",
14
  )
15
+
16
+ # Configure CORS
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # Allows all origins
20
+ allow_credentials=True,
21
+ allow_methods=["*"], # Allows all methods
22
+ allow_headers=["*"], # Allows all headers
23
+ )
24
+
25
+ # Include routers
26
  app.include_router(users_router)
27
  app.include_router(auth_router)
28
  app.include_router(meal_router)
29
  app.include_router(order_router)
30
 
 
31
  @app.get("/", tags=["Home"])
32
  def read_root():
33
  return {"message": "Welcome to SnapFeast API!"}
34
 
 
35
  @app.get("/health", tags=["Health"])
36
  def health_check():
37
  return {"status": "ok"}
38
 
 
39
  if __name__ == "__main__":
40
  import uvicorn
41
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
 
services/face_match.py CHANGED
@@ -2,16 +2,20 @@ from sklearn.metrics.pairwise import cosine_similarity
2
  import numpy as np
3
  from sqlalchemy.orm import Session
4
  from users.models import UserEmbeddings
 
 
 
5
 
6
  class FaceMatch:
7
  def __init__(self, db: Session):
8
  self.db = db
 
9
 
10
  def load_embeddings_from_db(self):
11
  user_embeddings = self.db.query(UserEmbeddings).all()
12
  return {ue.user_id: np.array(ue.embeddings) for ue in user_embeddings}
13
 
14
- def match_faces(self, new_embeddings, saved_embeddings, threshold=0.6):
15
  new_embeddings = np.array(new_embeddings)
16
  max_similarity = 0
17
  identity = None
@@ -22,12 +26,12 @@ class FaceMatch:
22
  max_similarity = similarity
23
  identity = user_id
24
 
25
- return identity, max_similarity if max_similarity > threshold else (None, 0)
26
 
27
  def new_face_matching(self, new_embeddings):
28
  embeddings_dict = self.load_embeddings_from_db()
29
  if not embeddings_dict:
30
- return {'status': 'Error', 'message': 'No embeddings available'}
31
 
32
  identity, similarity = self.match_faces(new_embeddings, embeddings_dict)
33
  if identity:
@@ -35,10 +39,9 @@ class FaceMatch:
35
  'status': 'Success',
36
  'message': 'Match Found',
37
  'user_id': identity,
38
- 'similarity': similarity
39
  }
40
  return {
41
  'status': 'Error',
42
  'message': 'No matching face found'
43
- }
44
-
 
2
  import numpy as np
3
  from sqlalchemy.orm import Session
4
  from users.models import UserEmbeddings
5
+ from core.config import get_settings
6
+
7
+ settings = get_settings()
8
 
9
  class FaceMatch:
10
  def __init__(self, db: Session):
11
  self.db = db
12
+ self.threshold = settings.FACE_RECOGNITION_THRESHOLD
13
 
14
  def load_embeddings_from_db(self):
15
  user_embeddings = self.db.query(UserEmbeddings).all()
16
  return {ue.user_id: np.array(ue.embeddings) for ue in user_embeddings}
17
 
18
+ def match_faces(self, new_embeddings, saved_embeddings):
19
  new_embeddings = np.array(new_embeddings)
20
  max_similarity = 0
21
  identity = None
 
26
  max_similarity = similarity
27
  identity = user_id
28
 
29
+ return identity, max_similarity if max_similarity > self.threshold else (None, 0)
30
 
31
  def new_face_matching(self, new_embeddings):
32
  embeddings_dict = self.load_embeddings_from_db()
33
  if not embeddings_dict:
34
+ return {'status': 'Error', 'message': 'No embeddings available in the database'}
35
 
36
  identity, similarity = self.match_faces(new_embeddings, embeddings_dict)
37
  if identity:
 
39
  'status': 'Success',
40
  'message': 'Match Found',
41
  'user_id': identity,
42
+ 'similarity': float(similarity) # Convert numpy float to Python float
43
  }
44
  return {
45
  'status': 'Error',
46
  'message': 'No matching face found'
47
+ }
 
services/facial_processing.py CHANGED
@@ -3,35 +3,41 @@ import os
3
  import torch
4
  from facenet_pytorch import MTCNN, InceptionResnetV1
5
  import logging
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
-
10
  class FacialProcessing:
11
  def __init__(self):
12
- # Set the cache directory to a writable location
13
  os.environ['TORCH_HOME'] = '/tmp/.cache/torch'
14
-
15
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
16
 
17
- self.mtcnn = MTCNN(keep_all=True, device=device)
18
- self.resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
19
-
20
-
21
- def extract_embeddings_vgg(self, image):
22
  try:
23
- # Preprocess the image
24
- preprocessed_image = self.mtcnn(image)
25
-
26
- if preprocessed_image is None:
27
- logger.warning(f"No face detected in image")
 
 
 
 
 
 
 
 
 
 
 
28
  return None
29
 
30
- # Extract the face embeddings
31
- embeddings = self.resnet(preprocessed_image.unsqueeze(0)).detach().cpu().numpy().tolist()
32
- if embeddings:
33
- return embeddings[0]
34
 
35
  except Exception as e:
36
  logger.error(f"An error occurred while extracting embeddings: {e}")
37
- return None
 
3
  import torch
4
  from facenet_pytorch import MTCNN, InceptionResnetV1
5
  import logging
6
+ from PIL import Image
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
10
  class FacialProcessing:
11
  def __init__(self):
 
12
  os.environ['TORCH_HOME'] = '/tmp/.cache/torch'
 
13
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ self.mtcnn = MTCNN(keep_all=True, device=self.device)
15
+ self.resnet = InceptionResnetV1(pretrained='vggface2').eval().to(self.device)
16
 
17
+ def extract_embeddings_vgg(self, image_path):
 
 
 
 
18
  try:
19
+ img = Image.open(image_path)
20
+ img = img.convert('RGB')
21
+
22
+ # Detect faces
23
+ boxes, _ = self.mtcnn.detect(img)
24
+
25
+ if boxes is None:
26
+ logger.warning(f"No face detected in image: {image_path}")
27
+ return None
28
+
29
+ # Get the largest face
30
+ largest_box = max(boxes, key=lambda box: (box[2] - box[0]) * (box[3] - box[1]))
31
+ face = self.mtcnn(img, return_prob=False)
32
+
33
+ if face is None:
34
+ logger.warning(f"Failed to align face in image: {image_path}")
35
  return None
36
 
37
+ # Extract embeddings
38
+ embeddings = self.resnet(face).detach().cpu().numpy().flatten()
39
+ return embeddings.tolist()
 
40
 
41
  except Exception as e:
42
  logger.error(f"An error occurred while extracting embeddings: {e}")
43
+ return None
users/models.py CHANGED
@@ -7,16 +7,16 @@ class User(Base):
7
  __tablename__ = "users"
8
 
9
  id = Column(Integer, primary_key=True, index=True)
10
- email = Column(String, unique=True, index=True)
11
- username = Column(String, unique=True, index=True)
12
- hashed_password = Column(String)
13
  first_name = Column(String)
14
  last_name = Column(String)
15
  age = Column(Integer)
16
  preferences = Column(ARRAY(String))
17
  is_active = Column(Boolean, default=True)
18
  is_admin = Column(Boolean, default=False)
19
- updated_at = Column(DateTime, default=datetime.utcnow)
20
  created_at = Column(DateTime, default=datetime.utcnow)
21
 
22
  orders = relationship("Order", back_populates="user")
@@ -26,7 +26,7 @@ class UserEmbeddings(Base):
26
  __tablename__ = "user_embeddings"
27
 
28
  id = Column(Integer, primary_key=True, index=True)
29
- user_id = Column(Integer, ForeignKey("users.id"))
30
- embeddings = Column(ARRAY(Float))
31
 
32
- user = relationship("User", back_populates="embeddings")
 
7
  __tablename__ = "users"
8
 
9
  id = Column(Integer, primary_key=True, index=True)
10
+ email = Column(String, unique=True, index=True, nullable=False)
11
+ username = Column(String, unique=True, index=True, nullable=False)
12
+ hashed_password = Column(String, nullable=False)
13
  first_name = Column(String)
14
  last_name = Column(String)
15
  age = Column(Integer)
16
  preferences = Column(ARRAY(String))
17
  is_active = Column(Boolean, default=True)
18
  is_admin = Column(Boolean, default=False)
19
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
20
  created_at = Column(DateTime, default=datetime.utcnow)
21
 
22
  orders = relationship("Order", back_populates="user")
 
26
  __tablename__ = "user_embeddings"
27
 
28
  id = Column(Integer, primary_key=True, index=True)
29
+ user_id = Column(Integer, ForeignKey("users.id"), unique=True, nullable=False)
30
+ embeddings = Column(ARRAY(Float), nullable=False)
31
 
32
+ user = relationship("User", back_populates="embeddings")
users/routes.py CHANGED
@@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
3
  from sqlalchemy.orm import Session
4
  from core.database import get_db
5
  from core.security import get_current_user, create_access_token
6
- from users.schemas import UserCreate, UserBase, UserEmbeddingsBase, User
7
  from users.services import create_user_account, create_user_embeddings, update_user, update_user_embeddings, get_user_by_id, get_user_by_email
8
  from services.facial_processing import FacialProcessing
9
  from services.face_match import FaceMatch
@@ -12,7 +12,6 @@ from datetime import timedelta
12
  from dotenv import load_dotenv
13
  from auth.services import get_token
14
 
15
-
16
  load_dotenv()
17
 
18
  router = APIRouter(
@@ -21,104 +20,81 @@ router = APIRouter(
21
  responses={404: {"description": "Not found"}},
22
  )
23
 
24
- @router.get("/health", tags=["Health"])
25
- def health_check():
26
- return {"status": "ok"}
27
-
28
  @router.post("/", status_code=status.HTTP_201_CREATED, response_model=UserBase)
29
  async def create_user(data: UserCreate, db: Session = Depends(get_db)):
30
- await create_user_account(data, db)
31
- payload = {"message": "User created successfully"}
32
- return JSONResponse(content=payload, status_code=status.HTTP_201_CREATED)
33
-
34
-
35
- @router.post("/me",status_code=status.HTTP_200_OK)
36
- def get_user_details(request: Request):
37
- user = request.user
38
- return user
39
 
40
  @router.get("/me", response_model=UserBase)
41
  async def read_users_me(current_user: User = Depends(get_current_user)):
42
  return current_user
43
 
44
  @router.put("/me", response_model=UserBase)
45
- async def updating_user(user: UserBase, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
46
- user = update_user(db, current_user.id, user)
47
- return current_user
48
-
49
 
50
- @router.post("/me/face",status_code=status.HTTP_200_OK)
51
- async def create_face_embeddings(file: UploadFile = File(...), user:User = Depends(get_current_user), db: Session = Depends(get_db)):
52
  face_processor = FacialProcessing()
53
 
54
- # Process the uploaded image
55
  image_path = f"faces/{user.id}.jpg"
56
  with open(image_path, "wb") as buffer:
57
  buffer.write(await file.read())
58
 
59
- # Extract embeddings
60
  embeddings = face_processor.extract_embeddings_vgg(image_path)
61
  if embeddings:
62
- create_user_embeddings(user.id, embeddings, db)
 
63
  return {"message": "Face embeddings created successfully"}
64
 
65
- raise HTTPException(status_code=400,
66
- detail="Failed to process face"
67
- )
68
 
69
- @router.get("/me/face",status_code=status.HTTP_200_OK)
70
- async def get_face_embeddings(user:User = Depends(get_current_user), db: Session = Depends(get_db)):
71
  face = db.query(UserEmbeddingsBase).filter(UserEmbeddingsBase.user_id == user.id).first()
72
  if not face:
73
- raise HTTPException(status_code=404,
74
- detail="Face embeddings not found"
75
- )
76
  return JSONResponse(content={"embeddings": face.embeddings}, status_code=status.HTTP_200_OK)
77
 
78
- @router.put("/me/face",status_code=status.HTTP_200_OK)
79
- async def updating_face_embeddings(file: UploadFile = File(...), user:User = Depends(get_current_user), db: Session = Depends(get_db)):
80
  face_processor = FacialProcessing()
81
 
82
- # Process the uploaded image
83
  image_path = f"faces/{user.id}.jpg"
84
  with open(image_path, "wb") as buffer:
85
  buffer.write(await file.read())
86
 
87
- # Extract embeddings
88
  embeddings = face_processor.extract_embeddings_vgg(image_path)
89
  if embeddings:
90
- update_user_embeddings(user.id, embeddings, db)
 
91
  return {"message": "Face embeddings updated successfully"}
92
 
93
- raise HTTPException(status_code=400,
94
- detail="Failed to process face"
95
- )
96
 
97
  @router.post("/login/face")
98
  async def face_login(file: UploadFile = File(...), db: Session = Depends(get_db)):
99
  face_processor = FacialProcessing()
100
  face_matcher = FaceMatch(db)
101
 
102
- # Process the uploaded image
103
  image_path = f"temp_{file.filename}"
104
  with open(image_path, "wb") as buffer:
105
  buffer.write(await file.read())
106
 
107
- # Extract embeddings
108
  embeddings = await face_processor.extract_embeddings(image_path)
109
  if not embeddings:
110
  raise HTTPException(status_code=400, detail="Failed to process face")
111
 
112
- # Match face
113
  match_result = face_matcher.new_face_matching(embeddings)
114
  if match_result['status'] == 'Success':
115
  user = get_user_by_id(match_result['user_id'], db)
116
  if not user:
117
  raise HTTPException(status_code=404, detail="User not found")
118
 
119
- access_token_expires = timedelta(minutes=os.getenv("ACCESS_TOKEN"))
120
- payload = {"id":user.id, "sub": user.email}
121
  token = get_token(payload, db)
122
  return JSONResponse(content=token.dict(), status_code=status.HTTP_200_OK)
123
 
124
- raise HTTPException(status_code=401, detail="Face not recognized")
 
3
  from sqlalchemy.orm import Session
4
  from core.database import get_db
5
  from core.security import get_current_user, create_access_token
6
+ from users.schemas import UserCreate, UserBase, UserEmbeddingsBase, User, UserUpdate
7
  from users.services import create_user_account, create_user_embeddings, update_user, update_user_embeddings, get_user_by_id, get_user_by_email
8
  from services.facial_processing import FacialProcessing
9
  from services.face_match import FaceMatch
 
12
  from dotenv import load_dotenv
13
  from auth.services import get_token
14
 
 
15
  load_dotenv()
16
 
17
  router = APIRouter(
 
20
  responses={404: {"description": "Not found"}},
21
  )
22
 
 
 
 
 
23
  @router.post("/", status_code=status.HTTP_201_CREATED, response_model=UserBase)
24
  async def create_user(data: UserCreate, db: Session = Depends(get_db)):
25
+ new_user = await create_user_account(data, db)
26
+ return new_user
 
 
 
 
 
 
 
27
 
28
  @router.get("/me", response_model=UserBase)
29
  async def read_users_me(current_user: User = Depends(get_current_user)):
30
  return current_user
31
 
32
  @router.put("/me", response_model=UserBase)
33
+ async def update_user_me(user: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
34
+ updated_user = update_user(db, current_user.id, user)
35
+ return updated_user
 
36
 
37
+ @router.post("/me/face", status_code=status.HTTP_200_OK)
38
+ async def create_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)):
39
  face_processor = FacialProcessing()
40
 
 
41
  image_path = f"faces/{user.id}.jpg"
42
  with open(image_path, "wb") as buffer:
43
  buffer.write(await file.read())
44
 
 
45
  embeddings = face_processor.extract_embeddings_vgg(image_path)
46
  if embeddings:
47
+ user_embeddings = UserEmbeddingsBase(embeddings=embeddings)
48
+ await create_user_embeddings(user.id, user_embeddings, db)
49
  return {"message": "Face embeddings created successfully"}
50
 
51
+ raise HTTPException(status_code=400, detail="Failed to process face")
 
 
52
 
53
+ @router.get("/me/face", status_code=status.HTTP_200_OK)
54
+ async def get_face_embeddings(user: User = Depends(get_current_user), db: Session = Depends(get_db)):
55
  face = db.query(UserEmbeddingsBase).filter(UserEmbeddingsBase.user_id == user.id).first()
56
  if not face:
57
+ raise HTTPException(status_code=404, detail="Face embeddings not found")
 
 
58
  return JSONResponse(content={"embeddings": face.embeddings}, status_code=status.HTTP_200_OK)
59
 
60
+ @router.put("/me/face", status_code=status.HTTP_200_OK)
61
+ async def update_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)):
62
  face_processor = FacialProcessing()
63
 
 
64
  image_path = f"faces/{user.id}.jpg"
65
  with open(image_path, "wb") as buffer:
66
  buffer.write(await file.read())
67
 
 
68
  embeddings = face_processor.extract_embeddings_vgg(image_path)
69
  if embeddings:
70
+ user_embeddings = UserEmbeddingsBase(embeddings=embeddings)
71
+ await update_user_embeddings(user.id, user_embeddings, db)
72
  return {"message": "Face embeddings updated successfully"}
73
 
74
+ raise HTTPException(status_code=400, detail="Failed to process face")
 
 
75
 
76
  @router.post("/login/face")
77
  async def face_login(file: UploadFile = File(...), db: Session = Depends(get_db)):
78
  face_processor = FacialProcessing()
79
  face_matcher = FaceMatch(db)
80
 
 
81
  image_path = f"temp_{file.filename}"
82
  with open(image_path, "wb") as buffer:
83
  buffer.write(await file.read())
84
 
 
85
  embeddings = await face_processor.extract_embeddings(image_path)
86
  if not embeddings:
87
  raise HTTPException(status_code=400, detail="Failed to process face")
88
 
 
89
  match_result = face_matcher.new_face_matching(embeddings)
90
  if match_result['status'] == 'Success':
91
  user = get_user_by_id(match_result['user_id'], db)
92
  if not user:
93
  raise HTTPException(status_code=404, detail="User not found")
94
 
95
+ access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")))
96
+ payload = {"id": user.id, "sub": user.email}
97
  token = get_token(payload, db)
98
  return JSONResponse(content=token.dict(), status_code=status.HTTP_200_OK)
99
 
100
+ raise HTTPException(status_code=401, detail="Face not recognized")
users/schemas.py CHANGED
@@ -1,25 +1,34 @@
1
- from pydantic import BaseModel, EmailStr
2
  from typing import Optional, List
 
3
 
4
  class UserBase(BaseModel):
5
- username: str
6
- first_name: str
7
- last_name: str
8
  email: EmailStr
9
- age: Optional[int] = None
10
  preferences: Optional[List[str]] = None
11
- is_active: Optional[bool] = True
12
-
13
 
14
  class UserCreate(UserBase):
15
- password: str
16
-
17
-
18
- class UserUpdate(UserBase):
19
- password: Optional[str] = None
 
 
 
 
 
20
 
21
  class User(UserBase):
22
  id: int
 
 
 
 
23
 
24
  class Config:
25
  orm_mode = True
@@ -30,11 +39,9 @@ class UserEmbeddingsBase(BaseModel):
30
  class UserEmbeddingsCreate(UserEmbeddingsBase):
31
  pass
32
 
33
-
34
  class UserEmbeddings(UserEmbeddingsBase):
35
  id: int
36
  user_id: int
37
 
38
  class Config:
39
- orm_mode = True
40
-
 
1
+ from pydantic import BaseModel, EmailStr, Field
2
  from typing import Optional, List
3
+ from datetime import datetime
4
 
5
  class UserBase(BaseModel):
6
+ username: str = Field(..., min_length=3, max_length=50)
7
+ first_name: str = Field(..., min_length=1, max_length=50)
8
+ last_name: str = Field(..., min_length=1, max_length=50)
9
  email: EmailStr
10
+ age: Optional[int] = Field(None, ge=0, le=120)
11
  preferences: Optional[List[str]] = None
12
+ is_active: bool = True
 
13
 
14
  class UserCreate(UserBase):
15
+ password: str = Field(..., min_length=8)
16
+
17
+ class UserUpdate(BaseModel):
18
+ username: Optional[str] = Field(None, min_length=3, max_length=50)
19
+ first_name: Optional[str] = Field(None, min_length=1, max_length=50)
20
+ last_name: Optional[str] = Field(None, min_length=1, max_length=50)
21
+ email: Optional[EmailStr] = None
22
+ age: Optional[int] = Field(None, ge=0, le=120)
23
+ preferences: Optional[List[str]] = None
24
+ password: Optional[str] = Field(None, min_length=8)
25
 
26
  class User(UserBase):
27
  id: int
28
+ is_active: bool
29
+ is_admin: bool
30
+ created_at: datetime
31
+ updated_at: datetime
32
 
33
  class Config:
34
  orm_mode = True
 
39
  class UserEmbeddingsCreate(UserEmbeddingsBase):
40
  pass
41
 
 
42
  class UserEmbeddings(UserEmbeddingsBase):
43
  id: int
44
  user_id: int
45
 
46
  class Config:
47
+ orm_mode = True
 
users/services.py CHANGED
@@ -1,26 +1,23 @@
1
  from users.models import User, UserEmbeddings
2
- from fastapi.exceptions import HTTPException
3
  from core.security import get_password_hash
4
  from datetime import datetime
5
  from sqlalchemy.orm import Session
6
- from users.schemas import UserCreate, UserUpdate, UserEmbeddings
7
 
8
-
9
-
10
- async def create_user_account(data:UserCreate, db:Session):
11
  user = db.query(User).filter(User.email == data.email).first()
12
  if user:
13
- raise HTTPException(status_code=422, detail="Email already registered")
14
 
15
  new_user = User(
16
  email=data.email,
 
17
  first_name=data.first_name,
18
  last_name=data.last_name,
19
  age=data.age,
20
  preferences=data.preferences,
21
- password = get_password_hash(data.password),
22
- registered_at=datetime.now(),
23
- updated_at=datetime.now(),
24
  )
25
 
26
  db.add(new_user)
@@ -28,58 +25,63 @@ async def create_user_account(data:UserCreate, db:Session):
28
  db.refresh(new_user)
29
  return new_user
30
 
31
- async def create_user_embeddings(user_id:int, embeddings:UserEmbeddings, db:Session):
32
  user = db.query(User).filter(User.id == user_id).first()
33
  if not user:
34
- embeddings = UserEmbeddings(user_id=user_id, embeddings=embeddings.embeddings)
35
- db.add(embeddings)
36
- db.commit()
37
- db.refresh(embeddings)
38
- return embeddings
39
-
40
- def get_user_by_id(user_id, db:Session):
41
- return db.query(User).filter(User.id == user_id).first()
42
-
43
 
44
- def get_user_by_email(email, db:Session):
45
- return db.query(User).filter(User.email == email).first()
 
 
 
46
 
 
 
 
 
 
47
 
48
- def get_users(db:Session, skip=0, limit=100):
49
  return db.query(User).offset(skip).limit(limit).all()
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def update_user(db:Session, user_id:int, user:UserUpdate):
53
  user = db.query(User).filter(User.id == user_id).first()
54
- if user:
55
- update_data = user.dict(exclude_unset=True)
56
- if 'password' in update_data:
57
- update_data['password'] = get_password_hash(update_data['password'])
58
- del update_data['password']
59
-
60
- for key, value in update_data.items():
61
- setattr(user, key, value)
62
- db.commit()
63
- db.refresh(user)
64
  return user
65
 
66
-
67
- def delete_user(user_id, db:Session):
68
- user = db.query(User).filter(User.id == user_id).first()
69
- if user:
70
- db.delete(user)
71
- db.commit()
72
- return user
73
- return None
74
-
75
-
76
- async def update_user_embeddings(user_id:int, embeddings:UserEmbeddings, db:Session):
77
- embeddings = db.query(UserEmbeddings).filter(UserEmbeddings.user_id == user_id).first()
78
- if embeddings:
79
- embeddings.embeddings = embeddings.embeddings
80
- db.commit()
81
- db.refresh(embeddings)
82
- return embeddings
83
- return None
84
-
85
-
 
1
  from users.models import User, UserEmbeddings
2
+ from fastapi import HTTPException
3
  from core.security import get_password_hash
4
  from datetime import datetime
5
  from sqlalchemy.orm import Session
6
+ from users.schemas import UserCreate, UserUpdate, UserEmbeddingsBase
7
 
8
+ async def create_user_account(data: UserCreate, db: Session):
 
 
9
  user = db.query(User).filter(User.email == data.email).first()
10
  if user:
11
+ raise HTTPException(status_code=400, detail="Email already registered")
12
 
13
  new_user = User(
14
  email=data.email,
15
+ username=data.username,
16
  first_name=data.first_name,
17
  last_name=data.last_name,
18
  age=data.age,
19
  preferences=data.preferences,
20
+ hashed_password=get_password_hash(data.password),
 
 
21
  )
22
 
23
  db.add(new_user)
 
25
  db.refresh(new_user)
26
  return new_user
27
 
28
+ async def create_user_embeddings(user_id: int, embeddings: UserEmbeddingsBase, db: Session):
29
  user = db.query(User).filter(User.id == user_id).first()
30
  if not user:
31
+ raise HTTPException(status_code=404, detail="User not found")
32
+
33
+ db_embeddings = UserEmbeddings(user_id=user_id, embeddings=embeddings.embeddings)
34
+ db.add(db_embeddings)
35
+ db.commit()
36
+ db.refresh(db_embeddings)
37
+ return db_embeddings
 
 
38
 
39
+ def get_user_by_id(user_id: int, db: Session):
40
+ user = db.query(User).filter(User.id == user_id).first()
41
+ if not user:
42
+ raise HTTPException(status_code=404, detail="User not found")
43
+ return user
44
 
45
+ def get_user_by_email(email: str, db: Session):
46
+ user = db.query(User).filter(User.email == email).first()
47
+ if not user:
48
+ raise HTTPException(status_code=404, detail="User not found")
49
+ return user
50
 
51
+ def get_users(db: Session, skip: int = 0, limit: int = 100):
52
  return db.query(User).offset(skip).limit(limit).all()
53
 
54
+ def update_user(db: Session, user_id: int, user: UserUpdate):
55
+ db_user = db.query(User).filter(User.id == user_id).first()
56
+ if not db_user:
57
+ raise HTTPException(status_code=404, detail="User not found")
58
+
59
+ update_data = user.dict(exclude_unset=True)
60
+ if 'password' in update_data:
61
+ update_data['hashed_password'] = get_password_hash(update_data['password'])
62
+ del update_data['password']
63
+
64
+ for key, value in update_data.items():
65
+ setattr(db_user, key, value)
66
+
67
+ db.commit()
68
+ db.refresh(db_user)
69
+ return db_user
70
 
71
+ def delete_user(user_id: int, db: Session):
72
  user = db.query(User).filter(User.id == user_id).first()
73
+ if not user:
74
+ raise HTTPException(status_code=404, detail="User not found")
75
+ db.delete(user)
76
+ db.commit()
 
 
 
 
 
 
77
  return user
78
 
79
+ async def update_user_embeddings(user_id: int, embeddings: UserEmbeddingsBase, db: Session):
80
+ db_embeddings = db.query(UserEmbeddings).filter(UserEmbeddings.user_id == user_id).first()
81
+ if not db_embeddings:
82
+ raise HTTPException(status_code=404, detail="User embeddings not found")
83
+
84
+ db_embeddings.embeddings = embeddings.embeddings
85
+ db.commit()
86
+ db.refresh(db_embeddings)
87
+ return db_embeddings