from fastapi import FastAPI, File, UploadFile, HTTPException import cv2 import numpy as np from PIL import Image import io import base64 from transformers import ViTFeatureExtractor, ViTForImageClassification import torch app = FastAPI() # Cargar el modelo de clasificación de edad y el extractor model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') @app.post("/detect/") async def detect_face(file: UploadFile = File(...)): try: # Leer y procesar la imagen cargada image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) img_np = np.array(image) if img_np.shape[2] == 4: img_np = cv2.cvtColor(img_np, cv2.COLOR_BGRA2BGR) # Cargar el clasificador Haar para detección de rostros face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) if len(faces) == 0: raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.") # Procesar cada rostro detectado results = [] for (x, y, w, h) in faces: # Extraer el rostro de la imagen face_img = img_np[y:y+h, x:x+w] pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)) # Realizar la predicción de edad inputs = transforms(pil_face_img, return_tensors='pt') output = model(**inputs) proba = output.logits.softmax(1) preds = proba.argmax(1) # Asumimos que la predicción está representando un rango de edad (esto puede adaptarse más tarde) predicted_age_range = str(preds.item()) # Dibujar un rectángulo alrededor del rostro y añadir la edad predicha cv2.rectangle(img_np, (x, y), (x+w, y+h), (255, 0, 0), 2) cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2) results.append({ "edad_predicha": predicted_age_range, "coordenadas_rostro": (x, y, w, h) }) # Convertir la imagen procesada a base64 result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)) img_byte_arr = io.BytesIO() result_image.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() return { "message": "Rostros detectados y edad predicha", "rostros": len(faces), "resultados": results, "imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8') } except Exception as e: raise HTTPException(status_code=500, detail=str(e))