josesantorcuato commited on
Commit
e39e3b9
1 Parent(s): c40261a

Se crea y sube repositorio

Browse files
Files changed (4) hide show
  1. README.md +4 -4
  2. app.py +82 -0
  3. models/efiB2_27_12_24_f1.pt +3 -0
  4. requirements.txt +8 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Polidata
3
- emoji: 馃敟
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.12.0
8
  app_file: app.py
 
1
  ---
2
+ title: Polidata Gradio
3
+ emoji: 馃悹
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.12.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+
10
+ checkpoint_path = 'models/efiB2_27_12_24_f1.pt'
11
+
12
+ # Simulaci贸n de nombres de clases
13
+ CLASSES = ['audio_recorder', 'card_grid_md', 'card_grid_sm', 'card_grid_xl', 'conversational', 'crypto', 'date_range', 'image_filter', 'list_md', 'list_profile', 'list_sm', 'list_xl', 'map', 'music', 'nav_drawer', 'notification', 'rate', 'reel', 'setting', 'sign', 'splashscreen', 'video_fullscreen', 'walktrough', 'weather']
14
+
15
+ def load_model(checkpoint_path: str) -> nn.Module:
16
+ # Crear el modelo original
17
+ model = models.efficientnet_b2(weights='DEFAULT')
18
+
19
+ # Modificar el clasificador para tener 24 clases
20
+ num_ftrs = model.classifier[-1].in_features
21
+ model.classifier[-1] = nn.Linear(num_ftrs, len(CLASSES))
22
+
23
+ # Cargar los pesos y los checkpoints desde un archivo de checkpoint
24
+ checkpoint = torch.load(checkpoint_path, map_location='cpu') # Asegurarse de cargar en la CPU
25
+ model.load_state_dict(checkpoint['model_state_dict'])
26
+
27
+ # Mover el modelo al dispositivo adecuado (que ahora es la CPU, pero no es necesario)
28
+ device = torch.device('cpu')
29
+ model.to(device)
30
+
31
+ model.eval()
32
+
33
+ return model
34
+ model = load_model(checkpoint_path)
35
+ # Cargar el modelo utilizando la funci贸n
36
+
37
+
38
+ # Funci贸n para hacer una predicci贸n con el modelo cargado
39
+ def predict_image(image):
40
+ # Redimensionar la imagen a 300x300
41
+ image = Image.fromarray(image)
42
+ transform = transforms.Compose([
43
+ transforms.Resize((260, 260)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
46
+ ])
47
+ image = transform(image).unsqueeze(0)
48
+
49
+ # Mover la imagen a la CPU
50
+ device = torch.device('cpu')
51
+ image = image.to(device)
52
+
53
+ # Obtener la predicci贸n del modelo
54
+ with torch.no_grad():
55
+ model.eval()
56
+ output = model(image)
57
+
58
+ # Obtener las probabilidades de las clases y sus 铆ndices
59
+ probabilities, indices = torch.topk(torch.softmax(output, dim=1), k=3)
60
+ probabilities = probabilities.tolist()[0]
61
+ indices = indices.tolist()[0]
62
+
63
+ # Obtener las clases y las confianzas correspondientes
64
+ top_classes = [CLASSES[idx] for idx in indices]
65
+ confidences = [round(prob * 1, 2) for prob in probabilities]
66
+
67
+ # Crear un diccionario que contenga las etiquetas y sus confianzas
68
+ label_dict = {cls: conf for cls, conf in zip(top_classes, confidences)}
69
+
70
+ # Devolver el resultado como un diccionario
71
+ return label_dict
72
+
73
+ # Gradio Interface
74
+ iface = gr.Interface(
75
+ fn=predict_image,
76
+ inputs="image",
77
+ outputs=gr.Label(num_top_classes=3), # Mostrar las 3 clases m谩s probables con sus confianzas
78
+ title="POLIDATA | Modelo de evaluaci贸n de interfaz de usuario",
79
+ description="Jos茅 Luis Santorcuato Tapia.",
80
+ )
81
+
82
+ iface.launch(share=True)
models/efiB2_27_12_24_f1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31477e012996decb1f7fe450e56ad0bf428a5b14116cf340930cae1b2a9001e4
3
+ size 31670309
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchdata
3
+ torchaudio
4
+ torchinfo
5
+ torchtext
6
+ torchvision
7
+ gradio
8
+ Pillow