Push new edits
Browse files- core/config.py +19 -8
- core/database.py +1 -0
- core/security.py +27 -43
- main.py +13 -8
- services/face_match.py +9 -6
- services/facial_processing.py +24 -18
- users/models.py +7 -7
- users/routes.py +22 -46
- users/schemas.py +22 -15
- users/services.py +55 -53
core/config.py
CHANGED
@@ -2,31 +2,42 @@ import os
|
|
2 |
from pathlib import Path
|
3 |
from dotenv import load_dotenv
|
4 |
from urllib.parse import quote_plus
|
5 |
-
from pydantic_settings import BaseSettings
|
6 |
|
7 |
load_dotenv()
|
8 |
|
9 |
class Settings(BaseSettings):
|
10 |
-
#
|
11 |
DATABASE_USER: str = os.getenv("PG_USER")
|
12 |
DATABASE_PASSWORD: str = os.getenv("PG_PASSWORD")
|
13 |
DATABASE_HOST: str = os.getenv("PG_HOST")
|
14 |
DATABASE_PORT: str = os.getenv("PG_PORT")
|
15 |
DATABASE_NAME: str = os.getenv("PG_NAME")
|
16 |
|
17 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
@property
|
19 |
def DATABASE_URL(self) -> str:
|
20 |
user = quote_plus(self.DATABASE_USER)
|
21 |
password = quote_plus(self.DATABASE_PASSWORD)
|
22 |
return f"postgresql://{user}:{password}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}?sslmode=require"
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", default="30")) # Default value as fallback
|
27 |
|
28 |
-
def get_settings():
|
29 |
return Settings()
|
30 |
|
31 |
# Usage example
|
32 |
-
settings = get_settings()
|
|
|
2 |
from pathlib import Path
|
3 |
from dotenv import load_dotenv
|
4 |
from urllib.parse import quote_plus
|
5 |
+
from pydantic_settings import BaseSettings
|
6 |
|
7 |
load_dotenv()
|
8 |
|
9 |
class Settings(BaseSettings):
|
10 |
+
# Database settings
|
11 |
DATABASE_USER: str = os.getenv("PG_USER")
|
12 |
DATABASE_PASSWORD: str = os.getenv("PG_PASSWORD")
|
13 |
DATABASE_HOST: str = os.getenv("PG_HOST")
|
14 |
DATABASE_PORT: str = os.getenv("PG_PORT")
|
15 |
DATABASE_NAME: str = os.getenv("PG_NAME")
|
16 |
|
17 |
+
# JWT settings
|
18 |
+
JWT_SECRET_KEY: str = os.getenv("JWT_SECRET")
|
19 |
+
JWT_ALGORITHM: str = "HS256"
|
20 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
21 |
+
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
22 |
+
|
23 |
+
# Application settings
|
24 |
+
DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true"
|
25 |
+
ALLOWED_HOSTS: list = os.getenv("ALLOWED_HOSTS", "*").split(",")
|
26 |
+
|
27 |
+
# Face recognition settings
|
28 |
+
FACE_RECOGNITION_THRESHOLD: float = float(os.getenv("FACE_RECOGNITION_THRESHOLD", "0.6"))
|
29 |
+
|
30 |
@property
|
31 |
def DATABASE_URL(self) -> str:
|
32 |
user = quote_plus(self.DATABASE_USER)
|
33 |
password = quote_plus(self.DATABASE_PASSWORD)
|
34 |
return f"postgresql://{user}:{password}@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}?sslmode=require"
|
35 |
|
36 |
+
class Config:
|
37 |
+
env_file = ".env"
|
|
|
38 |
|
39 |
+
def get_settings() -> Settings:
|
40 |
return Settings()
|
41 |
|
42 |
# Usage example
|
43 |
+
settings = get_settings()
|
core/database.py
CHANGED
@@ -14,6 +14,7 @@ engine = create_engine(
|
|
14 |
max_overflow=0
|
15 |
|
16 |
)
|
|
|
17 |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
18 |
|
19 |
Base = declarative_base()
|
|
|
14 |
max_overflow=0
|
15 |
|
16 |
)
|
17 |
+
|
18 |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
19 |
|
20 |
Base = declarative_base()
|
core/security.py
CHANGED
@@ -1,75 +1,59 @@
|
|
1 |
from passlib.context import CryptContext
|
2 |
from fastapi.security import OAuth2PasswordBearer
|
3 |
-
from fastapi import Depends, HTTPException
|
4 |
from datetime import timedelta, datetime
|
5 |
from jose import JWTError, jwt
|
6 |
from core.config import get_settings
|
7 |
from sqlalchemy.orm import Session
|
8 |
from core.database import get_db
|
9 |
|
10 |
-
|
11 |
settings = get_settings()
|
12 |
|
13 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
14 |
-
|
15 |
|
16 |
-
def get_password_hash(password):
|
17 |
return pwd_context.hash(password)
|
18 |
|
19 |
-
def verify_password(plain_password, hashed_password):
|
20 |
return pwd_context.verify(plain_password, hashed_password)
|
21 |
|
22 |
-
async def create_access_token(data:dict, expiry:timedelta):
|
23 |
payload = data.copy()
|
24 |
expire = datetime.utcnow() + expiry
|
25 |
payload.update({"exp": expire})
|
26 |
-
|
27 |
-
settings.JWT_SECRET_KEY,
|
28 |
-
algorithm=settings.JWT_ALGORITHM)
|
29 |
-
|
30 |
-
return token
|
31 |
|
32 |
-
async def create_refresh_token(data:dict):
|
33 |
payload = data.copy()
|
34 |
-
|
35 |
-
settings.JWT_SECRET_KEY,
|
36 |
-
algorithm=settings.JWT_ALGORITHM)
|
37 |
-
return token
|
38 |
|
39 |
-
def get_token_payload(token:str):
|
40 |
try:
|
41 |
-
|
42 |
-
settings.JWT_SECRET_KEY,
|
43 |
-
algorithms=[settings.JWT_ALGORITHM])
|
44 |
-
return payload
|
45 |
except JWTError:
|
46 |
return None
|
47 |
-
|
48 |
-
async def get_current_user(token: str = Depends(
|
49 |
-
from users.services import get_user_by_email # Local import
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
try:
|
52 |
payload = get_token_payload(token)
|
53 |
-
|
|
|
|
|
54 |
if email is None:
|
55 |
-
raise
|
56 |
-
detail="Invalid Token",
|
57 |
-
headers={"WWW-Authenticate": "Bearer"})
|
58 |
except JWTError:
|
59 |
-
raise
|
60 |
-
|
61 |
-
headers={"WWW-Authenticate": "Bearer"}
|
62 |
-
)
|
63 |
-
|
64 |
user = get_user_by_email(email, db=db)
|
65 |
if user is None:
|
66 |
-
raise
|
67 |
-
|
68 |
-
headers={"WWW-Authenticate": "Bearer"}
|
69 |
-
)
|
70 |
-
return user
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
1 |
from passlib.context import CryptContext
|
2 |
from fastapi.security import OAuth2PasswordBearer
|
3 |
+
from fastapi import Depends, HTTPException, status
|
4 |
from datetime import timedelta, datetime
|
5 |
from jose import JWTError, jwt
|
6 |
from core.config import get_settings
|
7 |
from sqlalchemy.orm import Session
|
8 |
from core.database import get_db
|
9 |
|
|
|
10 |
settings = get_settings()
|
11 |
|
12 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
13 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token/")
|
14 |
|
15 |
+
def get_password_hash(password: str) -> str:
|
16 |
return pwd_context.hash(password)
|
17 |
|
18 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
19 |
return pwd_context.verify(plain_password, hashed_password)
|
20 |
|
21 |
+
async def create_access_token(data: dict, expiry: timedelta) -> str:
|
22 |
payload = data.copy()
|
23 |
expire = datetime.utcnow() + expiry
|
24 |
payload.update({"exp": expire})
|
25 |
+
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
async def create_refresh_token(data: dict) -> str:
|
28 |
payload = data.copy()
|
29 |
+
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
|
|
|
|
|
|
30 |
|
31 |
+
def get_token_payload(token: str) -> dict:
|
32 |
try:
|
33 |
+
return jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
|
|
|
|
|
|
34 |
except JWTError:
|
35 |
return None
|
36 |
+
|
37 |
+
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
|
38 |
+
from users.services import get_user_by_email # Local import to avoid circular dependency
|
39 |
+
|
40 |
+
credentials_exception = HTTPException(
|
41 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
42 |
+
detail="Could not validate credentials",
|
43 |
+
headers={"WWW-Authenticate": "Bearer"},
|
44 |
+
)
|
45 |
|
46 |
try:
|
47 |
payload = get_token_payload(token)
|
48 |
+
if payload is None:
|
49 |
+
raise credentials_exception
|
50 |
+
email: str = payload.get("sub")
|
51 |
if email is None:
|
52 |
+
raise credentials_exception
|
|
|
|
|
53 |
except JWTError:
|
54 |
+
raise credentials_exception
|
55 |
+
|
|
|
|
|
|
|
56 |
user = get_user_by_email(email, db=db)
|
57 |
if user is None:
|
58 |
+
raise credentials_exception
|
59 |
+
return user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
# main.py
|
2 |
from fastapi import FastAPI
|
3 |
-
from
|
4 |
from users.routes import router as users_router
|
5 |
from auth.route import router as auth_router
|
6 |
from orders.routes import order_router, meal_router
|
@@ -13,24 +12,30 @@ app = FastAPI(
|
|
13 |
redoc_url="/redoc",
|
14 |
openapi_url="/openapi.json",
|
15 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
app.include_router(users_router)
|
17 |
app.include_router(auth_router)
|
18 |
app.include_router(meal_router)
|
19 |
app.include_router(order_router)
|
20 |
|
21 |
-
|
22 |
@app.get("/", tags=["Home"])
|
23 |
def read_root():
|
24 |
return {"message": "Welcome to SnapFeast API!"}
|
25 |
|
26 |
-
|
27 |
@app.get("/health", tags=["Health"])
|
28 |
def health_check():
|
29 |
return {"status": "ok"}
|
30 |
|
31 |
-
|
32 |
if __name__ == "__main__":
|
33 |
import uvicorn
|
34 |
-
|
35 |
-
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
36 |
-
|
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
from users.routes import router as users_router
|
4 |
from auth.route import router as auth_router
|
5 |
from orders.routes import order_router, meal_router
|
|
|
12 |
redoc_url="/redoc",
|
13 |
openapi_url="/openapi.json",
|
14 |
)
|
15 |
+
|
16 |
+
# Configure CORS
|
17 |
+
app.add_middleware(
|
18 |
+
CORSMiddleware,
|
19 |
+
allow_origins=["*"], # Allows all origins
|
20 |
+
allow_credentials=True,
|
21 |
+
allow_methods=["*"], # Allows all methods
|
22 |
+
allow_headers=["*"], # Allows all headers
|
23 |
+
)
|
24 |
+
|
25 |
+
# Include routers
|
26 |
app.include_router(users_router)
|
27 |
app.include_router(auth_router)
|
28 |
app.include_router(meal_router)
|
29 |
app.include_router(order_router)
|
30 |
|
|
|
31 |
@app.get("/", tags=["Home"])
|
32 |
def read_root():
|
33 |
return {"message": "Welcome to SnapFeast API!"}
|
34 |
|
|
|
35 |
@app.get("/health", tags=["Health"])
|
36 |
def health_check():
|
37 |
return {"status": "ok"}
|
38 |
|
|
|
39 |
if __name__ == "__main__":
|
40 |
import uvicorn
|
41 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|
|
|
|
services/face_match.py
CHANGED
@@ -2,16 +2,20 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
2 |
import numpy as np
|
3 |
from sqlalchemy.orm import Session
|
4 |
from users.models import UserEmbeddings
|
|
|
|
|
|
|
5 |
|
6 |
class FaceMatch:
|
7 |
def __init__(self, db: Session):
|
8 |
self.db = db
|
|
|
9 |
|
10 |
def load_embeddings_from_db(self):
|
11 |
user_embeddings = self.db.query(UserEmbeddings).all()
|
12 |
return {ue.user_id: np.array(ue.embeddings) for ue in user_embeddings}
|
13 |
|
14 |
-
def match_faces(self, new_embeddings, saved_embeddings
|
15 |
new_embeddings = np.array(new_embeddings)
|
16 |
max_similarity = 0
|
17 |
identity = None
|
@@ -22,12 +26,12 @@ class FaceMatch:
|
|
22 |
max_similarity = similarity
|
23 |
identity = user_id
|
24 |
|
25 |
-
return identity, max_similarity if max_similarity > threshold else (None, 0)
|
26 |
|
27 |
def new_face_matching(self, new_embeddings):
|
28 |
embeddings_dict = self.load_embeddings_from_db()
|
29 |
if not embeddings_dict:
|
30 |
-
return {'status': 'Error', 'message': 'No embeddings available'}
|
31 |
|
32 |
identity, similarity = self.match_faces(new_embeddings, embeddings_dict)
|
33 |
if identity:
|
@@ -35,10 +39,9 @@ class FaceMatch:
|
|
35 |
'status': 'Success',
|
36 |
'message': 'Match Found',
|
37 |
'user_id': identity,
|
38 |
-
'similarity': similarity
|
39 |
}
|
40 |
return {
|
41 |
'status': 'Error',
|
42 |
'message': 'No matching face found'
|
43 |
-
}
|
44 |
-
|
|
|
2 |
import numpy as np
|
3 |
from sqlalchemy.orm import Session
|
4 |
from users.models import UserEmbeddings
|
5 |
+
from core.config import get_settings
|
6 |
+
|
7 |
+
settings = get_settings()
|
8 |
|
9 |
class FaceMatch:
|
10 |
def __init__(self, db: Session):
|
11 |
self.db = db
|
12 |
+
self.threshold = settings.FACE_RECOGNITION_THRESHOLD
|
13 |
|
14 |
def load_embeddings_from_db(self):
|
15 |
user_embeddings = self.db.query(UserEmbeddings).all()
|
16 |
return {ue.user_id: np.array(ue.embeddings) for ue in user_embeddings}
|
17 |
|
18 |
+
def match_faces(self, new_embeddings, saved_embeddings):
|
19 |
new_embeddings = np.array(new_embeddings)
|
20 |
max_similarity = 0
|
21 |
identity = None
|
|
|
26 |
max_similarity = similarity
|
27 |
identity = user_id
|
28 |
|
29 |
+
return identity, max_similarity if max_similarity > self.threshold else (None, 0)
|
30 |
|
31 |
def new_face_matching(self, new_embeddings):
|
32 |
embeddings_dict = self.load_embeddings_from_db()
|
33 |
if not embeddings_dict:
|
34 |
+
return {'status': 'Error', 'message': 'No embeddings available in the database'}
|
35 |
|
36 |
identity, similarity = self.match_faces(new_embeddings, embeddings_dict)
|
37 |
if identity:
|
|
|
39 |
'status': 'Success',
|
40 |
'message': 'Match Found',
|
41 |
'user_id': identity,
|
42 |
+
'similarity': float(similarity) # Convert numpy float to Python float
|
43 |
}
|
44 |
return {
|
45 |
'status': 'Error',
|
46 |
'message': 'No matching face found'
|
47 |
+
}
|
|
services/facial_processing.py
CHANGED
@@ -3,35 +3,41 @@ import os
|
|
3 |
import torch
|
4 |
from facenet_pytorch import MTCNN, InceptionResnetV1
|
5 |
import logging
|
|
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
9 |
-
|
10 |
class FacialProcessing:
|
11 |
def __init__(self):
|
12 |
-
# Set the cache directory to a writable location
|
13 |
os.environ['TORCH_HOME'] = '/tmp/.cache/torch'
|
14 |
-
|
15 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
self.resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
|
19 |
-
|
20 |
-
|
21 |
-
def extract_embeddings_vgg(self, image):
|
22 |
try:
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return None
|
29 |
|
30 |
-
# Extract
|
31 |
-
embeddings = self.resnet(
|
32 |
-
|
33 |
-
return embeddings[0]
|
34 |
|
35 |
except Exception as e:
|
36 |
logger.error(f"An error occurred while extracting embeddings: {e}")
|
37 |
-
return None
|
|
|
3 |
import torch
|
4 |
from facenet_pytorch import MTCNN, InceptionResnetV1
|
5 |
import logging
|
6 |
+
from PIL import Image
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
|
|
10 |
class FacialProcessing:
|
11 |
def __init__(self):
|
|
|
12 |
os.environ['TORCH_HOME'] = '/tmp/.cache/torch'
|
|
|
13 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
self.mtcnn = MTCNN(keep_all=True, device=self.device)
|
15 |
+
self.resnet = InceptionResnetV1(pretrained='vggface2').eval().to(self.device)
|
16 |
|
17 |
+
def extract_embeddings_vgg(self, image_path):
|
|
|
|
|
|
|
|
|
18 |
try:
|
19 |
+
img = Image.open(image_path)
|
20 |
+
img = img.convert('RGB')
|
21 |
+
|
22 |
+
# Detect faces
|
23 |
+
boxes, _ = self.mtcnn.detect(img)
|
24 |
+
|
25 |
+
if boxes is None:
|
26 |
+
logger.warning(f"No face detected in image: {image_path}")
|
27 |
+
return None
|
28 |
+
|
29 |
+
# Get the largest face
|
30 |
+
largest_box = max(boxes, key=lambda box: (box[2] - box[0]) * (box[3] - box[1]))
|
31 |
+
face = self.mtcnn(img, return_prob=False)
|
32 |
+
|
33 |
+
if face is None:
|
34 |
+
logger.warning(f"Failed to align face in image: {image_path}")
|
35 |
return None
|
36 |
|
37 |
+
# Extract embeddings
|
38 |
+
embeddings = self.resnet(face).detach().cpu().numpy().flatten()
|
39 |
+
return embeddings.tolist()
|
|
|
40 |
|
41 |
except Exception as e:
|
42 |
logger.error(f"An error occurred while extracting embeddings: {e}")
|
43 |
+
return None
|
users/models.py
CHANGED
@@ -7,16 +7,16 @@ class User(Base):
|
|
7 |
__tablename__ = "users"
|
8 |
|
9 |
id = Column(Integer, primary_key=True, index=True)
|
10 |
-
email = Column(String, unique=True, index=True)
|
11 |
-
username = Column(String, unique=True, index=True)
|
12 |
-
hashed_password = Column(String)
|
13 |
first_name = Column(String)
|
14 |
last_name = Column(String)
|
15 |
age = Column(Integer)
|
16 |
preferences = Column(ARRAY(String))
|
17 |
is_active = Column(Boolean, default=True)
|
18 |
is_admin = Column(Boolean, default=False)
|
19 |
-
updated_at = Column(DateTime, default=datetime.utcnow)
|
20 |
created_at = Column(DateTime, default=datetime.utcnow)
|
21 |
|
22 |
orders = relationship("Order", back_populates="user")
|
@@ -26,7 +26,7 @@ class UserEmbeddings(Base):
|
|
26 |
__tablename__ = "user_embeddings"
|
27 |
|
28 |
id = Column(Integer, primary_key=True, index=True)
|
29 |
-
user_id = Column(Integer, ForeignKey("users.id"))
|
30 |
-
embeddings = Column(ARRAY(Float))
|
31 |
|
32 |
-
user = relationship("User", back_populates="embeddings")
|
|
|
7 |
__tablename__ = "users"
|
8 |
|
9 |
id = Column(Integer, primary_key=True, index=True)
|
10 |
+
email = Column(String, unique=True, index=True, nullable=False)
|
11 |
+
username = Column(String, unique=True, index=True, nullable=False)
|
12 |
+
hashed_password = Column(String, nullable=False)
|
13 |
first_name = Column(String)
|
14 |
last_name = Column(String)
|
15 |
age = Column(Integer)
|
16 |
preferences = Column(ARRAY(String))
|
17 |
is_active = Column(Boolean, default=True)
|
18 |
is_admin = Column(Boolean, default=False)
|
19 |
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
20 |
created_at = Column(DateTime, default=datetime.utcnow)
|
21 |
|
22 |
orders = relationship("Order", back_populates="user")
|
|
|
26 |
__tablename__ = "user_embeddings"
|
27 |
|
28 |
id = Column(Integer, primary_key=True, index=True)
|
29 |
+
user_id = Column(Integer, ForeignKey("users.id"), unique=True, nullable=False)
|
30 |
+
embeddings = Column(ARRAY(Float), nullable=False)
|
31 |
|
32 |
+
user = relationship("User", back_populates="embeddings")
|
users/routes.py
CHANGED
@@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
|
|
3 |
from sqlalchemy.orm import Session
|
4 |
from core.database import get_db
|
5 |
from core.security import get_current_user, create_access_token
|
6 |
-
from users.schemas import UserCreate, UserBase, UserEmbeddingsBase, User
|
7 |
from users.services import create_user_account, create_user_embeddings, update_user, update_user_embeddings, get_user_by_id, get_user_by_email
|
8 |
from services.facial_processing import FacialProcessing
|
9 |
from services.face_match import FaceMatch
|
@@ -12,7 +12,6 @@ from datetime import timedelta
|
|
12 |
from dotenv import load_dotenv
|
13 |
from auth.services import get_token
|
14 |
|
15 |
-
|
16 |
load_dotenv()
|
17 |
|
18 |
router = APIRouter(
|
@@ -21,104 +20,81 @@ router = APIRouter(
|
|
21 |
responses={404: {"description": "Not found"}},
|
22 |
)
|
23 |
|
24 |
-
@router.get("/health", tags=["Health"])
|
25 |
-
def health_check():
|
26 |
-
return {"status": "ok"}
|
27 |
-
|
28 |
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=UserBase)
|
29 |
async def create_user(data: UserCreate, db: Session = Depends(get_db)):
|
30 |
-
await create_user_account(data, db)
|
31 |
-
|
32 |
-
return JSONResponse(content=payload, status_code=status.HTTP_201_CREATED)
|
33 |
-
|
34 |
-
|
35 |
-
@router.post("/me",status_code=status.HTTP_200_OK)
|
36 |
-
def get_user_details(request: Request):
|
37 |
-
user = request.user
|
38 |
-
return user
|
39 |
|
40 |
@router.get("/me", response_model=UserBase)
|
41 |
async def read_users_me(current_user: User = Depends(get_current_user)):
|
42 |
return current_user
|
43 |
|
44 |
@router.put("/me", response_model=UserBase)
|
45 |
-
async def
|
46 |
-
|
47 |
-
return
|
48 |
-
|
49 |
|
50 |
-
@router.post("/me/face",status_code=status.HTTP_200_OK)
|
51 |
-
async def create_face_embeddings(file: UploadFile = File(...), user:User = Depends(get_current_user), db: Session = Depends(get_db)):
|
52 |
face_processor = FacialProcessing()
|
53 |
|
54 |
-
# Process the uploaded image
|
55 |
image_path = f"faces/{user.id}.jpg"
|
56 |
with open(image_path, "wb") as buffer:
|
57 |
buffer.write(await file.read())
|
58 |
|
59 |
-
# Extract embeddings
|
60 |
embeddings = face_processor.extract_embeddings_vgg(image_path)
|
61 |
if embeddings:
|
62 |
-
|
|
|
63 |
return {"message": "Face embeddings created successfully"}
|
64 |
|
65 |
-
raise HTTPException(status_code=400,
|
66 |
-
detail="Failed to process face"
|
67 |
-
)
|
68 |
|
69 |
-
@router.get("/me/face",status_code=status.HTTP_200_OK)
|
70 |
-
async def get_face_embeddings(user:User = Depends(get_current_user), db: Session = Depends(get_db)):
|
71 |
face = db.query(UserEmbeddingsBase).filter(UserEmbeddingsBase.user_id == user.id).first()
|
72 |
if not face:
|
73 |
-
raise HTTPException(status_code=404,
|
74 |
-
detail="Face embeddings not found"
|
75 |
-
)
|
76 |
return JSONResponse(content={"embeddings": face.embeddings}, status_code=status.HTTP_200_OK)
|
77 |
|
78 |
-
@router.put("/me/face",status_code=status.HTTP_200_OK)
|
79 |
-
async def
|
80 |
face_processor = FacialProcessing()
|
81 |
|
82 |
-
# Process the uploaded image
|
83 |
image_path = f"faces/{user.id}.jpg"
|
84 |
with open(image_path, "wb") as buffer:
|
85 |
buffer.write(await file.read())
|
86 |
|
87 |
-
# Extract embeddings
|
88 |
embeddings = face_processor.extract_embeddings_vgg(image_path)
|
89 |
if embeddings:
|
90 |
-
|
|
|
91 |
return {"message": "Face embeddings updated successfully"}
|
92 |
|
93 |
-
raise HTTPException(status_code=400,
|
94 |
-
detail="Failed to process face"
|
95 |
-
)
|
96 |
|
97 |
@router.post("/login/face")
|
98 |
async def face_login(file: UploadFile = File(...), db: Session = Depends(get_db)):
|
99 |
face_processor = FacialProcessing()
|
100 |
face_matcher = FaceMatch(db)
|
101 |
|
102 |
-
# Process the uploaded image
|
103 |
image_path = f"temp_{file.filename}"
|
104 |
with open(image_path, "wb") as buffer:
|
105 |
buffer.write(await file.read())
|
106 |
|
107 |
-
# Extract embeddings
|
108 |
embeddings = await face_processor.extract_embeddings(image_path)
|
109 |
if not embeddings:
|
110 |
raise HTTPException(status_code=400, detail="Failed to process face")
|
111 |
|
112 |
-
# Match face
|
113 |
match_result = face_matcher.new_face_matching(embeddings)
|
114 |
if match_result['status'] == 'Success':
|
115 |
user = get_user_by_id(match_result['user_id'], db)
|
116 |
if not user:
|
117 |
raise HTTPException(status_code=404, detail="User not found")
|
118 |
|
119 |
-
access_token_expires = timedelta(minutes=os.getenv("
|
120 |
-
payload = {"id":user.id, "sub": user.email}
|
121 |
token = get_token(payload, db)
|
122 |
return JSONResponse(content=token.dict(), status_code=status.HTTP_200_OK)
|
123 |
|
124 |
-
raise HTTPException(status_code=401, detail="Face not recognized")
|
|
|
3 |
from sqlalchemy.orm import Session
|
4 |
from core.database import get_db
|
5 |
from core.security import get_current_user, create_access_token
|
6 |
+
from users.schemas import UserCreate, UserBase, UserEmbeddingsBase, User, UserUpdate
|
7 |
from users.services import create_user_account, create_user_embeddings, update_user, update_user_embeddings, get_user_by_id, get_user_by_email
|
8 |
from services.facial_processing import FacialProcessing
|
9 |
from services.face_match import FaceMatch
|
|
|
12 |
from dotenv import load_dotenv
|
13 |
from auth.services import get_token
|
14 |
|
|
|
15 |
load_dotenv()
|
16 |
|
17 |
router = APIRouter(
|
|
|
20 |
responses={404: {"description": "Not found"}},
|
21 |
)
|
22 |
|
|
|
|
|
|
|
|
|
23 |
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=UserBase)
|
24 |
async def create_user(data: UserCreate, db: Session = Depends(get_db)):
|
25 |
+
new_user = await create_user_account(data, db)
|
26 |
+
return new_user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
@router.get("/me", response_model=UserBase)
|
29 |
async def read_users_me(current_user: User = Depends(get_current_user)):
|
30 |
return current_user
|
31 |
|
32 |
@router.put("/me", response_model=UserBase)
|
33 |
+
async def update_user_me(user: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
34 |
+
updated_user = update_user(db, current_user.id, user)
|
35 |
+
return updated_user
|
|
|
36 |
|
37 |
+
@router.post("/me/face", status_code=status.HTTP_200_OK)
|
38 |
+
async def create_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
39 |
face_processor = FacialProcessing()
|
40 |
|
|
|
41 |
image_path = f"faces/{user.id}.jpg"
|
42 |
with open(image_path, "wb") as buffer:
|
43 |
buffer.write(await file.read())
|
44 |
|
|
|
45 |
embeddings = face_processor.extract_embeddings_vgg(image_path)
|
46 |
if embeddings:
|
47 |
+
user_embeddings = UserEmbeddingsBase(embeddings=embeddings)
|
48 |
+
await create_user_embeddings(user.id, user_embeddings, db)
|
49 |
return {"message": "Face embeddings created successfully"}
|
50 |
|
51 |
+
raise HTTPException(status_code=400, detail="Failed to process face")
|
|
|
|
|
52 |
|
53 |
+
@router.get("/me/face", status_code=status.HTTP_200_OK)
|
54 |
+
async def get_face_embeddings(user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
55 |
face = db.query(UserEmbeddingsBase).filter(UserEmbeddingsBase.user_id == user.id).first()
|
56 |
if not face:
|
57 |
+
raise HTTPException(status_code=404, detail="Face embeddings not found")
|
|
|
|
|
58 |
return JSONResponse(content={"embeddings": face.embeddings}, status_code=status.HTTP_200_OK)
|
59 |
|
60 |
+
@router.put("/me/face", status_code=status.HTTP_200_OK)
|
61 |
+
async def update_face_embeddings(file: UploadFile = File(...), user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
62 |
face_processor = FacialProcessing()
|
63 |
|
|
|
64 |
image_path = f"faces/{user.id}.jpg"
|
65 |
with open(image_path, "wb") as buffer:
|
66 |
buffer.write(await file.read())
|
67 |
|
|
|
68 |
embeddings = face_processor.extract_embeddings_vgg(image_path)
|
69 |
if embeddings:
|
70 |
+
user_embeddings = UserEmbeddingsBase(embeddings=embeddings)
|
71 |
+
await update_user_embeddings(user.id, user_embeddings, db)
|
72 |
return {"message": "Face embeddings updated successfully"}
|
73 |
|
74 |
+
raise HTTPException(status_code=400, detail="Failed to process face")
|
|
|
|
|
75 |
|
76 |
@router.post("/login/face")
|
77 |
async def face_login(file: UploadFile = File(...), db: Session = Depends(get_db)):
|
78 |
face_processor = FacialProcessing()
|
79 |
face_matcher = FaceMatch(db)
|
80 |
|
|
|
81 |
image_path = f"temp_{file.filename}"
|
82 |
with open(image_path, "wb") as buffer:
|
83 |
buffer.write(await file.read())
|
84 |
|
|
|
85 |
embeddings = await face_processor.extract_embeddings(image_path)
|
86 |
if not embeddings:
|
87 |
raise HTTPException(status_code=400, detail="Failed to process face")
|
88 |
|
|
|
89 |
match_result = face_matcher.new_face_matching(embeddings)
|
90 |
if match_result['status'] == 'Success':
|
91 |
user = get_user_by_id(match_result['user_id'], db)
|
92 |
if not user:
|
93 |
raise HTTPException(status_code=404, detail="User not found")
|
94 |
|
95 |
+
access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")))
|
96 |
+
payload = {"id": user.id, "sub": user.email}
|
97 |
token = get_token(payload, db)
|
98 |
return JSONResponse(content=token.dict(), status_code=status.HTTP_200_OK)
|
99 |
|
100 |
+
raise HTTPException(status_code=401, detail="Face not recognized")
|
users/schemas.py
CHANGED
@@ -1,25 +1,34 @@
|
|
1 |
-
from pydantic import BaseModel, EmailStr
|
2 |
from typing import Optional, List
|
|
|
3 |
|
4 |
class UserBase(BaseModel):
|
5 |
-
username: str
|
6 |
-
first_name: str
|
7 |
-
last_name: str
|
8 |
email: EmailStr
|
9 |
-
age: Optional[int] = None
|
10 |
preferences: Optional[List[str]] = None
|
11 |
-
is_active:
|
12 |
-
|
13 |
|
14 |
class UserCreate(UserBase):
|
15 |
-
password: str
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
class User(UserBase):
|
22 |
id: int
|
|
|
|
|
|
|
|
|
23 |
|
24 |
class Config:
|
25 |
orm_mode = True
|
@@ -30,11 +39,9 @@ class UserEmbeddingsBase(BaseModel):
|
|
30 |
class UserEmbeddingsCreate(UserEmbeddingsBase):
|
31 |
pass
|
32 |
|
33 |
-
|
34 |
class UserEmbeddings(UserEmbeddingsBase):
|
35 |
id: int
|
36 |
user_id: int
|
37 |
|
38 |
class Config:
|
39 |
-
orm_mode = True
|
40 |
-
|
|
|
1 |
+
from pydantic import BaseModel, EmailStr, Field
|
2 |
from typing import Optional, List
|
3 |
+
from datetime import datetime
|
4 |
|
5 |
class UserBase(BaseModel):
|
6 |
+
username: str = Field(..., min_length=3, max_length=50)
|
7 |
+
first_name: str = Field(..., min_length=1, max_length=50)
|
8 |
+
last_name: str = Field(..., min_length=1, max_length=50)
|
9 |
email: EmailStr
|
10 |
+
age: Optional[int] = Field(None, ge=0, le=120)
|
11 |
preferences: Optional[List[str]] = None
|
12 |
+
is_active: bool = True
|
|
|
13 |
|
14 |
class UserCreate(UserBase):
|
15 |
+
password: str = Field(..., min_length=8)
|
16 |
+
|
17 |
+
class UserUpdate(BaseModel):
|
18 |
+
username: Optional[str] = Field(None, min_length=3, max_length=50)
|
19 |
+
first_name: Optional[str] = Field(None, min_length=1, max_length=50)
|
20 |
+
last_name: Optional[str] = Field(None, min_length=1, max_length=50)
|
21 |
+
email: Optional[EmailStr] = None
|
22 |
+
age: Optional[int] = Field(None, ge=0, le=120)
|
23 |
+
preferences: Optional[List[str]] = None
|
24 |
+
password: Optional[str] = Field(None, min_length=8)
|
25 |
|
26 |
class User(UserBase):
|
27 |
id: int
|
28 |
+
is_active: bool
|
29 |
+
is_admin: bool
|
30 |
+
created_at: datetime
|
31 |
+
updated_at: datetime
|
32 |
|
33 |
class Config:
|
34 |
orm_mode = True
|
|
|
39 |
class UserEmbeddingsCreate(UserEmbeddingsBase):
|
40 |
pass
|
41 |
|
|
|
42 |
class UserEmbeddings(UserEmbeddingsBase):
|
43 |
id: int
|
44 |
user_id: int
|
45 |
|
46 |
class Config:
|
47 |
+
orm_mode = True
|
|
users/services.py
CHANGED
@@ -1,26 +1,23 @@
|
|
1 |
from users.models import User, UserEmbeddings
|
2 |
-
from fastapi
|
3 |
from core.security import get_password_hash
|
4 |
from datetime import datetime
|
5 |
from sqlalchemy.orm import Session
|
6 |
-
from users.schemas import UserCreate, UserUpdate,
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
async def create_user_account(data:UserCreate, db:Session):
|
11 |
user = db.query(User).filter(User.email == data.email).first()
|
12 |
if user:
|
13 |
-
raise HTTPException(status_code=
|
14 |
|
15 |
new_user = User(
|
16 |
email=data.email,
|
|
|
17 |
first_name=data.first_name,
|
18 |
last_name=data.last_name,
|
19 |
age=data.age,
|
20 |
preferences=data.preferences,
|
21 |
-
|
22 |
-
registered_at=datetime.now(),
|
23 |
-
updated_at=datetime.now(),
|
24 |
)
|
25 |
|
26 |
db.add(new_user)
|
@@ -28,58 +25,63 @@ async def create_user_account(data:UserCreate, db:Session):
|
|
28 |
db.refresh(new_user)
|
29 |
return new_user
|
30 |
|
31 |
-
async def create_user_embeddings(user_id:int, embeddings:
|
32 |
user = db.query(User).filter(User.id == user_id).first()
|
33 |
if not user:
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
return db.query(User).filter(User.id == user_id).first()
|
42 |
-
|
43 |
|
44 |
-
def
|
45 |
-
|
|
|
|
|
|
|
46 |
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
def get_users(db:Session, skip=0, limit=100):
|
49 |
return db.query(User).offset(skip).limit(limit).all()
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
def
|
53 |
user = db.query(User).filter(User.id == user_id).first()
|
54 |
-
if user:
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
del update_data['password']
|
59 |
-
|
60 |
-
for key, value in update_data.items():
|
61 |
-
setattr(user, key, value)
|
62 |
-
db.commit()
|
63 |
-
db.refresh(user)
|
64 |
return user
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
async def update_user_embeddings(user_id:int, embeddings:UserEmbeddings, db:Session):
|
77 |
-
embeddings = db.query(UserEmbeddings).filter(UserEmbeddings.user_id == user_id).first()
|
78 |
-
if embeddings:
|
79 |
-
embeddings.embeddings = embeddings.embeddings
|
80 |
-
db.commit()
|
81 |
-
db.refresh(embeddings)
|
82 |
-
return embeddings
|
83 |
-
return None
|
84 |
-
|
85 |
-
|
|
|
1 |
from users.models import User, UserEmbeddings
|
2 |
+
from fastapi import HTTPException
|
3 |
from core.security import get_password_hash
|
4 |
from datetime import datetime
|
5 |
from sqlalchemy.orm import Session
|
6 |
+
from users.schemas import UserCreate, UserUpdate, UserEmbeddingsBase
|
7 |
|
8 |
+
async def create_user_account(data: UserCreate, db: Session):
|
|
|
|
|
9 |
user = db.query(User).filter(User.email == data.email).first()
|
10 |
if user:
|
11 |
+
raise HTTPException(status_code=400, detail="Email already registered")
|
12 |
|
13 |
new_user = User(
|
14 |
email=data.email,
|
15 |
+
username=data.username,
|
16 |
first_name=data.first_name,
|
17 |
last_name=data.last_name,
|
18 |
age=data.age,
|
19 |
preferences=data.preferences,
|
20 |
+
hashed_password=get_password_hash(data.password),
|
|
|
|
|
21 |
)
|
22 |
|
23 |
db.add(new_user)
|
|
|
25 |
db.refresh(new_user)
|
26 |
return new_user
|
27 |
|
28 |
+
async def create_user_embeddings(user_id: int, embeddings: UserEmbeddingsBase, db: Session):
|
29 |
user = db.query(User).filter(User.id == user_id).first()
|
30 |
if not user:
|
31 |
+
raise HTTPException(status_code=404, detail="User not found")
|
32 |
+
|
33 |
+
db_embeddings = UserEmbeddings(user_id=user_id, embeddings=embeddings.embeddings)
|
34 |
+
db.add(db_embeddings)
|
35 |
+
db.commit()
|
36 |
+
db.refresh(db_embeddings)
|
37 |
+
return db_embeddings
|
|
|
|
|
38 |
|
39 |
+
def get_user_by_id(user_id: int, db: Session):
|
40 |
+
user = db.query(User).filter(User.id == user_id).first()
|
41 |
+
if not user:
|
42 |
+
raise HTTPException(status_code=404, detail="User not found")
|
43 |
+
return user
|
44 |
|
45 |
+
def get_user_by_email(email: str, db: Session):
|
46 |
+
user = db.query(User).filter(User.email == email).first()
|
47 |
+
if not user:
|
48 |
+
raise HTTPException(status_code=404, detail="User not found")
|
49 |
+
return user
|
50 |
|
51 |
+
def get_users(db: Session, skip: int = 0, limit: int = 100):
|
52 |
return db.query(User).offset(skip).limit(limit).all()
|
53 |
|
54 |
+
def update_user(db: Session, user_id: int, user: UserUpdate):
|
55 |
+
db_user = db.query(User).filter(User.id == user_id).first()
|
56 |
+
if not db_user:
|
57 |
+
raise HTTPException(status_code=404, detail="User not found")
|
58 |
+
|
59 |
+
update_data = user.dict(exclude_unset=True)
|
60 |
+
if 'password' in update_data:
|
61 |
+
update_data['hashed_password'] = get_password_hash(update_data['password'])
|
62 |
+
del update_data['password']
|
63 |
+
|
64 |
+
for key, value in update_data.items():
|
65 |
+
setattr(db_user, key, value)
|
66 |
+
|
67 |
+
db.commit()
|
68 |
+
db.refresh(db_user)
|
69 |
+
return db_user
|
70 |
|
71 |
+
def delete_user(user_id: int, db: Session):
|
72 |
user = db.query(User).filter(User.id == user_id).first()
|
73 |
+
if not user:
|
74 |
+
raise HTTPException(status_code=404, detail="User not found")
|
75 |
+
db.delete(user)
|
76 |
+
db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
return user
|
78 |
|
79 |
+
async def update_user_embeddings(user_id: int, embeddings: UserEmbeddingsBase, db: Session):
|
80 |
+
db_embeddings = db.query(UserEmbeddings).filter(UserEmbeddings.user_id == user_id).first()
|
81 |
+
if not db_embeddings:
|
82 |
+
raise HTTPException(status_code=404, detail="User embeddings not found")
|
83 |
+
|
84 |
+
db_embeddings.embeddings = embeddings.embeddings
|
85 |
+
db.commit()
|
86 |
+
db.refresh(db_embeddings)
|
87 |
+
return db_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|