Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# Import necessary libraries
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import torch.nn as nn
|
@@ -8,12 +7,13 @@ import datasets
|
|
8 |
from datasets import load_dataset, DatasetDict, Audio
|
9 |
from huggingface_hub import PyTorchModelHubMixin
|
10 |
import numpy as np
|
11 |
-
import tempfile
|
12 |
-
import os
|
13 |
|
14 |
# Ensure you have the device setup (cuda or cpu)
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
|
|
|
|
|
|
17 |
# Define data class
|
18 |
class SpeechInferenceDataset(Dataset):
|
19 |
def __init__(self, audio_data, text_processor):
|
@@ -28,8 +28,7 @@ class SpeechInferenceDataset(Dataset):
|
|
28 |
return_tensors="pt",
|
29 |
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
|
30 |
input_features = inputs.input_features
|
31 |
-
# Modify
|
32 |
-
decoder_input_ids = torch.tensor([[1, 1]])
|
33 |
return input_features, decoder_input_ids
|
34 |
|
35 |
# Define model class
|
@@ -55,34 +54,40 @@ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
|
|
55 |
logits = self.classifier(pooled_output)
|
56 |
return logits
|
57 |
|
58 |
-
# Prepare data function
|
59 |
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
|
60 |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
-
# Prediction function
|
66 |
-
def predict(audio_data, sampling_rate, config
|
67 |
-
|
|
|
68 |
model = SpeechClassifier(config).to(device)
|
69 |
-
#
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
# Gradio Interface
|
74 |
def gradio_file_interface(uploaded_file):
|
75 |
-
#
|
76 |
-
|
|
|
|
|
77 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
78 |
return label
|
79 |
|
80 |
-
# Gradio Interface function for microphone input
|
81 |
def gradio_mic_interface(mic_input):
|
82 |
-
#
|
83 |
-
|
84 |
-
sampling_rate = mic_input['sample_rate']
|
85 |
-
prediction = predict(audio_data, sampling_rate, config)
|
86 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
87 |
return label
|
88 |
|
@@ -93,18 +98,17 @@ demo = gr.Blocks()
|
|
93 |
with demo:
|
94 |
mic_transcribe = gr.Interface(
|
95 |
fn=gradio_mic_interface,
|
96 |
-
inputs=gr.Audio(type="numpy"), #
|
97 |
outputs=gr.Textbox(label="Prediction")
|
98 |
)
|
99 |
|
100 |
file_transcribe = gr.Interface(
|
101 |
fn=gradio_file_interface,
|
102 |
-
inputs=gr.Audio(type="
|
103 |
outputs=gr.Textbox(label="Prediction")
|
104 |
)
|
105 |
|
106 |
-
# Use a tabbed interface to switch between the microphone and file upload interfaces
|
107 |
gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
|
108 |
|
109 |
-
# Launch the demo
|
110 |
demo.launch(debug=True)
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
7 |
from datasets import load_dataset, DatasetDict, Audio
|
8 |
from huggingface_hub import PyTorchModelHubMixin
|
9 |
import numpy as np
|
|
|
|
|
10 |
|
11 |
# Ensure you have the device setup (cuda or cpu)
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
|
14 |
+
# Define the config for your model
|
15 |
+
config = {"encoder": "openai/whisper-base", "num_labels": 2}
|
16 |
+
|
17 |
# Define data class
|
18 |
class SpeechInferenceDataset(Dataset):
|
19 |
def __init__(self, audio_data, text_processor):
|
|
|
28 |
return_tensors="pt",
|
29 |
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
|
30 |
input_features = inputs.input_features
|
31 |
+
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
|
|
|
32 |
return input_features, decoder_input_ids
|
33 |
|
34 |
# Define model class
|
|
|
54 |
logits = self.classifier(pooled_output)
|
55 |
return logits
|
56 |
|
57 |
+
# Prepare data function
|
58 |
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
|
59 |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
|
60 |
+
inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
|
61 |
+
input_features = inputs.input_features
|
62 |
+
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
|
63 |
+
return input_features.to(device), decoder_input_ids.to(device)
|
64 |
|
65 |
+
# Prediction function
|
66 |
+
def predict(audio_data, sampling_rate, config):
|
67 |
+
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
|
68 |
+
|
69 |
model = SpeechClassifier(config).to(device)
|
70 |
+
# Here we load the model from Hugging Face Hub
|
71 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
|
72 |
+
|
73 |
+
model.eval()
|
74 |
+
with torch.no_grad():
|
75 |
+
logits = model(input_features, decoder_input_ids)
|
76 |
+
predicted_ids = int(torch.argmax(logits, dim=-1))
|
77 |
+
return predicted_ids
|
78 |
|
79 |
+
# Gradio Interface functions
|
80 |
def gradio_file_interface(uploaded_file):
|
81 |
+
# Assuming the uploaded_file is a filepath (str)
|
82 |
+
with open(uploaded_file, "rb") as f:
|
83 |
+
audio_data = np.frombuffer(f.read(), np.int16)
|
84 |
+
prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
|
85 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
86 |
return label
|
87 |
|
|
|
88 |
def gradio_mic_interface(mic_input):
|
89 |
+
# mic_input is a dictionary with 'data' and 'sample_rate' keys
|
90 |
+
prediction = predict(mic_input['data'], mic_input['sample_rate'], config)
|
|
|
|
|
91 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
92 |
return label
|
93 |
|
|
|
98 |
with demo:
|
99 |
mic_transcribe = gr.Interface(
|
100 |
fn=gradio_mic_interface,
|
101 |
+
inputs=gr.Audio(source="microphone", type="numpy"), # Correct type for microphone
|
102 |
outputs=gr.Textbox(label="Prediction")
|
103 |
)
|
104 |
|
105 |
file_transcribe = gr.Interface(
|
106 |
fn=gradio_file_interface,
|
107 |
+
inputs=gr.Audio(source="upload", type="file"), # Correct type for file upload
|
108 |
outputs=gr.Textbox(label="Prediction")
|
109 |
)
|
110 |
|
|
|
111 |
gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
|
112 |
|
113 |
+
# Launch the demo
|
114 |
demo.launch(debug=True)
|