File size: 7,759 Bytes
1395a53
8af54b8
 
a6d7b1c
 
4c7982b
 
 
be1543a
 
9827786
3a8c0d0
f2c1a54
08339c7
8af54b8
 
 
 
 
 
be1543a
8af54b8
 
 
be1543a
8af54b8
 
 
 
 
 
 
 
33a6f85
 
 
8af54b8
 
 
 
 
 
 
 
 
9827786
 
 
c6f1343
9827786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af54b8
e01a5f6
 
a6d7b1c
845a45a
08339c7
72dba58
 
 
 
 
 
 
 
 
 
c1cde4c
72dba58
845a45a
3a8c0d0
 
 
 
232b173
 
3a8c0d0
0c75eca
f2c1a54
 
0c75eca
f2c1a54
0c75eca
f2c1a54
 
 
a6d7b1c
be1543a
 
a6d7b1c
 
3a8c0d0
 
f2c1a54
 
 
0c75eca
 
a6d7b1c
f21585c
 
 
360e3ac
f21585c
 
 
 
0c75eca
f21585c
3a8c0d0
be1543a
33af91b
be1543a
 
 
 
 
 
 
 
3a8c0d0
 
18cd4ae
845a45a
 
 
 
 
 
141ccb9
 
 
 
 
 
 
 
 
 
3a8c0d0
845a45a
 
 
 
 
 
 
d13c0d8
 
 
5ca9a91
d13c0d8
 
 
 
f2c1a54
 
 
 
 
 
 
 
 
 
 
33af91b
5ca9a91
33af91b
 
 
 
 
 
 
 
 
 
 
 
 
72dba58
33af91b
 
72dba58
 
 
 
 
3a8c0d0
 
33af91b
3a8c0d0
 
845a45a
3a8c0d0
 
 
c6f1343
72dba58
3a8c0d0
c1cde4c
3a8c0d0
 
 
 
5ca9a91
 
 
 
 
 
 
3a8c0d0
 
 
33af91b
 
3a8c0d0
72dba58
be1543a
 
a6d7b1c
3a8c0d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import logging

from typing import Any, Optional, Protocol, Iterable, Callable
from tqdm.auto import tqdm
from evaluate.evaluation_suite import EvaluationSuite
import evaluate
import numpy as np
import datasets
import pandas as pd
from .tasks import *
from .utils import *
from itertools import chain
from copy import deepcopy
from . import utils


class ReasoningMetric(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # if self.config_name in ["cmmlu"]:
        features = datasets.Features(
            {
                "responses": datasets.Value("string"),
                # "responses": datasets.Sequence(datasets.Value("float")),
                "references": datasets.Value("string"),
            }
        )

        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.EvaluationModuleInfo(
            # This is the description that will appear on the modules page.
            # module_type="measurement",
            description="",
            citation="",
            inputs_description="",
            # This defines the format of each prediction and reference
            features=features,
            # Homepage of the module for documentation
            homepage="http://module.homepage",
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"],
        )

    def _compute(self, responses, references):
        return_value = getattr(Metrics, self.config_name)(responses, references)
        match return_value:
            case extract_responses, extract_references:
                results = {
                    self.config_name: np.mean(
                        sync_pipe(lambda x, y: x == y)(
                            zip(extract_responses, extract_references)
                        )
                    )
                }
            case dict():
                results = return_value

            case list():
                results = {self.config_name: np.mean(return_value)}

            case _:
                raise NotImplementedError

        return results


class Suite(EvaluationSuite):
    task_class = Task
    utils = utils
    supported_datasets = [
        "arc",
        "hellaswag",
        "mmlu-chat",
        "winogrande",
        "gsm8k",
        "cmmlu-chat",
        "ceval-chat",
        "bbh",
        "drop",
        "MATH",
    ]

    def __getitem__(self, key) -> Task:
        match key:
            case str():
                return self.suite[key]
            case slice() | int():
                return self.tasks[key]

    def agg(self, suite):
        for cate, tasks in suite.items():
            if isinstance(tasks, dict):
                suite[cate] = self.agg(tasks)
            else:
                suite[cate] = np.mean([pd.Series(task.result).mean() for task in tasks])

        return suite

    def run(
        self,
        model_or_pipeline: Any,
    ) -> dict[str, float]:
        self.assert_suite_nonempty()

        self.suite: dict[str, list[Task]]
        for task in (bar := tqdm(self.tasks)):
            bar.desc = f"complete {task.name}."
            _ = task.run(model_or_pipeline)
            logging.info(f"{task.name} {task.result=}")
        return self.agg(deepcopy(self.suite))

    def arun(self, model_or_pipeline):
        async def sync_function():
            return await tqdm.gather(
                *[task.arun(model_or_pipeline) for task in self.tasks], leave=False
            )

        asyncio.run(sync_function())

        return self.agg(deepcopy(self.suite))

    def get_suite(self, name) -> dict[str, Task]:
        chat = False
        suite={}
        match name:
            case _ if "chat" in name:
                chat = True
        match name:
            case _ if name.startswith("mmlu"):
                suite = MMLU.suite(chat=chat)
            case _ if name.startswith("cmmlu"):
                suite = CMMLU.suite(chat=chat)
            case _ if name.startswith("ceval"):
                suite = CEVAL.suite(chat=chat)
            case "gsm8k":
                suite = Task(
                    dataset_name=("gsm8k", "main"),
                    metric_name=("sustech/tlem", "gsm8k"),
                    input_column="question",
                    label_column="answer",
                )
            case "bbh":
                suite = BBH.suite()
            case "arc":
                suite = ARC.suite()
            case "hellaswag":
                suite = HellaSwag.suite()
            case "drop":
                suite = DROP.suite()
            case "winogrande":
                suite = Winogrande.suite()

            case "mt_bench":
                suite = Task(
                    dataset_name="SUSTech/mt_bench_judge",
                    split="train",
                    prompt=mt_bench_prompt
                    # metric_name=("sustech/tlem", "gsm8k"),
                )
            case "MATH" | "competition_math":
                suite = Task(
                    dataset_name="hendrycks/competition_math",
                    prompt="This is a math problem, please think step by step and slove it: {input_column}. Simplify your final answer as much as possible and surround them with '$' in TeX form.",
                    metric_name=("sustech/tlem", "MATH"),
                    input_column="problem",
                    label_column="solution",
                )

            case "open-leaderboard":
                for name in [
                    "arc",
                    "hellaswag",
                    "mmlu-chat",
                    "winogrande",
                    "gsm8k",
                    # "truthful_qa",
                    "drop",
                ]:
                    suite.update(self.get_suite(name))
            case "tlem":
                for name in [
                    "arc",
                    "hellaswag",
                    "mmlu-chat",
                    "winogrande",
                    "gsm8k",
                    # "truthful_qa",
                    "cmmlu-chat",
                    "ceval-chat",
                    "bbh",
                ]:
                    suite.update(self.get_suite(name))

            case "all":
                for name in self.supported_datasets:
                    suite.update(self.get_suite(name))
            case _:
                raise NotImplementedError(
                    f"{name} is not supported in {self.supported_datasets}"
                )

        if isinstance(suite, Task):
            suite = [suite]
        suite = {name: suite}

        return suite

    def singleton(self, task):
        try:
            return self.tasks[self.tasks.index(task)]
        except ValueError:
            logging.debug(f"add {task.name} to suite.")
            self.tasks.append(task)
            logging.debug(self.tasks)
            return self.tasks[-1]

    def drop_duplicates(self, suite):
        for category, tasks in suite.items():
            match tasks:
                case list():
                    suite[category] = [self.singleton(task) for task in tasks]
                case dict():
                    suite[category] = self.drop_duplicates(tasks)
                case _:
                    raise NotImplementedError
        return suite

    def load(self, name):
        sub_suite = self.get_suite(name)
        self.suite.update(sub_suite)
        self.suite = self.drop_duplicates(self.suite)
        # return self

    def __init__(self, name="tlem"):
        super().__init__(name)
        self.tasks = []
        self.suite = {}