File size: 6,849 Bytes
5b4b058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97e56be
5b4b058
97e56be
 
 
5b4b058
 
97e56be
fa923dd
97e56be
 
 
 
 
 
 
 
 
 
 
 
d384d1c
5b4b058
 
97e56be
 
5b4b058
 
97e56be
5b4b058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97e56be
5b4b058
97e56be
5b4b058
 
 
 
 
 
97e56be
5b4b058
97e56be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd4d68f
97e56be
 
 
5b4b058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97e56be
 
 
 
 
 
5b4b058
 
97e56be
5b4b058
 
 
 
 
 
 
 
 
 
 
97e56be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import torch
import os
import io
from gtts import gTTS
import soundfile as sf
import tempfile
import logging

# Import your existing functionality
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import Wav2Vec2ForCTC, AutoProcessor

logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
# Update the model loading section
try:
    # Try to load custom model
    checkpoint_dir = "bishaltwr/final_m2m100"
    logging.info(f"Attempting to load custom M2M100 from {checkpoint_dir}")
    tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
    model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
    logging.info("Custom M2M100 model loaded successfully")
except Exception as e:
    logging.error(f"Error loading custom M2M100 model: {e}")
    # Fall back to official model
    checkpoint_dir = "facebook/m2m100_418M"
    logging.info(f"Attempting to load official M2M100 from {checkpoint_dir}")
    tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
    model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
    logging.info("Official M2M100 model loaded successfully")

# Set device after model loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
model_m2m.to(device)
    
# Initialize ASR model
model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
processor = AutoProcessor.from_pretrained(model_id)
model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)

# Initialize X-Transformer model
from inference import translate as xtranslate

def m2m_translate(text, source_lang, target_lang):
    """Translation using M2M100 model"""
    tokenizer.src_lang = source_lang
    inputs = tokenizer(text, return_tensors="pt").to(device)
    translated_tokens = model_m2m.generate(
        **inputs, 
        forced_bos_token_id=tokenizer.get_lang_id(target_lang)
    )
    translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
    return translated_text

def transcribe_audio(audio_path, language="npi"):
    """Transcribe audio using ASR model"""
    import librosa
    audio, sr = librosa.load(audio_path, sr=16000)
    processor.tokenizer.set_target_lang(language)
    model_asr.load_adapter(language)
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model_asr(**inputs).logits
    
    ids = torch.argmax(outputs, dim=-1)[0]
    transcription = processor.decode(ids, skip_special_tokens=True)
    
    if language == "eng":
        transcription = transcription.replace('<pad>','').replace('<unk>','')
    else:
        transcription = transcription.replace('<pad>',' ').replace('<unk>','')
    
    return transcription

def text_to_speech(text):
    """Convert text to speech using gTTS"""
    if not text:
        return None
    
    try:
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
            tts = gTTS(text=text)
            tts.save(temp_audio.name)
            return temp_audio.name
    except Exception as e:
        logging.error(f"TTS error: {e}")
        return None

def detect_language(text):
    """Simple language detection function"""
    english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
    return "en" if english_chars > len(text) * 0.5 else "ne"

def translate_text(text, model_choice, source_lang=None, target_lang=None):
    """Main translation function"""
    if not text:
        return "Please enter some text to translate"
        
    # Auto-detect language if not specified
    if not source_lang:
        source_lang = detect_language(text)
        target_lang = "ne" if source_lang == "en" else "en"
    
    # Choose the translation model
    if model_choice == "XTransformer":
        return xtranslate(text)
    elif model_choice == "M2M100":
        return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
    else:
        return "Selected model is not available"

# Set up the Gradio interface
with gr.Blocks(title="Nepali-English Translator") as demo:
    gr.Markdown("# Nepali-English Translator")
    gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.")
    gr.Markdown("Aakash Budhathoki, Apekshya Subedi, Bishal Tiwari, Kebin Malla. - Kantipur Engineering College.")
    
    with gr.Column():
        gr.Markdown("### Speech to Text")
        audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
        asr_language = gr.Radio(
            choices=["eng", "npi"], 
            value="npi", 
            label="Speech Language"
        )
        transcribe_button = gr.Button("Transcribe")
        transcription_output = gr.Textbox(label="Transcription Output", lines=3)
        
        gr.Markdown("### Text Translation")
        model_choice = gr.Dropdown(
            choices=["XTransformer", "M2M100"], 
            value="M2M100", 
            label="Translation Model"
        )
        source_lang = gr.Dropdown(
            choices=["Auto-detect", "en", "ne"], 
            value="Auto-detect", 
            label="Source Language"
        )
        target_lang = gr.Dropdown(
            choices=["Auto-select", "en", "ne"], 
            value="Auto-select", 
            label="Target Language"
        )
        translate_button = gr.Button("Translate")
        translation_output = gr.Textbox(label="Translation Output", lines=5)
        
        gr.Markdown("### Text to Speech")
        tts_button = gr.Button("Convert to Speech")
        audio_output = gr.Audio(label="Audio Output")
    
    # Define event handlers
    def process_translation(text, model, src_lang, tgt_lang):
        if src_lang == "Auto-detect":
            src_lang = None
        if tgt_lang == "Auto-select":
            tgt_lang = None
        return translate_text(text, model, src_lang, tgt_lang)
    
    def process_tts(text):
        return text_to_speech(text)
    
    def process_transcription(audio_path, language):
        if not audio_path:
            return "Please upload or record audio"
        return transcribe_audio(audio_path, language)
    
    # Connect the components
    transcribe_button.click(
        process_transcription,
        inputs=[audio_input, asr_language],
        outputs=transcription_output
    )
    
    translate_button.click(
        process_translation, 
        inputs=[transcription_output, model_choice, source_lang, target_lang], 
        outputs=translation_output
    )
    
    tts_button.click(
        process_tts,
        inputs=translation_output,
        outputs=audio_output
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()