jcho02's picture
Update app.py
a9b6797 verified
import gradio as gr
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor
import datasets
from datasets import load_dataset, DatasetDict, Audio
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
# Ensure you have the device setup (cuda or cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the config for your model
config = {"encoder": "openai/whisper-base", "num_labels": 2}
# Define data class
class SpeechInferenceDataset(Dataset):
def __init__(self, audio_data, text_processor):
self.audio_data = audio_data
self.text_processor = text_processor
def __len__(self):
return len(self.audio_data)
def __getitem__(self, index):
inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
return_tensors="pt",
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
return input_features, decoder_input_ids
# Define model class
class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(SpeechClassifier, self).__init__()
self.encoder = WhisperModel.from_pretrained(config["encoder"])
self.classifier = nn.Sequential(
nn.Linear(self.encoder.config.hidden_size, 4096),
nn.ReLU(),
nn.Linear(4096, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, config["num_labels"])
)
def forward(self, input_features, decoder_input_ids):
outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
pooled_output = outputs['last_hidden_state'][:, 0, :]
logits = self.classifier(pooled_output)
return logits
# Prepare data function
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
return input_features.to(device), decoder_input_ids.to(device)
# Prediction function
def predict(audio_data, sampling_rate, config):
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
model = SpeechClassifier(config).to(device)
# Here we load the model from Hugging Face Hub
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
model.eval()
with torch.no_grad():
logits = model(input_features, decoder_input_ids)
predicted_ids = int(torch.argmax(logits, dim=-1))
return predicted_ids
# Gradio Interface functions
def gradio_file_interface(uploaded_file):
# Assuming the uploaded_file is a filepath (str)
with open(uploaded_file, "rb") as f:
audio_data = np.frombuffer(f.read(), np.int16)
prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
return label
def gradio_mic_interface(mic_input):
# mic_input is a dictionary with 'data' and 'sample_rate' keys
prediction = predict(mic_input['data'], mic_input['sample_rate'], config)
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
return label
# Initialize Blocks
demo = gr.Blocks()
# Define the interfaces inside the Blocks context
with demo:
#mic_transcribe = gr.Interface(
# fn=gradio_mic_interface,
# inputs=gr.Audio(type="numpy"), # Use numpy for real-time audio like microphone
# outputs=gr.Textbox(label="Prediction")
#)
file_transcribe = gr.Interface(
fn=gradio_file_interface,
inputs=gr.Audio(type="filepath"), # Use filepath for uploaded audio files
outputs=gr.Textbox(label="Prediction")
)
# Combine interfaces into a tabbed interface
#gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
# Launch the demo with debugging enabled
demo.launch(debug=True)