chatcell-base / handler.py
Linear-Matrix-Probability's picture
Create handler.py
eee337d verified
raw
history blame
1.06 kB
from typing import Dict, List
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
)
# in line with the default config of the model
CONFIG = {
'max_length': 512,
'num_return_sequences': 1,
'no_repeat_ngram_size': 2,
'top_k': 50,
'top_p': 0.95,
'do_sample': True,
}
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
def __call__(self, data: Dict[str, str]) -> List[Dict[str, str]]:
inputs = data.pop('inputs', None)
if inputs is None or inputs == '':
return [{'generated_text': 'No input provided'}]
# preprocess
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
# inference
output_ids = self.model.generate(input_ids, **CONFIG)
# postprocess
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return [{'generated_text': response}]