root
commited on
Commit
·
dfdc6c0
1
Parent(s):
aba3fe3
add long_32k_eval
Browse files- evaluation/long_32k_eval/dataset_evaluator_retro.py +178 -0
- evaluation/long_32k_eval/dataset_evaluator_retro_longbench.py +203 -0
- evaluation/long_32k_eval/dataset_evaluator_retro_nv.py +181 -0
- evaluation/long_32k_eval/eval_retro_vllm.sh +118 -0
- evaluation/long_32k_eval/extract_log.py +88 -0
- evaluation/long_32k_eval/longbench/__pycache__/eval.cpython-310.pyc +0 -0
- evaluation/long_32k_eval/longbench/__pycache__/metrics.cpython-310.pyc +0 -0
- evaluation/long_32k_eval/longbench/eval.py +127 -0
- evaluation/long_32k_eval/longbench/metrics.py +154 -0
- evaluation/long_32k_eval/run_eval_vllm.sh +10 -0
evaluation/long_32k_eval/dataset_evaluator_retro.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import shutil
|
5 |
+
import re
|
6 |
+
|
7 |
+
from datasets import load_dataset, load_metric
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
DATASETS = [
|
11 |
+
"gov_report",
|
12 |
+
"summ_screen_fd",
|
13 |
+
"qmsum",
|
14 |
+
"qasper",
|
15 |
+
"narrative_qa",
|
16 |
+
"quality",
|
17 |
+
"quality_hard",
|
18 |
+
"contract_nli",
|
19 |
+
]
|
20 |
+
|
21 |
+
PATTERN = re.compile(r'\b[A-D]\b')
|
22 |
+
|
23 |
+
def find_answer(s):
|
24 |
+
match = PATTERN.search(s)
|
25 |
+
if match is None:
|
26 |
+
return None # None is a signal of not find! NOTE
|
27 |
+
return match.group()
|
28 |
+
|
29 |
+
def read_json_data(data_path):
|
30 |
+
references = []
|
31 |
+
questions = []
|
32 |
+
id_to_labels = dict()
|
33 |
+
id_list = list()
|
34 |
+
idx = 0
|
35 |
+
with open(data_path, "r") as f:
|
36 |
+
examples = json.load(f)
|
37 |
+
for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
|
38 |
+
idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
|
39 |
+
idx += 1
|
40 |
+
id_list.append(idx_str)
|
41 |
+
|
42 |
+
questions.append(data_item['question'])
|
43 |
+
if "answers" in data_item:
|
44 |
+
references.append(data_item['answers'][0])
|
45 |
+
answer_list = [answer_str for answer_str in data_item['answers']]
|
46 |
+
id_to_labels[idx_str] = answer_list
|
47 |
+
|
48 |
+
elif "answer" in data_item:
|
49 |
+
references.append(data_item['answer']) # take the single answer
|
50 |
+
id_to_labels[idx_str] = [data_item['answer']]
|
51 |
+
else:
|
52 |
+
raise ValueError("need answer or answers from input json")
|
53 |
+
return id_to_labels, id_list, questions
|
54 |
+
|
55 |
+
def convert_to_seq(aquestion, apred):
|
56 |
+
if apred is None:
|
57 |
+
apred = ""
|
58 |
+
|
59 |
+
matched_pred = find_answer(apred)
|
60 |
+
if matched_pred is None:
|
61 |
+
matched_pred = apred
|
62 |
+
|
63 |
+
apred = '({})'.format(matched_pred)
|
64 |
+
|
65 |
+
alist = aquestion.split('\n')
|
66 |
+
for aitem in alist:
|
67 |
+
aitem = aitem.strip()
|
68 |
+
if aitem.startswith(apred):
|
69 |
+
pred_out = ' '.join(aitem.split(' ')[1:])
|
70 |
+
print('from {} to [{}]'.format(apred, pred_out))
|
71 |
+
return pred_out
|
72 |
+
print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
|
73 |
+
|
74 |
+
return apred
|
75 |
+
|
76 |
+
# 500 -> 100
|
77 |
+
def load_prediction(test_file, id_list, id_to_labels, questions, dataset_name):
|
78 |
+
predictions = []
|
79 |
+
with open(test_file, "r") as f:
|
80 |
+
for line in f.readlines():
|
81 |
+
predictions.append(line.strip())
|
82 |
+
if len(predictions) != len(id_list):
|
83 |
+
print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
|
84 |
+
len(predictions), len(id_list)))
|
85 |
+
id_list = id_list[0: len(predictions)]
|
86 |
+
|
87 |
+
id_to_prediction = dict()
|
88 |
+
for aid, apred in zip(id_list, predictions):
|
89 |
+
id_to_prediction[aid] = apred
|
90 |
+
|
91 |
+
if dataset_name.startswith('quality'):
|
92 |
+
print('quality dataset, and rewriting the prediction to the full textual sequence...')
|
93 |
+
questions = questions[0: len(predictions)]
|
94 |
+
id_to_prediction = dict()
|
95 |
+
for aid, aquestion, apred in zip(id_list, questions, predictions):
|
96 |
+
apred_seq = convert_to_seq(aquestion, apred)
|
97 |
+
id_to_prediction[aid] = apred_seq
|
98 |
+
|
99 |
+
return id_to_prediction, id_list
|
100 |
+
|
101 |
+
def main(args, raise_on_errors=False):
|
102 |
+
datasets = [args.dataset] if args.dataset in DATASETS else DATASETS
|
103 |
+
for dataset_name in datasets:
|
104 |
+
print(dataset_name)
|
105 |
+
scrolls_metric = load_metric(download_metric(), dataset_name) # TODO cost time to load ! NOTE
|
106 |
+
|
107 |
+
id_to_labels, id_list, questions = read_json_data(args.datapath)
|
108 |
+
id_to_pred, id_list = load_prediction(args.gen_test_file,
|
109 |
+
id_list, id_to_labels, questions,
|
110 |
+
dataset_name)
|
111 |
+
|
112 |
+
if len(id_to_labels) > len(id_list):
|
113 |
+
print('NOTE: prune the reference set from {} to {}'.format(
|
114 |
+
len(id_to_labels), len(id_list)))
|
115 |
+
id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
|
116 |
+
|
117 |
+
errors, details = verify(id_to_pred, id_to_labels)
|
118 |
+
|
119 |
+
if len(errors) == 0:
|
120 |
+
metrics = scrolls_metric.compute(**scrolls_metric.convert_from_map_format(id_to_pred, id_to_labels))
|
121 |
+
print(json.dumps(metrics, indent=4))
|
122 |
+
dislist = [str(item) for item in metrics['display']]
|
123 |
+
print('final display:', dataset_name, ' '.join(dislist))
|
124 |
+
elif len(errors) > 0:
|
125 |
+
errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
|
126 |
+
print(json.dumps(errors, indent=4))
|
127 |
+
raise ValueError(f"Failed to evaluate due to: {errors_msg}")
|
128 |
+
|
129 |
+
|
130 |
+
def download_metric():
|
131 |
+
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
|
132 |
+
updated_scrolls_metric_path = (
|
133 |
+
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
|
134 |
+
)
|
135 |
+
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
|
136 |
+
return updated_scrolls_metric_path
|
137 |
+
|
138 |
+
def verify(id_to_pred, id_to_labels):
|
139 |
+
errors = []
|
140 |
+
details = {"missing_keys": [], "redundant_keys": []}
|
141 |
+
if not isinstance(id_to_pred, dict):
|
142 |
+
errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
|
143 |
+
else:
|
144 |
+
if not all(isinstance(key, str) for key in id_to_pred.keys()):
|
145 |
+
errors.append("All keys of the predictions dictionary must be strings")
|
146 |
+
if not all(isinstance(value, str) for value in id_to_pred.values()):
|
147 |
+
errors.append("All values of the predictions dictionary must be strings")
|
148 |
+
if len(errors) == 0:
|
149 |
+
predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
|
150 |
+
missing_keys = reference_keys - predictions_keys
|
151 |
+
redundant_keys = predictions_keys - reference_keys
|
152 |
+
|
153 |
+
if len(missing_keys) > 0:
|
154 |
+
details["missing_keys"] = list(missing_keys)
|
155 |
+
errors.append(f"There are missing example IDs.")
|
156 |
+
else:
|
157 |
+
del details["missing_keys"]
|
158 |
+
|
159 |
+
if len(redundant_keys) > 0:
|
160 |
+
details["redundant_keys"] = list(redundant_keys)
|
161 |
+
errors.append(f"There are redundant example IDs.")
|
162 |
+
else:
|
163 |
+
del details["redundant_keys"]
|
164 |
+
|
165 |
+
return errors, details
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
|
169 |
+
|
170 |
+
parser.add_argument("--datapath", type=str,
|
171 |
+
default=None, help="datapath for test json file [reference]")
|
172 |
+
parser.add_argument("--gen_test_file", type=str,
|
173 |
+
default=None, help="generations for test file [system prediction]")
|
174 |
+
parser.add_argument("--dataset", type=str,
|
175 |
+
default=None, help="name of the dataset used in scrolls: {}".format(DATASETS))
|
176 |
+
|
177 |
+
args = parser.parse_args()
|
178 |
+
main(args)
|
evaluation/long_32k_eval/dataset_evaluator_retro_longbench.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import shutil
|
5 |
+
import re
|
6 |
+
|
7 |
+
from datasets import load_dataset, load_metric
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
from longbench.eval import scorer
|
11 |
+
|
12 |
+
LONGBENCH_DATASETS = [
|
13 |
+
'musique', # NOTE TODO to add other 20 datasets
|
14 |
+
'hotpotqa',
|
15 |
+
'multifieldqa_en'
|
16 |
+
]
|
17 |
+
|
18 |
+
PATTERN = re.compile(r'\b[A-D]\b')
|
19 |
+
|
20 |
+
def find_answer(s):
|
21 |
+
match = PATTERN.search(s)
|
22 |
+
if match is None:
|
23 |
+
return None # None is a signal of not find! NOTE
|
24 |
+
return match.group()
|
25 |
+
|
26 |
+
def read_json_data(data_path):
|
27 |
+
references = []
|
28 |
+
questions = []
|
29 |
+
id_to_labels = dict()
|
30 |
+
id_list = list()
|
31 |
+
idx = 0
|
32 |
+
with open(data_path, "r") as f:
|
33 |
+
examples = json.load(f)
|
34 |
+
for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
|
35 |
+
idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
|
36 |
+
idx += 1
|
37 |
+
id_list.append(idx_str)
|
38 |
+
|
39 |
+
questions.append(data_item['question'])
|
40 |
+
if "answers" in data_item:
|
41 |
+
references.append(data_item['answers']) # NOTE take all the answers!
|
42 |
+
answer_list = [answer_str for answer_str in data_item['answers']]
|
43 |
+
id_to_labels[idx_str] = answer_list
|
44 |
+
|
45 |
+
elif "answer" in data_item:
|
46 |
+
references.append([data_item['answer']]) # take the single answer, as a list
|
47 |
+
id_to_labels[idx_str] = [data_item['answer']]
|
48 |
+
else:
|
49 |
+
raise ValueError("need answer or answers from input json")
|
50 |
+
return id_to_labels, id_list, questions, references #answers
|
51 |
+
|
52 |
+
def convert_to_seq(aquestion, apred):
|
53 |
+
if apred is None:
|
54 |
+
apred = ""
|
55 |
+
|
56 |
+
matched_pred = find_answer(apred)
|
57 |
+
if matched_pred is None:
|
58 |
+
matched_pred = apred
|
59 |
+
|
60 |
+
apred = '({})'.format(matched_pred)
|
61 |
+
|
62 |
+
alist = aquestion.split('\n')
|
63 |
+
for aitem in alist:
|
64 |
+
aitem = aitem.strip()
|
65 |
+
if aitem.startswith(apred):
|
66 |
+
pred_out = ' '.join(aitem.split(' ')[1:])
|
67 |
+
print('from {} to [{}]'.format(apred, pred_out))
|
68 |
+
return pred_out
|
69 |
+
print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
|
70 |
+
|
71 |
+
return apred
|
72 |
+
|
73 |
+
def load_prediction_openai(test_file):
|
74 |
+
predictions = []
|
75 |
+
with open(test_file, "r") as f:
|
76 |
+
apred_list = list()
|
77 |
+
for aline in f.readlines():
|
78 |
+
if aline.startswith('assistant: '):
|
79 |
+
if len(apred_list) > 0:
|
80 |
+
print('\n'.join(apred_list))
|
81 |
+
predictions.append('\n'.join(apred_list))
|
82 |
+
apred_list = list()
|
83 |
+
apred_list.append(aline[len('assistant: '):].strip())
|
84 |
+
else:
|
85 |
+
apred_list.append(aline.strip())
|
86 |
+
|
87 |
+
if len(apred_list) > 0:
|
88 |
+
predictions.append('\n'.join(apred_list))
|
89 |
+
print(len(predictions))
|
90 |
+
return predictions
|
91 |
+
|
92 |
+
|
93 |
+
# 500 -> 100
|
94 |
+
def load_prediction(test_file, id_list, id_to_labels,
|
95 |
+
questions, dataset_name, is_openai_assistant=False):
|
96 |
+
if is_openai_assistant:
|
97 |
+
predictions = load_prediction_openai(test_file)
|
98 |
+
else:
|
99 |
+
predictions = []
|
100 |
+
with open(test_file, "r") as f:
|
101 |
+
for line in f.readlines():
|
102 |
+
predictions.append(line.strip())
|
103 |
+
|
104 |
+
if len(predictions) != len(id_list):
|
105 |
+
print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
|
106 |
+
len(predictions), len(id_list)))
|
107 |
+
id_list = id_list[0: len(predictions)]
|
108 |
+
|
109 |
+
id_to_prediction = dict()
|
110 |
+
for aid, apred in zip(id_list, predictions):
|
111 |
+
id_to_prediction[aid] = apred
|
112 |
+
|
113 |
+
if dataset_name.startswith('quality'):
|
114 |
+
print('quality dataset, and rewriting the prediction to the full textual sequence...')
|
115 |
+
questions = questions[0: len(predictions)]
|
116 |
+
id_to_prediction = dict()
|
117 |
+
for aid, aquestion, apred in zip(id_list, questions, predictions):
|
118 |
+
apred_seq = convert_to_seq(aquestion, apred)
|
119 |
+
id_to_prediction[aid] = apred_seq
|
120 |
+
|
121 |
+
return id_to_prediction, id_list, predictions
|
122 |
+
|
123 |
+
def main(args, raise_on_errors=False):
|
124 |
+
datasets = [args.dataset] if args.dataset in LONGBENCH_DATASETS else LONGBENCH_DATASETS
|
125 |
+
for dataset_name in datasets:
|
126 |
+
print(dataset_name)
|
127 |
+
|
128 |
+
id_to_labels, id_list, questions, answers = read_json_data(args.datapath)
|
129 |
+
id_to_pred, id_list, predictions = load_prediction(args.gen_test_file,
|
130 |
+
id_list, id_to_labels, questions,
|
131 |
+
dataset_name, args.is_openai_assistant)
|
132 |
+
|
133 |
+
if len(id_to_labels) > len(id_list):
|
134 |
+
print('NOTE: prune the reference set from {} to {}'.format(
|
135 |
+
len(id_to_labels), len(id_list)))
|
136 |
+
id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
|
137 |
+
|
138 |
+
errors, details = verify(id_to_pred, id_to_labels)
|
139 |
+
|
140 |
+
if len(errors) == 0:
|
141 |
+
score = scorer(dataset_name, predictions, answers, all_classes=None)
|
142 |
+
print('final display:', dataset_name, score, "\n", args.gen_test_file)
|
143 |
+
elif len(errors) > 0:
|
144 |
+
errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
|
145 |
+
print(json.dumps(errors, indent=4))
|
146 |
+
raise ValueError(f"Failed to evaluate due to: {errors_msg}")
|
147 |
+
|
148 |
+
|
149 |
+
def download_metric():
|
150 |
+
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
|
151 |
+
updated_scrolls_metric_path = (
|
152 |
+
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
|
153 |
+
)
|
154 |
+
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
|
155 |
+
return updated_scrolls_metric_path
|
156 |
+
|
157 |
+
def verify(id_to_pred, id_to_labels):
|
158 |
+
errors = []
|
159 |
+
details = {"missing_keys": [], "redundant_keys": []}
|
160 |
+
if not isinstance(id_to_pred, dict):
|
161 |
+
errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
|
162 |
+
else:
|
163 |
+
if not all(isinstance(key, str) for key in id_to_pred.keys()):
|
164 |
+
errors.append("All keys of the predictions dictionary must be strings")
|
165 |
+
if not all(isinstance(value, str) for value in id_to_pred.values()):
|
166 |
+
errors.append("All values of the predictions dictionary must be strings")
|
167 |
+
if len(errors) == 0:
|
168 |
+
predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
|
169 |
+
missing_keys = reference_keys - predictions_keys
|
170 |
+
redundant_keys = predictions_keys - reference_keys
|
171 |
+
|
172 |
+
if len(missing_keys) > 0:
|
173 |
+
details["missing_keys"] = list(missing_keys)
|
174 |
+
errors.append(f"There are missing example IDs.")
|
175 |
+
else:
|
176 |
+
del details["missing_keys"]
|
177 |
+
|
178 |
+
if len(redundant_keys) > 0:
|
179 |
+
details["redundant_keys"] = list(redundant_keys)
|
180 |
+
errors.append(f"There are redundant example IDs.")
|
181 |
+
else:
|
182 |
+
del details["redundant_keys"]
|
183 |
+
|
184 |
+
return errors, details
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
|
188 |
+
|
189 |
+
dataset_help = "name of the dataset used in longbench: {}".format(LONGBENCH_DATASETS)
|
190 |
+
parser.add_argument("--datapath", type=str, required=True,
|
191 |
+
default=None, help="datapath for test json file [reference]")
|
192 |
+
parser.add_argument("--gen_test_file", type=str, required=True,
|
193 |
+
default=None, help="generations for test file [system prediction]")
|
194 |
+
parser.add_argument("--dataset", type=str, required=True,
|
195 |
+
default=None, help=dataset_help)
|
196 |
+
parser.add_argument("--is_openai_assistant", type=bool, required=False,
|
197 |
+
default=False,
|
198 |
+
help='if openai assistant, then combine multiple lines and the 1st-line starts with assistant:')
|
199 |
+
|
200 |
+
args = parser.parse_args()
|
201 |
+
print(args)
|
202 |
+
main(args)
|
203 |
+
|
evaluation/long_32k_eval/dataset_evaluator_retro_nv.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import shutil
|
5 |
+
import re
|
6 |
+
|
7 |
+
from datasets import load_dataset, load_metric
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
from nv.evaluate_f1_sft_zeroshot import evaluate_f1
|
11 |
+
|
12 |
+
DATASETS = [
|
13 |
+
'doc2dial_full_dialogue',
|
14 |
+
]
|
15 |
+
|
16 |
+
PATTERN = re.compile(r'\b[A-D]\b')
|
17 |
+
|
18 |
+
def find_answer(s):
|
19 |
+
match = PATTERN.search(s)
|
20 |
+
if match is None:
|
21 |
+
return None # None is a signal of not find! NOTE
|
22 |
+
return match.group()
|
23 |
+
|
24 |
+
def read_json_data(data_path):
|
25 |
+
references = []
|
26 |
+
questions = []
|
27 |
+
id_to_labels = dict()
|
28 |
+
id_list = list()
|
29 |
+
idx = 0
|
30 |
+
with open(data_path, "r") as f:
|
31 |
+
examples = json.load(f)
|
32 |
+
for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
|
33 |
+
idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
|
34 |
+
idx += 1
|
35 |
+
id_list.append(idx_str)
|
36 |
+
|
37 |
+
questions.append(data_item['question'])
|
38 |
+
if "answers" in data_item:
|
39 |
+
references.append(data_item['answers']) # NOTE take all the answers!
|
40 |
+
answer_list = [answer_str for answer_str in data_item['answers']]
|
41 |
+
id_to_labels[idx_str] = answer_list
|
42 |
+
|
43 |
+
elif "answer" in data_item:
|
44 |
+
references.append([data_item['answer']]) # take the single answer, as a list
|
45 |
+
id_to_labels[idx_str] = [data_item['answer']]
|
46 |
+
else:
|
47 |
+
raise ValueError("need answer or answers from input json")
|
48 |
+
return id_to_labels, id_list, questions, references #answers
|
49 |
+
|
50 |
+
def convert_to_seq(aquestion, apred):
|
51 |
+
if apred is None:
|
52 |
+
apred = ""
|
53 |
+
|
54 |
+
matched_pred = find_answer(apred)
|
55 |
+
if matched_pred is None:
|
56 |
+
matched_pred = apred
|
57 |
+
|
58 |
+
apred = '({})'.format(matched_pred)
|
59 |
+
|
60 |
+
alist = aquestion.split('\n')
|
61 |
+
for aitem in alist:
|
62 |
+
aitem = aitem.strip()
|
63 |
+
if aitem.startswith(apred):
|
64 |
+
pred_out = ' '.join(aitem.split(' ')[1:])
|
65 |
+
print('from {} to [{}]'.format(apred, pred_out))
|
66 |
+
return pred_out
|
67 |
+
print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
|
68 |
+
|
69 |
+
return apred
|
70 |
+
|
71 |
+
# 500 -> 100
|
72 |
+
def load_prediction(test_file, id_list, id_to_labels, questions, dataset_name):
|
73 |
+
predictions = []
|
74 |
+
with open(test_file, "r") as f:
|
75 |
+
for line in f.readlines():
|
76 |
+
predictions.append(line.strip())
|
77 |
+
if len(predictions) != len(id_list):
|
78 |
+
print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
|
79 |
+
len(predictions), len(id_list)))
|
80 |
+
id_list = id_list[0: len(predictions)]
|
81 |
+
|
82 |
+
id_to_prediction = dict()
|
83 |
+
for aid, apred in zip(id_list, predictions):
|
84 |
+
id_to_prediction[aid] = apred
|
85 |
+
|
86 |
+
if dataset_name.startswith('quality'):
|
87 |
+
print('quality dataset, and rewriting the prediction to the full textual sequence...')
|
88 |
+
questions = questions[0: len(predictions)]
|
89 |
+
id_to_prediction = dict()
|
90 |
+
for aid, aquestion, apred in zip(id_list, questions, predictions):
|
91 |
+
apred_seq = convert_to_seq(aquestion, apred)
|
92 |
+
id_to_prediction[aid] = apred_seq
|
93 |
+
|
94 |
+
return id_to_prediction, id_list, predictions
|
95 |
+
|
96 |
+
def main(args):
|
97 |
+
datasets = [args.dataset] if args.dataset in DATASETS else DATASETS
|
98 |
+
for dataset_name in datasets:
|
99 |
+
print(dataset_name)
|
100 |
+
|
101 |
+
ground_truth_file = args.datapath
|
102 |
+
prediction_file = args.gen_test_file
|
103 |
+
|
104 |
+
evaluate_f1(ground_truth_file, prediction_file, dataset_name)
|
105 |
+
|
106 |
+
|
107 |
+
def main_orig(args, raise_on_errors=False):
|
108 |
+
datasets = [args.dataset] if args.dataset in DATASETS else DATASETS
|
109 |
+
for dataset_name in datasets:
|
110 |
+
print(dataset_name)
|
111 |
+
|
112 |
+
id_to_labels, id_list, questions, answers = read_json_data(args.datapath)
|
113 |
+
id_to_pred, id_list, predictions = load_prediction(args.gen_test_file,
|
114 |
+
id_list, id_to_labels, questions,
|
115 |
+
dataset_name)
|
116 |
+
|
117 |
+
if len(id_to_labels) > len(id_list):
|
118 |
+
print('NOTE: prune the reference set from {} to {}'.format(
|
119 |
+
len(id_to_labels), len(id_list)))
|
120 |
+
id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
|
121 |
+
|
122 |
+
errors, details = verify(id_to_pred, id_to_labels)
|
123 |
+
|
124 |
+
if len(errors) == 0:
|
125 |
+
score = scorer(dataset_name, predictions, answers, all_classes=None)
|
126 |
+
print('final display:', dataset_name, score)
|
127 |
+
elif len(errors) > 0:
|
128 |
+
errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
|
129 |
+
print(json.dumps(errors, indent=4))
|
130 |
+
raise ValueError(f"Failed to evaluate due to: {errors_msg}")
|
131 |
+
|
132 |
+
|
133 |
+
def download_metric():
|
134 |
+
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
|
135 |
+
updated_scrolls_metric_path = (
|
136 |
+
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
|
137 |
+
)
|
138 |
+
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
|
139 |
+
return updated_scrolls_metric_path
|
140 |
+
|
141 |
+
def verify(id_to_pred, id_to_labels):
|
142 |
+
errors = []
|
143 |
+
details = {"missing_keys": [], "redundant_keys": []}
|
144 |
+
if not isinstance(id_to_pred, dict):
|
145 |
+
errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
|
146 |
+
else:
|
147 |
+
if not all(isinstance(key, str) for key in id_to_pred.keys()):
|
148 |
+
errors.append("All keys of the predictions dictionary must be strings")
|
149 |
+
if not all(isinstance(value, str) for value in id_to_pred.values()):
|
150 |
+
errors.append("All values of the predictions dictionary must be strings")
|
151 |
+
if len(errors) == 0:
|
152 |
+
predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
|
153 |
+
missing_keys = reference_keys - predictions_keys
|
154 |
+
redundant_keys = predictions_keys - reference_keys
|
155 |
+
|
156 |
+
if len(missing_keys) > 0:
|
157 |
+
details["missing_keys"] = list(missing_keys)
|
158 |
+
errors.append(f"There are missing example IDs.")
|
159 |
+
else:
|
160 |
+
del details["missing_keys"]
|
161 |
+
|
162 |
+
if len(redundant_keys) > 0:
|
163 |
+
details["redundant_keys"] = list(redundant_keys)
|
164 |
+
errors.append(f"There are redundant example IDs.")
|
165 |
+
else:
|
166 |
+
del details["redundant_keys"]
|
167 |
+
|
168 |
+
return errors, details
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
|
172 |
+
|
173 |
+
parser.add_argument("--datapath", type=str,
|
174 |
+
default=None, help="datapath for test json file [reference]")
|
175 |
+
parser.add_argument("--gen_test_file", type=str,
|
176 |
+
default=None, help="generations for test file [system prediction]")
|
177 |
+
parser.add_argument("--dataset", type=str,
|
178 |
+
default=None, help="name of the dataset used in scrolls: {}".format(DATASETS))
|
179 |
+
|
180 |
+
args = parser.parse_args()
|
181 |
+
main(args)
|
evaluation/long_32k_eval/eval_retro_vllm.sh
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#########################################################################
|
2 |
+
# File Name: eval.sh
|
3 |
+
# Author: Xianchao Wu, Peng Xu
|
4 |
+
# mail: [email protected], [email protected]
|
5 |
+
# Created Time: Mon Sep 4 07:33:40 2024
|
6 |
+
#########################################################################
|
7 |
+
#!/bin/bash
|
8 |
+
|
9 |
+
# TODO change this to your reference file dir:
|
10 |
+
REFDIR="" # data_home https://huggingface.co/nvidia/Llama3-ChatQA-2-70B/tree/main/data
|
11 |
+
|
12 |
+
|
13 |
+
# TODO change to your tstdir
|
14 |
+
model_path=""
|
15 |
+
TSTDIR="${model_path}/outputs/"
|
16 |
+
|
17 |
+
model_size=70b # TODO change this
|
18 |
+
retriever="e5_mistral_retriever_chunkbysents1200"
|
19 |
+
|
20 |
+
adir=$TSTDIR
|
21 |
+
echo $adir
|
22 |
+
|
23 |
+
declare -A dataset2num_samples
|
24 |
+
dataset2num_samples["gov_report"]=200
|
25 |
+
dataset2num_samples["narrative_qa"]=2000
|
26 |
+
dataset2num_samples["qasper"]=2000
|
27 |
+
dataset2num_samples["qmsum"]=200
|
28 |
+
dataset2num_samples["quality"]=2000
|
29 |
+
dataset2num_samples["summ_screen_fd"]=200
|
30 |
+
|
31 |
+
dataset2num_samples["musique"]=200
|
32 |
+
dataset2num_samples["hotpotqa"]=200
|
33 |
+
dataset2num_samples["multifieldqa_en"]=200
|
34 |
+
|
35 |
+
dataset2num_samples["squality"]=200
|
36 |
+
|
37 |
+
dataset2num_samples["doc2dial_full_dialogue"]=1000
|
38 |
+
|
39 |
+
echo "ref path = $REFDIR"
|
40 |
+
echo "tst out path = $TSTDIR"
|
41 |
+
|
42 |
+
declare -A sys2name
|
43 |
+
sys2name["baseline"]=""
|
44 |
+
sys2name["ret"]="_ctx5"
|
45 |
+
|
46 |
+
for system in "baseline" "ret"
|
47 |
+
do
|
48 |
+
suffix=${sys2name[${system}]}
|
49 |
+
|
50 |
+
echo "--final display----$system----"
|
51 |
+
for adataset in "qmsum" "qasper" "quality"
|
52 |
+
do
|
53 |
+
echo $adataset
|
54 |
+
|
55 |
+
ref_fn="${REFDIR}/${adataset}.${retriever}/test.json"
|
56 |
+
tst_fn="${adir}/${adataset}.e5_mistral_retriever_chunkbysents1200_output_0to${dataset2num_samples[${adataset}]}${suffix}.txt"
|
57 |
+
|
58 |
+
echo "ref for ${adataset}", ${ref_fn}
|
59 |
+
echo "tstout for ${adataset}", ${tst_fn}
|
60 |
+
if [[ ! -e ${tst_fn} ]]; then
|
61 |
+
echo "Error: tst_fn=${tst_fn} not exist!"
|
62 |
+
fi
|
63 |
+
if [[ ! -e ${ref_fn} ]]; then
|
64 |
+
echo "Error: ref_fn=${ref_fn} not exist!"
|
65 |
+
fi
|
66 |
+
#continue
|
67 |
+
|
68 |
+
if [[ -e ${tst_fn} && -e ${ref_fn} ]]
|
69 |
+
then
|
70 |
+
python3 dataset_evaluator_retro.py \
|
71 |
+
--datapath ${ref_fn} \
|
72 |
+
--gen_test_file ${tst_fn} \
|
73 |
+
--dataset $adataset
|
74 |
+
else
|
75 |
+
if [[ ! -e ${tst_fn} ]]; then
|
76 |
+
echo "Error: tst_fn=${tst_fn} not exist!"
|
77 |
+
fi
|
78 |
+
if [[ ! -e ${ref_fn} ]]; then
|
79 |
+
echo "Error: ref_fn=${ref_fn} not exist!"
|
80 |
+
fi
|
81 |
+
fi
|
82 |
+
done
|
83 |
+
|
84 |
+
for adataset in "musique" "hotpotqa" "multifieldqa_en"
|
85 |
+
do
|
86 |
+
echo $adataset
|
87 |
+
|
88 |
+
ref_fn="${REFDIR}/${adataset}.${retriever}/test.json"
|
89 |
+
# TODO change this if necessary, model's prediction output
|
90 |
+
tst_fn="${adir}/${adataset}.e5_mistral_retriever_chunkbysents1200_output_0to${dataset2num_samples[${adataset}]}${suffix}.txt"
|
91 |
+
|
92 |
+
echo "ref for ${adataset}", ${ref_fn}
|
93 |
+
echo "tstout for ${adataset}", ${tst_fn}
|
94 |
+
if [[ ! -e ${tst_fn} ]]; then
|
95 |
+
echo "Error: tst_fn=${tst_fn} not exist!"
|
96 |
+
fi
|
97 |
+
if [[ ! -e ${ref_fn} ]]; then
|
98 |
+
echo "Error: ref_fn=${ref_fn} not exist!"
|
99 |
+
fi
|
100 |
+
#continue
|
101 |
+
|
102 |
+
if [[ -e ${tst_fn} && -e ${ref_fn} ]]
|
103 |
+
then
|
104 |
+
python3 dataset_evaluator_retro_longbench.py \
|
105 |
+
--datapath ${ref_fn} \
|
106 |
+
--gen_test_file ${tst_fn} \
|
107 |
+
--dataset $adataset
|
108 |
+
else
|
109 |
+
if [[ ! -e ${tst_fn} ]]; then
|
110 |
+
echo "Error: tst_fn=${tst_fn} not exist!"
|
111 |
+
fi
|
112 |
+
if [[ ! -e ${ref_fn} ]]; then
|
113 |
+
echo "Error: ref_fn=${ref_fn} not exist!"
|
114 |
+
fi
|
115 |
+
fi
|
116 |
+
done
|
117 |
+
done
|
118 |
+
|
evaluation/long_32k_eval/extract_log.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
DATASETS1 = [
|
5 |
+
"qmsum",
|
6 |
+
"qasper",
|
7 |
+
"quality",
|
8 |
+
'musique',
|
9 |
+
'hotpotqa',
|
10 |
+
'multifieldqa_en'
|
11 |
+
]
|
12 |
+
|
13 |
+
DATASETS = [
|
14 |
+
"qmsum",
|
15 |
+
"qasper",
|
16 |
+
"quality",
|
17 |
+
'musique',
|
18 |
+
'hotpotqa',
|
19 |
+
'multifieldqa_en',
|
20 |
+
]
|
21 |
+
|
22 |
+
outrow = ''
|
23 |
+
data2res = dict()
|
24 |
+
|
25 |
+
def average(data2res):
|
26 |
+
sumvalue = 0.0
|
27 |
+
sumnum = 0.0
|
28 |
+
for adata in data2res:
|
29 |
+
avalue = data2res[adata]
|
30 |
+
sumvalue += avalue
|
31 |
+
sumnum += 1
|
32 |
+
|
33 |
+
assert sumnum > 0.0
|
34 |
+
return sumvalue/sumnum
|
35 |
+
|
36 |
+
def collect(value_list, outrow, data2res):
|
37 |
+
#print(value_list)
|
38 |
+
# first add the single avg score:
|
39 |
+
avg = round(np.mean(value_list), 4)
|
40 |
+
outrow += str(avg) + ' '
|
41 |
+
|
42 |
+
avg2 = average(data2res)
|
43 |
+
avg2 = round(avg2, 4)
|
44 |
+
outrow += str(avg2) + ' '
|
45 |
+
|
46 |
+
for adata in DATASETS:
|
47 |
+
ares = data2res[adata] if adata in data2res else "NA"
|
48 |
+
outrow += str(ares) + " "
|
49 |
+
print(outrow.strip())
|
50 |
+
|
51 |
+
print('system avg6 avg6 ' + ' '.join(DATASETS))
|
52 |
+
|
53 |
+
#infn = "eval_retro_2.sh.log.2"
|
54 |
+
#with open(infn) as br:
|
55 |
+
#for aline in br.readlines():
|
56 |
+
value_list = list()
|
57 |
+
for aline in sys.stdin:
|
58 |
+
#import ipdb; ipdb.set_trace()
|
59 |
+
aline = aline.strip()
|
60 |
+
if 'final display' in aline:
|
61 |
+
if '-baseline-' in aline or '-ret-' in aline:
|
62 |
+
if len(outrow) > 0 and len(data2res) > 0:
|
63 |
+
collect(value_list, outrow, data2res)
|
64 |
+
|
65 |
+
outrow = "" # reset
|
66 |
+
data2res = dict()
|
67 |
+
value_list = list()
|
68 |
+
|
69 |
+
aline2 = aline.replace('-', '')
|
70 |
+
aline2 = aline2.replace('final display', '')
|
71 |
+
outrow += aline2 + ' '
|
72 |
+
continue
|
73 |
+
|
74 |
+
cols = aline.split(' ')
|
75 |
+
adata = cols[2]
|
76 |
+
ares = '/'.join(cols[3:]) # NOTE use one geometric_mean instead
|
77 |
+
scores = cols[3:]
|
78 |
+
# for R1/R2/RL geometric_mean:
|
79 |
+
if len(scores) == 3:
|
80 |
+
scores = [float(item) for item in scores]
|
81 |
+
geo_mean = (scores[0] * scores[1] * scores[2]) ** (1.0 / 3.0)
|
82 |
+
ares = str(round(geo_mean, 4))
|
83 |
+
|
84 |
+
data2res[adata] = float(ares)
|
85 |
+
value_list.append(float(ares))
|
86 |
+
|
87 |
+
collect(value_list, outrow, data2res)
|
88 |
+
|
evaluation/long_32k_eval/longbench/__pycache__/eval.cpython-310.pyc
ADDED
Binary file (3.23 kB). View file
|
|
evaluation/long_32k_eval/longbench/__pycache__/metrics.cpython-310.pyc
ADDED
Binary file (5.93 kB). View file
|
|
evaluation/long_32k_eval/longbench/eval.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .metrics import (
|
7 |
+
qa_f1_score,
|
8 |
+
rouge_zh_score,
|
9 |
+
qa_f1_zh_score,
|
10 |
+
rouge_score,
|
11 |
+
classification_score,
|
12 |
+
retrieval_score,
|
13 |
+
retrieval_zh_score,
|
14 |
+
count_score,
|
15 |
+
code_sim_score,
|
16 |
+
)
|
17 |
+
|
18 |
+
dataset2metric = {
|
19 |
+
"narrativeqa": qa_f1_score,
|
20 |
+
"qasper": qa_f1_score,
|
21 |
+
"multifieldqa_en": qa_f1_score, # NOTE
|
22 |
+
"multifieldqa_zh": qa_f1_zh_score,
|
23 |
+
"hotpotqa": qa_f1_score, # NOTE
|
24 |
+
"2wikimqa": qa_f1_score,
|
25 |
+
"musique": qa_f1_score, # NOTE
|
26 |
+
"dureader": rouge_zh_score,
|
27 |
+
"gov_report": rouge_score,
|
28 |
+
"qmsum": rouge_score,
|
29 |
+
"multi_news": rouge_score,
|
30 |
+
"vcsum": rouge_zh_score,
|
31 |
+
"trec": classification_score,
|
32 |
+
"triviaqa": qa_f1_score,
|
33 |
+
"samsum": rouge_score,
|
34 |
+
"lsht": classification_score,
|
35 |
+
"passage_retrieval_en": retrieval_score,
|
36 |
+
"passage_count": count_score,
|
37 |
+
"passage_retrieval_zh": retrieval_zh_score,
|
38 |
+
"lcc": code_sim_score,
|
39 |
+
"repobench-p": code_sim_score,
|
40 |
+
}
|
41 |
+
|
42 |
+
def parse_args(args=None):
|
43 |
+
parser = argparse.ArgumentParser()
|
44 |
+
parser.add_argument('--model', type=str, default=None)
|
45 |
+
parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
|
46 |
+
return parser.parse_args(args)
|
47 |
+
|
48 |
+
def scorer_e(dataset, predictions, answers, lengths, all_classes):
|
49 |
+
scores = {"0-4k": [], "4-8k": [], "8k+": []}
|
50 |
+
for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
|
51 |
+
score = 0.
|
52 |
+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
|
53 |
+
prediction = prediction.lstrip('\n').split('\n')[0]
|
54 |
+
for ground_truth in ground_truths:
|
55 |
+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
|
56 |
+
if length < 4000:
|
57 |
+
scores["0-4k"].append(score)
|
58 |
+
elif length < 8000:
|
59 |
+
scores["4-8k"].append(score)
|
60 |
+
else:
|
61 |
+
scores["8k+"].append(score)
|
62 |
+
for key in scores.keys():
|
63 |
+
scores[key] = round(100 * np.mean(scores[key]), 2)
|
64 |
+
return scores
|
65 |
+
|
66 |
+
def scorer(dataset, predictions, answers, all_classes):
|
67 |
+
# dataset = 'hotpotqa', 'musique', 'multifieldqa_en'
|
68 |
+
# predictions = [pred.str, ..., ]
|
69 |
+
# answers = [ [answer.str, ...], ... ]
|
70 |
+
# all_classes = None
|
71 |
+
|
72 |
+
#import ipdb; ipdb.set_trace() # all_classes=None for 'hotpotqa' dataset NOTE
|
73 |
+
total_score = 0.
|
74 |
+
for (prediction, ground_truths) in zip(predictions, answers):
|
75 |
+
score = 0.
|
76 |
+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
|
77 |
+
prediction = prediction.lstrip('\n').split('\n')[0]
|
78 |
+
for ground_truth in ground_truths:
|
79 |
+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
|
80 |
+
total_score += score
|
81 |
+
|
82 |
+
outscore = round(100 * total_score / len(predictions), 2)
|
83 |
+
print(dataset, outscore)
|
84 |
+
return outscore
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
#import ipdb; ipdb.set_trace()
|
88 |
+
args = parse_args()
|
89 |
+
scores = dict()
|
90 |
+
if args.e:
|
91 |
+
path = f"pred_e/{args.model}/"
|
92 |
+
else:
|
93 |
+
path = f"pred/{args.model}/" # 'pred/chatglm2-6b-32k/' NOTE
|
94 |
+
all_files = os.listdir(path) # 21 files
|
95 |
+
print("Evaluating on:", all_files)
|
96 |
+
|
97 |
+
for filename in all_files:
|
98 |
+
#import ipdb; ipdb.set_trace()
|
99 |
+
if not filename.endswith("jsonl"):
|
100 |
+
continue
|
101 |
+
predictions, answers, lengths = [], [], []
|
102 |
+
dataset = filename.split('.')[0] # 获取数据集的名字
|
103 |
+
if not dataset in ['musique', 'hotpotqa', 'multifieldqa_en']:
|
104 |
+
continue # TODO debug only
|
105 |
+
with open(f"{path}{filename}", "r", encoding="utf-8") as f:
|
106 |
+
for line in f: # 每一行,进行一次json的解析
|
107 |
+
data = json.loads(line)
|
108 |
+
predictions.append(data["pred"])
|
109 |
+
answers.append(data["answers"])
|
110 |
+
all_classes = data["all_classes"] # 这是属于被一次次重复赋值了
|
111 |
+
if "length" in data:
|
112 |
+
lengths.append(data["length"])
|
113 |
+
if args.e:
|
114 |
+
score = scorer_e(dataset, predictions, answers, lengths, all_classes)
|
115 |
+
else:
|
116 |
+
score = scorer(dataset, predictions, answers, all_classes) # NOTE 重要的计算得分的入口 TODO 1. dataset=具体的数据集的名字;predictions=list of str,预测结果; answers = list of list,参考答案; all_classes这是原本就带的,test in
|
117 |
+
scores[dataset] = score
|
118 |
+
if args.e:
|
119 |
+
out_path = f"pred_e/{args.model}/result.json"
|
120 |
+
else:
|
121 |
+
out_path = f"pred/{args.model}/result.json"
|
122 |
+
|
123 |
+
print(scores)
|
124 |
+
|
125 |
+
with open(out_path, "w") as f:
|
126 |
+
json.dump(scores, f, ensure_ascii=False, indent=4)
|
127 |
+
|
evaluation/long_32k_eval/longbench/metrics.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import string
|
3 |
+
|
4 |
+
import jieba
|
5 |
+
from fuzzywuzzy import fuzz
|
6 |
+
import difflib
|
7 |
+
|
8 |
+
from typing import List
|
9 |
+
from collections import Counter
|
10 |
+
from rouge import Rouge
|
11 |
+
|
12 |
+
def normalize_answer(s):
|
13 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
14 |
+
|
15 |
+
def remove_articles(text):
|
16 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
17 |
+
|
18 |
+
def white_space_fix(text):
|
19 |
+
return " ".join(text.split())
|
20 |
+
|
21 |
+
def remove_punc(text):
|
22 |
+
exclude = set(string.punctuation)
|
23 |
+
return "".join(ch for ch in text if ch not in exclude)
|
24 |
+
|
25 |
+
def lower(text):
|
26 |
+
return text.lower()
|
27 |
+
|
28 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
29 |
+
|
30 |
+
|
31 |
+
def normalize_zh_answer(s):
|
32 |
+
"""Lower text and remove punctuation, extra whitespace."""
|
33 |
+
|
34 |
+
def white_space_fix(text):
|
35 |
+
return "".join(text.split())
|
36 |
+
|
37 |
+
def remove_punc(text):
|
38 |
+
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
39 |
+
all_punctuation = set(string.punctuation + cn_punctuation)
|
40 |
+
return "".join(ch for ch in text if ch not in all_punctuation)
|
41 |
+
|
42 |
+
def lower(text):
|
43 |
+
return text.lower()
|
44 |
+
|
45 |
+
return white_space_fix(remove_punc(lower(s)))
|
46 |
+
|
47 |
+
def count_score(prediction, ground_truth, **kwargs):
|
48 |
+
numbers = re.findall(r"\d+", prediction)
|
49 |
+
right_num = 0
|
50 |
+
for number in numbers:
|
51 |
+
if str(number) == str(ground_truth):
|
52 |
+
right_num += 1
|
53 |
+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
54 |
+
return float(final_score)
|
55 |
+
|
56 |
+
def retrieval_score(prediction, ground_truth, **kwargs):
|
57 |
+
pattern = r'Paragraph (\d+)'
|
58 |
+
matches = re.findall(pattern, ground_truth)
|
59 |
+
ground_truth_id = matches[0]
|
60 |
+
numbers = re.findall(r"\d+", prediction)
|
61 |
+
right_num = 0
|
62 |
+
for number in numbers:
|
63 |
+
if str(number) == str(ground_truth_id):
|
64 |
+
right_num += 1
|
65 |
+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
66 |
+
return float(final_score)
|
67 |
+
|
68 |
+
def retrieval_zh_score(prediction, ground_truth, **kwargs):
|
69 |
+
pattern = r'段落(\d+)'
|
70 |
+
matches = re.findall(pattern, ground_truth)
|
71 |
+
ground_truth_id = matches[0]
|
72 |
+
numbers = re.findall(r"\d+", prediction)
|
73 |
+
right_num = 0
|
74 |
+
for number in numbers:
|
75 |
+
if str(number) == str(ground_truth_id):
|
76 |
+
right_num += 1
|
77 |
+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
78 |
+
return float(final_score)
|
79 |
+
|
80 |
+
def code_sim_score(prediction, ground_truth, **kwargs):
|
81 |
+
all_lines = prediction.lstrip('\n').split('\n')
|
82 |
+
prediction = ""
|
83 |
+
for line in all_lines:
|
84 |
+
if ('`' not in line) and ('#' not in line) and ('//' not in line):
|
85 |
+
prediction = line
|
86 |
+
break
|
87 |
+
return (fuzz.ratio(prediction, ground_truth) / 100)
|
88 |
+
|
89 |
+
def classification_score(prediction, ground_truth, **kwargs):
|
90 |
+
em_match_list = []
|
91 |
+
all_classes = kwargs["all_classes"]
|
92 |
+
for class_name in all_classes:
|
93 |
+
if class_name in prediction:
|
94 |
+
em_match_list.append(class_name)
|
95 |
+
for match_term in em_match_list:
|
96 |
+
if match_term in ground_truth and match_term != ground_truth:
|
97 |
+
em_match_list.remove(match_term)
|
98 |
+
if em_match_list != 0:
|
99 |
+
if ground_truth in em_match_list:
|
100 |
+
score = (1.0 / len(em_match_list))
|
101 |
+
else:
|
102 |
+
score = 0.0
|
103 |
+
else:
|
104 |
+
best_match = None
|
105 |
+
highest_similarity = 0
|
106 |
+
for string in all_classes:
|
107 |
+
similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
|
108 |
+
if similarity > highest_similarity:
|
109 |
+
highest_similarity = similarity
|
110 |
+
best_match = string
|
111 |
+
score = float(best_match == ground_truth)
|
112 |
+
return score
|
113 |
+
|
114 |
+
def rouge_score(prediction, ground_truth, **kwargs):
|
115 |
+
rouge = Rouge()
|
116 |
+
try:
|
117 |
+
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
|
118 |
+
except:
|
119 |
+
return 0.0
|
120 |
+
return scores["rouge-l"]["f"]
|
121 |
+
|
122 |
+
def rouge_zh_score(prediction, ground_truth, **kwargs):
|
123 |
+
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
|
124 |
+
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
|
125 |
+
score = rouge_score(prediction, ground_truth)
|
126 |
+
return score
|
127 |
+
|
128 |
+
def f1_score(prediction, ground_truth, **kwargs):
|
129 |
+
common = Counter(prediction) & Counter(ground_truth)
|
130 |
+
num_same = sum(common.values())
|
131 |
+
if num_same == 0:
|
132 |
+
return 0
|
133 |
+
precision = 1.0 * num_same / len(prediction)
|
134 |
+
recall = 1.0 * num_same / len(ground_truth)
|
135 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
136 |
+
return f1
|
137 |
+
|
138 |
+
def qa_f1_score(prediction, ground_truth, **kwargs):
|
139 |
+
normalized_prediction = normalize_answer(prediction)
|
140 |
+
normalized_ground_truth = normalize_answer(ground_truth)
|
141 |
+
|
142 |
+
prediction_tokens = normalized_prediction.split()
|
143 |
+
ground_truth_tokens = normalized_ground_truth.split()
|
144 |
+
return f1_score(prediction_tokens, ground_truth_tokens)
|
145 |
+
|
146 |
+
|
147 |
+
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
|
148 |
+
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
|
149 |
+
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
|
150 |
+
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
|
151 |
+
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
|
152 |
+
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
|
153 |
+
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
|
154 |
+
return f1_score(prediction_tokens, ground_truth_tokens)
|
evaluation/long_32k_eval/run_eval_vllm.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
adate=`date +%Y_%m_%d_%H_%M_%S`
|
4 |
+
outlog="eval_retro_vllm.sh.log.6_4setsplus.$adate"
|
5 |
+
|
6 |
+
echo "out log = $outlog"
|
7 |
+
|
8 |
+
bash eval_retro_vllm.sh > $outlog 2>&1
|
9 |
+
|
10 |
+
grep "final display" $outlog | python3 extract_log.py
|