Upload generation_test_hf_script.py
Browse files- generation_test_hf_script.py +84 -0
generation_test_hf_script.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def load_rag_benchmark_tester_ds():
|
8 |
+
|
9 |
+
# pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo
|
10 |
+
from datasets import load_dataset
|
11 |
+
|
12 |
+
ds_name = "llmware/rag_instruct_benchmark_tester"
|
13 |
+
|
14 |
+
dataset = load_dataset(ds_name)
|
15 |
+
|
16 |
+
print("update: loading test dataset - ", dataset)
|
17 |
+
|
18 |
+
test_set = []
|
19 |
+
for i, samples in enumerate(dataset["train"]):
|
20 |
+
test_set.append(samples)
|
21 |
+
|
22 |
+
# to view test set samples
|
23 |
+
# print("rag benchmark dataset test samples: ", i, samples)
|
24 |
+
|
25 |
+
return test_set
|
26 |
+
|
27 |
+
|
28 |
+
def run_test(model_name, test_ds):
|
29 |
+
|
30 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
32 |
+
|
33 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
+
|
35 |
+
for i, entries in enumerate(test_ds):
|
36 |
+
|
37 |
+
# prepare prompt packaging used in fine-tuning process
|
38 |
+
new_prompt = "<human>: " + entries["context"] + "\n" + entries["query"] + "\n" + "<bot>:"
|
39 |
+
|
40 |
+
inputs = tokenizer(new_prompt, return_tensors="pt")
|
41 |
+
start_of_output = len(inputs.input_ids[0])
|
42 |
+
|
43 |
+
# temperature: set at 0.3 for consistency of output
|
44 |
+
# max_new_tokens: set at 100 - may prematurely stop a few of the summaries
|
45 |
+
|
46 |
+
outputs = model.generate(
|
47 |
+
inputs.input_ids.to(device),
|
48 |
+
eos_token_id=tokenizer.eos_token_id,
|
49 |
+
pad_token_id=tokenizer.eos_token_id,
|
50 |
+
do_sample=True,
|
51 |
+
temperature=0.3,
|
52 |
+
max_new_tokens=100,
|
53 |
+
)
|
54 |
+
|
55 |
+
output_only = tokenizer.decode(outputs[0][start_of_output:],skip_special_tokens=True)
|
56 |
+
|
57 |
+
# quick/optional post-processing clean-up of potential fine-tuning artifacts
|
58 |
+
|
59 |
+
eot = output_only.find("<|endoftext|>")
|
60 |
+
if eot > -1:
|
61 |
+
output_only = output_only[:eot]
|
62 |
+
|
63 |
+
bot = output_only.find("<bot>:")
|
64 |
+
if bot > -1:
|
65 |
+
output_only = output_only[bot+len("<bot>:"):]
|
66 |
+
|
67 |
+
# end - post-processing
|
68 |
+
|
69 |
+
print("\n")
|
70 |
+
print(i, "llm_response - ", output_only)
|
71 |
+
print(i, "gold_answer - ", entries["answer"])
|
72 |
+
|
73 |
+
return 0
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
|
78 |
+
test_ds = load_rag_benchmark_tester_ds()
|
79 |
+
|
80 |
+
model_name = "llmware/bling-1b-0.1"
|
81 |
+
|
82 |
+
output = run_test(model_name,test_ds)
|
83 |
+
|
84 |
+
|