File size: 1,040 Bytes
e276af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class Model:
def __init__(self, revision) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
self.model = AutoModelForSeq2SeqLM.from_pretrained("truong-xuan-linh/vit5-reproduce", revision=revision)
def preprocess_function(self, text):
inputs = self.tokenizer(
text, max_length=1024, truncation=True, padding=True, return_tensors="pt"
)
return inputs
def inference(self, text):
max_target_length = 256
inputs = self.preprocess_function(text)
outputs = self.model.generate(
input_ids=inputs['input_ids'],
max_length=max_target_length,
attention_mask=inputs['attention_mask'],
)
with self.tokenizer.as_target_tokenizer():
outputs = [self.tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]
return outputs[0] |