rookiemango commited on
Commit
167df7c
·
verified ·
1 Parent(s): d559180

Upload folder using huggingface_hub

Browse files
__pycache__/process_supervision.cpython-39.pyc ADDED
Binary file (3.82 kB). View file
 
__pycache__/test.cpython-39.pyc ADDED
Binary file (1.91 kB). View file
 
generation_method.py CHANGED
@@ -229,6 +229,8 @@ def parse_arguments():
229
 
230
  if args.dataset == "lean4_5k_test":
231
  args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
 
 
232
 
233
  elif args.dataset == "math_train":
234
  args.data_path = "data/test/math/train.jsonl"
@@ -237,12 +239,12 @@ def parse_arguments():
237
  args.data_path = "data/test/gsm8k/train.jsonl"
238
 
239
  elif args.dataset == "wild_test":
240
- args.data_path = "/hpc2hdd/home/zyang398/data_2/wild_sample1k.jsonl"
241
 
242
  elif args.dataset == "lean4_basic_test":
243
- args.data_path = "data/lean4_basic/1k_test.jsonl"
244
  elif args.dataset == "lean4_random_test":
245
- args.data_path = "data/lean4_random/1k_test.json"
246
  elif args.dataset == "lean4_random_first_train":
247
  args.data_path = "data/lean4_random/5k_first.json"
248
  elif args.dataset == "lean4_random_second_train":
@@ -324,7 +326,7 @@ PROMPT_DICT = {
324
  ),
325
  "lean4": (
326
  "Statement and proof in natural language:\n\n"
327
- "{statement_text}\n\n"
328
  "Translate the statement and proof in natural language to lean4:"
329
  ),
330
  "prompt_no_input": (
@@ -366,7 +368,7 @@ def get_question_answer(args):
366
  questions = [ PROMPT_DICT['wild'].format(question= questions[id], answer =answers[id][args.data_answer_key] ) for id in range(len(questions))]
367
 
368
  else:
369
- questions = [ PROMPT_DICT['lean4'].format(statement_text = item) for item in questions]
370
 
371
 
372
  return questions, answers
 
229
 
230
  if args.dataset == "lean4_5k_test":
231
  args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
232
+ elif args.dataset == "lean4_15k_train":
233
+ args.data_path = "data/lean4_random/15k_filtered.json"
234
 
235
  elif args.dataset == "math_train":
236
  args.data_path = "data/test/math/train.jsonl"
 
239
  args.data_path = "data/test/gsm8k/train.jsonl"
240
 
241
  elif args.dataset == "wild_test":
242
+ args.data_path = "data/wild/wild_sample1k.jsonl"
243
 
244
  elif args.dataset == "lean4_basic_test":
245
+ args.data_path = "data/lean4_basic/1k_test_filtered.jsonl"
246
  elif args.dataset == "lean4_random_test":
247
+ args.data_path = "data/lean4_random/1k_test_filtered.json"
248
  elif args.dataset == "lean4_random_first_train":
249
  args.data_path = "data/lean4_random/5k_first.json"
250
  elif args.dataset == "lean4_random_second_train":
 
326
  ),
327
  "lean4": (
328
  "Statement and proof in natural language:\n\n"
329
+ "{model_response}\n\n"
330
  "Translate the statement and proof in natural language to lean4:"
331
  ),
332
  "prompt_no_input": (
 
368
  questions = [ PROMPT_DICT['wild'].format(question= questions[id], answer =answers[id][args.data_answer_key] ) for id in range(len(questions))]
369
 
370
  else:
371
+ questions = [ PROMPT_DICT['lean4'].format(model_response = item) for item in questions]
372
 
373
 
374
  return questions, answers
multirun/allresults.json_temp_0.json ADDED
File without changes
process_supervision_training_data.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import re
3
+ import json
4
+ import tqdm
5
+
6
+ from process_supervision import load_tokenizer
7
+ from test import delete_extra_zero
8
+
9
+ def process_line(tokenizer, lines, wf_name):
10
+ acc = [] # Will store tuples of (label, prediction) for each expression
11
+ recall_count = [0, 0] # [number of correct positives, number of actual positives]
12
+ hullucination = []
13
+ import json
14
+
15
+ rft_file_line = 0
16
+ rft_list = []
17
+ verifier_file_line = 0
18
+ verifier_list = []
19
+ rft_verifier_file_line = 0
20
+ rft_verifier_list = []
21
+ acc = []
22
+
23
+
24
+ with open(wf_name, 'w', encoding='utf-8') as wf:
25
+ for line in tqdm.tqdm(lines):
26
+ for output in line['outputs']:
27
+ v_scores = output.get('vscores', [])
28
+ response = output.get('response', "")
29
+ is_true = output.get('label', "")
30
+
31
+ if is_true:
32
+ rft_list.append({"question": line['question'], "answer": output['response']})
33
+ rft_file_line += 1
34
+ if v_scores and v_scores[-1] > 0.5:
35
+ # Save to rft_verifier_enhanced.json
36
+ rft_verifier_list.append({"question": line['question'], "answer": output['response']})
37
+ rft_verifier_file_line += 1
38
+ if v_scores and v_scores[-1] > 0.5:
39
+ verifier_list.append({"question": line['question'], "answer": output['response']})
40
+ verifier_file_line += 1
41
+ if is_true:
42
+ acc.append(1)
43
+ else:
44
+ acc.append(0)
45
+ print(rft_file_line)
46
+ print(verifier_file_line)
47
+ print(rft_verifier_file_line)
48
+ print("acc" , sum(acc)/len(acc))
49
+
50
+ with open("data/continual_training/rft.json", 'w', encoding='utf-8') as rft_file:
51
+ json.dump(rft_list, rft_file , ensure_ascii=False, indent=2)
52
+
53
+ with open("data/continual_training/verifier_enhanced.json", 'w', encoding='utf-8') as verifier_file:
54
+ json.dump(verifier_list, verifier_file, ensure_ascii=False, indent=2)
55
+
56
+ with open("data/continual_training/rft_verifier_enhanced.json", 'w', encoding='utf-8') as rft_verifier_file:
57
+ json.dump(rft_verifier_list, rft_verifier_file , ensure_ascii=False, indent=2)
58
+
59
+
60
+
61
+
62
+
63
+ def locate_sublist(lst, sublst):
64
+ for i in range(len(lst) - len(sublst) + 1):
65
+ if lst[i:i+len(sublst)] == sublst:
66
+ return i # Return the starting index of the sublist in the list
67
+ assert ('not right')
68
+
69
+
70
+ def split_string_list(a_list, number ='\n'):
71
+ sublists = []
72
+ current_sublist = []
73
+ for item in a_list:
74
+ current_sublist.append(item)
75
+ if item == number:
76
+ if current_sublist: # if the current sublist is not empty
77
+ sublists.append(''.join(current_sublist))
78
+ current_sublist = [] # start a new sublist
79
+
80
+ # Don't forget to add the last sublist if it's not empty
81
+ if current_sublist:
82
+ sublists.append(''.join(current_sublist))
83
+
84
+ return sublists
85
+ def split_token_list(a_list, number =13):
86
+ sublists = []
87
+ current_sublist = []
88
+ for item in a_list:
89
+ current_sublist.append(item)
90
+ if item == number:
91
+ if current_sublist: # if the current sublist is not empty
92
+ sublists.append(current_sublist)
93
+ current_sublist = [] # start a new sublist
94
+
95
+ # Don't forget to add the last sublist if it's not empty
96
+ if current_sublist:
97
+ sublists.append(current_sublist)
98
+
99
+ return sublists
100
+ # Modify evaluate_expression function to return a list of results
101
+
102
+
103
+ def evaluate_expression_para(response_all, v_score, tokenizer, is_true):
104
+ # Initialize lists to hold multiple evaluation results for each expression
105
+ # here we make the v_score label in a "first error detection"
106
+ labels = []
107
+ predictions = []
108
+ sol_tokens = tokenizer(response_all).input_ids
109
+ process_v_score = [0] * len(sol_tokens)
110
+ hullucination = False
111
+ gt_help = False
112
+ error_detection = False
113
+ response_list = split_string_list(response_all)
114
+ token_list = split_token_list(sol_tokens)
115
+ previous_len = 0
116
+ for idx, string in enumerate(response_list):
117
+ # match = re.search(r'<<(.+?)>>', string)
118
+ para_token = token_list[idx]
119
+ para_token_location = sum([len(item) for item in token_list[:idx]])
120
+
121
+ if error_detection:
122
+ break
123
+
124
+
125
+ if abs(v_score[para_token_location]) < 1e-5:
126
+ error_detection = True
127
+
128
+ elif (v_score[para_token_location + len(para_token) - 1] - v_score[para_token_location])/v_score[para_token_location] < -0.5:
129
+ error_detection = True
130
+
131
+ else:
132
+ if not error_detection:
133
+ process_v_score[para_token_location : para_token_location + len(para_token) ] = [1] * len(para_token)
134
+
135
+ # if match:
136
+ # expression = match.group(1)
137
+ # start_token = tokenizer(response_all[ : previous_len + match.span()[0]]).input_ids
138
+ # if sol_tokens[:len(start_token)] != start_token:
139
+ # start_token = start_token[:-1]
140
+ # # print(tokenizer.decode(start_token))
141
+ # seg_token_location = len(start_token)
142
+ # seq_token = tokenizer(response_all[: previous_len + match.span()[1]]).input_ids[len(start_token):]
143
+ # # print(tokenizer.decode(seq_token))
144
+ # # Check if v_score change is positive
145
+ # try:
146
+ # if abs(v_score[seg_token_location]) < 1e-5:
147
+ # prediction = 'negative' # there is a extra example in v_score
148
+ # error_detection = True
149
+ #
150
+ # elif (v_score[min(seg_token_location + len(seq_token), len(v_score) - 1)] - v_score[seg_token_location]) / (v_score[seg_token_location]) < -0.9:
151
+ # prediction = 'negative' # there is a negative change in v_score
152
+ # error_detection = True
153
+ # else:
154
+ # prediction = 'positive' # no negative change in v_score
155
+ # if not error_detection:
156
+ # process_v_score[para_token_location : para_token_location + len(para_token)] = [1] * len(para_token)
157
+ # except:
158
+ # import pdb
159
+ # pdb.set_trace()
160
+ # try:
161
+ # before_equal, after_equal = expression.split('=')
162
+ # computed_value = float(eval(before_equal.strip()))
163
+ # actual_value = float(delete_extra_zero(after_equal.strip().replace(",", "")))
164
+ # # Use the positive v_score change as a proxy for a correct evaluation
165
+ # if abs(computed_value - actual_value) <= 1e-3:
166
+ # label = 'positive'
167
+ # else:
168
+ # label = 'negative'
169
+ # hullucination = True
170
+ #
171
+ # # Record the label and prediction for this expression
172
+ # labels.append(label)
173
+ # predictions.append(prediction)
174
+ # except Exception as e:
175
+ # pass
176
+ # else:
177
+ # if not error_detection:
178
+ # process_v_score[para_token_location: para_token_location + len(para_token)] = [1] * len(para_token)
179
+ #
180
+
181
+ if idx == len(response_list) - 1 and not error_detection and not is_true:
182
+ process_v_score[para_token_location: para_token_location + len(para_token)] = [0] * len(para_token)
183
+ gt_help = True
184
+
185
+ previous_len += len(string)
186
+ # if sum(process_v_score) != len(process_v_score) and sum(process_v_score) != 0:
187
+ # print(process_v_score)
188
+
189
+ return {'label': labels, 'prediction': predictions, 'hullucination': hullucination, 'gt_help': gt_help}, process_v_score
190
+
191
+
192
+
193
+
194
+
195
+ def evaluate_expression(response_all, v_score, tokenizer, is_true):
196
+ # Initialize lists to hold multiple evaluation results for each expression
197
+ # here we make the v_score label in a "first error detection"
198
+ labels = []
199
+ predictions = []
200
+ sol_tokens = tokenizer(response_all).input_ids
201
+ process_v_score = [0] * len(sol_tokens)
202
+ hullucination = False
203
+ gt_help = False
204
+ error_detection = False
205
+ response_list = split_string_list(response_all)
206
+ token_list = split_token_list(sol_tokens)
207
+ previous_len = 0
208
+ for idx, string in enumerate(response_list):
209
+ match = re.search(r'<<(.+?)>>', string)
210
+ para_token = token_list[idx]
211
+ para_token_location = sum([len(item) for item in token_list[:idx]])
212
+ if match:
213
+ expression = match.group(1)
214
+ start_token = tokenizer(response_all[ : previous_len + match.span()[0]]).input_ids
215
+ if sol_tokens[:len(start_token)] != start_token:
216
+ start_token = start_token[:-1]
217
+ # print(tokenizer.decode(start_token))
218
+ seg_token_location = len(start_token)
219
+ seq_token = tokenizer(response_all[: previous_len + match.span()[1]]).input_ids[len(start_token):]
220
+ # print(tokenizer.decode(seq_token))
221
+ # Check if v_score change is positive
222
+ try:
223
+ if abs(v_score[seg_token_location]) < 1e-5:
224
+ prediction = 'negative' # there is a extra example in v_score
225
+ error_detection = True
226
+
227
+ elif (v_score[min(seg_token_location + len(seq_token), len(v_score) - 1)] - v_score[seg_token_location]) / (v_score[seg_token_location]) < -0.9:
228
+ prediction = 'negative' # there is a negative change in v_score
229
+ error_detection = True
230
+ else:
231
+ prediction = 'positive' # no negative change in v_score
232
+ if not error_detection:
233
+ process_v_score[para_token_location : para_token_location + len(para_token)] = [1] * len(para_token)
234
+ except:
235
+ import pdb
236
+ pdb.set_trace()
237
+ try:
238
+ before_equal, after_equal = expression.split('=')
239
+ computed_value = float(eval(before_equal.strip()))
240
+ actual_value = float(delete_extra_zero(after_equal.strip().replace(",", "")))
241
+ # Use the positive v_score change as a proxy for a correct evaluation
242
+ if abs(computed_value - actual_value) <= 1e-3:
243
+ label = 'positive'
244
+ else:
245
+ label = 'negative'
246
+ hullucination = True
247
+
248
+ # Record the label and prediction for this expression
249
+ labels.append(label)
250
+ predictions.append(prediction)
251
+ except Exception as e:
252
+ pass
253
+ else:
254
+ if not error_detection:
255
+ process_v_score[para_token_location: para_token_location + len(para_token)] = [1] * len(para_token)
256
+
257
+
258
+ # if idx == len(response_list) - 1 and not error_detection and not is_true:
259
+ # process_v_score[para_token_location: para_token_location + len(para_token)] = [0] * len(para_token)
260
+ # gt_help = True
261
+
262
+ previous_len += len(string)
263
+ # if sum(process_v_score) != len(process_v_score) and sum(process_v_score) != 0:
264
+ # print(process_v_score)
265
+
266
+ return {'label': labels, 'prediction': predictions, 'hullucination': hullucination, 'gt_help': gt_help}, process_v_score
267
+
268
+
269
+
270
+
271
+
272
+
273
+ import multiprocessing
274
+ from functools import partial
275
+ import os
276
+ def process_chunk(tokenizer, chunk, wf_path):
277
+ acc = [] # Will store tuples of (label, prediction) for each expression
278
+ recall_count = [0, 0] # [number of correct positives, number of actual positives]
279
+ hullucination = []
280
+ gt_help = []
281
+
282
+ with open(wf_path, 'w', encoding='utf-8') as wf:
283
+ for line in tqdm.tqdm(chunk):
284
+ for output in line['outputs']:
285
+ import pdb
286
+ pdb.set_trace()
287
+ v_scores = output.get('vscores', [])
288
+ response = output.get('response', "")
289
+ is_true = output.get('label', "")
290
+ evaluation_results, process_v_scores = evaluate_expression_para(response, v_scores, tokenizer, is_true)
291
+ # output['process_vscores'] = process_v_scores
292
+
293
+ if evaluation_results['hullucination']:
294
+ hullucination.append(1)
295
+ else:
296
+ hullucination.append(0)
297
+
298
+ if evaluation_results['gt_help']:
299
+ gt_help.append(1)
300
+ else:
301
+ gt_help.append(0)
302
+
303
+
304
+ # Add the results to the accuracy list for each expression
305
+ for label, prediction in zip(evaluation_results['label'], evaluation_results['prediction']):
306
+ acc.append((label, prediction))
307
+
308
+ # Update recall counts for each expression
309
+ for idx, prediction in enumerate(evaluation_results['prediction']):
310
+ label = evaluation_results['label'][idx]
311
+ if label == 'positive':
312
+ recall_count[1] += 1 # Increment the count of actual positives
313
+ if prediction == 'positive':
314
+ recall_count[0] += 1 # Increment the count of correct positives
315
+ wf.writelines(json.dumps(line, ensure_ascii=False) + '\n')
316
+
317
+ # Calculate metrics for the chunk
318
+ accuracy = sum(1 for label, prediction in acc if label == prediction) / len(acc) if acc else 0
319
+ hullucination_rate = sum(hullucination) / len(hullucination) if hullucination else 0
320
+ # Return the metrics and counts, not just the rates, to allow aggregation
321
+ return {
322
+ "accuracy_sum": sum(1 for label, prediction in acc if label == prediction),
323
+ "total": len(acc),
324
+ "recall_correct": recall_count[0],
325
+ "recall_total": recall_count[1],
326
+ "hullucination_sum": sum(hullucination),
327
+ "hullucination_total": len(hullucination),
328
+ "gt_help_sum": sum(gt_help),
329
+ "gt_help_total": len(gt_help),
330
+ }
331
+ # print(
332
+ # f"Chunk accuracy: {accuracy}, Chunk recall: {recall}, Chunk hullucination: {sum(hullucination) / len(hullucination) if hullucination else 0}")
333
+
334
+
335
+
336
+ def parallel_process_line(tokenizer, lines, wf_path, num_processes=1):
337
+ if num_processes is None:
338
+ num_processes = multiprocessing.cpu_count()
339
+
340
+ # Split lines into chunks
341
+ chunk_size = int(len(lines) / num_processes)
342
+ chunks = [lines[i:i + chunk_size] for i in range(0, len(lines), chunk_size)]
343
+
344
+ # Generate a unique temporary file path for each chunk
345
+ temp_files = [f"multirun/{wf_path}_temp_{i}.json" for i in range(len(chunks))]
346
+
347
+ # Create a pool of workers to process data in parallel
348
+ with multiprocessing.Pool(processes=num_processes) as pool:
349
+ # Map each chunk to process_chunk function along with a unique temporary file path
350
+ results = pool.starmap(process_chunk, [(tokenizer, chunk, temp_file) for chunk, temp_file in zip(chunks, temp_files)])
351
+
352
+ # Combine results from temporary files into the final output file
353
+ with open(f"multirun2/{wf_path}.json", 'w', encoding='utf-8') as wf:
354
+ for temp_file in temp_files:
355
+ with open(temp_file, 'r', encoding='utf-8') as tf:
356
+ wf.write(tf.read())
357
+ os.remove(temp_file) # Clean up temporary file
358
+
359
+ # Aggregate metrics from all chunks
360
+ total_acc = sum(result['accuracy_sum'] for result in results)
361
+ total = sum(result['total'] for result in results)
362
+ total_recall_correct = sum(result['recall_correct'] for result in results)
363
+ total_recall = sum(result['recall_total'] for result in results)
364
+ total_hullucination = sum(result['hullucination_sum'] for result in results)
365
+ total_hullucination_counts = sum(result['hullucination_total'] for result in results)
366
+ total_gt_help = sum(result['gt_help_sum'] for result in results)
367
+ total_gt_help_counts = sum(result['gt_help_total'] for result in results)
368
+
369
+ # Calculate overall metrics
370
+ overall_accuracy = total_acc / total if total else 0
371
+ overall_recall = total_recall_correct / total_recall if total_recall else 0
372
+ overall_hullucination = total_hullucination / total_hullucination_counts if total_hullucination_counts else 0
373
+ overall_gt_help = total_gt_help/ total_gt_help_counts if total_gt_help_counts else 0
374
+
375
+ print(f"Overall accuracy: {overall_accuracy}")
376
+ print(f"Overall recall: {overall_recall}")
377
+ print(f"Overall hullucination: {overall_hullucination}")
378
+ print(f"Overall gt_help: {overall_gt_help}")
379
+
380
+
381
+
382
+ # Example usage
383
+ # line = '{"outputs": [{"solution_str": "The result is <<5 * 3 = 15>>."}, {"solution_str": "The answer is <<2 + 2 = 5>>."}]}'
384
+ # file_path = "eval_results/gsm8k/verifier/train/responses_v(threemodel)_g(threemodel).jsonl"
385
+ file_path = "eval_results/math/verifier/test/responses_v(lean4_random_15k_all-sample10-osv-gt2)_g(lean4_rand).jsonl"
386
+ line = [json.loads(line) for line in open(file_path, 'r', encoding = 'utf-8').readlines()]
387
+ for ex in line:
388
+ dedup_outputs = []
389
+ for output in ex['outputs']:
390
+ if len(output['tokens']) > 2048:
391
+ continue
392
+ dedup_outputs.append(output)
393
+ ex['outputs'] = dedup_outputs
394
+
395
+ model_dir = "../models/lean4_random_15k_all-sample10-osv-gt2/"
396
+ tokenizer = load_tokenizer(model_dir)
397
+ process_line(tokenizer, line,'good.json' )
398
+
399
+ # Example usage
400
+ # tokenizer = load_tokenizer(model_dir)
401
+ # parallel_process_line(tokenizer, line, "allresults.json")
402
+
403
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.39.2
2
+ DeepSpeed==0.14.0
3
+ SentencePiece
4
+ accelerate>=0.21.0
test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fraction import Fraction
2
+ import re
3
+ def is_number(s):
4
+ try:
5
+ float(s)
6
+ return True
7
+ except ValueError:
8
+ pass
9
+ try:
10
+ import unicodedata
11
+ unicodedata.numeric(s)
12
+ return True
13
+ except (TypeError, ValueError):
14
+ pass
15
+ return False
16
+
17
+ ANSWER_TRIGGER = 'The answer is'
18
+ def handle_frac(pred):
19
+ if '/' in pred:
20
+ denominator = pred.split('/')[1]
21
+ numerator = pred.split('/')[0]
22
+ if is_number(denominator) == True and is_number(numerator) == True:
23
+ if denominator == '0':
24
+ return round(float(numerator.replace(',', '')))
25
+ else:
26
+ frac = Fraction(pred.replace(',', ''))
27
+ num_numerator = frac.numerator
28
+ num_denominator = frac.denominator
29
+ return round(float(num_numerator / num_denominator))
30
+ def delete_extra_zero(n):
31
+ '''删除小数点后多余的0'''
32
+ try: n=float(n)
33
+ except:
34
+ # print("None {}".format(n))
35
+ try:
36
+ rr = str(handle_frac(n))
37
+ return rr
38
+ except:
39
+ return ''
40
+ if isinstance(n, int):
41
+ return str(n)
42
+ if isinstance(n, float):
43
+ n = str(n).rstrip('0') # 删除小数点后多余的0
44
+ n = int(n.rstrip('.')) if n.endswith('.') else float(n) # 只剩小数点直接转int,否则转回float
45
+ n=str(n)
46
+ return n
47
+
48
+ def output_answer_clean(model_pred):
49
+ model_pred = model_pred.lower()
50
+ preds = model_pred.split(ANSWER_TRIGGER.lower())
51
+ answer_flag = True if len(preds) > 1 else False
52
+ if answer_flag:
53
+ # Pick first answer with flag
54
+ pred = preds[1]
55
+ else:
56
+ # Pick last number without flag
57
+ pred = preds[-1]
58
+
59
+ pred = pred.replace(",", "")
60
+ # pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
61
+ # pred = [s.replace(",", "") for s in re.findall(r'-?\d+/?\.?\d*', pred)]
62
+ pred = [delete_extra_zero(s.replace(",", "")) for s in re.findall(r'-?\d+/?\.?\d*', pred)]
63
+
64
+ if len(pred) == 0:
65
+ return None
66
+
67
+ if answer_flag:
68
+ # choose the first element in list
69
+ pred = pred[0]
70
+ else:
71
+ # choose the last element in list
72
+ pred = pred
73
+
74
+ try:
75
+ if pred[-1] == ".":
76
+ pred = pred[:-1]
77
+ except:
78
+ pass
79
+ if isinstance(pred, list):
80
+ return pred[-1]
81
+ else:
82
+ return pred