Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
import gradio as gr | |
from PIL import Image | |
checkpoint_path = 'models/efiB2_27_12_24_f1.pt' | |
# Simulaci贸n de nombres de clases | |
CLASSES = ['audio_recorder', 'card_grid_md', 'card_grid_sm', 'card_grid_xl', 'conversational', 'crypto', 'date_range', 'image_filter', 'list_md', 'list_profile', 'list_sm', 'list_xl', 'map', 'music', 'nav_drawer', 'notification', 'rate', 'reel', 'setting', 'sign', 'splashscreen', 'video_fullscreen', 'walktrough', 'weather'] | |
def load_model(checkpoint_path: str) -> nn.Module: | |
# Crear el modelo original | |
model = models.efficientnet_b2(weights='DEFAULT') | |
# Modificar el clasificador para tener 24 clases | |
num_ftrs = model.classifier[-1].in_features | |
model.classifier[-1] = nn.Linear(num_ftrs, len(CLASSES)) | |
# Cargar los pesos y los checkpoints desde un archivo de checkpoint | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') # Asegurarse de cargar en la CPU | |
model.load_state_dict(checkpoint['model_state_dict']) | |
# Mover el modelo al dispositivo adecuado (que ahora es la CPU, pero no es necesario) | |
device = torch.device('cpu') | |
model.to(device) | |
model.eval() | |
return model | |
model = load_model(checkpoint_path) | |
# Cargar el modelo utilizando la funci贸n | |
# Funci贸n para hacer una predicci贸n con el modelo cargado | |
def predict_image(image): | |
# Redimensionar la imagen a 300x300 | |
image = Image.fromarray(image) | |
transform = transforms.Compose([ | |
transforms.Resize((260, 260)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
image = transform(image).unsqueeze(0) | |
# Mover la imagen a la CPU | |
device = torch.device('cpu') | |
image = image.to(device) | |
# Obtener la predicci贸n del modelo | |
with torch.no_grad(): | |
model.eval() | |
output = model(image) | |
# Obtener las probabilidades de las clases y sus 铆ndices | |
probabilities, indices = torch.topk(torch.softmax(output, dim=1), k=3) | |
probabilities = probabilities.tolist()[0] | |
indices = indices.tolist()[0] | |
# Obtener las clases y las confianzas correspondientes | |
top_classes = [CLASSES[idx] for idx in indices] | |
confidences = [round(prob * 1, 2) for prob in probabilities] | |
# Crear un diccionario que contenga las etiquetas y sus confianzas | |
label_dict = {cls: conf for cls, conf in zip(top_classes, confidences)} | |
# Devolver el resultado como un diccionario | |
return label_dict | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=predict_image, | |
inputs="image", | |
outputs=gr.Label(num_top_classes=3), # Mostrar las 3 clases m谩s probables con sus confianzas | |
title="POLIDATA | Modelo de evaluaci贸n de interfaz de usuario", | |
description="Jos茅 Luis Santorcuato Tapia.", | |
) | |
iface.launch(share=True) |