import gradio as gr import torch from transformers import ( AutoModelForCTC, Wav2Vec2Processor, AutoProcessor, WhisperProcessor, WhisperForConditionalGeneration ) import librosa # Initialize device - will work on CPU if GPU not available DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class ModelManager: def __init__(self): self.asr_models = {} def load_wav2vec2_base(self): model = AutoModelForCTC.from_pretrained("kabir259/w2v2-base_kabir").to(DEVICE) processor = Wav2Vec2Processor.from_pretrained("kabir259/w2v2-base_kabir") return model, processor def load_wav2vec2_bert(self): model = AutoModelForCTC.from_pretrained("Kabir259/w2v2-BERT_kabir").to(DEVICE) processor = AutoProcessor.from_pretrained("Kabir259/w2v2-BERT_kabir") return model, processor def load_whisper_small(self): model = WhisperForConditionalGeneration.from_pretrained("Kabir259/whisper-small_kabir").to(DEVICE) processor = WhisperProcessor.from_pretrained("Kabir259/whisper-small_kabir") model.generation_config.task = "transcribe" return model, processor def get_asr_model(self, model_name): if model_name not in self.asr_models: if model_name == "wav2vec2-base": self.asr_models[model_name] = self.load_wav2vec2_base() elif model_name == "wav2vec2-BERT": self.asr_models[model_name] = self.load_wav2vec2_bert() elif model_name == "whisper-small": self.asr_models[model_name] = self.load_whisper_small() return self.asr_models[model_name] def process_audio(audio_path, asr_model_name, model_manager): model, processor = model_manager.get_asr_model(asr_model_name) # Load and preprocess audio audio, sr = librosa.load(audio_path, sr=16000) # Load audio with a fixed sampling rate if asr_model_name in ["wav2vec2-base", "wav2vec2-BERT"]: # Process audio for wav2vec2 models input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.to(DEVICE) with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] else: # whisper model input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE) with torch.no_grad(): predicted_ids = model.generate(input_features) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return transcription def process_pipeline(audio, asr_model_choice, model_manager): if audio is None: return "Please record some audio first." transcription = process_audio(audio, asr_model_choice, model_manager) return transcription # Initialize the model manager model_manager = ModelManager() # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Medical Speech Recognition System 🥼") with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="Record Audio", type="filepath" ) asr_model_choice = gr.Dropdown( choices=["wav2vec2-base", "wav2vec2-BERT", "whisper-small"], value="wav2vec2-base", label="Select ASR Model" ) submit_btn = gr.Button("Transcribe") with gr.Column(): transcription_output = gr.Textbox( label="Transcribed Text", placeholder="Transcription will appear here..." ) submit_btn.click( fn=lambda audio, asr_choice: process_pipeline(audio, asr_choice, model_manager), inputs=[audio_input, asr_model_choice], outputs=transcription_output ) if __name__ == "__main__": demo.launch(share=True)