Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import WhisperModel, WhisperFeatureExtractor | |
import datasets | |
from datasets import load_dataset, DatasetDict, Audio | |
from huggingface_hub import PyTorchModelHubMixin | |
import numpy as np | |
# Ensure you have the device setup (cuda or cpu) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define the config for your model | |
config = {"encoder": "openai/whisper-base", "num_labels": 2} | |
# Define data class | |
class SpeechInferenceDataset(Dataset): | |
def __init__(self, audio_data, text_processor): | |
self.audio_data = audio_data | |
self.text_processor = text_processor | |
def __len__(self): | |
return len(self.audio_data) | |
def __getitem__(self, index): | |
inputs = self.text_processor(self.audio_data[index]["audio"]["array"], | |
return_tensors="pt", | |
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"]) | |
input_features = inputs.input_features | |
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements | |
return input_features, decoder_input_ids | |
# Define model class | |
class SpeechClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, config): | |
super(SpeechClassifier, self).__init__() | |
self.encoder = WhisperModel.from_pretrained(config["encoder"]) | |
self.classifier = nn.Sequential( | |
nn.Linear(self.encoder.config.hidden_size, 4096), | |
nn.ReLU(), | |
nn.Linear(4096, 2048), | |
nn.ReLU(), | |
nn.Linear(2048, 1024), | |
nn.ReLU(), | |
nn.Linear(1024, 512), | |
nn.ReLU(), | |
nn.Linear(512, config["num_labels"]) | |
) | |
def forward(self, input_features, decoder_input_ids): | |
outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids) | |
pooled_output = outputs['last_hidden_state'][:, 0, :] | |
logits = self.classifier(pooled_output) | |
return logits | |
# Prepare data function | |
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): | |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) | |
inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt") | |
input_features = inputs.input_features | |
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements | |
return input_features.to(device), decoder_input_ids.to(device) | |
# Prediction function | |
def predict(audio_data, sampling_rate, config): | |
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) | |
model = SpeechClassifier(config).to(device) | |
# Here we load the model from Hugging Face Hub | |
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) | |
model.eval() | |
with torch.no_grad(): | |
logits = model(input_features, decoder_input_ids) | |
predicted_ids = int(torch.argmax(logits, dim=-1)) | |
return predicted_ids | |
# Gradio Interface functions | |
def gradio_file_interface(uploaded_file): | |
# Assuming the uploaded_file is a filepath (str) | |
with open(uploaded_file, "rb") as f: | |
audio_data = np.frombuffer(f.read(), np.int16) | |
prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate | |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" | |
return label | |
def gradio_mic_interface(mic_input): | |
# mic_input is a dictionary with 'data' and 'sample_rate' keys | |
prediction = predict(mic_input['data'], mic_input['sample_rate'], config) | |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" | |
return label | |
# Initialize Blocks | |
demo = gr.Blocks() | |
# Define the interfaces inside the Blocks context | |
with demo: | |
#mic_transcribe = gr.Interface( | |
# fn=gradio_mic_interface, | |
# inputs=gr.Audio(type="numpy"), # Use numpy for real-time audio like microphone | |
# outputs=gr.Textbox(label="Prediction") | |
#) | |
file_transcribe = gr.Interface( | |
fn=gradio_file_interface, | |
inputs=gr.Audio(type="filepath"), # Use filepath for uploaded audio files | |
outputs=gr.Textbox(label="Prediction") | |
) | |
# Combine interfaces into a tabbed interface | |
#gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"]) | |
# Launch the demo with debugging enabled | |
demo.launch(debug=True) | |