File size: 827 Bytes
366f016 bf2292e 366f016 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from typing import Dict, List, Any
import time
import torch
from transformers import AutoTokenizer, AutoModel
#
class EndpointHandler:
def __init__(self, path=''):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModel.from_pretrained(path, load_in_8bit=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
inputs = data.pop('inputs', data)
parameters = data.pop('parameters', {})
starting_time = time.time()
tokenized = self.tokenizer(inputs, return_tensors='pt')
out = self.model.generate(tokenized.to('cuda'), **parameters).to('cpu')
detokenized = self.tokenizer.batch_decode(out)
ending_time = time.time()
return [{'generated_text': detokenized, 'generation_time': ending_time-starting_time}] |