BlueDice commited on
Commit
0890b21
·
1 Parent(s): fa9d559

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -42
handler.py CHANGED
@@ -8,52 +8,56 @@ class EndpointHandler():
8
  def __init__(self, path = ""):
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
  self.model = torch.load(f"{path}/torch_model.pt")
 
11
 
12
  def __call__(self, data):
 
 
 
 
 
13
  try:
14
- request_inputs = data.pop("inputs", data)
15
- template = request_inputs["template"]
16
- messages = request_inputs["messages"]
17
- char_name = request_inputs["char_name"]
18
- user_name = request_inputs["user_name"]
19
- template = open(f"{template}.txt", "r").read()
20
- user_input = "\n".join([
21
- "{name}: {message}".format(
22
- name = char_name if (id["role"] == "AI") else user_name,
23
- message = id["message"].strip()
24
- ) for id in messages
25
- ])
26
- prompt = template.format(
27
- char_name = char_name,
28
- user_name = user_name,
29
- user_input = user_input
30
- )
31
- input_ids = self.tokenizer(
32
- prompt + f"\n{char_name}:",
33
- return_tensors = "pt"
34
- ).to("cuda")
35
- encoded_output = self.model.generate(
36
- input_ids["input_ids"],
37
- max_new_tokens = 50,
38
- temperature = 0.5,
39
- top_p = 0.9,
40
- top_k = 0,
41
- repetition_penalty = 1.1,
42
- pad_token_id = 50256,
43
- num_return_sequences = 1
44
- )
45
- decoded_output = self.tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
46
- decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
47
- parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
48
- if len(parsed_result) != 0: decoded_output = parsed_result
49
- decoded_output = " ".join(decoded_output.replace("*","").split())
50
- try:
51
- parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
52
- if len(parsed_result) != 0: decoded_output = parsed_result
53
- except Exception: pass
54
  except Exception as e:
55
- decoded_output = str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return {
57
  "role": "AI",
58
- "message": decoded_output
 
59
  }
 
8
  def __init__(self, path = ""):
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
  self.model = torch.load(f"{path}/torch_model.pt")
11
+ self.default_template = open(f"{path}/default_template.txt", "r").read()
12
 
13
  def __call__(self, data):
14
+ request_inputs = data.pop("inputs", data)
15
+ template = request_inputs["template"]
16
+ messages = request_inputs["messages"]
17
+ char_name = request_inputs["char_name"]
18
+ user_name = request_inputs["user_name"]
19
  try:
20
+ template = open(f"/{template}.txt", "r").read()
21
+ check = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  except Exception as e:
23
+ template = self.default_template
24
+ check = 1
25
+ user_input = "\n".join([
26
+ "{name}: {message}".format(
27
+ name = char_name if (id["role"] == "AI") else user_name,
28
+ message = id["message"].strip()
29
+ ) for id in messages
30
+ ])
31
+ prompt = template.format(
32
+ char_name = char_name,
33
+ user_name = user_name,
34
+ user_input = user_input
35
+ )
36
+ input_ids = self.tokenizer(
37
+ prompt + f"\n{char_name}:",
38
+ return_tensors = "pt"
39
+ ).to("cuda")
40
+ encoded_output = self.model.generate(
41
+ input_ids["input_ids"],
42
+ max_new_tokens = 50,
43
+ temperature = 0.5,
44
+ top_p = 0.9,
45
+ top_k = 0,
46
+ repetition_penalty = 1.1,
47
+ pad_token_id = 50256,
48
+ num_return_sequences = 1
49
+ )
50
+ decoded_output = self.tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
51
+ decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
52
+ parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
53
+ if len(parsed_result) != 0: decoded_output = parsed_result
54
+ decoded_output = " ".join(decoded_output.replace("*","").split())
55
+ try:
56
+ parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
57
+ if len(parsed_result) != 0: decoded_output = parsed_result
58
+ except Exception: pass
59
  return {
60
  "role": "AI",
61
+ "message": decoded_output,
62
+ "check": check
63
  }