BlueDice commited on
Commit
78bd8dc
·
1 Parent(s): 24aa1f7

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +9 -17
code/inference.py CHANGED
@@ -1,7 +1,6 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
- from sagemaker_inference import content_types, decoder
3
  import torch
4
- import json
5
 
6
  template = """Alice Gate's Persona: Alice Gate is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
7
  <START>
@@ -26,18 +25,9 @@ def model_fn(model_dir):
26
  model = torch.load(f"{model_dir}/torch_model.pt")
27
  return model, tokenizer
28
 
29
- def input_fn(self, input_data, content_type):
30
- return decoder.decode(input_data, content_type)
31
-
32
- def output_fn(decoded_output, accept):
33
- response_body = json.dumps({
34
- "message": decoded_output
35
- })
36
- return response_body, accept
37
-
38
  def predict_fn(input_data, load_list):
39
  model, tokenizer = load_list
40
- inputs = data.pop("inputs", input_data)
41
  user_name = inputs["user_name"]
42
  user_input = "\n".join(inputs["user_input"])
43
  prompt = template.format(
@@ -55,12 +45,14 @@ def predict_fn(input_data, load_list):
55
  pad_token_id = 50256,
56
  num_return_sequences = 1
57
  )
58
- decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
59
- decoded_output = result.rsplit("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()
60
  parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
61
- decoded_output = parsed_result if len(parsed_result) != 0 else decoded_output.replace("*","")
62
- decoded_output = " ".join(result.split())
 
63
  try:
64
- decoded_output = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
 
65
  except Exception: pass
66
  return decoded_output
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import re
3
  import torch
 
4
 
5
  template = """Alice Gate's Persona: Alice Gate is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
6
  <START>
 
25
  model = torch.load(f"{model_dir}/torch_model.pt")
26
  return model, tokenizer
27
 
 
 
 
 
 
 
 
 
 
28
  def predict_fn(input_data, load_list):
29
  model, tokenizer = load_list
30
+ inputs = input_data.pop("inputs", input_data)
31
  user_name = inputs["user_name"]
32
  user_input = "\n".join(inputs["user_input"])
33
  prompt = template.format(
 
45
  pad_token_id = 50256,
46
  num_return_sequences = 1
47
  )
48
+ decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
49
+ decoded_output = decoded_output.split("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()
50
  parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
51
+ if len(parsed_result) != 0: decoded_output = parsed_result
52
+ decoded_output = decoded_output.replace("*","")
53
+ decoded_output = " ".join(decoded_output.split())
54
  try:
55
+ parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
56
+ if len(parsed_result) != 0: decoded_output = parsed_result
57
  except Exception: pass
58
  return decoded_output