CHATBOT / app.py
Marcos12886's picture
Todo bien esta noche muack
abdf62b
raw
history blame
7.21 kB
import os
import torch
import gradio as gr
from huggingface_hub import InferenceClient
from model import predict_params, AudioDataset
from interfaz import estilo, my_theme
token = os.getenv("HF_TOKEN")
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
model_cache = {}
def load_model_and_dataset(model_path, dataset_path, filter_white_noise):
if (model_path, dataset_path, filter_white_noise) not in model_cache:
model, _, _, id2label = predict_params(dataset_path, model_path, filter_white_noise)
model_cache[(model_path, dataset_path, filter_white_noise)] = (model, id2label)
return model_cache[(model_path, dataset_path, filter_white_noise)]
def predict(audio_path, model_path, dataset_path, filter_white_noise):
model, id2label = load_model_and_dataset(model_path, dataset_path, filter_white_noise)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
audios = AudioDataset(dataset_path, {}, filter_white_noise).preprocess_audio(audio_path)
inputs = {"input_values": audios.to(device).unsqueeze(0)}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_ids = torch.argmax(logits, dim=-1).item()
label = id2label[predicted_class_ids]
if dataset_path == "data/mixed_data":
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
label = label_mapping.get(predicted_class_ids, label)
return label
def predict_stream(audio_path):
model_mon, _ = load_model_and_dataset(
model_path="distilhubert-finetuned-cry-detector",
dataset_path="data/baby_cry_detection",
filter_white_noise=False
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_mon.to(device)
model_mon.eval()
audio_dataset = AudioDataset(dataset_path="data/baby_cry_detection", label2id={}, filter_white_noise=False)
processed_audio = audio_dataset.preprocess_audio(audio_path)
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
with torch.no_grad():
outputs = model_mon(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
crying_probabilities = probabilities[:, 1]
avg_crying_probability = crying_probabilities.mean().item()*100
if avg_crying_probability < 25:
model_class, id2label = load_model_and_dataset(
model_path="distilhubert-finetuned-mixed-data",
dataset_path="data/mixed_data",
filter_white_noise=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_class.to(device)
model_class.eval()
audio_dataset_class = AudioDataset(dataset_path="data/mixed_data", label2id={}, filter_white_noise=True)
processed_audio_class = audio_dataset_class.preprocess_audio(audio_path)
inputs_class = {"input_values": processed_audio_class.to(device).unsqueeze(0)}
with torch.no_grad():
outputs_class = model_class(**inputs_class)
logits_class = outputs_class.logits
predicted_class_ids_class = torch.argmax(logits_class, dim=-1).item()
label_class = id2label[predicted_class_ids_class]
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
label_class = label_mapping.get(predicted_class_ids_class, label_class)
return f"Bebé llorando por {label_class}. Probabilidad: {avg_crying_probability:.1f})"
else:
return f"No está llorando. Proabilidad: {avg_crying_probability:.1f})"
def chatbot_config(message, history: list[tuple[str, str]]):
system_message = "You are a Chatbot specialized in baby health and care."
max_tokens = 512
temperature = 0.7
top_p = 0.95
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
token = message_response.choices[0].delta.content
response += token
yield response
def cambiar_pestaña():
return gr.update(visible=False), gr.update(visible=True)
with gr.Blocks(theme=my_theme) as demo:
estilo()
with gr.Column(visible=True) as chatbot:
gr.Markdown("<h2>Asistente</h2>")
gr.ChatInterface(
chatbot_config # TODO: Mirar argumentos
)
gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
with gr.Row():
with gr.Column():
gr.Markdown("<h2>Predictor</h2>")
boton_pagina_1 = gr.Button("Prueba el predictor")
gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
with gr.Column():
gr.Markdown("<h2>Monitor</h2>")
boton_pagina_2 = gr.Button("Prueba el monitor")
gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
with gr.Column(visible=False) as pag_predictor:
gr.Markdown("<h2>Predictor</h2>")
audio_input = gr.Audio(
min_length=1.0,
format="wav",
label="Baby recorder",
type="filepath",
)
classify_btn = gr.Button("¿Por qué llora?")
classify_btn.click(
lambda audio: predict( # Mirar porque usar lambda
audio,
model_path="distilhubert-finetuned-mixed-data",
dataset_path="data/mixed_data",
filter_white_noise=True
),
inputs=audio_input,
outputs=gr.Textbox(label="Tu bebé llora por:")
)
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
with gr.Column(visible=False) as pag_monitor:
gr.Markdown("<h2>Monitor</h2>")
audio_stream = gr.Audio(
# min_length=1.0, # mirar por qué no va esto
format="wav",
label="Baby recorder",
type="filepath",
streaming=True
)
audio_stream.stream(
predict_stream,
inputs=audio_stream,
outputs=gr.Textbox(label="Tu bebé está:"),
)
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
boton_pagina_1.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
boton_pagina_2.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
demo.launch(share=True)