import re |
from test import delete_extra_zero |
import transformers |
import json |
import tqdm |
import multiprocessing |
from functools import partial |
def load_tokenizer(model_name_or_path): |
print(f"+ [Model] Initializing Tokenizer: {model_name_or_path}") |
tokenizer = transformers.AutoTokenizer.from_pretrained( |
model_name_or_path, |
padding_side="right", |
use_fast=False, |
) |
if 'phi' in model_name_or_path: |
tokenizer.pad_token = tokenizer.unk_token |
else: |
if tokenizer.pad_token is None: |
tokenizer.add_special_tokens({ |
"eos_token": DEFAULT_EOS_TOKEN, |
"bos_token": DEFAULT_BOS_TOKEN, |
"unk_token": DEFAULT_UNK_TOKEN, |
}) |
return tokenizer |
def evaluate_expression(string): |
correct_count = 0 |
match = re.search(r'<<(.+?)>>', string) |
if match: |
expression = match.group(1) |
try: |
before_equal, after_equal = expression.split('=') |
computed_value = float(eval(before_equal.strip())) |
actual_value = float(delete_extra_zero(after_equal.strip().replace(",", ""))) |
if abs(computed_value - actual_value) <= 1e-3: |
correct_count = 1 |
except Exception as e: |
print(f"Error evaluating expression: {expression}. Error: {e}") |
return correct_count |
def process_line(tokenizer, line): |
acc = [] |
line = json.loads(line) |
for idx in range(len(line['outputs'])): |
item = line['outputs'][idx] |
v_scores = item['vscores'] |
solution_tokens = item['tokens'] |
if item['label']: |
split_token_id = 13 |
split_indices = [0] |
split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if |
token_id == split_token_id and solution_tokens[i - 1] != split_token_id]) |
split_indices.append(len(solution_tokens)) |
segment_v_scores = [v_scores[split_indices[i]] for i in range(1, len(split_indices))] |
score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in |
range(1, len(segment_v_scores))] |
if len(score_changes) and min(score_changes) < 0: |
max_change_index = score_changes.index(min(score_changes)) + 1 |
highlighted_solution = [] |
for i in range(len(split_indices) - 1): |
segment = solution_tokens[split_indices[i]:split_indices[i + 1]] |
if i == max_change_index: |
detect_string = tokenizer.decode(segment[:-1]) |
highlighted_solution.append("<Error>" + detect_string + "<Error>\n") |
matches = re.findall(r'<<([^>>]+)>>', detect_string) |
if not matches: |
continue |
is_false = not evaluate_expression(detect_string) |
if is_false: |
acc.append(1) |
else: |
acc.append(0) |
else: |
highlighted_solution.append(tokenizer.decode(segment)) |
return acc |
def process_line2(tokenizer, line): |
filter_list= [] |
line = json.loads(line) |
for idx in range(len(line['outputs'])): |
item = line['outputs'][idx] |
v_scores = item['vscores'] |
solution_tokens = item['tokens'] |
def load_vescores(): |
file_path = "eval_results/gsm8k/verifier/test/responses_v(mistral7b-ep2-n100-scahead-mse-lm-token)_g(llama2chatfinetuned).jsonl" |
model_dir = "/data/OVM-Mistral-7b/mistral7b-ep2-n100-scahead-mse-lm-token" |
tokenizer = load_tokenizer(model_dir) |
number = multiprocessing.cpu_count() |
pool = multiprocessing.Pool(1) |
acc = [] |
with open(file_path, 'r', encoding='utf-8') as fp: |
lines = fp.readlines() |
with tqdm.tqdm(total=len(lines)) as pbar: |
func = partial(process_line, tokenizer) |
for result in pool.imap(func, lines): |
acc.extend(result) |
pbar.update() |
print(f"acc : {sum(acc)/len(acc)}") |
if __name__ == '__main__': |
pass |