Update code/inference.py
Browse files- code/inference.py +12 -5
code/inference.py
CHANGED
@@ -20,17 +20,23 @@ def predict_fn(data, load_list):
|
|
20 |
messages = request_inputs["messages"]
|
21 |
char_name = request_inputs["char_name"]
|
22 |
user_name = request_inputs["user_name"]
|
|
|
23 |
user_input = [
|
24 |
"{name}: {message}".format(
|
25 |
name = char_name if (id["role"] == "AI") else user_name,
|
26 |
message = id["message"].strip()
|
27 |
) for id in messages
|
28 |
]
|
29 |
-
user_input = "\n".join([user_input])
|
30 |
-
prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
|
31 |
|
32 |
-
#
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
encoded_output = model.generate(
|
35 |
input_ids["input_ids"],
|
36 |
max_new_tokens = 50,
|
@@ -54,5 +60,6 @@ def predict_fn(data, load_list):
|
|
54 |
except Exception: pass
|
55 |
return {
|
56 |
"role": "AI",
|
57 |
-
"message": decoded_output
|
|
|
58 |
}
|
|
|
20 |
messages = request_inputs["messages"]
|
21 |
char_name = request_inputs["char_name"]
|
22 |
user_name = request_inputs["user_name"]
|
23 |
+
chats_curled = request_inputs["chats_curled"]
|
24 |
user_input = [
|
25 |
"{name}: {message}".format(
|
26 |
name = char_name if (id["role"] == "AI") else user_name,
|
27 |
message = id["message"].strip()
|
28 |
) for id in messages
|
29 |
]
|
|
|
|
|
30 |
|
31 |
+
# Tokenize the model input
|
32 |
+
while True:
|
33 |
+
prompt = template.format(char_name = char_name, user_name = user_name, user_input = "\n".join([user_input]))
|
34 |
+
input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
|
35 |
+
if input_ids.input_ids.size(1) > 2048:
|
36 |
+
chats_curled += 1
|
37 |
+
user_input = user_input[chats_curled*2:]
|
38 |
+
else: break
|
39 |
+
|
40 |
encoded_output = model.generate(
|
41 |
input_ids["input_ids"],
|
42 |
max_new_tokens = 50,
|
|
|
60 |
except Exception: pass
|
61 |
return {
|
62 |
"role": "AI",
|
63 |
+
"message": decoded_output,
|
64 |
+
"chats_curled": chats_curled
|
65 |
}
|