File size: 1,871 Bytes
53f6532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import gradio as gr
from model import predict_params, AudioDataset
from interfaz import estilo, my_theme

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_class, id2label_class = predict_params(model_path="distilhubert-finetuned-mixed-data", dataset_path="data/mixed_data", filter_white_noise=True)

def call(audiopath, model, dataset_path, filter_white_noise):
    model.to(device)
    model.eval()
    audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,)
    processed_audio = audio_dataset.preprocess_audio(audiopath)
    inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    return logits

def predict(audio_path_pred):
    with torch.no_grad():
        logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True)
        predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
        label_class = id2label_class[predicted_class_ids_class]
        label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
        label_class = label_mapping.get(predicted_class_ids_class, label_class)
    return label_class

def cambiar_pestaña():
    return gr.update(visible=False), gr.update(visible=True)

with gr.Blocks(theme=my_theme) as demo:
    estilo()
    with gr.Column(visible=False) as pag_predictor:
        gr.Markdown("<h2>Predictor</h2>")
        audio_input = gr.Audio(
            min_length=1.0,
            format="wav",
            label="Baby recorder",
            type="filepath",
            )
        gr.Button("¿Por qué llora?").click(
            predict,
            inputs=audio_input,
            outputs=gr.Textbox(label="Tu bebé llora por:")
            )
demo.launch(share=True)