Spaces:
Sleeping
Sleeping
Marcos12886
commited on
Commit
•
abdf62b
1
Parent(s):
2567e73
Todo bien esta noche muack
Browse files- app.py +115 -171
- interfaz.py +101 -0
- model.py +38 -24
app.py
CHANGED
@@ -2,25 +2,82 @@ import os
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import InferenceClient
|
5 |
-
from model import
|
|
|
6 |
|
7 |
token = os.getenv("HF_TOKEN")
|
8 |
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
|
|
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
with torch.no_grad():
|
18 |
outputs = model(**inputs)
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
return label
|
22 |
|
23 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
messages = [{"role": "system", "content": system_message}]
|
25 |
for val in history:
|
26 |
if val[0]:
|
@@ -29,102 +86,22 @@ def respond(message, history: list[tuple[str, str]], system_message, max_tokens,
|
|
29 |
messages.append({"role": "assistant", "content": val[1]})
|
30 |
messages.append({"role": "user", "content": message})
|
31 |
response = ""
|
32 |
-
for
|
33 |
-
token =
|
34 |
response += token
|
35 |
yield response
|
36 |
|
37 |
def cambiar_pestaña():
|
38 |
return gr.update(visible=False), gr.update(visible=True)
|
39 |
|
40 |
-
my_theme = gr.themes.Soft(
|
41 |
-
primary_hue="emerald",
|
42 |
-
secondary_hue="green",
|
43 |
-
neutral_hue="slate",
|
44 |
-
text_size="sm",
|
45 |
-
spacing_size="sm",
|
46 |
-
font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
47 |
-
font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
|
48 |
-
).set(
|
49 |
-
body_background_fill='*neutral_50',
|
50 |
-
body_text_color='*neutral_600',
|
51 |
-
body_text_size='*text_sm',
|
52 |
-
embed_radius='*radius_md',
|
53 |
-
shadow_drop='*shadow_spread',
|
54 |
-
shadow_spread='*button_shadow_active'
|
55 |
-
)
|
56 |
-
|
57 |
with gr.Blocks(theme=my_theme) as demo:
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
h1 {
|
66 |
-
font-family: 'Lobster', cursive;
|
67 |
-
font-size: 5em !important;
|
68 |
-
text-align: center;
|
69 |
-
margin: 0;
|
70 |
-
}
|
71 |
-
|
72 |
-
.gr-button {
|
73 |
-
background-color: #4CAF50 !important;
|
74 |
-
color: white !important;
|
75 |
-
border: none;
|
76 |
-
padding: 15px 32px;
|
77 |
-
text-align: center;
|
78 |
-
text-decoration: none;
|
79 |
-
display: inline-block;
|
80 |
-
font-size: 16px;
|
81 |
-
margin: 4px 2px;
|
82 |
-
cursor: pointer;
|
83 |
-
border-radius: 12px;
|
84 |
-
}
|
85 |
-
|
86 |
-
.gr-button:hover {
|
87 |
-
background-color: #45a049;
|
88 |
-
}
|
89 |
-
h2 {
|
90 |
-
font-family: 'Lobster', cursive;
|
91 |
-
font-size: 3em !important;
|
92 |
-
text-align: center;
|
93 |
-
margin: 0;
|
94 |
-
}
|
95 |
-
p.slogan, h4, p, h3 {
|
96 |
-
font-family: 'Roboto', sans-serif;
|
97 |
-
text-align: center;
|
98 |
-
}
|
99 |
-
</style>
|
100 |
-
<h1>Iremia</h1>
|
101 |
-
<h4 style='text-align: center; font-size: 1.5em'>Tu aliado para el bienestar de tu bebé</h4>
|
102 |
-
"""
|
103 |
-
)
|
104 |
-
gr.Markdown(
|
105 |
-
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
|
106 |
-
)
|
107 |
-
gr.Markdown(
|
108 |
-
"<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
|
109 |
-
)
|
110 |
-
gr.Markdown(
|
111 |
-
"<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
|
112 |
-
)
|
113 |
-
gr.Markdown(
|
114 |
-
"<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
|
115 |
-
)
|
116 |
-
gr.Markdown(
|
117 |
-
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
|
118 |
-
)
|
119 |
-
gr.Markdown(
|
120 |
-
"<p style='text-align: left'>Iremia ofrece dos funcionalidades muy interesantes:</p>"
|
121 |
-
)
|
122 |
-
gr.Markdown(
|
123 |
-
"<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial, somos capaces de predecir por qué tu hijo de menos de 2 años está llorando. Además, tendrás acceso a un asistente personal para consultar cualquier duda que tengas sobre el cuidado de tu pequeño.</p>"
|
124 |
-
)
|
125 |
-
gr.Markdown(
|
126 |
-
"<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no, y si está llorando, predice automáticamente la causa, lo cual te brindará la tranquilidad de saber siempre qué pasa con tu pequeño y te ahorrará tiempo y muchas horas de sueño.</p>"
|
127 |
-
)
|
128 |
with gr.Row():
|
129 |
with gr.Column():
|
130 |
gr.Markdown("<h2>Predictor</h2>")
|
@@ -134,74 +111,41 @@ with gr.Blocks(theme=my_theme) as demo:
|
|
134 |
gr.Markdown("<h2>Monitor</h2>")
|
135 |
boton_pagina_2 = gr.Button("Prueba el monitor")
|
136 |
gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
|
137 |
-
with gr.Column(visible=False) as
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
audio_input = gr.Audio(
|
176 |
-
min_length=1.0,
|
177 |
-
format="wav",
|
178 |
-
label="Baby recorder",
|
179 |
-
type="filepath", # Para no usar numpy y preprocesar siempre igual
|
180 |
-
)
|
181 |
-
classify_btn = gr.Button("¿Por qué llora?")
|
182 |
-
classification_output = gr.Textbox(label="Tu bebé está:")
|
183 |
-
classify_btn.click(
|
184 |
-
lambda audio: predict(audio, dataset_path="data/baby_cry_detection"),
|
185 |
-
inputs=audio_input,
|
186 |
-
outputs=classification_output
|
187 |
-
)
|
188 |
-
with gr.Column():
|
189 |
-
gr.Markdown("<h2>Assistant</h2>")
|
190 |
-
system_message = "You are a Chatbot specialized in baby health and care."
|
191 |
-
max_tokens = 512
|
192 |
-
temperature = 0.7
|
193 |
-
top_p = 0.95
|
194 |
-
chatbot = gr.ChatInterface(
|
195 |
-
respond, # TODO: Cambiar para que argumentos estén aquí metidos
|
196 |
-
additional_inputs=[
|
197 |
-
gr.State(value=system_message),
|
198 |
-
gr.State(value=max_tokens),
|
199 |
-
gr.State(value=temperature),
|
200 |
-
gr.State(value=top_p)
|
201 |
-
],
|
202 |
-
)
|
203 |
-
gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
|
204 |
-
boton_volver_inicio_2 = gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pagina_2, pantalla_inicial])
|
205 |
-
boton_pagina_1.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_1])
|
206 |
-
boton_pagina_2.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_2])
|
207 |
-
demo.launch()
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import InferenceClient
|
5 |
+
from model import predict_params, AudioDataset
|
6 |
+
from interfaz import estilo, my_theme
|
7 |
|
8 |
token = os.getenv("HF_TOKEN")
|
9 |
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
|
10 |
+
model_cache = {}
|
11 |
|
12 |
+
def load_model_and_dataset(model_path, dataset_path, filter_white_noise):
|
13 |
+
if (model_path, dataset_path, filter_white_noise) not in model_cache:
|
14 |
+
model, _, _, id2label = predict_params(dataset_path, model_path, filter_white_noise)
|
15 |
+
model_cache[(model_path, dataset_path, filter_white_noise)] = (model, id2label)
|
16 |
+
return model_cache[(model_path, dataset_path, filter_white_noise)]
|
17 |
+
|
18 |
+
def predict(audio_path, model_path, dataset_path, filter_white_noise):
|
19 |
+
model, id2label = load_model_and_dataset(model_path, dataset_path, filter_white_noise)
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
model.to(device)
|
22 |
+
model.eval()
|
23 |
+
audios = AudioDataset(dataset_path, {}, filter_white_noise).preprocess_audio(audio_path)
|
24 |
+
inputs = {"input_values": audios.to(device).unsqueeze(0)}
|
25 |
with torch.no_grad():
|
26 |
outputs = model(**inputs)
|
27 |
+
logits = outputs.logits
|
28 |
+
predicted_class_ids = torch.argmax(logits, dim=-1).item()
|
29 |
+
label = id2label[predicted_class_ids]
|
30 |
+
if dataset_path == "data/mixed_data":
|
31 |
+
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
|
32 |
+
label = label_mapping.get(predicted_class_ids, label)
|
33 |
return label
|
34 |
|
35 |
+
def predict_stream(audio_path):
|
36 |
+
model_mon, _ = load_model_and_dataset(
|
37 |
+
model_path="distilhubert-finetuned-cry-detector",
|
38 |
+
dataset_path="data/baby_cry_detection",
|
39 |
+
filter_white_noise=False
|
40 |
+
)
|
41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
model_mon.to(device)
|
43 |
+
model_mon.eval()
|
44 |
+
audio_dataset = AudioDataset(dataset_path="data/baby_cry_detection", label2id={}, filter_white_noise=False)
|
45 |
+
processed_audio = audio_dataset.preprocess_audio(audio_path)
|
46 |
+
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
|
47 |
+
with torch.no_grad():
|
48 |
+
outputs = model_mon(**inputs)
|
49 |
+
logits = outputs.logits
|
50 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
51 |
+
crying_probabilities = probabilities[:, 1]
|
52 |
+
avg_crying_probability = crying_probabilities.mean().item()*100
|
53 |
+
if avg_crying_probability < 25:
|
54 |
+
model_class, id2label = load_model_and_dataset(
|
55 |
+
model_path="distilhubert-finetuned-mixed-data",
|
56 |
+
dataset_path="data/mixed_data",
|
57 |
+
filter_white_noise=True
|
58 |
+
)
|
59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
+
model_class.to(device)
|
61 |
+
model_class.eval()
|
62 |
+
audio_dataset_class = AudioDataset(dataset_path="data/mixed_data", label2id={}, filter_white_noise=True)
|
63 |
+
processed_audio_class = audio_dataset_class.preprocess_audio(audio_path)
|
64 |
+
inputs_class = {"input_values": processed_audio_class.to(device).unsqueeze(0)}
|
65 |
+
with torch.no_grad():
|
66 |
+
outputs_class = model_class(**inputs_class)
|
67 |
+
logits_class = outputs_class.logits
|
68 |
+
predicted_class_ids_class = torch.argmax(logits_class, dim=-1).item()
|
69 |
+
label_class = id2label[predicted_class_ids_class]
|
70 |
+
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
|
71 |
+
label_class = label_mapping.get(predicted_class_ids_class, label_class)
|
72 |
+
return f"Bebé llorando por {label_class}. Probabilidad: {avg_crying_probability:.1f})"
|
73 |
+
else:
|
74 |
+
return f"No está llorando. Proabilidad: {avg_crying_probability:.1f})"
|
75 |
+
|
76 |
+
def chatbot_config(message, history: list[tuple[str, str]]):
|
77 |
+
system_message = "You are a Chatbot specialized in baby health and care."
|
78 |
+
max_tokens = 512
|
79 |
+
temperature = 0.7
|
80 |
+
top_p = 0.95
|
81 |
messages = [{"role": "system", "content": system_message}]
|
82 |
for val in history:
|
83 |
if val[0]:
|
|
|
86 |
messages.append({"role": "assistant", "content": val[1]})
|
87 |
messages.append({"role": "user", "content": message})
|
88 |
response = ""
|
89 |
+
for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
|
90 |
+
token = message_response.choices[0].delta.content
|
91 |
response += token
|
92 |
yield response
|
93 |
|
94 |
def cambiar_pestaña():
|
95 |
return gr.update(visible=False), gr.update(visible=True)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
with gr.Blocks(theme=my_theme) as demo:
|
98 |
+
estilo()
|
99 |
+
with gr.Column(visible=True) as chatbot:
|
100 |
+
gr.Markdown("<h2>Asistente</h2>")
|
101 |
+
gr.ChatInterface(
|
102 |
+
chatbot_config # TODO: Mirar argumentos
|
103 |
+
)
|
104 |
+
gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
with gr.Row():
|
106 |
with gr.Column():
|
107 |
gr.Markdown("<h2>Predictor</h2>")
|
|
|
111 |
gr.Markdown("<h2>Monitor</h2>")
|
112 |
boton_pagina_2 = gr.Button("Prueba el monitor")
|
113 |
gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
|
114 |
+
with gr.Column(visible=False) as pag_predictor:
|
115 |
+
gr.Markdown("<h2>Predictor</h2>")
|
116 |
+
audio_input = gr.Audio(
|
117 |
+
min_length=1.0,
|
118 |
+
format="wav",
|
119 |
+
label="Baby recorder",
|
120 |
+
type="filepath",
|
121 |
+
)
|
122 |
+
classify_btn = gr.Button("¿Por qué llora?")
|
123 |
+
classify_btn.click(
|
124 |
+
lambda audio: predict( # Mirar porque usar lambda
|
125 |
+
audio,
|
126 |
+
model_path="distilhubert-finetuned-mixed-data",
|
127 |
+
dataset_path="data/mixed_data",
|
128 |
+
filter_white_noise=True
|
129 |
+
),
|
130 |
+
inputs=audio_input,
|
131 |
+
outputs=gr.Textbox(label="Tu bebé llora por:")
|
132 |
+
)
|
133 |
+
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
|
134 |
+
with gr.Column(visible=False) as pag_monitor:
|
135 |
+
gr.Markdown("<h2>Monitor</h2>")
|
136 |
+
audio_stream = gr.Audio(
|
137 |
+
# min_length=1.0, # mirar por qué no va esto
|
138 |
+
format="wav",
|
139 |
+
label="Baby recorder",
|
140 |
+
type="filepath",
|
141 |
+
streaming=True
|
142 |
+
)
|
143 |
+
audio_stream.stream(
|
144 |
+
predict_stream,
|
145 |
+
inputs=audio_stream,
|
146 |
+
outputs=gr.Textbox(label="Tu bebé está:"),
|
147 |
+
)
|
148 |
+
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
|
149 |
+
boton_pagina_1.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
|
150 |
+
boton_pagina_2.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
|
151 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interfaz.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
my_theme = gr.themes.Soft(
|
4 |
+
primary_hue="emerald",
|
5 |
+
secondary_hue="green",
|
6 |
+
neutral_hue="slate",
|
7 |
+
text_size="sm",
|
8 |
+
spacing_size="sm",
|
9 |
+
font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
10 |
+
font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
|
11 |
+
).set(
|
12 |
+
body_background_fill='*neutral_50',
|
13 |
+
body_text_color='*neutral_600',
|
14 |
+
body_text_size='*text_sm',
|
15 |
+
embed_radius='*radius_md',
|
16 |
+
shadow_drop='*shadow_spread',
|
17 |
+
shadow_spread='*button_shadow_active'
|
18 |
+
)
|
19 |
+
def estilo():
|
20 |
+
gr.HTML(
|
21 |
+
"""
|
22 |
+
<style>
|
23 |
+
@import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
|
24 |
+
@import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
|
25 |
+
|
26 |
+
h1 {
|
27 |
+
font-family: 'Lobster', cursive;
|
28 |
+
font-size: 5em !important;
|
29 |
+
text-align: center;
|
30 |
+
margin: 0;
|
31 |
+
}
|
32 |
+
|
33 |
+
.gr-button {
|
34 |
+
background-color: #4CAF50 !important;
|
35 |
+
color: white !important;
|
36 |
+
border: none;
|
37 |
+
padding: 15px 32px;
|
38 |
+
text-align: center;
|
39 |
+
text-decoration: none;
|
40 |
+
display: inline-block;
|
41 |
+
font-size: 16px;
|
42 |
+
margin: 4px 2px;
|
43 |
+
cursor: pointer;
|
44 |
+
border-radius: 12px;
|
45 |
+
}
|
46 |
+
|
47 |
+
.gr-button:hover {
|
48 |
+
background-color: #45a049;
|
49 |
+
}
|
50 |
+
h2 {
|
51 |
+
font-family: 'Lobster', cursive;
|
52 |
+
font-size: 3em !important;
|
53 |
+
text-align: center;
|
54 |
+
margin: 0;
|
55 |
+
}
|
56 |
+
p.slogan, h4, p, h3 {
|
57 |
+
font-family: 'Roboto', sans-serif;
|
58 |
+
text-align: center;
|
59 |
+
}
|
60 |
+
</style>
|
61 |
+
<h1>Iremia</h1>
|
62 |
+
<h4 style='text-align: center; font-size: 1.5em'>Tu aliado para el bienestar de tu bebé</h4>
|
63 |
+
"""
|
64 |
+
)
|
65 |
+
return my_theme
|
66 |
+
def inicio():
|
67 |
+
estilo()
|
68 |
+
gr.Markdown(
|
69 |
+
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
|
70 |
+
)
|
71 |
+
gr.Markdown(
|
72 |
+
"<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
|
73 |
+
)
|
74 |
+
gr.Markdown(
|
75 |
+
"<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
|
76 |
+
)
|
77 |
+
gr.Markdown(
|
78 |
+
"<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
|
79 |
+
)
|
80 |
+
gr.Markdown(
|
81 |
+
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
|
82 |
+
)
|
83 |
+
gr.Markdown(
|
84 |
+
"<p style='text-align: left'>Iremia ofrece dos funcionalidades muy interesantes:</p>"
|
85 |
+
)
|
86 |
+
gr.Markdown(
|
87 |
+
"<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial, somos capaces de predecir por qué tu hijo de menos de 2 años está llorando. Además, tendrás acceso a un asistente personal para consultar cualquier duda que tengas sobre el cuidado de tu pequeño.</p>"
|
88 |
+
)
|
89 |
+
gr.Markdown(
|
90 |
+
"<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no, y si está llorando, predice automáticamente la causa, lo cual te brindará la tranquilidad de saber siempre qué pasa con tu pequeño y te ahorrará tiempo y muchas horas de sueño.</p>"
|
91 |
+
)
|
92 |
+
with gr.Row():
|
93 |
+
with gr.Column():
|
94 |
+
gr.Markdown("<h2>Predictor</h2>")
|
95 |
+
boton_pagina_1 = gr.Button("Prueba el predictor")
|
96 |
+
gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
|
97 |
+
with gr.Column():
|
98 |
+
gr.Markdown("<h2>Monitor</h2>")
|
99 |
+
boton_pagina_2 = gr.Button("Prueba el monitor")
|
100 |
+
gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
|
101 |
+
return boton_pagina_1, boton_pagina_2
|
model.py
CHANGED
@@ -23,11 +23,13 @@ config_file = "models_config.json"
|
|
23 |
clasificador = "class"
|
24 |
monitor = "mon"
|
25 |
batch_size = 16
|
|
|
26 |
|
27 |
class AudioDataset(Dataset):
|
28 |
-
def __init__(self, dataset_path, label2id):
|
29 |
self.dataset_path = dataset_path
|
30 |
self.label2id = label2id
|
|
|
31 |
self.file_paths = []
|
32 |
self.labels = []
|
33 |
for label_dir, label_id in self.label2id.items():
|
@@ -37,7 +39,7 @@ class AudioDataset(Dataset):
|
|
37 |
audio_path = os.path.join(label_path, file_name)
|
38 |
self.file_paths.append(audio_path)
|
39 |
self.labels.append(label_id)
|
40 |
-
self.file_paths.sort(key=lambda x: x.split('_part')[0])
|
41 |
|
42 |
def __len__(self):
|
43 |
return len(self.file_paths)
|
@@ -55,29 +57,33 @@ class AudioDataset(Dataset):
|
|
55 |
waveform, sample_rate = torchaudio.load(
|
56 |
audio_path,
|
57 |
normalize=True, # Convierte a float32
|
58 |
-
# num_frames= # TODO: Probar para que no haga falta recortar los audios
|
59 |
)
|
60 |
if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
|
61 |
resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
|
62 |
waveform = resampler(waveform)
|
63 |
if waveform.shape[0] > 1: # Si es stereo, convertir a mono
|
64 |
waveform = waveform.mean(dim=0, keepdim=True)
|
65 |
-
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) #
|
66 |
max_length = int(SAMPLING_RATE * MAX_DURATION)
|
67 |
if waveform.shape[1] > max_length:
|
68 |
-
waveform = waveform[:, :max_length]
|
69 |
else:
|
70 |
-
waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1]))
|
71 |
inputs = FEATURE_EXTRACTOR(
|
72 |
waveform.squeeze(),
|
73 |
-
sampling_rate=SAMPLING_RATE,
|
74 |
return_tensors="pt",
|
75 |
# max_length=int(SAMPLING_RATE * MAX_DURATION),
|
76 |
-
# truncation=True,
|
77 |
-
padding=True,
|
78 |
)
|
79 |
return inputs.input_values.squeeze()
|
80 |
|
|
|
|
|
|
|
|
|
|
|
81 |
def seed_everything():
|
82 |
torch.manual_seed(seed)
|
83 |
torch.cuda.manual_seed(seed)
|
@@ -96,9 +102,9 @@ def build_label_mappings(dataset_path):
|
|
96 |
label_id += 1
|
97 |
return label2id, id2label
|
98 |
|
99 |
-
def create_dataloader(dataset_path, test_size=0.2,
|
100 |
label2id, id2label = build_label_mappings(dataset_path)
|
101 |
-
dataset = AudioDataset(dataset_path, label2id)
|
102 |
dataset_size = len(dataset)
|
103 |
indices = list(range(dataset_size))
|
104 |
random.shuffle(indices)
|
@@ -115,9 +121,9 @@ def create_dataloader(dataset_path, test_size=0.2, num_workers=12, shuffle=True,
|
|
115 |
)
|
116 |
return train_dataloader, test_dataloader, label2id, id2label
|
117 |
|
118 |
-
def load_model(num_labels, label2id, id2label):
|
119 |
config = HubertConfig.from_pretrained(
|
120 |
-
|
121 |
num_labels=num_labels,
|
122 |
label2id=label2id,
|
123 |
id2label=id2label,
|
@@ -125,18 +131,23 @@ def load_model(num_labels, label2id, id2label):
|
|
125 |
)
|
126 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
127 |
model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
|
128 |
-
|
129 |
config=config,
|
130 |
-
torch_dtype=torch.float32,
|
131 |
)
|
132 |
model.to(device)
|
133 |
return model
|
134 |
|
135 |
-
def
|
136 |
-
train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path)
|
137 |
-
model = load_model(num_labels=len(id2label), label2id=label2id, id2label=id2label)
|
138 |
return model, train_dataloader, test_dataloader, id2label
|
139 |
|
|
|
|
|
|
|
|
|
|
|
140 |
def compute_metrics(eval_pred):
|
141 |
predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
|
142 |
references = torch.tensor(eval_pred.label_ids)
|
@@ -149,9 +160,9 @@ def compute_metrics(eval_pred):
|
|
149 |
"f1": f1,
|
150 |
}
|
151 |
|
152 |
-
def main(training_args, output_dir, dataset_path):
|
153 |
seed_everything()
|
154 |
-
model, train_dataloader, test_dataloader, _ =
|
155 |
trainer = Trainer(
|
156 |
model=model,
|
157 |
args=training_args,
|
@@ -162,9 +173,10 @@ def main(training_args, output_dir, dataset_path):
|
|
162 |
)
|
163 |
torch.cuda.empty_cache() # liberar memoria de la GPU
|
164 |
trainer.train() # se pueden modificar los parámetros para continuar el train
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
168 |
|
169 |
def load_config(model_name):
|
170 |
with open(config_file, 'r') as f:
|
@@ -176,8 +188,10 @@ def load_config(model_name):
|
|
176 |
|
177 |
if __name__ == "__main__":
|
178 |
config = load_config(clasificador) # PARA CAMBIAR MODELOS
|
|
|
179 |
# config = load_config(monitor) # PARA CAMBIAR MODELOS
|
|
|
180 |
training_args = config["training_args"]
|
181 |
output_dir = config["output_dir"]
|
182 |
dataset_path = config["dataset_path"]
|
183 |
-
main(training_args, output_dir, dataset_path)
|
|
|
23 |
clasificador = "class"
|
24 |
monitor = "mon"
|
25 |
batch_size = 16
|
26 |
+
num_workers = 12
|
27 |
|
28 |
class AudioDataset(Dataset):
|
29 |
+
def __init__(self, dataset_path, label2id, filter_white_noise):
|
30 |
self.dataset_path = dataset_path
|
31 |
self.label2id = label2id
|
32 |
+
self.filter_white_noise = filter_white_noise
|
33 |
self.file_paths = []
|
34 |
self.labels = []
|
35 |
for label_dir, label_id in self.label2id.items():
|
|
|
39 |
audio_path = os.path.join(label_path, file_name)
|
40 |
self.file_paths.append(audio_path)
|
41 |
self.labels.append(label_id)
|
42 |
+
self.file_paths.sort(key=lambda x: x.split('_part')[0]) # no sé si influye
|
43 |
|
44 |
def __len__(self):
|
45 |
return len(self.file_paths)
|
|
|
57 |
waveform, sample_rate = torchaudio.load(
|
58 |
audio_path,
|
59 |
normalize=True, # Convierte a float32
|
|
|
60 |
)
|
61 |
if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
|
62 |
resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
|
63 |
waveform = resampler(waveform)
|
64 |
if waveform.shape[0] > 1: # Si es stereo, convertir a mono
|
65 |
waveform = waveform.mean(dim=0, keepdim=True)
|
66 |
+
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Normalizar, sin 1e-6 el accuracy es pésimo!!
|
67 |
max_length = int(SAMPLING_RATE * MAX_DURATION)
|
68 |
if waveform.shape[1] > max_length:
|
69 |
+
waveform = waveform[:, :max_length] # Truncar
|
70 |
else:
|
71 |
+
waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
|
72 |
inputs = FEATURE_EXTRACTOR(
|
73 |
waveform.squeeze(),
|
74 |
+
sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
|
75 |
return_tensors="pt",
|
76 |
# max_length=int(SAMPLING_RATE * MAX_DURATION),
|
77 |
+
# truncation=True, # Hecho a mano
|
78 |
+
# padding=True, # Hecho a mano
|
79 |
)
|
80 |
return inputs.input_values.squeeze()
|
81 |
|
82 |
+
def is_white_noise(audio):
|
83 |
+
mean = torch.mean(audio)
|
84 |
+
std = torch.std(audio)
|
85 |
+
return torch.abs(mean) < 0.001 and std < 0.01
|
86 |
+
|
87 |
def seed_everything():
|
88 |
torch.manual_seed(seed)
|
89 |
torch.cuda.manual_seed(seed)
|
|
|
102 |
label_id += 1
|
103 |
return label2id, id2label
|
104 |
|
105 |
+
def create_dataloader(dataset_path, filter_white_noise, test_size=0.2, shuffle=True, pin_memory=True):
|
106 |
label2id, id2label = build_label_mappings(dataset_path)
|
107 |
+
dataset = AudioDataset(dataset_path, label2id, filter_white_noise)
|
108 |
dataset_size = len(dataset)
|
109 |
indices = list(range(dataset_size))
|
110 |
random.shuffle(indices)
|
|
|
121 |
)
|
122 |
return train_dataloader, test_dataloader, label2id, id2label
|
123 |
|
124 |
+
def load_model(model_path, num_labels, label2id, id2label):
|
125 |
config = HubertConfig.from_pretrained(
|
126 |
+
pretrained_model_name_or_path=model_path,
|
127 |
num_labels=num_labels,
|
128 |
label2id=label2id,
|
129 |
id2label=id2label,
|
|
|
131 |
)
|
132 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
133 |
model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
|
134 |
+
pretrained_model_name_or_path=model_path,
|
135 |
config=config,
|
136 |
+
torch_dtype=torch.float32,
|
137 |
)
|
138 |
model.to(device)
|
139 |
return model
|
140 |
|
141 |
+
def train_params(dataset_path, filter_white_noise):
|
142 |
+
train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
|
143 |
+
model = load_model(model_path=MODEL, num_labels=len(id2label), label2id=label2id, id2label=id2label)
|
144 |
return model, train_dataloader, test_dataloader, id2label
|
145 |
|
146 |
+
def predict_params(dataset_path, model_path, filter_white_noise):
|
147 |
+
_, _, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
|
148 |
+
model = load_model(model_path, num_labels=len(id2label), label2id=label2id, id2label=id2label)
|
149 |
+
return model, None, None, id2label
|
150 |
+
|
151 |
def compute_metrics(eval_pred):
|
152 |
predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
|
153 |
references = torch.tensor(eval_pred.label_ids)
|
|
|
160 |
"f1": f1,
|
161 |
}
|
162 |
|
163 |
+
def main(training_args, output_dir, dataset_path, filter_white_noise):
|
164 |
seed_everything()
|
165 |
+
model, train_dataloader, test_dataloader, _ = train_params(dataset_path, filter_white_noise)
|
166 |
trainer = Trainer(
|
167 |
model=model,
|
168 |
args=training_args,
|
|
|
173 |
)
|
174 |
torch.cuda.empty_cache() # liberar memoria de la GPU
|
175 |
trainer.train() # se pueden modificar los parámetros para continuar el train
|
176 |
+
# trainer.save_model(output_dir) # Guardar modelo local.
|
177 |
+
os.makedirs(output_dir, exist_ok=True) # Crear carpeta
|
178 |
+
trainer.push_to_hub(token=token) # Subir modelo a perfil
|
179 |
+
upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local
|
180 |
|
181 |
def load_config(model_name):
|
182 |
with open(config_file, 'r') as f:
|
|
|
188 |
|
189 |
if __name__ == "__main__":
|
190 |
config = load_config(clasificador) # PARA CAMBIAR MODELOS
|
191 |
+
filter_white_noise = True
|
192 |
# config = load_config(monitor) # PARA CAMBIAR MODELOS
|
193 |
+
# filter_white_noise = False
|
194 |
training_args = config["training_args"]
|
195 |
output_dir = config["output_dir"]
|
196 |
dataset_path = config["dataset_path"]
|
197 |
+
main(training_args, output_dir, dataset_path, filter_white_noise)
|