|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
class Model: |
|
def __init__(self, revision) -> None: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base") |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained("truong-xuan-linh/vit5-reproduce", revision=revision) |
|
|
|
def preprocess_function(self, text): |
|
inputs = self.tokenizer( |
|
text, max_length=1024, truncation=True, padding=True, return_tensors="pt" |
|
) |
|
return inputs |
|
|
|
def inference(self, text): |
|
max_target_length = 256 |
|
inputs = self.preprocess_function(text) |
|
outputs = self.model.generate( |
|
input_ids=inputs['input_ids'], |
|
max_length=max_target_length, |
|
attention_mask=inputs['attention_mask'], |
|
) |
|
|
|
with self.tokenizer.as_target_tokenizer(): |
|
outputs = [self.tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs] |
|
|
|
return outputs[0] |