SnapFeast / services /face_match.py
Testys's picture
Adding migrations from alembic
d6866b9
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from sqlalchemy.orm import Session
from users.models import UserEmbeddings
from core.config import get_settings
import logging
logger = logging.getLogger(__name__)
settings = get_settings()
class FaceMatch:
def __init__(self, db: Session):
self.db = db
self.threshold = settings.FACE_RECOGNITION_THRESHOLD
self.max_matches = 1 # Only allow one match
self.embedding_shape = None # Will be set when loading embeddings
def load_embeddings_from_db(self):
user_embeddings = self.db.query(UserEmbeddings).all()
embeddings_dict = {}
for ue in user_embeddings:
embedding = np.array(ue.embeddings)
if self.embedding_shape is None:
self.embedding_shape = embedding.shape
elif embedding.shape != self.embedding_shape:
logger.warning(f"Inconsistent embedding shape for user {ue.user_id}. Expected {self.embedding_shape}, got {embedding.shape}")
continue # Skip this embedding
embeddings_dict[ue.user_id] = embedding
return embeddings_dict
def validate_embedding(self, embedding):
if self.embedding_shape is None:
logger.warning("No reference embedding shape available")
return False
if np.array(embedding).shape != self.embedding_shape:
logger.warning(f"Invalid embedding shape. Expected {self.embedding_shape}, got {np.array(embedding).shape}")
return False
return True
def match_faces(self, new_embeddings, saved_embeddings):
if not self.validate_embedding(new_embeddings):
return None, 0
new_embeddings = np.array(new_embeddings)
similarities = []
for user_id, stored_embeddings in saved_embeddings.items():
similarity = cosine_similarity(new_embeddings.reshape(1, -1), stored_embeddings.reshape(1, -1))[0][0]
similarities.append((user_id, similarity))
# Sort similarities in descending order
similarities.sort(key=lambda x: x[1], reverse=True)
# Check if the top match exceeds the threshold and if there's a significant gap to the second-best match
if similarities and similarities[0][1] > self.threshold:
if len(similarities) == 1 or similarities[0][1] - similarities[1][1] > 0.1:
return similarities[0]
return None, 0
def new_face_matching(self, new_embeddings):
embeddings_dict = self.load_embeddings_from_db()
if not embeddings_dict:
return {'status': 'Error', 'message': 'No valid embeddings available in the database'}
if not self.validate_embedding(new_embeddings):
return {'status': 'Error', 'message': 'Invalid embedding shape'}
identity, similarity = self.match_faces(new_embeddings, embeddings_dict)
if identity:
return {
'status': 'Success',
'message': 'Match Found',
'user_id': identity,
'similarity': float(similarity) # Convert numpy float to Python float
}
return {
'status': 'Error',
'message': 'No matching face found or multiple potential matches detected'
}