File size: 8,339 Bytes
5b4b058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa923dd
5b4b058
fa923dd
 
 
5b4b058
 
fa923dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b4b058
d384d1c
5b4b058
d384d1c
5b4b058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import gradio as gr
import torch
import os
import io
from gtts import gTTS
import soundfile as sf
import tempfile
import logging

# Import your existing functionality
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import Wav2Vec2ForCTC, AutoProcessor

logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
# Update the model loading section
try:
    # Try to load custom model
    checkpoint_dir = "bishaltwr/final_m2m100"
    logging.info(f"Attempting to load custom M2M100 from {checkpoint_dir}")
    tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
    model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
    logging.info("Custom M2M100 model loaded successfully")
except Exception as e:
    logging.error(f"Error loading custom M2M100 model: {e}")
    try:
        # Fall back to official model
        checkpoint_dir = "facebook/m2m100_418M"
        logging.info(f"Attempting to load official M2M100 from {checkpoint_dir}")
        tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
        model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
        logging.info("Official M2M100 model loaded successfully")
        m2m_available = True
    except Exception as e2:
        logging.error(f"Error loading official M2M100 model: {e2}")
        m2m_available = False
        logging.info("Setting m2m_available to False")

# Set device after model loading
if m2m_available:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    model_m2m.to(device)
    
# Initialize ASR model
model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
try:
    processor = AutoProcessor.from_pretrained(model_id)
    model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
    asr_available = True
except Exception as e:
    logging.error(f"Error loading ASR model: {e}")
    asr_available = False

# Initialize X-Transformer model
try:
    from inference import translate as xtranslate
    xtransformer_available = True
except Exception as e:
    logging.error(f"Error loading XTransformer model: {e}")
    xtransformer_available = False

def m2m_translate(text, source_lang, target_lang):
    """Translation using M2M100 model"""
    if not m2m_available:
        return "M2M100 model not available"
    
    tokenizer.src_lang = source_lang
    inputs = tokenizer(text, return_tensors="pt").to(device)
    translated_tokens = model_m2m.generate(
        **inputs, 
        forced_bos_token_id=tokenizer.get_lang_id(target_lang)
    )
    translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
    return translated_text

def transcribe_audio(audio_path, language="npi"):
    """Transcribe audio using ASR model"""
    if not asr_available:
        return "ASR model not available"
    
    import librosa
    audio, sr = librosa.load(audio_path, sr=16000)
    processor.tokenizer.set_target_lang(language)
    model_asr.load_adapter(language)
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model_asr(**inputs).logits
    
    ids = torch.argmax(outputs, dim=-1)[0]
    transcription = processor.decode(ids, skip_special_tokens=True)
    
    if language == "eng":
        transcription = transcription.replace('<pad>','').replace('<unk>','')
    else:
        transcription = transcription.replace('<pad>',' ').replace('<unk>','')
    
    return transcription

def text_to_speech(text):
    """Convert text to speech using gTTS"""
    if not text:
        return None
    
    try:
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
            tts = gTTS(text=text)
            tts.save(temp_audio.name)
            return temp_audio.name
    except Exception as e:
        logging.error(f"TTS error: {e}")
        return None

def detect_language(text):
    """Simple language detection function"""
    english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
    return "en" if english_chars > len(text) * 0.5 else "ne"

def translate_text(text, model_choice, source_lang=None, target_lang=None):
    """Main translation function"""
    if not text:
        return "Please enter some text to translate"
        
    # Auto-detect language if not specified
    if not source_lang:
        source_lang = detect_language(text)
        target_lang = "ne" if source_lang == "en" else "en"
    
    # Choose the translation model
    if model_choice == "XTransformer" and xtransformer_available:
        return xtranslate(text)
    elif model_choice == "M2M100" and m2m_available:
        return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
    else:
        return "Selected model is not available"

# Set up the Gradio interface
with gr.Blocks(title="Nepali-English Translator") as demo:
    gr.Markdown("# Nepali-English Translation Service")
    gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.")
    
    # Set up tabs for different functions
    with gr.Tabs():
        # Text Translation Tab
        with gr.TabItem("Text Translation"):
            with gr.Row():
                with gr.Column():
                    text_input = gr.Textbox(label="Input Text", lines=5)
                    
                    with gr.Row():
                        model_choice = gr.Radio(
                            choices=["XTransformer", "M2M100"], 
                            value="XTransformer", 
                            label="Translation Model"
                        )
                        
                    with gr.Row():
                        source_lang = gr.Dropdown(
                            choices=["Auto-detect", "en", "ne"], 
                            value="Auto-detect", 
                            label="Source Language",
                            visible=True
                        )
                        target_lang = gr.Dropdown(
                            choices=["Auto-select", "en", "ne"], 
                            value="Auto-select", 
                            label="Target Language",
                            visible=True
                        )
                    
                    translate_button = gr.Button("Translate")
                
                with gr.Column():
                    translation_output = gr.Textbox(label="Translation Output", lines=5)
                    tts_button = gr.Button("Convert to Speech")
                    audio_output = gr.Audio(label="Audio Output")
        
        # Speech to Text Tab
        with gr.TabItem("Speech to Text"):
            with gr.Column():
                audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
                asr_language = gr.Radio(
                    choices=["eng", "npi"], 
                    value="npi", 
                    label="Speech Language"
                )
                transcribe_button = gr.Button("Transcribe")
                transcription_output = gr.Textbox(label="Transcription Output", lines=3)
    
    # Define event handlers
    def process_translation(text, model, src_lang, tgt_lang):
        if src_lang == "Auto-detect":
            src_lang = None
        if tgt_lang == "Auto-select":
            tgt_lang = None
        return translate_text(text, model, src_lang, tgt_lang)
    
    def process_tts(text):
        return text_to_speech(text)
    
    def process_transcription(audio_path, language):
        if not audio_path:
            return "Please upload or record audio"
        return transcribe_audio(audio_path, language)
    
    # Connect the components
    translate_button.click(
        process_translation, 
        inputs=[text_input, model_choice, source_lang, target_lang], 
        outputs=translation_output
    )
    
    tts_button.click(
        process_tts,
        inputs=translation_output,
        outputs=audio_output
    )
    
    transcribe_button.click(
        process_transcription,
        inputs=[audio_input, asr_language],
        outputs=transcription_output
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()