polidata / app.py
josesantorcuato's picture
Se crea y sube repositorio
e39e3b9
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
checkpoint_path = 'models/efiB2_27_12_24_f1.pt'
# Simulaci贸n de nombres de clases
CLASSES = ['audio_recorder', 'card_grid_md', 'card_grid_sm', 'card_grid_xl', 'conversational', 'crypto', 'date_range', 'image_filter', 'list_md', 'list_profile', 'list_sm', 'list_xl', 'map', 'music', 'nav_drawer', 'notification', 'rate', 'reel', 'setting', 'sign', 'splashscreen', 'video_fullscreen', 'walktrough', 'weather']
def load_model(checkpoint_path: str) -> nn.Module:
# Crear el modelo original
model = models.efficientnet_b2(weights='DEFAULT')
# Modificar el clasificador para tener 24 clases
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, len(CLASSES))
# Cargar los pesos y los checkpoints desde un archivo de checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu') # Asegurarse de cargar en la CPU
model.load_state_dict(checkpoint['model_state_dict'])
# Mover el modelo al dispositivo adecuado (que ahora es la CPU, pero no es necesario)
device = torch.device('cpu')
model.to(device)
model.eval()
return model
model = load_model(checkpoint_path)
# Cargar el modelo utilizando la funci贸n
# Funci贸n para hacer una predicci贸n con el modelo cargado
def predict_image(image):
# Redimensionar la imagen a 300x300
image = Image.fromarray(image)
transform = transforms.Compose([
transforms.Resize((260, 260)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = transform(image).unsqueeze(0)
# Mover la imagen a la CPU
device = torch.device('cpu')
image = image.to(device)
# Obtener la predicci贸n del modelo
with torch.no_grad():
model.eval()
output = model(image)
# Obtener las probabilidades de las clases y sus 铆ndices
probabilities, indices = torch.topk(torch.softmax(output, dim=1), k=3)
probabilities = probabilities.tolist()[0]
indices = indices.tolist()[0]
# Obtener las clases y las confianzas correspondientes
top_classes = [CLASSES[idx] for idx in indices]
confidences = [round(prob * 1, 2) for prob in probabilities]
# Crear un diccionario que contenga las etiquetas y sus confianzas
label_dict = {cls: conf for cls, conf in zip(top_classes, confidences)}
# Devolver el resultado como un diccionario
return label_dict
# Gradio Interface
iface = gr.Interface(
fn=predict_image,
inputs="image",
outputs=gr.Label(num_top_classes=3), # Mostrar las 3 clases m谩s probables con sus confianzas
title="POLIDATA | Modelo de evaluaci贸n de interfaz de usuario",
description="Jos茅 Luis Santorcuato Tapia.",
)
iface.launch(share=True)