Itamarl commited on
Commit
3ac4d0b
·
1 Parent(s): fc14f0e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -6
handler.py CHANGED
@@ -20,21 +20,33 @@ class EndpointHandler():
20
 
21
  self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
22
  print("tokenizer created ", datetime.now())
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  self.generate_text = transformers.pipeline(
24
  model=self.model,
25
  tokenizer=self.tokenizer,
 
26
  task='text-generation',
27
  return_full_text=True,
28
  temperature=0.1,
29
  top_p=0.15,
30
  top_k=0,
31
- # max_new_tokens=64,
32
  repetition_penalty=1.1
33
  )
34
 
35
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
36
- print("iiiiiiiiii ", data)
37
- inputs = data.pop("inputs ", data)
38
- print(inputs)
39
- res = self.generate_text("Explain to me the difference between nuclear fission and fusion." , max_length= 60)
40
  return res
 
20
 
21
  self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
22
  print("tokenizer created ", datetime.now())
23
+
24
+
25
+ stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
26
+
27
+ class StopOnTokens(StoppingCriteria):
28
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
29
+ for stop_id in stop_token_ids:
30
+ if input_ids[0][-1] == stop_id:
31
+ return True
32
+ return False
33
+
34
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
35
+
36
  self.generate_text = transformers.pipeline(
37
  model=self.model,
38
  tokenizer=self.tokenizer,
39
+ stopping_criteria=stopping_criteria,
40
  task='text-generation',
41
  return_full_text=True,
42
  temperature=0.1,
43
  top_p=0.15,
44
  top_k=0,
45
+ max_new_tokens=64,
46
  repetition_penalty=1.1
47
  )
48
 
49
+
50
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
51
+ res = self.generate_text("Explain to me the difference between nuclear fission and fusion.")
 
 
52
  return res