leonett's picture
Update app.py
10c4b17 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
from facenet_pytorch import MTCNN
from io import BytesIO
import logging
# Configuración
logging.basicConfig(level=logging.DEBUG)
device = torch.device("cpu")
CANVAS_SIZE = (512, 512) # Tamaño original del lienzo
# Modelo de envejecimiento
class AgingModel(torch.nn.Module):
def __init__(self):
super(AgingModel, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(16, 3, kernel_size=3, padding=1)
self.relu = torch.nn.ReLU()
torch.nn.init.kaiming_normal_(self.conv1.weight)
torch.nn.init.kaiming_normal_(self.conv2.weight)
def forward(self, x, age_factor):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = x * (1 + age_factor / 100.0)
return torch.clamp(x, 0, 1)
# Inicializar modelo y MTCNN
model = AgingModel().to(device)
model.eval()
mtcnn = MTCNN(image_size=128, margin=0, min_face_size=10, device=device) # Más sensible
def resize_to_canvas(image, canvas_size=CANVAS_SIZE):
"""Redimensionar imagen para ajustarse al lienzo manteniendo proporciones."""
img = image.convert("RGB")
img.thumbnail(canvas_size, Image.LANCZOS)
canvas = Image.new("RGB", canvas_size, (255, 255, 255)) # Fondo blanco
offset = ((canvas_size[0] - img.size[0]) // 2, (canvas_size[1] - img.size[1]) // 2)
canvas.paste(img, offset)
return canvas
def align_and_preprocess(image):
"""Alinear y preprocesar la imagen para el modelo."""
logging.debug("Starting face detection")
if isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
return None, f"Tipo de imagen no válido: {type(image)}"
detected_face, _ = mtcnn.detect(img)
logging.debug(f"Caras detectadas: {detected_face}")
if detected_face is None:
return None, "No se detectó una cara. Por favor, sube una imagen con un rostro claro."
box = detected_face[0].astype(int)
face_img = img.crop((box[0], box[1], box[2], box[3]))
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
return transform(face_img).unsqueeze(0).to(device), None
def generate_aged_image(input_image, target_age):
"""Generar imagen envejecida."""
logging.debug("Iniciando generación de imagen")
if input_image is None:
return None, "Por favor, sube una imagen válida."
input_image = resize_to_canvas(input_image)
input_tensor, error = align_and_preprocess(input_image)
if error:
return None, error
with torch.no_grad():
output = model(input_tensor, target_age)
logging.debug(f"Tensor de salida min: {output.min()}, max: {output.max()}")
output_image = output.squeeze().permute(1, 2, 0).cpu().numpy()
output_image = np.clip(output_image * 255, 0, 255).astype(np.uint8)
output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(output_image, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150) # Umbrales ajustados para mejor detección
edges = cv2.dilate(edges, None, iterations=1)
output_image[edges > 0] = output_image[edges > 0] * 0.6 # Arrugas más marcadas
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
output_image = output_image.astype(np.float32)
output_image = np.clip(output_image * (1 - target_age / 300.0), 0, 255).astype(np.uint8)
result = Image.fromarray(output_image)
result = resize_to_canvas(result) # Ajustar al lienzo final
buffer = BytesIO()
result.save(buffer, format="JPEG", quality=95)
jpeg_image = Image.open(buffer)
logging.debug("Generación de imagen completada")
return jpeg_image, None
# Interfaz de Gradio
def create_interface():
with gr.Blocks(css=".container {max-width: 1200px; margin: auto; padding: 20px; background-color: #f5f5f5;} .canvas {width: 512px !important; height: 512px !important; object-fit: contain;}") as interface:
gr.Markdown(
"""
# Sistema de Envejecimiento Facial para Personas Desaparecidas
Sube una imagen y selecciona la edad objetivo para generar una versión envejecida.
Este sistema está diseñado para apoyar a unidades policiales en la búsqueda de personas desaparecidas.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Imagen de entrada", type="pil", image_mode="RGB", height=512, width=512, elem_classes=["container", "canvas"])
age_slider = gr.Slider(0, 100, value=30, step=1, label="Edad objetivo")
submit_button = gr.Button("Generar imagen envejecida")
with gr.Column(scale=1):
output_image = gr.Image(label="Resultado envejecido", type="pil", image_mode="RGB", format="jpeg", height=512, width=512, elem_classes=["container", "canvas"])
error_message = gr.Textbox(label="Mensajes", interactive=False)
submit_button.click(
fn=generate_aged_image,
inputs=[input_image, age_slider],
outputs=[output_image, error_message]
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()