Ne-En-Trn / app.py
bishaltwr's picture
yo
97e56be
import gradio as gr
import torch
import os
import io
from gtts import gTTS
import soundfile as sf
import tempfile
import logging
# Import your existing functionality
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import Wav2Vec2ForCTC, AutoProcessor
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Update the model loading section
try:
# Try to load custom model
checkpoint_dir = "bishaltwr/final_m2m100"
logging.info(f"Attempting to load custom M2M100 from {checkpoint_dir}")
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
logging.info("Custom M2M100 model loaded successfully")
except Exception as e:
logging.error(f"Error loading custom M2M100 model: {e}")
# Fall back to official model
checkpoint_dir = "facebook/m2m100_418M"
logging.info(f"Attempting to load official M2M100 from {checkpoint_dir}")
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
logging.info("Official M2M100 model loaded successfully")
# Set device after model loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
model_m2m.to(device)
# Initialize ASR model
model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
processor = AutoProcessor.from_pretrained(model_id)
model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
# Initialize X-Transformer model
from inference import translate as xtranslate
def m2m_translate(text, source_lang, target_lang):
"""Translation using M2M100 model"""
tokenizer.src_lang = source_lang
inputs = tokenizer(text, return_tensors="pt").to(device)
translated_tokens = model_m2m.generate(
**inputs,
forced_bos_token_id=tokenizer.get_lang_id(target_lang)
)
translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return translated_text
def transcribe_audio(audio_path, language="npi"):
"""Transcribe audio using ASR model"""
import librosa
audio, sr = librosa.load(audio_path, sr=16000)
processor.tokenizer.set_target_lang(language)
model_asr.load_adapter(language)
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = model_asr(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids, skip_special_tokens=True)
if language == "eng":
transcription = transcription.replace('<pad>','').replace('<unk>','')
else:
transcription = transcription.replace('<pad>',' ').replace('<unk>','')
return transcription
def text_to_speech(text):
"""Convert text to speech using gTTS"""
if not text:
return None
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
tts = gTTS(text=text)
tts.save(temp_audio.name)
return temp_audio.name
except Exception as e:
logging.error(f"TTS error: {e}")
return None
def detect_language(text):
"""Simple language detection function"""
english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
return "en" if english_chars > len(text) * 0.5 else "ne"
def translate_text(text, model_choice, source_lang=None, target_lang=None):
"""Main translation function"""
if not text:
return "Please enter some text to translate"
# Auto-detect language if not specified
if not source_lang:
source_lang = detect_language(text)
target_lang = "ne" if source_lang == "en" else "en"
# Choose the translation model
if model_choice == "XTransformer":
return xtranslate(text)
elif model_choice == "M2M100":
return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
else:
return "Selected model is not available"
# Set up the Gradio interface
with gr.Blocks(title="Nepali-English Translator") as demo:
gr.Markdown("# Nepali-English Translator")
gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.")
gr.Markdown("Aakash Budhathoki, Apekshya Subedi, Bishal Tiwari, Kebin Malla. - Kantipur Engineering College.")
with gr.Column():
gr.Markdown("### Speech to Text")
audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
asr_language = gr.Radio(
choices=["eng", "npi"],
value="npi",
label="Speech Language"
)
transcribe_button = gr.Button("Transcribe")
transcription_output = gr.Textbox(label="Transcription Output", lines=3)
gr.Markdown("### Text Translation")
model_choice = gr.Dropdown(
choices=["XTransformer", "M2M100"],
value="M2M100",
label="Translation Model"
)
source_lang = gr.Dropdown(
choices=["Auto-detect", "en", "ne"],
value="Auto-detect",
label="Source Language"
)
target_lang = gr.Dropdown(
choices=["Auto-select", "en", "ne"],
value="Auto-select",
label="Target Language"
)
translate_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translation Output", lines=5)
gr.Markdown("### Text to Speech")
tts_button = gr.Button("Convert to Speech")
audio_output = gr.Audio(label="Audio Output")
# Define event handlers
def process_translation(text, model, src_lang, tgt_lang):
if src_lang == "Auto-detect":
src_lang = None
if tgt_lang == "Auto-select":
tgt_lang = None
return translate_text(text, model, src_lang, tgt_lang)
def process_tts(text):
return text_to_speech(text)
def process_transcription(audio_path, language):
if not audio_path:
return "Please upload or record audio"
return transcribe_audio(audio_path, language)
# Connect the components
transcribe_button.click(
process_transcription,
inputs=[audio_input, asr_language],
outputs=transcription_output
)
translate_button.click(
process_translation,
inputs=[transcription_output, model_choice, source_lang, target_lang],
outputs=translation_output
)
tts_button.click(
process_tts,
inputs=translation_output,
outputs=audio_output
)
# Launch the app
if __name__ == "__main__":
demo.launch()