|
import re |
|
from test import delete_extra_zero |
|
import transformers |
|
DEFAULT_PAD_TOKEN = "<pad>" |
|
DEFAULT_BOS_TOKEN = "<s>" |
|
DEFAULT_EOS_TOKEN = "</s>" |
|
DEFAULT_UNK_TOKEN = "<unk>" |
|
|
|
import json |
|
import tqdm |
|
import multiprocessing |
|
from functools import partial |
|
|
|
def load_tokenizer(model_name_or_path): |
|
|
|
|
|
print(f"+ [Model] Initializing Tokenizer: {model_name_or_path}") |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_name_or_path, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
|
|
if 'phi' in model_name_or_path: |
|
tokenizer.pad_token = tokenizer.unk_token |
|
else: |
|
if tokenizer.pad_token is None: |
|
tokenizer.add_special_tokens({ |
|
"eos_token": DEFAULT_EOS_TOKEN, |
|
"bos_token": DEFAULT_BOS_TOKEN, |
|
"unk_token": DEFAULT_UNK_TOKEN, |
|
}) |
|
|
|
|
|
return tokenizer |
|
|
|
def evaluate_expression(string): |
|
correct_count = 0 |
|
|
|
|
|
match = re.search(r'<<(.+?)>>', string) |
|
if match: |
|
expression = match.group(1) |
|
|
|
try: |
|
before_equal, after_equal = expression.split('=') |
|
|
|
computed_value = float(eval(before_equal.strip())) |
|
|
|
actual_value = float(delete_extra_zero(after_equal.strip().replace(",", ""))) |
|
|
|
if abs(computed_value - actual_value) <= 1e-3: |
|
correct_count = 1 |
|
except Exception as e: |
|
print(f"Error evaluating expression: {expression}. Error: {e}") |
|
|
|
|
|
return correct_count |
|
|
|
|
|
|
|
def process_line(tokenizer, line): |
|
|
|
acc = [] |
|
line = json.loads(line) |
|
for idx in range(len(line['outputs'])): |
|
item = line['outputs'][idx] |
|
v_scores = item['vscores'] |
|
solution_tokens = item['tokens'] |
|
if item['label']: |
|
split_token_id = 13 |
|
split_indices = [0] |
|
split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if |
|
token_id == split_token_id and solution_tokens[i - 1] != split_token_id]) |
|
split_indices.append(len(solution_tokens)) |
|
segment_v_scores = [v_scores[split_indices[i]] for i in range(1, len(split_indices))] |
|
score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in |
|
range(1, len(segment_v_scores))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(score_changes) and min(score_changes) < 0: |
|
max_change_index = score_changes.index(min(score_changes)) + 1 |
|
highlighted_solution = [] |
|
for i in range(len(split_indices) - 1): |
|
segment = solution_tokens[split_indices[i]:split_indices[i + 1]] |
|
if i == max_change_index: |
|
detect_string = tokenizer.decode(segment[:-1]) |
|
highlighted_solution.append("<Error>" + detect_string + "<Error>\n") |
|
matches = re.findall(r'<<([^>>]+)>>', detect_string) |
|
if not matches: |
|
continue |
|
is_false = not evaluate_expression(detect_string) |
|
if is_false: |
|
acc.append(1) |
|
else: |
|
acc.append(0) |
|
else: |
|
highlighted_solution.append(tokenizer.decode(segment)) |
|
return acc |
|
|
|
def process_line2(tokenizer, line): |
|
|
|
filter_list= [] |
|
line = json.loads(line) |
|
for idx in range(len(line['outputs'])): |
|
item = line['outputs'][idx] |
|
v_scores = item['vscores'] |
|
solution_tokens = item['tokens'] |
|
|
|
|
|
def load_vescores(): |
|
file_path = "eval_results/gsm8k/verifier/test/responses_v(mistral7b-ep2-n100-scahead-mse-lm-token)_g(llama2chatfinetuned).jsonl" |
|
model_dir = "/data/OVM-Mistral-7b/mistral7b-ep2-n100-scahead-mse-lm-token" |
|
tokenizer = load_tokenizer(model_dir) |
|
number = multiprocessing.cpu_count() |
|
pool = multiprocessing.Pool(1) |
|
acc = [] |
|
with open(file_path, 'r', encoding='utf-8') as fp: |
|
lines = fp.readlines() |
|
with tqdm.tqdm(total=len(lines)) as pbar: |
|
func = partial(process_line, tokenizer) |
|
for result in pool.imap(func, lines): |
|
acc.extend(result) |
|
pbar.update() |
|
|
|
print(f"acc : {sum(acc)/len(acc)}") |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|