rookiemango
commited on
Upload folder using huggingface_hub
Browse files- __pycache__/process_supervision.cpython-39.pyc +0 -0
- __pycache__/test.cpython-39.pyc +0 -0
- generation_method.py +7 -5
- multirun/allresults.json_temp_0.json +0 -0
- process_supervision_training_data.py +403 -0
- requirements.txt +4 -0
- test.py +82 -0
__pycache__/process_supervision.cpython-39.pyc
ADDED
Binary file (3.82 kB). View file
|
|
__pycache__/test.cpython-39.pyc
ADDED
Binary file (1.91 kB). View file
|
|
generation_method.py
CHANGED
@@ -229,6 +229,8 @@ def parse_arguments():
|
|
229 |
|
230 |
if args.dataset == "lean4_5k_test":
|
231 |
args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
|
|
|
|
|
232 |
|
233 |
elif args.dataset == "math_train":
|
234 |
args.data_path = "data/test/math/train.jsonl"
|
@@ -237,12 +239,12 @@ def parse_arguments():
|
|
237 |
args.data_path = "data/test/gsm8k/train.jsonl"
|
238 |
|
239 |
elif args.dataset == "wild_test":
|
240 |
-
args.data_path = "/
|
241 |
|
242 |
elif args.dataset == "lean4_basic_test":
|
243 |
-
args.data_path = "data/lean4_basic/
|
244 |
elif args.dataset == "lean4_random_test":
|
245 |
-
args.data_path = "data/lean4_random/
|
246 |
elif args.dataset == "lean4_random_first_train":
|
247 |
args.data_path = "data/lean4_random/5k_first.json"
|
248 |
elif args.dataset == "lean4_random_second_train":
|
@@ -324,7 +326,7 @@ PROMPT_DICT = {
|
|
324 |
),
|
325 |
"lean4": (
|
326 |
"Statement and proof in natural language:\n\n"
|
327 |
-
"{
|
328 |
"Translate the statement and proof in natural language to lean4:"
|
329 |
),
|
330 |
"prompt_no_input": (
|
@@ -366,7 +368,7 @@ def get_question_answer(args):
|
|
366 |
questions = [ PROMPT_DICT['wild'].format(question= questions[id], answer =answers[id][args.data_answer_key] ) for id in range(len(questions))]
|
367 |
|
368 |
else:
|
369 |
-
questions = [ PROMPT_DICT['lean4'].format(
|
370 |
|
371 |
|
372 |
return questions, answers
|
|
|
229 |
|
230 |
if args.dataset == "lean4_5k_test":
|
231 |
args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
|
232 |
+
elif args.dataset == "lean4_15k_train":
|
233 |
+
args.data_path = "data/lean4_random/15k_filtered.json"
|
234 |
|
235 |
elif args.dataset == "math_train":
|
236 |
args.data_path = "data/test/math/train.jsonl"
|
|
|
239 |
args.data_path = "data/test/gsm8k/train.jsonl"
|
240 |
|
241 |
elif args.dataset == "wild_test":
|
242 |
+
args.data_path = "data/wild/wild_sample1k.jsonl"
|
243 |
|
244 |
elif args.dataset == "lean4_basic_test":
|
245 |
+
args.data_path = "data/lean4_basic/1k_test_filtered.jsonl"
|
246 |
elif args.dataset == "lean4_random_test":
|
247 |
+
args.data_path = "data/lean4_random/1k_test_filtered.json"
|
248 |
elif args.dataset == "lean4_random_first_train":
|
249 |
args.data_path = "data/lean4_random/5k_first.json"
|
250 |
elif args.dataset == "lean4_random_second_train":
|
|
|
326 |
),
|
327 |
"lean4": (
|
328 |
"Statement and proof in natural language:\n\n"
|
329 |
+
"{model_response}\n\n"
|
330 |
"Translate the statement and proof in natural language to lean4:"
|
331 |
),
|
332 |
"prompt_no_input": (
|
|
|
368 |
questions = [ PROMPT_DICT['wild'].format(question= questions[id], answer =answers[id][args.data_answer_key] ) for id in range(len(questions))]
|
369 |
|
370 |
else:
|
371 |
+
questions = [ PROMPT_DICT['lean4'].format(model_response = item) for item in questions]
|
372 |
|
373 |
|
374 |
return questions, answers
|
multirun/allresults.json_temp_0.json
ADDED
File without changes
|
process_supervision_training_data.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
from process_supervision import load_tokenizer
|
7 |
+
from test import delete_extra_zero
|
8 |
+
|
9 |
+
def process_line(tokenizer, lines, wf_name):
|
10 |
+
acc = [] # Will store tuples of (label, prediction) for each expression
|
11 |
+
recall_count = [0, 0] # [number of correct positives, number of actual positives]
|
12 |
+
hullucination = []
|
13 |
+
import json
|
14 |
+
|
15 |
+
rft_file_line = 0
|
16 |
+
rft_list = []
|
17 |
+
verifier_file_line = 0
|
18 |
+
verifier_list = []
|
19 |
+
rft_verifier_file_line = 0
|
20 |
+
rft_verifier_list = []
|
21 |
+
acc = []
|
22 |
+
|
23 |
+
|
24 |
+
with open(wf_name, 'w', encoding='utf-8') as wf:
|
25 |
+
for line in tqdm.tqdm(lines):
|
26 |
+
for output in line['outputs']:
|
27 |
+
v_scores = output.get('vscores', [])
|
28 |
+
response = output.get('response', "")
|
29 |
+
is_true = output.get('label', "")
|
30 |
+
|
31 |
+
if is_true:
|
32 |
+
rft_list.append({"question": line['question'], "answer": output['response']})
|
33 |
+
rft_file_line += 1
|
34 |
+
if v_scores and v_scores[-1] > 0.5:
|
35 |
+
# Save to rft_verifier_enhanced.json
|
36 |
+
rft_verifier_list.append({"question": line['question'], "answer": output['response']})
|
37 |
+
rft_verifier_file_line += 1
|
38 |
+
if v_scores and v_scores[-1] > 0.5:
|
39 |
+
verifier_list.append({"question": line['question'], "answer": output['response']})
|
40 |
+
verifier_file_line += 1
|
41 |
+
if is_true:
|
42 |
+
acc.append(1)
|
43 |
+
else:
|
44 |
+
acc.append(0)
|
45 |
+
print(rft_file_line)
|
46 |
+
print(verifier_file_line)
|
47 |
+
print(rft_verifier_file_line)
|
48 |
+
print("acc" , sum(acc)/len(acc))
|
49 |
+
|
50 |
+
with open("data/continual_training/rft.json", 'w', encoding='utf-8') as rft_file:
|
51 |
+
json.dump(rft_list, rft_file , ensure_ascii=False, indent=2)
|
52 |
+
|
53 |
+
with open("data/continual_training/verifier_enhanced.json", 'w', encoding='utf-8') as verifier_file:
|
54 |
+
json.dump(verifier_list, verifier_file, ensure_ascii=False, indent=2)
|
55 |
+
|
56 |
+
with open("data/continual_training/rft_verifier_enhanced.json", 'w', encoding='utf-8') as rft_verifier_file:
|
57 |
+
json.dump(rft_verifier_list, rft_verifier_file , ensure_ascii=False, indent=2)
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
def locate_sublist(lst, sublst):
|
64 |
+
for i in range(len(lst) - len(sublst) + 1):
|
65 |
+
if lst[i:i+len(sublst)] == sublst:
|
66 |
+
return i # Return the starting index of the sublist in the list
|
67 |
+
assert ('not right')
|
68 |
+
|
69 |
+
|
70 |
+
def split_string_list(a_list, number ='\n'):
|
71 |
+
sublists = []
|
72 |
+
current_sublist = []
|
73 |
+
for item in a_list:
|
74 |
+
current_sublist.append(item)
|
75 |
+
if item == number:
|
76 |
+
if current_sublist: # if the current sublist is not empty
|
77 |
+
sublists.append(''.join(current_sublist))
|
78 |
+
current_sublist = [] # start a new sublist
|
79 |
+
|
80 |
+
# Don't forget to add the last sublist if it's not empty
|
81 |
+
if current_sublist:
|
82 |
+
sublists.append(''.join(current_sublist))
|
83 |
+
|
84 |
+
return sublists
|
85 |
+
def split_token_list(a_list, number =13):
|
86 |
+
sublists = []
|
87 |
+
current_sublist = []
|
88 |
+
for item in a_list:
|
89 |
+
current_sublist.append(item)
|
90 |
+
if item == number:
|
91 |
+
if current_sublist: # if the current sublist is not empty
|
92 |
+
sublists.append(current_sublist)
|
93 |
+
current_sublist = [] # start a new sublist
|
94 |
+
|
95 |
+
# Don't forget to add the last sublist if it's not empty
|
96 |
+
if current_sublist:
|
97 |
+
sublists.append(current_sublist)
|
98 |
+
|
99 |
+
return sublists
|
100 |
+
# Modify evaluate_expression function to return a list of results
|
101 |
+
|
102 |
+
|
103 |
+
def evaluate_expression_para(response_all, v_score, tokenizer, is_true):
|
104 |
+
# Initialize lists to hold multiple evaluation results for each expression
|
105 |
+
# here we make the v_score label in a "first error detection"
|
106 |
+
labels = []
|
107 |
+
predictions = []
|
108 |
+
sol_tokens = tokenizer(response_all).input_ids
|
109 |
+
process_v_score = [0] * len(sol_tokens)
|
110 |
+
hullucination = False
|
111 |
+
gt_help = False
|
112 |
+
error_detection = False
|
113 |
+
response_list = split_string_list(response_all)
|
114 |
+
token_list = split_token_list(sol_tokens)
|
115 |
+
previous_len = 0
|
116 |
+
for idx, string in enumerate(response_list):
|
117 |
+
# match = re.search(r'<<(.+?)>>', string)
|
118 |
+
para_token = token_list[idx]
|
119 |
+
para_token_location = sum([len(item) for item in token_list[:idx]])
|
120 |
+
|
121 |
+
if error_detection:
|
122 |
+
break
|
123 |
+
|
124 |
+
|
125 |
+
if abs(v_score[para_token_location]) < 1e-5:
|
126 |
+
error_detection = True
|
127 |
+
|
128 |
+
elif (v_score[para_token_location + len(para_token) - 1] - v_score[para_token_location])/v_score[para_token_location] < -0.5:
|
129 |
+
error_detection = True
|
130 |
+
|
131 |
+
else:
|
132 |
+
if not error_detection:
|
133 |
+
process_v_score[para_token_location : para_token_location + len(para_token) ] = [1] * len(para_token)
|
134 |
+
|
135 |
+
# if match:
|
136 |
+
# expression = match.group(1)
|
137 |
+
# start_token = tokenizer(response_all[ : previous_len + match.span()[0]]).input_ids
|
138 |
+
# if sol_tokens[:len(start_token)] != start_token:
|
139 |
+
# start_token = start_token[:-1]
|
140 |
+
# # print(tokenizer.decode(start_token))
|
141 |
+
# seg_token_location = len(start_token)
|
142 |
+
# seq_token = tokenizer(response_all[: previous_len + match.span()[1]]).input_ids[len(start_token):]
|
143 |
+
# # print(tokenizer.decode(seq_token))
|
144 |
+
# # Check if v_score change is positive
|
145 |
+
# try:
|
146 |
+
# if abs(v_score[seg_token_location]) < 1e-5:
|
147 |
+
# prediction = 'negative' # there is a extra example in v_score
|
148 |
+
# error_detection = True
|
149 |
+
#
|
150 |
+
# elif (v_score[min(seg_token_location + len(seq_token), len(v_score) - 1)] - v_score[seg_token_location]) / (v_score[seg_token_location]) < -0.9:
|
151 |
+
# prediction = 'negative' # there is a negative change in v_score
|
152 |
+
# error_detection = True
|
153 |
+
# else:
|
154 |
+
# prediction = 'positive' # no negative change in v_score
|
155 |
+
# if not error_detection:
|
156 |
+
# process_v_score[para_token_location : para_token_location + len(para_token)] = [1] * len(para_token)
|
157 |
+
# except:
|
158 |
+
# import pdb
|
159 |
+
# pdb.set_trace()
|
160 |
+
# try:
|
161 |
+
# before_equal, after_equal = expression.split('=')
|
162 |
+
# computed_value = float(eval(before_equal.strip()))
|
163 |
+
# actual_value = float(delete_extra_zero(after_equal.strip().replace(",", "")))
|
164 |
+
# # Use the positive v_score change as a proxy for a correct evaluation
|
165 |
+
# if abs(computed_value - actual_value) <= 1e-3:
|
166 |
+
# label = 'positive'
|
167 |
+
# else:
|
168 |
+
# label = 'negative'
|
169 |
+
# hullucination = True
|
170 |
+
#
|
171 |
+
# # Record the label and prediction for this expression
|
172 |
+
# labels.append(label)
|
173 |
+
# predictions.append(prediction)
|
174 |
+
# except Exception as e:
|
175 |
+
# pass
|
176 |
+
# else:
|
177 |
+
# if not error_detection:
|
178 |
+
# process_v_score[para_token_location: para_token_location + len(para_token)] = [1] * len(para_token)
|
179 |
+
#
|
180 |
+
|
181 |
+
if idx == len(response_list) - 1 and not error_detection and not is_true:
|
182 |
+
process_v_score[para_token_location: para_token_location + len(para_token)] = [0] * len(para_token)
|
183 |
+
gt_help = True
|
184 |
+
|
185 |
+
previous_len += len(string)
|
186 |
+
# if sum(process_v_score) != len(process_v_score) and sum(process_v_score) != 0:
|
187 |
+
# print(process_v_score)
|
188 |
+
|
189 |
+
return {'label': labels, 'prediction': predictions, 'hullucination': hullucination, 'gt_help': gt_help}, process_v_score
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
def evaluate_expression(response_all, v_score, tokenizer, is_true):
|
196 |
+
# Initialize lists to hold multiple evaluation results for each expression
|
197 |
+
# here we make the v_score label in a "first error detection"
|
198 |
+
labels = []
|
199 |
+
predictions = []
|
200 |
+
sol_tokens = tokenizer(response_all).input_ids
|
201 |
+
process_v_score = [0] * len(sol_tokens)
|
202 |
+
hullucination = False
|
203 |
+
gt_help = False
|
204 |
+
error_detection = False
|
205 |
+
response_list = split_string_list(response_all)
|
206 |
+
token_list = split_token_list(sol_tokens)
|
207 |
+
previous_len = 0
|
208 |
+
for idx, string in enumerate(response_list):
|
209 |
+
match = re.search(r'<<(.+?)>>', string)
|
210 |
+
para_token = token_list[idx]
|
211 |
+
para_token_location = sum([len(item) for item in token_list[:idx]])
|
212 |
+
if match:
|
213 |
+
expression = match.group(1)
|
214 |
+
start_token = tokenizer(response_all[ : previous_len + match.span()[0]]).input_ids
|
215 |
+
if sol_tokens[:len(start_token)] != start_token:
|
216 |
+
start_token = start_token[:-1]
|
217 |
+
# print(tokenizer.decode(start_token))
|
218 |
+
seg_token_location = len(start_token)
|
219 |
+
seq_token = tokenizer(response_all[: previous_len + match.span()[1]]).input_ids[len(start_token):]
|
220 |
+
# print(tokenizer.decode(seq_token))
|
221 |
+
# Check if v_score change is positive
|
222 |
+
try:
|
223 |
+
if abs(v_score[seg_token_location]) < 1e-5:
|
224 |
+
prediction = 'negative' # there is a extra example in v_score
|
225 |
+
error_detection = True
|
226 |
+
|
227 |
+
elif (v_score[min(seg_token_location + len(seq_token), len(v_score) - 1)] - v_score[seg_token_location]) / (v_score[seg_token_location]) < -0.9:
|
228 |
+
prediction = 'negative' # there is a negative change in v_score
|
229 |
+
error_detection = True
|
230 |
+
else:
|
231 |
+
prediction = 'positive' # no negative change in v_score
|
232 |
+
if not error_detection:
|
233 |
+
process_v_score[para_token_location : para_token_location + len(para_token)] = [1] * len(para_token)
|
234 |
+
except:
|
235 |
+
import pdb
|
236 |
+
pdb.set_trace()
|
237 |
+
try:
|
238 |
+
before_equal, after_equal = expression.split('=')
|
239 |
+
computed_value = float(eval(before_equal.strip()))
|
240 |
+
actual_value = float(delete_extra_zero(after_equal.strip().replace(",", "")))
|
241 |
+
# Use the positive v_score change as a proxy for a correct evaluation
|
242 |
+
if abs(computed_value - actual_value) <= 1e-3:
|
243 |
+
label = 'positive'
|
244 |
+
else:
|
245 |
+
label = 'negative'
|
246 |
+
hullucination = True
|
247 |
+
|
248 |
+
# Record the label and prediction for this expression
|
249 |
+
labels.append(label)
|
250 |
+
predictions.append(prediction)
|
251 |
+
except Exception as e:
|
252 |
+
pass
|
253 |
+
else:
|
254 |
+
if not error_detection:
|
255 |
+
process_v_score[para_token_location: para_token_location + len(para_token)] = [1] * len(para_token)
|
256 |
+
|
257 |
+
|
258 |
+
# if idx == len(response_list) - 1 and not error_detection and not is_true:
|
259 |
+
# process_v_score[para_token_location: para_token_location + len(para_token)] = [0] * len(para_token)
|
260 |
+
# gt_help = True
|
261 |
+
|
262 |
+
previous_len += len(string)
|
263 |
+
# if sum(process_v_score) != len(process_v_score) and sum(process_v_score) != 0:
|
264 |
+
# print(process_v_score)
|
265 |
+
|
266 |
+
return {'label': labels, 'prediction': predictions, 'hullucination': hullucination, 'gt_help': gt_help}, process_v_score
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
import multiprocessing
|
274 |
+
from functools import partial
|
275 |
+
import os
|
276 |
+
def process_chunk(tokenizer, chunk, wf_path):
|
277 |
+
acc = [] # Will store tuples of (label, prediction) for each expression
|
278 |
+
recall_count = [0, 0] # [number of correct positives, number of actual positives]
|
279 |
+
hullucination = []
|
280 |
+
gt_help = []
|
281 |
+
|
282 |
+
with open(wf_path, 'w', encoding='utf-8') as wf:
|
283 |
+
for line in tqdm.tqdm(chunk):
|
284 |
+
for output in line['outputs']:
|
285 |
+
import pdb
|
286 |
+
pdb.set_trace()
|
287 |
+
v_scores = output.get('vscores', [])
|
288 |
+
response = output.get('response', "")
|
289 |
+
is_true = output.get('label', "")
|
290 |
+
evaluation_results, process_v_scores = evaluate_expression_para(response, v_scores, tokenizer, is_true)
|
291 |
+
# output['process_vscores'] = process_v_scores
|
292 |
+
|
293 |
+
if evaluation_results['hullucination']:
|
294 |
+
hullucination.append(1)
|
295 |
+
else:
|
296 |
+
hullucination.append(0)
|
297 |
+
|
298 |
+
if evaluation_results['gt_help']:
|
299 |
+
gt_help.append(1)
|
300 |
+
else:
|
301 |
+
gt_help.append(0)
|
302 |
+
|
303 |
+
|
304 |
+
# Add the results to the accuracy list for each expression
|
305 |
+
for label, prediction in zip(evaluation_results['label'], evaluation_results['prediction']):
|
306 |
+
acc.append((label, prediction))
|
307 |
+
|
308 |
+
# Update recall counts for each expression
|
309 |
+
for idx, prediction in enumerate(evaluation_results['prediction']):
|
310 |
+
label = evaluation_results['label'][idx]
|
311 |
+
if label == 'positive':
|
312 |
+
recall_count[1] += 1 # Increment the count of actual positives
|
313 |
+
if prediction == 'positive':
|
314 |
+
recall_count[0] += 1 # Increment the count of correct positives
|
315 |
+
wf.writelines(json.dumps(line, ensure_ascii=False) + '\n')
|
316 |
+
|
317 |
+
# Calculate metrics for the chunk
|
318 |
+
accuracy = sum(1 for label, prediction in acc if label == prediction) / len(acc) if acc else 0
|
319 |
+
hullucination_rate = sum(hullucination) / len(hullucination) if hullucination else 0
|
320 |
+
# Return the metrics and counts, not just the rates, to allow aggregation
|
321 |
+
return {
|
322 |
+
"accuracy_sum": sum(1 for label, prediction in acc if label == prediction),
|
323 |
+
"total": len(acc),
|
324 |
+
"recall_correct": recall_count[0],
|
325 |
+
"recall_total": recall_count[1],
|
326 |
+
"hullucination_sum": sum(hullucination),
|
327 |
+
"hullucination_total": len(hullucination),
|
328 |
+
"gt_help_sum": sum(gt_help),
|
329 |
+
"gt_help_total": len(gt_help),
|
330 |
+
}
|
331 |
+
# print(
|
332 |
+
# f"Chunk accuracy: {accuracy}, Chunk recall: {recall}, Chunk hullucination: {sum(hullucination) / len(hullucination) if hullucination else 0}")
|
333 |
+
|
334 |
+
|
335 |
+
|
336 |
+
def parallel_process_line(tokenizer, lines, wf_path, num_processes=1):
|
337 |
+
if num_processes is None:
|
338 |
+
num_processes = multiprocessing.cpu_count()
|
339 |
+
|
340 |
+
# Split lines into chunks
|
341 |
+
chunk_size = int(len(lines) / num_processes)
|
342 |
+
chunks = [lines[i:i + chunk_size] for i in range(0, len(lines), chunk_size)]
|
343 |
+
|
344 |
+
# Generate a unique temporary file path for each chunk
|
345 |
+
temp_files = [f"multirun/{wf_path}_temp_{i}.json" for i in range(len(chunks))]
|
346 |
+
|
347 |
+
# Create a pool of workers to process data in parallel
|
348 |
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
349 |
+
# Map each chunk to process_chunk function along with a unique temporary file path
|
350 |
+
results = pool.starmap(process_chunk, [(tokenizer, chunk, temp_file) for chunk, temp_file in zip(chunks, temp_files)])
|
351 |
+
|
352 |
+
# Combine results from temporary files into the final output file
|
353 |
+
with open(f"multirun2/{wf_path}.json", 'w', encoding='utf-8') as wf:
|
354 |
+
for temp_file in temp_files:
|
355 |
+
with open(temp_file, 'r', encoding='utf-8') as tf:
|
356 |
+
wf.write(tf.read())
|
357 |
+
os.remove(temp_file) # Clean up temporary file
|
358 |
+
|
359 |
+
# Aggregate metrics from all chunks
|
360 |
+
total_acc = sum(result['accuracy_sum'] for result in results)
|
361 |
+
total = sum(result['total'] for result in results)
|
362 |
+
total_recall_correct = sum(result['recall_correct'] for result in results)
|
363 |
+
total_recall = sum(result['recall_total'] for result in results)
|
364 |
+
total_hullucination = sum(result['hullucination_sum'] for result in results)
|
365 |
+
total_hullucination_counts = sum(result['hullucination_total'] for result in results)
|
366 |
+
total_gt_help = sum(result['gt_help_sum'] for result in results)
|
367 |
+
total_gt_help_counts = sum(result['gt_help_total'] for result in results)
|
368 |
+
|
369 |
+
# Calculate overall metrics
|
370 |
+
overall_accuracy = total_acc / total if total else 0
|
371 |
+
overall_recall = total_recall_correct / total_recall if total_recall else 0
|
372 |
+
overall_hullucination = total_hullucination / total_hullucination_counts if total_hullucination_counts else 0
|
373 |
+
overall_gt_help = total_gt_help/ total_gt_help_counts if total_gt_help_counts else 0
|
374 |
+
|
375 |
+
print(f"Overall accuracy: {overall_accuracy}")
|
376 |
+
print(f"Overall recall: {overall_recall}")
|
377 |
+
print(f"Overall hullucination: {overall_hullucination}")
|
378 |
+
print(f"Overall gt_help: {overall_gt_help}")
|
379 |
+
|
380 |
+
|
381 |
+
|
382 |
+
# Example usage
|
383 |
+
# line = '{"outputs": [{"solution_str": "The result is <<5 * 3 = 15>>."}, {"solution_str": "The answer is <<2 + 2 = 5>>."}]}'
|
384 |
+
# file_path = "eval_results/gsm8k/verifier/train/responses_v(threemodel)_g(threemodel).jsonl"
|
385 |
+
file_path = "eval_results/math/verifier/test/responses_v(lean4_random_15k_all-sample10-osv-gt2)_g(lean4_rand).jsonl"
|
386 |
+
line = [json.loads(line) for line in open(file_path, 'r', encoding = 'utf-8').readlines()]
|
387 |
+
for ex in line:
|
388 |
+
dedup_outputs = []
|
389 |
+
for output in ex['outputs']:
|
390 |
+
if len(output['tokens']) > 2048:
|
391 |
+
continue
|
392 |
+
dedup_outputs.append(output)
|
393 |
+
ex['outputs'] = dedup_outputs
|
394 |
+
|
395 |
+
model_dir = "../models/lean4_random_15k_all-sample10-osv-gt2/"
|
396 |
+
tokenizer = load_tokenizer(model_dir)
|
397 |
+
process_line(tokenizer, line,'good.json' )
|
398 |
+
|
399 |
+
# Example usage
|
400 |
+
# tokenizer = load_tokenizer(model_dir)
|
401 |
+
# parallel_process_line(tokenizer, line, "allresults.json")
|
402 |
+
|
403 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.39.2
|
2 |
+
DeepSpeed==0.14.0
|
3 |
+
SentencePiece
|
4 |
+
accelerate>=0.21.0
|
test.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fraction import Fraction
|
2 |
+
import re
|
3 |
+
def is_number(s):
|
4 |
+
try:
|
5 |
+
float(s)
|
6 |
+
return True
|
7 |
+
except ValueError:
|
8 |
+
pass
|
9 |
+
try:
|
10 |
+
import unicodedata
|
11 |
+
unicodedata.numeric(s)
|
12 |
+
return True
|
13 |
+
except (TypeError, ValueError):
|
14 |
+
pass
|
15 |
+
return False
|
16 |
+
|
17 |
+
ANSWER_TRIGGER = 'The answer is'
|
18 |
+
def handle_frac(pred):
|
19 |
+
if '/' in pred:
|
20 |
+
denominator = pred.split('/')[1]
|
21 |
+
numerator = pred.split('/')[0]
|
22 |
+
if is_number(denominator) == True and is_number(numerator) == True:
|
23 |
+
if denominator == '0':
|
24 |
+
return round(float(numerator.replace(',', '')))
|
25 |
+
else:
|
26 |
+
frac = Fraction(pred.replace(',', ''))
|
27 |
+
num_numerator = frac.numerator
|
28 |
+
num_denominator = frac.denominator
|
29 |
+
return round(float(num_numerator / num_denominator))
|
30 |
+
def delete_extra_zero(n):
|
31 |
+
'''删除小数点后多余的0'''
|
32 |
+
try: n=float(n)
|
33 |
+
except:
|
34 |
+
# print("None {}".format(n))
|
35 |
+
try:
|
36 |
+
rr = str(handle_frac(n))
|
37 |
+
return rr
|
38 |
+
except:
|
39 |
+
return ''
|
40 |
+
if isinstance(n, int):
|
41 |
+
return str(n)
|
42 |
+
if isinstance(n, float):
|
43 |
+
n = str(n).rstrip('0') # 删除小数点后多余的0
|
44 |
+
n = int(n.rstrip('.')) if n.endswith('.') else float(n) # 只剩小数点直接转int,否则转回float
|
45 |
+
n=str(n)
|
46 |
+
return n
|
47 |
+
|
48 |
+
def output_answer_clean(model_pred):
|
49 |
+
model_pred = model_pred.lower()
|
50 |
+
preds = model_pred.split(ANSWER_TRIGGER.lower())
|
51 |
+
answer_flag = True if len(preds) > 1 else False
|
52 |
+
if answer_flag:
|
53 |
+
# Pick first answer with flag
|
54 |
+
pred = preds[1]
|
55 |
+
else:
|
56 |
+
# Pick last number without flag
|
57 |
+
pred = preds[-1]
|
58 |
+
|
59 |
+
pred = pred.replace(",", "")
|
60 |
+
# pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
|
61 |
+
# pred = [s.replace(",", "") for s in re.findall(r'-?\d+/?\.?\d*', pred)]
|
62 |
+
pred = [delete_extra_zero(s.replace(",", "")) for s in re.findall(r'-?\d+/?\.?\d*', pred)]
|
63 |
+
|
64 |
+
if len(pred) == 0:
|
65 |
+
return None
|
66 |
+
|
67 |
+
if answer_flag:
|
68 |
+
# choose the first element in list
|
69 |
+
pred = pred[0]
|
70 |
+
else:
|
71 |
+
# choose the last element in list
|
72 |
+
pred = pred
|
73 |
+
|
74 |
+
try:
|
75 |
+
if pred[-1] == ".":
|
76 |
+
pred = pred[:-1]
|
77 |
+
except:
|
78 |
+
pass
|
79 |
+
if isinstance(pred, list):
|
80 |
+
return pred[-1]
|
81 |
+
else:
|
82 |
+
return pred
|