Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
) | |
from IndicTransToolkit import IndicProcessor | |
import os | |
import subprocess | |
# Function to clone the repository and set up the environment | |
def setup_repo(): | |
# Clone the repository | |
repo_url = "https://github.com/AI4Bharat/IndicTrans2" | |
repo_dir = "IndicTrans2" | |
if not os.path.exists(repo_dir): | |
subprocess.run(["git", "clone", repo_url]) | |
# Navigate to the project directory and install dependencies | |
os.chdir(os.path.join(repo_dir, "huggingface_interface")) | |
subprocess.run(["source", "install.sh"], shell=True) | |
# Function to process translation | |
def translate(input_text, src_lang, tgt_lang): | |
setup_repo() # Ensure the repo is set up | |
model_name = "ai4bharat/indictrans2-indic-indic-1B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
ip = IndicProcessor(inference=True) | |
batch = ip.preprocess_batch([input_text], src_lang=src_lang, tgt_lang=tgt_lang) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
inputs = tokenizer( | |
batch, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(DEVICE) | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
with tokenizer.as_target_tokenizer(): | |
translation = tokenizer.batch_decode( | |
generated_tokens.detach().cpu().tolist(), | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
)[0] | |
return translation | |
# List of languages with their code names | |
languages = [ | |
("Assamese", "asm_Beng"), ("Kashmiri (Arabic)", "kas_Arab"), ("Punjabi", "pan_Guru"), | |
("Bengali", "ben_Beng"), ("Kashmiri (Devanagari)", "kas_Deva"), ("Sanskrit", "san_Deva"), | |
("Bodo", "brx_Deva"), ("Maithili", "mai_Deva"), ("Santali", "sat_Olck"), | |
("Dogri", "doi_Deva"), ("Malayalam", "mal_Mlym"), ("Sindhi (Arabic)", "snd_Arab"), | |
("English", "eng_Latn"), ("Marathi", "mar_Deva"), ("Sindhi (Devanagari)", "snd_Deva"), | |
("Konkani", "gom_Deva"), ("Manipuri (Bengali)", "mni_Beng"), ("Tamil", "tam_Taml"), | |
("Gujarati", "guj_Gujr"), ("Manipuri (Meitei)", "mni_Mtei"), ("Telugu", "tel_Telu"), | |
("Hindi", "hin_Deva"), ("Nepali", "npi_Deva"), ("Urdu", "urd_Arab"), | |
("Kannada", "kan_Knda"), ("Odia", "ory_Orya") | |
] | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# IndicTrans2 Translation") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Input Text") | |
src_lang = gr.Dropdown(label="Source Language", choices=[lang[0] for lang in languages], type="value") | |
tgt_lang = gr.Dropdown(label="Target Language", choices=[lang[0] for lang in languages], type="value") | |
translate_button = gr.Button("Translate") | |
output_text = gr.Textbox(label="Translated Output") | |
# Call translate function when button is clicked | |
translate_button.click(fn=translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text) | |
demo.launch() | |