kusht55's picture
Update app.py
73836a5 verified
import gradio as gr
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from IndicTransToolkit import IndicProcessor
import os
import subprocess
# Function to clone the repository and set up the environment
def setup_repo():
# Clone the repository
repo_url = "https://github.com/AI4Bharat/IndicTrans2"
repo_dir = "IndicTrans2"
if not os.path.exists(repo_dir):
subprocess.run(["git", "clone", repo_url])
# Navigate to the project directory and install dependencies
os.chdir(os.path.join(repo_dir, "huggingface_interface"))
subprocess.run(["source", "install.sh"], shell=True)
# Function to process translation
def translate(input_text, src_lang, tgt_lang):
setup_repo() # Ensure the repo is set up
model_name = "ai4bharat/indictrans2-indic-indic-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
batch = ip.preprocess_batch([input_text], src_lang=src_lang, tgt_lang=tgt_lang)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
translation = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
return translation
# List of languages with their code names
languages = [
("Assamese", "asm_Beng"), ("Kashmiri (Arabic)", "kas_Arab"), ("Punjabi", "pan_Guru"),
("Bengali", "ben_Beng"), ("Kashmiri (Devanagari)", "kas_Deva"), ("Sanskrit", "san_Deva"),
("Bodo", "brx_Deva"), ("Maithili", "mai_Deva"), ("Santali", "sat_Olck"),
("Dogri", "doi_Deva"), ("Malayalam", "mal_Mlym"), ("Sindhi (Arabic)", "snd_Arab"),
("English", "eng_Latn"), ("Marathi", "mar_Deva"), ("Sindhi (Devanagari)", "snd_Deva"),
("Konkani", "gom_Deva"), ("Manipuri (Bengali)", "mni_Beng"), ("Tamil", "tam_Taml"),
("Gujarati", "guj_Gujr"), ("Manipuri (Meitei)", "mni_Mtei"), ("Telugu", "tel_Telu"),
("Hindi", "hin_Deva"), ("Nepali", "npi_Deva"), ("Urdu", "urd_Arab"),
("Kannada", "kan_Knda"), ("Odia", "ory_Orya")
]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# IndicTrans2 Translation")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text")
src_lang = gr.Dropdown(label="Source Language", choices=[lang[0] for lang in languages], type="value")
tgt_lang = gr.Dropdown(label="Target Language", choices=[lang[0] for lang in languages], type="value")
translate_button = gr.Button("Translate")
output_text = gr.Textbox(label="Translated Output")
# Call translate function when button is clicked
translate_button.click(fn=translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
demo.launch()