MultiTrickFox commited on
Commit
366f016
·
verified ·
1 Parent(s): cc33cb3

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -0
handler.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from time import time
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModel
6
+
7
+ #
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=''):
12
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
13
+ self.model = AutoModel.from_pretrained(path, load_in_8bit=True)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
16
+ inputs = data.pop('inputs', data)
17
+ parameters = data.pop('parameters', {})
18
+
19
+ starting_time = time.time()
20
+
21
+ tokenized = self.tokenizer(inputs, return_tensors='pt')
22
+ out = self.model.generate(tokenized.to('cuda'), **parameters).to('cpu')
23
+ detokenized = self.tokenizer.batch_decode(out)
24
+
25
+ ending_time = time.time()
26
+
27
+ return [{'generated_text': detokenized, 'generation_time': ending_time-starting_time}]