Spaces:
Sleeping
Sleeping
File size: 1,871 Bytes
53f6532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
import gradio as gr
from model import predict_params, AudioDataset
from interfaz import estilo, my_theme
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_class, id2label_class = predict_params(model_path="distilhubert-finetuned-mixed-data", dataset_path="data/mixed_data", filter_white_noise=True)
def call(audiopath, model, dataset_path, filter_white_noise):
model.to(device)
model.eval()
audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,)
processed_audio = audio_dataset.preprocess_audio(audiopath)
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
return logits
def predict(audio_path_pred):
with torch.no_grad():
logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True)
predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
label_class = id2label_class[predicted_class_ids_class]
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
label_class = label_mapping.get(predicted_class_ids_class, label_class)
return label_class
def cambiar_pestaña():
return gr.update(visible=False), gr.update(visible=True)
with gr.Blocks(theme=my_theme) as demo:
estilo()
with gr.Column(visible=False) as pag_predictor:
gr.Markdown("<h2>Predictor</h2>")
audio_input = gr.Audio(
min_length=1.0,
format="wav",
label="Baby recorder",
type="filepath",
)
gr.Button("¿Por qué llora?").click(
predict,
inputs=audio_input,
outputs=gr.Textbox(label="Tu bebé llora por:")
)
demo.launch(share=True)
|