LA1512 commited on
Commit
3dd9480
1 Parent(s): 61ee63b

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -0
handler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline, AutoTokenizer, BartForConditionalGeneration
3
+
4
+ class EndpointHandler():
5
+ def __init__(self, path=""):
6
+ self.model = BartForConditionalGeneration.from_pretrained(path)
7
+ self.tokenizer = AutoTokenizer(path)
8
+
9
+
10
+ def __call__(self, data: str) -> str:
11
+ """
12
+ data args:
13
+ inputs (:obj: `str`)
14
+ date (:obj: `str`)
15
+ Return:
16
+ A :obj:`list` | `dict`: will be serialized and returned
17
+ """
18
+ # get inputs
19
+
20
+
21
+ text_tokenized = self.tokenizer(
22
+ [data], padding="max_length", truncation=True, max_length=1024,return_tensors='pt')
23
+
24
+ prediction_token = self.model.generate(text_tokenized["input_ids"], max_length = 256, num_beams = 6)
25
+
26
+ prediction_summary = self.tokenizer.decode(prediction_token[0][2:-1:1])
27
+
28
+ return prediction_summary