Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from scipy.spatial.distance import cosine | |
import cv2 | |
import os | |
RECOGNITION_THRESHOLD = 0.3 | |
# Assuming the PyTorch model is a ResNet (or similar) and has been trained accordingly | |
# Load the embedding model | |
embedding_model = torch.load('full_mode2.pth') | |
embedding_model.eval() # Set the model to evaluation mode | |
# Database to store embeddings and user IDs | |
user_embeddings = {} | |
# Preprocess the image | |
def preprocess_image(image): | |
image = cv2.resize(image, (375, 375)) # Resize image | |
image = image / 255.0 # Normalize pixel values | |
image = np.transpose(image, (2, 0, 1)) # Change from HWC to CHW format | |
return torch.tensor(image, dtype=torch.float32).unsqueeze(0) # Add batch dimension | |
# Generate embedding | |
def generate_embedding(image): | |
preprocessed_image = preprocess_image(image) | |
with torch.no_grad(): # No need to track gradients | |
return embedding_model(preprocessed_image).numpy()[0] | |
# Register new user | |
def register_user(image, user_id): | |
try: | |
embedding = generate_embedding(image) | |
user_embeddings[user_id] = embedding | |
return f"User {user_id} registered successfully." | |
except Exception as e: | |
return f"Error during registration: {str(e)}" | |
# Recognize user | |
def recognize_user(image): | |
try: | |
new_embedding = generate_embedding(image) | |
min_distance = float('inf') | |
recognized_user_id = "Unknown" | |
for user_id, embedding in user_embeddings.items(): | |
distance = cosine(new_embedding, embedding) | |
print(f"Distance for {user_id}: {distance}") # Debug: Print distances | |
if distance < min_distance: | |
min_distance = distance | |
recognized_user_id = user_id | |
print(f"Min distance: {min_distance}") # Debug: Print minimum distance | |
if min_distance > RECOGNITION_THRESHOLD: | |
return "User not recognized." | |
else: | |
return f"Recognized User: {recognized_user_id}" | |
except Exception as e: | |
return f"Error during recognition: {str(e)}" | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("Facial Recognition System") | |
with gr.Tab("Register"): | |
with gr.Row(): | |
img_register = gr.Image() | |
user_id = gr.Textbox(label="User ID") | |
register_button = gr.Button("Register") | |
register_output = gr.Textbox() | |
register_button.click(register_user, inputs=[img_register, user_id], outputs=register_output) | |
with gr.Tab("Recognize"): | |
with gr.Row(): | |
img_recognize = gr.Image() | |
recognize_button = gr.Button("Recognize") | |
recognize_output = gr.Textbox() | |
recognize_button.click(recognize_user, inputs=[img_recognize], outputs=recognize_output) | |
demo.launch(share=True) | |
if __name__ == "__main__": | |
main() |