binaryaaron commited on
Commit
d646830
·
unverified ·
1 Parent(s): 944bc62

updating handler

Browse files
Files changed (2) hide show
  1. handler.py +11 -14
  2. tester.py +2 -2
handler.py CHANGED
@@ -4,22 +4,19 @@ import torch
4
 
5
  MAX_TOKENS=8192
6
 
7
- class EndpointHandler():
8
- def __init__(self, path=""):
9
- self.pipeline = transformers.pipeline(
10
  "text-generation",
11
  model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
12
- model_kwargs={"torch_dtype": torch.bfloat16},
13
  device_map="auto",
14
  )
15
 
16
-
17
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
- inputs = data.get("inputs", data)
19
-
20
- outputs = self.pipeline(
21
- inputs,
22
- max_new_tokens=256,
23
- )
24
- print(outputs[0]["generated_text"][-1])
25
- return outputs
 
4
 
5
  MAX_TOKENS=8192
6
 
7
+ class EndpointHandler(object):
8
+ def __init__(self):
9
+ self.pipeline: transformers.Pipeline = transformers.pipeline(
10
  "text-generation",
11
  model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
12
+ model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True, },
13
  device_map="auto",
14
  )
15
 
16
+ def __call__(self, text_inputs: Any) -> List[List[Dict[str, float]]]:
17
+ outputs = self.pipeline(
18
+ text_inputs,
19
+ max_new_tokens=MAX_TOKENS,
20
+ )
21
+ print(outputs[0]["generated_text"][-1])
22
+ return outputs
 
 
 
tester.py CHANGED
@@ -2,7 +2,7 @@ from handler import EndpointHandler
2
 
3
  if __name__ == "__main__":
4
  # init handler
5
- my_handler = EndpointHandler(path=".")
6
 
7
  # prepare sample payload
8
  messages = [
@@ -10,7 +10,7 @@ if __name__ == "__main__":
10
  ]
11
 
12
  # test the handler
13
- pred=my_handler.pipeline(messages)
14
 
15
  # show results
16
  print(pred)
 
2
 
3
  if __name__ == "__main__":
4
  # init handler
5
+ my_handler = EndpointHandler()
6
 
7
  # prepare sample payload
8
  messages = [
 
10
  ]
11
 
12
  # test the handler
13
+ pred=my_handler(messages)
14
 
15
  # show results
16
  print(pred)