BlueDice commited on
Commit
fa9d559
·
1 Parent(s): 8ac4df3

Update handler.py (#1)

Browse files

- Update handler.py (32958b5317300ea7ce493f656d75945c5f350421)

Files changed (1) hide show
  1. handler.py +41 -38
handler.py CHANGED
@@ -10,46 +10,49 @@ class EndpointHandler():
10
  self.model = torch.load(f"{path}/torch_model.pt")
11
 
12
  def __call__(self, data):
13
- request_inputs = data.pop("inputs", data)
14
- template = request_inputs["template"]
15
- messages = request_inputs["messages"]
16
- char_name = request_inputs["char_name"]
17
- user_name = request_inputs["user_name"]
18
- template = open(f"{template}.txt", "r").read()
19
- user_input = "\n".join([
20
- "{name}: {message}".format(
21
- name = char_name if (id["role"] == "AI") else user_name,
22
- message = id["message"].strip()
23
- ) for id in messages
24
- ])
25
- prompt = template.format(
26
- char_name = char_name,
27
- user_name = user_name,
28
- user_input = user_input
29
- )
30
- input_ids = self.tokenizer(
31
- prompt + f"\n{char_name}:",
32
- return_tensors = "pt"
33
- ).to("cuda")
34
- encoded_output = self.model.generate(
35
- input_ids["input_ids"],
36
- max_new_tokens = 50,
37
- temperature = 0.5,
38
- top_p = 0.9,
39
- top_k = 0,
40
- repetition_penalty = 1.1,
41
- pad_token_id = 50256,
42
- num_return_sequences = 1
43
- )
44
- decoded_output = self.tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
45
- decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
46
- parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
47
- if len(parsed_result) != 0: decoded_output = parsed_result
48
- decoded_output = " ".join(decoded_output.replace("*","").split())
49
  try:
50
- parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if len(parsed_result) != 0: decoded_output = parsed_result
52
- except Exception: pass
 
 
 
 
 
 
53
  return {
54
  "role": "AI",
55
  "message": decoded_output
 
10
  self.model = torch.load(f"{path}/torch_model.pt")
11
 
12
  def __call__(self, data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  try:
14
+ request_inputs = data.pop("inputs", data)
15
+ template = request_inputs["template"]
16
+ messages = request_inputs["messages"]
17
+ char_name = request_inputs["char_name"]
18
+ user_name = request_inputs["user_name"]
19
+ template = open(f"{template}.txt", "r").read()
20
+ user_input = "\n".join([
21
+ "{name}: {message}".format(
22
+ name = char_name if (id["role"] == "AI") else user_name,
23
+ message = id["message"].strip()
24
+ ) for id in messages
25
+ ])
26
+ prompt = template.format(
27
+ char_name = char_name,
28
+ user_name = user_name,
29
+ user_input = user_input
30
+ )
31
+ input_ids = self.tokenizer(
32
+ prompt + f"\n{char_name}:",
33
+ return_tensors = "pt"
34
+ ).to("cuda")
35
+ encoded_output = self.model.generate(
36
+ input_ids["input_ids"],
37
+ max_new_tokens = 50,
38
+ temperature = 0.5,
39
+ top_p = 0.9,
40
+ top_k = 0,
41
+ repetition_penalty = 1.1,
42
+ pad_token_id = 50256,
43
+ num_return_sequences = 1
44
+ )
45
+ decoded_output = self.tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
46
+ decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
47
+ parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
48
  if len(parsed_result) != 0: decoded_output = parsed_result
49
+ decoded_output = " ".join(decoded_output.replace("*","").split())
50
+ try:
51
+ parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
52
+ if len(parsed_result) != 0: decoded_output = parsed_result
53
+ except Exception: pass
54
+ except Exception as e:
55
+ decoded_output = str(e)
56
  return {
57
  "role": "AI",
58
  "message": decoded_output