|
import pdb |
|
import re |
|
import json |
|
import tqdm |
|
|
|
from process_supervision import load_tokenizer |
|
from test import delete_extra_zero |
|
|
|
def process_line(tokenizer, lines, wf_name, output_dir,thres=0.3): |
|
acc = [] |
|
recall_count = [0, 0] |
|
hullucination = [] |
|
import json |
|
|
|
rft_file_line = 0 |
|
rft_list = [] |
|
verifier_file_line = 0 |
|
verifier_list = [] |
|
rft_verifier_file_line = 0 |
|
rft_verifier_list = [] |
|
acc = [] |
|
|
|
|
|
with open(wf_name, 'w', encoding='utf-8') as wf: |
|
for line in tqdm.tqdm(lines): |
|
for output in line['outputs']: |
|
v_scores = output.get('vscores', []) |
|
response = output.get('response', "") |
|
is_true = output.get('label', "") |
|
|
|
if is_true: |
|
rft_list.append({"question": line['question'], "answer": output['response']}) |
|
rft_file_line += 1 |
|
if v_scores and v_scores[-1] >= thres: |
|
|
|
rft_verifier_list.append({"question": line['question'], "answer": output['response']}) |
|
rft_verifier_file_line += 1 |
|
if v_scores and v_scores[-1] >= thres: |
|
verifier_list.append({"question": line['question'], "answer": output['response']}) |
|
verifier_file_line += 1 |
|
if is_true: |
|
acc.append(1) |
|
else: |
|
acc.append(0) |
|
print(rft_file_line) |
|
print(verifier_file_line) |
|
print(rft_verifier_file_line) |
|
print("acc" , sum(acc)/len(acc)) |
|
|
|
with open(f"{output_dir}/rft.json", 'w', encoding='utf-8') as rft_file: |
|
json.dump(rft_list, rft_file , ensure_ascii=False, indent=2) |
|
|
|
with open(f"{output_dir}/verifier_enhanced.json", 'w', encoding='utf-8') as verifier_file: |
|
json.dump(verifier_list, verifier_file, ensure_ascii=False, indent=2) |
|
|
|
with open(f"{output_dir}/rft_verifier_enhanced.json", 'w', encoding='utf-8') as rft_verifier_file: |
|
json.dump(rft_verifier_list, rft_verifier_file , ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
def locate_sublist(lst, sublst): |
|
for i in range(len(lst) - len(sublst) + 1): |
|
if lst[i:i+len(sublst)] == sublst: |
|
return i |
|
assert ('not right') |
|
|
|
|
|
def split_string_list(a_list, number ='\n'): |
|
sublists = [] |
|
current_sublist = [] |
|
for item in a_list: |
|
current_sublist.append(item) |
|
if item == number: |
|
if current_sublist: |
|
sublists.append(''.join(current_sublist)) |
|
current_sublist = [] |
|
|
|
|
|
if current_sublist: |
|
sublists.append(''.join(current_sublist)) |
|
|
|
return sublists |
|
def split_token_list(a_list, number =13): |
|
sublists = [] |
|
current_sublist = [] |
|
for item in a_list: |
|
current_sublist.append(item) |
|
if item == number: |
|
if current_sublist: |
|
sublists.append(current_sublist) |
|
current_sublist = [] |
|
|
|
|
|
if current_sublist: |
|
sublists.append(current_sublist) |
|
|
|
return sublists |
|
|
|
|
|
|
|
def evaluate_expression_para(response_all, v_score, tokenizer, is_true): |
|
|
|
|
|
labels = [] |
|
predictions = [] |
|
sol_tokens = tokenizer(response_all).input_ids |
|
process_v_score = [0] * len(sol_tokens) |
|
hullucination = False |
|
gt_help = False |
|
error_detection = False |
|
response_list = split_string_list(response_all) |
|
token_list = split_token_list(sol_tokens) |
|
previous_len = 0 |
|
for idx, string in enumerate(response_list): |
|
|
|
para_token = token_list[idx] |
|
para_token_location = sum([len(item) for item in token_list[:idx]]) |
|
|
|
if error_detection: |
|
break |
|
|
|
|
|
if abs(v_score[para_token_location]) < 1e-5: |
|
error_detection = True |
|
|
|
elif (v_score[para_token_location + len(para_token) - 1] - v_score[para_token_location])/v_score[para_token_location] < -0.5: |
|
error_detection = True |
|
|
|
else: |
|
if not error_detection: |
|
process_v_score[para_token_location : para_token_location + len(para_token) ] = [1] * len(para_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if idx == len(response_list) - 1 and not error_detection and not is_true: |
|
process_v_score[para_token_location: para_token_location + len(para_token)] = [0] * len(para_token) |
|
gt_help = True |
|
|
|
previous_len += len(string) |
|
|
|
|
|
|
|
return {'label': labels, 'prediction': predictions, 'hullucination': hullucination, 'gt_help': gt_help}, process_v_score |
|
|
|
|
|
|
|
|
|
|
|
def evaluate_expression(response_all, v_score, tokenizer, is_true): |
|
|
|
|
|
labels = [] |
|
predictions = [] |
|
sol_tokens = tokenizer(response_all).input_ids |
|
process_v_score = [0] * len(sol_tokens) |
|
hullucination = False |
|
gt_help = False |
|
error_detection = False |
|
response_list = split_string_list(response_all) |
|
token_list = split_token_list(sol_tokens) |
|
previous_len = 0 |
|
for idx, string in enumerate(response_list): |
|
match = re.search(r'<<(.+?)>>', string) |
|
para_token = token_list[idx] |
|
para_token_location = sum([len(item) for item in token_list[:idx]]) |
|
if match: |
|
expression = match.group(1) |
|
start_token = tokenizer(response_all[ : previous_len + match.span()[0]]).input_ids |
|
if sol_tokens[:len(start_token)] != start_token: |
|
start_token = start_token[:-1] |
|
|
|
seg_token_location = len(start_token) |
|
seq_token = tokenizer(response_all[: previous_len + match.span()[1]]).input_ids[len(start_token):] |
|
|
|
|
|
try: |
|
if abs(v_score[seg_token_location]) < 1e-5: |
|
prediction = 'negative' |
|
error_detection = True |
|
|
|
elif (v_score[min(seg_token_location + len(seq_token), len(v_score) - 1)] - v_score[seg_token_location]) / (v_score[seg_token_location]) < -0.9: |
|
prediction = 'negative' |
|
error_detection = True |
|
else: |
|
prediction = 'positive' |
|
if not error_detection: |
|
process_v_score[para_token_location : para_token_location + len(para_token)] = [1] * len(para_token) |
|
except: |
|
import pdb |
|
pdb.set_trace() |
|
try: |
|
before_equal, after_equal = expression.split('=') |
|
computed_value = float(eval(before_equal.strip())) |
|
actual_value = float(delete_extra_zero(after_equal.strip().replace(",", ""))) |
|
|
|
if abs(computed_value - actual_value) <= 1e-3: |
|
label = 'positive' |
|
else: |
|
label = 'negative' |
|
hullucination = True |
|
|
|
|
|
labels.append(label) |
|
predictions.append(prediction) |
|
except Exception as e: |
|
pass |
|
else: |
|
if not error_detection: |
|
process_v_score[para_token_location: para_token_location + len(para_token)] = [1] * len(para_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
previous_len += len(string) |
|
|
|
|
|
|
|
return {'label': labels, 'prediction': predictions, 'hullucination': hullucination, 'gt_help': gt_help}, process_v_score |
|
|
|
|
|
|
|
|
|
|
|
|
|
import multiprocessing |
|
from functools import partial |
|
import os |
|
def process_chunk(tokenizer, chunk, wf_path): |
|
acc = [] |
|
recall_count = [0, 0] |
|
hullucination = [] |
|
gt_help = [] |
|
|
|
with open(wf_path, 'w', encoding='utf-8') as wf: |
|
for line in tqdm.tqdm(chunk): |
|
for output in line['outputs']: |
|
import pdb |
|
pdb.set_trace() |
|
v_scores = output.get('vscores', []) |
|
response = output.get('response', "") |
|
is_true = output.get('label', "") |
|
evaluation_results, process_v_scores = evaluate_expression_para(response, v_scores, tokenizer, is_true) |
|
|
|
|
|
if evaluation_results['hullucination']: |
|
hullucination.append(1) |
|
else: |
|
hullucination.append(0) |
|
|
|
if evaluation_results['gt_help']: |
|
gt_help.append(1) |
|
else: |
|
gt_help.append(0) |
|
|
|
|
|
|
|
for label, prediction in zip(evaluation_results['label'], evaluation_results['prediction']): |
|
acc.append((label, prediction)) |
|
|
|
|
|
for idx, prediction in enumerate(evaluation_results['prediction']): |
|
label = evaluation_results['label'][idx] |
|
if label == 'positive': |
|
recall_count[1] += 1 |
|
if prediction == 'positive': |
|
recall_count[0] += 1 |
|
wf.writelines(json.dumps(line, ensure_ascii=False) + '\n') |
|
|
|
|
|
accuracy = sum(1 for label, prediction in acc if label == prediction) / len(acc) if acc else 0 |
|
hullucination_rate = sum(hullucination) / len(hullucination) if hullucination else 0 |
|
|
|
return { |
|
"accuracy_sum": sum(1 for label, prediction in acc if label == prediction), |
|
"total": len(acc), |
|
"recall_correct": recall_count[0], |
|
"recall_total": recall_count[1], |
|
"hullucination_sum": sum(hullucination), |
|
"hullucination_total": len(hullucination), |
|
"gt_help_sum": sum(gt_help), |
|
"gt_help_total": len(gt_help), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def parallel_process_line(tokenizer, lines, wf_path, num_processes=1): |
|
if num_processes is None: |
|
num_processes = multiprocessing.cpu_count() |
|
|
|
|
|
chunk_size = int(len(lines) / num_processes) |
|
chunks = [lines[i:i + chunk_size] for i in range(0, len(lines), chunk_size)] |
|
|
|
|
|
temp_files = [f"multirun/{wf_path}_temp_{i}.json" for i in range(len(chunks))] |
|
|
|
|
|
with multiprocessing.Pool(processes=num_processes) as pool: |
|
|
|
results = pool.starmap(process_chunk, [(tokenizer, chunk, temp_file) for chunk, temp_file in zip(chunks, temp_files)]) |
|
|
|
|
|
with open(f"multirun2/{wf_path}.json", 'w', encoding='utf-8') as wf: |
|
for temp_file in temp_files: |
|
with open(temp_file, 'r', encoding='utf-8') as tf: |
|
wf.write(tf.read()) |
|
os.remove(temp_file) |
|
|
|
|
|
total_acc = sum(result['accuracy_sum'] for result in results) |
|
total = sum(result['total'] for result in results) |
|
total_recall_correct = sum(result['recall_correct'] for result in results) |
|
total_recall = sum(result['recall_total'] for result in results) |
|
total_hullucination = sum(result['hullucination_sum'] for result in results) |
|
total_hullucination_counts = sum(result['hullucination_total'] for result in results) |
|
total_gt_help = sum(result['gt_help_sum'] for result in results) |
|
total_gt_help_counts = sum(result['gt_help_total'] for result in results) |
|
|
|
|
|
overall_accuracy = total_acc / total if total else 0 |
|
overall_recall = total_recall_correct / total_recall if total_recall else 0 |
|
overall_hullucination = total_hullucination / total_hullucination_counts if total_hullucination_counts else 0 |
|
overall_gt_help = total_gt_help/ total_gt_help_counts if total_gt_help_counts else 0 |
|
|
|
print(f"Overall accuracy: {overall_accuracy}") |
|
print(f"Overall recall: {overall_recall}") |
|
print(f"Overall hullucination: {overall_hullucination}") |
|
print(f"Overall gt_help: {overall_gt_help}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
file_path_list = [ |
|
|
|
|
|
"eval_results/math/verifier/test/responses_v(lean4_random_15k_all-sample10-osv-gt2)_g(lean4_rand).jsonl", |
|
] |
|
line = [] |
|
for file_path in file_path_list: |
|
line += [json.loads(line) for line in open(file_path, 'r', encoding = 'utf-8').readlines() ] |
|
for ex in line: |
|
dedup_outputs = [] |
|
for output in ex['outputs']: |
|
if len(output['tokens']) > 2048: |
|
continue |
|
dedup_outputs.append(output) |
|
ex['outputs'] = dedup_outputs |
|
|
|
model_dir = "../models/lean4_random_15k_all-sample10-osv-gt2/" |
|
tokenizer = load_tokenizer(model_dir) |
|
process_line(tokenizer, line,'good.json' ,"data/continual_training_lean" ,0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|