Update code/inference.py
Browse files- code/inference.py +2 -2
code/inference.py
CHANGED
@@ -7,9 +7,9 @@ def model_fn(model_dir):
|
|
7 |
model = torch.load(f"{model_dir}/torch_model.pt")
|
8 |
return model, tokenizer
|
9 |
|
10 |
-
def predict_fn(
|
11 |
model, tokenizer = load_list
|
12 |
-
request_inputs = input_data.pop("inputs",
|
13 |
template = request_inputs["template"]
|
14 |
messages = request_inputs["messages"]
|
15 |
char_name = request_inputs["char_name"]
|
|
|
7 |
model = torch.load(f"{model_dir}/torch_model.pt")
|
8 |
return model, tokenizer
|
9 |
|
10 |
+
def predict_fn(data, load_list):
|
11 |
model, tokenizer = load_list
|
12 |
+
request_inputs = input_data.pop("inputs", data)
|
13 |
template = request_inputs["template"]
|
14 |
messages = request_inputs["messages"]
|
15 |
char_name = request_inputs["char_name"]
|