Update handler.py
Browse files- handler.py +29 -10
handler.py
CHANGED
@@ -3,8 +3,18 @@ import torch
|
|
3 |
|
4 |
class EndpointHandler:
|
5 |
def __init__(self, path=""):
|
|
|
|
|
|
|
|
|
6 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
7 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
self.pipe = pipeline(
|
9 |
"text-generation",
|
10 |
model=self.model,
|
@@ -16,14 +26,23 @@ class EndpointHandler:
|
|
16 |
)
|
17 |
|
18 |
def __call__(self, data):
|
|
|
|
|
|
|
|
|
19 |
prompt = data.get("inputs", "")
|
20 |
if not prompt:
|
21 |
-
return {"error": "No input provided"}
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class EndpointHandler:
|
5 |
def __init__(self, path=""):
|
6 |
+
if not path:
|
7 |
+
raise ValueError("A valid model path or name must be provided.")
|
8 |
+
|
9 |
+
# Load tokenizer and model
|
10 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
11 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
12 |
+
path,
|
13 |
+
torch_dtype=torch.float16,
|
14 |
+
device_map="auto"
|
15 |
+
)
|
16 |
+
|
17 |
+
# Set up text-generation pipeline
|
18 |
self.pipe = pipeline(
|
19 |
"text-generation",
|
20 |
model=self.model,
|
|
|
26 |
)
|
27 |
|
28 |
def __call__(self, data):
|
29 |
+
# Validate input data
|
30 |
+
if not isinstance(data, dict):
|
31 |
+
return {"error": "Input must be a JSON object."}
|
32 |
+
|
33 |
prompt = data.get("inputs", "")
|
34 |
if not prompt:
|
35 |
+
return {"error": "No input provided."}
|
36 |
+
|
37 |
+
try:
|
38 |
+
# Generate response
|
39 |
+
outputs = self.pipe(prompt)
|
40 |
+
if outputs:
|
41 |
+
response = outputs[0]['generated_text']
|
42 |
+
# Remove the original prompt from the response
|
43 |
+
response = response[len(prompt):].strip()
|
44 |
+
return {"generated_text": response}
|
45 |
+
else:
|
46 |
+
return {"error": "No output generated."}
|
47 |
+
except Exception as e:
|
48 |
+
return {"error": str(e)}
|