Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import WhisperModel, WhisperFeatureExtractor | |
| import datasets | |
| from datasets import load_dataset, DatasetDict, Audio | |
| from huggingface_hub import PyTorchModelHubMixin | |
| import numpy as np | |
| # Ensure you have the device setup (cuda or cpu) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Define the config for your model | |
| config = {"encoder": "openai/whisper-base", "num_labels": 2} | |
| # Define data class | |
| class SpeechInferenceDataset(Dataset): | |
| def __init__(self, audio_data, text_processor): | |
| self.audio_data = audio_data | |
| self.text_processor = text_processor | |
| def __len__(self): | |
| return len(self.audio_data) | |
| def __getitem__(self, index): | |
| inputs = self.text_processor(self.audio_data[index]["audio"]["array"], | |
| return_tensors="pt", | |
| sampling_rate=self.audio_data[index]["audio"]["sampling_rate"]) | |
| input_features = inputs.input_features | |
| decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements | |
| return input_features, decoder_input_ids | |
| # Define model class | |
| class SpeechClassifier(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, config): | |
| super(SpeechClassifier, self).__init__() | |
| self.encoder = WhisperModel.from_pretrained(config["encoder"]) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.encoder.config.hidden_size, 4096), | |
| nn.ReLU(), | |
| nn.Linear(4096, 2048), | |
| nn.ReLU(), | |
| nn.Linear(2048, 1024), | |
| nn.ReLU(), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, config["num_labels"]) | |
| ) | |
| def forward(self, input_features, decoder_input_ids): | |
| outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids) | |
| pooled_output = outputs['last_hidden_state'][:, 0, :] | |
| logits = self.classifier(pooled_output) | |
| return logits | |
| # Prepare data function | |
| def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) | |
| inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt") | |
| input_features = inputs.input_features | |
| decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements | |
| return input_features.to(device), decoder_input_ids.to(device) | |
| # Prediction function | |
| def predict(audio_data, sampling_rate, config): | |
| input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) | |
| model = SpeechClassifier(config).to(device) | |
| # Here we load the model from Hugging Face Hub | |
| model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) | |
| model.eval() | |
| with torch.no_grad(): | |
| logits = model(input_features, decoder_input_ids) | |
| predicted_ids = int(torch.argmax(logits, dim=-1)) | |
| return predicted_ids | |
| # Gradio Interface functions | |
| def gradio_file_interface(uploaded_file): | |
| # Assuming the uploaded_file is a filepath (str) | |
| with open(uploaded_file, "rb") as f: | |
| audio_data = np.frombuffer(f.read(), np.int16) | |
| prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate | |
| label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" | |
| return label | |
| def gradio_mic_interface(mic_input): | |
| # mic_input is a dictionary with 'data' and 'sample_rate' keys | |
| prediction = predict(mic_input['data'], mic_input['sample_rate'], config) | |
| label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" | |
| return label | |
| # Initialize Blocks | |
| demo = gr.Blocks() | |
| # Define the interfaces inside the Blocks context | |
| with demo: | |
| #mic_transcribe = gr.Interface( | |
| # fn=gradio_mic_interface, | |
| # inputs=gr.Audio(type="numpy"), # Use numpy for real-time audio like microphone | |
| # outputs=gr.Textbox(label="Prediction") | |
| #) | |
| file_transcribe = gr.Interface( | |
| fn=gradio_file_interface, | |
| inputs=gr.Audio(type="filepath"), # Use filepath for uploaded audio files | |
| outputs=gr.Textbox(label="Prediction") | |
| ) | |
| # Combine interfaces into a tabbed interface | |
| #gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"]) | |
| # Launch the demo with debugging enabled | |
| demo.launch(debug=True) | |