Marcos12886 commited on
Commit
2567e73
·
1 Parent(s): 96d8764

Usar solo una función para ambos modelos

Browse files
Files changed (1) hide show
  1. app.py +14 -22
app.py CHANGED
@@ -7,28 +7,13 @@ from model import model_params, AudioDataset
7
  token = os.getenv("HF_TOKEN")
8
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
9
 
10
- def predict_class(audio_path):
11
- dataset_path = f"data/mixed_data" # PARA CLASIFICADOR
12
  model, _, _, id2label = model_params(dataset_path)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Usar a GPU o CPU
14
  model.to(device)# Usar a GPU o CPU
15
  audio_dataset = AudioDataset(dataset_path, {})
16
  inputs = audio_dataset.preprocess_audio(audio_path)
17
- inputs = {"input_values": inputs.to(device).unsqueeze(0)}
18
- with torch.no_grad():
19
- outputs = model(**inputs)
20
- predicted_class_ids = outputs.logits.argmax(-1)
21
- label = id2label[predicted_class_ids.item()]
22
- return label
23
-
24
- def predict_mon(audio_path):
25
- dataset_path = f"data/baby_cry_detection" # PARA CLASIFICADOR
26
- model, _, _, id2label = model_params(dataset_path)
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Usar a GPU o CPU
28
- model.to(device)# Usar a GPU o CPU
29
- audio_dataset = AudioDataset(dataset_path, {})
30
- inputs = audio_dataset.preprocess_audio(audio_path)
31
- inputs = {"input_values": inputs.to(device).unsqueeze(0)}
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
  predicted_class_ids = outputs.logits.argmax(-1)
@@ -161,7 +146,11 @@ with gr.Blocks(theme=my_theme) as demo:
161
  )
162
  classify_btn = gr.Button("¿Por qué llora?")
163
  classification_output = gr.Textbox(label="Tu bebé llora por:")
164
- classify_btn.click(predict_class, inputs=audio_input, outputs=classification_output)
 
 
 
 
165
  with gr.Column():
166
  gr.Markdown("<h2>Assistant</h2>")
167
  system_message = "You are a Chatbot specialized in baby health and care."
@@ -190,8 +179,12 @@ with gr.Blocks(theme=my_theme) as demo:
190
  type="filepath", # Para no usar numpy y preprocesar siempre igual
191
  )
192
  classify_btn = gr.Button("¿Por qué llora?")
193
- classification_output = gr.Textbox(label="Tu bebé llora por:")
194
- classify_btn.click(predict_mon, inputs=audio_input, outputs=classification_output)
 
 
 
 
195
  with gr.Column():
196
  gr.Markdown("<h2>Assistant</h2>")
197
  system_message = "You are a Chatbot specialized in baby health and care."
@@ -208,8 +201,7 @@ with gr.Blocks(theme=my_theme) as demo:
208
  ],
209
  )
210
  gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
211
- gr.Markdown("Contenido de la Página 2")
212
- boton_volver_inicio_2 = gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pagina_2, pantalla_inicial])
213
  boton_pagina_1.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_1])
214
  boton_pagina_2.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_2])
215
  demo.launch()
 
7
  token = os.getenv("HF_TOKEN")
8
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
9
 
10
+ def predict(audio_path, dataset_path):
 
11
  model, _, _, id2label = model_params(dataset_path)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Usar a GPU o CPU
13
  model.to(device)# Usar a GPU o CPU
14
  audio_dataset = AudioDataset(dataset_path, {})
15
  inputs = audio_dataset.preprocess_audio(audio_path)
16
+ inputs = {"input_values": inputs.to(device).unsqueeze(0)} # Add batch dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
  predicted_class_ids = outputs.logits.argmax(-1)
 
146
  )
147
  classify_btn = gr.Button("¿Por qué llora?")
148
  classification_output = gr.Textbox(label="Tu bebé llora por:")
149
+ classify_btn.click(
150
+ lambda audio: predict(audio, dataset_path="data/mixed_data"),
151
+ inputs=audio_input,
152
+ outputs=classification_output
153
+ )
154
  with gr.Column():
155
  gr.Markdown("<h2>Assistant</h2>")
156
  system_message = "You are a Chatbot specialized in baby health and care."
 
179
  type="filepath", # Para no usar numpy y preprocesar siempre igual
180
  )
181
  classify_btn = gr.Button("¿Por qué llora?")
182
+ classification_output = gr.Textbox(label="Tu bebé está:")
183
+ classify_btn.click(
184
+ lambda audio: predict(audio, dataset_path="data/baby_cry_detection"),
185
+ inputs=audio_input,
186
+ outputs=classification_output
187
+ )
188
  with gr.Column():
189
  gr.Markdown("<h2>Assistant</h2>")
190
  system_message = "You are a Chatbot specialized in baby health and care."
 
201
  ],
202
  )
203
  gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
204
+ boton_volver_inicio_2 = gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pagina_2, pantalla_inicial])
 
205
  boton_pagina_1.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_1])
206
  boton_pagina_2.click(cambiar_pestaña, outputs=[pantalla_inicial, pagina_2])
207
  demo.launch()