krisoei commited on
Commit
4ce2301
·
verified ·
1 Parent(s): 1264bba

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -10
handler.py CHANGED
@@ -3,8 +3,18 @@ import torch
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
 
 
 
 
6
  self.tokenizer = AutoTokenizer.from_pretrained(path)
7
- self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
8
  self.pipe = pipeline(
9
  "text-generation",
10
  model=self.model,
@@ -16,14 +26,23 @@ class EndpointHandler:
16
  )
17
 
18
  def __call__(self, data):
 
 
 
 
19
  prompt = data.get("inputs", "")
20
  if not prompt:
21
- return {"error": "No input provided"}
22
-
23
- # Generate response
24
- response = self.pipe(prompt)[0]['generated_text']
25
-
26
- # Remove the original prompt from the response
27
- response = response[len(prompt):].strip()
28
-
29
- return {"generated_text": response}
 
 
 
 
 
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
6
+ if not path:
7
+ raise ValueError("A valid model path or name must be provided.")
8
+
9
+ # Load tokenizer and model
10
  self.tokenizer = AutoTokenizer.from_pretrained(path)
11
+ self.model = AutoModelForCausalLM.from_pretrained(
12
+ path,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto"
15
+ )
16
+
17
+ # Set up text-generation pipeline
18
  self.pipe = pipeline(
19
  "text-generation",
20
  model=self.model,
 
26
  )
27
 
28
  def __call__(self, data):
29
+ # Validate input data
30
+ if not isinstance(data, dict):
31
+ return {"error": "Input must be a JSON object."}
32
+
33
  prompt = data.get("inputs", "")
34
  if not prompt:
35
+ return {"error": "No input provided."}
36
+
37
+ try:
38
+ # Generate response
39
+ outputs = self.pipe(prompt)
40
+ if outputs:
41
+ response = outputs[0]['generated_text']
42
+ # Remove the original prompt from the response
43
+ response = response[len(prompt):].strip()
44
+ return {"generated_text": response}
45
+ else:
46
+ return {"error": "No output generated."}
47
+ except Exception as e:
48
+ return {"error": str(e)}