|
import gradio as gr |
|
import torch |
|
import os |
|
import io |
|
from gtts import gTTS |
|
import soundfile as sf |
|
import tempfile |
|
import logging |
|
|
|
|
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
|
|
try: |
|
|
|
checkpoint_dir = "bishaltwr/final_m2m100" |
|
logging.info(f"Attempting to load custom M2M100 from {checkpoint_dir}") |
|
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir) |
|
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir) |
|
logging.info("Custom M2M100 model loaded successfully") |
|
except Exception as e: |
|
logging.error(f"Error loading custom M2M100 model: {e}") |
|
|
|
checkpoint_dir = "facebook/m2m100_418M" |
|
logging.info(f"Attempting to load official M2M100 from {checkpoint_dir}") |
|
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir) |
|
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir) |
|
logging.info("Official M2M100 model loaded successfully") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logging.info(f"Using device: {device}") |
|
model_m2m.to(device) |
|
|
|
|
|
model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali" |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True) |
|
|
|
|
|
from inference import translate as xtranslate |
|
|
|
def m2m_translate(text, source_lang, target_lang): |
|
"""Translation using M2M100 model""" |
|
tokenizer.src_lang = source_lang |
|
inputs = tokenizer(text, return_tensors="pt").to(device) |
|
translated_tokens = model_m2m.generate( |
|
**inputs, |
|
forced_bos_token_id=tokenizer.get_lang_id(target_lang) |
|
) |
|
translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
return translated_text |
|
|
|
def transcribe_audio(audio_path, language="npi"): |
|
"""Transcribe audio using ASR model""" |
|
import librosa |
|
audio, sr = librosa.load(audio_path, sr=16000) |
|
processor.tokenizer.set_target_lang(language) |
|
model_asr.load_adapter(language) |
|
inputs = processor(audio, sampling_rate=16000, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model_asr(**inputs).logits |
|
|
|
ids = torch.argmax(outputs, dim=-1)[0] |
|
transcription = processor.decode(ids, skip_special_tokens=True) |
|
|
|
if language == "eng": |
|
transcription = transcription.replace('<pad>','').replace('<unk>','') |
|
else: |
|
transcription = transcription.replace('<pad>',' ').replace('<unk>','') |
|
|
|
return transcription |
|
|
|
def text_to_speech(text): |
|
"""Convert text to speech using gTTS""" |
|
if not text: |
|
return None |
|
|
|
try: |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: |
|
tts = gTTS(text=text) |
|
tts.save(temp_audio.name) |
|
return temp_audio.name |
|
except Exception as e: |
|
logging.error(f"TTS error: {e}") |
|
return None |
|
|
|
def detect_language(text): |
|
"""Simple language detection function""" |
|
english_chars = sum(1 for c in text if c.isascii() and c.isalpha()) |
|
return "en" if english_chars > len(text) * 0.5 else "ne" |
|
|
|
def translate_text(text, model_choice, source_lang=None, target_lang=None): |
|
"""Main translation function""" |
|
if not text: |
|
return "Please enter some text to translate" |
|
|
|
|
|
if not source_lang: |
|
source_lang = detect_language(text) |
|
target_lang = "ne" if source_lang == "en" else "en" |
|
|
|
|
|
if model_choice == "XTransformer": |
|
return xtranslate(text) |
|
elif model_choice == "M2M100": |
|
return m2m_translate(text, source_lang=source_lang, target_lang=target_lang) |
|
else: |
|
return "Selected model is not available" |
|
|
|
|
|
with gr.Blocks(title="Nepali-English Translator") as demo: |
|
gr.Markdown("# Nepali-English Translator") |
|
gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.") |
|
gr.Markdown("Aakash Budhathoki, Apekshya Subedi, Bishal Tiwari, Kebin Malla. - Kantipur Engineering College.") |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Speech to Text") |
|
audio_input = gr.Audio(label="Upload or Record Audio", type="filepath") |
|
asr_language = gr.Radio( |
|
choices=["eng", "npi"], |
|
value="npi", |
|
label="Speech Language" |
|
) |
|
transcribe_button = gr.Button("Transcribe") |
|
transcription_output = gr.Textbox(label="Transcription Output", lines=3) |
|
|
|
gr.Markdown("### Text Translation") |
|
model_choice = gr.Dropdown( |
|
choices=["XTransformer", "M2M100"], |
|
value="M2M100", |
|
label="Translation Model" |
|
) |
|
source_lang = gr.Dropdown( |
|
choices=["Auto-detect", "en", "ne"], |
|
value="Auto-detect", |
|
label="Source Language" |
|
) |
|
target_lang = gr.Dropdown( |
|
choices=["Auto-select", "en", "ne"], |
|
value="Auto-select", |
|
label="Target Language" |
|
) |
|
translate_button = gr.Button("Translate") |
|
translation_output = gr.Textbox(label="Translation Output", lines=5) |
|
|
|
gr.Markdown("### Text to Speech") |
|
tts_button = gr.Button("Convert to Speech") |
|
audio_output = gr.Audio(label="Audio Output") |
|
|
|
|
|
def process_translation(text, model, src_lang, tgt_lang): |
|
if src_lang == "Auto-detect": |
|
src_lang = None |
|
if tgt_lang == "Auto-select": |
|
tgt_lang = None |
|
return translate_text(text, model, src_lang, tgt_lang) |
|
|
|
def process_tts(text): |
|
return text_to_speech(text) |
|
|
|
def process_transcription(audio_path, language): |
|
if not audio_path: |
|
return "Please upload or record audio" |
|
return transcribe_audio(audio_path, language) |
|
|
|
|
|
transcribe_button.click( |
|
process_transcription, |
|
inputs=[audio_input, asr_language], |
|
outputs=transcription_output |
|
) |
|
|
|
translate_button.click( |
|
process_translation, |
|
inputs=[transcription_output, model_choice, source_lang, target_lang], |
|
outputs=translation_output |
|
) |
|
|
|
tts_button.click( |
|
process_tts, |
|
inputs=translation_output, |
|
outputs=audio_output |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |