import torch from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan from datasets import load_dataset import gradio as gr import numpy as np from transformers import MarianMTModel, MarianTokenizer, T5Tokenizer,T5ForConditionalGeneration from transformers import WhisperProcessor, WhisperForConditionalGeneration import torchaudio from torchaudio.transforms import Resample device = "cuda" if torch.cuda.is_available() else "cpu" ''' Importing necessary models for TTS module ''' processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") speaker_embeddings = torch.tensor(embeddings_dataset[2000]["xvector"]).unsqueeze(0) tts_model.to(device) vocoder.to(device) ''' Importing necessary models for text to text translation model ''' model_name='sinahalafinetuned' tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) ''' Importing necessary models for ASR module ''' asr_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2") forced_decoder_ids = asr_processor.get_decoder_prompt_ids(language="sinhala", task="transcribe") def TTS(text, speaker_embeddings): if len(text.strip()) == 0: return np.zeros(0).astype(np.int16) inputs = processor(text=text, return_tensors="pt") input_ids = inputs["input_ids"] speech = tts_model.generate_speech(input_ids, speaker_embeddings, vocoder=vocoder) speech = (speech.numpy() * 32767).astype(np.int16) return (16000, speech) # This function performs the text to text translation def translate(text): inputs = tokenizer(text, return_tensors="pt") translated = model.generate(**inputs, max_length=50) translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] return translated_text # This function performs the speech recognition work def ASR(audio): waveform, sample_rate = torchaudio.load(audio) resampler = Resample(orig_freq=sample_rate, new_freq=16000) resampled_waveform = resampler(waveform) input_features = asr_processor(resampled_waveform.numpy(), sampling_rate=16000, return_tensors="pt").input_features predicted_ids = asr_model.generate(input_features, forced_decoder_ids=forced_decoder_ids) transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True) return transcription # Gradio Interface for Speech-to-Speech Translation def speech_to_speech(audio): text = ASR(audio) print("*************************** Performing ASR **********************************") print(text) translated_text = translate(text) print("*************************** Performing TT **********************************") print(translated_text) generated_audio = TTS(translated_text,speaker_embeddings) print("*************************** Performing TTS **********************************") return translated_text,generated_audio demo = gr.Blocks() mic_translate = gr.Interface( fn=speech_to_speech, inputs=gr.Audio(source="microphone", type="filepath"), outputs=[gr.Textbox(label="English Transcript"),gr.Audio(label="Generated Speech", type="numpy")], title='SPEECH-TRANSLATION', live=True ) file_translate = gr.Interface( fn=speech_to_speech, inputs=gr.Audio(source="upload", type="filepath"), outputs=[gr.Textbox(label="English Transcript"),gr.Audio(label="Generated Speech", type="numpy")], title='SPEECH-TRANSLATION', live=True ) with demo: gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"]) demo.launch(debug=True)