leonett's picture
Create app.py
a6cd9b7 verified
raw
history blame contribute delete
2.68 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from facenet_pytorch import MTCNN
from model import CAAE # Modelo del repositorio Face-Aging-CAAE
# Configuraci贸n del modelo (CPU)
device = torch.device("cpu")
model_path = "model/CAAE_MORPH.pth"
model = CAAE(latent_dim=128).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
mtcnn = MTCNN(image_size=128, margin=0, min_face_size=20, device=device)
def align_and_preprocess(image):
"""Alinear y preprocesar la imagen para el modelo."""
img = Image.fromarray(image).convert("RGB")
detected_face = mtcnn(img)
if detected_face is None:
raise ValueError("No se detect贸 una cara en la imagen.")
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
return transform(detected_face).unsqueeze(0).to(device)
def generate_aged_image(input_image, target_age):
"""Generar imagen envejecida."""
try:
input_tensor = align_and_preprocess(input_image)
age_tensor = torch.tensor([[target_age / 100.0]], dtype=torch.float32).to(device)
with torch.no_grad():
output = model(input_tensor, age_tensor)
output_image = output.squeeze().permute(1, 2, 0).cpu().numpy()
output_image = np.clip(output_image * 255, 0, 255).astype(np.uint8)
return Image.fromarray(output_image).resize((input_image.width, input_image.height))
except Exception as e:
return f"Error: {str(e)}"
# Interfaz de Gradio
def app():
interface = gr.Interface(
fn=generate_aged_image,
inputs=[
gr.Image(label="Imagen de entrada", type="pil"),
gr.Slider(0, 100, value=30, step=1, label="Edad objetivo")
],
outputs=gr.Image(label="Resultado", type="pil"),
examples=[
["example_images/input.jpg", 40], # Ejemplo 1
["example_images/input2.jpg", 60] # Ejemplo 2
],
title="Envejecimiento Facial",
description="Carga una imagen, elige una edad y genera su versi贸n envejecida."
)
return interface
if __name__ == "__main__":
# Aseg煤rate de tener las dependencias instaladas:
# pip install gradio torch torchvision facenet-pytorch numpy
# Descarga el modelo de Face-Aging-CAAE:
# git clone https://github.com/ZZUTK/Face-Aging-CAAE.git
# wget -P Face-Aging-CAAE/model/ https://raw.githubusercontent.com/ZZUTK/Face-Aging-CAAE/master/model/CAAE_MORPH.pth
app().launch()