Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- scripts/yans/lm-evaluation-harness/lm_eval/__init__.py +1 -0
- scripts/yans/lm-evaluation-harness/lm_eval/__main__.py +461 -0
- scripts/yans/lm-evaluation-harness/lm_eval/evaluator.py +649 -0
- scripts/yans/lm-evaluation-harness/lm_eval/evaluator_utils.py +542 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/__init__.py +25 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/extraction.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/selection.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/transformation.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/decontamination.py +25 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/extraction.py +184 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/selection.py +61 -0
- scripts/yans/lm-evaluation-harness/lm_eval/filters/transformation.py +56 -0
- scripts/yans/lm-evaluation-harness/lm_eval/prompts/__init__.py +126 -0
- scripts/yans/lm-evaluation-harness/lm_eval/prompts/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/README.md +119 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/__init__.py +650 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/README.md +56 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/anli_r1.yaml +26 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/anli_r2.yaml +5 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/anli_r3.yaml +5 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/drop/README.md +53 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/drop/default.yaml +26 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/drop/utils.py +205 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/README.md +54 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math.yaml +15 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_algebra.yaml +25 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_counting_and_prob.yaml +3 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_geometry.yaml +3 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_intermediate_algebra.yaml +3 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_num_theory.yaml +3 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_prealgebra.yaml +3 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_precalc.yaml +3 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/utils.py +231 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/siqa/README.md +37 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/siqa/siqa.yaml +16 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/squadv2/README.md +54 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/squadv2/squadv2.yaml +2 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/squadv2/task.py +241 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/README.md +60 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/_xcopa.yaml +19 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_et.yaml +13 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_ht.yaml +4 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_id.yaml +4 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_it.yaml +4 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_qu.yaml +4 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_sw.yaml +4 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_ta.yaml +4 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_th.yaml +4 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_tr.yaml +4 -0
scripts/yans/lm-evaluation-harness/lm_eval/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .evaluator import evaluate, simple_evaluate
|
scripts/yans/lm-evaluation-harness/lm_eval/__main__.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from functools import partial
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from lm_eval import evaluator, utils
|
10 |
+
from lm_eval.evaluator import request_caching_arg_to_dict
|
11 |
+
from lm_eval.loggers import EvaluationTracker, WandbLogger
|
12 |
+
from lm_eval.tasks import TaskManager
|
13 |
+
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
|
14 |
+
|
15 |
+
|
16 |
+
def _int_or_none_list_arg_type(
|
17 |
+
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
|
18 |
+
):
|
19 |
+
def parse_value(item):
|
20 |
+
item = item.strip().lower()
|
21 |
+
if item == "none":
|
22 |
+
return None
|
23 |
+
try:
|
24 |
+
return int(item)
|
25 |
+
except ValueError:
|
26 |
+
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
|
27 |
+
|
28 |
+
items = [parse_value(v) for v in value.split(split_char)]
|
29 |
+
num_items = len(items)
|
30 |
+
|
31 |
+
if num_items == 1:
|
32 |
+
# Makes downstream handling the same for single and multiple values
|
33 |
+
items = items * max_len
|
34 |
+
elif num_items < min_len or num_items > max_len:
|
35 |
+
raise argparse.ArgumentTypeError(
|
36 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
|
37 |
+
)
|
38 |
+
elif num_items != max_len:
|
39 |
+
logging.warning(
|
40 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
|
41 |
+
"Missing values will be filled with defaults."
|
42 |
+
)
|
43 |
+
default_items = [parse_value(v) for v in defaults.split(split_char)]
|
44 |
+
items.extend(
|
45 |
+
default_items[num_items:]
|
46 |
+
) # extend items list with missing defaults
|
47 |
+
|
48 |
+
return items
|
49 |
+
|
50 |
+
|
51 |
+
def check_argument_types(parser: argparse.ArgumentParser):
|
52 |
+
"""
|
53 |
+
Check to make sure all CLI args are typed, raises error if not
|
54 |
+
"""
|
55 |
+
for action in parser._actions:
|
56 |
+
if action.dest != "help" and not action.const:
|
57 |
+
if action.type is None:
|
58 |
+
raise ValueError(
|
59 |
+
f"Argument '{action.dest}' doesn't have a type specified."
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
continue
|
63 |
+
|
64 |
+
|
65 |
+
def setup_parser() -> argparse.ArgumentParser:
|
66 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
67 |
+
parser.add_argument(
|
68 |
+
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--tasks",
|
72 |
+
"-t",
|
73 |
+
default=None,
|
74 |
+
type=str,
|
75 |
+
metavar="task1,task2",
|
76 |
+
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--model_args",
|
80 |
+
"-a",
|
81 |
+
default="",
|
82 |
+
type=str,
|
83 |
+
help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--num_fewshot",
|
87 |
+
"-f",
|
88 |
+
type=int,
|
89 |
+
default=None,
|
90 |
+
metavar="N",
|
91 |
+
help="Number of examples in few-shot context",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--batch_size",
|
95 |
+
"-b",
|
96 |
+
type=str,
|
97 |
+
default=1,
|
98 |
+
metavar="auto|auto:N|N",
|
99 |
+
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--max_batch_size",
|
103 |
+
type=int,
|
104 |
+
default=None,
|
105 |
+
metavar="N",
|
106 |
+
help="Maximal batch size to try with --batch_size auto.",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--device",
|
110 |
+
type=str,
|
111 |
+
default=None,
|
112 |
+
help="Device to use (e.g. cuda, cuda:0, cpu).",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--output_path",
|
116 |
+
"-o",
|
117 |
+
default=None,
|
118 |
+
type=str,
|
119 |
+
metavar="DIR|DIR/file.json",
|
120 |
+
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--limit",
|
124 |
+
"-L",
|
125 |
+
type=float,
|
126 |
+
default=None,
|
127 |
+
metavar="N|0<N<1",
|
128 |
+
help="Limit the number of examples per task. "
|
129 |
+
"If <1, limit is a percentage of the total number of examples.",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--use_cache",
|
133 |
+
"-c",
|
134 |
+
type=str,
|
135 |
+
default=None,
|
136 |
+
metavar="DIR",
|
137 |
+
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--cache_requests",
|
141 |
+
type=str,
|
142 |
+
default=None,
|
143 |
+
choices=["true", "refresh", "delete"],
|
144 |
+
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--check_integrity",
|
148 |
+
action="store_true",
|
149 |
+
help="Whether to run the relevant part of the test suite for the tasks.",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--write_out",
|
153 |
+
"-w",
|
154 |
+
action="store_true",
|
155 |
+
default=False,
|
156 |
+
help="Prints the prompt for the first few documents.",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--log_samples",
|
160 |
+
"-s",
|
161 |
+
action="store_true",
|
162 |
+
default=False,
|
163 |
+
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--system_instruction",
|
167 |
+
type=str,
|
168 |
+
default=None,
|
169 |
+
help="System instruction to be used in the prompt",
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--apply_chat_template",
|
173 |
+
action="store_true",
|
174 |
+
default=False,
|
175 |
+
help="If True, applies the chat template to the prompt",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--fewshot_as_multiturn",
|
179 |
+
action="store_true",
|
180 |
+
default=False,
|
181 |
+
help="If True, uses the fewshot as a multi-turn conversation",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--show_config",
|
185 |
+
action="store_true",
|
186 |
+
default=False,
|
187 |
+
help="If True, shows the the full config of all tasks at the end of the evaluation.",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--include_path",
|
191 |
+
type=str,
|
192 |
+
default=None,
|
193 |
+
metavar="DIR",
|
194 |
+
help="Additional path to include if there are external tasks to include.",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--gen_kwargs",
|
198 |
+
type=str,
|
199 |
+
default=None,
|
200 |
+
help=(
|
201 |
+
"String arguments for model generation on greedy_until tasks,"
|
202 |
+
" e.g. `temperature=0,top_k=0,top_p=0`."
|
203 |
+
),
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
"--verbosity",
|
207 |
+
"-v",
|
208 |
+
type=str.upper,
|
209 |
+
default="INFO",
|
210 |
+
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
|
211 |
+
help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
|
212 |
+
)
|
213 |
+
parser.add_argument(
|
214 |
+
"--wandb_args",
|
215 |
+
type=str,
|
216 |
+
default="",
|
217 |
+
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--hf_hub_log_args",
|
221 |
+
type=str,
|
222 |
+
default="",
|
223 |
+
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--predict_only",
|
227 |
+
"-x",
|
228 |
+
action="store_true",
|
229 |
+
default=False,
|
230 |
+
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
|
231 |
+
)
|
232 |
+
default_seed_string = "0,1234,1234,1234"
|
233 |
+
parser.add_argument(
|
234 |
+
"--seed",
|
235 |
+
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
|
236 |
+
default=default_seed_string, # for backward compatibility
|
237 |
+
help=(
|
238 |
+
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
|
239 |
+
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
|
240 |
+
"respectively, or a single integer to set the same seed for all four.\n"
|
241 |
+
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
|
242 |
+
"(for backward compatibility).\n"
|
243 |
+
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
|
244 |
+
"Here numpy's seed is not set since the second value is `None`.\n"
|
245 |
+
"E.g, `--seed 42` sets all four seeds to 42."
|
246 |
+
),
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--trust_remote_code",
|
250 |
+
action="store_true",
|
251 |
+
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
|
252 |
+
)
|
253 |
+
return parser
|
254 |
+
|
255 |
+
|
256 |
+
def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
257 |
+
check_argument_types(parser)
|
258 |
+
return parser.parse_args()
|
259 |
+
|
260 |
+
|
261 |
+
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
|
262 |
+
if not args:
|
263 |
+
# we allow for args to be passed externally, else we parse them ourselves
|
264 |
+
parser = setup_parser()
|
265 |
+
args = parse_eval_args(parser)
|
266 |
+
|
267 |
+
if args.wandb_args:
|
268 |
+
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
|
269 |
+
|
270 |
+
eval_logger = utils.eval_logger
|
271 |
+
eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
|
272 |
+
eval_logger.info(f"Verbosity set to {args.verbosity}")
|
273 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
274 |
+
|
275 |
+
# update the evaluation tracker args with the output path and the HF token
|
276 |
+
if args.output_path:
|
277 |
+
args.hf_hub_log_args += f",output_path={args.output_path}"
|
278 |
+
if os.environ.get("HF_TOKEN", None):
|
279 |
+
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
|
280 |
+
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
|
281 |
+
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
|
282 |
+
|
283 |
+
if args.predict_only:
|
284 |
+
args.log_samples = True
|
285 |
+
if (args.log_samples or args.predict_only) and not args.output_path:
|
286 |
+
raise ValueError(
|
287 |
+
"Specify --output_path if providing --log_samples or --predict_only"
|
288 |
+
)
|
289 |
+
|
290 |
+
if args.fewshot_as_multiturn and args.apply_chat_template is False:
|
291 |
+
raise ValueError(
|
292 |
+
"If fewshot_as_multiturn is set, apply_chat_template must be set to True."
|
293 |
+
)
|
294 |
+
|
295 |
+
if (
|
296 |
+
args.num_fewshot is None or args.num_fewshot == 0
|
297 |
+
) and args.fewshot_as_multiturn:
|
298 |
+
raise ValueError(
|
299 |
+
"If fewshot_as_multiturn is set, num_fewshot must be greater than 0."
|
300 |
+
)
|
301 |
+
|
302 |
+
if args.include_path is not None:
|
303 |
+
eval_logger.info(f"Including path: {args.include_path}")
|
304 |
+
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
|
305 |
+
|
306 |
+
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
|
307 |
+
eval_logger.warning(
|
308 |
+
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
|
309 |
+
)
|
310 |
+
|
311 |
+
if args.limit:
|
312 |
+
eval_logger.warning(
|
313 |
+
" --limit SHOULD ONLY BE USED FOR TESTING."
|
314 |
+
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
|
315 |
+
)
|
316 |
+
|
317 |
+
if args.tasks is None:
|
318 |
+
eval_logger.error("Need to specify task to evaluate.")
|
319 |
+
sys.exit()
|
320 |
+
elif args.tasks == "list":
|
321 |
+
print(task_manager.list_all_tasks())
|
322 |
+
sys.exit()
|
323 |
+
elif args.tasks == "list_groups":
|
324 |
+
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
|
325 |
+
sys.exit()
|
326 |
+
elif args.tasks == "list_tags":
|
327 |
+
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
|
328 |
+
sys.exit()
|
329 |
+
elif args.tasks == "list_subtasks":
|
330 |
+
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
|
331 |
+
sys.exit()
|
332 |
+
else:
|
333 |
+
if os.path.isdir(args.tasks):
|
334 |
+
import glob
|
335 |
+
|
336 |
+
task_names = []
|
337 |
+
yaml_path = os.path.join(args.tasks, "*.yaml")
|
338 |
+
for yaml_file in glob.glob(yaml_path):
|
339 |
+
config = utils.load_yaml_config(yaml_file)
|
340 |
+
task_names.append(config)
|
341 |
+
else:
|
342 |
+
task_list = args.tasks.split(",")
|
343 |
+
task_names = task_manager.match_tasks(task_list)
|
344 |
+
for task in [task for task in task_list if task not in task_names]:
|
345 |
+
if os.path.isfile(task):
|
346 |
+
config = utils.load_yaml_config(task)
|
347 |
+
task_names.append(config)
|
348 |
+
task_missing = [
|
349 |
+
task for task in task_list if task not in task_names and "*" not in task
|
350 |
+
] # we don't want errors if a wildcard ("*") task name was used
|
351 |
+
|
352 |
+
if task_missing:
|
353 |
+
missing = ", ".join(task_missing)
|
354 |
+
eval_logger.error(
|
355 |
+
f"Tasks were not found: {missing}\n"
|
356 |
+
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
|
357 |
+
)
|
358 |
+
raise ValueError(
|
359 |
+
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
|
360 |
+
)
|
361 |
+
|
362 |
+
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
|
363 |
+
if args.trust_remote_code:
|
364 |
+
eval_logger.info(
|
365 |
+
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
|
366 |
+
)
|
367 |
+
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
|
368 |
+
# because it's already been determined based on the prior env var before launching our
|
369 |
+
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
|
370 |
+
import datasets
|
371 |
+
|
372 |
+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
|
373 |
+
|
374 |
+
args.model_args = args.model_args + ",trust_remote_code=True"
|
375 |
+
|
376 |
+
eval_logger.info(f"Selected Tasks: {task_names}")
|
377 |
+
|
378 |
+
request_caching_args = request_caching_arg_to_dict(
|
379 |
+
cache_requests=args.cache_requests
|
380 |
+
)
|
381 |
+
|
382 |
+
results = evaluator.simple_evaluate(
|
383 |
+
model=args.model,
|
384 |
+
model_args=args.model_args,
|
385 |
+
tasks=task_names,
|
386 |
+
num_fewshot=args.num_fewshot,
|
387 |
+
batch_size=args.batch_size,
|
388 |
+
max_batch_size=args.max_batch_size,
|
389 |
+
device=args.device,
|
390 |
+
use_cache=args.use_cache,
|
391 |
+
limit=args.limit,
|
392 |
+
check_integrity=args.check_integrity,
|
393 |
+
write_out=args.write_out,
|
394 |
+
log_samples=args.log_samples,
|
395 |
+
evaluation_tracker=evaluation_tracker,
|
396 |
+
system_instruction=args.system_instruction,
|
397 |
+
apply_chat_template=args.apply_chat_template,
|
398 |
+
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
399 |
+
gen_kwargs=args.gen_kwargs,
|
400 |
+
task_manager=task_manager,
|
401 |
+
verbosity=args.verbosity,
|
402 |
+
predict_only=args.predict_only,
|
403 |
+
random_seed=args.seed[0],
|
404 |
+
numpy_random_seed=args.seed[1],
|
405 |
+
torch_random_seed=args.seed[2],
|
406 |
+
fewshot_random_seed=args.seed[3],
|
407 |
+
**request_caching_args,
|
408 |
+
)
|
409 |
+
|
410 |
+
if results is not None:
|
411 |
+
if args.log_samples:
|
412 |
+
samples = results.pop("samples")
|
413 |
+
dumped = json.dumps(
|
414 |
+
results, indent=2, default=handle_non_serializable, ensure_ascii=False
|
415 |
+
)
|
416 |
+
if args.show_config:
|
417 |
+
print(dumped)
|
418 |
+
|
419 |
+
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
|
420 |
+
|
421 |
+
# Add W&B logging
|
422 |
+
if args.wandb_args:
|
423 |
+
try:
|
424 |
+
wandb_logger.post_init(results)
|
425 |
+
wandb_logger.log_eval_result()
|
426 |
+
if args.log_samples:
|
427 |
+
wandb_logger.log_eval_samples(samples)
|
428 |
+
except Exception as e:
|
429 |
+
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
|
430 |
+
|
431 |
+
evaluation_tracker.save_results_aggregated(
|
432 |
+
results=results, samples=samples if args.log_samples else None
|
433 |
+
)
|
434 |
+
|
435 |
+
if args.log_samples:
|
436 |
+
for task_name, config in results["configs"].items():
|
437 |
+
evaluation_tracker.save_results_samples(
|
438 |
+
task_name=task_name, samples=samples[task_name]
|
439 |
+
)
|
440 |
+
|
441 |
+
if (
|
442 |
+
evaluation_tracker.push_results_to_hub
|
443 |
+
or evaluation_tracker.push_samples_to_hub
|
444 |
+
):
|
445 |
+
evaluation_tracker.recreate_metadata_card()
|
446 |
+
|
447 |
+
print(
|
448 |
+
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
|
449 |
+
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
|
450 |
+
)
|
451 |
+
print(make_table(results))
|
452 |
+
if "groups" in results:
|
453 |
+
print(make_table(results, "groups"))
|
454 |
+
|
455 |
+
if args.wandb_args:
|
456 |
+
# Tear down wandb run once all the logging is done.
|
457 |
+
wandb_logger.run.finish()
|
458 |
+
|
459 |
+
|
460 |
+
if __name__ == "__main__":
|
461 |
+
cli_evaluate()
|
scripts/yans/lm-evaluation-harness/lm_eval/evaluator.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
from collections import defaultdict
|
7 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import lm_eval.api.metrics
|
13 |
+
import lm_eval.api.registry
|
14 |
+
import lm_eval.api.task
|
15 |
+
import lm_eval.models
|
16 |
+
from lm_eval.caching.cache import delete_cache
|
17 |
+
from lm_eval.evaluator_utils import (
|
18 |
+
consolidate_group_results,
|
19 |
+
consolidate_results,
|
20 |
+
get_sample_size,
|
21 |
+
get_subtask_list,
|
22 |
+
get_task_list,
|
23 |
+
prepare_print_tasks,
|
24 |
+
print_writeout,
|
25 |
+
run_task_tests,
|
26 |
+
)
|
27 |
+
from lm_eval.loggers import EvaluationTracker
|
28 |
+
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
|
29 |
+
from lm_eval.tasks import (
|
30 |
+
TaskManager,
|
31 |
+
get_task_dict,
|
32 |
+
)
|
33 |
+
from lm_eval.utils import (
|
34 |
+
eval_logger,
|
35 |
+
handle_non_serializable,
|
36 |
+
hash_string,
|
37 |
+
positional_deprecated,
|
38 |
+
simple_parse_args_string,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
if TYPE_CHECKING:
|
43 |
+
from lm_eval.api.model import LM
|
44 |
+
from lm_eval.api.task import Task
|
45 |
+
|
46 |
+
|
47 |
+
@positional_deprecated
|
48 |
+
def simple_evaluate(
|
49 |
+
model,
|
50 |
+
model_args: Optional[Union[str, dict]] = None,
|
51 |
+
tasks: Optional[List[Union[str, dict, object]]] = None,
|
52 |
+
num_fewshot: Optional[int] = None,
|
53 |
+
batch_size: Optional[Union[int, str]] = None,
|
54 |
+
max_batch_size: Optional[int] = None,
|
55 |
+
device: Optional[str] = None,
|
56 |
+
use_cache: Optional[str] = None,
|
57 |
+
cache_requests: bool = False,
|
58 |
+
rewrite_requests_cache: bool = False,
|
59 |
+
delete_requests_cache: bool = False,
|
60 |
+
limit: Optional[Union[int, float]] = None,
|
61 |
+
bootstrap_iters: int = 100000,
|
62 |
+
check_integrity: bool = False,
|
63 |
+
write_out: bool = False,
|
64 |
+
log_samples: bool = True,
|
65 |
+
evaluation_tracker: Optional[EvaluationTracker] = None,
|
66 |
+
system_instruction: Optional[str] = None,
|
67 |
+
apply_chat_template: bool = False,
|
68 |
+
fewshot_as_multiturn: bool = False,
|
69 |
+
gen_kwargs: Optional[str] = None,
|
70 |
+
task_manager: Optional[TaskManager] = None,
|
71 |
+
verbosity: str = "INFO",
|
72 |
+
predict_only: bool = False,
|
73 |
+
random_seed: int = 0,
|
74 |
+
numpy_random_seed: int = 1234,
|
75 |
+
torch_random_seed: int = 1234,
|
76 |
+
fewshot_random_seed: int = 1234,
|
77 |
+
):
|
78 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
79 |
+
|
80 |
+
:param model: Union[str, LM]
|
81 |
+
Name of model or LM object, see lm_eval.models.get_model
|
82 |
+
:param model_args: Optional[str, dict]
|
83 |
+
String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
|
84 |
+
Ignored if `model` argument is a LM object.
|
85 |
+
:param tasks: list[Union[str, dict, Task]]
|
86 |
+
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
|
87 |
+
:param num_fewshot: int
|
88 |
+
Number of examples in few-shot context
|
89 |
+
:param batch_size: int or str, optional
|
90 |
+
Batch size for model
|
91 |
+
:param max_batch_size: int, optional
|
92 |
+
Maximal batch size to try with automatic batch size detection
|
93 |
+
:param device: str, optional
|
94 |
+
PyTorch device (e.g. "cpu" or "cuda:0") for running models
|
95 |
+
:param use_cache: str, optional
|
96 |
+
A path to a sqlite db file for caching model responses. `None` if not caching.
|
97 |
+
:param cache_requests: bool, optional
|
98 |
+
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
|
99 |
+
:param rewrite_requests_cache: bool, optional
|
100 |
+
Rewrites all of the request cache if set to `True`. `None` if not desired.
|
101 |
+
:param delete_requests_cache: bool, optional
|
102 |
+
Deletes all of the request cache if set to `True`. `None` if not desired.
|
103 |
+
:param limit: int or float, optional
|
104 |
+
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
|
105 |
+
:param bootstrap_iters:
|
106 |
+
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
|
107 |
+
:param check_integrity: bool
|
108 |
+
Whether to run the relevant part of the test suite for the tasks
|
109 |
+
:param write_out: bool
|
110 |
+
If True, write out an example document and model input for checking task integrity
|
111 |
+
:param log_samples: bool
|
112 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
113 |
+
:param system_instruction: str
|
114 |
+
System instruction to be applied to the prompt
|
115 |
+
:param apply_chat_template: bool
|
116 |
+
If True, apply chat template to the prompt
|
117 |
+
:param fewshot_as_multiturn: bool
|
118 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
119 |
+
:param gen_kwargs: str
|
120 |
+
String arguments for model generation
|
121 |
+
Ignored for all tasks with loglikelihood output_type
|
122 |
+
:param predict_only: bool
|
123 |
+
If true only model outputs will be generated and returned. Metrics will not be evaluated
|
124 |
+
:param random_seed: int
|
125 |
+
Random seed for python's random module. If set to None, the seed will not be set.
|
126 |
+
:param numpy_random_seed: int
|
127 |
+
Random seed for numpy. If set to None, the seed will not be set.
|
128 |
+
:param torch_random_seed: int
|
129 |
+
Random seed for torch. If set to None, the seed will not be set.
|
130 |
+
:param fewshot_random_seed: int
|
131 |
+
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
|
132 |
+
|
133 |
+
:return
|
134 |
+
Dictionary of results
|
135 |
+
"""
|
136 |
+
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
|
137 |
+
start_date = time.time()
|
138 |
+
|
139 |
+
if delete_requests_cache:
|
140 |
+
eval_logger.info("Deleting requests cache...")
|
141 |
+
delete_cache()
|
142 |
+
|
143 |
+
seed_message = []
|
144 |
+
if random_seed is not None:
|
145 |
+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
|
146 |
+
seed_message.append(f"Setting random seed to {random_seed}")
|
147 |
+
random.seed(random_seed)
|
148 |
+
|
149 |
+
if numpy_random_seed is not None:
|
150 |
+
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
|
151 |
+
np.random.seed(numpy_random_seed)
|
152 |
+
|
153 |
+
if torch_random_seed is not None:
|
154 |
+
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
|
155 |
+
torch.manual_seed(torch_random_seed)
|
156 |
+
|
157 |
+
if seed_message:
|
158 |
+
eval_logger.info(" | ".join(seed_message))
|
159 |
+
|
160 |
+
if tasks is None:
|
161 |
+
tasks = []
|
162 |
+
if len(tasks) == 0:
|
163 |
+
raise ValueError(
|
164 |
+
"No tasks specified, or no tasks found. Please verify the task names."
|
165 |
+
)
|
166 |
+
|
167 |
+
if gen_kwargs is not None:
|
168 |
+
gen_kwargs = simple_parse_args_string(gen_kwargs)
|
169 |
+
eval_logger.warning(
|
170 |
+
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
|
171 |
+
"Ensure 'do_sample=True' for non-greedy decoding!"
|
172 |
+
)
|
173 |
+
if gen_kwargs == "":
|
174 |
+
gen_kwargs = None
|
175 |
+
|
176 |
+
if isinstance(model, str):
|
177 |
+
if model_args is None:
|
178 |
+
eval_logger.warning("model_args not specified. Using defaults.")
|
179 |
+
model_args = ""
|
180 |
+
|
181 |
+
if isinstance(model_args, dict):
|
182 |
+
eval_logger.info(
|
183 |
+
f"Initializing {model} model, with arguments: {model_args}"
|
184 |
+
)
|
185 |
+
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
|
186 |
+
model_args,
|
187 |
+
{
|
188 |
+
"batch_size": batch_size,
|
189 |
+
"max_batch_size": max_batch_size,
|
190 |
+
"device": device,
|
191 |
+
},
|
192 |
+
)
|
193 |
+
|
194 |
+
else:
|
195 |
+
eval_logger.info(
|
196 |
+
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
|
197 |
+
)
|
198 |
+
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
|
199 |
+
model_args,
|
200 |
+
{
|
201 |
+
"batch_size": batch_size,
|
202 |
+
"max_batch_size": max_batch_size,
|
203 |
+
"device": device,
|
204 |
+
},
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
if not isinstance(model, lm_eval.api.model.LM):
|
208 |
+
raise TypeError
|
209 |
+
eval_logger.info("Using pre-initialized model")
|
210 |
+
lm = model
|
211 |
+
|
212 |
+
if use_cache is not None:
|
213 |
+
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
|
214 |
+
lm = lm_eval.api.model.CachingLM(
|
215 |
+
lm,
|
216 |
+
use_cache
|
217 |
+
# each rank receives a different cache db.
|
218 |
+
# necessary to avoid multiple writes to cache at once
|
219 |
+
+ "_rank"
|
220 |
+
+ str(lm.rank)
|
221 |
+
+ ".db",
|
222 |
+
)
|
223 |
+
|
224 |
+
if task_manager is None:
|
225 |
+
task_manager = TaskManager(verbosity)
|
226 |
+
|
227 |
+
task_dict = get_task_dict(tasks, task_manager)
|
228 |
+
|
229 |
+
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
|
230 |
+
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
|
231 |
+
def _adjust_config(task_dict):
|
232 |
+
adjusted_task_dict = {}
|
233 |
+
for task_name, task_obj in task_dict.items():
|
234 |
+
if isinstance(task_obj, dict):
|
235 |
+
adjusted_task_dict = {
|
236 |
+
**adjusted_task_dict,
|
237 |
+
**{task_name: _adjust_config(task_obj)},
|
238 |
+
}
|
239 |
+
|
240 |
+
else:
|
241 |
+
if task_obj.get_config("output_type") == "generate_until":
|
242 |
+
if gen_kwargs is not None:
|
243 |
+
task_obj.set_config(
|
244 |
+
key="generation_kwargs", value=gen_kwargs, update=True
|
245 |
+
)
|
246 |
+
|
247 |
+
if predict_only:
|
248 |
+
eval_logger.info(
|
249 |
+
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
|
250 |
+
)
|
251 |
+
# we have to change the class properties post-hoc. This is pretty hacky.
|
252 |
+
task_obj.override_metric(metric_name="bypass")
|
253 |
+
|
254 |
+
# override tasks' fewshot values to the provided num_fewshot arg value
|
255 |
+
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
|
256 |
+
if num_fewshot is not None:
|
257 |
+
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
|
258 |
+
eval_logger.info(
|
259 |
+
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
eval_logger.warning(
|
263 |
+
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
|
264 |
+
)
|
265 |
+
task_obj.set_config(key="num_fewshot", value=num_fewshot)
|
266 |
+
else:
|
267 |
+
# if num_fewshot not provided, and the task does not define a default one, default to 0
|
268 |
+
if (
|
269 |
+
default_num_fewshot := task_obj.get_config("num_fewshot")
|
270 |
+
) is None:
|
271 |
+
task_obj.set_config(key="num_fewshot", value=0)
|
272 |
+
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
|
273 |
+
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
|
274 |
+
eval_logger.info(
|
275 |
+
f"Setting fewshot random generator seed to {fewshot_random_seed}"
|
276 |
+
)
|
277 |
+
|
278 |
+
adjusted_task_dict[task_name] = task_obj
|
279 |
+
|
280 |
+
return adjusted_task_dict
|
281 |
+
|
282 |
+
task_dict = _adjust_config(task_dict)
|
283 |
+
|
284 |
+
if check_integrity:
|
285 |
+
run_task_tests(task_list=tasks)
|
286 |
+
|
287 |
+
if evaluation_tracker is not None:
|
288 |
+
evaluation_tracker.general_config_tracker.log_experiment_args(
|
289 |
+
model_source=model,
|
290 |
+
model_args=model_args,
|
291 |
+
system_instruction=system_instruction,
|
292 |
+
chat_template=lm.chat_template if apply_chat_template else None,
|
293 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
294 |
+
)
|
295 |
+
|
296 |
+
results = evaluate(
|
297 |
+
lm=lm,
|
298 |
+
task_dict=task_dict,
|
299 |
+
limit=limit,
|
300 |
+
cache_requests=cache_requests,
|
301 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
302 |
+
bootstrap_iters=bootstrap_iters,
|
303 |
+
write_out=write_out,
|
304 |
+
log_samples=True if predict_only else log_samples,
|
305 |
+
system_instruction=system_instruction,
|
306 |
+
apply_chat_template=apply_chat_template,
|
307 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
308 |
+
verbosity=verbosity,
|
309 |
+
)
|
310 |
+
|
311 |
+
if lm.rank == 0:
|
312 |
+
if isinstance(model, str):
|
313 |
+
model_name = model
|
314 |
+
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
|
315 |
+
model_name = model.config._name_or_path
|
316 |
+
else:
|
317 |
+
model_name = type(model).__name__
|
318 |
+
|
319 |
+
# add info about the model and few shot config
|
320 |
+
results["config"] = {
|
321 |
+
"model": model_name,
|
322 |
+
"model_args": model_args,
|
323 |
+
}
|
324 |
+
# add more detailed model info if available
|
325 |
+
if isinstance(lm, lm_eval.models.huggingface.HFLM):
|
326 |
+
results["config"].update(lm.get_model_info())
|
327 |
+
# add info about execution
|
328 |
+
results["config"].update(
|
329 |
+
{
|
330 |
+
"batch_size": batch_size,
|
331 |
+
"batch_sizes": (
|
332 |
+
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
|
333 |
+
),
|
334 |
+
"device": device,
|
335 |
+
"use_cache": use_cache,
|
336 |
+
"limit": limit,
|
337 |
+
"bootstrap_iters": bootstrap_iters,
|
338 |
+
"gen_kwargs": gen_kwargs,
|
339 |
+
"random_seed": random_seed,
|
340 |
+
"numpy_seed": numpy_random_seed,
|
341 |
+
"torch_seed": torch_random_seed,
|
342 |
+
"fewshot_seed": fewshot_random_seed,
|
343 |
+
}
|
344 |
+
)
|
345 |
+
results["git_hash"] = get_git_commit_hash()
|
346 |
+
results["date"] = start_date
|
347 |
+
add_env_info(results) # additional environment info to results
|
348 |
+
add_tokenizer_info(results, lm) # additional info about tokenizer
|
349 |
+
return results
|
350 |
+
else:
|
351 |
+
return None
|
352 |
+
|
353 |
+
|
354 |
+
@positional_deprecated
|
355 |
+
def evaluate(
|
356 |
+
lm: "LM",
|
357 |
+
task_dict,
|
358 |
+
limit: Optional[int] = None,
|
359 |
+
cache_requests: bool = False,
|
360 |
+
rewrite_requests_cache: bool = False,
|
361 |
+
bootstrap_iters: Optional[int] = 100000,
|
362 |
+
write_out: bool = False,
|
363 |
+
log_samples: bool = True,
|
364 |
+
system_instruction: Optional[str] = None,
|
365 |
+
apply_chat_template: bool = False,
|
366 |
+
fewshot_as_multiturn: bool = False,
|
367 |
+
verbosity: str = "INFO",
|
368 |
+
):
|
369 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
370 |
+
|
371 |
+
:param lm: obj
|
372 |
+
Language Model
|
373 |
+
:param task_dict: dict[str, Task]
|
374 |
+
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
|
375 |
+
:param limit: int, optional
|
376 |
+
Limit the number of examples per task (only use this for testing)
|
377 |
+
:param bootstrap_iters:
|
378 |
+
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
|
379 |
+
:param write_out: bool
|
380 |
+
If True, write out an example document and model input for checking task integrity
|
381 |
+
:param log_samples: bool
|
382 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
383 |
+
:param system_instruction: str
|
384 |
+
System instruction to be applied to the prompt
|
385 |
+
:param apply_chat_template: bool
|
386 |
+
If True, apply chat template to the prompt
|
387 |
+
:param fewshot_as_multiturn: bool
|
388 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
389 |
+
:return
|
390 |
+
Dictionary of results
|
391 |
+
"""
|
392 |
+
|
393 |
+
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
|
394 |
+
|
395 |
+
# tracks all Instances/requests a model must generate output on.
|
396 |
+
requests = defaultdict(list)
|
397 |
+
# stores the amount to pad out reqs per req. type so that
|
398 |
+
# number of fwd passes per distributed rank is equal
|
399 |
+
padding_requests = defaultdict(int)
|
400 |
+
|
401 |
+
# get lists of group hierarchy and each type of request
|
402 |
+
eval_tasks = get_task_list(task_dict)
|
403 |
+
if not log_samples:
|
404 |
+
if not all(
|
405 |
+
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
|
406 |
+
for task_output in eval_tasks
|
407 |
+
):
|
408 |
+
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
|
409 |
+
for task_output in eval_tasks:
|
410 |
+
task: Task = task_output.task
|
411 |
+
limit = get_sample_size(task, limit)
|
412 |
+
task.build_all_requests(
|
413 |
+
limit=limit,
|
414 |
+
rank=lm.rank,
|
415 |
+
world_size=lm.world_size,
|
416 |
+
cache_requests=cache_requests,
|
417 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
418 |
+
system_instruction=system_instruction,
|
419 |
+
apply_chat_template=apply_chat_template,
|
420 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
421 |
+
chat_template=getattr(lm, "apply_chat_template")
|
422 |
+
if apply_chat_template
|
423 |
+
else None,
|
424 |
+
tokenizer_name=getattr(lm, "tokenizer_name", "")
|
425 |
+
if apply_chat_template
|
426 |
+
else "",
|
427 |
+
)
|
428 |
+
eval_logger.debug(
|
429 |
+
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
|
430 |
+
)
|
431 |
+
if write_out:
|
432 |
+
print_writeout(task)
|
433 |
+
# aggregate Instances by LM method requested to get output.
|
434 |
+
for instance in task.instances:
|
435 |
+
reqtype = instance.request_type
|
436 |
+
requests[reqtype].append(instance)
|
437 |
+
|
438 |
+
if lm.world_size > 1:
|
439 |
+
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
|
440 |
+
gathered_item = (
|
441 |
+
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
|
442 |
+
)
|
443 |
+
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
|
444 |
+
reqtype = (
|
445 |
+
"loglikelihood"
|
446 |
+
if task.OUTPUT_TYPE == "multiple_choice"
|
447 |
+
else task.OUTPUT_TYPE
|
448 |
+
)
|
449 |
+
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
|
450 |
+
numpad = max(gathered_item) - gathered_item[lm.rank]
|
451 |
+
# todo: may not account for padding in cases like SquadV2 which has multiple req types
|
452 |
+
padding_requests[reqtype] += numpad
|
453 |
+
|
454 |
+
### Run LM on inputs, get all outputs ###
|
455 |
+
# execute each type of request
|
456 |
+
for reqtype, reqs in requests.items():
|
457 |
+
eval_logger.info(f"Running {reqtype} requests")
|
458 |
+
# create `K` copies of each request `req` based off `K = req.repeats`
|
459 |
+
cloned_reqs = []
|
460 |
+
for req in reqs:
|
461 |
+
cloned_reqs.extend([req] * req.repeats)
|
462 |
+
|
463 |
+
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
|
464 |
+
for _ in range(padding_requests[reqtype]):
|
465 |
+
cloned_reqs.extend([req] * req.repeats)
|
466 |
+
|
467 |
+
# run requests through model
|
468 |
+
resps = getattr(lm, reqtype)(cloned_reqs)
|
469 |
+
|
470 |
+
# put responses from model into a list of length K for each request.
|
471 |
+
for x, req in zip(resps, cloned_reqs):
|
472 |
+
req.resps.append(x)
|
473 |
+
|
474 |
+
if lm.world_size > 1:
|
475 |
+
lm.accelerator.wait_for_everyone()
|
476 |
+
|
477 |
+
RANK = lm.rank
|
478 |
+
WORLD_SIZE = lm.world_size
|
479 |
+
### Postprocess outputs ###
|
480 |
+
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
|
481 |
+
for task_output in eval_tasks:
|
482 |
+
task = task_output.task
|
483 |
+
task.apply_filters()
|
484 |
+
|
485 |
+
### Collect values of metrics on all datapoints ###
|
486 |
+
# # unpack results and sort back in order and return control to Task
|
487 |
+
# TODO: make it possible to use a different metric per filter
|
488 |
+
# Pre-process task.instances to group by doc_id
|
489 |
+
instances_by_doc_id = defaultdict(list)
|
490 |
+
for instance in task.instances:
|
491 |
+
instances_by_doc_id[instance.doc_id].append(instance)
|
492 |
+
# Sort instances within each group
|
493 |
+
for instances in instances_by_doc_id.values():
|
494 |
+
instances.sort(key=lambda x: x.idx)
|
495 |
+
# iterate over different filters used
|
496 |
+
for filter_key in task.instances[0].filtered_resps.keys():
|
497 |
+
doc_iterator = task.doc_iterator(
|
498 |
+
rank=RANK, limit=limit, world_size=WORLD_SIZE
|
499 |
+
)
|
500 |
+
for doc_id, doc in doc_iterator:
|
501 |
+
requests = instances_by_doc_id[doc_id]
|
502 |
+
metrics = task.process_results(
|
503 |
+
doc, [req.filtered_resps[filter_key] for req in requests]
|
504 |
+
)
|
505 |
+
if log_samples:
|
506 |
+
target = task.doc_to_target(doc)
|
507 |
+
example = {
|
508 |
+
"doc_id": doc_id,
|
509 |
+
"doc": doc,
|
510 |
+
"target": target,
|
511 |
+
"arguments": [req.args for req in requests],
|
512 |
+
"resps": [req.resps for req in requests],
|
513 |
+
"filtered_resps": [
|
514 |
+
req.filtered_resps[filter_key] for req in requests
|
515 |
+
],
|
516 |
+
"doc_hash": hash_string(
|
517 |
+
json.dumps(
|
518 |
+
requests[0].doc,
|
519 |
+
indent=2,
|
520 |
+
default=handle_non_serializable,
|
521 |
+
ensure_ascii=False,
|
522 |
+
)
|
523 |
+
),
|
524 |
+
"prompt_hash": hash_string(requests[0].arguments[0]),
|
525 |
+
"target_hash": hash_string(str(target)),
|
526 |
+
}
|
527 |
+
example.update(metrics)
|
528 |
+
task_output.logged_samples.append(example)
|
529 |
+
for metric, value in metrics.items():
|
530 |
+
task_output.sample_metrics[(metric, filter_key)].append(value)
|
531 |
+
|
532 |
+
if WORLD_SIZE > 1:
|
533 |
+
# if multigpu, then gather data across all ranks to rank 0
|
534 |
+
# first gather logged samples across all ranks
|
535 |
+
for task_output in eval_tasks:
|
536 |
+
if log_samples:
|
537 |
+
# for task_name, task_samples in list(samples.items()):
|
538 |
+
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
|
539 |
+
torch.distributed.gather_object(
|
540 |
+
obj=task_output.logged_samples,
|
541 |
+
object_gather_list=full_samples,
|
542 |
+
dst=0,
|
543 |
+
)
|
544 |
+
|
545 |
+
if RANK == 0:
|
546 |
+
task_output.logged_samples = list(
|
547 |
+
itertools.chain.from_iterable(full_samples)
|
548 |
+
)
|
549 |
+
|
550 |
+
# then collect metrics across all ranks
|
551 |
+
for metrics in task_output.sample_metrics:
|
552 |
+
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
|
553 |
+
torch.distributed.gather_object(
|
554 |
+
obj=task_output.sample_metrics[metrics],
|
555 |
+
object_gather_list=metric_list,
|
556 |
+
dst=0,
|
557 |
+
)
|
558 |
+
if RANK == 0:
|
559 |
+
task_output.sample_metrics[metrics] = list(
|
560 |
+
itertools.chain.from_iterable(metric_list)
|
561 |
+
)
|
562 |
+
|
563 |
+
if RANK == 0:
|
564 |
+
### Aggregate results over all datapoints ###
|
565 |
+
# aggregate results ; run bootstrap CIs
|
566 |
+
for task_output in eval_tasks:
|
567 |
+
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
|
568 |
+
(
|
569 |
+
results,
|
570 |
+
samples,
|
571 |
+
configs,
|
572 |
+
versions,
|
573 |
+
num_fewshot,
|
574 |
+
higher_is_better,
|
575 |
+
) = consolidate_results(eval_tasks)
|
576 |
+
|
577 |
+
### Calculate group metrics ###
|
578 |
+
if bool(results):
|
579 |
+
results, versions, show_group_table, *_ = consolidate_group_results(
|
580 |
+
results, versions, task_dict
|
581 |
+
)
|
582 |
+
|
583 |
+
results_agg, group_agg = prepare_print_tasks(task_dict, results)
|
584 |
+
subtask_list = get_subtask_list(task_dict)
|
585 |
+
|
586 |
+
# collect all higher_is_better values for metrics
|
587 |
+
# in the group's subtasks.
|
588 |
+
# TODO: clean this up ; unify with the below metric_list loop?
|
589 |
+
_higher_is_better = {}
|
590 |
+
for group, task_list in subtask_list.items():
|
591 |
+
if (
|
592 |
+
len(task_list) != 0
|
593 |
+
): # subtask list will list "task_name": [] for solo tasks
|
594 |
+
for task in task_list:
|
595 |
+
for m, h in higher_is_better[task].items():
|
596 |
+
if m not in _higher_is_better.keys():
|
597 |
+
_higher_is_better[m] = h
|
598 |
+
|
599 |
+
if (
|
600 |
+
m in _higher_is_better
|
601 |
+
and _higher_is_better[m] is not None
|
602 |
+
and _higher_is_better[m] != h
|
603 |
+
):
|
604 |
+
eval_logger.warning(
|
605 |
+
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
|
606 |
+
)
|
607 |
+
_higher_is_better[m] = None
|
608 |
+
higher_is_better[group] = _higher_is_better
|
609 |
+
|
610 |
+
results_dict = {
|
611 |
+
"results": dict(results_agg.items()),
|
612 |
+
**(
|
613 |
+
{"groups": dict(group_agg.items())}
|
614 |
+
if (bool(group_agg) & show_group_table)
|
615 |
+
else {}
|
616 |
+
),
|
617 |
+
"group_subtasks": dict(reversed(subtask_list.items())),
|
618 |
+
"configs": dict(sorted(configs.items())),
|
619 |
+
"versions": dict(sorted(versions.items())),
|
620 |
+
"n-shot": dict(sorted(num_fewshot.items())),
|
621 |
+
"higher_is_better": dict(sorted(higher_is_better.items())),
|
622 |
+
"n-samples": {
|
623 |
+
task_output.task_name: {
|
624 |
+
"original": len(task_output.task.eval_docs),
|
625 |
+
"effective": min(
|
626 |
+
limit if limit else len(task_output.task.eval_docs),
|
627 |
+
len(task_output.task.eval_docs),
|
628 |
+
),
|
629 |
+
}
|
630 |
+
for task_output in eval_tasks
|
631 |
+
},
|
632 |
+
}
|
633 |
+
if log_samples:
|
634 |
+
results_dict["samples"] = dict(samples)
|
635 |
+
|
636 |
+
return results_dict
|
637 |
+
|
638 |
+
else:
|
639 |
+
return None
|
640 |
+
|
641 |
+
|
642 |
+
def request_caching_arg_to_dict(cache_requests: str) -> dict:
|
643 |
+
request_caching_args = {
|
644 |
+
"cache_requests": cache_requests in {"true", "refresh"},
|
645 |
+
"rewrite_requests_cache": cache_requests == "refresh",
|
646 |
+
"delete_requests_cache": cache_requests == "delete",
|
647 |
+
}
|
648 |
+
|
649 |
+
return request_caching_args
|
scripts/yans/lm-evaluation-harness/lm_eval/evaluator_utils.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import math
|
3 |
+
import pathlib
|
4 |
+
import sys
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
from lm_eval.api.group import ConfigurableGroup
|
8 |
+
from lm_eval.api.metrics import (
|
9 |
+
aggregate_subtask_metrics,
|
10 |
+
pooled_sample_stderr,
|
11 |
+
stderr_for_metric,
|
12 |
+
)
|
13 |
+
from lm_eval.api.task import Task
|
14 |
+
from lm_eval.utils import eval_logger, positional_deprecated
|
15 |
+
|
16 |
+
|
17 |
+
class TaskOutput:
|
18 |
+
"""
|
19 |
+
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
task (object): The task object.
|
23 |
+
task_name (str): The name of the task.
|
24 |
+
task_config (dict): The configuration of the task.
|
25 |
+
version (str): The version of the task.
|
26 |
+
group_name (str): The name of the task group.
|
27 |
+
n_shot (int): The number of shots for the task.
|
28 |
+
task_alias (str): The alias of the task.
|
29 |
+
group_alias (str): The alias of the task group.
|
30 |
+
is_group (bool): Indicates if the task is a group.
|
31 |
+
logged_samples (list): The list of logged samples.
|
32 |
+
sample_len (int): The length of the samples.
|
33 |
+
sample_metrics (defaultdict): The dictionary of samples' metrics.
|
34 |
+
agg_metrics (defaultdict): The dictionary of aggregate metrics.
|
35 |
+
|
36 |
+
Methods:
|
37 |
+
from_taskdict(cls, task_name: str, task):
|
38 |
+
Creates a TaskOutput instance from a task dictionary.
|
39 |
+
|
40 |
+
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
|
41 |
+
Calculates the aggregate metrics for the task.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
task=None,
|
47 |
+
task_name=None,
|
48 |
+
task_config=None,
|
49 |
+
version=None,
|
50 |
+
group_name=None,
|
51 |
+
n_shot=None,
|
52 |
+
task_alias=None,
|
53 |
+
group_alias=None,
|
54 |
+
is_group=None,
|
55 |
+
):
|
56 |
+
self.task = task
|
57 |
+
self.task_config = task_config
|
58 |
+
self.task_name = task_name
|
59 |
+
self.group_name = group_name
|
60 |
+
self.version = version
|
61 |
+
self.n_shot = n_shot
|
62 |
+
self.task_alias = task_alias
|
63 |
+
self.group_alias = group_alias
|
64 |
+
self.is_group = is_group
|
65 |
+
self.logged_samples = []
|
66 |
+
self.sample_len = None
|
67 |
+
self.sample_metrics = collections.defaultdict(list)
|
68 |
+
self.agg_metrics = collections.defaultdict(list)
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def from_taskdict(cls, task_name: str, task):
|
72 |
+
if isinstance(task, tuple):
|
73 |
+
group_name, task = task
|
74 |
+
else:
|
75 |
+
group_name = None
|
76 |
+
if not task:
|
77 |
+
# these gets filtered out in get_task_list
|
78 |
+
# once they are added to group hierarchy
|
79 |
+
is_group = True
|
80 |
+
return cls(
|
81 |
+
task=task, task_name=task_name, is_group=is_group, group_name=group_name
|
82 |
+
)
|
83 |
+
version = task.VERSION
|
84 |
+
task_config = dict(task.dump_config())
|
85 |
+
if (n_shot := task_config.get("num_fewshot")) == 0:
|
86 |
+
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
|
87 |
+
task_alias = task_config.get("alias")
|
88 |
+
group_alias = task_config.get("group_alias")
|
89 |
+
return cls(
|
90 |
+
task=task,
|
91 |
+
task_name=task_name,
|
92 |
+
task_config=task_config,
|
93 |
+
group_name=group_name,
|
94 |
+
version=version,
|
95 |
+
n_shot=n_shot,
|
96 |
+
task_alias=task_alias,
|
97 |
+
group_alias=group_alias,
|
98 |
+
)
|
99 |
+
|
100 |
+
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
|
101 |
+
for (metric, filter_key), items in self.sample_metrics.items():
|
102 |
+
agg_fn = self.task.aggregation()[metric]
|
103 |
+
metric_key = f"{metric},{filter_key}"
|
104 |
+
self.agg_metrics[metric_key] = agg_fn(items)
|
105 |
+
self.sample_len = len(items) # TODO: same sample size for each metric?
|
106 |
+
if isinstance(bootstrap_iters, int):
|
107 |
+
stderr_fn = stderr_for_metric(
|
108 |
+
metric=agg_fn,
|
109 |
+
bootstrap_iters=min(bootstrap_iters, 100)
|
110 |
+
if metric in ["bleu", "chrf", "ter"]
|
111 |
+
else bootstrap_iters,
|
112 |
+
)
|
113 |
+
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
|
114 |
+
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
raise ValueError(
|
118 |
+
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
|
119 |
+
)
|
120 |
+
|
121 |
+
def __repr__(self):
|
122 |
+
return (
|
123 |
+
f"TaskOutput(task_name={self.task_name}, "
|
124 |
+
f"group_name={self.group_name}, "
|
125 |
+
f"version={self.version}, "
|
126 |
+
f"n_shot={self.n_shot}, "
|
127 |
+
f"task_alias={self.task_alias}, "
|
128 |
+
f"group_alias={self.group_alias})"
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
def get_task_list(task_dict: dict) -> List[TaskOutput]:
|
133 |
+
outputs = []
|
134 |
+
for task_name, task_obj in task_dict.items():
|
135 |
+
if isinstance(task_obj, dict):
|
136 |
+
_outputs = get_task_list(task_obj)
|
137 |
+
outputs.extend(_outputs)
|
138 |
+
else:
|
139 |
+
task_output = TaskOutput.from_taskdict(task_name, task_obj)
|
140 |
+
outputs.append(task_output)
|
141 |
+
|
142 |
+
return outputs
|
143 |
+
|
144 |
+
|
145 |
+
def get_subtask_list(task_dict, task_root=None, depth=0):
|
146 |
+
subtask_list = {}
|
147 |
+
for group_obj, task_obj in task_dict.items():
|
148 |
+
if isinstance(group_obj, ConfigurableGroup):
|
149 |
+
# group_name = group_obj.group_name
|
150 |
+
group_name = group_obj.group_name
|
151 |
+
else:
|
152 |
+
group_name = group_obj
|
153 |
+
if isinstance(task_obj, dict):
|
154 |
+
_subtask_list = get_subtask_list(
|
155 |
+
task_obj, task_root=group_name, depth=depth + 1
|
156 |
+
)
|
157 |
+
if task_root:
|
158 |
+
subtask_list.setdefault((task_root, depth), []).extend(
|
159 |
+
[
|
160 |
+
_task
|
161 |
+
for (_task, _depth) in _subtask_list.keys()
|
162 |
+
if (_depth - 1) == depth
|
163 |
+
]
|
164 |
+
)
|
165 |
+
|
166 |
+
subtask_list = {**subtask_list, **_subtask_list}
|
167 |
+
else:
|
168 |
+
if isinstance(task_obj, ConfigurableGroup):
|
169 |
+
# group_or_task_name = task_obj.group_name
|
170 |
+
group_or_task_name = task_obj.group_name
|
171 |
+
elif isinstance(task_obj, Task):
|
172 |
+
# group_or_task_name = task_obj.task_name
|
173 |
+
group_or_task_name = task_obj.task_name
|
174 |
+
|
175 |
+
if task_root is None:
|
176 |
+
subtask_list.setdefault((group_or_task_name, depth), [])
|
177 |
+
else:
|
178 |
+
subtask_list.setdefault((task_root, depth), []).append(
|
179 |
+
group_or_task_name
|
180 |
+
)
|
181 |
+
|
182 |
+
if depth == 0:
|
183 |
+
_subtask_list = {}
|
184 |
+
for group_key, task_list in subtask_list.items():
|
185 |
+
group_name, depth = group_key
|
186 |
+
_subtask_list[group_name] = task_list
|
187 |
+
subtask_list = _subtask_list
|
188 |
+
|
189 |
+
return subtask_list
|
190 |
+
|
191 |
+
|
192 |
+
def print_writeout(task) -> None:
|
193 |
+
for inst in task.instances:
|
194 |
+
# print the prompt for the first few documents
|
195 |
+
if inst.doc_id < 1:
|
196 |
+
eval_logger.info(
|
197 |
+
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
|
198 |
+
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
|
199 |
+
)
|
200 |
+
eval_logger.info(f"Request: {str(inst)}")
|
201 |
+
|
202 |
+
|
203 |
+
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
|
204 |
+
if limit is not None:
|
205 |
+
limit = (
|
206 |
+
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
|
207 |
+
)
|
208 |
+
return limit
|
209 |
+
|
210 |
+
|
211 |
+
def prepare_print_tasks(
|
212 |
+
task_dict: dict,
|
213 |
+
results: dict,
|
214 |
+
task_depth=0,
|
215 |
+
group_depth=0,
|
216 |
+
) -> Tuple[dict, dict]:
|
217 |
+
"""
|
218 |
+
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
|
219 |
+
value is a list of task names.
|
220 |
+
@param results: Dictionary containing the results of each task. Each key is a
|
221 |
+
group name and its value is a dictionary of task results.
|
222 |
+
@param task_depth: The indentation level for printing the task
|
223 |
+
hierarchy. Default is 0.
|
224 |
+
@param group_depth: The indentation level for printing the group
|
225 |
+
hierarchy. Default is 0.
|
226 |
+
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
|
227 |
+
aggregated results for each task, and groups_agg contains aggregated results for each group.
|
228 |
+
|
229 |
+
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
|
230 |
+
"""
|
231 |
+
|
232 |
+
def _sort_task_dict(task_dict):
|
233 |
+
"""
|
234 |
+
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
|
235 |
+
Required so that we end up sorting within each sub-header correctly.
|
236 |
+
"""
|
237 |
+
|
238 |
+
return dict(
|
239 |
+
sorted(
|
240 |
+
task_dict.items(),
|
241 |
+
key=lambda item: item[0].group_name
|
242 |
+
if isinstance(item[0], ConfigurableGroup)
|
243 |
+
else item[0],
|
244 |
+
)
|
245 |
+
)
|
246 |
+
|
247 |
+
task_agg = collections.defaultdict(dict)
|
248 |
+
group_agg = collections.defaultdict(dict)
|
249 |
+
task_dict = _sort_task_dict(task_dict)
|
250 |
+
for task_or_group_name, task_or_group_obj in task_dict.items():
|
251 |
+
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
|
252 |
+
if isinstance(task_or_group_name, ConfigurableGroup):
|
253 |
+
# string_name = task_or_group_name.group_name
|
254 |
+
name = task_or_group_name.group_name
|
255 |
+
from_configurable_group = True
|
256 |
+
task_or_group_obj = _sort_task_dict(task_or_group_obj)
|
257 |
+
elif isinstance(task_or_group_name, str):
|
258 |
+
name = task_or_group_name
|
259 |
+
if isinstance(task_or_group_obj, Task):
|
260 |
+
# string_name = task_or_group_obj.task_name
|
261 |
+
name = task_or_group_obj.task_name
|
262 |
+
from_configurable_group = False
|
263 |
+
|
264 |
+
task_agg[name] = results[name].copy()
|
265 |
+
if from_configurable_group:
|
266 |
+
if task_or_group_name.group_alias is not None:
|
267 |
+
alias = task_or_group_name.group_alias
|
268 |
+
else:
|
269 |
+
alias = task_or_group_name.group
|
270 |
+
else:
|
271 |
+
if "alias" in task_agg[name]:
|
272 |
+
alias = task_agg[name]["alias"]
|
273 |
+
else:
|
274 |
+
alias = name
|
275 |
+
|
276 |
+
task_agg[name]["alias"] = tab_string + alias
|
277 |
+
if "samples" in task_agg[name]:
|
278 |
+
task_agg[name].pop("samples")
|
279 |
+
|
280 |
+
if from_configurable_group and (" " not in results[name]):
|
281 |
+
group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
|
282 |
+
group_agg[name] = results[name].copy()
|
283 |
+
group_agg[name]["alias"] = group_tab_string + alias
|
284 |
+
if "samples" in group_agg[name]:
|
285 |
+
group_agg[name].pop("samples")
|
286 |
+
|
287 |
+
if isinstance(task_or_group_obj, dict):
|
288 |
+
task_depth += 1
|
289 |
+
group_depth += 1
|
290 |
+
_task_agg, _group_agg = prepare_print_tasks(
|
291 |
+
task_or_group_obj, results, task_depth, group_depth
|
292 |
+
)
|
293 |
+
task_agg = {
|
294 |
+
**task_agg,
|
295 |
+
**_task_agg,
|
296 |
+
}
|
297 |
+
group_agg = {**group_agg, **_group_agg}
|
298 |
+
task_depth -= 1
|
299 |
+
group_depth -= 1
|
300 |
+
return task_agg, group_agg
|
301 |
+
|
302 |
+
|
303 |
+
def consolidate_results(
|
304 |
+
eval_tasks: List[TaskOutput],
|
305 |
+
) -> Tuple[dict, dict, dict, dict, dict, dict]:
|
306 |
+
"""
|
307 |
+
@param eval_tasks: list(TaskOutput).
|
308 |
+
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
|
309 |
+
|
310 |
+
Consolidates the results of multiple evaluation tasks into a single structure.
|
311 |
+
|
312 |
+
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
|
313 |
+
results structure. The consolidated results structure has the following properties:
|
314 |
+
|
315 |
+
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
|
316 |
+
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
|
317 |
+
aliases specified in the task configuration.
|
318 |
+
- samples: A defaultdict with task names as keys and lists of log samples as values.
|
319 |
+
- configs: A defaultdict with task names as keys and task configurations as values.
|
320 |
+
- versions: A defaultdict with task names as keys and task versions as values.
|
321 |
+
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
|
322 |
+
- higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
|
323 |
+
for each metric as values.
|
324 |
+
|
325 |
+
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
|
326 |
+
"""
|
327 |
+
# stores the final result for each task, for each metric/filter pair.
|
328 |
+
results = collections.defaultdict(dict)
|
329 |
+
# logs info about each document evaluated.
|
330 |
+
samples = collections.defaultdict(list)
|
331 |
+
# store num-fewshot value per task
|
332 |
+
num_fewshot = collections.defaultdict(int)
|
333 |
+
# Tracks the YAML configs of all chosen task
|
334 |
+
configs = collections.defaultdict(dict)
|
335 |
+
# Tracks each task's version.
|
336 |
+
versions = collections.defaultdict(dict)
|
337 |
+
# Track `higher_is_better` for each metric
|
338 |
+
higher_is_better = collections.defaultdict(dict)
|
339 |
+
|
340 |
+
for task_output in eval_tasks:
|
341 |
+
if "task_alias" in (task_config := task_output.task_config):
|
342 |
+
results[task_output.task_name]["alias"] = task_config["task_alias"]
|
343 |
+
else:
|
344 |
+
results[task_output.task_name]["alias"] = task_output.task_name
|
345 |
+
if group_alias := task_output.group_alias:
|
346 |
+
if group_alias not in results and (group_name := task_output.group_name):
|
347 |
+
results[group_name]["alias"] = group_alias
|
348 |
+
num_fewshot[task_output.task_name] = task_output.n_shot
|
349 |
+
configs[task_output.task_name] = task_output.task_config
|
350 |
+
versions[task_output.task_name] = task_output.version
|
351 |
+
samples[task_output.task_name] = task_output.logged_samples
|
352 |
+
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
|
353 |
+
for (metric, filter_key), items in task_output.sample_metrics.items():
|
354 |
+
metric_key = f"{metric},{filter_key}"
|
355 |
+
results[task_output.task_name][metric_key] = task_output.agg_metrics[
|
356 |
+
metric_key
|
357 |
+
]
|
358 |
+
results[task_output.task_name]["samples"] = task_output.sample_len
|
359 |
+
results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
|
360 |
+
task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
|
361 |
+
)
|
362 |
+
return results, samples, configs, versions, num_fewshot, higher_is_better
|
363 |
+
|
364 |
+
|
365 |
+
def consolidate_group_results(
|
366 |
+
results,
|
367 |
+
versions,
|
368 |
+
task_dict,
|
369 |
+
task_root=None,
|
370 |
+
show_group_table=False,
|
371 |
+
task_aggregation_list=None,
|
372 |
+
) -> Tuple[dict, dict, bool, Union[None,]]:
|
373 |
+
"""
|
374 |
+
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
|
375 |
+
|
376 |
+
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
|
377 |
+
|
378 |
+
- results: A defaultdict with task names (and, after this function is called, group names of
|
379 |
+
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
|
380 |
+
- versions: A defaultdict with task names (and, after this function is called, group names of
|
381 |
+
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
|
382 |
+
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
|
383 |
+
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
|
384 |
+
|
385 |
+
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
|
386 |
+
In the top-level invocation of this function, task_aggregation_list is ignored.
|
387 |
+
"""
|
388 |
+
if task_root is None:
|
389 |
+
task_root = {}
|
390 |
+
|
391 |
+
if task_aggregation_list is None:
|
392 |
+
task_aggregation_list = {}
|
393 |
+
|
394 |
+
for group_or_task, group_or_task_info in task_dict.items():
|
395 |
+
# Convert to string
|
396 |
+
if isinstance(group_or_task, ConfigurableGroup):
|
397 |
+
group_config = group_or_task.config
|
398 |
+
group_or_task = group_or_task.group_name
|
399 |
+
else:
|
400 |
+
group_config = None
|
401 |
+
|
402 |
+
if isinstance(group_or_task_info, Task):
|
403 |
+
if task_root:
|
404 |
+
task_aggregation_list.setdefault(task_root, []).append(
|
405 |
+
group_or_task_info.task_name
|
406 |
+
)
|
407 |
+
else:
|
408 |
+
(
|
409 |
+
results,
|
410 |
+
versions,
|
411 |
+
show_group_table,
|
412 |
+
_task_aggregation_list,
|
413 |
+
) = consolidate_group_results(
|
414 |
+
results,
|
415 |
+
versions,
|
416 |
+
group_or_task_info,
|
417 |
+
group_or_task,
|
418 |
+
show_group_table,
|
419 |
+
task_aggregation_list,
|
420 |
+
)
|
421 |
+
if task_root:
|
422 |
+
task_aggregation_list.setdefault(task_root, []).extend(
|
423 |
+
task_aggregation_list.get(group_or_task, [])
|
424 |
+
)
|
425 |
+
|
426 |
+
if (group_config is None) or (
|
427 |
+
group_config["aggregate_metric_list"] is None
|
428 |
+
):
|
429 |
+
results[group_or_task][" "] = " "
|
430 |
+
continue
|
431 |
+
|
432 |
+
if "aggregate_metric_list" in group_config:
|
433 |
+
agg_metric_list = group_config["aggregate_metric_list"]
|
434 |
+
|
435 |
+
show_group_table = show_group_table | bool(
|
436 |
+
group_config["aggregate_metric_list"]
|
437 |
+
)
|
438 |
+
|
439 |
+
task_list = _task_aggregation_list[group_or_task]
|
440 |
+
|
441 |
+
metric_list = list(
|
442 |
+
{
|
443 |
+
key
|
444 |
+
for task in task_list
|
445 |
+
for key in results[task].keys()
|
446 |
+
if "_stderr" not in key and key not in ["task", "alias", "samples"]
|
447 |
+
}
|
448 |
+
)
|
449 |
+
for metric in metric_list:
|
450 |
+
stderr = "_stderr,".join(metric.split(","))
|
451 |
+
|
452 |
+
# gather metrics, sizes, and stderrs from subtasks
|
453 |
+
metrics = [
|
454 |
+
results[task][metric]
|
455 |
+
for task in task_list
|
456 |
+
if metric in results[task]
|
457 |
+
] # TODO: copy?
|
458 |
+
stderrs = [
|
459 |
+
results[task][stderr]
|
460 |
+
for task in task_list
|
461 |
+
if stderr in results[task]
|
462 |
+
]
|
463 |
+
sizes = [
|
464 |
+
results[task]["samples"]
|
465 |
+
for task in task_list
|
466 |
+
if metric in results[task]
|
467 |
+
]
|
468 |
+
|
469 |
+
for metric_config in agg_metric_list:
|
470 |
+
for filter_name in metric_config["filter_list"]:
|
471 |
+
if metric != ",".join([metric_config["metric"], filter_name]):
|
472 |
+
continue
|
473 |
+
|
474 |
+
# compute group's pooled metric and stderr
|
475 |
+
if metric_config["aggregation"] == "mean":
|
476 |
+
aggregate_fn = aggregate_subtask_metrics
|
477 |
+
else:
|
478 |
+
raise ValueError(
|
479 |
+
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
|
480 |
+
)
|
481 |
+
|
482 |
+
results[group_or_task][metric] = aggregate_fn(
|
483 |
+
metrics,
|
484 |
+
sizes,
|
485 |
+
metric_config["weight_by_size"],
|
486 |
+
)
|
487 |
+
# TODO: calculate groups' metrics using arbitrary agg fns
|
488 |
+
if "N/A" in stderrs:
|
489 |
+
results[group_or_task][stderr] = "N/A"
|
490 |
+
else:
|
491 |
+
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
|
492 |
+
results[group_or_task][stderr] = pooled_sample_stderr(
|
493 |
+
stderrs, sizes
|
494 |
+
)
|
495 |
+
|
496 |
+
results[group_or_task]["samples"] = sum(sizes)
|
497 |
+
group_metadata = group_config.get("metadata", None)
|
498 |
+
if group_metadata is not None:
|
499 |
+
versions[group_or_task] = group_metadata.get("version", None)
|
500 |
+
# print(results)
|
501 |
+
return results, versions, show_group_table, task_aggregation_list
|
502 |
+
|
503 |
+
|
504 |
+
@positional_deprecated
|
505 |
+
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
|
506 |
+
"""
|
507 |
+
Search upward in the directory tree to a maximum of three layers
|
508 |
+
to find and return the package root (containing the 'tests' folder)
|
509 |
+
"""
|
510 |
+
cur_path = start_path.resolve()
|
511 |
+
max_layers = 3
|
512 |
+
for _ in range(max_layers):
|
513 |
+
if (cur_path / "tests" / "test_version_stable.py").exists():
|
514 |
+
return cur_path
|
515 |
+
else:
|
516 |
+
cur_path = cur_path.parent.resolve()
|
517 |
+
raise FileNotFoundError(
|
518 |
+
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
|
519 |
+
)
|
520 |
+
|
521 |
+
|
522 |
+
@positional_deprecated
|
523 |
+
def run_task_tests(task_list: List[str]):
|
524 |
+
"""
|
525 |
+
Find the package root and run the tests for the given tasks
|
526 |
+
"""
|
527 |
+
import pytest
|
528 |
+
|
529 |
+
package_root = find_test_root(start_path=pathlib.Path(__file__))
|
530 |
+
task_string = " or ".join(task_list)
|
531 |
+
args = [
|
532 |
+
f"{package_root}/tests/test_version_stable.py",
|
533 |
+
f"--rootdir={package_root}",
|
534 |
+
"-k",
|
535 |
+
f"{task_string}",
|
536 |
+
]
|
537 |
+
sys.path.append(str(package_root))
|
538 |
+
pytest_return_val = pytest.main(args)
|
539 |
+
if pytest_return_val:
|
540 |
+
raise ValueError(
|
541 |
+
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
|
542 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from lm_eval.api.filter import FilterEnsemble
|
5 |
+
from lm_eval.api.registry import get_filter
|
6 |
+
|
7 |
+
from . import extraction, selection, transformation
|
8 |
+
|
9 |
+
|
10 |
+
def build_filter_ensemble(
|
11 |
+
filter_name: str, components: List[List[str]]
|
12 |
+
) -> FilterEnsemble:
|
13 |
+
"""
|
14 |
+
Create a filtering pipeline.
|
15 |
+
"""
|
16 |
+
filters = []
|
17 |
+
for function, kwargs in components:
|
18 |
+
if kwargs is None:
|
19 |
+
kwargs = {}
|
20 |
+
# create a filter given its name in the registry
|
21 |
+
f = partial(get_filter(function), **kwargs)
|
22 |
+
# add the filter as a pipeline step
|
23 |
+
filters.append(f)
|
24 |
+
|
25 |
+
return FilterEnsemble(name=filter_name, filters=filters)
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (815 Bytes). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/extraction.cpython-310.pyc
ADDED
Binary file (6.01 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/selection.cpython-310.pyc
ADDED
Binary file (2.9 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/__pycache__/transformation.cpython-310.pyc
ADDED
Binary file (3.44 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/decontamination.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lm_eval.api.filter import Filter
|
2 |
+
from lm_eval.api.registry import register_filter
|
3 |
+
|
4 |
+
|
5 |
+
@register_filter("decontaminate")
|
6 |
+
class DecontaminationFilter(Filter):
|
7 |
+
"""
|
8 |
+
A filter which evaluates
|
9 |
+
"""
|
10 |
+
|
11 |
+
name = "track_decontamination"
|
12 |
+
|
13 |
+
def __init__(self, path) -> None:
|
14 |
+
"""
|
15 |
+
|
16 |
+
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
|
17 |
+
should further cache result on a given (task_name, doc_id)
|
18 |
+
"""
|
19 |
+
self._decontam_results = None
|
20 |
+
|
21 |
+
def apply(self, resps, docs) -> None:
|
22 |
+
"""
|
23 |
+
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
|
24 |
+
"""
|
25 |
+
pass
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/extraction.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
import unicodedata
|
4 |
+
|
5 |
+
from lm_eval.api.filter import Filter
|
6 |
+
from lm_eval.api.registry import register_filter
|
7 |
+
|
8 |
+
|
9 |
+
@register_filter("regex")
|
10 |
+
class RegexFilter(Filter):
|
11 |
+
""" """
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
16 |
+
group_select=0,
|
17 |
+
fallback: str = "[invalid]",
|
18 |
+
) -> None:
|
19 |
+
"""
|
20 |
+
pass a string `regex` to run `re.compile(r"regex")` on.
|
21 |
+
`fallback` defines the output returned if no matches for the regex are located.
|
22 |
+
"""
|
23 |
+
self.regex_pattern = regex_pattern
|
24 |
+
self.regex = re.compile(regex_pattern)
|
25 |
+
self.group_select = group_select
|
26 |
+
self.fallback = fallback
|
27 |
+
|
28 |
+
def apply(self, resps, docs):
|
29 |
+
# here, we assume we have a list, in which each element is
|
30 |
+
# a list of model responses for some particular input/target pair.
|
31 |
+
# so we process each of these (same input/target response sets)
|
32 |
+
# independently (and keep them a list.)
|
33 |
+
def filter_set(inst):
|
34 |
+
filtered = []
|
35 |
+
for resp in inst:
|
36 |
+
match = self.regex.findall(resp)
|
37 |
+
if match:
|
38 |
+
match = match[self.group_select]
|
39 |
+
if isinstance(match, tuple):
|
40 |
+
match = [m for m in match if m][0]
|
41 |
+
match = match.strip()
|
42 |
+
else:
|
43 |
+
match = self.fallback
|
44 |
+
filtered.append(match)
|
45 |
+
return filtered
|
46 |
+
|
47 |
+
# print(resps)
|
48 |
+
filtered_resps = list(map(lambda x: filter_set(x), resps))
|
49 |
+
# print(filtered_resps)
|
50 |
+
|
51 |
+
return filtered_resps
|
52 |
+
|
53 |
+
|
54 |
+
@register_filter("remove_whitespace")
|
55 |
+
class WhitespaceFilter(Filter):
|
56 |
+
""" """
|
57 |
+
|
58 |
+
def __init__(self) -> None:
|
59 |
+
pass
|
60 |
+
|
61 |
+
def apply(self, resps, docs):
|
62 |
+
def filter_set(inst):
|
63 |
+
filtered_resp = []
|
64 |
+
for resp in inst:
|
65 |
+
resp = resp.lstrip()
|
66 |
+
filtered_resp.append(resp)
|
67 |
+
return filtered_resp
|
68 |
+
|
69 |
+
filtered_resps = [filter_set(resp) for resp in resps]
|
70 |
+
|
71 |
+
return filtered_resps
|
72 |
+
|
73 |
+
|
74 |
+
@register_filter("multi_choice_regex")
|
75 |
+
class MultiChoiceRegexFilter(RegexFilter):
|
76 |
+
"""
|
77 |
+
A filter used to extract a model's answer on multiple choice questions with
|
78 |
+
letter answers. assumes each document has a "choices" field
|
79 |
+
containing the list of answer choices and that the answer label symbols
|
80 |
+
are of the form (A), (B), (C), ... or A, B, C.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
86 |
+
group_select=0,
|
87 |
+
fallback: str = "[invalid]",
|
88 |
+
ignore_case=False,
|
89 |
+
ignore_punctuation=False,
|
90 |
+
regexes_to_ignore=None,
|
91 |
+
) -> None:
|
92 |
+
"""
|
93 |
+
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
|
94 |
+
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
|
95 |
+
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
|
96 |
+
group_select: Selects the (group_select)th match from the findall result.
|
97 |
+
ignore_case: Ignores the case during step 1 matching
|
98 |
+
ignore_punctuation: Remove the punctuation during step 1 matching
|
99 |
+
regexes_to_ignore: Remove these regexes during step 1 matching
|
100 |
+
"""
|
101 |
+
super().__init__(regex_pattern, group_select, fallback)
|
102 |
+
self.ignore_case = ignore_case
|
103 |
+
self.ignore_punctuation = ignore_punctuation
|
104 |
+
self.regexes_to_ignore = regexes_to_ignore
|
105 |
+
|
106 |
+
def apply(self, resps, docs):
|
107 |
+
# here, we assume we have a list, in which each element is
|
108 |
+
# a list of model responses for some particular input/target pair.
|
109 |
+
# so we process each of these (same input/target response sets)
|
110 |
+
# independently (and keep them a list.)
|
111 |
+
|
112 |
+
def find_match(regex, resp, convert_dict={}):
|
113 |
+
match = regex.findall(resp)
|
114 |
+
if match:
|
115 |
+
match = match[self.group_select]
|
116 |
+
if isinstance(match, tuple):
|
117 |
+
match = [m for m in match if m][0]
|
118 |
+
match = match.strip()
|
119 |
+
if match and match in convert_dict:
|
120 |
+
match = convert_dict[match]
|
121 |
+
return match
|
122 |
+
|
123 |
+
punct_tbl = dict.fromkeys(
|
124 |
+
i
|
125 |
+
for i in range(sys.maxunicode)
|
126 |
+
if unicodedata.category(chr(i)).startswith("P")
|
127 |
+
)
|
128 |
+
|
129 |
+
def filter_ignores(st):
|
130 |
+
if self.regexes_to_ignore is not None:
|
131 |
+
for s in self.regexes_to_ignore:
|
132 |
+
st = re.sub(s, "", st)
|
133 |
+
|
134 |
+
if self.ignore_case:
|
135 |
+
st = st.lower()
|
136 |
+
|
137 |
+
if self.ignore_punctuation:
|
138 |
+
# https://stackoverflow.com/a/266162
|
139 |
+
st = st.translate(punct_tbl)
|
140 |
+
return st
|
141 |
+
|
142 |
+
filtered_resps = []
|
143 |
+
|
144 |
+
for r, doc in zip(resps, docs):
|
145 |
+
fallback_regexes = []
|
146 |
+
choice_to_alpha = {}
|
147 |
+
next_alpha = "A"
|
148 |
+
|
149 |
+
without_paren_fallback_regexes = []
|
150 |
+
without_paren_to_target = {}
|
151 |
+
|
152 |
+
choices = doc["choices"]
|
153 |
+
for c in choices:
|
154 |
+
m = filter_ignores(c.strip())
|
155 |
+
fallback_regexes.append(f"{re.escape(m)}")
|
156 |
+
choice_to_alpha[m] = f"({next_alpha})"
|
157 |
+
|
158 |
+
without_paren_fallback_regexes.append(next_alpha)
|
159 |
+
without_paren_to_target[next_alpha] = f"({next_alpha})"
|
160 |
+
|
161 |
+
next_alpha = chr(ord(next_alpha) + 1)
|
162 |
+
fallback_regex = re.compile("|".join(fallback_regexes))
|
163 |
+
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
|
164 |
+
without_paren_fallback_regex = re.compile(
|
165 |
+
f":[\s]*({without_paren_fallback_regex})"
|
166 |
+
)
|
167 |
+
|
168 |
+
filtered = []
|
169 |
+
for resp in r:
|
170 |
+
match = find_match(self.regex, resp)
|
171 |
+
if not match:
|
172 |
+
match = find_match(
|
173 |
+
fallback_regex, filter_ignores(resp), choice_to_alpha
|
174 |
+
)
|
175 |
+
if not match:
|
176 |
+
match = find_match(
|
177 |
+
without_paren_fallback_regex, resp, without_paren_to_target
|
178 |
+
)
|
179 |
+
if not match:
|
180 |
+
match = self.fallback
|
181 |
+
filtered.append(match)
|
182 |
+
filtered_resps.append(filtered)
|
183 |
+
|
184 |
+
return filtered_resps
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/selection.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
|
3 |
+
from lm_eval.api.filter import Filter
|
4 |
+
from lm_eval.api.registry import register_filter
|
5 |
+
|
6 |
+
|
7 |
+
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
|
8 |
+
# that takes an input and returns a scalar and then should select the max reward,
|
9 |
+
# or should implement different filters for different ways of handling a reward model's inference.
|
10 |
+
|
11 |
+
|
12 |
+
@register_filter("take_first")
|
13 |
+
class TakeFirstFilter(Filter):
|
14 |
+
def __init__(self) -> None:
|
15 |
+
"""
|
16 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def apply(self, resps, docs):
|
20 |
+
"""
|
21 |
+
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
|
22 |
+
"""
|
23 |
+
return map(lambda r: r[0], resps)
|
24 |
+
|
25 |
+
|
26 |
+
@register_filter("take_first_k")
|
27 |
+
class TakeKFilter(Filter):
|
28 |
+
def __init__(self, **kwargs) -> None:
|
29 |
+
self.k = kwargs.pop("k")
|
30 |
+
|
31 |
+
super().__init__(**kwargs)
|
32 |
+
|
33 |
+
def apply(self, resps, docs):
|
34 |
+
# need resp to be subscriptable to check below
|
35 |
+
resps = list(resps)
|
36 |
+
# check we have at least k responses per doc, else we can't take the first k
|
37 |
+
assert (
|
38 |
+
len(resps[0]) >= self.k
|
39 |
+
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
|
40 |
+
return map(lambda r: r[: self.k], resps)
|
41 |
+
|
42 |
+
|
43 |
+
@register_filter("majority_vote")
|
44 |
+
class MajorityVoteFilter(Filter):
|
45 |
+
def __init__(self) -> None:
|
46 |
+
"""
|
47 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def apply(self, resps, docs):
|
51 |
+
"""
|
52 |
+
Each entry of `resps` is a list of model responses.
|
53 |
+
We select the response that occurs most frequently in each entry of `resps`.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def select_majority(resp):
|
57 |
+
counts = Counter(resp)
|
58 |
+
vote = counts.most_common(1)[0][0]
|
59 |
+
return vote
|
60 |
+
|
61 |
+
return map(lambda r: [select_majority(r)], resps)
|
scripts/yans/lm-evaluation-harness/lm_eval/filters/transformation.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lm_eval.api.filter import Filter
|
2 |
+
from lm_eval.api.registry import register_filter
|
3 |
+
|
4 |
+
|
5 |
+
@register_filter("lowercase")
|
6 |
+
class LowercaseFilter(Filter):
|
7 |
+
def __init__(self) -> None:
|
8 |
+
pass
|
9 |
+
|
10 |
+
def apply(self, resps, docs):
|
11 |
+
def filter_set(inst):
|
12 |
+
return [resp.lower() for resp in inst]
|
13 |
+
|
14 |
+
return [filter_set(resp) for resp in resps]
|
15 |
+
|
16 |
+
|
17 |
+
@register_filter("uppercase")
|
18 |
+
class UppercaseFilter(Filter):
|
19 |
+
def __init__(self) -> None:
|
20 |
+
pass
|
21 |
+
|
22 |
+
def apply(self, resps, docs):
|
23 |
+
def filter_set(inst):
|
24 |
+
return [resp.upper() for resp in inst]
|
25 |
+
|
26 |
+
return [filter_set(resp) for resp in resps]
|
27 |
+
|
28 |
+
|
29 |
+
@register_filter("map")
|
30 |
+
class MapFilter(Filter):
|
31 |
+
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
|
32 |
+
"""
|
33 |
+
Initializes the MapFilter with a given mapping dictionary and default value.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
- mapping_dict (dict): A dictionary containing the key-value mappings.
|
37 |
+
Default is an empty dictionary.
|
38 |
+
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
|
39 |
+
Default is None.
|
40 |
+
|
41 |
+
Example:
|
42 |
+
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
|
43 |
+
"""
|
44 |
+
if mapping_dict is None:
|
45 |
+
mapping_dict = {}
|
46 |
+
assert isinstance(
|
47 |
+
mapping_dict, dict
|
48 |
+
), "Provided mapping_dict is not a dictionary"
|
49 |
+
self.mapping_dict = mapping_dict
|
50 |
+
self.default_value = default_value
|
51 |
+
|
52 |
+
def apply(self, resps, docs):
|
53 |
+
def filter_set(inst):
|
54 |
+
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
|
55 |
+
|
56 |
+
return [filter_set(resp) for resp in resps]
|
scripts/yans/lm-evaluation-harness/lm_eval/prompts/__init__.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import os
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
from lm_eval import utils
|
6 |
+
from lm_eval.utils import eval_logger
|
7 |
+
|
8 |
+
|
9 |
+
# Prompt library.
|
10 |
+
# Stores prompts in a dictionary indexed by 2 levels:
|
11 |
+
# prompt category name, and prompt name.
|
12 |
+
# This allows us to access prompts
|
13 |
+
PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
|
14 |
+
"qa-basic": {
|
15 |
+
"question-newline-answer": "Question: {{question}}\nAnswer:",
|
16 |
+
"q-newline-a": "Q: {{question}}\nA:",
|
17 |
+
},
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
|
22 |
+
# unpack prompt name
|
23 |
+
category_name, prompt_name = prompt_id.split(":")
|
24 |
+
if subset_name is None:
|
25 |
+
dataset_full_name = dataset_name
|
26 |
+
else:
|
27 |
+
dataset_full_name = f"{dataset_name}-{subset_name}"
|
28 |
+
eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
|
29 |
+
if category_name == "promptsource":
|
30 |
+
try:
|
31 |
+
from promptsource.templates import DatasetTemplates
|
32 |
+
except ModuleNotFoundError:
|
33 |
+
raise Exception(
|
34 |
+
"Tried to load a Promptsource template, but promptsource is not installed ",
|
35 |
+
"please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]",
|
36 |
+
)
|
37 |
+
try:
|
38 |
+
if subset_name is None:
|
39 |
+
prompts = DatasetTemplates(dataset_name=dataset_name)
|
40 |
+
else:
|
41 |
+
prompts = DatasetTemplates(
|
42 |
+
dataset_name=dataset_name, subset_name=subset_name
|
43 |
+
)
|
44 |
+
except Exception:
|
45 |
+
raise ValueError(f"{dataset_name} and {subset_name} not found")
|
46 |
+
if prompt_name in prompts.all_template_names:
|
47 |
+
return prompts[prompt_name]
|
48 |
+
else:
|
49 |
+
raise ValueError(
|
50 |
+
f"{prompt_name} not in prompt list {prompts.all_template_names}"
|
51 |
+
)
|
52 |
+
elif ".yaml" in category_name:
|
53 |
+
import yaml
|
54 |
+
|
55 |
+
with open(category_name, "rb") as file:
|
56 |
+
prompt_yaml_file = yaml.full_load(file)
|
57 |
+
|
58 |
+
prompt_string = prompt_yaml_file["prompts"][prompt_name]
|
59 |
+
return PromptString(prompt_string)
|
60 |
+
else:
|
61 |
+
try:
|
62 |
+
return PROMPT_REGISTRY[category_name][prompt_name]
|
63 |
+
except Exception:
|
64 |
+
raise ValueError(
|
65 |
+
f"expected only a single `:` as separator between \
|
66 |
+
prompt category and name, but got `{prompt_id}` instead"
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def load_prompt_list(
|
71 |
+
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
|
72 |
+
):
|
73 |
+
category_name, prompt_name = use_prompt.split(":")
|
74 |
+
|
75 |
+
if category_name == "promptsource":
|
76 |
+
from promptsource.templates import DatasetTemplates
|
77 |
+
|
78 |
+
if subset_name is None:
|
79 |
+
prompts = DatasetTemplates(dataset_name=dataset_name)
|
80 |
+
else:
|
81 |
+
prompts = DatasetTemplates(
|
82 |
+
dataset_name=dataset_name, subset_name=subset_name
|
83 |
+
)
|
84 |
+
|
85 |
+
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
|
86 |
+
|
87 |
+
elif ".yaml" in category_name:
|
88 |
+
import yaml
|
89 |
+
|
90 |
+
if yaml_path is not None:
|
91 |
+
category_name = os.path.realpath(os.path.join(yaml_path, category_name))
|
92 |
+
|
93 |
+
with open(category_name, "rb") as file:
|
94 |
+
prompt_yaml_file = yaml.full_load(file)
|
95 |
+
|
96 |
+
prompt_list = utils.pattern_match(
|
97 |
+
prompt_name, prompt_yaml_file["prompts"].keys()
|
98 |
+
)
|
99 |
+
|
100 |
+
# category_name, *prompt_name = use_prompt.split(":")
|
101 |
+
# TODO allow to multiple prompt naming
|
102 |
+
# if len(prompt_name) > 1:
|
103 |
+
# prompt_list = []
|
104 |
+
# for prompt in prompt_name:
|
105 |
+
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
|
106 |
+
# else:
|
107 |
+
# prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
|
108 |
+
return [":".join([category_name, prompt]) for prompt in prompt_list]
|
109 |
+
|
110 |
+
|
111 |
+
class PromptString:
|
112 |
+
def __init__(self, prompt_string):
|
113 |
+
self.prompt_string = prompt_string
|
114 |
+
|
115 |
+
def apply(self, doc):
|
116 |
+
doc_to_text = self.prompt_string["doc_to_text"]
|
117 |
+
doc_to_target = self.prompt_string["doc_to_target"]
|
118 |
+
|
119 |
+
# TODO need a way to process doc_to_choice
|
120 |
+
if "doc_to_choice" in self.prompt_string:
|
121 |
+
raise Exception("Not yet implemented to accept doc_to_choice")
|
122 |
+
|
123 |
+
text_string = utils.apply_template(doc_to_text, doc)
|
124 |
+
target_string = utils.apply_template(doc_to_target, doc)
|
125 |
+
|
126 |
+
return [text_string, target_string]
|
scripts/yans/lm-evaluation-harness/lm_eval/prompts/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (3.22 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/README.md
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Tasks
|
3 |
+
|
4 |
+
A list of supported tasks and task groupings can be viewed with `lm-eval --tasks list`.
|
5 |
+
|
6 |
+
For more information, including a full list of task names and their precise meanings or sources, follow the links provided to the individual README.md files for each subfolder.
|
7 |
+
|
8 |
+
| Task Family | Description | Language(s) |
|
9 |
+
|-------------|-------------|-------------|
|
10 |
+
| [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese |
|
11 |
+
| [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic |
|
12 |
+
| [agieval](agieval/README.md) | Tasks involving historical data or questions related to history and historical texts. | English, Chinese |
|
13 |
+
| [anli](anli/README.md) | Adversarial natural language inference tasks designed to test model robustness. | English |
|
14 |
+
| [arabicmmlu](arabicmmlu/README.md) | Localized Arabic version of MMLU with multiple-choice questions from 40 subjects. | Arabic |
|
15 |
+
| [arc](arc/README.md) | Tasks involving complex reasoning over a diverse set of questions. | English |
|
16 |
+
| [arithmetic](arithmetic/README.md) | Tasks involving numerical computations and arithmetic reasoning. | English |
|
17 |
+
| [asdiv](asdiv/README.md) | Tasks involving arithmetic and mathematical reasoning challenges. | English |
|
18 |
+
| [babi](babi/README.md) | Tasks designed as question and answering challenges based on simulated stories. | English |
|
19 |
+
| [basqueglue](basqueglue/README.md) | Tasks designed to evaluate language understanding in Basque language. | Basque |
|
20 |
+
| [bbh](bbh/README.md) | Tasks focused on deep semantic understanding through hypothesization and reasoning. | English, German |
|
21 |
+
| [belebele](belebele/README.md) | Language understanding tasks in a variety of languages and scripts. | Multiple (122 languages) |
|
22 |
+
| benchmarks | General benchmarking tasks that test a wide range of language understanding capabilities. | |
|
23 |
+
| [bertaqa](bertaqa/README.md) | Local Basque cultural trivia QA tests in English and Basque languages. | English, Basque, Basque (MT) |
|
24 |
+
| [bigbench](bigbench/README.md) | Broad tasks from the BIG-bench benchmark designed to push the boundaries of large models. | Multiple |
|
25 |
+
| [blimp](blimp/README.md) | Tasks testing grammatical phenomena to evaluate language model's linguistic capabilities. | English |
|
26 |
+
| [ceval](ceval/README.md) | Tasks that evaluate language understanding and reasoning in an educational context. | Chinese |
|
27 |
+
| [cmmlu](cmmlu/README.md) | Multi-subject multiple choice question tasks for comprehensive academic assessment. | Chinese |
|
28 |
+
| code_x_glue | Tasks that involve understanding and generating code across multiple programming languages. | Go, Java, JS, PHP, Python, Ruby |
|
29 |
+
| [commonsense_qa](commonsense_qa/README.md) | CommonsenseQA, a multiple-choice QA dataset for measuring commonsense knowledge. | English |
|
30 |
+
| [copal_id](copal_id/README.md) | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian |
|
31 |
+
| [coqa](coqa/README.md) | Conversational question answering tasks to test dialog understanding. | English |
|
32 |
+
| [crows_pairs](crows_pairs/README.md) | Tasks designed to test model biases in various sociodemographic groups. | English, French |
|
33 |
+
| csatqa | Tasks related to SAT and other standardized testing questions for academic assessment. | Korean |
|
34 |
+
| [drop](drop/README.md) | Tasks requiring numerical reasoning, reading comprehension, and question answering. | English |
|
35 |
+
| [eq_bench](eq_bench/README.md) | Tasks focused on equality and ethics in question answering and decision-making. | English |
|
36 |
+
| [eus_exams](eus_exams/README.md) | Tasks based on various professional and academic exams in the Basque language. | Basque |
|
37 |
+
| [eus_proficiency](eus_proficiency/README.md) | Tasks designed to test proficiency in the Basque language across various topics. | Basque |
|
38 |
+
| [eus_reading](eus_reading/README.md) | Reading comprehension tasks specifically designed for the Basque language. | Basque |
|
39 |
+
| [eus_trivia](eus_trivia/README.md) | Trivia and knowledge testing tasks in the Basque language. | Basque |
|
40 |
+
| [fda](fda/README.md) | Tasks for extracting key-value pairs from FDA documents to test information extraction. | English |
|
41 |
+
| [fld](fld/README.md) | Tasks involving free-form and directed dialogue understanding. | English |
|
42 |
+
| [french_bench](french_bench/README.md) | Set of tasks designed to assess language model performance in French. | French|
|
43 |
+
| [glue](glue/README.md) | General Language Understanding Evaluation benchmark to test broad language abilities. | English |
|
44 |
+
| [gpqa](gpqa/README.md) | Tasks designed for general public question answering and knowledge verification. | English |
|
45 |
+
| [gsm8k](gsm8k/README.md) | A benchmark of grade school math problems aimed at evaluating reasoning capabilities. | English |
|
46 |
+
| [haerae](haerae/README.md) | Tasks focused on assessing detailed factual and historical knowledge. | Korean |
|
47 |
+
| [headqa](headqa/README.md) | A high-level education-based question answering dataset to test specialized knowledge. | Spanish, English |
|
48 |
+
| [hellaswag](hellaswag/README.md) | Tasks to predict the ending of stories or scenarios, testing comprehension and creativity. | English |
|
49 |
+
| [hendrycks_ethics](hendrycks_ethics/README.md) | Tasks designed to evaluate the ethical reasoning capabilities of models. | English |
|
50 |
+
| [hendrycks_math](hendrycks_math/README.md) | Mathematical problem-solving tasks to test numerical reasoning and problem-solving. | English |
|
51 |
+
| [ifeval](ifeval/README.md) | Interactive fiction evaluation tasks for narrative understanding and reasoning. | English |
|
52 |
+
| [inverse_scaling](inverse_scaling/README.md) | Multiple-choice tasks from the Inverse Scaling Prize, designed to find settings where larger language models perform worse. | English |
|
53 |
+
| [kmmlu](kmmlu/README.md) | Knowledge-based multi-subject multiple choice questions for academic evaluation. | Korean |
|
54 |
+
| [kobest](kobest/README.md) | A collection of tasks designed to evaluate understanding in Korean language. | Korean |
|
55 |
+
| [kormedmcqa](kormedmcqa/README.md) | Medical question answering tasks in Korean to test specialized domain knowledge. | Korean |
|
56 |
+
| [lambada](lambada/README.md) | Tasks designed to predict the endings of text passages, testing language prediction skills. | English |
|
57 |
+
| [lambada_cloze](lambada_cloze/README.md) | Cloze-style LAMBADA dataset. | English |
|
58 |
+
| [lambada_multilingual](lambada_multilingual/README.md) | Multilingual LAMBADA dataset. This is a legacy version of the multilingual dataset, and users should instead use `lambada_multilingual_stablelm`. | German, English, Spanish, French, Italian |
|
59 |
+
| [lambada_multilingual_stablelm](lambada_multilingual_stablelm/README.md) | Multilingual LAMBADA dataset. Users should prefer evaluating on this version of the multilingual dataset instead of on `lambada_multilingual`. | German, English, Spanish, French, Italian, Dutch, Portuguese |
|
60 |
+
| [leaderboard](leaderboard/README.md) | Task group used by Hugging Face's [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard). Those tasks are static and will not change through time | English |
|
61 |
+
| [logiqa](logiqa/README.md) | Logical reasoning tasks requiring advanced inference and deduction. | English, Chinese |
|
62 |
+
| [logiqa2](logiqa2/README.md) | Large-scale logical reasoning dataset adapted from the Chinese Civil Service Examination. | English, Chinese |
|
63 |
+
| [mathqa](mathqa/README.md) | Question answering tasks involving mathematical reasoning and problem-solving. | English |
|
64 |
+
| [mc_taco](mc_taco/README.md) | Question-answer pairs that require temporal commonsense comprehension. | English |
|
65 |
+
| [med_concepts_qa](med_concepts_qa/README.md) | Benchmark for evaluating LLMs on their abilities to interpret medical codes and distinguish between medical concept. | English |
|
66 |
+
| medmcqa | Medical multiple choice questions assessing detailed medical knowledge. | English |
|
67 |
+
| medqa | Multiple choice question answering based on the United States Medical License Exams. | |
|
68 |
+
| [mgsm](mgsm/README.md) | Benchmark of multilingual grade-school math problems. | Spanish, French, German, Russian, Chinese, Japanese, Thai, Swahili, Bengali, Telugu |
|
69 |
+
| [minerva_math](minerva_math/README.md) | Mathematics-focused tasks requiring numerical reasoning and problem-solving skills. | English |
|
70 |
+
| mmlu | Massive Multitask Language Understanding benchmark for broad domain language evaluation. Several variants are supported. | English |
|
71 |
+
| [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigourous. | English |
|
72 |
+
| model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | |
|
73 |
+
| [mutual](mutual/README.md) | A retrieval-based dataset for multi-turn dialogue reasoning. | English |
|
74 |
+
| [nq_open](nq_open/README.md) | Open domain question answering tasks based on the Natural Questions dataset. | English |
|
75 |
+
| [okapi/arc_multilingual](okapi/arc_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) **Machine Translated.** |
|
76 |
+
| [okapi/hellaswag_multilingual](okapi/hellaswag_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (30 languages) **Machine Translated.** |
|
77 |
+
| okapi/mmlu_multilingual | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (34 languages) **Machine Translated.** |
|
78 |
+
| [okapi/truthfulqa_multilingual](okapi/truthfulqa_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) **Machine Translated.** |
|
79 |
+
| [openbookqa](openbookqa/README.md) | Open-book question answering tasks that require external knowledge and reasoning. | English |
|
80 |
+
| [paloma](paloma/README.md) | Paloma is a comprehensive benchmark designed to evaluate open language models across a wide range of domains, ranging from niche artist communities to mental health forums on Reddit. | English |
|
81 |
+
| [paws-x](paws-x/README.md) | Paraphrase Adversaries from Word Scrambling, focusing on cross-lingual capabilities. | English, French, Spanish, German, Chinese, Japanese, Korean |
|
82 |
+
| [pile](pile/README.md) | Open source language modelling data set that consists of 22 smaller, high-quality datasets. | English |
|
83 |
+
| [pile_10k](pile_10k/README.md) | The first 10K elements of The Pile, useful for debugging models trained on it. | English |
|
84 |
+
| [piqa](piqa/README.md) | Physical Interaction Question Answering tasks to test physical commonsense reasoning. | English |
|
85 |
+
| [polemo2](polemo2/README.md) | Sentiment analysis and emotion detection tasks based on Polish language data. | Polish |
|
86 |
+
| [prost](prost/README.md) | Tasks requiring understanding of professional standards and ethics in various domains. | English |
|
87 |
+
| [pubmedqa](pubmedqa/README.md) | Question answering tasks based on PubMed research articles for biomedical understanding. | English |
|
88 |
+
| [qa4mre](qa4mre/README.md) | Question Answering for Machine Reading Evaluation, assessing comprehension and reasoning. | English |
|
89 |
+
| [qasper](qasper/README.md) | Question Answering dataset based on academic papers, testing in-depth scientific knowledge. | English |
|
90 |
+
| [race](race/README.md) | Reading comprehension assessment tasks based on English exams in China. | English |
|
91 |
+
| realtoxicityprompts | Tasks to evaluate language models for generating text with potential toxicity. | |
|
92 |
+
| [sciq](sciq/README.md) | Science Question Answering tasks to assess understanding of scientific concepts. | English |
|
93 |
+
| [scrolls](scrolls/README.md) | Tasks that involve long-form reading comprehension across various domains. | English |
|
94 |
+
| [siqa](siqa/README.md) | Social Interaction Question Answering to evaluate common sense and social reasoning. | English |
|
95 |
+
| [squad_completion](squad_completion/README.md) | A variant of the SQuAD question answering task designed for zero-shot evaluation of small LMs. | English |
|
96 |
+
| [squadv2](squadv2/README.md) | Stanford Question Answering Dataset version 2, a reading comprehension benchmark. | English |
|
97 |
+
| [storycloze](storycloze/README.md) | Tasks to predict story endings, focusing on narrative logic and coherence. | English |
|
98 |
+
| [super_glue](super_glue/README.md) | A suite of challenging tasks designed to test a range of language understanding skills. | English |
|
99 |
+
| [swag](swag/README.md) | Situations With Adversarial Generations, predicting the next event in videos. | English |
|
100 |
+
| [swde](swde/README.md) | Information extraction tasks from semi-structured web pages. | English |
|
101 |
+
| [tinyBenchmarks](tinyBenchmarks/README.md) | Evaluation of large language models with fewer examples using tiny versions of popular benchmarks. | English |
|
102 |
+
| [tmmluplus](tmmluplus/README.md) | An extended set of tasks under the TMMLU framework for broader academic assessments. | Traditional Chinese |
|
103 |
+
| [toxigen](toxigen/README.md) | Tasks designed to evaluate language models on their propensity to generate toxic content. | English |
|
104 |
+
| [translation](translation/README.md) | Tasks focused on evaluating the language translation capabilities of models. | Arabic, English, Spanish, Basque, Hindi, Indonesian, Burmese, Russian, Swahili, Telugu, Chinese |
|
105 |
+
| [triviaqa](triviaqa/README.md) | A large-scale dataset for trivia question answering to test general knowledge. | English |
|
106 |
+
| [truthfulqa](truthfulqa/README.md) | A QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English |
|
107 |
+
| [unitxt](unitxt/README.md) | A number of tasks implemented using the unitxt library for flexible, shareable, and reusable data preparation and evaluation for generative AI. | English |
|
108 |
+
| [unscramble](unscramble/README.md) | Tasks involving the rearrangement of scrambled sentences to test syntactic understanding. | English |
|
109 |
+
| [webqs](webqs/README.md) | Web-based question answering tasks designed to evaluate internet search and retrieval. | English |
|
110 |
+
| [wikitext](wikitext/README.md) | Tasks based on text from Wikipedia articles to assess language modeling and generation. | English |
|
111 |
+
| [winogrande](winogrande/README.md) | A large-scale dataset for coreference resolution, inspired by the Winograd Schema Challenge. | English |
|
112 |
+
| [wmdp](wmdp/README.md) | A benchmark with the objective of minimizing performance, based on potentially-sensitive multiple-choice knowledge questions. | English |
|
113 |
+
| [wmt2016](wmt2016/README.md) | Tasks from the WMT 2016 shared task, focusing on translation between multiple languages. | English, Czech, German, Finnish, Russian, Romanian, Turkish |
|
114 |
+
| [wsc273](wsc273/README.md) | The Winograd Schema Challenge, a test of commonsense reasoning and coreference resolution. | English |
|
115 |
+
| [xcopa](xcopa/README.md) | Cross-lingual Choice of Plausible Alternatives, testing reasoning in multiple languages. | Estonian, Haitian, Indonesian, Italian, Quechua, Swahili, Tamil, Thai, Turkish, Vietnamese, Chinese |
|
116 |
+
| [xnli](xnli/README.md) | Cross-Lingual Natural Language Inference to test understanding across different languages. | Arabic, Bulgarian, German, Greek, English, Spanish, French, Hindi, Russian, Swahili, Thai, Turkish, Urdu, Vietnamese, Chinese |
|
117 |
+
| [xnli_eu](xnli_eu/README.md) | Cross-lingual Natural Language Inference tasks in Basque. | Basque |
|
118 |
+
| [xstorycloze](xstorycloze/README.md) | Cross-lingual narrative understanding tasks to predict story endings in multiple languages. | Russian, Simplified Chinese, Spanish, Arabic, Hindi, Indonesian, Telugu, Swahili, Basque, Burmese |
|
119 |
+
| [xwinograd](xwinograd/README.md) | Cross-lingual Winograd schema tasks for coreference resolution in multiple languages. | English, French, Japanese, Portuguese, Russian, Chinese |
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/__init__.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import inspect
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from functools import partial
|
6 |
+
from typing import Dict, List, Mapping, Optional, Union
|
7 |
+
|
8 |
+
from lm_eval import utils
|
9 |
+
from lm_eval.api.group import ConfigurableGroup, GroupConfig
|
10 |
+
from lm_eval.api.task import ConfigurableTask, Task
|
11 |
+
from lm_eval.evaluator_utils import get_subtask_list
|
12 |
+
|
13 |
+
|
14 |
+
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
|
15 |
+
|
16 |
+
|
17 |
+
class TaskManager:
|
18 |
+
"""TaskManager indexes all tasks from the default `lm_eval/tasks/`
|
19 |
+
and an optional directory if provided.
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
verbosity="INFO",
|
26 |
+
include_path: Optional[Union[str, List]] = None,
|
27 |
+
include_defaults: bool = True,
|
28 |
+
) -> None:
|
29 |
+
self.verbosity = verbosity
|
30 |
+
self.include_path = include_path
|
31 |
+
self.logger = utils.eval_logger
|
32 |
+
self.logger.setLevel(getattr(logging, f"{verbosity}"))
|
33 |
+
|
34 |
+
self._task_index = self.initialize_tasks(
|
35 |
+
include_path=include_path, include_defaults=include_defaults
|
36 |
+
)
|
37 |
+
self._all_tasks = sorted(list(self._task_index.keys()))
|
38 |
+
|
39 |
+
self._all_groups = sorted(
|
40 |
+
[x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
|
41 |
+
)
|
42 |
+
self._all_subtasks = sorted(
|
43 |
+
[x for x in self._all_tasks if self._task_index[x]["type"] == "task"]
|
44 |
+
)
|
45 |
+
self._all_tags = sorted(
|
46 |
+
[x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
|
47 |
+
)
|
48 |
+
|
49 |
+
self.task_group_map = collections.defaultdict(list)
|
50 |
+
|
51 |
+
def initialize_tasks(
|
52 |
+
self,
|
53 |
+
include_path: Optional[Union[str, List]] = None,
|
54 |
+
include_defaults: bool = True,
|
55 |
+
):
|
56 |
+
"""Creates a dictionary of tasks index.
|
57 |
+
|
58 |
+
:param include_path: Union[str, List] = None
|
59 |
+
An additional path to be searched for tasks recursively.
|
60 |
+
Can provide more than one such path as a list.
|
61 |
+
:param include_defaults: bool = True
|
62 |
+
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
|
63 |
+
:return
|
64 |
+
Dictionary of task names as key and task metadata
|
65 |
+
"""
|
66 |
+
if include_defaults:
|
67 |
+
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
|
68 |
+
else:
|
69 |
+
all_paths = []
|
70 |
+
if include_path is not None:
|
71 |
+
if isinstance(include_path, str):
|
72 |
+
include_path = [include_path]
|
73 |
+
all_paths.extend(include_path)
|
74 |
+
|
75 |
+
task_index = {}
|
76 |
+
for task_dir in all_paths:
|
77 |
+
tasks = self._get_task_and_group(task_dir)
|
78 |
+
task_index = {**tasks, **task_index}
|
79 |
+
|
80 |
+
return task_index
|
81 |
+
|
82 |
+
@property
|
83 |
+
def all_tasks(self):
|
84 |
+
return self._all_tasks
|
85 |
+
|
86 |
+
@property
|
87 |
+
def all_groups(self):
|
88 |
+
return self._all_groups
|
89 |
+
|
90 |
+
@property
|
91 |
+
def all_subtasks(self):
|
92 |
+
return self._all_subtasks
|
93 |
+
|
94 |
+
@property
|
95 |
+
def all_tags(self):
|
96 |
+
return self._all_tags
|
97 |
+
|
98 |
+
@property
|
99 |
+
def task_index(self):
|
100 |
+
return self._task_index
|
101 |
+
|
102 |
+
def list_all_tasks(
|
103 |
+
self, list_groups=True, list_tags=True, list_subtasks=True
|
104 |
+
) -> str:
|
105 |
+
from pytablewriter import MarkdownTableWriter
|
106 |
+
|
107 |
+
def sanitize_path(path):
|
108 |
+
# don't print full path if we are within the lm_eval/tasks dir !
|
109 |
+
# if we aren't though, provide the full path.
|
110 |
+
if "lm_eval/tasks/" in path:
|
111 |
+
return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
|
112 |
+
else:
|
113 |
+
return path
|
114 |
+
|
115 |
+
group_table = MarkdownTableWriter()
|
116 |
+
group_table.headers = ["Group", "Config Location"]
|
117 |
+
gt_values = []
|
118 |
+
for g in self.all_groups:
|
119 |
+
path = self.task_index[g]["yaml_path"]
|
120 |
+
if path == -1:
|
121 |
+
path = "---"
|
122 |
+
else:
|
123 |
+
path = sanitize_path(path)
|
124 |
+
gt_values.append([g, path])
|
125 |
+
group_table.value_matrix = gt_values
|
126 |
+
|
127 |
+
tag_table = MarkdownTableWriter()
|
128 |
+
tag_table.headers = ["Tag"]
|
129 |
+
tag_table.value_matrix = [[t] for t in self.all_tags]
|
130 |
+
|
131 |
+
subtask_table = MarkdownTableWriter()
|
132 |
+
subtask_table.headers = ["Task", "Config Location", "Output Type"]
|
133 |
+
st_values = []
|
134 |
+
for t in self.all_subtasks:
|
135 |
+
path = self.task_index[t]["yaml_path"]
|
136 |
+
|
137 |
+
output_type = ""
|
138 |
+
|
139 |
+
# read the yaml file to determine the output type
|
140 |
+
if path != -1:
|
141 |
+
config = utils.load_yaml_config(path, mode="simple")
|
142 |
+
if "output_type" in config:
|
143 |
+
output_type = config["output_type"]
|
144 |
+
elif (
|
145 |
+
"include" in config
|
146 |
+
): # if no output type, check if there is an include with an output type
|
147 |
+
include_path = path.split("/")[:-1] + config["include"]
|
148 |
+
include_config = utils.load_yaml_config(include_path, mode="simple")
|
149 |
+
if "output_type" in include_config:
|
150 |
+
output_type = include_config["output_type"]
|
151 |
+
|
152 |
+
if path == -1:
|
153 |
+
path = "---"
|
154 |
+
else:
|
155 |
+
path = sanitize_path(path)
|
156 |
+
st_values.append([t, path, output_type])
|
157 |
+
subtask_table.value_matrix = st_values
|
158 |
+
|
159 |
+
result = "\n"
|
160 |
+
if list_groups:
|
161 |
+
result += group_table.dumps() + "\n\n"
|
162 |
+
if list_tags:
|
163 |
+
result += tag_table.dumps() + "\n\n"
|
164 |
+
if list_subtasks:
|
165 |
+
result += subtask_table.dumps() + "\n\n"
|
166 |
+
return result
|
167 |
+
|
168 |
+
def match_tasks(self, task_list):
|
169 |
+
return utils.pattern_match(task_list, self.all_tasks)
|
170 |
+
|
171 |
+
def _name_is_registered(self, name) -> bool:
|
172 |
+
if name in self.all_tasks:
|
173 |
+
return True
|
174 |
+
return False
|
175 |
+
|
176 |
+
def _name_is_task(self, name) -> bool:
|
177 |
+
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
|
178 |
+
return True
|
179 |
+
return False
|
180 |
+
|
181 |
+
def _name_is_tag(self, name) -> bool:
|
182 |
+
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
|
183 |
+
return True
|
184 |
+
return False
|
185 |
+
|
186 |
+
def _name_is_group(self, name) -> bool:
|
187 |
+
if self._name_is_registered(name) and (
|
188 |
+
self.task_index[name]["type"] == "group"
|
189 |
+
):
|
190 |
+
return True
|
191 |
+
return False
|
192 |
+
|
193 |
+
def _name_is_python_task(self, name):
|
194 |
+
if self._name_is_registered(name) and (
|
195 |
+
self.task_index[name]["type"] == "python_task"
|
196 |
+
):
|
197 |
+
return True
|
198 |
+
return False
|
199 |
+
|
200 |
+
def _config_is_task(self, config) -> bool:
|
201 |
+
if ("task" in config) and isinstance(config["task"], str):
|
202 |
+
return True
|
203 |
+
return False
|
204 |
+
|
205 |
+
def _config_is_group(self, config) -> bool:
|
206 |
+
if ("task" in config) and isinstance(config["task"], list):
|
207 |
+
return True
|
208 |
+
return False
|
209 |
+
|
210 |
+
def _config_is_python_task(self, config) -> bool:
|
211 |
+
if "class" in config:
|
212 |
+
return True
|
213 |
+
return False
|
214 |
+
|
215 |
+
def _get_yaml_path(self, name):
|
216 |
+
if name not in self.task_index:
|
217 |
+
raise ValueError
|
218 |
+
return self.task_index[name]["yaml_path"]
|
219 |
+
|
220 |
+
def _get_config(self, name):
|
221 |
+
if name not in self.task_index:
|
222 |
+
raise ValueError
|
223 |
+
yaml_path = self._get_yaml_path(name)
|
224 |
+
if yaml_path == -1:
|
225 |
+
return {}
|
226 |
+
else:
|
227 |
+
return utils.load_yaml_config(yaml_path, mode="full")
|
228 |
+
|
229 |
+
def _get_tasklist(self, name):
|
230 |
+
if self._name_is_task(name):
|
231 |
+
raise ValueError
|
232 |
+
return self.task_index[name]["task"]
|
233 |
+
|
234 |
+
def _process_alias(self, config, group=None):
|
235 |
+
# If the group is not the same as the original
|
236 |
+
# group which the group alias was intended for,
|
237 |
+
# Set the group_alias to None instead.
|
238 |
+
if ("group_alias" in config) and ("group" in config) and group is not None:
|
239 |
+
if config["group"] != group:
|
240 |
+
config["group_alias"] = None
|
241 |
+
return config
|
242 |
+
|
243 |
+
def _class_has_config_in_constructor(self, cls):
|
244 |
+
constructor = getattr(cls, "__init__", None)
|
245 |
+
return (
|
246 |
+
"config" in inspect.signature(constructor).parameters
|
247 |
+
if constructor
|
248 |
+
else False
|
249 |
+
)
|
250 |
+
|
251 |
+
def _load_individual_task_or_group(
|
252 |
+
self,
|
253 |
+
name_or_config: Optional[Union[str, dict]] = None,
|
254 |
+
parent_name: Optional[str] = None,
|
255 |
+
update_config: Optional[dict] = None,
|
256 |
+
) -> Mapping:
|
257 |
+
def _load_task(config, task):
|
258 |
+
if "include" in config:
|
259 |
+
config = {
|
260 |
+
**utils.load_yaml_config(
|
261 |
+
yaml_path=None,
|
262 |
+
yaml_config={"include": config.pop("include")},
|
263 |
+
mode="full",
|
264 |
+
),
|
265 |
+
**config,
|
266 |
+
}
|
267 |
+
if self._config_is_python_task(config):
|
268 |
+
if self._class_has_config_in_constructor(config["class"]):
|
269 |
+
task_object = config["class"](config=config)
|
270 |
+
else:
|
271 |
+
task_object = config["class"]()
|
272 |
+
if isinstance(task_object, ConfigurableTask):
|
273 |
+
# very scuffed: set task name here. TODO: fixme?
|
274 |
+
task_object.config.task = config["task"]
|
275 |
+
else:
|
276 |
+
task_object = ConfigurableTask(config=config)
|
277 |
+
|
278 |
+
return {task: task_object}
|
279 |
+
|
280 |
+
def _get_group_and_subtask_from_config(config):
|
281 |
+
group_name = ConfigurableGroup(config=config)
|
282 |
+
subtask_list = []
|
283 |
+
for task in group_name.config["task"]:
|
284 |
+
if isinstance(task, str) and self._name_is_tag(task):
|
285 |
+
subtask_list.extend(self._get_tasklist(task))
|
286 |
+
else:
|
287 |
+
subtask_list.append(task)
|
288 |
+
return group_name, subtask_list
|
289 |
+
|
290 |
+
def _process_group_config(config, update_config=None):
|
291 |
+
if update_config is not None:
|
292 |
+
config = {**config, **update_config}
|
293 |
+
_update_config = {
|
294 |
+
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
|
295 |
+
}
|
296 |
+
if not bool(_update_config):
|
297 |
+
_update_config = None
|
298 |
+
|
299 |
+
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
|
300 |
+
return group_config, _update_config
|
301 |
+
|
302 |
+
if isinstance(name_or_config, str):
|
303 |
+
if update_config is not None:
|
304 |
+
# Process name_or_config as a dict instead
|
305 |
+
name_or_config = {"task": name_or_config, **update_config}
|
306 |
+
elif self._name_is_task(name_or_config) or self._name_is_python_task(
|
307 |
+
name_or_config
|
308 |
+
):
|
309 |
+
task_config = self._get_config(name_or_config)
|
310 |
+
return _load_task(task_config, task=name_or_config)
|
311 |
+
else:
|
312 |
+
subtask_list = self._get_tasklist(name_or_config)
|
313 |
+
if subtask_list == -1:
|
314 |
+
group_config = self._get_config(name_or_config)
|
315 |
+
group_config, update_config = _process_group_config(group_config)
|
316 |
+
group_name, subtask_list = _get_group_and_subtask_from_config(
|
317 |
+
group_config
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
if self._name_is_tag(name_or_config):
|
321 |
+
fn = partial(
|
322 |
+
self._load_individual_task_or_group,
|
323 |
+
update_config=name_or_config
|
324 |
+
if isinstance(name_or_config, dict)
|
325 |
+
else None,
|
326 |
+
)
|
327 |
+
return dict(
|
328 |
+
collections.ChainMap(*map(fn, reversed(subtask_list)))
|
329 |
+
)
|
330 |
+
else:
|
331 |
+
group_name = ConfigurableGroup(
|
332 |
+
config={"group": name_or_config, "task": subtask_list}
|
333 |
+
)
|
334 |
+
|
335 |
+
if isinstance(name_or_config, dict):
|
336 |
+
if self._config_is_task(name_or_config):
|
337 |
+
name = name_or_config.pop("task")
|
338 |
+
if update_config is not None:
|
339 |
+
name_or_config = {**name_or_config, **update_config}
|
340 |
+
# If the name is registered as a group
|
341 |
+
if self._name_is_group(name):
|
342 |
+
group_config = self._get_config(name)
|
343 |
+
|
344 |
+
group_config, update_config = _process_group_config(
|
345 |
+
group_config, name_or_config
|
346 |
+
)
|
347 |
+
group_name, subtask_list = _get_group_and_subtask_from_config(
|
348 |
+
group_config
|
349 |
+
)
|
350 |
+
elif self._name_is_tag(name):
|
351 |
+
subtask_list = self._get_tasklist(name)
|
352 |
+
fn = partial(
|
353 |
+
self._load_individual_task_or_group,
|
354 |
+
update_config=name_or_config,
|
355 |
+
)
|
356 |
+
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
|
357 |
+
else:
|
358 |
+
if self._name_is_registered(name):
|
359 |
+
base_task_config = self._get_config(name)
|
360 |
+
|
361 |
+
# Check if this is a duplicate.
|
362 |
+
if parent_name is not None:
|
363 |
+
num_duplicate = len(
|
364 |
+
list(
|
365 |
+
filter(
|
366 |
+
lambda x: x.startswith(name),
|
367 |
+
self.task_group_map[parent_name],
|
368 |
+
)
|
369 |
+
)
|
370 |
+
)
|
371 |
+
if num_duplicate > 0:
|
372 |
+
name = f"{name}-{num_duplicate}"
|
373 |
+
self.task_group_map[parent_name].append(name)
|
374 |
+
|
375 |
+
task_config = {
|
376 |
+
**base_task_config,
|
377 |
+
**name_or_config,
|
378 |
+
}
|
379 |
+
else:
|
380 |
+
task_config = name_or_config
|
381 |
+
return _load_task(task_config, task=name)
|
382 |
+
else:
|
383 |
+
group_config, update_config = _process_group_config(name_or_config)
|
384 |
+
group_name, subtask_list = _get_group_and_subtask_from_config(
|
385 |
+
group_config
|
386 |
+
)
|
387 |
+
|
388 |
+
fn = partial(
|
389 |
+
self._load_individual_task_or_group,
|
390 |
+
parent_name=group_name,
|
391 |
+
update_config=update_config,
|
392 |
+
)
|
393 |
+
return {
|
394 |
+
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
|
395 |
+
}
|
396 |
+
|
397 |
+
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
|
398 |
+
"""Loads a dictionary of task objects from a list
|
399 |
+
|
400 |
+
:param task_list: Union[str, list] = None
|
401 |
+
Single string or list of string of task names to be loaded
|
402 |
+
|
403 |
+
:return
|
404 |
+
Dictionary of task objects
|
405 |
+
"""
|
406 |
+
if isinstance(task_list, str):
|
407 |
+
task_list = [task_list]
|
408 |
+
|
409 |
+
all_loaded_tasks = dict(
|
410 |
+
collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
|
411 |
+
)
|
412 |
+
return all_loaded_tasks
|
413 |
+
|
414 |
+
def load_config(self, config: Dict):
|
415 |
+
return self._load_individual_task_or_group(config)
|
416 |
+
|
417 |
+
def _get_task_and_group(self, task_dir: str):
|
418 |
+
"""Creates a dictionary of tasks index with the following metadata,
|
419 |
+
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
|
420 |
+
`task` refer to regular task configs, `python_task` are special
|
421 |
+
yaml files that only consists of `task` and `class` parameters.
|
422 |
+
`group` are group configs. `tags` are labels that can be assigned
|
423 |
+
to tasks to assist in sorting and calling tasks of certain themes.
|
424 |
+
- `yaml_path`, path to the yaml file. If the entry is a `group` that
|
425 |
+
was configured through a task config, the yaml_path will be -1
|
426 |
+
and all subtasks will be listed in `task` (see below)
|
427 |
+
- `task`, reserved for entries with `type` as `group`. This will list
|
428 |
+
all subtasks. When a group config is created (as opposed to task
|
429 |
+
config having `group` parameter set), this will be set to -1 to
|
430 |
+
avoid recursive indexing. The whole list of subtasks will be loaded
|
431 |
+
at evaluation.
|
432 |
+
|
433 |
+
:param task_dir: str
|
434 |
+
A directory to check for tasks
|
435 |
+
|
436 |
+
:return
|
437 |
+
Dictionary of task names as key and task metadata
|
438 |
+
"""
|
439 |
+
# TODO: remove group in next release
|
440 |
+
print_info = True
|
441 |
+
ignore_dirs = [
|
442 |
+
"__pycache__",
|
443 |
+
".ipynb_checkpoints",
|
444 |
+
]
|
445 |
+
tasks_and_groups = collections.defaultdict()
|
446 |
+
for root, dirs, file_list in os.walk(task_dir):
|
447 |
+
dirs[:] = [d for d in dirs if d not in ignore_dirs]
|
448 |
+
for f in file_list:
|
449 |
+
if f.endswith(".yaml"):
|
450 |
+
yaml_path = os.path.join(root, f)
|
451 |
+
config = utils.load_yaml_config(yaml_path, mode="simple")
|
452 |
+
if self._config_is_python_task(config):
|
453 |
+
# This is a python class config
|
454 |
+
tasks_and_groups[config["task"]] = {
|
455 |
+
"type": "python_task",
|
456 |
+
"yaml_path": yaml_path,
|
457 |
+
}
|
458 |
+
elif self._config_is_group(config):
|
459 |
+
# This is a group config
|
460 |
+
tasks_and_groups[config["group"]] = {
|
461 |
+
"type": "group",
|
462 |
+
"task": -1, # This signals that
|
463 |
+
# we don't need to know
|
464 |
+
# the task list for indexing
|
465 |
+
# as it can be loaded
|
466 |
+
# when called.
|
467 |
+
"yaml_path": yaml_path,
|
468 |
+
}
|
469 |
+
|
470 |
+
# # Registered the level 1 tasks from a group config
|
471 |
+
# for config in config["task"]:
|
472 |
+
# if isinstance(config, dict) and self._config_is_task(config):
|
473 |
+
# task = config["task"]
|
474 |
+
# tasks_and_groups[task] = {
|
475 |
+
# "type": "task",
|
476 |
+
# "yaml_path": yaml_path,
|
477 |
+
# }
|
478 |
+
|
479 |
+
elif self._config_is_task(config):
|
480 |
+
# This is a task config
|
481 |
+
task = config["task"]
|
482 |
+
tasks_and_groups[task] = {
|
483 |
+
"type": "task",
|
484 |
+
"yaml_path": yaml_path,
|
485 |
+
}
|
486 |
+
|
487 |
+
# TODO: remove group in next release
|
488 |
+
for attr in ["tag", "group"]:
|
489 |
+
if attr in config:
|
490 |
+
if attr == "group" and print_info:
|
491 |
+
self.logger.info(
|
492 |
+
"`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. "
|
493 |
+
"`tag` will be used to allow to call a collection of tasks just like `group`. "
|
494 |
+
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
|
495 |
+
"which will be the offical way to create groups with addition of group-wide configuations."
|
496 |
+
)
|
497 |
+
print_info = False
|
498 |
+
# attr = "tag"
|
499 |
+
|
500 |
+
attr_list = config[attr]
|
501 |
+
if isinstance(attr_list, str):
|
502 |
+
attr_list = [attr_list]
|
503 |
+
|
504 |
+
for tag in attr_list:
|
505 |
+
if tag not in tasks_and_groups:
|
506 |
+
tasks_and_groups[tag] = {
|
507 |
+
"type": "tag",
|
508 |
+
"task": [task],
|
509 |
+
"yaml_path": -1,
|
510 |
+
}
|
511 |
+
elif tasks_and_groups[tag]["type"] != "tag":
|
512 |
+
self.logger.info(
|
513 |
+
f"The tag {tag} is already registered as a group, this tag will not be registered. "
|
514 |
+
"This may affect tasks you want to call."
|
515 |
+
)
|
516 |
+
break
|
517 |
+
else:
|
518 |
+
tasks_and_groups[tag]["task"].append(task)
|
519 |
+
else:
|
520 |
+
self.logger.debug(f"File {f} in {root} could not be loaded")
|
521 |
+
|
522 |
+
return tasks_and_groups
|
523 |
+
|
524 |
+
|
525 |
+
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
|
526 |
+
if "task" in task_config:
|
527 |
+
return task_config["task"]
|
528 |
+
if "dataset_name" in task_config:
|
529 |
+
return "{dataset_path}_{dataset_name}".format(**task_config)
|
530 |
+
else:
|
531 |
+
return "{dataset_path}".format(**task_config)
|
532 |
+
|
533 |
+
|
534 |
+
def get_task_name_from_object(task_object):
|
535 |
+
if hasattr(task_object, "config"):
|
536 |
+
return task_object._config["task"]
|
537 |
+
|
538 |
+
# TODO: scrap this
|
539 |
+
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
|
540 |
+
return (
|
541 |
+
task_object.EVAL_HARNESS_NAME
|
542 |
+
if hasattr(task_object, "EVAL_HARNESS_NAME")
|
543 |
+
else type(task_object).__name__
|
544 |
+
)
|
545 |
+
|
546 |
+
|
547 |
+
def _check_duplicates(task_dict: dict) -> List[str]:
|
548 |
+
"""helper function solely used in validating get_task_dict output.
|
549 |
+
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
|
550 |
+
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
|
551 |
+
"oversubscribed" to several disjoint groups.
|
552 |
+
"""
|
553 |
+
subtask_names = []
|
554 |
+
for key, value in task_dict.items():
|
555 |
+
subtask_names.extend(value)
|
556 |
+
|
557 |
+
duplicate_tasks = {
|
558 |
+
task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
|
559 |
+
}
|
560 |
+
|
561 |
+
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks
|
562 |
+
competing_groups = [
|
563 |
+
group
|
564 |
+
for group in task_dict.keys()
|
565 |
+
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
|
566 |
+
]
|
567 |
+
|
568 |
+
if len(duplicate_tasks) > 0:
|
569 |
+
raise ValueError(
|
570 |
+
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
|
571 |
+
)
|
572 |
+
|
573 |
+
|
574 |
+
def get_task_dict(
|
575 |
+
task_name_list: Union[str, List[Union[str, Dict, Task]]],
|
576 |
+
task_manager: Optional[TaskManager] = None,
|
577 |
+
):
|
578 |
+
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
|
579 |
+
|
580 |
+
:param task_name_list: List[Union[str, Dict, Task]]
|
581 |
+
Name of model or LM object, see lm_eval.models.get_model
|
582 |
+
:param task_manager: TaskManager = None
|
583 |
+
A TaskManager object that stores indexed tasks. If not set,
|
584 |
+
task_manager will load one. This should be set by the user
|
585 |
+
if there are additional paths that want to be included
|
586 |
+
via `include_path`
|
587 |
+
|
588 |
+
:return
|
589 |
+
Dictionary of task objects
|
590 |
+
"""
|
591 |
+
|
592 |
+
task_name_from_string_dict = {}
|
593 |
+
task_name_from_config_dict = {}
|
594 |
+
task_name_from_object_dict = {}
|
595 |
+
|
596 |
+
if isinstance(task_name_list, str):
|
597 |
+
task_name_list = [task_name_list]
|
598 |
+
elif isinstance(task_name_list, list):
|
599 |
+
if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
|
600 |
+
raise TypeError(
|
601 |
+
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
|
602 |
+
)
|
603 |
+
else:
|
604 |
+
raise TypeError(
|
605 |
+
f"Expected a 'str' or 'list' but received {type(task_name_list)}."
|
606 |
+
)
|
607 |
+
|
608 |
+
string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
|
609 |
+
others_task_name_list = [
|
610 |
+
task for task in task_name_list if not isinstance(task, str)
|
611 |
+
]
|
612 |
+
if len(string_task_name_list) > 0:
|
613 |
+
if task_manager is None:
|
614 |
+
task_manager = TaskManager()
|
615 |
+
|
616 |
+
task_name_from_string_dict = task_manager.load_task_or_group(
|
617 |
+
string_task_name_list
|
618 |
+
)
|
619 |
+
|
620 |
+
for task_element in others_task_name_list:
|
621 |
+
if isinstance(task_element, dict):
|
622 |
+
task_name_from_config_dict = {
|
623 |
+
**task_name_from_config_dict,
|
624 |
+
**task_manager.load_config(config=task_element),
|
625 |
+
}
|
626 |
+
|
627 |
+
elif isinstance(task_element, Task):
|
628 |
+
task_name_from_object_dict = {
|
629 |
+
**task_name_from_object_dict,
|
630 |
+
get_task_name_from_object(task_element): task_element,
|
631 |
+
}
|
632 |
+
|
633 |
+
if not set(task_name_from_string_dict.keys()).isdisjoint(
|
634 |
+
set(task_name_from_object_dict.keys())
|
635 |
+
):
|
636 |
+
raise ValueError
|
637 |
+
|
638 |
+
final_task_dict = {
|
639 |
+
**task_name_from_string_dict,
|
640 |
+
**task_name_from_config_dict,
|
641 |
+
**task_name_from_object_dict,
|
642 |
+
}
|
643 |
+
|
644 |
+
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
|
645 |
+
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
|
646 |
+
# and we'd be unsure which to use and report.)
|
647 |
+
# we explicitly check and error in this case.
|
648 |
+
_check_duplicates(get_subtask_list(final_task_dict))
|
649 |
+
|
650 |
+
return final_task_dict
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/README.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ANLI
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: `Adversarial NLI: A New Benchmark for Natural Language Understanding`
|
6 |
+
|
7 |
+
Paper Link: https://arxiv.org/abs/1910.14599
|
8 |
+
|
9 |
+
Adversarial NLI (ANLI) is a dataset collected via an iterative, adversarial
|
10 |
+
human-and-model-in-the-loop procedure. It consists of three rounds that progressively
|
11 |
+
increase in difficulty and complexity, and each question-answer includes annotator-
|
12 |
+
provided explanations.
|
13 |
+
|
14 |
+
Homepage: https://github.com/facebookresearch/anli
|
15 |
+
|
16 |
+
### Citation
|
17 |
+
|
18 |
+
```
|
19 |
+
@inproceedings{nie-etal-2020-adversarial,
|
20 |
+
title = "Adversarial {NLI}: A New Benchmark for Natural Language Understanding",
|
21 |
+
author = "Nie, Yixin and
|
22 |
+
Williams, Adina and
|
23 |
+
Dinan, Emily and
|
24 |
+
Bansal, Mohit and
|
25 |
+
Weston, Jason and
|
26 |
+
Kiela, Douwe",
|
27 |
+
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
|
28 |
+
year = "2020",
|
29 |
+
publisher = "Association for Computational Linguistics",
|
30 |
+
}
|
31 |
+
```
|
32 |
+
|
33 |
+
### Groups and Tasks
|
34 |
+
|
35 |
+
#### Groups
|
36 |
+
|
37 |
+
* `anli`: Evaluates `anli_r1`, `anli_r2`, and `anli_r3`
|
38 |
+
|
39 |
+
#### Tasks
|
40 |
+
* `anli_r1`: The data collected adversarially in the first round.
|
41 |
+
* `anli_r2`: The data collected adversarially in the second round, after training on the previous round's data.
|
42 |
+
* `anli_r3`: The data collected adversarially in the third round, after training on the previous multiple rounds of data.
|
43 |
+
|
44 |
+
|
45 |
+
### Checklist
|
46 |
+
|
47 |
+
For adding novel benchmarks/datasets to the library:
|
48 |
+
* [x] Is the task an existing benchmark in the literature?
|
49 |
+
* [x] Have you referenced the original paper that introduced the task?
|
50 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
51 |
+
|
52 |
+
|
53 |
+
If other tasks on this dataset are already supported:
|
54 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
55 |
+
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
56 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/anli_r1.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tag:
|
2 |
+
- anli
|
3 |
+
task: anli_r1
|
4 |
+
dataset_path: anli
|
5 |
+
dataset_name: null
|
6 |
+
output_type: multiple_choice
|
7 |
+
training_split: train_r1
|
8 |
+
validation_split: dev_r1
|
9 |
+
test_split: test_r1
|
10 |
+
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}} True, False, or Neither?\nAnswer:"
|
11 |
+
# True = entailment
|
12 |
+
# False = contradiction
|
13 |
+
# Neither = neutral
|
14 |
+
doc_to_target: "{{['True', 'Neither', 'False'][label]}}"
|
15 |
+
doc_to_choice:
|
16 |
+
- "True"
|
17 |
+
- "Neither"
|
18 |
+
- "False"
|
19 |
+
should_decontaminate: true
|
20 |
+
doc_to_decontamination_query: premise
|
21 |
+
metric_list:
|
22 |
+
- metric: acc
|
23 |
+
aggregation: mean
|
24 |
+
higher_is_better: true
|
25 |
+
metadata:
|
26 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/anli_r2.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: anli_r1.yaml
|
2 |
+
task: anli_r2
|
3 |
+
training_split: train_r2
|
4 |
+
validation_split: dev_r2
|
5 |
+
test_split: test_r2
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/anli/anli_r3.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: anli_r1.yaml
|
2 |
+
task: anli_r3
|
3 |
+
training_split: train_r3
|
4 |
+
validation_split: dev_r3
|
5 |
+
test_split: test_r3
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/drop/README.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DROP
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: `DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs`
|
6 |
+
|
7 |
+
Abstract: https://aclanthology.org/attachments/N19-1246.Supplementary.pdf
|
8 |
+
|
9 |
+
DROP is a QA dataset which tests comprehensive understanding of paragraphs. In
|
10 |
+
this crowdsourced, adversarially-created, 96k question-answering benchmark, a
|
11 |
+
system must resolve multiple references in a question, map them onto a paragraph,
|
12 |
+
and perform discrete operations over them (such as addition, counting, or sorting).
|
13 |
+
|
14 |
+
Homepage: https://allenai.org/data/drop
|
15 |
+
|
16 |
+
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
|
17 |
+
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
|
18 |
+
|
19 |
+
### Citation
|
20 |
+
|
21 |
+
```
|
22 |
+
@misc{dua2019drop,
|
23 |
+
title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs},
|
24 |
+
author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner},
|
25 |
+
year={2019},
|
26 |
+
eprint={1903.00161},
|
27 |
+
archivePrefix={arXiv},
|
28 |
+
primaryClass={cs.CL}
|
29 |
+
}
|
30 |
+
```
|
31 |
+
|
32 |
+
### Groups and Tasks
|
33 |
+
|
34 |
+
#### Groups
|
35 |
+
|
36 |
+
* Not part of a group yet.
|
37 |
+
|
38 |
+
#### Tasks
|
39 |
+
|
40 |
+
* `drop`
|
41 |
+
|
42 |
+
### Checklist
|
43 |
+
|
44 |
+
For adding novel benchmarks/datasets to the library:
|
45 |
+
* [ ] Is the task an existing benchmark in the literature?
|
46 |
+
* [ ] Have you referenced the original paper that introduced the task?
|
47 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
48 |
+
|
49 |
+
|
50 |
+
If other tasks on this dataset are already supported:
|
51 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
52 |
+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
53 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/drop/default.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: drop
|
2 |
+
dataset_path: EleutherAI/drop
|
3 |
+
output_type: generate_until
|
4 |
+
training_split: train
|
5 |
+
validation_split: validation
|
6 |
+
process_docs: !function utils.process_docs
|
7 |
+
doc_to_text: "{{passage}} {{question}}"
|
8 |
+
doc_to_target: "{{ answer|join(',')}}"
|
9 |
+
target_delimiter: ""
|
10 |
+
process_results: !function utils.process_results
|
11 |
+
should_decontaminate: true
|
12 |
+
doc_to_decontamination_query: "{{passage}} {{question}}"
|
13 |
+
generation_kwargs:
|
14 |
+
until:
|
15 |
+
- "."
|
16 |
+
metric_list:
|
17 |
+
- metric: em
|
18 |
+
aggregation: mean
|
19 |
+
higher_is_better: true
|
20 |
+
- metric: f1
|
21 |
+
aggregation: mean
|
22 |
+
higher_is_better: true
|
23 |
+
metadata:
|
24 |
+
version: 3.0
|
25 |
+
dataset_kwargs:
|
26 |
+
trust_remote_code: true
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/drop/utils.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import string
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
8 |
+
|
9 |
+
|
10 |
+
def process_docs(dataset):
|
11 |
+
def _process(doc):
|
12 |
+
return {
|
13 |
+
"id": doc["query_id"],
|
14 |
+
"passage": doc["passage"],
|
15 |
+
"question": doc["question"],
|
16 |
+
"answers": get_answers(doc),
|
17 |
+
}
|
18 |
+
|
19 |
+
return dataset.map(_process)
|
20 |
+
|
21 |
+
|
22 |
+
def get_answers(doc):
|
23 |
+
def _flatten_validated_answers(validated_answers):
|
24 |
+
"""Flattens a dict of lists of validated answers.
|
25 |
+
{"number": ['1', '8'], ...}
|
26 |
+
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
|
27 |
+
"""
|
28 |
+
valid_answers = []
|
29 |
+
for i in range(len(validated_answers["number"])):
|
30 |
+
valid_answers.append(
|
31 |
+
{
|
32 |
+
"number": validated_answers["number"][i],
|
33 |
+
"date": validated_answers["date"][i],
|
34 |
+
"spans": validated_answers["spans"][i],
|
35 |
+
}
|
36 |
+
)
|
37 |
+
return valid_answers
|
38 |
+
|
39 |
+
answers = []
|
40 |
+
answers_set = set()
|
41 |
+
candidates = [doc["answer"]] + _flatten_validated_answers(doc["validated_answers"])
|
42 |
+
for candidate in candidates:
|
43 |
+
answer = parse_answer(candidate)
|
44 |
+
if answer in answers_set:
|
45 |
+
continue
|
46 |
+
answers_set.add(answer)
|
47 |
+
answers.append(answer)
|
48 |
+
return answers
|
49 |
+
|
50 |
+
|
51 |
+
def parse_answer(answer):
|
52 |
+
# NOTE: Everything is returned as a tuple for uniformity and hashability.
|
53 |
+
if answer["number"] != "":
|
54 |
+
return (str(answer["number"]),)
|
55 |
+
if answer["spans"] != []:
|
56 |
+
return tuple(answer["spans"])
|
57 |
+
return (
|
58 |
+
" ".join(
|
59 |
+
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
|
60 |
+
).strip(),
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def process_results(doc, results):
|
65 |
+
preds, golds = results, doc["answers"]
|
66 |
+
max_em = 0
|
67 |
+
max_f1 = 0
|
68 |
+
for gold_answer in golds:
|
69 |
+
exact_match, f1_score = get_metrics(preds, gold_answer)
|
70 |
+
if gold_answer[0].strip():
|
71 |
+
max_em = max(max_em, exact_match)
|
72 |
+
max_f1 = max(max_f1, f1_score)
|
73 |
+
return {"em": max_em, "f1": max_f1}
|
74 |
+
|
75 |
+
|
76 |
+
def get_metrics(predicted, gold):
|
77 |
+
"""
|
78 |
+
Takes a predicted answer and a gold answer (that are both either a string or a list of
|
79 |
+
strings), and returns exact match and the DROP F1 metric for the prediction. If you are
|
80 |
+
writing a script for evaluating objects in memory (say, the output of predictions during
|
81 |
+
validation, or while training), this is the function you want to call, after using
|
82 |
+
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
|
83 |
+
"""
|
84 |
+
predicted_bags = _answer_to_bags(predicted)
|
85 |
+
gold_bags = _answer_to_bags(gold)
|
86 |
+
|
87 |
+
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(
|
88 |
+
gold_bags[0]
|
89 |
+
):
|
90 |
+
exact_match = 1.0
|
91 |
+
else:
|
92 |
+
exact_match = 0.0
|
93 |
+
|
94 |
+
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
95 |
+
f1 = np.mean(f1_per_bag)
|
96 |
+
f1 = round(f1, 2)
|
97 |
+
return exact_match, f1
|
98 |
+
|
99 |
+
|
100 |
+
def _answer_to_bags(answer):
|
101 |
+
if isinstance(answer, (list, tuple)):
|
102 |
+
raw_spans = answer
|
103 |
+
else:
|
104 |
+
raw_spans = [answer]
|
105 |
+
normalized_spans = []
|
106 |
+
token_bags = []
|
107 |
+
for raw_span in raw_spans:
|
108 |
+
normalized_span = _normalize(raw_span)
|
109 |
+
normalized_spans.append(normalized_span)
|
110 |
+
token_bags.append(set(normalized_span.split()))
|
111 |
+
return normalized_spans, token_bags
|
112 |
+
|
113 |
+
|
114 |
+
def _align_bags(predicted, gold):
|
115 |
+
"""
|
116 |
+
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
|
117 |
+
between them and gets maximum metric values over all the answers.
|
118 |
+
"""
|
119 |
+
from scipy.optimize import linear_sum_assignment
|
120 |
+
|
121 |
+
scores = np.zeros([len(gold), len(predicted)])
|
122 |
+
for gold_index, gold_item in enumerate(gold):
|
123 |
+
for pred_index, pred_item in enumerate(predicted):
|
124 |
+
if _match_numbers_if_present(gold_item, pred_item):
|
125 |
+
scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
|
126 |
+
row_ind, col_ind = linear_sum_assignment(-scores)
|
127 |
+
|
128 |
+
max_scores = np.zeros([max(len(gold), len(predicted))])
|
129 |
+
for row, column in zip(row_ind, col_ind):
|
130 |
+
max_scores[row] = max(max_scores[row], scores[row, column])
|
131 |
+
return max_scores
|
132 |
+
|
133 |
+
|
134 |
+
def _compute_f1(predicted_bag, gold_bag):
|
135 |
+
intersection = len(gold_bag.intersection(predicted_bag))
|
136 |
+
if not predicted_bag:
|
137 |
+
precision = 1.0
|
138 |
+
else:
|
139 |
+
precision = intersection / float(len(predicted_bag))
|
140 |
+
if not gold_bag:
|
141 |
+
recall = 1.0
|
142 |
+
else:
|
143 |
+
recall = intersection / float(len(gold_bag))
|
144 |
+
f1 = (
|
145 |
+
(2 * precision * recall) / (precision + recall)
|
146 |
+
if not (precision == 0.0 and recall == 0.0)
|
147 |
+
else 0.0
|
148 |
+
)
|
149 |
+
return f1
|
150 |
+
|
151 |
+
|
152 |
+
def _match_numbers_if_present(gold_bag, predicted_bag):
|
153 |
+
gold_numbers = set()
|
154 |
+
predicted_numbers = set()
|
155 |
+
for word in gold_bag:
|
156 |
+
if _is_number(word):
|
157 |
+
gold_numbers.add(word)
|
158 |
+
for word in predicted_bag:
|
159 |
+
if _is_number(word):
|
160 |
+
predicted_numbers.add(word)
|
161 |
+
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
162 |
+
return True
|
163 |
+
return False
|
164 |
+
|
165 |
+
|
166 |
+
def _is_number(text):
|
167 |
+
try:
|
168 |
+
float(text)
|
169 |
+
return True
|
170 |
+
except ValueError:
|
171 |
+
return False
|
172 |
+
|
173 |
+
|
174 |
+
def _remove_articles(text):
|
175 |
+
return _ARTICLES.sub(" ", text)
|
176 |
+
|
177 |
+
|
178 |
+
def _white_space_fix(text):
|
179 |
+
return " ".join(text.split())
|
180 |
+
|
181 |
+
|
182 |
+
def _remove_punc(text):
|
183 |
+
exclude = set(string.punctuation)
|
184 |
+
if not _is_number(text):
|
185 |
+
return "".join(ch for ch in text if ch not in exclude)
|
186 |
+
else:
|
187 |
+
return text
|
188 |
+
|
189 |
+
|
190 |
+
def _fix_number(text):
|
191 |
+
return str(float(text)) if _is_number(text) else text
|
192 |
+
|
193 |
+
|
194 |
+
def _tokenize(text):
|
195 |
+
return re.split(" |-", text)
|
196 |
+
|
197 |
+
|
198 |
+
def _normalize(answer):
|
199 |
+
tokens = [
|
200 |
+
_white_space_fix(_remove_articles(_fix_number(_remove_punc(token.lower()))))
|
201 |
+
for token in _tokenize(answer)
|
202 |
+
]
|
203 |
+
tokens = [token for token in tokens if token.strip()]
|
204 |
+
normalized = " ".join(tokens).strip()
|
205 |
+
return normalized
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MATH
|
2 |
+
|
3 |
+
## Paper
|
4 |
+
Measuring Mathematical Problem Solving With the MATH Dataset
|
5 |
+
https://arxiv.org/abs/2103.03874
|
6 |
+
|
7 |
+
Many intellectual endeavors require mathematical problem solving, but this skill remains beyond the capabilities of computers. To measure this ability in machine learning models, we introduce MATH, a new dataset of 12,500 challenging competition mathematics problems. Each problem in MATH has a full step-by-step solution which can be used to teach models to generate answer derivations and explanations.
|
8 |
+
|
9 |
+
NOTE: This task corresponds to the MATH (`hendrycks_math`) implementation at https://github.com/EleutherAI/lm-evaluation-harness/tree/master . For the variant which uses the custom 4-shot prompt in the Minerva paper (https://arxiv.org/abs/2206.14858), and SymPy answer checking as done by Minerva, see `lm_eval/tasks/minerva_math`.
|
10 |
+
|
11 |
+
Homepage: https://github.com/hendrycks/math
|
12 |
+
|
13 |
+
|
14 |
+
## Citation
|
15 |
+
```
|
16 |
+
@article{hendrycksmath2021,
|
17 |
+
title={Measuring Mathematical Problem Solving With the MATH Dataset},
|
18 |
+
author={Dan Hendrycks and Collin Burns and Saurav Kadavath and Akul Arora and Steven Basart and Eric Tang and Dawn Song and Jacob Steinhardt},
|
19 |
+
journal={NeurIPS},
|
20 |
+
year={2021}
|
21 |
+
}
|
22 |
+
```
|
23 |
+
|
24 |
+
### Groups and Tasks
|
25 |
+
|
26 |
+
#### Groups
|
27 |
+
|
28 |
+
- `hendrycks_math`: the MATH benchmark from Hendrycks et al. 0- or few-shot.
|
29 |
+
|
30 |
+
#### Tasks
|
31 |
+
|
32 |
+
- `hendrycks_math_algebra`
|
33 |
+
- `hendrycks_math_counting_and_prob`
|
34 |
+
- `hendrycks_math_geometry`
|
35 |
+
- `hendrycks_math_intermediate_algebra`
|
36 |
+
- `hendrycks_math_num_theory`
|
37 |
+
- `hendrycks_math_prealgebra`
|
38 |
+
- `hendrycks_math_precalc`
|
39 |
+
|
40 |
+
### Checklist
|
41 |
+
|
42 |
+
The checklist is the following:
|
43 |
+
|
44 |
+
For adding novel benchmarks/datasets to the library:
|
45 |
+
* [x] Is the task an existing benchmark in the literature?
|
46 |
+
* [x] Have you referenced the original paper that introduced the task?
|
47 |
+
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
48 |
+
* Answer extraction code is taken from the original MATH benchmark paper's repository.
|
49 |
+
|
50 |
+
|
51 |
+
If other tasks on this dataset are already supported:
|
52 |
+
* [x] Is the "Main" variant of this task clearly denoted?
|
53 |
+
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
54 |
+
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group: hendrycks_math
|
2 |
+
task:
|
3 |
+
- hendrycks_math_algebra
|
4 |
+
- hendrycks_math_counting_and_prob
|
5 |
+
- hendrycks_math_geometry
|
6 |
+
- hendrycks_math_intermediate_algebra
|
7 |
+
- hendrycks_math_num_theory
|
8 |
+
- hendrycks_math_prealgebra
|
9 |
+
- hendrycks_math_precalc
|
10 |
+
aggregate_metric_list:
|
11 |
+
- metric: exact_match
|
12 |
+
aggregation: mean
|
13 |
+
weight_by_size: true
|
14 |
+
metadata:
|
15 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_algebra.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tag:
|
2 |
+
- math_word_problems
|
3 |
+
task: hendrycks_math_algebra
|
4 |
+
dataset_path: EleutherAI/hendrycks_math
|
5 |
+
process_docs: !function utils.process_docs
|
6 |
+
dataset_name: algebra
|
7 |
+
output_type: generate_until
|
8 |
+
training_split: train
|
9 |
+
test_split: test
|
10 |
+
doc_to_text: "Problem: {{problem}}\nAnswer:"
|
11 |
+
process_results: !function utils.process_results
|
12 |
+
doc_to_target: "{{answer}}"
|
13 |
+
generation_kwargs:
|
14 |
+
until:
|
15 |
+
- "Problem:"
|
16 |
+
do_sample: false
|
17 |
+
temperature: 0
|
18 |
+
metric_list:
|
19 |
+
- metric: exact_match
|
20 |
+
aggregation: mean
|
21 |
+
higher_is_better: true
|
22 |
+
metadata:
|
23 |
+
version: 1.0
|
24 |
+
dataset_kwargs:
|
25 |
+
trust_remote_code: true
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_counting_and_prob.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include: hendrycks_math_algebra.yaml
|
2 |
+
dataset_name: counting_and_probability
|
3 |
+
task: hendrycks_math_counting_and_prob
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_geometry.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include: hendrycks_math_algebra.yaml
|
2 |
+
dataset_name: geometry
|
3 |
+
task: hendrycks_math_geometry
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_intermediate_algebra.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include: hendrycks_math_algebra.yaml
|
2 |
+
dataset_name: intermediate_algebra
|
3 |
+
task: hendrycks_math_intermediate_algebra
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_num_theory.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include: hendrycks_math_algebra.yaml
|
2 |
+
dataset_name: number_theory
|
3 |
+
task: hendrycks_math_num_theory
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_prealgebra.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include: hendrycks_math_algebra.yaml
|
2 |
+
dataset_name: prealgebra
|
3 |
+
task: hendrycks_math_prealgebra
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/hendrycks_math_precalc.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include: hendrycks_math_algebra.yaml
|
2 |
+
dataset_name: precalculus
|
3 |
+
task: hendrycks_math_precalc
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_math/utils.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
import datasets
|
4 |
+
|
5 |
+
|
6 |
+
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
|
7 |
+
def _process_doc(doc: dict) -> dict:
|
8 |
+
out_doc = {
|
9 |
+
"problem": doc["problem"],
|
10 |
+
"solution": doc["solution"],
|
11 |
+
"answer": remove_boxed(last_boxed_only_string(doc["solution"])),
|
12 |
+
}
|
13 |
+
return out_doc
|
14 |
+
|
15 |
+
return dataset.map(_process_doc)
|
16 |
+
|
17 |
+
|
18 |
+
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
|
19 |
+
retval = 0
|
20 |
+
indices = [pos for pos, char in enumerate(results[0]) if char == "$"]
|
21 |
+
if len(indices) <= 1:
|
22 |
+
answer = results[0]
|
23 |
+
else:
|
24 |
+
answer = results[0][indices[0] + 1 : indices[-1]]
|
25 |
+
|
26 |
+
if is_equiv(answer, remove_boxed(last_boxed_only_string(doc["solution"]))):
|
27 |
+
retval = 1
|
28 |
+
|
29 |
+
results = {
|
30 |
+
"exact_match": retval,
|
31 |
+
}
|
32 |
+
return results
|
33 |
+
|
34 |
+
|
35 |
+
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
|
36 |
+
def is_equiv(str1, str2, verbose=False):
|
37 |
+
if str1 is None and str2 is None:
|
38 |
+
print("WARNING: Both None")
|
39 |
+
return True
|
40 |
+
if str1 is None or str2 is None:
|
41 |
+
return False
|
42 |
+
|
43 |
+
try:
|
44 |
+
ss1 = strip_string(str1)
|
45 |
+
ss2 = strip_string(str2)
|
46 |
+
if verbose:
|
47 |
+
print(ss1, ss2)
|
48 |
+
return ss1 == ss2
|
49 |
+
except Exception:
|
50 |
+
return str1 == str2
|
51 |
+
|
52 |
+
|
53 |
+
def remove_boxed(s):
|
54 |
+
if "\\boxed " in s:
|
55 |
+
left = "\\boxed "
|
56 |
+
assert s[: len(left)] == left
|
57 |
+
return s[len(left) :]
|
58 |
+
|
59 |
+
left = "\\boxed{"
|
60 |
+
|
61 |
+
assert s[: len(left)] == left
|
62 |
+
assert s[-1] == "}"
|
63 |
+
|
64 |
+
return s[len(left) : -1]
|
65 |
+
|
66 |
+
|
67 |
+
def last_boxed_only_string(string):
|
68 |
+
idx = string.rfind("\\boxed")
|
69 |
+
if "\\boxed " in string:
|
70 |
+
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
|
71 |
+
if idx < 0:
|
72 |
+
idx = string.rfind("\\fbox")
|
73 |
+
if idx < 0:
|
74 |
+
return None
|
75 |
+
|
76 |
+
i = idx
|
77 |
+
right_brace_idx = None
|
78 |
+
num_left_braces_open = 0
|
79 |
+
while i < len(string):
|
80 |
+
if string[i] == "{":
|
81 |
+
num_left_braces_open += 1
|
82 |
+
if string[i] == "}":
|
83 |
+
num_left_braces_open -= 1
|
84 |
+
if num_left_braces_open == 0:
|
85 |
+
right_brace_idx = i
|
86 |
+
break
|
87 |
+
i += 1
|
88 |
+
|
89 |
+
if right_brace_idx is None:
|
90 |
+
retval = None
|
91 |
+
else:
|
92 |
+
retval = string[idx : right_brace_idx + 1]
|
93 |
+
|
94 |
+
return retval
|
95 |
+
|
96 |
+
|
97 |
+
def fix_fracs(string):
|
98 |
+
substrs = string.split("\\frac")
|
99 |
+
new_str = substrs[0]
|
100 |
+
if len(substrs) > 1:
|
101 |
+
substrs = substrs[1:]
|
102 |
+
for substr in substrs:
|
103 |
+
new_str += "\\frac"
|
104 |
+
if substr[0] == "{":
|
105 |
+
new_str += substr
|
106 |
+
else:
|
107 |
+
try:
|
108 |
+
assert len(substr) >= 2
|
109 |
+
except AssertionError:
|
110 |
+
return string
|
111 |
+
a = substr[0]
|
112 |
+
b = substr[1]
|
113 |
+
if b != "{":
|
114 |
+
if len(substr) > 2:
|
115 |
+
post_substr = substr[2:]
|
116 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
117 |
+
else:
|
118 |
+
new_str += "{" + a + "}{" + b + "}"
|
119 |
+
else:
|
120 |
+
if len(substr) > 2:
|
121 |
+
post_substr = substr[2:]
|
122 |
+
new_str += "{" + a + "}" + b + post_substr
|
123 |
+
else:
|
124 |
+
new_str += "{" + a + "}" + b
|
125 |
+
string = new_str
|
126 |
+
return string
|
127 |
+
|
128 |
+
|
129 |
+
def fix_a_slash_b(string):
|
130 |
+
if len(string.split("/")) != 2:
|
131 |
+
return string
|
132 |
+
a = string.split("/")[0]
|
133 |
+
b = string.split("/")[1]
|
134 |
+
try:
|
135 |
+
a = int(a)
|
136 |
+
b = int(b)
|
137 |
+
assert string == "{}/{}".format(a, b)
|
138 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
139 |
+
return new_string
|
140 |
+
except AssertionError:
|
141 |
+
return string
|
142 |
+
|
143 |
+
|
144 |
+
def remove_right_units(string):
|
145 |
+
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
146 |
+
if "\\text{ " in string:
|
147 |
+
splits = string.split("\\text{ ")
|
148 |
+
assert len(splits) == 2
|
149 |
+
return splits[0]
|
150 |
+
else:
|
151 |
+
return string
|
152 |
+
|
153 |
+
|
154 |
+
def fix_sqrt(string):
|
155 |
+
if "\\sqrt" not in string:
|
156 |
+
return string
|
157 |
+
splits = string.split("\\sqrt")
|
158 |
+
new_string = splits[0]
|
159 |
+
for split in splits[1:]:
|
160 |
+
if split[0] != "{":
|
161 |
+
a = split[0]
|
162 |
+
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
163 |
+
else:
|
164 |
+
new_substr = "\\sqrt" + split
|
165 |
+
new_string += new_substr
|
166 |
+
return new_string
|
167 |
+
|
168 |
+
|
169 |
+
def strip_string(string):
|
170 |
+
# linebreaks
|
171 |
+
string = string.replace("\n", "")
|
172 |
+
|
173 |
+
# remove inverse spaces
|
174 |
+
string = string.replace("\\!", "")
|
175 |
+
|
176 |
+
# replace \\ with \
|
177 |
+
string = string.replace("\\\\", "\\")
|
178 |
+
|
179 |
+
# replace tfrac and dfrac with frac
|
180 |
+
string = string.replace("tfrac", "frac")
|
181 |
+
string = string.replace("dfrac", "frac")
|
182 |
+
|
183 |
+
# remove \left and \right
|
184 |
+
string = string.replace("\\left", "")
|
185 |
+
string = string.replace("\\right", "")
|
186 |
+
|
187 |
+
# Remove circ (degrees)
|
188 |
+
string = string.replace("^{\\circ}", "")
|
189 |
+
string = string.replace("^\\circ", "")
|
190 |
+
|
191 |
+
# remove dollar signs
|
192 |
+
string = string.replace("\\$", "")
|
193 |
+
|
194 |
+
# remove units (on the right)
|
195 |
+
string = remove_right_units(string)
|
196 |
+
|
197 |
+
# remove percentage
|
198 |
+
string = string.replace("\\%", "")
|
199 |
+
string = string.replace("\%", "") # noqa: W605
|
200 |
+
|
201 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
202 |
+
string = string.replace(" .", " 0.")
|
203 |
+
string = string.replace("{.", "{0.")
|
204 |
+
# if empty, return empty string
|
205 |
+
if len(string) == 0:
|
206 |
+
return string
|
207 |
+
if string[0] == ".":
|
208 |
+
string = "0" + string
|
209 |
+
|
210 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
211 |
+
if len(string.split("=")) == 2:
|
212 |
+
if len(string.split("=")[0]) <= 2:
|
213 |
+
string = string.split("=")[1]
|
214 |
+
|
215 |
+
# fix sqrt3 --> sqrt{3}
|
216 |
+
string = fix_sqrt(string)
|
217 |
+
|
218 |
+
# remove spaces
|
219 |
+
string = string.replace(" ", "")
|
220 |
+
|
221 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
222 |
+
string = fix_fracs(string)
|
223 |
+
|
224 |
+
# manually change 0.5 --> \frac{1}{2}
|
225 |
+
if string == "0.5":
|
226 |
+
string = "\\frac{1}{2}"
|
227 |
+
|
228 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
229 |
+
string = fix_a_slash_b(string)
|
230 |
+
|
231 |
+
return string
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/siqa/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Social IQA
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: Social IQA: Commonsense Reasoning about Social Interactions
|
6 |
+
|
7 |
+
Abstract: https://arxiv.org/abs/1904.09728
|
8 |
+
|
9 |
+
> We introduce Social IQa, the first largescale benchmark for commonsense reasoning about social situations. Social IQa contains 38,000 multiple choice questions for probing emotional and social intelligence in a variety of everyday situations (e.g., Q: "Jordan wanted to tell Tracy a secret, so Jordan leaned towards Tracy. Why did Jordan do this?" A: "Make sure no one else could hear"). Through crowdsourcing, we collect commonsense questions along with correct and incorrect answers about social interactions, using a new framework that mitigates stylistic artifacts in incorrect answers by asking workers to provide the right answer to a different but related question. Empirical results show that our benchmark is challenging for existing question-answering models based on pretrained language models, compared to human performance (>20% gap). Notably, we further establish Social IQa as a resource for transfer learning of commonsense knowledge, achieving state-of-the-art performance on multiple commonsense reasoning tasks (Winograd Schemas, COPA).
|
10 |
+
|
11 |
+
Homepage: https://allenai.org/data/socialiqa
|
12 |
+
|
13 |
+
|
14 |
+
### Citation
|
15 |
+
|
16 |
+
```
|
17 |
+
@inproceedings{sap2019social,
|
18 |
+
title={Social IQa: Commonsense Reasoning about Social Interactions},
|
19 |
+
author={Sap, Maarten and Rashkin, Hannah and Chen, Derek and Le Bras, Ronan and Choi, Yejin},
|
20 |
+
booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
|
21 |
+
pages={4463--4473},
|
22 |
+
year={2019}
|
23 |
+
}
|
24 |
+
```
|
25 |
+
|
26 |
+
### Checklist
|
27 |
+
|
28 |
+
For adding novel benchmarks/datasets to the library:
|
29 |
+
* [X] Is the task an existing benchmark in the literature?
|
30 |
+
* [X] Have you referenced the original paper that introduced the task?
|
31 |
+
* [X] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? The original paper doesn't have an associated implementation, but there is an official entry in [BigBench](https://github.com/google/BIG-bench/tree/main/bigbench/benchmark_tasks/social_iqa). I use the same prompting format as BigBench.
|
32 |
+
|
33 |
+
|
34 |
+
If other tasks on this dataset are already supported:
|
35 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
36 |
+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
37 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/siqa/siqa.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: social_iqa
|
2 |
+
dataset_path: social_i_qa
|
3 |
+
dataset_name: null
|
4 |
+
output_type: multiple_choice
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
doc_to_text: "Q: {{context}} {{question}}\nA:"
|
8 |
+
target_delimiter: " "
|
9 |
+
doc_to_choice: "{{[answerA, answerB, answerC]}}"
|
10 |
+
doc_to_target: "{{ (label|int) - 1 }}"
|
11 |
+
metric_list:
|
12 |
+
- metric: acc
|
13 |
+
aggregation: mean
|
14 |
+
higher_is_better: true
|
15 |
+
metadata:
|
16 |
+
version: 0.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/squadv2/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Task-name
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: `Know What You Don’t Know: Unanswerable Questions for SQuAD`
|
6 |
+
Abstract: https://arxiv.org/abs/1806.03822
|
7 |
+
|
8 |
+
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset,
|
9 |
+
consisting of questions posed by crowdworkers on a set of Wikipedia articles,
|
10 |
+
where the answer to every question is a segment of text, or span, from the
|
11 |
+
corresponding reading passage, or the question might be unanswerable.
|
12 |
+
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable
|
13 |
+
questions written adversarially by crowdworkers to look similar to answerable ones.
|
14 |
+
To do well on SQuAD2.0, systems must not only answer questions when possible, but
|
15 |
+
also determine when no answer is supported by the paragraph and abstain from answering.
|
16 |
+
|
17 |
+
Homepage: https://rajpurkar.github.io/SQuAD-explorer/
|
18 |
+
|
19 |
+
|
20 |
+
### Citation
|
21 |
+
|
22 |
+
```
|
23 |
+
@misc{rajpurkar2018know,
|
24 |
+
title={Know What You Don't Know: Unanswerable Questions for SQuAD},
|
25 |
+
author={Pranav Rajpurkar and Robin Jia and Percy Liang},
|
26 |
+
year={2018},
|
27 |
+
eprint={1806.03822},
|
28 |
+
archivePrefix={arXiv},
|
29 |
+
primaryClass={cs.CL}
|
30 |
+
}
|
31 |
+
```
|
32 |
+
|
33 |
+
### Groups and Tasks
|
34 |
+
|
35 |
+
#### Groups
|
36 |
+
|
37 |
+
* Not part of a group yet
|
38 |
+
|
39 |
+
#### Tasks
|
40 |
+
|
41 |
+
* `squadv2`: `Default squadv2 task`
|
42 |
+
|
43 |
+
### Checklist
|
44 |
+
|
45 |
+
For adding novel benchmarks/datasets to the library:
|
46 |
+
* [ ] Is the task an existing benchmark in the literature?
|
47 |
+
* [ ] Have you referenced the original paper that introduced the task?
|
48 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
49 |
+
|
50 |
+
|
51 |
+
If other tasks on this dataset are already supported:
|
52 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
53 |
+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
54 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/squadv2/squadv2.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
task: squadv2
|
2 |
+
class: !function task.SQuAD2
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/squadv2/task.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Know What You Don’t Know: Unanswerable Questions for SQuAD
|
3 |
+
https://arxiv.org/pdf/1806.03822.pdf
|
4 |
+
|
5 |
+
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset,
|
6 |
+
consisting of questions posed by crowdworkers on a set of Wikipedia articles,
|
7 |
+
where the answer to every question is a segment of text, or span, from the
|
8 |
+
corresponding reading passage, or the question might be unanswerable.
|
9 |
+
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable
|
10 |
+
questions written adversarially by crowdworkers to look similar to answerable ones.
|
11 |
+
To do well on SQuAD2.0, systems must not only answer questions when possible, but
|
12 |
+
also determine when no answer is supported by the paragraph and abstain from answering.
|
13 |
+
|
14 |
+
Homepage: https://rajpurkar.github.io/SQuAD-explorer/
|
15 |
+
"""
|
16 |
+
|
17 |
+
from functools import partial
|
18 |
+
from math import exp
|
19 |
+
|
20 |
+
import datasets
|
21 |
+
from packaging import version
|
22 |
+
|
23 |
+
from lm_eval.api.instance import Instance
|
24 |
+
from lm_eval.api.task import ConfigurableTask
|
25 |
+
|
26 |
+
|
27 |
+
_CITATION = """
|
28 |
+
@misc{rajpurkar2018know,
|
29 |
+
title={Know What You Don't Know: Unanswerable Questions for SQuAD},
|
30 |
+
author={Pranav Rajpurkar and Robin Jia and Percy Liang},
|
31 |
+
year={2018},
|
32 |
+
eprint={1806.03822},
|
33 |
+
archivePrefix={arXiv},
|
34 |
+
primaryClass={cs.CL}
|
35 |
+
}
|
36 |
+
"""
|
37 |
+
|
38 |
+
|
39 |
+
def _squad_metric(predictions, references):
|
40 |
+
squad_metric = datasets.load_metric("squad_v2")
|
41 |
+
return squad_metric.compute(predictions=predictions, references=references)
|
42 |
+
|
43 |
+
|
44 |
+
def _squad_agg(key, items):
|
45 |
+
predictions, references = zip(*items)
|
46 |
+
|
47 |
+
return _squad_metric(predictions=predictions, references=references).get(key, 0)
|
48 |
+
|
49 |
+
|
50 |
+
class SQuAD2(ConfigurableTask):
|
51 |
+
VERSION = 3
|
52 |
+
DATASET_PATH = "squad_v2"
|
53 |
+
DATASET_NAME = None
|
54 |
+
|
55 |
+
def __init__(self, config=None):
|
56 |
+
super().__init__(config={"metadata": {"version": self.VERSION}})
|
57 |
+
|
58 |
+
# HF changed squad on us so we have to make sure we aren't running the old one
|
59 |
+
assert version.parse(datasets.__version__) >= version.parse(
|
60 |
+
"1.11.0"
|
61 |
+
), "datasets v1.11.0 or later required for SQuAD"
|
62 |
+
|
63 |
+
def has_training_docs(self):
|
64 |
+
return True
|
65 |
+
|
66 |
+
def has_validation_docs(self):
|
67 |
+
return True
|
68 |
+
|
69 |
+
def has_test_docs(self):
|
70 |
+
return False
|
71 |
+
|
72 |
+
def training_docs(self):
|
73 |
+
return self.dataset["train"]
|
74 |
+
|
75 |
+
def validation_docs(self):
|
76 |
+
return self.dataset["validation"]
|
77 |
+
|
78 |
+
def doc_to_text(self, doc):
|
79 |
+
return (
|
80 |
+
"Title: "
|
81 |
+
+ doc["title"]
|
82 |
+
+ "\n\n"
|
83 |
+
+ "Background: "
|
84 |
+
+ doc["context"]
|
85 |
+
+ "\n\n"
|
86 |
+
+ "Question: "
|
87 |
+
+ doc["question"]
|
88 |
+
+ "\n\n"
|
89 |
+
+ "Answer:"
|
90 |
+
)
|
91 |
+
|
92 |
+
def should_decontaminate(self):
|
93 |
+
return True
|
94 |
+
|
95 |
+
def doc_to_decontamination_query(self, doc):
|
96 |
+
return doc["context"]
|
97 |
+
|
98 |
+
def doc_to_target(self, doc):
|
99 |
+
answer_list = doc["answers"]["text"]
|
100 |
+
if len(answer_list) > 0:
|
101 |
+
answer = answer_list[0]
|
102 |
+
else:
|
103 |
+
answer = "unanswerable"
|
104 |
+
return " " + answer
|
105 |
+
|
106 |
+
def construct_requests(self, doc, ctx, **kwargs):
|
107 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
108 |
+
Requests which will be sent to the LM.
|
109 |
+
|
110 |
+
:param doc:
|
111 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
112 |
+
:param ctx: str
|
113 |
+
The context string, generated by fewshot_context. This includes the natural
|
114 |
+
language description, as well as the few shot examples, and the question
|
115 |
+
part of the document for `doc`.
|
116 |
+
"""
|
117 |
+
|
118 |
+
return [
|
119 |
+
Instance(
|
120 |
+
request_type="generate_until",
|
121 |
+
doc=doc,
|
122 |
+
arguments=(ctx, {"until": ["\n"]}),
|
123 |
+
idx=0,
|
124 |
+
**kwargs,
|
125 |
+
),
|
126 |
+
Instance(
|
127 |
+
request_type="loglikelihood",
|
128 |
+
doc=doc,
|
129 |
+
arguments=(ctx, " " + "unanswerable"),
|
130 |
+
idx=0,
|
131 |
+
**kwargs,
|
132 |
+
),
|
133 |
+
]
|
134 |
+
|
135 |
+
def process_results(self, doc, results):
|
136 |
+
"""Take a single document and the LM results and evaluates, returning a
|
137 |
+
dict where keys are the names of submetrics and values are the values of
|
138 |
+
the metric for that one document
|
139 |
+
|
140 |
+
:param doc:
|
141 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
142 |
+
:param results:
|
143 |
+
The results of the requests created in construct_requests.
|
144 |
+
"""
|
145 |
+
|
146 |
+
continuation, (logprob_unanswerable, _) = results
|
147 |
+
|
148 |
+
no_answer_probability = exp(logprob_unanswerable)
|
149 |
+
|
150 |
+
predictions = {
|
151 |
+
"id": doc["id"],
|
152 |
+
"prediction_text": continuation,
|
153 |
+
"no_answer_probability": no_answer_probability,
|
154 |
+
}
|
155 |
+
|
156 |
+
references = {
|
157 |
+
"id": doc["id"],
|
158 |
+
"answers": doc["answers"],
|
159 |
+
}
|
160 |
+
|
161 |
+
return {
|
162 |
+
"exact": (
|
163 |
+
predictions,
|
164 |
+
references,
|
165 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
166 |
+
"f1": (
|
167 |
+
predictions,
|
168 |
+
references,
|
169 |
+
), # The F-score of predicted tokens versus the gold answer
|
170 |
+
"HasAns_exact": (
|
171 |
+
predictions,
|
172 |
+
references,
|
173 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
174 |
+
"HasAns_f1": (
|
175 |
+
predictions,
|
176 |
+
references,
|
177 |
+
), # The F-score of predicted tokens versus the gold answer
|
178 |
+
"NoAns_exact": (
|
179 |
+
predictions,
|
180 |
+
references,
|
181 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
182 |
+
"NoAns_f1": (
|
183 |
+
predictions,
|
184 |
+
references,
|
185 |
+
), # The F-score of predicted tokens versus the gold answer
|
186 |
+
"best_exact": (
|
187 |
+
predictions,
|
188 |
+
references,
|
189 |
+
), # Best exact match (with varying threshold)
|
190 |
+
"best_f1": (predictions, references), # Best F1 (with varying threshold)
|
191 |
+
}
|
192 |
+
|
193 |
+
def aggregation(self):
|
194 |
+
"""
|
195 |
+
:returns: {str: [float] -> float}
|
196 |
+
A dictionary where keys are the names of submetrics and values are
|
197 |
+
functions that aggregate a list of metrics
|
198 |
+
"""
|
199 |
+
return {
|
200 |
+
"exact": partial(
|
201 |
+
_squad_agg, "exact"
|
202 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
203 |
+
"f1": partial(
|
204 |
+
_squad_agg, "f1"
|
205 |
+
), # The F-score of predicted tokens versus the gold answer
|
206 |
+
"HasAns_exact": partial(
|
207 |
+
_squad_agg, "HasAns_exact"
|
208 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
209 |
+
"HasAns_f1": partial(
|
210 |
+
_squad_agg, "HasAns_f1"
|
211 |
+
), # The F-score of predicted tokens versus the gold answer
|
212 |
+
"NoAns_exact": partial(
|
213 |
+
_squad_agg, "NoAns_exact"
|
214 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
215 |
+
"NoAns_f1": partial(
|
216 |
+
_squad_agg, "NoAns_f1"
|
217 |
+
), # The F-score of predicted tokens versus the gold answer
|
218 |
+
"best_exact": partial(
|
219 |
+
_squad_agg, "best_exact"
|
220 |
+
), # Best exact match (with varying threshold)
|
221 |
+
"best_f1": partial(
|
222 |
+
_squad_agg, "best_f1"
|
223 |
+
), # Best F1 (with varying threshold)
|
224 |
+
}
|
225 |
+
|
226 |
+
def higher_is_better(self):
|
227 |
+
"""
|
228 |
+
:returns: {str: bool}
|
229 |
+
A dictionary where keys are the names of submetrics and values are
|
230 |
+
whether a higher value of the submetric is better
|
231 |
+
"""
|
232 |
+
return {
|
233 |
+
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
|
234 |
+
"f1": True, # The F-score of predicted tokens versus the gold answer
|
235 |
+
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
|
236 |
+
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
|
237 |
+
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
|
238 |
+
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
|
239 |
+
"best_exact": True, # Best exact match (with varying threshold)
|
240 |
+
"best_f1": True, # Best F1 (with varying threshold)
|
241 |
+
}
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# XCOPA
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: `XCOPA: A Multilingual Dataset for Causal Commonsense Reasoning`
|
6 |
+
|
7 |
+
Abstract: https://ducdauge.github.io/files/xcopa.pdf
|
8 |
+
|
9 |
+
The Cross-lingual Choice of Plausible Alternatives dataset is a benchmark to evaluate the ability of machine learning models to transfer commonsense reasoning across languages.
|
10 |
+
The dataset is the translation and reannotation of the English COPA (Roemmele et al. 2011) and covers 11 languages from 11 families and several areas around the globe.
|
11 |
+
The dataset is challenging as it requires both the command of world knowledge and the ability to generalise to new languages.
|
12 |
+
All the details about the creation of XCOPA and the implementation of the baselines are available in the paper.
|
13 |
+
|
14 |
+
Homepage: https://github.com/cambridgeltl/xcopa
|
15 |
+
|
16 |
+
### Citation
|
17 |
+
|
18 |
+
```
|
19 |
+
@inproceedings{ponti2020xcopa,
|
20 |
+
title={{XCOPA: A} Multilingual Dataset for Causal Commonsense Reasoning},
|
21 |
+
author={Edoardo M. Ponti, Goran Glava\v{s}, Olga Majewska, Qianchu Liu, Ivan Vuli\'{c} and Anna Korhonen},
|
22 |
+
booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)},
|
23 |
+
year={2020},
|
24 |
+
url={https://ducdauge.github.io/files/xcopa.pdf}
|
25 |
+
}
|
26 |
+
```
|
27 |
+
|
28 |
+
### Groups and Tasks
|
29 |
+
|
30 |
+
#### Groups
|
31 |
+
|
32 |
+
* `xcopa`
|
33 |
+
|
34 |
+
#### Tasks
|
35 |
+
|
36 |
+
* `xcopa_et`: Estonian
|
37 |
+
* `xcopa_ht`: Haitian Creole
|
38 |
+
* `xcopa_id`: Indonesian
|
39 |
+
* `xcopa_it`: Italian
|
40 |
+
* `xcopa_qu`: Cusco-Collao Quechua
|
41 |
+
* `xcopa_sw`: Kiswahili
|
42 |
+
* `xcopa_ta`: Tamil
|
43 |
+
* `xcopa_th`: Thai
|
44 |
+
* `xcopa_tr`: Turkish
|
45 |
+
* `xcopa_vi`: Vietnamese
|
46 |
+
* `xcopa_zh`: Mandarin Chinese
|
47 |
+
|
48 |
+
|
49 |
+
### Checklist
|
50 |
+
|
51 |
+
For adding novel benchmarks/datasets to the library:
|
52 |
+
* [ ] Is the task an existing benchmark in the literature?
|
53 |
+
* [ ] Have you referenced the original paper that introduced the task?
|
54 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
55 |
+
|
56 |
+
|
57 |
+
If other tasks on this dataset are already supported:
|
58 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
59 |
+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
60 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/_xcopa.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group: xcopa
|
2 |
+
task:
|
3 |
+
- xcopa_et
|
4 |
+
- xcopa_ht
|
5 |
+
- xcopa_id
|
6 |
+
- xcopa_it
|
7 |
+
- xcopa_qu
|
8 |
+
- xcopa_sw
|
9 |
+
- xcopa_ta
|
10 |
+
- xcopa_th
|
11 |
+
- xcopa_tr
|
12 |
+
- xcopa_vi
|
13 |
+
- xcopa_zh
|
14 |
+
aggregate_metric_list:
|
15 |
+
- metric: acc
|
16 |
+
aggregation: mean
|
17 |
+
weight_by_size: True
|
18 |
+
metadata:
|
19 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_et.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: xcopa_et
|
2 |
+
dataset_path: xcopa
|
3 |
+
dataset_name: et
|
4 |
+
output_type: multiple_choice
|
5 |
+
validation_split: validation
|
6 |
+
test_split: test
|
7 |
+
doc_to_text: !function utils.doc_to_text_et
|
8 |
+
doc_to_target: label
|
9 |
+
doc_to_choice: !function utils.doc_to_choice
|
10 |
+
metric_list:
|
11 |
+
- metric: acc
|
12 |
+
metadata:
|
13 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_ht.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_ht
|
3 |
+
dataset_name: ht
|
4 |
+
doc_to_text: !function utils.doc_to_text_ht
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_id.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_id
|
3 |
+
dataset_name: id
|
4 |
+
doc_to_text: !function utils.doc_to_text_id
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_it.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_it
|
3 |
+
dataset_name: it
|
4 |
+
doc_to_text: !function utils.doc_to_text_it
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_qu.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_qu
|
3 |
+
dataset_name: qu
|
4 |
+
doc_to_text: !function utils.doc_to_text_qu
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_sw.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_sw
|
3 |
+
dataset_name: sw
|
4 |
+
doc_to_text: !function utils.doc_to_text_sw
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_ta.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_ta
|
3 |
+
dataset_name: ta
|
4 |
+
doc_to_text: !function utils.doc_to_text_ta
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_th.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_th
|
3 |
+
dataset_name: th
|
4 |
+
doc_to_text: !function utils.doc_to_text_th
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/xcopa/default_tr.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: default_et.yaml
|
2 |
+
task: xcopa_tr
|
3 |
+
dataset_name: tr
|
4 |
+
doc_to_text: !function utils.doc_to_text_tr
|