BlueDice commited on
Commit
8ac4df3
·
1 Parent(s): a0e38ab

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +2 -2
code/inference.py CHANGED
@@ -15,12 +15,12 @@ def predict_fn(data, load_list):
15
  char_name = request_inputs["char_name"]
16
  user_name = request_inputs["user_name"]
17
  template = open(f"{template}.txt", "r").read()
18
- user_input = [
19
  "{name}: {message}".format(
20
  name = char_name if (id["role"] == "AI") else user_name,
21
  message = id["message"].strip()
22
  ) for id in messages
23
- ]
24
  prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
25
  input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
26
  encoded_output = model.generate(
 
15
  char_name = request_inputs["char_name"]
16
  user_name = request_inputs["user_name"]
17
  template = open(f"{template}.txt", "r").read()
18
+ user_input = "\n".join([
19
  "{name}: {message}".format(
20
  name = char_name if (id["role"] == "AI") else user_name,
21
  message = id["message"].strip()
22
  ) for id in messages
23
+ ])
24
  prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
25
  input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
26
  encoded_output = model.generate(