Spaces:
Sleeping
Sleeping
File size: 6,632 Bytes
1e6dc54 5cf41d0 ace06e3 abdf62b cc3562b 6d1143c 2ca1b49 166aa6c d0064d4 ace06e3 166aa6c abdf62b 166aa6c 1e6dc54 abdf62b 166aa6c 1e6dc54 166aa6c abdf62b 1d21972 166aa6c 1d21972 abdf62b 166aa6c 1d21972 166aa6c abdf62b 166aa6c abdf62b ace06e3 ebf42ac ace06e3 abdf62b ace06e3 1e6dc54 ace06e3 abdf62b ebf42ac 166aa6c ebf42ac 166aa6c abdf62b 166aa6c abdf62b 166aa6c abdf62b 166aa6c abdf62b 166aa6c abdf62b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import torch
import gradio as gr
from huggingface_hub import InferenceClient
from model import predict_params, AudioDataset
from interfaz import estilo, my_theme
token = os.getenv("HF_TOKEN")
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_class, id2label_class = predict_params(model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data", dataset_path="data/mixed_data", filter_white_noise=True)
model_mon, id2label_mon = predict_params(model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector", dataset_path="data/baby_cry_detection", filter_white_noise=False)
def call(audiopath, model, dataset_path, filter_white_noise):
model.to(device)
model.eval()
audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,)
processed_audio = audio_dataset.preprocess_audio(audiopath)
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
return logits
def predict(audio_path_pred):
with torch.no_grad():
logits = call(audio_path_pred, model=model_class, dataset_path="A-POR-LOS-8000/data/mixed_data", filter_white_noise=True)
predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
label_class = id2label_class[predicted_class_ids_class]
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
label_class = label_mapping.get(predicted_class_ids_class, label_class)
return label_class
def predict_stream(audio_path_stream):
with torch.no_grad():
logits = call(audio_path_stream, model=model_mon, dataset_path="A-POR-LOS-8000/data/baby_cry_detection", filter_white_noise=False)
probabilities = torch.nn.functional.softmax(logits, dim=-1)
crying_probabilities = probabilities[:, 1]
avg_crying_probability = crying_probabilities.mean()*100
if avg_crying_probability < 15:
label_class = predict(audio_path_stream)
return "Está llorando por:", f"{label_class}. Probabilidad: {avg_crying_probability:.1f}%"
else:
return "No está llorando.", f"Probabilidad: {avg_crying_probability:.1f}%"
def decibelios(audio_path_stream):
with torch.no_grad():
logits = call(audio_path_stream, model=model_mon, dataset_path="A-POR-LOS-8000/data/baby_cry_detection", filter_white_noise=False)
rms = torch.sqrt(torch.mean(torch.square(logits)))
db_level = 20 * torch.log10(rms + 1e-6).item()
return db_level
def mostrar_decibelios(audio_path_stream, visual_threshold):
db_level = decibelios(audio_path_stream)
if db_level < visual_threshold:
return f"Prediciendo. Decibelios: {db_level:.2f}"
elif db_level > visual_threshold:
return "No detectamos ruido..."
def predict_stream_decib(audio_path_stream, visual_threshold):
db_level = decibelios(audio_path_stream)
if db_level < visual_threshold:
llorando, probabilidad = predict_stream(audio_path_stream)
return f"{llorando} {probabilidad}"
else:
return ""
def chatbot_config(message, history: list[tuple[str, str]]):
system_message = "You are a Chatbot specialized in baby health and care."
max_tokens = 512
temperature = 0.7
top_p = 0.95
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
token = message_response.choices[0].delta.content
response += token
yield response
def cambiar_pestaña():
return gr.update(visible=False), gr.update(visible=True)
with gr.Blocks(theme=my_theme) as demo:
estilo()
with gr.Column(visible=True) as chatbot:
gr.Markdown("<h2>Asistente</h2>")
gr.ChatInterface(
chatbot_config # TODO: Mirar argumentos
)
gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
with gr.Row():
with gr.Column():
gr.Markdown("<h2>Predictor</h2>")
boton_predictor = gr.Button("Prueba el predictor")
gr.Markdown("<p>Descubre por qué llora tu bebé</p>")
with gr.Column():
gr.Markdown("<h2>Monitor</h2>")
boton_monitor = gr.Button("Prueba el monitor")
gr.Markdown("<p>Monitoriza si tu hijo está llorando y por qué, sin levantarte del sofá</p>")
with gr.Column(visible=False) as pag_predictor:
gr.Markdown("<h2>Predictor</h2>")
audio_input = gr.Audio(
min_length=1.0,
format="wav",
label="Baby recorder",
type="filepath",
)
gr.Button("¿Por qué llora?").click(
predict,
inputs=audio_input,
outputs=gr.Textbox(label="Tu bebé llora por:")
)
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
with gr.Column(visible=False) as pag_monitor:
gr.Markdown("<h2>Monitor</h2>")
audio_stream = gr.Audio(
format="wav",
label="Baby recorder",
type="filepath",
streaming=True
)
threshold_db = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=30,
label="Umbral de dB para activar la predicción"
)
audio_stream.stream(
mostrar_decibelios,
inputs=[audio_stream, threshold_db],
outputs=gr.Textbox(value="Esperando...", label="Estado")
)
audio_stream.stream(
predict_stream_decib,
inputs=[audio_stream, threshold_db],
outputs=gr.Textbox(value="", label="Tu bebé:")
)
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
demo.launch(share=True)
|