BlueDice commited on
Commit
b4b37be
·
1 Parent(s): f7ed38a

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +12 -5
code/inference.py CHANGED
@@ -20,17 +20,23 @@ def predict_fn(data, load_list):
20
  messages = request_inputs["messages"]
21
  char_name = request_inputs["char_name"]
22
  user_name = request_inputs["user_name"]
 
23
  user_input = [
24
  "{name}: {message}".format(
25
  name = char_name if (id["role"] == "AI") else user_name,
26
  message = id["message"].strip()
27
  ) for id in messages
28
  ]
29
- user_input = "\n".join([user_input])
30
- prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
31
 
32
- # tokenize the model input, generate and decode output
33
- input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
 
 
 
 
 
 
 
34
  encoded_output = model.generate(
35
  input_ids["input_ids"],
36
  max_new_tokens = 50,
@@ -54,5 +60,6 @@ def predict_fn(data, load_list):
54
  except Exception: pass
55
  return {
56
  "role": "AI",
57
- "message": decoded_output
 
58
  }
 
20
  messages = request_inputs["messages"]
21
  char_name = request_inputs["char_name"]
22
  user_name = request_inputs["user_name"]
23
+ chats_curled = request_inputs["chats_curled"]
24
  user_input = [
25
  "{name}: {message}".format(
26
  name = char_name if (id["role"] == "AI") else user_name,
27
  message = id["message"].strip()
28
  ) for id in messages
29
  ]
 
 
30
 
31
+ # Tokenize the model input
32
+ while True:
33
+ prompt = template.format(char_name = char_name, user_name = user_name, user_input = "\n".join([user_input]))
34
+ input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
35
+ if input_ids.input_ids.size(1) > 2048:
36
+ chats_curled += 1
37
+ user_input = user_input[chats_curled*2:]
38
+ else: break
39
+
40
  encoded_output = model.generate(
41
  input_ids["input_ids"],
42
  max_new_tokens = 50,
 
60
  except Exception: pass
61
  return {
62
  "role": "AI",
63
+ "message": decoded_output,
64
+ "chats_curled": chats_curled
65
  }