|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import WhisperModel, WhisperFeatureExtractor |
|
import datasets |
|
from datasets import load_dataset, DatasetDict, Audio |
|
from huggingface_hub import PyTorchModelHubMixin |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): |
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) |
|
inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt") |
|
input_features = inputs.input_features |
|
decoder_input_ids = torch.tensor([[1, 1]]) |
|
return input_features.to(device), decoder_input_ids.to(device) |
|
|
|
|
|
def predict(audio_data, sampling_rate, config): |
|
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) |
|
|
|
model = SpeechClassifier(config).to(device) |
|
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
logits = model(input_features, decoder_input_ids) |
|
predicted_ids = int(torch.argmax(logits, dim=-1)) |
|
return predicted_ids |
|
|
|
|
|
def gradio_interface(audio_input): |
|
if isinstance(audio_input, tuple): |
|
|
|
audio_data, sample_rate = audio_input |
|
else: |
|
|
|
with open(audio_input, "rb") as f: |
|
audio_data = np.frombuffer(f.read(), np.int16) |
|
sample_rate = 16000 |
|
|
|
prediction = predict(audio_data, sample_rate, config) |
|
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" |
|
return label |
|
|
|
|
|
demo = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Audio(type="numpy", label="Upload or Record Audio"), |
|
outputs=gr.Textbox(label="Prediction") |
|
) |
|
|
|
|
|
demo.launch(debug=True) |
|
|