Marcos12886 commited on
Commit
abdf62b
1 Parent(s): 2567e73

Todo bien esta noche muack

Browse files
Files changed (3) hide show
  1. app.py +115 -171
  2. interfaz.py +101 -0
  3. model.py +38 -24
app.py CHANGED
@@ -2,25 +2,82 @@ import os
2
  import torch
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
- from model import model_params, AudioDataset
 
6
 
7
  token = os.getenv("HF_TOKEN")
8
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
 
9
 
10
- def predict(audio_path, dataset_path):
11
- model, _, _, id2label = model_params(dataset_path)
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Usar a GPU o CPU
13
- model.to(device)# Usar a GPU o CPU
14
- audio_dataset = AudioDataset(dataset_path, {})
15
- inputs = audio_dataset.preprocess_audio(audio_path)
16
- inputs = {"input_values": inputs.to(device).unsqueeze(0)} # Add batch dimension
 
 
 
 
 
 
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
- predicted_class_ids = outputs.logits.argmax(-1)
20
- label = id2label[predicted_class_ids.item()]
 
 
 
 
21
  return label
22
 
23
- def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  messages = [{"role": "system", "content": system_message}]
25
  for val in history:
26
  if val[0]:
@@ -29,102 +86,22 @@ def respond(message, history: list[tuple[str, str]], system_message, max_tokens,
29
  messages.append({"role": "assistant", "content": val[1]})
30
  messages.append({"role": "user", "content": message})
31
  response = ""
32
- for message in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p): # Creo que lo importante para el modelo
33
- token = message.choices[0].delta.content
34
  response += token
35
  yield response
36
 
37
  def cambiar_pestaña():
38
  return gr.update(visible=False), gr.update(visible=True)
39
 
40
- my_theme = gr.themes.Soft(
41
- primary_hue="emerald",
42
- secondary_hue="green",
43
- neutral_hue="slate",
44
- text_size="sm",
45
- spacing_size="sm",
46
- font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
47
- font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
48
- ).set(
49
- body_background_fill='*neutral_50',
50
- body_text_color='*neutral_600',
51
- body_text_size='*text_sm',
52
- embed_radius='*radius_md',
53
- shadow_drop='*shadow_spread',
54
- shadow_spread='*button_shadow_active'
55
- )
56
-
57
  with gr.Blocks(theme=my_theme) as demo:
58
- with gr.Column(visible=True, elem_id="pantalla-inicial") as pantalla_inicial:
59
- gr.HTML(
60
- """
61
- <style>
62
- @import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
63
- @import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
64
-
65
- h1 {
66
- font-family: 'Lobster', cursive;
67
- font-size: 5em !important;
68
- text-align: center;
69
- margin: 0;
70
- }
71
-
72
- .gr-button {
73
- background-color: #4CAF50 !important;
74
- color: white !important;
75
- border: none;
76
- padding: 15px 32px;
77
- text-align: center;
78
- text-decoration: none;
79
- display: inline-block;
80
- font-size: 16px;
81
- margin: 4px 2px;
82
- cursor: pointer;
83
- border-radius: 12px;
84
- }
85
-
86
- .gr-button:hover {
87
- background-color: #45a049;
88
- }
89
- h2 {
90
- font-family: 'Lobster', cursive;
91
- font-size: 3em !important;
92
- text-align: center;
93
- margin: 0;
94
- }
95
- p.slogan, h4, p, h3 {
96
- font-family: 'Roboto', sans-serif;
97
- text-align: center;
98
- }
99
- </style>
100
- <h1>Iremia</h1>
101
- <h4 style='text-align: center; font-size: 1.5em'>Tu aliado para el bienestar de tu bebé</h4>
102
- """
103
- )
104
- gr.Markdown(
105
- "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
106
- )
107
- gr.Markdown(
108
- "<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
109
- )
110
- gr.Markdown(
111
- "<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
112
- )
113
- gr.Markdown(
114
- "<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
115
- )
116
- gr.Markdown(
117
- "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
118
- )
119
- gr.Markdown(
120
- "<p style='text-align: left'>Iremia ofrece dos funcionalidades muy interesantes:</p>"
121
- )
122
- gr.Markdown(
123
- "<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial, somos capaces de predecir por qué tu hijo de menos de 2 años está llorando. Además, tendrás acceso a un asistente personal para consultar cualquier duda que tengas sobre el cuidado de tu pequeño.</p>"
124
- )
125
- gr.Markdown(
126
- "<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no, y si está llorando, predice automáticamente la causa, lo cual te brindará la tranquilidad de saber siempre qué pasa con tu pequeño y te ahorrará tiempo y muchas horas de sueño.</p>"
127
- )
128
  with gr.Row():
129
  with gr.Column():
130
  gr.Markdown("<h2>Predictor</h2>")
@@ -134,74 +111,41 @@ with gr.Blocks(theme=my_theme) as demo:
134
  gr.Markdown("<h2>Monitor</h2>")
135
  boton_pagina_2 = gr.Button("Prueba el monitor")
136
  gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
137
- with gr.Column(visible=False) as pagina_1:
138
- with gr.Row():
139
- with gr.Column():
140
- gr.Markdown("<h2>Predictor</h2>")
141
- audio_input = gr.Audio(
142
- min_length=1.0,
143
- format="wav",
144
- label="Baby recorder",
145
- type="filepath", # Para no usar numpy y preprocesar siempre igual
146
- )
147
- classify_btn = gr.Button("¿Por qué llora?")
148
- classification_output = gr.Textbox(label="Tu bebé llora por:")
149
- classify_btn.click(
150
- lambda audio: predict(audio, dataset_path="data/mixed_data"),
151
- inputs=audio_input,
152
- outputs=classification_output
153
- )
154
- with gr.Column():
155
- gr.Markdown("<h2>Assistant</h2>")
156
- system_message = "You are a Chatbot specialized in baby health and care."
157
- max_tokens = 512
158
- temperature = 0.7
159
- top_p = 0.95
160
- chatbot = gr.ChatInterface(
161
- respond, # TODO: Cambiar para que argumentos estén aquí metidos
162
- additional_inputs=[
163
- gr.State(value=system_message),
164
- gr.State(value=max_tokens),
165
- gr.State(value=temperature),
166
- gr.State(value=top_p)
167
- ],
168
- )
169
- gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
170
- boton_volver_inicio_1 = gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pagina_1, pantalla_inicial])
171
- with gr.Column(visible=False) as pagina_2:
172
- with gr.Row():
173
- with gr.Column():
174
- gr.Markdown("<h2>Monitor</h2>")
175
- audio_input = gr.Audio(
176
- min_length=1.0,
177
- format="wav",
178
- label="Baby recorder",
179
- type="filepath", # Para no usar numpy y preprocesar siempre igual
180
- )
181
- classify_btn = gr.Button("¿Por qué llora?")
182
- classification_output = gr.Textbox(label="Tu bebé está:")
183
- classify_btn.click(
184
- lambda audio: predict(audio, dataset_path="data/baby_cry_detection"),
185
- inputs=audio_input,
186
- outputs=classification_output
187
- )
188
- with gr.Column():
189
- gr.Markdown("<h2>Assistant</h2>")
190
- system_message = "You are a Chatbot specialized in baby health and care."
191
- max_tokens = 512
192
- temperature = 0.7
193
- top_p = 0.95
194
- chatbot = gr.ChatInterface(
195
- respond, # TODO: Cambiar para que argumentos estén aquí metidos
196
- additional_inputs=[
197
- gr.State(value=system_message),
198
- gr.State(value=max_tokens),
199
- gr.State(value=temperature),
200
- gr.State(value=top_p)
201
- ],
202
- )
203
- gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
204
- boton_volver_inicio_2 = gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pagina_2, pantalla_inicial])
205
- boton_pagina_1.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_1])
206
- boton_pagina_2.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_2])
207
- demo.launch()
 
2
  import torch
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
+ from model import predict_params, AudioDataset
6
+ from interfaz import estilo, my_theme
7
 
8
  token = os.getenv("HF_TOKEN")
9
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
10
+ model_cache = {}
11
 
12
+ def load_model_and_dataset(model_path, dataset_path, filter_white_noise):
13
+ if (model_path, dataset_path, filter_white_noise) not in model_cache:
14
+ model, _, _, id2label = predict_params(dataset_path, model_path, filter_white_noise)
15
+ model_cache[(model_path, dataset_path, filter_white_noise)] = (model, id2label)
16
+ return model_cache[(model_path, dataset_path, filter_white_noise)]
17
+
18
+ def predict(audio_path, model_path, dataset_path, filter_white_noise):
19
+ model, id2label = load_model_and_dataset(model_path, dataset_path, filter_white_noise)
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model.to(device)
22
+ model.eval()
23
+ audios = AudioDataset(dataset_path, {}, filter_white_noise).preprocess_audio(audio_path)
24
+ inputs = {"input_values": audios.to(device).unsqueeze(0)}
25
  with torch.no_grad():
26
  outputs = model(**inputs)
27
+ logits = outputs.logits
28
+ predicted_class_ids = torch.argmax(logits, dim=-1).item()
29
+ label = id2label[predicted_class_ids]
30
+ if dataset_path == "data/mixed_data":
31
+ label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
32
+ label = label_mapping.get(predicted_class_ids, label)
33
  return label
34
 
35
+ def predict_stream(audio_path):
36
+ model_mon, _ = load_model_and_dataset(
37
+ model_path="distilhubert-finetuned-cry-detector",
38
+ dataset_path="data/baby_cry_detection",
39
+ filter_white_noise=False
40
+ )
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model_mon.to(device)
43
+ model_mon.eval()
44
+ audio_dataset = AudioDataset(dataset_path="data/baby_cry_detection", label2id={}, filter_white_noise=False)
45
+ processed_audio = audio_dataset.preprocess_audio(audio_path)
46
+ inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
47
+ with torch.no_grad():
48
+ outputs = model_mon(**inputs)
49
+ logits = outputs.logits
50
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
51
+ crying_probabilities = probabilities[:, 1]
52
+ avg_crying_probability = crying_probabilities.mean().item()*100
53
+ if avg_crying_probability < 25:
54
+ model_class, id2label = load_model_and_dataset(
55
+ model_path="distilhubert-finetuned-mixed-data",
56
+ dataset_path="data/mixed_data",
57
+ filter_white_noise=True
58
+ )
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ model_class.to(device)
61
+ model_class.eval()
62
+ audio_dataset_class = AudioDataset(dataset_path="data/mixed_data", label2id={}, filter_white_noise=True)
63
+ processed_audio_class = audio_dataset_class.preprocess_audio(audio_path)
64
+ inputs_class = {"input_values": processed_audio_class.to(device).unsqueeze(0)}
65
+ with torch.no_grad():
66
+ outputs_class = model_class(**inputs_class)
67
+ logits_class = outputs_class.logits
68
+ predicted_class_ids_class = torch.argmax(logits_class, dim=-1).item()
69
+ label_class = id2label[predicted_class_ids_class]
70
+ label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
71
+ label_class = label_mapping.get(predicted_class_ids_class, label_class)
72
+ return f"Bebé llorando por {label_class}. Probabilidad: {avg_crying_probability:.1f})"
73
+ else:
74
+ return f"No está llorando. Proabilidad: {avg_crying_probability:.1f})"
75
+
76
+ def chatbot_config(message, history: list[tuple[str, str]]):
77
+ system_message = "You are a Chatbot specialized in baby health and care."
78
+ max_tokens = 512
79
+ temperature = 0.7
80
+ top_p = 0.95
81
  messages = [{"role": "system", "content": system_message}]
82
  for val in history:
83
  if val[0]:
 
86
  messages.append({"role": "assistant", "content": val[1]})
87
  messages.append({"role": "user", "content": message})
88
  response = ""
89
+ for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
90
+ token = message_response.choices[0].delta.content
91
  response += token
92
  yield response
93
 
94
  def cambiar_pestaña():
95
  return gr.update(visible=False), gr.update(visible=True)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  with gr.Blocks(theme=my_theme) as demo:
98
+ estilo()
99
+ with gr.Column(visible=True) as chatbot:
100
+ gr.Markdown("<h2>Asistente</h2>")
101
+ gr.ChatInterface(
102
+ chatbot_config # TODO: Mirar argumentos
103
+ )
104
+ gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  with gr.Row():
106
  with gr.Column():
107
  gr.Markdown("<h2>Predictor</h2>")
 
111
  gr.Markdown("<h2>Monitor</h2>")
112
  boton_pagina_2 = gr.Button("Prueba el monitor")
113
  gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
114
+ with gr.Column(visible=False) as pag_predictor:
115
+ gr.Markdown("<h2>Predictor</h2>")
116
+ audio_input = gr.Audio(
117
+ min_length=1.0,
118
+ format="wav",
119
+ label="Baby recorder",
120
+ type="filepath",
121
+ )
122
+ classify_btn = gr.Button("¿Por qué llora?")
123
+ classify_btn.click(
124
+ lambda audio: predict( # Mirar porque usar lambda
125
+ audio,
126
+ model_path="distilhubert-finetuned-mixed-data",
127
+ dataset_path="data/mixed_data",
128
+ filter_white_noise=True
129
+ ),
130
+ inputs=audio_input,
131
+ outputs=gr.Textbox(label="Tu bebé llora por:")
132
+ )
133
+ gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
134
+ with gr.Column(visible=False) as pag_monitor:
135
+ gr.Markdown("<h2>Monitor</h2>")
136
+ audio_stream = gr.Audio(
137
+ # min_length=1.0, # mirar por qué no va esto
138
+ format="wav",
139
+ label="Baby recorder",
140
+ type="filepath",
141
+ streaming=True
142
+ )
143
+ audio_stream.stream(
144
+ predict_stream,
145
+ inputs=audio_stream,
146
+ outputs=gr.Textbox(label="Tu bebé está:"),
147
+ )
148
+ gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
149
+ boton_pagina_1.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
150
+ boton_pagina_2.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
151
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
interfaz.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ my_theme = gr.themes.Soft(
4
+ primary_hue="emerald",
5
+ secondary_hue="green",
6
+ neutral_hue="slate",
7
+ text_size="sm",
8
+ spacing_size="sm",
9
+ font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
10
+ font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
11
+ ).set(
12
+ body_background_fill='*neutral_50',
13
+ body_text_color='*neutral_600',
14
+ body_text_size='*text_sm',
15
+ embed_radius='*radius_md',
16
+ shadow_drop='*shadow_spread',
17
+ shadow_spread='*button_shadow_active'
18
+ )
19
+ def estilo():
20
+ gr.HTML(
21
+ """
22
+ <style>
23
+ @import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
24
+ @import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
25
+
26
+ h1 {
27
+ font-family: 'Lobster', cursive;
28
+ font-size: 5em !important;
29
+ text-align: center;
30
+ margin: 0;
31
+ }
32
+
33
+ .gr-button {
34
+ background-color: #4CAF50 !important;
35
+ color: white !important;
36
+ border: none;
37
+ padding: 15px 32px;
38
+ text-align: center;
39
+ text-decoration: none;
40
+ display: inline-block;
41
+ font-size: 16px;
42
+ margin: 4px 2px;
43
+ cursor: pointer;
44
+ border-radius: 12px;
45
+ }
46
+
47
+ .gr-button:hover {
48
+ background-color: #45a049;
49
+ }
50
+ h2 {
51
+ font-family: 'Lobster', cursive;
52
+ font-size: 3em !important;
53
+ text-align: center;
54
+ margin: 0;
55
+ }
56
+ p.slogan, h4, p, h3 {
57
+ font-family: 'Roboto', sans-serif;
58
+ text-align: center;
59
+ }
60
+ </style>
61
+ <h1>Iremia</h1>
62
+ <h4 style='text-align: center; font-size: 1.5em'>Tu aliado para el bienestar de tu bebé</h4>
63
+ """
64
+ )
65
+ return my_theme
66
+ def inicio():
67
+ estilo()
68
+ gr.Markdown(
69
+ "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
70
+ )
71
+ gr.Markdown(
72
+ "<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
73
+ )
74
+ gr.Markdown(
75
+ "<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
76
+ )
77
+ gr.Markdown(
78
+ "<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
79
+ )
80
+ gr.Markdown(
81
+ "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
82
+ )
83
+ gr.Markdown(
84
+ "<p style='text-align: left'>Iremia ofrece dos funcionalidades muy interesantes:</p>"
85
+ )
86
+ gr.Markdown(
87
+ "<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial, somos capaces de predecir por qué tu hijo de menos de 2 años está llorando. Además, tendrás acceso a un asistente personal para consultar cualquier duda que tengas sobre el cuidado de tu pequeño.</p>"
88
+ )
89
+ gr.Markdown(
90
+ "<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no, y si está llorando, predice automáticamente la causa, lo cual te brindará la tranquilidad de saber siempre qué pasa con tu pequeño y te ahorrará tiempo y muchas horas de sueño.</p>"
91
+ )
92
+ with gr.Row():
93
+ with gr.Column():
94
+ gr.Markdown("<h2>Predictor</h2>")
95
+ boton_pagina_1 = gr.Button("Prueba el predictor")
96
+ gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
97
+ with gr.Column():
98
+ gr.Markdown("<h2>Monitor</h2>")
99
+ boton_pagina_2 = gr.Button("Prueba el monitor")
100
+ gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
101
+ return boton_pagina_1, boton_pagina_2
model.py CHANGED
@@ -23,11 +23,13 @@ config_file = "models_config.json"
23
  clasificador = "class"
24
  monitor = "mon"
25
  batch_size = 16
 
26
 
27
  class AudioDataset(Dataset):
28
- def __init__(self, dataset_path, label2id):
29
  self.dataset_path = dataset_path
30
  self.label2id = label2id
 
31
  self.file_paths = []
32
  self.labels = []
33
  for label_dir, label_id in self.label2id.items():
@@ -37,7 +39,7 @@ class AudioDataset(Dataset):
37
  audio_path = os.path.join(label_path, file_name)
38
  self.file_paths.append(audio_path)
39
  self.labels.append(label_id)
40
- self.file_paths.sort(key=lambda x: x.split('_part')[0])
41
 
42
  def __len__(self):
43
  return len(self.file_paths)
@@ -55,29 +57,33 @@ class AudioDataset(Dataset):
55
  waveform, sample_rate = torchaudio.load(
56
  audio_path,
57
  normalize=True, # Convierte a float32
58
- # num_frames= # TODO: Probar para que no haga falta recortar los audios
59
  )
60
  if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
61
  resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
62
  waveform = resampler(waveform)
63
  if waveform.shape[0] > 1: # Si es stereo, convertir a mono
64
  waveform = waveform.mean(dim=0, keepdim=True)
65
- waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Sin 1e-6 el accuracy es pésimo!!
66
  max_length = int(SAMPLING_RATE * MAX_DURATION)
67
  if waveform.shape[1] > max_length:
68
- waveform = waveform[:, :max_length]
69
  else:
70
- waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1]))
71
  inputs = FEATURE_EXTRACTOR(
72
  waveform.squeeze(),
73
- sampling_rate=SAMPLING_RATE,
74
  return_tensors="pt",
75
  # max_length=int(SAMPLING_RATE * MAX_DURATION),
76
- # truncation=True,
77
- padding=True,
78
  )
79
  return inputs.input_values.squeeze()
80
 
 
 
 
 
 
81
  def seed_everything():
82
  torch.manual_seed(seed)
83
  torch.cuda.manual_seed(seed)
@@ -96,9 +102,9 @@ def build_label_mappings(dataset_path):
96
  label_id += 1
97
  return label2id, id2label
98
 
99
- def create_dataloader(dataset_path, test_size=0.2, num_workers=12, shuffle=True, pin_memory=True):
100
  label2id, id2label = build_label_mappings(dataset_path)
101
- dataset = AudioDataset(dataset_path, label2id)
102
  dataset_size = len(dataset)
103
  indices = list(range(dataset_size))
104
  random.shuffle(indices)
@@ -115,9 +121,9 @@ def create_dataloader(dataset_path, test_size=0.2, num_workers=12, shuffle=True,
115
  )
116
  return train_dataloader, test_dataloader, label2id, id2label
117
 
118
- def load_model(num_labels, label2id, id2label):
119
  config = HubertConfig.from_pretrained(
120
- MODEL,
121
  num_labels=num_labels,
122
  label2id=label2id,
123
  id2label=id2label,
@@ -125,18 +131,23 @@ def load_model(num_labels, label2id, id2label):
125
  )
126
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
  model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
128
- MODEL,
129
  config=config,
130
- torch_dtype=torch.float32, # No afecta 1ª época, mejor ponerlo
131
  )
132
  model.to(device)
133
  return model
134
 
135
- def model_params(dataset_path):
136
- train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path)
137
- model = load_model(num_labels=len(id2label), label2id=label2id, id2label=id2label)
138
  return model, train_dataloader, test_dataloader, id2label
139
 
 
 
 
 
 
140
  def compute_metrics(eval_pred):
141
  predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
142
  references = torch.tensor(eval_pred.label_ids)
@@ -149,9 +160,9 @@ def compute_metrics(eval_pred):
149
  "f1": f1,
150
  }
151
 
152
- def main(training_args, output_dir, dataset_path):
153
  seed_everything()
154
- model, train_dataloader, test_dataloader, _ = model_params(dataset_path)
155
  trainer = Trainer(
156
  model=model,
157
  args=training_args,
@@ -162,9 +173,10 @@ def main(training_args, output_dir, dataset_path):
162
  )
163
  torch.cuda.empty_cache() # liberar memoria de la GPU
164
  trainer.train() # se pueden modificar los parámetros para continuar el train
165
- os.makedirs(output_dir, exist_ok=True) # Crear carpeta con el modelo si no existe
166
- trainer.save_model(output_dir) # para subir el modelo a Hugging Face. Necesario para hacer la predicción, no sé por qué.
167
- # upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}",folder_path=output_dir, token=token) # subir modelo a organización
 
168
 
169
  def load_config(model_name):
170
  with open(config_file, 'r') as f:
@@ -176,8 +188,10 @@ def load_config(model_name):
176
 
177
  if __name__ == "__main__":
178
  config = load_config(clasificador) # PARA CAMBIAR MODELOS
 
179
  # config = load_config(monitor) # PARA CAMBIAR MODELOS
 
180
  training_args = config["training_args"]
181
  output_dir = config["output_dir"]
182
  dataset_path = config["dataset_path"]
183
- main(training_args, output_dir, dataset_path)
 
23
  clasificador = "class"
24
  monitor = "mon"
25
  batch_size = 16
26
+ num_workers = 12
27
 
28
  class AudioDataset(Dataset):
29
+ def __init__(self, dataset_path, label2id, filter_white_noise):
30
  self.dataset_path = dataset_path
31
  self.label2id = label2id
32
+ self.filter_white_noise = filter_white_noise
33
  self.file_paths = []
34
  self.labels = []
35
  for label_dir, label_id in self.label2id.items():
 
39
  audio_path = os.path.join(label_path, file_name)
40
  self.file_paths.append(audio_path)
41
  self.labels.append(label_id)
42
+ self.file_paths.sort(key=lambda x: x.split('_part')[0]) # no sé si influye
43
 
44
  def __len__(self):
45
  return len(self.file_paths)
 
57
  waveform, sample_rate = torchaudio.load(
58
  audio_path,
59
  normalize=True, # Convierte a float32
 
60
  )
61
  if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
62
  resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
63
  waveform = resampler(waveform)
64
  if waveform.shape[0] > 1: # Si es stereo, convertir a mono
65
  waveform = waveform.mean(dim=0, keepdim=True)
66
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Normalizar, sin 1e-6 el accuracy es pésimo!!
67
  max_length = int(SAMPLING_RATE * MAX_DURATION)
68
  if waveform.shape[1] > max_length:
69
+ waveform = waveform[:, :max_length] # Truncar
70
  else:
71
+ waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
72
  inputs = FEATURE_EXTRACTOR(
73
  waveform.squeeze(),
74
+ sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
75
  return_tensors="pt",
76
  # max_length=int(SAMPLING_RATE * MAX_DURATION),
77
+ # truncation=True, # Hecho a mano
78
+ # padding=True, # Hecho a mano
79
  )
80
  return inputs.input_values.squeeze()
81
 
82
+ def is_white_noise(audio):
83
+ mean = torch.mean(audio)
84
+ std = torch.std(audio)
85
+ return torch.abs(mean) < 0.001 and std < 0.01
86
+
87
  def seed_everything():
88
  torch.manual_seed(seed)
89
  torch.cuda.manual_seed(seed)
 
102
  label_id += 1
103
  return label2id, id2label
104
 
105
+ def create_dataloader(dataset_path, filter_white_noise, test_size=0.2, shuffle=True, pin_memory=True):
106
  label2id, id2label = build_label_mappings(dataset_path)
107
+ dataset = AudioDataset(dataset_path, label2id, filter_white_noise)
108
  dataset_size = len(dataset)
109
  indices = list(range(dataset_size))
110
  random.shuffle(indices)
 
121
  )
122
  return train_dataloader, test_dataloader, label2id, id2label
123
 
124
+ def load_model(model_path, num_labels, label2id, id2label):
125
  config = HubertConfig.from_pretrained(
126
+ pretrained_model_name_or_path=model_path,
127
  num_labels=num_labels,
128
  label2id=label2id,
129
  id2label=id2label,
 
131
  )
132
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
133
  model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
134
+ pretrained_model_name_or_path=model_path,
135
  config=config,
136
+ torch_dtype=torch.float32,
137
  )
138
  model.to(device)
139
  return model
140
 
141
+ def train_params(dataset_path, filter_white_noise):
142
+ train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
143
+ model = load_model(model_path=MODEL, num_labels=len(id2label), label2id=label2id, id2label=id2label)
144
  return model, train_dataloader, test_dataloader, id2label
145
 
146
+ def predict_params(dataset_path, model_path, filter_white_noise):
147
+ _, _, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
148
+ model = load_model(model_path, num_labels=len(id2label), label2id=label2id, id2label=id2label)
149
+ return model, None, None, id2label
150
+
151
  def compute_metrics(eval_pred):
152
  predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
153
  references = torch.tensor(eval_pred.label_ids)
 
160
  "f1": f1,
161
  }
162
 
163
+ def main(training_args, output_dir, dataset_path, filter_white_noise):
164
  seed_everything()
165
+ model, train_dataloader, test_dataloader, _ = train_params(dataset_path, filter_white_noise)
166
  trainer = Trainer(
167
  model=model,
168
  args=training_args,
 
173
  )
174
  torch.cuda.empty_cache() # liberar memoria de la GPU
175
  trainer.train() # se pueden modificar los parámetros para continuar el train
176
+ # trainer.save_model(output_dir) # Guardar modelo local.
177
+ os.makedirs(output_dir, exist_ok=True) # Crear carpeta
178
+ trainer.push_to_hub(token=token) # Subir modelo a perfil
179
+ upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local
180
 
181
  def load_config(model_name):
182
  with open(config_file, 'r') as f:
 
188
 
189
  if __name__ == "__main__":
190
  config = load_config(clasificador) # PARA CAMBIAR MODELOS
191
+ filter_white_noise = True
192
  # config = load_config(monitor) # PARA CAMBIAR MODELOS
193
+ # filter_white_noise = False
194
  training_args = config["training_args"]
195
  output_dir = config["output_dir"]
196
  dataset_path = config["dataset_path"]
197
+ main(training_args, output_dir, dataset_path, filter_white_noise)