root commited on
Commit
dfdc6c0
·
1 Parent(s): aba3fe3

add long_32k_eval

Browse files
evaluation/long_32k_eval/dataset_evaluator_retro.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import shutil
5
+ import re
6
+
7
+ from datasets import load_dataset, load_metric
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ DATASETS = [
11
+ "gov_report",
12
+ "summ_screen_fd",
13
+ "qmsum",
14
+ "qasper",
15
+ "narrative_qa",
16
+ "quality",
17
+ "quality_hard",
18
+ "contract_nli",
19
+ ]
20
+
21
+ PATTERN = re.compile(r'\b[A-D]\b')
22
+
23
+ def find_answer(s):
24
+ match = PATTERN.search(s)
25
+ if match is None:
26
+ return None # None is a signal of not find! NOTE
27
+ return match.group()
28
+
29
+ def read_json_data(data_path):
30
+ references = []
31
+ questions = []
32
+ id_to_labels = dict()
33
+ id_list = list()
34
+ idx = 0
35
+ with open(data_path, "r") as f:
36
+ examples = json.load(f)
37
+ for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
38
+ idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
39
+ idx += 1
40
+ id_list.append(idx_str)
41
+
42
+ questions.append(data_item['question'])
43
+ if "answers" in data_item:
44
+ references.append(data_item['answers'][0])
45
+ answer_list = [answer_str for answer_str in data_item['answers']]
46
+ id_to_labels[idx_str] = answer_list
47
+
48
+ elif "answer" in data_item:
49
+ references.append(data_item['answer']) # take the single answer
50
+ id_to_labels[idx_str] = [data_item['answer']]
51
+ else:
52
+ raise ValueError("need answer or answers from input json")
53
+ return id_to_labels, id_list, questions
54
+
55
+ def convert_to_seq(aquestion, apred):
56
+ if apred is None:
57
+ apred = ""
58
+
59
+ matched_pred = find_answer(apred)
60
+ if matched_pred is None:
61
+ matched_pred = apred
62
+
63
+ apred = '({})'.format(matched_pred)
64
+
65
+ alist = aquestion.split('\n')
66
+ for aitem in alist:
67
+ aitem = aitem.strip()
68
+ if aitem.startswith(apred):
69
+ pred_out = ' '.join(aitem.split(' ')[1:])
70
+ print('from {} to [{}]'.format(apred, pred_out))
71
+ return pred_out
72
+ print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
73
+
74
+ return apred
75
+
76
+ # 500 -> 100
77
+ def load_prediction(test_file, id_list, id_to_labels, questions, dataset_name):
78
+ predictions = []
79
+ with open(test_file, "r") as f:
80
+ for line in f.readlines():
81
+ predictions.append(line.strip())
82
+ if len(predictions) != len(id_list):
83
+ print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
84
+ len(predictions), len(id_list)))
85
+ id_list = id_list[0: len(predictions)]
86
+
87
+ id_to_prediction = dict()
88
+ for aid, apred in zip(id_list, predictions):
89
+ id_to_prediction[aid] = apred
90
+
91
+ if dataset_name.startswith('quality'):
92
+ print('quality dataset, and rewriting the prediction to the full textual sequence...')
93
+ questions = questions[0: len(predictions)]
94
+ id_to_prediction = dict()
95
+ for aid, aquestion, apred in zip(id_list, questions, predictions):
96
+ apred_seq = convert_to_seq(aquestion, apred)
97
+ id_to_prediction[aid] = apred_seq
98
+
99
+ return id_to_prediction, id_list
100
+
101
+ def main(args, raise_on_errors=False):
102
+ datasets = [args.dataset] if args.dataset in DATASETS else DATASETS
103
+ for dataset_name in datasets:
104
+ print(dataset_name)
105
+ scrolls_metric = load_metric(download_metric(), dataset_name) # TODO cost time to load ! NOTE
106
+
107
+ id_to_labels, id_list, questions = read_json_data(args.datapath)
108
+ id_to_pred, id_list = load_prediction(args.gen_test_file,
109
+ id_list, id_to_labels, questions,
110
+ dataset_name)
111
+
112
+ if len(id_to_labels) > len(id_list):
113
+ print('NOTE: prune the reference set from {} to {}'.format(
114
+ len(id_to_labels), len(id_list)))
115
+ id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
116
+
117
+ errors, details = verify(id_to_pred, id_to_labels)
118
+
119
+ if len(errors) == 0:
120
+ metrics = scrolls_metric.compute(**scrolls_metric.convert_from_map_format(id_to_pred, id_to_labels))
121
+ print(json.dumps(metrics, indent=4))
122
+ dislist = [str(item) for item in metrics['display']]
123
+ print('final display:', dataset_name, ' '.join(dislist))
124
+ elif len(errors) > 0:
125
+ errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
126
+ print(json.dumps(errors, indent=4))
127
+ raise ValueError(f"Failed to evaluate due to: {errors_msg}")
128
+
129
+
130
+ def download_metric():
131
+ scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
132
+ updated_scrolls_metric_path = (
133
+ os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
134
+ )
135
+ shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
136
+ return updated_scrolls_metric_path
137
+
138
+ def verify(id_to_pred, id_to_labels):
139
+ errors = []
140
+ details = {"missing_keys": [], "redundant_keys": []}
141
+ if not isinstance(id_to_pred, dict):
142
+ errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
143
+ else:
144
+ if not all(isinstance(key, str) for key in id_to_pred.keys()):
145
+ errors.append("All keys of the predictions dictionary must be strings")
146
+ if not all(isinstance(value, str) for value in id_to_pred.values()):
147
+ errors.append("All values of the predictions dictionary must be strings")
148
+ if len(errors) == 0:
149
+ predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
150
+ missing_keys = reference_keys - predictions_keys
151
+ redundant_keys = predictions_keys - reference_keys
152
+
153
+ if len(missing_keys) > 0:
154
+ details["missing_keys"] = list(missing_keys)
155
+ errors.append(f"There are missing example IDs.")
156
+ else:
157
+ del details["missing_keys"]
158
+
159
+ if len(redundant_keys) > 0:
160
+ details["redundant_keys"] = list(redundant_keys)
161
+ errors.append(f"There are redundant example IDs.")
162
+ else:
163
+ del details["redundant_keys"]
164
+
165
+ return errors, details
166
+
167
+ if __name__ == "__main__":
168
+ parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
169
+
170
+ parser.add_argument("--datapath", type=str,
171
+ default=None, help="datapath for test json file [reference]")
172
+ parser.add_argument("--gen_test_file", type=str,
173
+ default=None, help="generations for test file [system prediction]")
174
+ parser.add_argument("--dataset", type=str,
175
+ default=None, help="name of the dataset used in scrolls: {}".format(DATASETS))
176
+
177
+ args = parser.parse_args()
178
+ main(args)
evaluation/long_32k_eval/dataset_evaluator_retro_longbench.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import shutil
5
+ import re
6
+
7
+ from datasets import load_dataset, load_metric
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from longbench.eval import scorer
11
+
12
+ LONGBENCH_DATASETS = [
13
+ 'musique', # NOTE TODO to add other 20 datasets
14
+ 'hotpotqa',
15
+ 'multifieldqa_en'
16
+ ]
17
+
18
+ PATTERN = re.compile(r'\b[A-D]\b')
19
+
20
+ def find_answer(s):
21
+ match = PATTERN.search(s)
22
+ if match is None:
23
+ return None # None is a signal of not find! NOTE
24
+ return match.group()
25
+
26
+ def read_json_data(data_path):
27
+ references = []
28
+ questions = []
29
+ id_to_labels = dict()
30
+ id_list = list()
31
+ idx = 0
32
+ with open(data_path, "r") as f:
33
+ examples = json.load(f)
34
+ for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
35
+ idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
36
+ idx += 1
37
+ id_list.append(idx_str)
38
+
39
+ questions.append(data_item['question'])
40
+ if "answers" in data_item:
41
+ references.append(data_item['answers']) # NOTE take all the answers!
42
+ answer_list = [answer_str for answer_str in data_item['answers']]
43
+ id_to_labels[idx_str] = answer_list
44
+
45
+ elif "answer" in data_item:
46
+ references.append([data_item['answer']]) # take the single answer, as a list
47
+ id_to_labels[idx_str] = [data_item['answer']]
48
+ else:
49
+ raise ValueError("need answer or answers from input json")
50
+ return id_to_labels, id_list, questions, references #answers
51
+
52
+ def convert_to_seq(aquestion, apred):
53
+ if apred is None:
54
+ apred = ""
55
+
56
+ matched_pred = find_answer(apred)
57
+ if matched_pred is None:
58
+ matched_pred = apred
59
+
60
+ apred = '({})'.format(matched_pred)
61
+
62
+ alist = aquestion.split('\n')
63
+ for aitem in alist:
64
+ aitem = aitem.strip()
65
+ if aitem.startswith(apred):
66
+ pred_out = ' '.join(aitem.split(' ')[1:])
67
+ print('from {} to [{}]'.format(apred, pred_out))
68
+ return pred_out
69
+ print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
70
+
71
+ return apred
72
+
73
+ def load_prediction_openai(test_file):
74
+ predictions = []
75
+ with open(test_file, "r") as f:
76
+ apred_list = list()
77
+ for aline in f.readlines():
78
+ if aline.startswith('assistant: '):
79
+ if len(apred_list) > 0:
80
+ print('\n'.join(apred_list))
81
+ predictions.append('\n'.join(apred_list))
82
+ apred_list = list()
83
+ apred_list.append(aline[len('assistant: '):].strip())
84
+ else:
85
+ apred_list.append(aline.strip())
86
+
87
+ if len(apred_list) > 0:
88
+ predictions.append('\n'.join(apred_list))
89
+ print(len(predictions))
90
+ return predictions
91
+
92
+
93
+ # 500 -> 100
94
+ def load_prediction(test_file, id_list, id_to_labels,
95
+ questions, dataset_name, is_openai_assistant=False):
96
+ if is_openai_assistant:
97
+ predictions = load_prediction_openai(test_file)
98
+ else:
99
+ predictions = []
100
+ with open(test_file, "r") as f:
101
+ for line in f.readlines():
102
+ predictions.append(line.strip())
103
+
104
+ if len(predictions) != len(id_list):
105
+ print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
106
+ len(predictions), len(id_list)))
107
+ id_list = id_list[0: len(predictions)]
108
+
109
+ id_to_prediction = dict()
110
+ for aid, apred in zip(id_list, predictions):
111
+ id_to_prediction[aid] = apred
112
+
113
+ if dataset_name.startswith('quality'):
114
+ print('quality dataset, and rewriting the prediction to the full textual sequence...')
115
+ questions = questions[0: len(predictions)]
116
+ id_to_prediction = dict()
117
+ for aid, aquestion, apred in zip(id_list, questions, predictions):
118
+ apred_seq = convert_to_seq(aquestion, apred)
119
+ id_to_prediction[aid] = apred_seq
120
+
121
+ return id_to_prediction, id_list, predictions
122
+
123
+ def main(args, raise_on_errors=False):
124
+ datasets = [args.dataset] if args.dataset in LONGBENCH_DATASETS else LONGBENCH_DATASETS
125
+ for dataset_name in datasets:
126
+ print(dataset_name)
127
+
128
+ id_to_labels, id_list, questions, answers = read_json_data(args.datapath)
129
+ id_to_pred, id_list, predictions = load_prediction(args.gen_test_file,
130
+ id_list, id_to_labels, questions,
131
+ dataset_name, args.is_openai_assistant)
132
+
133
+ if len(id_to_labels) > len(id_list):
134
+ print('NOTE: prune the reference set from {} to {}'.format(
135
+ len(id_to_labels), len(id_list)))
136
+ id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
137
+
138
+ errors, details = verify(id_to_pred, id_to_labels)
139
+
140
+ if len(errors) == 0:
141
+ score = scorer(dataset_name, predictions, answers, all_classes=None)
142
+ print('final display:', dataset_name, score, "\n", args.gen_test_file)
143
+ elif len(errors) > 0:
144
+ errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
145
+ print(json.dumps(errors, indent=4))
146
+ raise ValueError(f"Failed to evaluate due to: {errors_msg}")
147
+
148
+
149
+ def download_metric():
150
+ scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
151
+ updated_scrolls_metric_path = (
152
+ os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
153
+ )
154
+ shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
155
+ return updated_scrolls_metric_path
156
+
157
+ def verify(id_to_pred, id_to_labels):
158
+ errors = []
159
+ details = {"missing_keys": [], "redundant_keys": []}
160
+ if not isinstance(id_to_pred, dict):
161
+ errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
162
+ else:
163
+ if not all(isinstance(key, str) for key in id_to_pred.keys()):
164
+ errors.append("All keys of the predictions dictionary must be strings")
165
+ if not all(isinstance(value, str) for value in id_to_pred.values()):
166
+ errors.append("All values of the predictions dictionary must be strings")
167
+ if len(errors) == 0:
168
+ predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
169
+ missing_keys = reference_keys - predictions_keys
170
+ redundant_keys = predictions_keys - reference_keys
171
+
172
+ if len(missing_keys) > 0:
173
+ details["missing_keys"] = list(missing_keys)
174
+ errors.append(f"There are missing example IDs.")
175
+ else:
176
+ del details["missing_keys"]
177
+
178
+ if len(redundant_keys) > 0:
179
+ details["redundant_keys"] = list(redundant_keys)
180
+ errors.append(f"There are redundant example IDs.")
181
+ else:
182
+ del details["redundant_keys"]
183
+
184
+ return errors, details
185
+
186
+ if __name__ == "__main__":
187
+ parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
188
+
189
+ dataset_help = "name of the dataset used in longbench: {}".format(LONGBENCH_DATASETS)
190
+ parser.add_argument("--datapath", type=str, required=True,
191
+ default=None, help="datapath for test json file [reference]")
192
+ parser.add_argument("--gen_test_file", type=str, required=True,
193
+ default=None, help="generations for test file [system prediction]")
194
+ parser.add_argument("--dataset", type=str, required=True,
195
+ default=None, help=dataset_help)
196
+ parser.add_argument("--is_openai_assistant", type=bool, required=False,
197
+ default=False,
198
+ help='if openai assistant, then combine multiple lines and the 1st-line starts with assistant:')
199
+
200
+ args = parser.parse_args()
201
+ print(args)
202
+ main(args)
203
+
evaluation/long_32k_eval/dataset_evaluator_retro_nv.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import shutil
5
+ import re
6
+
7
+ from datasets import load_dataset, load_metric
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from nv.evaluate_f1_sft_zeroshot import evaluate_f1
11
+
12
+ DATASETS = [
13
+ 'doc2dial_full_dialogue',
14
+ ]
15
+
16
+ PATTERN = re.compile(r'\b[A-D]\b')
17
+
18
+ def find_answer(s):
19
+ match = PATTERN.search(s)
20
+ if match is None:
21
+ return None # None is a signal of not find! NOTE
22
+ return match.group()
23
+
24
+ def read_json_data(data_path):
25
+ references = []
26
+ questions = []
27
+ id_to_labels = dict()
28
+ id_list = list()
29
+ idx = 0
30
+ with open(data_path, "r") as f:
31
+ examples = json.load(f)
32
+ for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
33
+ idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
34
+ idx += 1
35
+ id_list.append(idx_str)
36
+
37
+ questions.append(data_item['question'])
38
+ if "answers" in data_item:
39
+ references.append(data_item['answers']) # NOTE take all the answers!
40
+ answer_list = [answer_str for answer_str in data_item['answers']]
41
+ id_to_labels[idx_str] = answer_list
42
+
43
+ elif "answer" in data_item:
44
+ references.append([data_item['answer']]) # take the single answer, as a list
45
+ id_to_labels[idx_str] = [data_item['answer']]
46
+ else:
47
+ raise ValueError("need answer or answers from input json")
48
+ return id_to_labels, id_list, questions, references #answers
49
+
50
+ def convert_to_seq(aquestion, apred):
51
+ if apred is None:
52
+ apred = ""
53
+
54
+ matched_pred = find_answer(apred)
55
+ if matched_pred is None:
56
+ matched_pred = apred
57
+
58
+ apred = '({})'.format(matched_pred)
59
+
60
+ alist = aquestion.split('\n')
61
+ for aitem in alist:
62
+ aitem = aitem.strip()
63
+ if aitem.startswith(apred):
64
+ pred_out = ' '.join(aitem.split(' ')[1:])
65
+ print('from {} to [{}]'.format(apred, pred_out))
66
+ return pred_out
67
+ print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
68
+
69
+ return apred
70
+
71
+ # 500 -> 100
72
+ def load_prediction(test_file, id_list, id_to_labels, questions, dataset_name):
73
+ predictions = []
74
+ with open(test_file, "r") as f:
75
+ for line in f.readlines():
76
+ predictions.append(line.strip())
77
+ if len(predictions) != len(id_list):
78
+ print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
79
+ len(predictions), len(id_list)))
80
+ id_list = id_list[0: len(predictions)]
81
+
82
+ id_to_prediction = dict()
83
+ for aid, apred in zip(id_list, predictions):
84
+ id_to_prediction[aid] = apred
85
+
86
+ if dataset_name.startswith('quality'):
87
+ print('quality dataset, and rewriting the prediction to the full textual sequence...')
88
+ questions = questions[0: len(predictions)]
89
+ id_to_prediction = dict()
90
+ for aid, aquestion, apred in zip(id_list, questions, predictions):
91
+ apred_seq = convert_to_seq(aquestion, apred)
92
+ id_to_prediction[aid] = apred_seq
93
+
94
+ return id_to_prediction, id_list, predictions
95
+
96
+ def main(args):
97
+ datasets = [args.dataset] if args.dataset in DATASETS else DATASETS
98
+ for dataset_name in datasets:
99
+ print(dataset_name)
100
+
101
+ ground_truth_file = args.datapath
102
+ prediction_file = args.gen_test_file
103
+
104
+ evaluate_f1(ground_truth_file, prediction_file, dataset_name)
105
+
106
+
107
+ def main_orig(args, raise_on_errors=False):
108
+ datasets = [args.dataset] if args.dataset in DATASETS else DATASETS
109
+ for dataset_name in datasets:
110
+ print(dataset_name)
111
+
112
+ id_to_labels, id_list, questions, answers = read_json_data(args.datapath)
113
+ id_to_pred, id_list, predictions = load_prediction(args.gen_test_file,
114
+ id_list, id_to_labels, questions,
115
+ dataset_name)
116
+
117
+ if len(id_to_labels) > len(id_list):
118
+ print('NOTE: prune the reference set from {} to {}'.format(
119
+ len(id_to_labels), len(id_list)))
120
+ id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
121
+
122
+ errors, details = verify(id_to_pred, id_to_labels)
123
+
124
+ if len(errors) == 0:
125
+ score = scorer(dataset_name, predictions, answers, all_classes=None)
126
+ print('final display:', dataset_name, score)
127
+ elif len(errors) > 0:
128
+ errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
129
+ print(json.dumps(errors, indent=4))
130
+ raise ValueError(f"Failed to evaluate due to: {errors_msg}")
131
+
132
+
133
+ def download_metric():
134
+ scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
135
+ updated_scrolls_metric_path = (
136
+ os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
137
+ )
138
+ shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
139
+ return updated_scrolls_metric_path
140
+
141
+ def verify(id_to_pred, id_to_labels):
142
+ errors = []
143
+ details = {"missing_keys": [], "redundant_keys": []}
144
+ if not isinstance(id_to_pred, dict):
145
+ errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
146
+ else:
147
+ if not all(isinstance(key, str) for key in id_to_pred.keys()):
148
+ errors.append("All keys of the predictions dictionary must be strings")
149
+ if not all(isinstance(value, str) for value in id_to_pred.values()):
150
+ errors.append("All values of the predictions dictionary must be strings")
151
+ if len(errors) == 0:
152
+ predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
153
+ missing_keys = reference_keys - predictions_keys
154
+ redundant_keys = predictions_keys - reference_keys
155
+
156
+ if len(missing_keys) > 0:
157
+ details["missing_keys"] = list(missing_keys)
158
+ errors.append(f"There are missing example IDs.")
159
+ else:
160
+ del details["missing_keys"]
161
+
162
+ if len(redundant_keys) > 0:
163
+ details["redundant_keys"] = list(redundant_keys)
164
+ errors.append(f"There are redundant example IDs.")
165
+ else:
166
+ del details["redundant_keys"]
167
+
168
+ return errors, details
169
+
170
+ if __name__ == "__main__":
171
+ parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
172
+
173
+ parser.add_argument("--datapath", type=str,
174
+ default=None, help="datapath for test json file [reference]")
175
+ parser.add_argument("--gen_test_file", type=str,
176
+ default=None, help="generations for test file [system prediction]")
177
+ parser.add_argument("--dataset", type=str,
178
+ default=None, help="name of the dataset used in scrolls: {}".format(DATASETS))
179
+
180
+ args = parser.parse_args()
181
+ main(args)
evaluation/long_32k_eval/eval_retro_vllm.sh ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #########################################################################
2
+ # File Name: eval.sh
3
+ # Author: Xianchao Wu, Peng Xu
4
5
+ # Created Time: Mon Sep 4 07:33:40 2024
6
+ #########################################################################
7
+ #!/bin/bash
8
+
9
+ # TODO change this to your reference file dir:
10
+ REFDIR="" # data_home https://huggingface.co/nvidia/Llama3-ChatQA-2-70B/tree/main/data
11
+
12
+
13
+ # TODO change to your tstdir
14
+ model_path=""
15
+ TSTDIR="${model_path}/outputs/"
16
+
17
+ model_size=70b # TODO change this
18
+ retriever="e5_mistral_retriever_chunkbysents1200"
19
+
20
+ adir=$TSTDIR
21
+ echo $adir
22
+
23
+ declare -A dataset2num_samples
24
+ dataset2num_samples["gov_report"]=200
25
+ dataset2num_samples["narrative_qa"]=2000
26
+ dataset2num_samples["qasper"]=2000
27
+ dataset2num_samples["qmsum"]=200
28
+ dataset2num_samples["quality"]=2000
29
+ dataset2num_samples["summ_screen_fd"]=200
30
+
31
+ dataset2num_samples["musique"]=200
32
+ dataset2num_samples["hotpotqa"]=200
33
+ dataset2num_samples["multifieldqa_en"]=200
34
+
35
+ dataset2num_samples["squality"]=200
36
+
37
+ dataset2num_samples["doc2dial_full_dialogue"]=1000
38
+
39
+ echo "ref path = $REFDIR"
40
+ echo "tst out path = $TSTDIR"
41
+
42
+ declare -A sys2name
43
+ sys2name["baseline"]=""
44
+ sys2name["ret"]="_ctx5"
45
+
46
+ for system in "baseline" "ret"
47
+ do
48
+ suffix=${sys2name[${system}]}
49
+
50
+ echo "--final display----$system----"
51
+ for adataset in "qmsum" "qasper" "quality"
52
+ do
53
+ echo $adataset
54
+
55
+ ref_fn="${REFDIR}/${adataset}.${retriever}/test.json"
56
+ tst_fn="${adir}/${adataset}.e5_mistral_retriever_chunkbysents1200_output_0to${dataset2num_samples[${adataset}]}${suffix}.txt"
57
+
58
+ echo "ref for ${adataset}", ${ref_fn}
59
+ echo "tstout for ${adataset}", ${tst_fn}
60
+ if [[ ! -e ${tst_fn} ]]; then
61
+ echo "Error: tst_fn=${tst_fn} not exist!"
62
+ fi
63
+ if [[ ! -e ${ref_fn} ]]; then
64
+ echo "Error: ref_fn=${ref_fn} not exist!"
65
+ fi
66
+ #continue
67
+
68
+ if [[ -e ${tst_fn} && -e ${ref_fn} ]]
69
+ then
70
+ python3 dataset_evaluator_retro.py \
71
+ --datapath ${ref_fn} \
72
+ --gen_test_file ${tst_fn} \
73
+ --dataset $adataset
74
+ else
75
+ if [[ ! -e ${tst_fn} ]]; then
76
+ echo "Error: tst_fn=${tst_fn} not exist!"
77
+ fi
78
+ if [[ ! -e ${ref_fn} ]]; then
79
+ echo "Error: ref_fn=${ref_fn} not exist!"
80
+ fi
81
+ fi
82
+ done
83
+
84
+ for adataset in "musique" "hotpotqa" "multifieldqa_en"
85
+ do
86
+ echo $adataset
87
+
88
+ ref_fn="${REFDIR}/${adataset}.${retriever}/test.json"
89
+ # TODO change this if necessary, model's prediction output
90
+ tst_fn="${adir}/${adataset}.e5_mistral_retriever_chunkbysents1200_output_0to${dataset2num_samples[${adataset}]}${suffix}.txt"
91
+
92
+ echo "ref for ${adataset}", ${ref_fn}
93
+ echo "tstout for ${adataset}", ${tst_fn}
94
+ if [[ ! -e ${tst_fn} ]]; then
95
+ echo "Error: tst_fn=${tst_fn} not exist!"
96
+ fi
97
+ if [[ ! -e ${ref_fn} ]]; then
98
+ echo "Error: ref_fn=${ref_fn} not exist!"
99
+ fi
100
+ #continue
101
+
102
+ if [[ -e ${tst_fn} && -e ${ref_fn} ]]
103
+ then
104
+ python3 dataset_evaluator_retro_longbench.py \
105
+ --datapath ${ref_fn} \
106
+ --gen_test_file ${tst_fn} \
107
+ --dataset $adataset
108
+ else
109
+ if [[ ! -e ${tst_fn} ]]; then
110
+ echo "Error: tst_fn=${tst_fn} not exist!"
111
+ fi
112
+ if [[ ! -e ${ref_fn} ]]; then
113
+ echo "Error: ref_fn=${ref_fn} not exist!"
114
+ fi
115
+ fi
116
+ done
117
+ done
118
+
evaluation/long_32k_eval/extract_log.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+
4
+ DATASETS1 = [
5
+ "qmsum",
6
+ "qasper",
7
+ "quality",
8
+ 'musique',
9
+ 'hotpotqa',
10
+ 'multifieldqa_en'
11
+ ]
12
+
13
+ DATASETS = [
14
+ "qmsum",
15
+ "qasper",
16
+ "quality",
17
+ 'musique',
18
+ 'hotpotqa',
19
+ 'multifieldqa_en',
20
+ ]
21
+
22
+ outrow = ''
23
+ data2res = dict()
24
+
25
+ def average(data2res):
26
+ sumvalue = 0.0
27
+ sumnum = 0.0
28
+ for adata in data2res:
29
+ avalue = data2res[adata]
30
+ sumvalue += avalue
31
+ sumnum += 1
32
+
33
+ assert sumnum > 0.0
34
+ return sumvalue/sumnum
35
+
36
+ def collect(value_list, outrow, data2res):
37
+ #print(value_list)
38
+ # first add the single avg score:
39
+ avg = round(np.mean(value_list), 4)
40
+ outrow += str(avg) + ' '
41
+
42
+ avg2 = average(data2res)
43
+ avg2 = round(avg2, 4)
44
+ outrow += str(avg2) + ' '
45
+
46
+ for adata in DATASETS:
47
+ ares = data2res[adata] if adata in data2res else "NA"
48
+ outrow += str(ares) + " "
49
+ print(outrow.strip())
50
+
51
+ print('system avg6 avg6 ' + ' '.join(DATASETS))
52
+
53
+ #infn = "eval_retro_2.sh.log.2"
54
+ #with open(infn) as br:
55
+ #for aline in br.readlines():
56
+ value_list = list()
57
+ for aline in sys.stdin:
58
+ #import ipdb; ipdb.set_trace()
59
+ aline = aline.strip()
60
+ if 'final display' in aline:
61
+ if '-baseline-' in aline or '-ret-' in aline:
62
+ if len(outrow) > 0 and len(data2res) > 0:
63
+ collect(value_list, outrow, data2res)
64
+
65
+ outrow = "" # reset
66
+ data2res = dict()
67
+ value_list = list()
68
+
69
+ aline2 = aline.replace('-', '')
70
+ aline2 = aline2.replace('final display', '')
71
+ outrow += aline2 + ' '
72
+ continue
73
+
74
+ cols = aline.split(' ')
75
+ adata = cols[2]
76
+ ares = '/'.join(cols[3:]) # NOTE use one geometric_mean instead
77
+ scores = cols[3:]
78
+ # for R1/R2/RL geometric_mean:
79
+ if len(scores) == 3:
80
+ scores = [float(item) for item in scores]
81
+ geo_mean = (scores[0] * scores[1] * scores[2]) ** (1.0 / 3.0)
82
+ ares = str(round(geo_mean, 4))
83
+
84
+ data2res[adata] = float(ares)
85
+ value_list.append(float(ares))
86
+
87
+ collect(value_list, outrow, data2res)
88
+
evaluation/long_32k_eval/longbench/__pycache__/eval.cpython-310.pyc ADDED
Binary file (3.23 kB). View file
 
evaluation/long_32k_eval/longbench/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (5.93 kB). View file
 
evaluation/long_32k_eval/longbench/eval.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+
6
+ from .metrics import (
7
+ qa_f1_score,
8
+ rouge_zh_score,
9
+ qa_f1_zh_score,
10
+ rouge_score,
11
+ classification_score,
12
+ retrieval_score,
13
+ retrieval_zh_score,
14
+ count_score,
15
+ code_sim_score,
16
+ )
17
+
18
+ dataset2metric = {
19
+ "narrativeqa": qa_f1_score,
20
+ "qasper": qa_f1_score,
21
+ "multifieldqa_en": qa_f1_score, # NOTE
22
+ "multifieldqa_zh": qa_f1_zh_score,
23
+ "hotpotqa": qa_f1_score, # NOTE
24
+ "2wikimqa": qa_f1_score,
25
+ "musique": qa_f1_score, # NOTE
26
+ "dureader": rouge_zh_score,
27
+ "gov_report": rouge_score,
28
+ "qmsum": rouge_score,
29
+ "multi_news": rouge_score,
30
+ "vcsum": rouge_zh_score,
31
+ "trec": classification_score,
32
+ "triviaqa": qa_f1_score,
33
+ "samsum": rouge_score,
34
+ "lsht": classification_score,
35
+ "passage_retrieval_en": retrieval_score,
36
+ "passage_count": count_score,
37
+ "passage_retrieval_zh": retrieval_zh_score,
38
+ "lcc": code_sim_score,
39
+ "repobench-p": code_sim_score,
40
+ }
41
+
42
+ def parse_args(args=None):
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--model', type=str, default=None)
45
+ parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
46
+ return parser.parse_args(args)
47
+
48
+ def scorer_e(dataset, predictions, answers, lengths, all_classes):
49
+ scores = {"0-4k": [], "4-8k": [], "8k+": []}
50
+ for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
51
+ score = 0.
52
+ if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
53
+ prediction = prediction.lstrip('\n').split('\n')[0]
54
+ for ground_truth in ground_truths:
55
+ score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
56
+ if length < 4000:
57
+ scores["0-4k"].append(score)
58
+ elif length < 8000:
59
+ scores["4-8k"].append(score)
60
+ else:
61
+ scores["8k+"].append(score)
62
+ for key in scores.keys():
63
+ scores[key] = round(100 * np.mean(scores[key]), 2)
64
+ return scores
65
+
66
+ def scorer(dataset, predictions, answers, all_classes):
67
+ # dataset = 'hotpotqa', 'musique', 'multifieldqa_en'
68
+ # predictions = [pred.str, ..., ]
69
+ # answers = [ [answer.str, ...], ... ]
70
+ # all_classes = None
71
+
72
+ #import ipdb; ipdb.set_trace() # all_classes=None for 'hotpotqa' dataset NOTE
73
+ total_score = 0.
74
+ for (prediction, ground_truths) in zip(predictions, answers):
75
+ score = 0.
76
+ if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
77
+ prediction = prediction.lstrip('\n').split('\n')[0]
78
+ for ground_truth in ground_truths:
79
+ score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
80
+ total_score += score
81
+
82
+ outscore = round(100 * total_score / len(predictions), 2)
83
+ print(dataset, outscore)
84
+ return outscore
85
+
86
+ if __name__ == '__main__':
87
+ #import ipdb; ipdb.set_trace()
88
+ args = parse_args()
89
+ scores = dict()
90
+ if args.e:
91
+ path = f"pred_e/{args.model}/"
92
+ else:
93
+ path = f"pred/{args.model}/" # 'pred/chatglm2-6b-32k/' NOTE
94
+ all_files = os.listdir(path) # 21 files
95
+ print("Evaluating on:", all_files)
96
+
97
+ for filename in all_files:
98
+ #import ipdb; ipdb.set_trace()
99
+ if not filename.endswith("jsonl"):
100
+ continue
101
+ predictions, answers, lengths = [], [], []
102
+ dataset = filename.split('.')[0] # 获取数据集的名字
103
+ if not dataset in ['musique', 'hotpotqa', 'multifieldqa_en']:
104
+ continue # TODO debug only
105
+ with open(f"{path}{filename}", "r", encoding="utf-8") as f:
106
+ for line in f: # 每一行,进行一次json的解析
107
+ data = json.loads(line)
108
+ predictions.append(data["pred"])
109
+ answers.append(data["answers"])
110
+ all_classes = data["all_classes"] # 这是属于被一次次重复赋值了
111
+ if "length" in data:
112
+ lengths.append(data["length"])
113
+ if args.e:
114
+ score = scorer_e(dataset, predictions, answers, lengths, all_classes)
115
+ else:
116
+ score = scorer(dataset, predictions, answers, all_classes) # NOTE 重要的计算得分的入口 TODO 1. dataset=具体的数据集的名字;predictions=list of str,预测结果; answers = list of list,参考答案; all_classes这是原本就带的,test in
117
+ scores[dataset] = score
118
+ if args.e:
119
+ out_path = f"pred_e/{args.model}/result.json"
120
+ else:
121
+ out_path = f"pred/{args.model}/result.json"
122
+
123
+ print(scores)
124
+
125
+ with open(out_path, "w") as f:
126
+ json.dump(scores, f, ensure_ascii=False, indent=4)
127
+
evaluation/long_32k_eval/longbench/metrics.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+ import jieba
5
+ from fuzzywuzzy import fuzz
6
+ import difflib
7
+
8
+ from typing import List
9
+ from collections import Counter
10
+ from rouge import Rouge
11
+
12
+ def normalize_answer(s):
13
+ """Lower text and remove punctuation, articles and extra whitespace."""
14
+
15
+ def remove_articles(text):
16
+ return re.sub(r"\b(a|an|the)\b", " ", text)
17
+
18
+ def white_space_fix(text):
19
+ return " ".join(text.split())
20
+
21
+ def remove_punc(text):
22
+ exclude = set(string.punctuation)
23
+ return "".join(ch for ch in text if ch not in exclude)
24
+
25
+ def lower(text):
26
+ return text.lower()
27
+
28
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
29
+
30
+
31
+ def normalize_zh_answer(s):
32
+ """Lower text and remove punctuation, extra whitespace."""
33
+
34
+ def white_space_fix(text):
35
+ return "".join(text.split())
36
+
37
+ def remove_punc(text):
38
+ cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
39
+ all_punctuation = set(string.punctuation + cn_punctuation)
40
+ return "".join(ch for ch in text if ch not in all_punctuation)
41
+
42
+ def lower(text):
43
+ return text.lower()
44
+
45
+ return white_space_fix(remove_punc(lower(s)))
46
+
47
+ def count_score(prediction, ground_truth, **kwargs):
48
+ numbers = re.findall(r"\d+", prediction)
49
+ right_num = 0
50
+ for number in numbers:
51
+ if str(number) == str(ground_truth):
52
+ right_num += 1
53
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
54
+ return float(final_score)
55
+
56
+ def retrieval_score(prediction, ground_truth, **kwargs):
57
+ pattern = r'Paragraph (\d+)'
58
+ matches = re.findall(pattern, ground_truth)
59
+ ground_truth_id = matches[0]
60
+ numbers = re.findall(r"\d+", prediction)
61
+ right_num = 0
62
+ for number in numbers:
63
+ if str(number) == str(ground_truth_id):
64
+ right_num += 1
65
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
66
+ return float(final_score)
67
+
68
+ def retrieval_zh_score(prediction, ground_truth, **kwargs):
69
+ pattern = r'段落(\d+)'
70
+ matches = re.findall(pattern, ground_truth)
71
+ ground_truth_id = matches[0]
72
+ numbers = re.findall(r"\d+", prediction)
73
+ right_num = 0
74
+ for number in numbers:
75
+ if str(number) == str(ground_truth_id):
76
+ right_num += 1
77
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
78
+ return float(final_score)
79
+
80
+ def code_sim_score(prediction, ground_truth, **kwargs):
81
+ all_lines = prediction.lstrip('\n').split('\n')
82
+ prediction = ""
83
+ for line in all_lines:
84
+ if ('`' not in line) and ('#' not in line) and ('//' not in line):
85
+ prediction = line
86
+ break
87
+ return (fuzz.ratio(prediction, ground_truth) / 100)
88
+
89
+ def classification_score(prediction, ground_truth, **kwargs):
90
+ em_match_list = []
91
+ all_classes = kwargs["all_classes"]
92
+ for class_name in all_classes:
93
+ if class_name in prediction:
94
+ em_match_list.append(class_name)
95
+ for match_term in em_match_list:
96
+ if match_term in ground_truth and match_term != ground_truth:
97
+ em_match_list.remove(match_term)
98
+ if em_match_list != 0:
99
+ if ground_truth in em_match_list:
100
+ score = (1.0 / len(em_match_list))
101
+ else:
102
+ score = 0.0
103
+ else:
104
+ best_match = None
105
+ highest_similarity = 0
106
+ for string in all_classes:
107
+ similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
108
+ if similarity > highest_similarity:
109
+ highest_similarity = similarity
110
+ best_match = string
111
+ score = float(best_match == ground_truth)
112
+ return score
113
+
114
+ def rouge_score(prediction, ground_truth, **kwargs):
115
+ rouge = Rouge()
116
+ try:
117
+ scores = rouge.get_scores([prediction], [ground_truth], avg=True)
118
+ except:
119
+ return 0.0
120
+ return scores["rouge-l"]["f"]
121
+
122
+ def rouge_zh_score(prediction, ground_truth, **kwargs):
123
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
124
+ ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
125
+ score = rouge_score(prediction, ground_truth)
126
+ return score
127
+
128
+ def f1_score(prediction, ground_truth, **kwargs):
129
+ common = Counter(prediction) & Counter(ground_truth)
130
+ num_same = sum(common.values())
131
+ if num_same == 0:
132
+ return 0
133
+ precision = 1.0 * num_same / len(prediction)
134
+ recall = 1.0 * num_same / len(ground_truth)
135
+ f1 = (2 * precision * recall) / (precision + recall)
136
+ return f1
137
+
138
+ def qa_f1_score(prediction, ground_truth, **kwargs):
139
+ normalized_prediction = normalize_answer(prediction)
140
+ normalized_ground_truth = normalize_answer(ground_truth)
141
+
142
+ prediction_tokens = normalized_prediction.split()
143
+ ground_truth_tokens = normalized_ground_truth.split()
144
+ return f1_score(prediction_tokens, ground_truth_tokens)
145
+
146
+
147
+ def qa_f1_zh_score(prediction, ground_truth, **kwargs):
148
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
149
+ ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
150
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
151
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
152
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
153
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
154
+ return f1_score(prediction_tokens, ground_truth_tokens)
evaluation/long_32k_eval/run_eval_vllm.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ adate=`date +%Y_%m_%d_%H_%M_%S`
4
+ outlog="eval_retro_vllm.sh.log.6_4setsplus.$adate"
5
+
6
+ echo "out log = $outlog"
7
+
8
+ bash eval_retro_vllm.sh > $outlog 2>&1
9
+
10
+ grep "final display" $outlog | python3 extract_log.py