Indic_Translation / utils_trans.py
shethjenil's picture
Upload 6 files
73208f7 verified
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor
class IndicTrans:
def __init__(self,all_lang:list[str],en2indic,indic2en,indic2indic):
self.all_lang = all_lang
self.ip = IndicProcessor(inference=True)
self.indictrans_en2indic_tokenizer = AutoTokenizer.from_pretrained(en2indic, trust_remote_code=True)
self.indictrans_en2indic_model = AutoModelForSeq2SeqLM.from_pretrained(en2indic, trust_remote_code=True)
self.indictrans_indic2en_tokenizer = AutoTokenizer.from_pretrained(indic2en, trust_remote_code=True)
self.indictrans_indic2en_model = AutoModelForSeq2SeqLM.from_pretrained(indic2en, trust_remote_code=True)
self.indictrans_indic2indic_tokenizer = AutoTokenizer.from_pretrained(indic2indic, trust_remote_code=True)
self.indictrans_indic2indic_model = AutoModelForSeq2SeqLM.from_pretrained(indic2indic, trust_remote_code=True)
def _translate(self,model,tokenizer,input_list: list[str], source_lang: str, target_lang: str)->list[str]:
with torch.inference_mode():
outputs = model.generate(**tokenizer(self.ip.preprocess_batch(input_list, src_lang=source_lang, tgt_lang=target_lang, visualize=False),padding="longest",truncation=True,max_length=256,return_tensors="pt"), num_beams=5, num_return_sequences=1, max_length=256)
with tokenizer.as_target_tokenizer():
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return self.ip.postprocess_batch(outputs, lang=target_lang)
def translate(self,input: str, source_lang: str, target_lang: str):
assert source_lang != target_lang and source_lang in self.all_lang and target_lang in self.all_lang
if source_lang == "eng_Latn":
return self._translate(self.indictrans_en2indic_model,self.indictrans_en2indic_tokenizer,[input],source_lang,target_lang)[0]
elif target_lang == "eng_Latn":
return self._translate(self.indictrans_indic2en_model,self.indictrans_indic2en_tokenizer,[input],source_lang,target_lang)[0]
else:
return self._translate(self.indictrans_indic2indic_model,self.indictrans_indic2indic_tokenizer,[input],source_lang,target_lang)[0]