Update code/inference.py
Browse files- code/inference.py +2 -2
code/inference.py
CHANGED
@@ -15,12 +15,12 @@ def predict_fn(data, load_list):
|
|
15 |
char_name = request_inputs["char_name"]
|
16 |
user_name = request_inputs["user_name"]
|
17 |
template = open(f"{template}.txt", "r").read()
|
18 |
-
user_input = [
|
19 |
"{name}: {message}".format(
|
20 |
name = char_name if (id["role"] == "AI") else user_name,
|
21 |
message = id["message"].strip()
|
22 |
) for id in messages
|
23 |
-
]
|
24 |
prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
|
25 |
input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
|
26 |
encoded_output = model.generate(
|
|
|
15 |
char_name = request_inputs["char_name"]
|
16 |
user_name = request_inputs["user_name"]
|
17 |
template = open(f"{template}.txt", "r").read()
|
18 |
+
user_input = "\n".join([
|
19 |
"{name}: {message}".format(
|
20 |
name = char_name if (id["role"] == "AI") else user_name,
|
21 |
message = id["message"].strip()
|
22 |
) for id in messages
|
23 |
+
])
|
24 |
prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
|
25 |
input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
|
26 |
encoded_output = model.generate(
|