File size: 2,345 Bytes
73208f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor
class IndicTrans:
    def __init__(self,all_lang:list[str],en2indic,indic2en,indic2indic):
        self.all_lang = all_lang
        self.ip = IndicProcessor(inference=True)
        self.indictrans_en2indic_tokenizer = AutoTokenizer.from_pretrained(en2indic, trust_remote_code=True)
        self.indictrans_en2indic_model = AutoModelForSeq2SeqLM.from_pretrained(en2indic, trust_remote_code=True)
        self.indictrans_indic2en_tokenizer = AutoTokenizer.from_pretrained(indic2en, trust_remote_code=True)
        self.indictrans_indic2en_model = AutoModelForSeq2SeqLM.from_pretrained(indic2en, trust_remote_code=True)
        self.indictrans_indic2indic_tokenizer = AutoTokenizer.from_pretrained(indic2indic, trust_remote_code=True)
        self.indictrans_indic2indic_model = AutoModelForSeq2SeqLM.from_pretrained(indic2indic, trust_remote_code=True)
    def _translate(self,model,tokenizer,input_list: list[str], source_lang: str, target_lang: str)->list[str]:
        with torch.inference_mode():
            outputs = model.generate(**tokenizer(self.ip.preprocess_batch(input_list, src_lang=source_lang, tgt_lang=target_lang, visualize=False),padding="longest",truncation=True,max_length=256,return_tensors="pt"), num_beams=5, num_return_sequences=1, max_length=256)
        with tokenizer.as_target_tokenizer():
            outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return self.ip.postprocess_batch(outputs, lang=target_lang)
    def translate(self,input: str, source_lang: str, target_lang: str):
        assert source_lang != target_lang and source_lang in self.all_lang and target_lang in self.all_lang
        if source_lang == "eng_Latn":
            return self._translate(self.indictrans_en2indic_model,self.indictrans_en2indic_tokenizer,[input],source_lang,target_lang)[0]
        elif target_lang == "eng_Latn":
            return self._translate(self.indictrans_indic2en_model,self.indictrans_indic2en_tokenizer,[input],source_lang,target_lang)[0]
        else:
            return self._translate(self.indictrans_indic2indic_model,self.indictrans_indic2indic_tokenizer,[input],source_lang,target_lang)[0]