Update handler.py
Browse files- handler.py +3 -3
handler.py
CHANGED
@@ -17,7 +17,6 @@ class EndpointHandler():
|
|
17 |
char_name = request_inputs["char_name"]
|
18 |
user_name = request_inputs["user_name"]
|
19 |
chats_curled = request_inputs["chats_curled"]
|
20 |
-
template = self.default_template
|
21 |
user_input = [
|
22 |
"{name}: {message}".format(
|
23 |
name = char_name if (id["role"] == "AI") else user_name,
|
@@ -25,7 +24,7 @@ class EndpointHandler():
|
|
25 |
) for id in messages
|
26 |
]
|
27 |
while True:
|
28 |
-
prompt =
|
29 |
input_ids = self.tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
|
30 |
if input_ids.input_ids.size(1) > 2048:
|
31 |
chats_curled += 1
|
@@ -52,5 +51,6 @@ class EndpointHandler():
|
|
52 |
except Exception: pass
|
53 |
return {
|
54 |
"role": "AI",
|
55 |
-
"message": decoded_output
|
|
|
56 |
}
|
|
|
17 |
char_name = request_inputs["char_name"]
|
18 |
user_name = request_inputs["user_name"]
|
19 |
chats_curled = request_inputs["chats_curled"]
|
|
|
20 |
user_input = [
|
21 |
"{name}: {message}".format(
|
22 |
name = char_name if (id["role"] == "AI") else user_name,
|
|
|
24 |
) for id in messages
|
25 |
]
|
26 |
while True:
|
27 |
+
prompt = self.default_template.format(char_name = char_name, user_name = user_name, user_input = "\n".join(user_input))
|
28 |
input_ids = self.tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
|
29 |
if input_ids.input_ids.size(1) > 2048:
|
30 |
chats_curled += 1
|
|
|
51 |
except Exception: pass
|
52 |
return {
|
53 |
"role": "AI",
|
54 |
+
"message": decoded_output,
|
55 |
+
"chats_curled": chats_curled
|
56 |
}
|