File size: 5,704 Bytes
da66274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import re
from  test import delete_extra_zero
import transformers
DEFAULT_PAD_TOKEN = "<pad>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"

import json
import tqdm
import multiprocessing
from functools import partial

def load_tokenizer(model_name_or_path):


    print(f"+ [Model] Initializing Tokenizer: {model_name_or_path}")
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_name_or_path,
        padding_side="right",
        use_fast=False,
    )

    if 'phi' in model_name_or_path:
        tokenizer.pad_token = tokenizer.unk_token
    else:
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
            })


    return tokenizer

def evaluate_expression(string):
    correct_count = 0

    # Find the expression inside <<>>
    match = re.search(r'<<(.+?)>>', string)
    if match:
        expression = match.group(1)
        # Separate the expressions before and after the '='
        try:
            before_equal, after_equal = expression.split('=')
            # Evaluate the expression before the '='
            computed_value = float(eval(before_equal.strip()))
            # Convert the after_equal to an integer for comparison
            actual_value = float(delete_extra_zero(after_equal.strip().replace(",", "")))
            # Compare the computed value with the actual value
            if abs(computed_value - actual_value) <= 1e-3:
                correct_count = 1
        except Exception as e:
            print(f"Error evaluating expression: {expression}. Error: {e}")

    # Calculate accuracy
    return correct_count



def process_line(tokenizer, line):

    acc = []
    line = json.loads(line)
    for idx in range(len(line['outputs'])):
        item = line['outputs'][idx]
        v_scores = item['vscores']
        solution_tokens = item['tokens']
        if item['label']:
            split_token_id = 13
            split_indices = [0]
            split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if
                                  token_id == split_token_id and solution_tokens[i - 1] != split_token_id])
            split_indices.append(len(solution_tokens))
            segment_v_scores = [v_scores[split_indices[i]] for i in range(1, len(split_indices))]
            score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in
                             range(1, len(segment_v_scores))]
            # split_token_id = 13
            # split_indices = [0]
            # split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if token_id == split_token_id and solution_tokens[i - 1] != split_token_id])
            # split_indices.append(len(solution_tokens))
            # segment_v_scores = [v_scores[split_indices[i]] for i in range(len(split_indices) - 1)]
            # score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in range(1, len(segment_v_scores))]
            if len(score_changes) and min(score_changes) < 0:
                max_change_index = score_changes.index(min(score_changes)) + 1
                highlighted_solution = []
                for i in range(len(split_indices) - 1):
                    segment = solution_tokens[split_indices[i]:split_indices[i + 1]]
                    if i == max_change_index:
                        detect_string = tokenizer.decode(segment[:-1])
                        highlighted_solution.append("<Error>" + detect_string + "<Error>\n")
                        matches = re.findall(r'<<([^>>]+)>>', detect_string)
                        if not matches:
                            continue
                        is_false = not evaluate_expression(detect_string)
                        if is_false:
                            acc.append(1)
                        else:
                            acc.append(0)
                    else:
                        highlighted_solution.append(tokenizer.decode(segment))
    return acc

def process_line2(tokenizer, line):

    filter_list= []
    line = json.loads(line)
    for idx in range(len(line['outputs'])):
        item = line['outputs'][idx]
        v_scores = item['vscores']
        solution_tokens = item['tokens']


def load_vescores():
    file_path = "eval_results/gsm8k/verifier/test/responses_v(mistral7b-ep2-n100-scahead-mse-lm-token)_g(llama2chatfinetuned).jsonl"
    model_dir = "/data/OVM-Mistral-7b/mistral7b-ep2-n100-scahead-mse-lm-token"
    tokenizer = load_tokenizer(model_dir)
    number  = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(1)
    acc = []
    with open(file_path, 'r', encoding='utf-8') as fp:
        lines = fp.readlines()
        with tqdm.tqdm(total=len(lines)) as pbar:  # Correct usage of tqdm
            func = partial(process_line, tokenizer)
            for result in pool.imap(func, lines):
                acc.extend(result)
                pbar.update()

    print(f"acc : {sum(acc)/len(acc)}")

# Ensure that the main module is being run
if __name__ == '__main__':
    pass
    # load_vescores()
    # Example usage
    # strings = [
    #     "In the fourth game, Clayton scored the average of his points from the first three games.  This is 24+14+10 = <<24+14+10=40>>40 points."
    #     "In the fourth game, Clayton scored the average of his points from the first three games. This is 24+14+10 = <<24+14+10=48>>",
    #     "Another example where 2*5 = <<2*5=10>>"
    # ]
    #
    # accuracy = evaluate_expression(strings)
    # print(f"Accuracy: {accuracy}")