Update code/inference.py
Browse files- code/inference.py +2 -3
code/inference.py
CHANGED
@@ -17,8 +17,7 @@ Alice Gate: Yeah, it's really fun. I'm lucky to be able to do this as a job.
|
|
17 |
{user_name}: Definetly.
|
18 |
<END>
|
19 |
Alice Gate: *Alice strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
|
20 |
-
{user_input}
|
21 |
-
Alice Gate:"""
|
22 |
|
23 |
def model_fn(model_dir):
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
@@ -34,7 +33,7 @@ def predict_fn(input_data, load_list):
|
|
34 |
user_name = user_name,
|
35 |
user_input = user_input
|
36 |
)
|
37 |
-
input_ids = tokenizer(prompt, return_tensors = "pt").to("cuda")
|
38 |
encoded_output = model.generate(
|
39 |
input_ids["input_ids"],
|
40 |
max_new_tokens = 50,
|
|
|
17 |
{user_name}: Definetly.
|
18 |
<END>
|
19 |
Alice Gate: *Alice strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
|
20 |
+
{user_input}"""
|
|
|
21 |
|
22 |
def model_fn(model_dir):
|
23 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
|
33 |
user_name = user_name,
|
34 |
user_input = user_input
|
35 |
)
|
36 |
+
input_ids = tokenizer(prompt + "\nAlice Gate:", return_tensors = "pt").to("cuda")
|
37 |
encoded_output = model.generate(
|
38 |
input_ids["input_ids"],
|
39 |
max_new_tokens = 50,
|