import os import torch import gradio as gr import numpy as np import soundfile as sf import librosa from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, VitsModel, AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration ) from typing import Optional, Tuple, Dict, List class TalklasTranslator: """ Speech-to-Speech translation pipeline for Philippine languages. Uses MMS/Whisper for STT, NLLB for MT, and MMS for TTS with pitch-shifting for voice gender. """ LANGUAGE_MAPPING = { "English": "eng", "Tagalog": "tgl", "Cebuano": "ceb", "Ilocano": "ilo", "Waray": "war", "Pangasinan": "pag" } NLLB_LANGUAGE_CODES = { "eng": "eng_Latn", "tgl": "tgl_Latn", "ceb": "ceb_Latn", "ilo": "ilo_Latn", "war": "war_Latn", "pag": "pag_Latn" } def __init__( self, source_lang: str = "eng", target_lang: str = "tgl", device: Optional[str] = None ): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.source_lang = source_lang self.target_lang = target_lang self.sample_rate = 16000 print(f"Initializing Talklas Translator on {self.device}") # Initialize models self._initialize_stt_model() self._initialize_mt_model() self._initialize_tts_model() def _initialize_stt_model(self): """Initialize speech-to-text model with fallback to Whisper""" try: print("Loading STT model...") try: # Try loading MMS model first self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all") # Set language if available if self.source_lang in self.stt_processor.tokenizer.vocab.keys(): self.stt_processor.tokenizer.set_target_lang(self.source_lang) self.stt_model.load_adapter(self.source_lang) print(f"Loaded MMS STT model for {self.source_lang}") else: print(f"Language {self.source_lang} not in MMS, using default") except Exception as mms_error: print(f"MMS loading failed: {mms_error}") # Fallback to Whisper print("Loading Whisper as fallback...") self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small") self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") print("Loaded Whisper STT model") self.stt_model.to(self.device) except Exception as e: print(f"STT model initialization failed: {e}") raise RuntimeError("Could not initialize STT model") def _initialize_mt_model(self): """Initialize machine translation model""" try: print("Loading NLLB Translation model...") self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") self.mt_model.to(self.device) print("NLLB Translation model loaded") except Exception as e: print(f"MT model initialization failed: {e}") raise def _initialize_tts_model(self): """Initialize text-to-speech model""" try: print("Loading TTS model...") try: self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}") self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}") print(f"Loaded TTS model for {self.target_lang}") except Exception as tts_error: print(f"Target language TTS failed: {tts_error}") print("Falling back to English TTS") self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") self.tts_model.to(self.device) except Exception as e: print(f"TTS model initialization failed: {e}") raise def update_languages(self, source_lang: str, target_lang: str) -> str: """Update languages and reinitialize models if needed""" if source_lang == self.source_lang and target_lang == self.target_lang: return "Languages already set" self.source_lang = source_lang self.target_lang = target_lang # Only reinitialize models that depend on language self._initialize_stt_model() self._initialize_tts_model() return f"Languages updated to {source_lang} → {target_lang}" def speech_to_text(self, audio_path: str) -> str: """Convert speech to text using loaded STT model""" try: waveform, sample_rate = sf.read(audio_path) if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) inputs = self.stt_processor( waveform, sampling_rate=16000, return_tensors="pt" ).to(self.device) with torch.no_grad(): if isinstance(self.stt_model, WhisperForConditionalGeneration): # Whisper model generated_ids = self.stt_model.generate(**inputs) transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] else: # MMS model (Wav2Vec2ForCTC) logits = self.stt_model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = self.stt_processor.batch_decode(predicted_ids)[0] return transcription except Exception as e: print(f"Speech recognition failed: {e}") raise RuntimeError("Speech recognition failed") def translate_text(self, text: str) -> str: """Translate text using NLLB model""" try: source_code = self.NLLB_LANGUAGE_CODES[self.source_lang] target_code = self.NLLB_LANGUAGE_CODES[self.target_lang] self.mt_tokenizer.src_lang = source_code inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device) with torch.no_grad(): generated_tokens = self.mt_model.generate( **inputs, forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code), max_length=448 ) return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] except Exception as e: print(f"Translation failed: {e}") raise RuntimeError("Text translation failed") def text_to_speech(self, text: str, voice_gender: str = "neutral") -> Tuple[int, np.ndarray]: """Convert text to speech with optional pitch-shifting for voice gender""" try: inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device) with torch.no_grad(): output = self.tts_model(**inputs) speech = output.waveform.cpu().numpy().squeeze() # Apply pitch-shifting based on voice_gender if voice_gender.lower() == "female": # Increase pitch (e.g., +4 semitones for a more traditionally feminine voice) speech = librosa.effects.pitch_shift(speech, sr=self.tts_model.config.sampling_rate, n_steps=1) elif voice_gender.lower() == "male": # Decrease pitch (e.g., -4 semitones for a more traditionally masculine voice) speech = librosa.effects.pitch_shift(speech, sr=self.tts_model.config.sampling_rate, n_steps=-2) # Convert to 16-bit PCM speech = (speech * 32767).astype(np.int16) return self.tts_model.config.sampling_rate, speech except Exception as e: print(f"Speech synthesis failed: {e}") raise RuntimeError("Speech synthesis failed") def translate_speech(self, audio_path: str, voice_gender: str = "neutral") -> Dict: """Full speech-to-speech translation with voice gender option""" try: source_text = self.speech_to_text(audio_path) translated_text = self.translate_text(source_text) sample_rate, audio = self.text_to_speech(translated_text, voice_gender) return { "source_text": source_text, "translated_text": translated_text, "output_audio": (sample_rate, audio), "performance": "Translation successful" } except Exception as e: return { "source_text": "Error", "translated_text": "Error", "output_audio": (16000, np.zeros(1000, dtype=np.int16)), "performance": f"Error: {str(e)}" } def translate_text_only(self, text: str, voice_gender: str = "neutral") -> Dict: """Text-to-speech translation with voice gender option""" try: translated_text = self.translate_text(text) sample_rate, audio = self.text_to_speech(translated_text, voice_gender) return { "source_text": text, "translated_text": translated_text, "output_audio": (sample_rate, audio), "performance": "Translation successful" } except Exception as e: return { "source_text": text, "translated_text": "Error", "output_audio": (16000, np.zeros(1000, dtype=np.int16)), "performance": f"Error: {str(e)}" } class TranslatorSingleton: _instance = None @classmethod def get_instance(cls): if cls._instance is None: cls._instance = TalklasTranslator() return cls._instance def process_audio(audio_path, source_lang, target_lang, voice_gender): """Process audio through the full translation pipeline with voice gender""" # Validate input if not audio_path: return None, "No audio provided", "No translation available", "Please provide audio input" # Update languages source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang] target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang] translator = TranslatorSingleton.get_instance() status = translator.update_languages(source_code, target_code) # Process the audio results = translator.translate_speech(audio_path, voice_gender) return results["output_audio"], results["source_text"], results["translated_text"], results["performance"] def process_text(text, source_lang, target_lang, voice_gender): """Process text through the translation pipeline with voice gender""" # Validate input if not text: return None, "No text provided", "No translation available", "Please provide text input" # Update languages source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang] target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang] translator = TranslatorSingleton.get_instance() status = translator.update_languages(source_code, target_code) # Process the text results = translator.translate_text_only(text, voice_gender) return results["output_audio"], results["source_text"], results["translated_text"], results["performance"] def create_gradio_interface(): """Create and launch Gradio interface with voice gender selection""" # Define language options languages = list(TalklasTranslator.LANGUAGE_MAPPING.keys()) voice_genders = ["Neutral", "Male", "Female"] # Define the interface demo = gr.Blocks(title="Talklas - Speech & Text Translation") with demo: gr.Markdown("# Talklas: Speech-to-Speech Translation System") gr.Markdown("### Translate between Philippine Languages and English") with gr.Row(): with gr.Column(): source_lang = gr.Dropdown( choices=languages, value="English", label="Source Language" ) target_lang = gr.Dropdown( choices=languages, value="Tagalog", label="Target Language" ) voice_gender = gr.Dropdown( choices=voice_genders, value="Neutral", label="Voice Gender" ) language_status = gr.Textbox(label="Language Status") update_btn = gr.Button("Update Languages") with gr.Tabs(): with gr.TabItem("Audio Input"): with gr.Row(): with gr.Column(): gr.Markdown("### Audio Input") audio_input = gr.Audio( type="filepath", label="Upload Audio File" ) audio_translate_btn = gr.Button("Translate Audio", variant="primary") with gr.Column(): gr.Markdown("### Output") audio_output = gr.Audio( label="Translated Speech", type="numpy", autoplay=True ) with gr.TabItem("Text Input"): with gr.Row(): with gr.Column(): gr.Markdown("### Text Input") text_input = gr.Textbox( label="Enter text to translate", lines=3 ) text_translate_btn = gr.Button("Translate Text", variant="primary") with gr.Column(): gr.Markdown("### Output") text_output = gr.Audio( label="Translated Speech", type="numpy", autoplay=True ) with gr.Row(): with gr.Column(): source_text = gr.Textbox(label="Source Text") translated_text = gr.Textbox(label="Translated Text") performance_info = gr.Textbox(label="Performance Metrics") # Set up events update_btn.click( lambda source_lang, target_lang: TranslatorSingleton.get_instance().update_languages( TalklasTranslator.LANGUAGE_MAPPING[source_lang], TalklasTranslator.LANGUAGE_MAPPING[target_lang] ), inputs=[source_lang, target_lang], outputs=[language_status] ) # Audio translate button click audio_translate_btn.click( process_audio, inputs=[audio_input, source_lang, target_lang, voice_gender], outputs=[audio_output, source_text, translated_text, performance_info] ).then( None, None, None, js="""() => { const audioElements = document.querySelectorAll('audio'); if (audioElements.length > 0) { const lastAudio = audioElements[audioElements.length - 1]; lastAudio.play().catch(error => { console.warn('Autoplay failed:', error); alert('Audio may require user interaction to play'); }); } }""" ) # Text translate button click text_translate_btn.click( process_text, inputs=[text_input, source_lang, target_lang, voice_gender], outputs=[text_output, source_text, translated_text, performance_info] ).then( None, None, None, js="""() => { const audioElements = document.querySelectorAll('audio'); if (audioElements.length > 0) { const lastAudio = audioElements[audioElements.length - 1]; lastAudio.play().catch(error => { console.warn('Autoplay failed:', error); alert('Audio may require user interaction to play'); }); } }""" ) return demo if __name__ == "__main__": demo = create_gradio_interface() demo.launch(share=True, debug=True)