auto-info / process_supervision.py
rookiemango's picture
Upload folder using huggingface_hub
da66274 verified
raw
history blame
5.7 kB
import re
from test import delete_extra_zero
import transformers
DEFAULT_PAD_TOKEN = "<pad>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
import json
import tqdm
import multiprocessing
from functools import partial
def load_tokenizer(model_name_or_path):
print(f"+ [Model] Initializing Tokenizer: {model_name_or_path}")
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path,
padding_side="right",
use_fast=False,
)
if 'phi' in model_name_or_path:
tokenizer.pad_token = tokenizer.unk_token
else:
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
})
return tokenizer
def evaluate_expression(string):
correct_count = 0
# Find the expression inside <<>>
match = re.search(r'<<(.+?)>>', string)
if match:
expression = match.group(1)
# Separate the expressions before and after the '='
try:
before_equal, after_equal = expression.split('=')
# Evaluate the expression before the '='
computed_value = float(eval(before_equal.strip()))
# Convert the after_equal to an integer for comparison
actual_value = float(delete_extra_zero(after_equal.strip().replace(",", "")))
# Compare the computed value with the actual value
if abs(computed_value - actual_value) <= 1e-3:
correct_count = 1
except Exception as e:
print(f"Error evaluating expression: {expression}. Error: {e}")
# Calculate accuracy
return correct_count
def process_line(tokenizer, line):
acc = []
line = json.loads(line)
for idx in range(len(line['outputs'])):
item = line['outputs'][idx]
v_scores = item['vscores']
solution_tokens = item['tokens']
if item['label']:
split_token_id = 13
split_indices = [0]
split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if
token_id == split_token_id and solution_tokens[i - 1] != split_token_id])
split_indices.append(len(solution_tokens))
segment_v_scores = [v_scores[split_indices[i]] for i in range(1, len(split_indices))]
score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in
range(1, len(segment_v_scores))]
# split_token_id = 13
# split_indices = [0]
# split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if token_id == split_token_id and solution_tokens[i - 1] != split_token_id])
# split_indices.append(len(solution_tokens))
# segment_v_scores = [v_scores[split_indices[i]] for i in range(len(split_indices) - 1)]
# score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in range(1, len(segment_v_scores))]
if len(score_changes) and min(score_changes) < 0:
max_change_index = score_changes.index(min(score_changes)) + 1
highlighted_solution = []
for i in range(len(split_indices) - 1):
segment = solution_tokens[split_indices[i]:split_indices[i + 1]]
if i == max_change_index:
detect_string = tokenizer.decode(segment[:-1])
highlighted_solution.append("<Error>" + detect_string + "<Error>\n")
matches = re.findall(r'<<([^>>]+)>>', detect_string)
if not matches:
continue
is_false = not evaluate_expression(detect_string)
if is_false:
acc.append(1)
else:
acc.append(0)
else:
highlighted_solution.append(tokenizer.decode(segment))
return acc
def process_line2(tokenizer, line):
filter_list= []
line = json.loads(line)
for idx in range(len(line['outputs'])):
item = line['outputs'][idx]
v_scores = item['vscores']
solution_tokens = item['tokens']
def load_vescores():
file_path = "eval_results/gsm8k/verifier/test/responses_v(mistral7b-ep2-n100-scahead-mse-lm-token)_g(llama2chatfinetuned).jsonl"
model_dir = "/data/OVM-Mistral-7b/mistral7b-ep2-n100-scahead-mse-lm-token"
tokenizer = load_tokenizer(model_dir)
number = multiprocessing.cpu_count()
pool = multiprocessing.Pool(1)
acc = []
with open(file_path, 'r', encoding='utf-8') as fp:
lines = fp.readlines()
with tqdm.tqdm(total=len(lines)) as pbar: # Correct usage of tqdm
func = partial(process_line, tokenizer)
for result in pool.imap(func, lines):
acc.extend(result)
pbar.update()
print(f"acc : {sum(acc)/len(acc)}")
# Ensure that the main module is being run
if __name__ == '__main__':
pass
# load_vescores()
# Example usage
# strings = [
# "In the fourth game, Clayton scored the average of his points from the first three games. This is 24+14+10 = <<24+14+10=40>>40 points."
# "In the fourth game, Clayton scored the average of his points from the first three games. This is 24+14+10 = <<24+14+10=48>>",
# "Another example where 2*5 = <<2*5=10>>"
# ]
#
# accuracy = evaluate_expression(strings)
# print(f"Accuracy: {accuracy}")