import gradio as gr import torch from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, ) from IndicTransToolkit import IndicProcessor import os import subprocess # Function to clone the repository and set up the environment def setup_repo(): # Clone the repository repo_url = "https://github.com/AI4Bharat/IndicTrans2" repo_dir = "IndicTrans2" if not os.path.exists(repo_dir): subprocess.run(["git", "clone", repo_url]) # Navigate to the project directory and install dependencies os.chdir(os.path.join(repo_dir, "huggingface_interface")) subprocess.run(["source", "install.sh"], shell=True) # Function to process translation def translate(input_text, src_lang, tgt_lang): setup_repo() # Ensure the repo is set up model_name = "ai4bharat/indictrans2-indic-indic-1B" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) ip = IndicProcessor(inference=True) batch = ip.preprocess_batch([input_text], src_lang=src_lang, tgt_lang=tgt_lang) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): translation = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True, )[0] return translation # List of languages with their code names languages = [ ("Assamese", "asm_Beng"), ("Kashmiri (Arabic)", "kas_Arab"), ("Punjabi", "pan_Guru"), ("Bengali", "ben_Beng"), ("Kashmiri (Devanagari)", "kas_Deva"), ("Sanskrit", "san_Deva"), ("Bodo", "brx_Deva"), ("Maithili", "mai_Deva"), ("Santali", "sat_Olck"), ("Dogri", "doi_Deva"), ("Malayalam", "mal_Mlym"), ("Sindhi (Arabic)", "snd_Arab"), ("English", "eng_Latn"), ("Marathi", "mar_Deva"), ("Sindhi (Devanagari)", "snd_Deva"), ("Konkani", "gom_Deva"), ("Manipuri (Bengali)", "mni_Beng"), ("Tamil", "tam_Taml"), ("Gujarati", "guj_Gujr"), ("Manipuri (Meitei)", "mni_Mtei"), ("Telugu", "tel_Telu"), ("Hindi", "hin_Deva"), ("Nepali", "npi_Deva"), ("Urdu", "urd_Arab"), ("Kannada", "kan_Knda"), ("Odia", "ory_Orya") ] # Gradio interface with gr.Blocks() as demo: gr.Markdown("# IndicTrans2 Translation") with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input Text") src_lang = gr.Dropdown(label="Source Language", choices=[lang[0] for lang in languages], type="value") tgt_lang = gr.Dropdown(label="Target Language", choices=[lang[0] for lang in languages], type="value") translate_button = gr.Button("Translate") output_text = gr.Textbox(label="Translated Output") # Call translate function when button is clicked translate_button.click(fn=translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text) demo.launch()