updating handler
Browse files- handler.py +11 -14
- tester.py +2 -2
handler.py
CHANGED
@@ -4,22 +4,19 @@ import torch
|
|
4 |
|
5 |
MAX_TOKENS=8192
|
6 |
|
7 |
-
class EndpointHandler():
|
8 |
-
def __init__(self
|
9 |
-
self.pipeline = transformers.pipeline(
|
10 |
"text-generation",
|
11 |
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
|
12 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
13 |
device_map="auto",
|
14 |
)
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
)
|
24 |
-
print(outputs[0]["generated_text"][-1])
|
25 |
-
return outputs
|
|
|
4 |
|
5 |
MAX_TOKENS=8192
|
6 |
|
7 |
+
class EndpointHandler(object):
|
8 |
+
def __init__(self):
|
9 |
+
self.pipeline: transformers.Pipeline = transformers.pipeline(
|
10 |
"text-generation",
|
11 |
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
|
12 |
+
model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True, },
|
13 |
device_map="auto",
|
14 |
)
|
15 |
|
16 |
+
def __call__(self, text_inputs: Any) -> List[List[Dict[str, float]]]:
|
17 |
+
outputs = self.pipeline(
|
18 |
+
text_inputs,
|
19 |
+
max_new_tokens=MAX_TOKENS,
|
20 |
+
)
|
21 |
+
print(outputs[0]["generated_text"][-1])
|
22 |
+
return outputs
|
|
|
|
|
|
tester.py
CHANGED
@@ -2,7 +2,7 @@ from handler import EndpointHandler
|
|
2 |
|
3 |
if __name__ == "__main__":
|
4 |
# init handler
|
5 |
-
my_handler = EndpointHandler(
|
6 |
|
7 |
# prepare sample payload
|
8 |
messages = [
|
@@ -10,7 +10,7 @@ if __name__ == "__main__":
|
|
10 |
]
|
11 |
|
12 |
# test the handler
|
13 |
-
pred=my_handler
|
14 |
|
15 |
# show results
|
16 |
print(pred)
|
|
|
2 |
|
3 |
if __name__ == "__main__":
|
4 |
# init handler
|
5 |
+
my_handler = EndpointHandler()
|
6 |
|
7 |
# prepare sample payload
|
8 |
messages = [
|
|
|
10 |
]
|
11 |
|
12 |
# test the handler
|
13 |
+
pred=my_handler(messages)
|
14 |
|
15 |
# show results
|
16 |
print(pred)
|