Robertomarting commited on
Commit
c101beb
·
verified ·
1 Parent(s): a16c24f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +324 -0
app.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import gradio as gr
3
+ import soundfile as sf
4
+ import tempfile
5
+ import os
6
+ import io
7
+ import librosa
8
+ import numpy as np
9
+ import pandas as pd
10
+ from transformers import ASTFeatureExtractor, AutoModelForAudioClassification, Trainer, Wav2Vec2FeatureExtractor, HubertForSequenceClassification, pipeline
11
+ from datasets import Dataset, DatasetDict
12
+ import torch.nn.functional as F
13
+ import torch
14
+ from collections import Counter
15
+ from scipy.stats import kurtosis
16
+ from huggingface_hub import InferenceClient
17
+ import os
18
+
19
+ access_token_mod_1 = os.getenv('HF_Access_Personal')
20
+
21
+ # Cargar el procesador y modelo
22
+ processor = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
23
+ model = AutoModelForAudioClassification.from_pretrained("Robertomarting/tmp_trainer",token=access_token_mod_1)
24
+
25
+ def is_white_noise(audio, threshold=0.75):
26
+ kurt = kurtosis(audio)
27
+ return np.abs(kurt) < 0.1 and np.mean(np.abs(audio)) < threshold
28
+
29
+ def create_audio_dataframe(audio_tuple, target_sr=16000, target_duration=1.0):
30
+ data = []
31
+ target_length = int(target_sr * target_duration)
32
+
33
+ wav_buffer = io.BytesIO()
34
+ sf.write(wav_buffer, audio_tuple[1], audio_tuple[0], format='wav')
35
+
36
+ wav_buffer.seek(0)
37
+ audio_data, sample_rate = sf.read(wav_buffer)
38
+
39
+ audio_data = audio_data.astype(np.float32)
40
+
41
+ if len(audio_data.shape) > 1:
42
+ audio_data = np.mean(audio_data, axis=1)
43
+
44
+ if sample_rate != target_sr:
45
+ audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=target_sr)
46
+
47
+ audio_data, _ = librosa.effects.trim(audio_data)
48
+
49
+ if len(audio_data) > target_length:
50
+ for i in range(0, len(audio_data), target_length):
51
+ segment = audio_data[i:i + target_length]
52
+ if len(segment) == target_length:
53
+ if not is_white_noise(segment):
54
+ data.append({"audio": segment})
55
+ else:
56
+ if not is_white_noise(audio_data):
57
+ data.append({"audio": audio_data})
58
+
59
+ df = pd.DataFrame(data)
60
+ return df
61
+
62
+ def convert_bytes_to_float64(byte_list):
63
+ return [float(i) for i in byte_list]
64
+
65
+ def preprocess_function(examples):
66
+ audio_arrays = examples["audio"]
67
+ inputs = processor(
68
+ audio_arrays,
69
+ padding=True,
70
+ sampling_rate=processor.sampling_rate,
71
+ max_length=int(processor.sampling_rate * 1),
72
+ truncation=True,
73
+ )
74
+ return inputs
75
+
76
+ def predict_audio(audio):
77
+ df = create_audio_dataframe(audio)
78
+ df['audio'] = df['audio'].apply(convert_bytes_to_float64)
79
+
80
+ # Convertir el dataframe a Dataset
81
+ predict_dataset = Dataset.from_pandas(df)
82
+ dataset = DatasetDict({
83
+ 'train': predict_dataset
84
+ })
85
+
86
+ if '__index_level_0__' in dataset['train'].column_names:
87
+ dataset['train'] = dataset['train'].remove_columns(['__index_level_0__'])
88
+
89
+ encoded_dataset = dataset.map(preprocess_function, remove_columns=["audio"], batched=True)
90
+
91
+ # Crear el Trainer para la predicción
92
+ trainer = Trainer(
93
+ model=model,
94
+ eval_dataset=encoded_dataset["train"]
95
+ )
96
+
97
+ # Realizar las predicciones
98
+ predictions_output = trainer.predict(encoded_dataset["train"].with_format("torch"))
99
+
100
+ # Obtener las predicciones y etiquetas verdaderas
101
+ predictions = predictions_output.predictions
102
+ labels = predictions_output.label_ids
103
+
104
+ # Convertir logits a probabilidades
105
+ probabilities = F.softmax(torch.tensor(predictions), dim=-1).numpy()
106
+ predicted_classes = probabilities.argmax(axis=1)
107
+
108
+ # Obtener la etiqueta más común
109
+ most_common_predicted_label = Counter(predicted_classes).most_common(1)[0][0]
110
+
111
+ # Mapear etiquetas numéricas a etiquetas de texto
112
+ replace_dict = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
113
+ most_common_predicted_label = replace_dict[most_common_predicted_label]
114
+
115
+ return most_common_predicted_label
116
+
117
+ def clear_audio_input(audio):
118
+ return ""
119
+
120
+ access_token = os.getenv('HF_ACCESS_TOKEN')
121
+
122
+ client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407", token=access_token)
123
+
124
+ def respond(
125
+ message,
126
+ history: list[tuple[str, str]],
127
+ system_message,
128
+ max_tokens,
129
+ temperature,
130
+ top_p,
131
+ ):
132
+ messages = [{"role": "system", "content": system_message}]
133
+
134
+ for val in history:
135
+ if val[0]:
136
+ messages.append({"role": "user", "content": val[0]})
137
+ if val[1]:
138
+ messages.append({"role": "assistant", "content": val[1]})
139
+
140
+ messages.append({"role": "user", "content": message})
141
+
142
+ response = ""
143
+
144
+ for message in client.chat_completion(
145
+ messages,
146
+ max_tokens=max_tokens,
147
+ stream=True,
148
+ temperature=temperature,
149
+ top_p=top_p,
150
+ ):
151
+ token = message.choices[0].delta.content
152
+ response += token
153
+ yield response
154
+
155
+ my_theme = gr.themes.Soft(
156
+ primary_hue="emerald",
157
+ secondary_hue="green",
158
+ neutral_hue="slate",
159
+ text_size="sm",
160
+ spacing_size="sm",
161
+ font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
162
+ font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
163
+ ).set(
164
+ body_background_fill='*neutral_50',
165
+ body_text_color='*neutral_600',
166
+ body_text_size='*text_sm',
167
+ embed_radius='*radius_md',
168
+ shadow_drop='*shadow_spread',
169
+ shadow_spread='*button_shadow_active'
170
+ )
171
+
172
+ # Función para mostrar la página 1
173
+ def mostrar_pagina_1():
174
+ return gr.update(visible=False), gr.update(visible=True)
175
+
176
+ # Función para mostrar la página 2
177
+ def mostrar_pagina_2():
178
+ return gr.update(visible=False), gr.update(visible=True)
179
+
180
+ # Función para regresar a la pantalla inicial
181
+ def redirigir_a_pantalla_inicial():
182
+ return gr.update(visible=True), gr.update(visible=False)
183
+
184
+ ### Monitor
185
+
186
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("ntu-spml/distilhubert")
187
+
188
+ monitor_model = HubertForSequenceClassification.from_pretrained("A-POR-LOS-8000/distilhubert-finetuned-cry-detector",token=access_token_mod_1)
189
+
190
+ pipeline_monitor = pipeline(model=monitor_model,feature_extractor=processor)
191
+
192
+ def predict_monitor(stream, new_chunk):
193
+ sr, y = new_chunk
194
+ y = y.astype(np.float32)
195
+ y /= np.max(np.abs(y))
196
+
197
+ if stream is not None:
198
+ stream = np.concatenate([stream, y])
199
+ else:
200
+ stream = y
201
+ return stream, pipeline_monitor(stream)
202
+
203
+ my_theme = gr.themes.Soft(
204
+ primary_hue="emerald",
205
+ secondary_hue="green",
206
+ neutral_hue="slate",
207
+ text_size="sm",
208
+ spacing_size="sm",
209
+ font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
210
+ font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
211
+ ).set(
212
+ body_background_fill='*neutral_50',
213
+ body_text_color='*neutral_600',
214
+ body_text_size='*text_sm',
215
+ embed_radius='*radius_md',
216
+ shadow_drop='*shadow_spread',
217
+ shadow_spread='*button_shadow_active'
218
+ )
219
+
220
+ with gr.Blocks(theme = my_theme) as demo:
221
+
222
+ with gr.Column() as pantalla_inicial:
223
+ gr.HTML(
224
+ """
225
+ <style>
226
+ @import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
227
+ @import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
228
+
229
+ h1 {
230
+ font-family: 'Lobster', cursive;
231
+ font-size: 5em !important;
232
+ text-align: center;
233
+ margin: 0;
234
+ }
235
+ h2 {
236
+ font-family: 'Lobster', cursive;
237
+ font-size: 3em !important;
238
+ text-align: center;
239
+ margin: 0;
240
+ }
241
+ p.slogan, h4, p, h3 {
242
+ font-family: 'Roboto', sans-serif;
243
+ text-align: center;
244
+ }
245
+ </style>
246
+ <h1>Iremia</h1>
247
+ <h4 style='text-align: center; font-size: 1.5em'>El mejor aliado para el bienestar de tu bebé</h4>
248
+ """
249
+ )
250
+ gr.Markdown("<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>")
251
+ gr.Markdown("<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>")
252
+ gr.Markdown("<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>")
253
+ gr.Markdown("<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>")
254
+ gr.Markdown("<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>")
255
+ gr.Markdown("<p style='text-align: left'>Iremia ofrece dos funcionalidades muy interesantes:</p>")
256
+ gr.Markdown("<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial, somos capaces de predecir por qué tu hijo de menos de 2 años está llorando. Además, tendrás acceso a un asistente personal para consultar cualquier duda que tengas sobre el cuidado de tu pequeño.</p>")
257
+ gr.Markdown("<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no, y si está llorando, predice automáticamente la causa, lo cual te brindará la tranquilidad de saber siempre qué pasa con tu pequeño y te ahorrará tiempo y muchas horas de sueño.</p>")
258
+
259
+ with gr.Row():
260
+ with gr.Column():
261
+ gr.Markdown("<h2>Predictor</h2>")
262
+ boton_pagina_1 = gr.Button("Prueba el predictor")
263
+ gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
264
+ with gr.Column():
265
+ gr.Markdown("<h2>Monitor</h2>")
266
+ boton_pagina_2 = gr.Button("Prueba el monitor")
267
+ gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
268
+
269
+ with gr.Column(visible=False) as pagina_1:
270
+ with gr.Row():
271
+ with gr.Column():
272
+ gr.Markdown("<h2>Predictor</h2>")
273
+ audio_input = gr.Audio(type="numpy", label="Baby recorder")
274
+ classify_btn = gr.Button("¿Por qué llora?")
275
+ classification_output = gr.Textbox(label="Tu bebé llora por:")
276
+
277
+ classify_btn.click(predict_audio, inputs=audio_input, outputs=classification_output)
278
+ audio_input.change(fn=clear_audio_input, inputs=audio_input, outputs=classification_output)
279
+
280
+
281
+ with gr.Column():
282
+ gr.Markdown("<h2>Assistant</h2>")
283
+ system_message = "You are a Chatbot specialized in baby health and care."
284
+ max_tokens = 512
285
+ temperature = 0.7
286
+ top_p = 0.95
287
+
288
+ chatbot = gr.ChatInterface(
289
+ respond,
290
+ additional_inputs=[
291
+ gr.State(value=system_message),
292
+ gr.State(value=max_tokens),
293
+ gr.State(value=temperature),
294
+ gr.State(value=top_p)
295
+ ],
296
+ )
297
+
298
+ gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
299
+
300
+ boton_volver_inicio_1 = gr.Button("Volver a la pantalla inicial")
301
+ boton_volver_inicio_1.click(redirigir_a_pantalla_inicial, inputs=None, outputs=[pantalla_inicial, pagina_1])
302
+
303
+ with gr.Column(visible=False) as pagina_2:
304
+ gr.Markdown("<h2>Monitor</h2>")
305
+ gr.Markdown("# Detección en tiempo real del llanto del bebé con Pipeline")
306
+
307
+ # Componente de audio en streaming
308
+ audio_input = gr.Audio(source="microphone", streaming=True, format="wav", label="Habla cerca del micrófono")
309
+
310
+ # Salida del texto donde se muestra la predicción
311
+ output_text = gr.Textbox(label="Resultado de la predicción")
312
+
313
+ # Vincular la predicción en streaming con el audio
314
+ audio_input.stream(fn=lambda audio: predict_monitor(audio, audio_classifier),
315
+ inputs=audio_input,
316
+ outputs=output_text)
317
+
318
+ boton_volver_inicio_2 = gr.Button("Volver a la pantalla inicial")
319
+ boton_volver_inicio_2.click(redirigir_a_pantalla_inicial, inputs=None, outputs=[pantalla_inicial, pagina_2])
320
+
321
+ boton_pagina_1.click(mostrar_pagina_1, inputs=None, outputs=[pantalla_inicial, pagina_1])
322
+ boton_pagina_2.click(mostrar_pagina_2, inputs=None, outputs=[pantalla_inicial, pagina_2])
323
+
324
+ demo.launch()