import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit import IndicProcessor import gradio as gr # Define the model and tokenizer model_name = "ai4bharat/indictrans2-indic-indic-1B" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) ip = IndicProcessor(inference=True) # Define the language codes LANGUAGES = { "Assamese (asm_Beng)": "asm_Beng", "Kashmiri (kas_Arab)": "kas_Arab", "Punjabi (pan_Guru)": "pan_Guru", "Bengali (ben_Beng)": "ben_Beng", "Kashmiri (kas_Deva)": "kas_Deva", "Sanskrit (san_Deva)": "san_Deva", "Bodo (brx_Deva)": "brx_Deva", "Maithili (mai_Deva)": "mai_Deva", "Santali (sat_Olck)": "sat_Olck", "Dogri (doi_Deva)": "doi_Deva", "Malayalam (mal_Mlym)": "mal_Mlym", "Sindhi (snd_Arab)": "snd_Arab", "English (eng_Latn)": "eng_Latn", "Marathi (mar_Deva)": "mar_Deva", "Sindhi (snd_Deva)": "snd_Deva", "Konkani (gom_Deva)": "gom_Deva", "Manipuri (mni_Beng)": "mni_Beng", "Tamil (tam_Taml)": "tam_Taml", "Gujarati (guj_Gujr)": "guj_Gujr", "Manipuri (mni_Mtei)": "mni_Mtei", "Telugu (tel_Telu)": "tel_Telu", "Hindi (hin_Deva)": "hin_Deva", "Nepali (npi_Deva)": "npi_Deva", "Urdu (urd_Arab)": "urd_Arab", "Kannada (kan_Knda)": "kan_Knda", "Odia (ory_Orya)": "ory_Orya", } # Define the translation function def translate(text, src_lang, tgt_lang): batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) return generated_text # Create a Gradio interface with gr.Blocks() as demo: gr.Markdown("### Indic Translations") input_text = gr.Textbox(label="Input Text", placeholder="Enter text to translate") src_lang = gr.Dropdown(label="Source Language", choices=list(LANGUAGES.keys())) tgt_lang = gr.Dropdown(label="Target Language", choices=list(LANGUAGES.keys())) translate_button = gr.Button("Translate") translation_output = gr.Textbox(label="Translation", interactive=False) @translate_button.click def on_translate(text, src_lang, tgt_lang): translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang]) translation_output.value = translation demo.launch()