jcho02 commited on
Commit
55127ad
·
verified ·
1 Parent(s): a15c916

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from transformers import WhisperModel, WhisperFeatureExtractor
8
+ import datasets
9
+ from datasets import load_dataset, DatasetDict, Audio
10
+ from huggingface_hub import PyTorchModelHubMixin
11
+
12
+ # Define data class
13
+ class SpeechInferenceDataset(Dataset):
14
+ def __init__(self, audio_data, text_processor):
15
+ self.audio_data = audio_data
16
+ self.text_processor = text_processor
17
+
18
+ def __len__(self):
19
+ return len(self.audio_data)
20
+
21
+ def __getitem__(self, index):
22
+ inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
23
+ return_tensors="pt",
24
+ sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
25
+ input_features = inputs.input_features
26
+ # Assuming 'encoder' is defined or available in the scope
27
+ decoder_input_ids = torch.tensor([[1, 1]]) * encoder.config.decoder_start_token_id
28
+ return input_features, decoder_input_ids
29
+
30
+ # Define model class
31
+ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
32
+ def __init__(self, config):
33
+ super(SpeechClassifier, self).__init__()
34
+ self.encoder = WhisperModel.from_pretrained(config["encoder"])
35
+ self.classifier = nn.Sequential(
36
+ nn.Linear(self.encoder.config.hidden_size, 4096),
37
+ nn.ReLU(),
38
+ nn.Linear(4096, 2048),
39
+ nn.ReLU(),
40
+ nn.Linear(2048, 1024),
41
+ nn.ReLU(),
42
+ nn.Linear(1024, 512),
43
+ nn.ReLU(),
44
+ nn.Linear(512, config["num_labels"])
45
+ )
46
+
47
+ def forward(self, input_features, decoder_input_ids):
48
+ outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
49
+ pooled_output = outputs['last_hidden_state'][:, 0, :]
50
+ logits = self.classifier(pooled_output)
51
+ return logits
52
+
53
+ # Prepare data function
54
+ def prepare_data(audio_file_path, model_checkpoint="openai/whisper-base"):
55
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
56
+ inference_data = datasets.Dataset.from_dict({"path": [audio_file_path], "audio": [audio_file_path]}).cast_column("audio", Audio(sampling_rate=16_000))
57
+ inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor)
58
+ inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)
59
+ input_features, decoder_input_ids = next(iter(inference_loader))
60
+ # Ensure 'device' is defined or replace with 'torch.device("cpu")' if GPU is not available
61
+ input_features = input_features.squeeze(1).to(device)
62
+ decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
63
+ return input_features, decoder_input_ids
64
+
65
+ # Prediction function
66
+ def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labels": 2}):
67
+ input_features, decoder_input_ids = prepare_data(audio_file_path)
68
+ model = SpeechClassifier(config)
69
+ model.eval()
70
+ with torch.no_grad():
71
+ logits = model(input_features, decoder_input_ids)
72
+ predicted_ids = int(torch.argmax(logits, dim=-1))
73
+ return predicted_ids
74
+
75
+ # Gradio Interface function
76
+ def gradio_interface(uploaded_file):
77
+ with open(uploaded_file.name, "wb") as f:
78
+ f.write(uploaded_file.read())
79
+ prediction = predict(uploaded_file.name)
80
+ label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
81
+ return label
82
+
83
+ # Create and launch Gradio Interface with File upload input
84
+ iface = gr.Interface(fn=gradio_interface,
85
+ inputs=gr.inputs.File(label="Upload Audio File"),
86
+ outputs="text")
87
+ iface.launch()