viethoangtranduong commited on
Commit
774805b
·
1 Parent(s): bde8fae

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +45 -3
handler.py CHANGED
@@ -1,8 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- class EndpointHandler:
6
  def __init__(self, path: str = ""):
7
 
8
  self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
@@ -18,10 +60,10 @@ class EndpointHandler:
18
  """
19
 
20
  # process input
21
- inputs_dict = data.pop("inputs", data)
22
  parameters = data.pop("parameters", {})
23
 
24
- prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_dict]
25
 
26
  self.tokenizer.pad_token = self.tokenizer.eos_token
27
 
 
1
+ # import torch
2
+ # from typing import Dict, List, Any
3
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # class EndpointHandler:
6
+ # def __init__(self, path: str = ""):
7
+
8
+ # self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
9
+ # self.model = AutoModelForCausalLM.from_pretrained(path, device_map = "auto", torch_dtype=torch.float16)
10
+
11
+ # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
12
+ # """
13
+ # Args:
14
+ # data (:obj:):
15
+ # includes the input data and the parameters for the inference.
16
+ # Return:
17
+ # A :obj:`list`:. The list contains the answer and scores of the inference inputs
18
+ # """
19
+
20
+ # # process input
21
+ # inputs_dict = data.pop("inputs", data)
22
+ # parameters = data.pop("parameters", {})
23
+
24
+ # prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_dict]
25
+
26
+ # self.tokenizer.pad_token = self.tokenizer.eos_token
27
+
28
+ # inputs = self.tokenizer(prompts, truncation=True, max_length=2048-512,
29
+ # return_tensors='pt', padding=True).to(self.model.device)
30
+ # input_length = inputs.input_ids.shape[1]
31
+
32
+ # if parameters.get("deterministic", False):
33
+ # torch.manual_seed(42)
34
+
35
+ # outputs = self.model.generate(
36
+ # **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.7, top_k=50
37
+ # )
38
+
39
+ # output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
40
+
41
+ # return {"generated_text": output_strs}
42
+
43
  import torch
44
  from typing import Dict, List, Any
45
  from transformers import AutoTokenizer, AutoModelForCausalLM
46
 
47
+ class EndpointHandler():
48
  def __init__(self, path: str = ""):
49
 
50
  self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
 
60
  """
61
 
62
  # process input
63
+ inputs_list = data.pop("inputs", data)
64
  parameters = data.pop("parameters", {})
65
 
66
+ prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_list]
67
 
68
  self.tokenizer.pad_token = self.tokenizer.eos_token
69