|
from sklearn.metrics.pairwise import cosine_similarity |
|
import numpy as np |
|
from sqlalchemy.orm import Session |
|
from users.models import UserEmbeddings |
|
from core.config import get_settings |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
settings = get_settings() |
|
|
|
class FaceMatch: |
|
def __init__(self, db: Session): |
|
self.db = db |
|
self.threshold = settings.FACE_RECOGNITION_THRESHOLD |
|
self.max_matches = 1 |
|
self.embedding_shape = None |
|
|
|
def load_embeddings_from_db(self): |
|
user_embeddings = self.db.query(UserEmbeddings).all() |
|
embeddings_dict = {} |
|
for ue in user_embeddings: |
|
embedding = np.array(ue.embeddings) |
|
if self.embedding_shape is None: |
|
self.embedding_shape = embedding.shape |
|
elif embedding.shape != self.embedding_shape: |
|
logger.warning(f"Inconsistent embedding shape for user {ue.user_id}. Expected {self.embedding_shape}, got {embedding.shape}") |
|
continue |
|
embeddings_dict[ue.user_id] = embedding |
|
return embeddings_dict |
|
|
|
def validate_embedding(self, embedding): |
|
if self.embedding_shape is None: |
|
logger.warning("No reference embedding shape available") |
|
return False |
|
if np.array(embedding).shape != self.embedding_shape: |
|
logger.warning(f"Invalid embedding shape. Expected {self.embedding_shape}, got {np.array(embedding).shape}") |
|
return False |
|
return True |
|
|
|
def match_faces(self, new_embeddings, saved_embeddings): |
|
if not self.validate_embedding(new_embeddings): |
|
return None, 0 |
|
|
|
new_embeddings = np.array(new_embeddings) |
|
similarities = [] |
|
|
|
for user_id, stored_embeddings in saved_embeddings.items(): |
|
similarity = cosine_similarity(new_embeddings.reshape(1, -1), stored_embeddings.reshape(1, -1))[0][0] |
|
similarities.append((user_id, similarity)) |
|
|
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
if similarities and similarities[0][1] > self.threshold: |
|
if len(similarities) == 1 or similarities[0][1] - similarities[1][1] > 0.1: |
|
return similarities[0] |
|
|
|
return None, 0 |
|
|
|
def new_face_matching(self, new_embeddings): |
|
embeddings_dict = self.load_embeddings_from_db() |
|
if not embeddings_dict: |
|
return {'status': 'Error', 'message': 'No valid embeddings available in the database'} |
|
|
|
if not self.validate_embedding(new_embeddings): |
|
return {'status': 'Error', 'message': 'Invalid embedding shape'} |
|
|
|
identity, similarity = self.match_faces(new_embeddings, embeddings_dict) |
|
if identity: |
|
return { |
|
'status': 'Success', |
|
'message': 'Match Found', |
|
'user_id': identity, |
|
'similarity': float(similarity) |
|
} |
|
return { |
|
'status': 'Error', |
|
'message': 'No matching face found or multiple potential matches detected' |
|
} |