File size: 827 Bytes
366f016
bf2292e
366f016
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import Dict, List, Any
import time

import torch
from transformers import AutoTokenizer, AutoModel

#


class EndpointHandler:
    def __init__(self, path=''):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModel.from_pretrained(path, load_in_8bit=True)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        inputs = data.pop('inputs', data)
        parameters = data.pop('parameters', {})

        starting_time = time.time()

        tokenized = self.tokenizer(inputs, return_tensors='pt')
        out = self.model.generate(tokenized.to('cuda'), **parameters).to('cpu')
        detokenized = self.tokenizer.batch_decode(out)

        ending_time = time.time()

        return [{'generated_text': detokenized, 'generation_time': ending_time-starting_time}]