Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/__main__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/evaluator.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/evaluator_utils.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/utils.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/caching/__init__.py +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/caching/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/caching/__pycache__/cache.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/caching/cache.py +55 -0
- scripts/yans/lm-evaluation-harness/lm_eval/decontamination/__init__.py +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/decontamination/archiver.py +171 -0
- scripts/yans/lm-evaluation-harness/lm_eval/decontamination/decontaminate.py +166 -0
- scripts/yans/lm-evaluation-harness/lm_eval/decontamination/janitor.py +328 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/__init__.py +2 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/evaluation_tracker.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/utils.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/wandb_logger.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py +521 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/utils.py +143 -0
- scripts/yans/lm-evaluation-harness/lm_eval/loggers/wandb_logger.py +352 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/dummy.py +41 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/gguf.py +130 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/mamba_lm.py +126 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/neuron_optimum.py +737 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/vllm_causallms.py +540 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/bleu.py +241 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/go.yaml +21 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/java.yaml +21 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/javascript.yaml +21 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/php.yaml +21 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/python.yaml +21 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/ruby.yaml +21 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/utils.py +12 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/README.md +54 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/commonsense.yaml +15 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/deontology.yaml +9 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/justice.yaml +9 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/utilitarianism.yaml +12 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/utilitarianism_original_yaml +16 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/utils.py +25 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/virtue.yaml +10 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/mc_taco/README.md +53 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/mc_taco/default.yaml +15 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/pubmedqa/README.md +56 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/pubmedqa/preprocess_pubmedqa.py +6 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/pubmedqa/pubmedqa.yaml +16 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/qa4mre/README.md +55 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/qa4mre/preprocess_qa4mre.py +6 -0
- scripts/yans/lm-evaluation-harness/lm_eval/tasks/qa4mre/qa4mre_2011.yaml +22 -0
scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (224 Bytes). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/__main__.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/evaluator.cpython-310.pyc
ADDED
Binary file (15.5 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/evaluator_utils.cpython-310.pyc
ADDED
Binary file (15 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (15.5 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/caching/__init__.py
ADDED
File without changes
|
scripts/yans/lm-evaluation-harness/lm_eval/caching/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (164 Bytes). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/caching/__pycache__/cache.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/caching/cache.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import dill
|
5 |
+
|
6 |
+
from lm_eval.utils import eval_logger
|
7 |
+
|
8 |
+
|
9 |
+
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
|
10 |
+
|
11 |
+
OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
|
12 |
+
|
13 |
+
|
14 |
+
PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
|
15 |
+
|
16 |
+
# This should be sufficient for uniqueness
|
17 |
+
HASH_INPUT = "EleutherAI-lm-evaluation-harness"
|
18 |
+
|
19 |
+
HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
|
20 |
+
|
21 |
+
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
|
22 |
+
|
23 |
+
|
24 |
+
def load_from_cache(file_name):
|
25 |
+
try:
|
26 |
+
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
27 |
+
|
28 |
+
with open(path, "rb") as file:
|
29 |
+
cached_task_dict = dill.loads(file.read())
|
30 |
+
return cached_task_dict
|
31 |
+
|
32 |
+
except Exception:
|
33 |
+
eval_logger.debug(f"{file_name} is not cached, generating...")
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
def save_to_cache(file_name, obj):
|
38 |
+
if not os.path.exists(PATH):
|
39 |
+
os.mkdir(PATH)
|
40 |
+
|
41 |
+
file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
42 |
+
|
43 |
+
eval_logger.debug(f"Saving {file_path} to cache...")
|
44 |
+
with open(file_path, "wb") as file:
|
45 |
+
file.write(dill.dumps(obj))
|
46 |
+
|
47 |
+
|
48 |
+
# NOTE the "key" param is to allow for flexibility
|
49 |
+
def delete_cache(key: str = ""):
|
50 |
+
files = os.listdir(PATH)
|
51 |
+
|
52 |
+
for file in files:
|
53 |
+
if file.startswith(key) and file.endswith(FILE_SUFFIX):
|
54 |
+
file_path = f"{PATH}/{file}"
|
55 |
+
os.unlink(file_path)
|
scripts/yans/lm-evaluation-harness/lm_eval/decontamination/__init__.py
ADDED
File without changes
|
scripts/yans/lm-evaluation-harness/lm_eval/decontamination/archiver.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import mmap
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
import jsonlines
|
10 |
+
import tqdm
|
11 |
+
import zstandard
|
12 |
+
|
13 |
+
|
14 |
+
def json_serial(obj: Any) -> str:
|
15 |
+
"""JSON serializer for objects not serializable by default json code"""
|
16 |
+
|
17 |
+
if isinstance(obj, (datetime.datetime,)):
|
18 |
+
return obj.isoformat()
|
19 |
+
raise TypeError("Type %s not serializable" % type(obj))
|
20 |
+
|
21 |
+
|
22 |
+
# Modified version of lm_dataformat Archive for single file.
|
23 |
+
class Archive:
|
24 |
+
def __init__(self, file_path: str, compression_level: int = 3) -> None:
|
25 |
+
self.file_path = file_path
|
26 |
+
dir_name = os.path.dirname(file_path)
|
27 |
+
if dir_name:
|
28 |
+
os.makedirs(dir_name, exist_ok=True)
|
29 |
+
self.fh = open(self.file_path, "wb")
|
30 |
+
self.cctx = zstandard.ZstdCompressor(level=compression_level)
|
31 |
+
self.compressor = self.cctx.stream_writer(self.fh)
|
32 |
+
|
33 |
+
def add_data(self, data, meta=None) -> None:
|
34 |
+
if meta is None:
|
35 |
+
meta = {}
|
36 |
+
self.compressor.write(
|
37 |
+
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
|
38 |
+
"UTF-8"
|
39 |
+
)
|
40 |
+
+ b"\n"
|
41 |
+
)
|
42 |
+
|
43 |
+
def commit(self) -> None:
|
44 |
+
self.compressor.flush(zstandard.FLUSH_FRAME)
|
45 |
+
self.fh.flush()
|
46 |
+
self.fh.close()
|
47 |
+
|
48 |
+
|
49 |
+
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
|
50 |
+
class Reader:
|
51 |
+
def __init__(self) -> None:
|
52 |
+
pass
|
53 |
+
|
54 |
+
def read(
|
55 |
+
self,
|
56 |
+
file,
|
57 |
+
get_meta: bool = False,
|
58 |
+
autojoin_paragraphs: bool = True,
|
59 |
+
para_joiner: str = "\n\n",
|
60 |
+
):
|
61 |
+
with open(file, "rb") as fh:
|
62 |
+
self.fh = fh
|
63 |
+
cctx = zstandard.ZstdDecompressor()
|
64 |
+
reader = io.BufferedReader(cctx.stream_reader(fh))
|
65 |
+
rdr = jsonlines.Reader(reader)
|
66 |
+
for ob in rdr:
|
67 |
+
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
|
68 |
+
if isinstance(ob, str):
|
69 |
+
assert not get_meta
|
70 |
+
yield ob
|
71 |
+
continue
|
72 |
+
|
73 |
+
text = ob["text"]
|
74 |
+
|
75 |
+
if autojoin_paragraphs and isinstance(text, list):
|
76 |
+
text = para_joiner.join(text)
|
77 |
+
|
78 |
+
if get_meta:
|
79 |
+
yield text, (ob["meta"] if "meta" in ob else {})
|
80 |
+
else:
|
81 |
+
yield text
|
82 |
+
|
83 |
+
|
84 |
+
class TextArchive:
|
85 |
+
def __init__(self, file_path, mode: str = "rb+") -> None:
|
86 |
+
self.file_path = file_path
|
87 |
+
dir_name = os.path.dirname(file_path)
|
88 |
+
if dir_name:
|
89 |
+
os.makedirs(dir_name, exist_ok=True)
|
90 |
+
|
91 |
+
if not os.path.exists(file_path):
|
92 |
+
Path(file_path).touch()
|
93 |
+
|
94 |
+
self.fh = open(self.file_path, mode)
|
95 |
+
|
96 |
+
def add_data(self, data) -> None:
|
97 |
+
self.fh.write(data.encode("UTF-8") + b"\n")
|
98 |
+
|
99 |
+
def commit(self) -> None:
|
100 |
+
self.fh.flush()
|
101 |
+
self.fh.close()
|
102 |
+
|
103 |
+
|
104 |
+
class TextReader:
|
105 |
+
def __init__(self, file_path) -> None:
|
106 |
+
self.file_path = file_path
|
107 |
+
|
108 |
+
# Optimized mmap read with infrequent tqdm updates to maintain speed
|
109 |
+
# Tested up to 250MB/s.
|
110 |
+
def read_tqdm(self, update_frequency: int = 10000):
|
111 |
+
current_file_position = 0
|
112 |
+
line_counter = 0
|
113 |
+
with open(self.file_path, "r", encoding="utf-8") as fh, tqdm.tqdm(
|
114 |
+
total=os.path.getsize(self.file_path),
|
115 |
+
dynamic_ncols=True,
|
116 |
+
unit="byte",
|
117 |
+
unit_scale=1,
|
118 |
+
) as progress:
|
119 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
120 |
+
for line in iter(mmap_obj.readline, b""):
|
121 |
+
line = line.decode("utf-8")
|
122 |
+
line_counter += 1
|
123 |
+
if line_counter == update_frequency:
|
124 |
+
new_file_pos = mmap_obj.tell()
|
125 |
+
bytes_read = new_file_pos - current_file_position
|
126 |
+
current_file_position = new_file_pos
|
127 |
+
progress.update(bytes_read)
|
128 |
+
line_counter = 0
|
129 |
+
yield line[:-1]
|
130 |
+
|
131 |
+
def read_and_tell(self):
|
132 |
+
current_file_position = 0
|
133 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
134 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
135 |
+
for line in iter(mmap_obj.readline, b""):
|
136 |
+
line = line.decode("utf-8")
|
137 |
+
new_file_pos = mmap_obj.tell()
|
138 |
+
raw_bytes_read = new_file_pos - current_file_position
|
139 |
+
current_file_position = new_file_pos
|
140 |
+
yield line[:-1], raw_bytes_read
|
141 |
+
|
142 |
+
def read(self):
|
143 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
144 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
145 |
+
for line in iter(mmap_obj.readline, b""):
|
146 |
+
line = line.decode("utf-8")
|
147 |
+
yield line[:-1]
|
148 |
+
|
149 |
+
def read_slow(self):
|
150 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
151 |
+
while True:
|
152 |
+
line = fh.readline()
|
153 |
+
if line == -1 or line == "":
|
154 |
+
break
|
155 |
+
else:
|
156 |
+
yield line[:-1]
|
157 |
+
|
158 |
+
|
159 |
+
# Optimized for speed. Decompresses the archive in shell before
|
160 |
+
# using the mmap'd TextReader.
|
161 |
+
class ZStdTextReader:
|
162 |
+
def __init__(self, file) -> None:
|
163 |
+
self.file = file
|
164 |
+
|
165 |
+
def read_tqdm(self):
|
166 |
+
decompressed_file = self.file[:-4]
|
167 |
+
print("Decompressing file, please wait...")
|
168 |
+
os.system(f"zstd -d {self.file}") # linux decompress is faster
|
169 |
+
reader = TextReader(decompressed_file)
|
170 |
+
yield from reader.read_tqdm()
|
171 |
+
os.remove(decompressed_file)
|
scripts/yans/lm-evaluation-harness/lm_eval/decontamination/decontaminate.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
|
9 |
+
from .archiver import ZStdTextReader
|
10 |
+
from .janitor import Janitor, word_ngrams
|
11 |
+
|
12 |
+
|
13 |
+
# Was used for testing the evaluator decoupled from the full logic below
|
14 |
+
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
|
15 |
+
simulated_overlap = 0.1
|
16 |
+
contaminated = int(len(docs) * simulated_overlap)
|
17 |
+
return random.sample(range(len(docs)), contaminated)
|
18 |
+
|
19 |
+
|
20 |
+
# Returns a dictionary containing all overlapping documents in each
|
21 |
+
# task. In the standard use case, an overlap occurs when any of the 13-grams
|
22 |
+
# found in the task document exist in the training set documents.
|
23 |
+
#
|
24 |
+
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
|
25 |
+
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
|
26 |
+
# files. These should exist in the "ngrams_path" provided to this function.
|
27 |
+
|
28 |
+
|
29 |
+
# Algorithm:
|
30 |
+
# 1. Build lookups for each dataset {ngram: list(document_ids)}
|
31 |
+
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
|
32 |
+
# 3. Full scan the 13-grams from the training set against the merged lookup,
|
33 |
+
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
|
34 |
+
# 4. Strip the task_set from the dictionary keys and return
|
35 |
+
#
|
36 |
+
# We cache the task+set lookups as well as the overlaps.
|
37 |
+
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
|
38 |
+
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
|
39 |
+
|
40 |
+
info_dict_path = os.path.join(ngrams_path, "info.json")
|
41 |
+
info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
|
42 |
+
ngrams_n_size = info_dict["ngram_size"]
|
43 |
+
|
44 |
+
janitor = Janitor()
|
45 |
+
|
46 |
+
# Build lookup for each dataset first in case we use different task combinations later
|
47 |
+
print("Building Lookups...")
|
48 |
+
start = time.perf_counter()
|
49 |
+
|
50 |
+
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
|
51 |
+
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
|
52 |
+
|
53 |
+
lookups = {}
|
54 |
+
duplicates = {} # (task_name, task_set): set(doc_ids)}
|
55 |
+
sets_to_decontaminate = len(docs_by_task_set.keys())
|
56 |
+
|
57 |
+
for (task_name, task_set), docs in docs_by_task_set.items():
|
58 |
+
if not os.path.exists(f"data/{task_name}"):
|
59 |
+
os.mkdir(f"data/{task_name}")
|
60 |
+
|
61 |
+
# Check if we've decontaminated this combination before
|
62 |
+
overlaps_dump_path = get_overlaps_dump_path(
|
63 |
+
task_name, task_set, ngrams_n_size, limit
|
64 |
+
)
|
65 |
+
if os.path.exists(overlaps_dump_path):
|
66 |
+
duplicates[(task_name, task_set)] = pickle.load(
|
67 |
+
open(overlaps_dump_path, "rb")
|
68 |
+
)
|
69 |
+
sets_to_decontaminate -= 1
|
70 |
+
continue
|
71 |
+
else:
|
72 |
+
duplicates[(task_name, task_set)] = set()
|
73 |
+
|
74 |
+
# Build/load the task lookup {ngram: set(documents)}.
|
75 |
+
task_set_lookup_path = (
|
76 |
+
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
|
77 |
+
)
|
78 |
+
if os.path.exists(task_set_lookup_path):
|
79 |
+
print(f"{task_set_lookup_path} available, loading...")
|
80 |
+
lookups[(task_name, task_set)] = pickle.load(
|
81 |
+
open(task_set_lookup_path, "rb")
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
print(f"{task_set_lookup_path} not available, building...")
|
85 |
+
lookup = collections.defaultdict(set)
|
86 |
+
|
87 |
+
for doc_id, document in enumerate(docs):
|
88 |
+
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
|
89 |
+
for ngram in ngrams:
|
90 |
+
lookup[ngram].add(doc_id)
|
91 |
+
|
92 |
+
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
|
93 |
+
lookups[(task_name, task_set)] = lookup
|
94 |
+
|
95 |
+
elapsed = time.perf_counter() - start
|
96 |
+
print(f"Building lookups took {elapsed:0.5f} seconds.")
|
97 |
+
|
98 |
+
matched_ngrams = []
|
99 |
+
|
100 |
+
if sets_to_decontaminate > 0:
|
101 |
+
print("Merging lookups...")
|
102 |
+
start = time.perf_counter()
|
103 |
+
merged_lookup = collections.defaultdict(list)
|
104 |
+
for (task_name, task_set), lookup in lookups.items():
|
105 |
+
for ngram, doc_ids in lookup.items():
|
106 |
+
merged_lookup[ngram].append((task_name, task_set, doc_ids))
|
107 |
+
|
108 |
+
elapsed = time.perf_counter() - start
|
109 |
+
print(f"Merging lookups took {elapsed:0.5f} seconds.")
|
110 |
+
|
111 |
+
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
|
112 |
+
files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
|
113 |
+
print(files)
|
114 |
+
|
115 |
+
for file in files:
|
116 |
+
start = time.perf_counter()
|
117 |
+
print(f"Scanning {file}")
|
118 |
+
reader = ZStdTextReader(file)
|
119 |
+
total_ngrams = 0
|
120 |
+
unique_ngrams = 0
|
121 |
+
matching_unique = 0
|
122 |
+
non_matching_unique = 0
|
123 |
+
|
124 |
+
current_ngram = ""
|
125 |
+
for line in reader.read_tqdm(): # Scan training set ngrams file
|
126 |
+
total_ngrams += 1
|
127 |
+
[ngram, document_id] = line.rsplit(" ", 1)
|
128 |
+
if (
|
129 |
+
ngram != current_ngram
|
130 |
+
): # Only need to match the ngram once in training set
|
131 |
+
unique_ngrams += 1
|
132 |
+
current_ngram = ngram
|
133 |
+
if ngram in merged_lookup:
|
134 |
+
matched_ngrams.append(ngram) # For logging
|
135 |
+
matching_unique += 1
|
136 |
+
for task_name, task_set, doc_ids in merged_lookup[ngram]:
|
137 |
+
task_doc_set = duplicates[(task_name, task_set)]
|
138 |
+
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
|
139 |
+
task_doc_set.add(doc_id)
|
140 |
+
del merged_lookup[ngram] # No point matching again
|
141 |
+
else:
|
142 |
+
non_matching_unique += 1
|
143 |
+
|
144 |
+
print(f"Total Ngrams: {total_ngrams}")
|
145 |
+
print(f"Unique Ngrams: {unique_ngrams}")
|
146 |
+
print(f"Unique Matching: {matching_unique}")
|
147 |
+
print(f"Unique Non Matching: {non_matching_unique}")
|
148 |
+
print("Matched ngrams:")
|
149 |
+
for ngram in matched_ngrams:
|
150 |
+
print(ngram)
|
151 |
+
|
152 |
+
elapsed = time.perf_counter() - start
|
153 |
+
print(f"Read took {elapsed:0.5f} seconds.")
|
154 |
+
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")
|
155 |
+
|
156 |
+
print(duplicates)
|
157 |
+
|
158 |
+
# Dump overlaps separately
|
159 |
+
for (task_name, task_set), doc_ids in duplicates.items():
|
160 |
+
overlaps_dump_path = get_overlaps_dump_path(
|
161 |
+
task_name, task_set, ngrams_n_size, limit
|
162 |
+
)
|
163 |
+
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
|
164 |
+
|
165 |
+
# Strip task set and return
|
166 |
+
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
|
scripts/yans/lm-evaluation-harness/lm_eval/decontamination/janitor.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import re
|
3 |
+
import string
|
4 |
+
import traceback
|
5 |
+
from typing import Iterator, List, Sequence, Tuple, TypeVar
|
6 |
+
|
7 |
+
|
8 |
+
# This is a cpp module. Compile janitor_util.cpp with:
|
9 |
+
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
|
10 |
+
try:
|
11 |
+
import janitor_util
|
12 |
+
|
13 |
+
JANITOR_CPP = True
|
14 |
+
except Exception:
|
15 |
+
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
|
16 |
+
traceback.print_exc()
|
17 |
+
JANITOR_CPP = False
|
18 |
+
|
19 |
+
T = TypeVar("T")
|
20 |
+
|
21 |
+
|
22 |
+
# Implementation from nltk source
|
23 |
+
# https://www.nltk.org/_modules/nltk/util.html
|
24 |
+
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
|
25 |
+
history = []
|
26 |
+
while n > 1:
|
27 |
+
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
|
28 |
+
try:
|
29 |
+
next_item = next(sequence)
|
30 |
+
except StopIteration:
|
31 |
+
# no more data, terminate the generator
|
32 |
+
return
|
33 |
+
history.append(next_item)
|
34 |
+
n -= 1
|
35 |
+
for item in sequence:
|
36 |
+
history.append(item)
|
37 |
+
yield tuple(history)
|
38 |
+
del history[0]
|
39 |
+
|
40 |
+
|
41 |
+
def word_ngrams(s: str, n: int) -> Iterator[str]:
|
42 |
+
"""Splits a string into ngram words"""
|
43 |
+
tokens = s.split() # not a generator :(
|
44 |
+
ngram_seqs = form_ngrams(iter(tokens), n)
|
45 |
+
return (" ".join(ngram) for ngram in ngram_seqs)
|
46 |
+
|
47 |
+
|
48 |
+
# Does character sequences only - combined faster function to play around with later
|
49 |
+
# def word_ngrams_indices_combined(sequence, n):
|
50 |
+
# current_word = ""
|
51 |
+
# history = []
|
52 |
+
# gap = False;
|
53 |
+
# start = 0
|
54 |
+
# end = 0
|
55 |
+
# for character in sequence:
|
56 |
+
# if character == " ":
|
57 |
+
# if not gap:
|
58 |
+
# gap = True
|
59 |
+
# history.append(current_word)
|
60 |
+
# end += len(current_word) - 1
|
61 |
+
# current_word = ""
|
62 |
+
# if len(history) == n:
|
63 |
+
# yield (tuple(history), start, end)
|
64 |
+
# del history[0]
|
65 |
+
# start = end + 1
|
66 |
+
# end = start
|
67 |
+
# else:
|
68 |
+
# gap = False
|
69 |
+
# current_word += character
|
70 |
+
|
71 |
+
|
72 |
+
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
|
73 |
+
def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
74 |
+
"""Splits a string on whitespaces and records the indices of each in the original string.
|
75 |
+
@:return generator((word, (start_idx, end_idx)), ...)
|
76 |
+
"""
|
77 |
+
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
|
78 |
+
|
79 |
+
|
80 |
+
def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
81 |
+
"""Splits a string into pairs of (ngram words, their start/end indices)"""
|
82 |
+
tokens_with_indices = split_indices(s)
|
83 |
+
|
84 |
+
# Generator of ngrams of (word, idx_pairs)
|
85 |
+
# (
|
86 |
+
# [(word, (start,end)), (word, (start, end))...],
|
87 |
+
# [(word, (start, end)), ...],
|
88 |
+
# ...
|
89 |
+
# )
|
90 |
+
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
|
91 |
+
|
92 |
+
# Generator of pairs of word and index ngrams
|
93 |
+
# (
|
94 |
+
# ([word, word, ...], [(start,end), (start,end), ...]),
|
95 |
+
# ...
|
96 |
+
# )
|
97 |
+
ngram_indices_pairs = (
|
98 |
+
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
|
99 |
+
)
|
100 |
+
|
101 |
+
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
|
102 |
+
return (
|
103 |
+
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
|
104 |
+
for ngram_seq, indices in ngram_indices_pairs
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
class Janitor:
|
109 |
+
# FIXME delete_chars: Should anything else go here? Special chars?
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
ngram_n: int = 13,
|
113 |
+
window_to_remove: int = 200,
|
114 |
+
too_dirty_cutoff: int = 10,
|
115 |
+
minimum_slice_length: int = 200,
|
116 |
+
delete_chars: str = string.punctuation,
|
117 |
+
) -> None:
|
118 |
+
self.ngram_n = ngram_n
|
119 |
+
self.window_to_remove = window_to_remove
|
120 |
+
self.too_dirty_cutoff = too_dirty_cutoff
|
121 |
+
self.minimum_slice_length = minimum_slice_length
|
122 |
+
self.delete_chars = delete_chars
|
123 |
+
|
124 |
+
self.dirt_ngrams = set()
|
125 |
+
|
126 |
+
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
|
127 |
+
# This is fast by python standards
|
128 |
+
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
|
129 |
+
self.translation_table = str.maketrans(
|
130 |
+
string.ascii_lowercase + string.ascii_uppercase, # These characters
|
131 |
+
string.ascii_lowercase * 2, # Become these characters
|
132 |
+
self.delete_chars, # These are deleted
|
133 |
+
)
|
134 |
+
|
135 |
+
##############
|
136 |
+
# I/O for saving contamination ngrams
|
137 |
+
##############
|
138 |
+
|
139 |
+
def save_contamination_ngrams(self, filename: str) -> None:
|
140 |
+
with open(filename, "wb") as fp:
|
141 |
+
pickle.dump(filename, fp)
|
142 |
+
|
143 |
+
def load_contamination_ngrams(self, filename: str) -> None:
|
144 |
+
with open(filename, "rb") as fp:
|
145 |
+
self.dirt_ngrams = pickle.load(fp)
|
146 |
+
|
147 |
+
##############
|
148 |
+
# Call these :)
|
149 |
+
##############
|
150 |
+
|
151 |
+
def register_contaminant(self, dirt_string: str) -> None:
|
152 |
+
"""Register a string as contamination to be removed, e.g. a test set
|
153 |
+
This breaks the dirt_string into ngrams to store for future cleaning"""
|
154 |
+
if JANITOR_CPP:
|
155 |
+
return self.register_contaminant_cpp(dirt_string)
|
156 |
+
else:
|
157 |
+
print("WARNING: Janitor running in python mode")
|
158 |
+
return self.register_contaminant_python(dirt_string)
|
159 |
+
|
160 |
+
def clean(self, dirty_string: str) -> List[str]:
|
161 |
+
"""Clean a string (e.g. a training set) by removing all ngrams previously
|
162 |
+
registered as contaminants. Returns a list of clean chunks, or empty if
|
163 |
+
the string was too dirty"""
|
164 |
+
if JANITOR_CPP:
|
165 |
+
return self.clean_cpp(dirty_string)
|
166 |
+
else:
|
167 |
+
print("WARNING: Janitor running in python mode")
|
168 |
+
return self.clean_python(dirty_string)
|
169 |
+
|
170 |
+
def _split_chunks(
|
171 |
+
self, dirty_string: str, dirty_parts: Sequence[Tuple]
|
172 |
+
) -> List[str]:
|
173 |
+
clean_chunks = []
|
174 |
+
splice_idx = 0
|
175 |
+
end = -1
|
176 |
+
for i, (ngram, start, end) in enumerate(dirty_parts):
|
177 |
+
if i >= self.too_dirty_cutoff:
|
178 |
+
return []
|
179 |
+
start = max(0, start - self.window_to_remove)
|
180 |
+
end = min(len(dirty_string), end + self.window_to_remove)
|
181 |
+
|
182 |
+
if start - splice_idx > self.minimum_slice_length:
|
183 |
+
clean_chunks.append(dirty_string[splice_idx:start])
|
184 |
+
splice_idx = end
|
185 |
+
|
186 |
+
if end < len(dirty_string) - self.minimum_slice_length:
|
187 |
+
clean_chunks.append(dirty_string[end + 1 :])
|
188 |
+
|
189 |
+
return clean_chunks
|
190 |
+
|
191 |
+
##############
|
192 |
+
# Fast C++
|
193 |
+
##############
|
194 |
+
|
195 |
+
def register_contaminant_cpp(self, dirt_string) -> None:
|
196 |
+
self.dirt_ngrams.update(
|
197 |
+
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
|
198 |
+
)
|
199 |
+
|
200 |
+
def clean_cpp(self, dirty_string: str) -> List[str]:
|
201 |
+
contamination_indices = janitor_util.clean_ngram_with_indices(
|
202 |
+
dirty_string, self.delete_chars, self.ngram_n
|
203 |
+
)
|
204 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
205 |
+
|
206 |
+
##############
|
207 |
+
# Slow python
|
208 |
+
##############
|
209 |
+
|
210 |
+
def normalize_string(self, s: str) -> str:
|
211 |
+
return s.translate(self.translation_table)
|
212 |
+
|
213 |
+
def register_contaminant_python(self, dirt_string: str) -> None:
|
214 |
+
self.dirt_ngrams.update(
|
215 |
+
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
|
216 |
+
)
|
217 |
+
|
218 |
+
def clean_python(self, dirty_string: str) -> List[str]:
|
219 |
+
contamination_indices = (
|
220 |
+
(None, *idx_pair)
|
221 |
+
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
|
222 |
+
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
|
223 |
+
)
|
224 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
225 |
+
|
226 |
+
|
227 |
+
##################################################################
|
228 |
+
# Tests
|
229 |
+
#################################################################
|
230 |
+
|
231 |
+
# def print_cpp():
|
232 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
233 |
+
|
234 |
+
# for i in range(1, 10, 2):
|
235 |
+
# pprint(janitor_util.clean_ngram(source, string.punctuation, i))
|
236 |
+
# for ngram, start, end in \
|
237 |
+
# janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
|
238 |
+
# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
|
239 |
+
|
240 |
+
|
241 |
+
# def test_cpp():
|
242 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
243 |
+
# contaminant = "dirty boy. Clean he he"
|
244 |
+
|
245 |
+
# jan_python = Janitor()
|
246 |
+
# jan_cpp = Janitor()
|
247 |
+
|
248 |
+
# jan_python.register_contaminant_python(contaminant)
|
249 |
+
# jan_cpp.register_contaminant(contaminant)
|
250 |
+
|
251 |
+
# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
|
252 |
+
|
253 |
+
# assert jan_python.clean_python(source) == jan_cpp.clean(source), \
|
254 |
+
# (jan_python.clean_python(source), jan_cpp.clean(source))
|
255 |
+
|
256 |
+
# print("Passed test, python==cpp")
|
257 |
+
|
258 |
+
|
259 |
+
# def benchmark():
|
260 |
+
# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
|
261 |
+
# setup = \
|
262 |
+
# """
|
263 |
+
# with open("data/enwik8", "r") as f:
|
264 |
+
# data = f.read()
|
265 |
+
# jan = Janitor(too_dirty_cutoff=1000)
|
266 |
+
# jan.register_contaminant('''
|
267 |
+
# theories is that there is a connection between "geekdom" and autism.
|
268 |
+
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled "
|
269 |
+
# The [[Geek]] Syndrome", which is a point argued by many in the autism rights
|
270 |
+
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of
|
271 |
+
# the media's application of mental disease labels to what is actually variant normal behavior
|
272 |
+
# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
|
273 |
+
# interests, even when they seem unusual to others, are not in themselves signs of autism or
|
274 |
+
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying
|
275 |
+
# mental disease labels to children who in the past would have simply been accepted as a little
|
276 |
+
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
|
277 |
+
# Due to the recent publicity surrounding autism and autis
|
278 |
+
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
|
279 |
+
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
|
280 |
+
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
|
281 |
+
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
|
282 |
+
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
|
283 |
+
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
|
284 |
+
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
|
285 |
+
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
|
286 |
+
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
|
287 |
+
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
|
288 |
+
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
|
289 |
+
# [[United Arab Emirates]]. After the Emirates gained independence in 1971,
|
290 |
+
# ''')
|
291 |
+
# """
|
292 |
+
|
293 |
+
# n = 1
|
294 |
+
# print(f"Timing {n} run on 100 MB")
|
295 |
+
# print("Register contaminant")
|
296 |
+
# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
|
297 |
+
# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
|
298 |
+
|
299 |
+
# print("Clean")
|
300 |
+
# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
|
301 |
+
# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
|
302 |
+
|
303 |
+
|
304 |
+
# def test_janitor_general():
|
305 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
306 |
+
# contaminant = "dirty boy. Clean he he"
|
307 |
+
|
308 |
+
# jan = Janitor(ngram_n=3)
|
309 |
+
# jan.register_contaminant(contaminant)
|
310 |
+
# cleaned = " ".join(jan.clean(source))
|
311 |
+
# for contam in jan.dirt_ngrams:
|
312 |
+
# assert contam not in cleaned, contam
|
313 |
+
|
314 |
+
# filename = "data/saved_contam"
|
315 |
+
# jan.save_contamination_ngrams(filename)
|
316 |
+
|
317 |
+
# jan = Janitor(ngram_n=3)
|
318 |
+
# jan.load_contamination_ngrams(filename)
|
319 |
+
# cleaned = " ".join(jan.clean(source))
|
320 |
+
# for contam in jan.dirt_ngrams:
|
321 |
+
# assert contam not in cleaned, contam
|
322 |
+
|
323 |
+
|
324 |
+
# if __name__ == "__main__":
|
325 |
+
# test()
|
326 |
+
# # print_cpp()
|
327 |
+
# # test_cpp()
|
328 |
+
# # benchmark()
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .evaluation_tracker import EvaluationTracker
|
2 |
+
from .wandb_logger import WandbLogger
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (272 Bytes). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/evaluation_tracker.cpython-310.pyc
ADDED
Binary file (15.6 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.25 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/__pycache__/wandb_logger.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
from collections import defaultdict
|
6 |
+
from dataclasses import asdict, dataclass
|
7 |
+
from datetime import datetime
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from datasets import load_dataset
|
11 |
+
from datasets.utils.metadata import MetadataConfigs
|
12 |
+
from huggingface_hub import (
|
13 |
+
DatasetCard,
|
14 |
+
DatasetCardData,
|
15 |
+
HfApi,
|
16 |
+
hf_hub_url,
|
17 |
+
)
|
18 |
+
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
19 |
+
|
20 |
+
from lm_eval.utils import (
|
21 |
+
eval_logger,
|
22 |
+
get_file_datetime,
|
23 |
+
get_file_task_name,
|
24 |
+
get_results_filenames,
|
25 |
+
get_sample_results_filenames,
|
26 |
+
handle_non_serializable,
|
27 |
+
hash_string,
|
28 |
+
sanitize_list,
|
29 |
+
sanitize_model_name,
|
30 |
+
sanitize_task_name,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass(init=False)
|
35 |
+
class GeneralConfigTracker:
|
36 |
+
"""
|
37 |
+
Tracker for the evaluation parameters.
|
38 |
+
|
39 |
+
Attributes:
|
40 |
+
model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
|
41 |
+
model_name (str): Name of the model.
|
42 |
+
model_name_sanitized (str): Sanitized model name for directory creation.
|
43 |
+
start_time (float): Start time of the experiment. Logged at class init.
|
44 |
+
end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
|
45 |
+
total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
|
46 |
+
"""
|
47 |
+
|
48 |
+
model_source: str = None
|
49 |
+
model_name: str = None
|
50 |
+
model_name_sanitized: str = None
|
51 |
+
system_instruction: str = None
|
52 |
+
system_instruction_sha: str = None
|
53 |
+
fewshot_as_multiturn: bool = None
|
54 |
+
chat_template: str = None
|
55 |
+
chat_template_sha: str = None
|
56 |
+
start_time: float = None
|
57 |
+
end_time: float = None
|
58 |
+
total_evaluation_time_seconds: str = None
|
59 |
+
|
60 |
+
def __init__(self) -> None:
|
61 |
+
"""Starts the evaluation timer."""
|
62 |
+
self.start_time = time.perf_counter()
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def _get_model_name(model_args: str) -> str:
|
66 |
+
"""Extracts the model name from the model arguments."""
|
67 |
+
|
68 |
+
def extract_model_name(model_args: str, key: str) -> str:
|
69 |
+
"""Extracts the model name from the model arguments using a key."""
|
70 |
+
args_after_key = model_args.split(key)[1]
|
71 |
+
return args_after_key.split(",")[0]
|
72 |
+
|
73 |
+
# order does matter, e.g. peft and delta are provided together with pretrained
|
74 |
+
prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
|
75 |
+
for prefix in prefixes:
|
76 |
+
if prefix in model_args:
|
77 |
+
return extract_model_name(model_args, prefix)
|
78 |
+
return ""
|
79 |
+
|
80 |
+
def log_experiment_args(
|
81 |
+
self,
|
82 |
+
model_source: str,
|
83 |
+
model_args: str,
|
84 |
+
system_instruction: str,
|
85 |
+
chat_template: str,
|
86 |
+
fewshot_as_multiturn: bool,
|
87 |
+
) -> None:
|
88 |
+
"""Logs model parameters and job ID."""
|
89 |
+
self.model_source = model_source
|
90 |
+
self.model_name = GeneralConfigTracker._get_model_name(model_args)
|
91 |
+
self.model_name_sanitized = sanitize_model_name(self.model_name)
|
92 |
+
self.system_instruction = system_instruction
|
93 |
+
self.system_instruction_sha = (
|
94 |
+
hash_string(system_instruction) if system_instruction else None
|
95 |
+
)
|
96 |
+
self.chat_template = chat_template
|
97 |
+
self.chat_template_sha = hash_string(chat_template) if chat_template else None
|
98 |
+
self.fewshot_as_multiturn = fewshot_as_multiturn
|
99 |
+
|
100 |
+
def log_end_time(self) -> None:
|
101 |
+
"""Logs the end time of the evaluation and calculates the total evaluation time."""
|
102 |
+
self.end_time = time.perf_counter()
|
103 |
+
self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
|
104 |
+
|
105 |
+
|
106 |
+
class EvaluationTracker:
|
107 |
+
"""
|
108 |
+
Keeps track and saves relevant information of the evaluation process.
|
109 |
+
Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
output_path: str = None,
|
115 |
+
hub_results_org: str = "",
|
116 |
+
hub_repo_name: str = "",
|
117 |
+
details_repo_name: str = "",
|
118 |
+
results_repo_name: str = "",
|
119 |
+
push_results_to_hub: bool = False,
|
120 |
+
push_samples_to_hub: bool = False,
|
121 |
+
public_repo: bool = False,
|
122 |
+
token: str = "",
|
123 |
+
leaderboard_url: str = "",
|
124 |
+
point_of_contact: str = "",
|
125 |
+
gated: bool = False,
|
126 |
+
) -> None:
|
127 |
+
"""
|
128 |
+
Creates all the necessary loggers for evaluation tracking.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
output_path (str): Path to save the results. If not provided, the results won't be saved.
|
132 |
+
hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
|
133 |
+
hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
|
134 |
+
details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
|
135 |
+
result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
|
136 |
+
push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
|
137 |
+
push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
|
138 |
+
public_repo (bool): Whether to push the results to a public or private repository.
|
139 |
+
token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
|
140 |
+
leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
|
141 |
+
point_of_contact (str): Contact information on the Hugging Face hub dataset card.
|
142 |
+
gated (bool): Whether to gate the repository.
|
143 |
+
"""
|
144 |
+
self.general_config_tracker = GeneralConfigTracker()
|
145 |
+
|
146 |
+
self.output_path = output_path
|
147 |
+
self.push_results_to_hub = push_results_to_hub
|
148 |
+
self.push_samples_to_hub = push_samples_to_hub
|
149 |
+
self.public_repo = public_repo
|
150 |
+
self.leaderboard_url = leaderboard_url
|
151 |
+
self.point_of_contact = point_of_contact
|
152 |
+
self.api = HfApi(token=token) if token else None
|
153 |
+
self.gated_repo = gated
|
154 |
+
|
155 |
+
if not self.api and (push_results_to_hub or push_samples_to_hub):
|
156 |
+
raise ValueError(
|
157 |
+
"Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. "
|
158 |
+
"Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable."
|
159 |
+
)
|
160 |
+
|
161 |
+
if (
|
162 |
+
self.api
|
163 |
+
and hub_results_org == ""
|
164 |
+
and (push_results_to_hub or push_samples_to_hub)
|
165 |
+
):
|
166 |
+
hub_results_org = self.api.whoami()["name"]
|
167 |
+
eval_logger.warning(
|
168 |
+
f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
|
169 |
+
)
|
170 |
+
|
171 |
+
if hub_repo_name == "":
|
172 |
+
details_repo_name = (
|
173 |
+
details_repo_name if details_repo_name != "" else "lm-eval-results"
|
174 |
+
)
|
175 |
+
results_repo_name = (
|
176 |
+
results_repo_name if results_repo_name != "" else details_repo_name
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
details_repo_name = hub_repo_name
|
180 |
+
results_repo_name = hub_repo_name
|
181 |
+
eval_logger.warning(
|
182 |
+
"hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
|
183 |
+
)
|
184 |
+
|
185 |
+
self.details_repo = f"{hub_results_org}/{details_repo_name}"
|
186 |
+
self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
|
187 |
+
self.results_repo = f"{hub_results_org}/{results_repo_name}"
|
188 |
+
self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
|
189 |
+
|
190 |
+
def save_results_aggregated(
|
191 |
+
self,
|
192 |
+
results: dict,
|
193 |
+
samples: dict,
|
194 |
+
) -> None:
|
195 |
+
"""
|
196 |
+
Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
results (dict): The aggregated results to save.
|
200 |
+
samples (dict): The samples results to save.
|
201 |
+
"""
|
202 |
+
self.general_config_tracker.log_end_time()
|
203 |
+
|
204 |
+
if self.output_path:
|
205 |
+
try:
|
206 |
+
eval_logger.info("Saving results aggregated")
|
207 |
+
|
208 |
+
# calculate cumulative hash for each task - only if samples are provided
|
209 |
+
task_hashes = {}
|
210 |
+
if samples:
|
211 |
+
for task_name, task_samples in samples.items():
|
212 |
+
sample_hashes = [
|
213 |
+
s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
|
214 |
+
for s in task_samples
|
215 |
+
]
|
216 |
+
task_hashes[task_name] = hash_string("".join(sample_hashes))
|
217 |
+
|
218 |
+
# update initial results dict
|
219 |
+
results.update({"task_hashes": task_hashes})
|
220 |
+
results.update(asdict(self.general_config_tracker))
|
221 |
+
dumped = json.dumps(
|
222 |
+
results,
|
223 |
+
indent=2,
|
224 |
+
default=handle_non_serializable,
|
225 |
+
ensure_ascii=False,
|
226 |
+
)
|
227 |
+
|
228 |
+
path = Path(self.output_path if self.output_path else Path.cwd())
|
229 |
+
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
|
230 |
+
path.mkdir(parents=True, exist_ok=True)
|
231 |
+
|
232 |
+
self.date_id = datetime.now().isoformat().replace(":", "-")
|
233 |
+
file_results_aggregated = path.joinpath(f"results_{self.date_id}.json")
|
234 |
+
file_results_aggregated.open("w", encoding="utf-8").write(dumped)
|
235 |
+
|
236 |
+
if self.api and self.push_results_to_hub:
|
237 |
+
repo_id = (
|
238 |
+
self.results_repo
|
239 |
+
if self.public_repo
|
240 |
+
else self.results_repo_private
|
241 |
+
)
|
242 |
+
self.api.create_repo(
|
243 |
+
repo_id=repo_id,
|
244 |
+
repo_type="dataset",
|
245 |
+
private=not self.public_repo,
|
246 |
+
exist_ok=True,
|
247 |
+
)
|
248 |
+
self.api.upload_file(
|
249 |
+
repo_id=repo_id,
|
250 |
+
path_or_fileobj=str(
|
251 |
+
path.joinpath(f"results_{self.date_id}.json")
|
252 |
+
),
|
253 |
+
path_in_repo=os.path.join(
|
254 |
+
self.general_config_tracker.model_name,
|
255 |
+
f"results_{self.date_id}.json",
|
256 |
+
),
|
257 |
+
repo_type="dataset",
|
258 |
+
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
|
259 |
+
)
|
260 |
+
eval_logger.info(
|
261 |
+
"Successfully pushed aggregated results to the Hugging Face Hub. "
|
262 |
+
f"You can find them at: {repo_id}"
|
263 |
+
)
|
264 |
+
|
265 |
+
except Exception as e:
|
266 |
+
eval_logger.warning("Could not save results aggregated")
|
267 |
+
eval_logger.info(repr(e))
|
268 |
+
else:
|
269 |
+
eval_logger.info(
|
270 |
+
"Output path not provided, skipping saving results aggregated"
|
271 |
+
)
|
272 |
+
|
273 |
+
def save_results_samples(
|
274 |
+
self,
|
275 |
+
task_name: str,
|
276 |
+
samples: dict,
|
277 |
+
) -> None:
|
278 |
+
"""
|
279 |
+
Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
task_name (str): The task name to save the samples for.
|
283 |
+
samples (dict): The samples results to save.
|
284 |
+
"""
|
285 |
+
if self.output_path:
|
286 |
+
try:
|
287 |
+
eval_logger.info(f"Saving per-sample results for: {task_name}")
|
288 |
+
|
289 |
+
path = Path(self.output_path if self.output_path else Path.cwd())
|
290 |
+
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
|
291 |
+
path.mkdir(parents=True, exist_ok=True)
|
292 |
+
|
293 |
+
file_results_samples = path.joinpath(
|
294 |
+
f"samples_{task_name}_{self.date_id}.jsonl"
|
295 |
+
)
|
296 |
+
|
297 |
+
for sample in samples:
|
298 |
+
# we first need to sanitize arguments and resps
|
299 |
+
# otherwise we won't be able to load the dataset
|
300 |
+
# using the datasets library
|
301 |
+
arguments = {}
|
302 |
+
for i, arg in enumerate(sample["arguments"]):
|
303 |
+
arguments[f"gen_args_{i}"] = {}
|
304 |
+
for j, tmp in enumerate(arg):
|
305 |
+
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
|
306 |
+
|
307 |
+
sample["resps"] = sanitize_list(sample["resps"])
|
308 |
+
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
|
309 |
+
sample["arguments"] = arguments
|
310 |
+
sample["target"] = str(sample["target"])
|
311 |
+
|
312 |
+
sample_dump = (
|
313 |
+
json.dumps(
|
314 |
+
sample,
|
315 |
+
default=handle_non_serializable,
|
316 |
+
ensure_ascii=False,
|
317 |
+
)
|
318 |
+
+ "\n"
|
319 |
+
)
|
320 |
+
|
321 |
+
with open(file_results_samples, "a", encoding="utf-8") as f:
|
322 |
+
f.write(sample_dump)
|
323 |
+
|
324 |
+
if self.api and self.push_samples_to_hub:
|
325 |
+
repo_id = (
|
326 |
+
self.details_repo
|
327 |
+
if self.public_repo
|
328 |
+
else self.details_repo_private
|
329 |
+
)
|
330 |
+
self.api.create_repo(
|
331 |
+
repo_id=repo_id,
|
332 |
+
repo_type="dataset",
|
333 |
+
private=not self.public_repo,
|
334 |
+
exist_ok=True,
|
335 |
+
)
|
336 |
+
try:
|
337 |
+
if self.gated_repo:
|
338 |
+
headers = build_hf_headers()
|
339 |
+
r = get_session().put(
|
340 |
+
url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
|
341 |
+
headers=headers,
|
342 |
+
json={"gated": "auto"},
|
343 |
+
)
|
344 |
+
hf_raise_for_status(r)
|
345 |
+
except Exception as e:
|
346 |
+
eval_logger.warning("Could not gate the repository")
|
347 |
+
eval_logger.info(repr(e))
|
348 |
+
self.api.upload_folder(
|
349 |
+
repo_id=repo_id,
|
350 |
+
folder_path=str(path),
|
351 |
+
path_in_repo=self.general_config_tracker.model_name_sanitized,
|
352 |
+
repo_type="dataset",
|
353 |
+
commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
|
354 |
+
)
|
355 |
+
eval_logger.info(
|
356 |
+
f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. "
|
357 |
+
f"You can find them at: {repo_id}"
|
358 |
+
)
|
359 |
+
|
360 |
+
except Exception as e:
|
361 |
+
eval_logger.warning("Could not save sample results")
|
362 |
+
eval_logger.info(repr(e))
|
363 |
+
else:
|
364 |
+
eval_logger.info("Output path not provided, skipping saving sample results")
|
365 |
+
|
366 |
+
def recreate_metadata_card(self) -> None:
|
367 |
+
"""
|
368 |
+
Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
|
369 |
+
"""
|
370 |
+
|
371 |
+
eval_logger.info("Recreating metadata card")
|
372 |
+
repo_id = self.details_repo if self.public_repo else self.details_repo_private
|
373 |
+
|
374 |
+
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
|
375 |
+
results_files = get_results_filenames(files_in_repo)
|
376 |
+
sample_files = get_sample_results_filenames(files_in_repo)
|
377 |
+
|
378 |
+
# Build a dictionary to store the latest evaluation datetime for:
|
379 |
+
# - Each tested model and its aggregated results
|
380 |
+
# - Each task and sample results, if existing
|
381 |
+
# i.e. {
|
382 |
+
# "org__model_name__gsm8k": "2021-09-01T12:00:00",
|
383 |
+
# "org__model_name__ifeval": "2021-09-01T12:00:00",
|
384 |
+
# "org__model_name__results": "2021-09-01T12:00:00"
|
385 |
+
# }
|
386 |
+
latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat())
|
387 |
+
|
388 |
+
for file_path in sample_files:
|
389 |
+
file_path = Path(file_path)
|
390 |
+
filename = file_path.name
|
391 |
+
model_name = file_path.parent
|
392 |
+
task_name = get_file_task_name(filename)
|
393 |
+
results_datetime = get_file_datetime(filename)
|
394 |
+
task_name_sanitized = sanitize_task_name(task_name)
|
395 |
+
# Results and sample results for the same model and task will have the same datetime
|
396 |
+
samples_key = f"{model_name}__{task_name_sanitized}"
|
397 |
+
results_key = f"{model_name}__results"
|
398 |
+
latest_datetime = max(
|
399 |
+
latest_task_results_datetime[samples_key],
|
400 |
+
results_datetime,
|
401 |
+
)
|
402 |
+
latest_task_results_datetime[samples_key] = latest_datetime
|
403 |
+
latest_task_results_datetime[results_key] = max(
|
404 |
+
latest_task_results_datetime[results_key],
|
405 |
+
latest_datetime,
|
406 |
+
)
|
407 |
+
|
408 |
+
# Create metadata card
|
409 |
+
card_metadata = MetadataConfigs()
|
410 |
+
|
411 |
+
# Add the latest aggregated results to the metadata card for easy access
|
412 |
+
for file_path in results_files:
|
413 |
+
file_path = Path(file_path)
|
414 |
+
results_filename = file_path.name
|
415 |
+
model_name = file_path.parent
|
416 |
+
eval_date = get_file_datetime(results_filename)
|
417 |
+
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
418 |
+
results_filename = Path("**") / Path(results_filename).name
|
419 |
+
config_name = f"{model_name}__results"
|
420 |
+
sanitized_last_eval_date_results = re.sub(
|
421 |
+
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
422 |
+
)
|
423 |
+
|
424 |
+
if eval_date_sanitized == sanitized_last_eval_date_results:
|
425 |
+
# Ensure that all results files are listed in the metadata card
|
426 |
+
current_results = card_metadata.get(config_name, {"data_files": []})
|
427 |
+
current_results["data_files"].append(
|
428 |
+
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
429 |
+
)
|
430 |
+
card_metadata[config_name] = current_results
|
431 |
+
# If the results file is the newest, update the "latest" field in the metadata card
|
432 |
+
card_metadata[config_name]["data_files"].append(
|
433 |
+
{"split": "latest", "path": [str(results_filename)]}
|
434 |
+
)
|
435 |
+
|
436 |
+
# Add the tasks details configs
|
437 |
+
for file_path in sample_files:
|
438 |
+
file_path = Path(file_path)
|
439 |
+
filename = file_path.name
|
440 |
+
model_name = file_path.parent
|
441 |
+
task_name = get_file_task_name(filename)
|
442 |
+
eval_date = get_file_datetime(filename)
|
443 |
+
task_name_sanitized = sanitize_task_name(task_name)
|
444 |
+
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
445 |
+
results_filename = Path("**") / Path(filename).name
|
446 |
+
config_name = f"{model_name}__{task_name_sanitized}"
|
447 |
+
sanitized_last_eval_date_results = re.sub(
|
448 |
+
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
449 |
+
)
|
450 |
+
if eval_date_sanitized == sanitized_last_eval_date_results:
|
451 |
+
# Ensure that all sample results files are listed in the metadata card
|
452 |
+
current_details_for_task = card_metadata.get(
|
453 |
+
config_name, {"data_files": []}
|
454 |
+
)
|
455 |
+
current_details_for_task["data_files"].append(
|
456 |
+
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
457 |
+
)
|
458 |
+
card_metadata[config_name] = current_details_for_task
|
459 |
+
# If the samples results file is the newest, update the "latest" field in the metadata card
|
460 |
+
card_metadata[config_name]["data_files"].append(
|
461 |
+
{"split": "latest", "path": [str(results_filename)]}
|
462 |
+
)
|
463 |
+
|
464 |
+
# Get latest results and extract info to update metadata card examples
|
465 |
+
latest_datetime = max(latest_task_results_datetime.values())
|
466 |
+
latest_model_name = max(
|
467 |
+
latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k]
|
468 |
+
)
|
469 |
+
last_results_file = [
|
470 |
+
f for f in results_files if latest_datetime.replace(":", "-") in f
|
471 |
+
][0]
|
472 |
+
last_results_file_path = hf_hub_url(
|
473 |
+
repo_id=repo_id, filename=last_results_file, repo_type="dataset"
|
474 |
+
)
|
475 |
+
latest_results_file = load_dataset(
|
476 |
+
"json", data_files=last_results_file_path, split="train"
|
477 |
+
)
|
478 |
+
results_dict = latest_results_file["results"][0]
|
479 |
+
new_dictionary = {"all": results_dict}
|
480 |
+
new_dictionary.update(results_dict)
|
481 |
+
results_string = json.dumps(new_dictionary, indent=4)
|
482 |
+
|
483 |
+
dataset_summary = (
|
484 |
+
"Dataset automatically created during the evaluation run of model "
|
485 |
+
)
|
486 |
+
if self.general_config_tracker.model_source == "hf":
|
487 |
+
dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n"
|
488 |
+
else:
|
489 |
+
dataset_summary += f"{self.general_config_tracker.model_name}\n"
|
490 |
+
dataset_summary += (
|
491 |
+
f"The dataset is composed of {len(card_metadata)-1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
|
492 |
+
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
|
493 |
+
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
|
494 |
+
'An additional configuration "results" store all the aggregated results of the run.\n\n'
|
495 |
+
"To load the details from a run, you can for instance do the following:\n"
|
496 |
+
)
|
497 |
+
if self.general_config_tracker.model_source == "hf":
|
498 |
+
dataset_summary += (
|
499 |
+
"```python\nfrom datasets import load_dataset\n"
|
500 |
+
f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n'
|
501 |
+
)
|
502 |
+
dataset_summary += (
|
503 |
+
"## Latest results\n\n"
|
504 |
+
f'These are the [latest results from run {latest_datetime}]({last_results_file_path.replace("/resolve/", "/blob/")}) '
|
505 |
+
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
|
506 |
+
'You find each in the results and the "latest" split for each eval):\n\n'
|
507 |
+
f"```python\n{results_string}\n```"
|
508 |
+
)
|
509 |
+
card_data = DatasetCardData(
|
510 |
+
dataset_summary=dataset_summary,
|
511 |
+
repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}",
|
512 |
+
pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}",
|
513 |
+
leaderboard_url=self.leaderboard_url,
|
514 |
+
point_of_contact=self.point_of_contact,
|
515 |
+
)
|
516 |
+
card_metadata.to_dataset_card_data(card_data)
|
517 |
+
card = DatasetCard.from_template(
|
518 |
+
card_data,
|
519 |
+
pretty_name=card_data.pretty_name,
|
520 |
+
)
|
521 |
+
card.push_to_hub(repo_id, repo_type="dataset")
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import subprocess
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from torch.utils.collect_env import get_pretty_env_info
|
10 |
+
from transformers import __version__ as trans_version
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
|
17 |
+
"""Remove the ',none' substring from the input_string if it exists at the end.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
input_string (str): The input string from which to remove the ',none' substring.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
|
24 |
+
and a boolean indicating whether the modification was made (True) or not (False).
|
25 |
+
"""
|
26 |
+
# Define the pattern to match ',none' at the end of the string
|
27 |
+
pattern = re.compile(r",none$")
|
28 |
+
|
29 |
+
# Use sub() to replace ',none' with an empty string
|
30 |
+
result = re.sub(pattern, "", input_string)
|
31 |
+
|
32 |
+
# check if the input_string changed
|
33 |
+
removed = result != input_string
|
34 |
+
|
35 |
+
return result, removed
|
36 |
+
|
37 |
+
|
38 |
+
def _handle_non_serializable(o: Any) -> Union[int, str, list]:
|
39 |
+
"""Handle non-serializable objects by converting them to serializable types.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
o (Any): The object to be handled.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
|
46 |
+
it will be converted to int. If the object is of type set, it will be converted
|
47 |
+
to a list. Otherwise, it will be converted to str.
|
48 |
+
"""
|
49 |
+
if isinstance(o, np.int64) or isinstance(o, np.int32):
|
50 |
+
return int(o)
|
51 |
+
elif isinstance(o, set):
|
52 |
+
return list(o)
|
53 |
+
else:
|
54 |
+
return str(o)
|
55 |
+
|
56 |
+
|
57 |
+
def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
|
58 |
+
try:
|
59 |
+
git_folder = Path(repo_path, ".git")
|
60 |
+
if git_folder.is_file():
|
61 |
+
git_folder = Path(
|
62 |
+
git_folder.parent,
|
63 |
+
git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
|
64 |
+
)
|
65 |
+
if Path(git_folder, "HEAD").exists():
|
66 |
+
head_name = (
|
67 |
+
Path(git_folder, "HEAD")
|
68 |
+
.read_text(encoding="utf-8")
|
69 |
+
.split("\n")[0]
|
70 |
+
.split(" ")[-1]
|
71 |
+
)
|
72 |
+
head_ref = Path(git_folder, head_name)
|
73 |
+
git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
|
74 |
+
else:
|
75 |
+
git_hash = None
|
76 |
+
except Exception as err:
|
77 |
+
logger.debug(
|
78 |
+
f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
|
79 |
+
)
|
80 |
+
return None
|
81 |
+
return git_hash
|
82 |
+
|
83 |
+
|
84 |
+
def get_git_commit_hash():
|
85 |
+
"""
|
86 |
+
Gets the git commit hash of your current repo (if it exists).
|
87 |
+
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
|
88 |
+
"""
|
89 |
+
try:
|
90 |
+
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
|
91 |
+
git_hash = git_hash.decode()
|
92 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
93 |
+
# FileNotFoundError occurs when git not installed on system
|
94 |
+
git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
|
95 |
+
return git_hash
|
96 |
+
|
97 |
+
|
98 |
+
def add_env_info(storage: Dict[str, Any]):
|
99 |
+
try:
|
100 |
+
pretty_env_info = get_pretty_env_info()
|
101 |
+
except Exception as err:
|
102 |
+
pretty_env_info = str(err)
|
103 |
+
transformers_version = trans_version
|
104 |
+
upper_dir_commit = get_commit_from_path(
|
105 |
+
Path(os.getcwd(), "..")
|
106 |
+
) # git hash of upper repo if exists
|
107 |
+
added_info = {
|
108 |
+
"pretty_env_info": pretty_env_info,
|
109 |
+
"transformers_version": transformers_version,
|
110 |
+
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
|
111 |
+
}
|
112 |
+
storage.update(added_info)
|
113 |
+
|
114 |
+
|
115 |
+
def add_tokenizer_info(storage: Dict[str, Any], lm):
|
116 |
+
if getattr(lm, "tokenizer", False):
|
117 |
+
try:
|
118 |
+
tokenizer_info = {
|
119 |
+
"tokenizer_pad_token": [
|
120 |
+
lm.tokenizer.pad_token,
|
121 |
+
str(lm.tokenizer.pad_token_id),
|
122 |
+
],
|
123 |
+
"tokenizer_eos_token": [
|
124 |
+
lm.tokenizer.eos_token,
|
125 |
+
str(lm.tokenizer.eos_token_id),
|
126 |
+
],
|
127 |
+
"tokenizer_bos_token": [
|
128 |
+
lm.tokenizer.bos_token,
|
129 |
+
str(lm.tokenizer.bos_token_id),
|
130 |
+
],
|
131 |
+
"eot_token_id": getattr(lm, "eot_token_id", None),
|
132 |
+
"max_length": getattr(lm, "max_length", None),
|
133 |
+
}
|
134 |
+
storage.update(tokenizer_info)
|
135 |
+
except Exception as err:
|
136 |
+
logger.debug(
|
137 |
+
f"Logging detailed tokenizer info failed with {err}, skipping..."
|
138 |
+
)
|
139 |
+
# seems gguf and textsynth do not have tokenizer
|
140 |
+
else:
|
141 |
+
logger.debug(
|
142 |
+
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
|
143 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/loggers/wandb_logger.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
from typing import Any, Dict, List, Literal, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from packaging.version import Version
|
9 |
+
|
10 |
+
from lm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def get_wandb_printer() -> Literal["Printer"]:
|
17 |
+
"""Returns a wandb printer instance for pretty stdout."""
|
18 |
+
from wandb.sdk.lib.printer import get_printer
|
19 |
+
from wandb.sdk.wandb_settings import Settings
|
20 |
+
|
21 |
+
printer = get_printer(Settings()._jupyter)
|
22 |
+
return printer
|
23 |
+
|
24 |
+
|
25 |
+
class WandbLogger:
|
26 |
+
def __init__(self, **kwargs) -> None:
|
27 |
+
"""Attaches to wandb logger if already initialized. Otherwise, passes kwargs to wandb.init()
|
28 |
+
|
29 |
+
Args:
|
30 |
+
kwargs Optional[Any]: Arguments for configuration.
|
31 |
+
|
32 |
+
Parse and log the results returned from evaluator.simple_evaluate() with:
|
33 |
+
wandb_logger.post_init(results)
|
34 |
+
wandb_logger.log_eval_result()
|
35 |
+
wandb_logger.log_eval_samples(results["samples"])
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
import wandb
|
39 |
+
|
40 |
+
assert Version(wandb.__version__) >= Version("0.13.6")
|
41 |
+
if Version(wandb.__version__) < Version("0.13.6"):
|
42 |
+
wandb.require("report-editing:v0")
|
43 |
+
except Exception as e:
|
44 |
+
logger.warning(
|
45 |
+
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
|
46 |
+
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
|
47 |
+
f"{e}"
|
48 |
+
)
|
49 |
+
|
50 |
+
self.wandb_args: Dict[str, Any] = kwargs
|
51 |
+
|
52 |
+
# initialize a W&B run
|
53 |
+
if wandb.run is None:
|
54 |
+
self.run = wandb.init(**self.wandb_args)
|
55 |
+
else:
|
56 |
+
self.run = wandb.run
|
57 |
+
|
58 |
+
self.printer = get_wandb_printer()
|
59 |
+
|
60 |
+
def post_init(self, results: Dict[str, Any]) -> None:
|
61 |
+
self.results: Dict[str, Any] = copy.deepcopy(results)
|
62 |
+
self.task_names: List[str] = list(results.get("results", {}).keys())
|
63 |
+
self.group_names: List[str] = list(results.get("groups", {}).keys())
|
64 |
+
|
65 |
+
def _get_config(self) -> Dict[str, Any]:
|
66 |
+
"""Get configuration parameters."""
|
67 |
+
self.task_configs = self.results.get("configs", {})
|
68 |
+
cli_configs = self.results.get("config", {})
|
69 |
+
configs = {
|
70 |
+
"task_configs": self.task_configs,
|
71 |
+
"cli_configs": cli_configs,
|
72 |
+
}
|
73 |
+
|
74 |
+
return configs
|
75 |
+
|
76 |
+
def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
|
77 |
+
"""Sanitize the results dictionary."""
|
78 |
+
_results = copy.deepcopy(self.results.get("results", dict()))
|
79 |
+
|
80 |
+
# Remove None from the metric string name
|
81 |
+
tmp_results = copy.deepcopy(_results)
|
82 |
+
for task_name in self.task_names:
|
83 |
+
task_result = tmp_results.get(task_name, dict())
|
84 |
+
for metric_name, metric_value in task_result.items():
|
85 |
+
_metric_name, removed = remove_none_pattern(metric_name)
|
86 |
+
if removed:
|
87 |
+
_results[task_name][_metric_name] = metric_value
|
88 |
+
_results[task_name].pop(metric_name)
|
89 |
+
|
90 |
+
# remove string valued keys from the results dict
|
91 |
+
wandb_summary = {}
|
92 |
+
for task in self.task_names:
|
93 |
+
task_result = _results.get(task, dict())
|
94 |
+
for metric_name, metric_value in task_result.items():
|
95 |
+
if isinstance(metric_value, str):
|
96 |
+
wandb_summary[f"{task}/{metric_name}"] = metric_value
|
97 |
+
|
98 |
+
for summary_metric, summary_value in wandb_summary.items():
|
99 |
+
_task, _summary_metric = summary_metric.split("/")
|
100 |
+
_results[_task].pop(_summary_metric)
|
101 |
+
|
102 |
+
tmp_results = copy.deepcopy(_results)
|
103 |
+
for task_name, task_results in tmp_results.items():
|
104 |
+
for metric_name, metric_value in task_results.items():
|
105 |
+
_results[f"{task_name}/{metric_name}"] = metric_value
|
106 |
+
_results[task_name].pop(metric_name)
|
107 |
+
for task in self.task_names:
|
108 |
+
_results.pop(task)
|
109 |
+
|
110 |
+
return wandb_summary, _results
|
111 |
+
|
112 |
+
def _log_results_as_table(self) -> None:
|
113 |
+
"""Generate and log evaluation results as a table to W&B."""
|
114 |
+
columns = [
|
115 |
+
"Version",
|
116 |
+
"Filter",
|
117 |
+
"num_fewshot",
|
118 |
+
"Metric",
|
119 |
+
"Value",
|
120 |
+
"Stderr",
|
121 |
+
]
|
122 |
+
|
123 |
+
def make_table(columns: List[str], key: str = "results"):
|
124 |
+
import wandb
|
125 |
+
|
126 |
+
table = wandb.Table(columns=columns)
|
127 |
+
results = copy.deepcopy(self.results)
|
128 |
+
|
129 |
+
for k, dic in results.get(key).items():
|
130 |
+
if k in self.group_names and not key == "groups":
|
131 |
+
continue
|
132 |
+
version = results.get("versions").get(k)
|
133 |
+
if version == "N/A":
|
134 |
+
version = None
|
135 |
+
n = results.get("n-shot").get(k)
|
136 |
+
|
137 |
+
for (mf), v in dic.items():
|
138 |
+
m, _, f = mf.partition(",")
|
139 |
+
if m.endswith("_stderr"):
|
140 |
+
continue
|
141 |
+
if m == "alias":
|
142 |
+
continue
|
143 |
+
|
144 |
+
if m + "_stderr" + "," + f in dic:
|
145 |
+
se = dic[m + "_stderr" + "," + f]
|
146 |
+
if se != "N/A":
|
147 |
+
se = "%.4f" % se
|
148 |
+
table.add_data(*[k, version, f, n, m, str(v), str(se)])
|
149 |
+
else:
|
150 |
+
table.add_data(*[k, version, f, n, m, str(v), ""])
|
151 |
+
|
152 |
+
return table
|
153 |
+
|
154 |
+
# log the complete eval result to W&B Table
|
155 |
+
table = make_table(["Tasks"] + columns, "results")
|
156 |
+
self.run.log({"evaluation/eval_results": table})
|
157 |
+
|
158 |
+
if "groups" in self.results.keys():
|
159 |
+
table = make_table(["Groups"] + columns, "groups")
|
160 |
+
self.run.log({"evaluation/group_eval_results": table})
|
161 |
+
|
162 |
+
def _log_results_as_artifact(self) -> None:
|
163 |
+
"""Log results as JSON artifact to W&B."""
|
164 |
+
import wandb
|
165 |
+
|
166 |
+
dumped = json.dumps(
|
167 |
+
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
|
168 |
+
)
|
169 |
+
artifact = wandb.Artifact("results", type="eval_results")
|
170 |
+
with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
|
171 |
+
f.write(dumped)
|
172 |
+
self.run.log_artifact(artifact)
|
173 |
+
|
174 |
+
def log_eval_result(self) -> None:
|
175 |
+
"""Log evaluation results to W&B."""
|
176 |
+
# Log configs to wandb
|
177 |
+
configs = self._get_config()
|
178 |
+
self.run.config.update(configs)
|
179 |
+
|
180 |
+
wandb_summary, self.wandb_results = self._sanitize_results_dict()
|
181 |
+
# update wandb.run.summary with items that were removed
|
182 |
+
self.run.summary.update(wandb_summary)
|
183 |
+
# Log the evaluation metrics to wandb
|
184 |
+
self.run.log(self.wandb_results)
|
185 |
+
# Log the evaluation metrics as W&B Table
|
186 |
+
self._log_results_as_table()
|
187 |
+
# Log the results dict as json to W&B Artifacts
|
188 |
+
self._log_results_as_artifact()
|
189 |
+
|
190 |
+
def _generate_dataset(
|
191 |
+
self, data: List[Dict[str, Any]], config: Dict[str, Any]
|
192 |
+
) -> pd.DataFrame:
|
193 |
+
"""Generate a dataset from evaluation data.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
data (List[Dict[str, Any]]): The data to generate a dataset for.
|
197 |
+
config (Dict[str, Any]): The configuration of the task.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
|
201 |
+
"""
|
202 |
+
ids = [x["doc_id"] for x in data]
|
203 |
+
labels = [x["target"] for x in data]
|
204 |
+
instance = [""] * len(ids)
|
205 |
+
resps = [""] * len(ids)
|
206 |
+
filtered_resps = [""] * len(ids)
|
207 |
+
model_outputs = {}
|
208 |
+
|
209 |
+
metrics_list = config["metric_list"]
|
210 |
+
metrics = {}
|
211 |
+
for metric in metrics_list:
|
212 |
+
metric = metric.get("metric")
|
213 |
+
if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
|
214 |
+
metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
|
215 |
+
if metric in ["byte_perplexity", "bits_per_byte"]:
|
216 |
+
metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
|
217 |
+
else:
|
218 |
+
metrics[f"{metric}_words"] = [x[metric][1] for x in data]
|
219 |
+
else:
|
220 |
+
metrics[metric] = [x[metric] for x in data]
|
221 |
+
|
222 |
+
if config["output_type"] == "loglikelihood":
|
223 |
+
instance = [x["arguments"][0][0] for x in data]
|
224 |
+
labels = [x["arguments"][0][1] for x in data]
|
225 |
+
resps = [
|
226 |
+
f'log probability of continuation is {x["resps"][0][0][0]} '
|
227 |
+
+ "\n\n"
|
228 |
+
+ "continuation will {} generated with greedy sampling".format(
|
229 |
+
"not be" if not x["resps"][0][0][1] else "be"
|
230 |
+
)
|
231 |
+
for x in data
|
232 |
+
]
|
233 |
+
filtered_resps = [
|
234 |
+
f'log probability of continuation is {x["filtered_resps"][0][0]} '
|
235 |
+
+ "\n\n"
|
236 |
+
+ "continuation will {} generated with greedy sampling".format(
|
237 |
+
"not be" if not x["filtered_resps"][0][1] else "be"
|
238 |
+
)
|
239 |
+
for x in data
|
240 |
+
]
|
241 |
+
elif config["output_type"] == "multiple_choice":
|
242 |
+
instance = [x["arguments"][0][0] for x in data]
|
243 |
+
choices = [
|
244 |
+
"\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
|
245 |
+
for x in data
|
246 |
+
]
|
247 |
+
resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
|
248 |
+
filtered_resps = [
|
249 |
+
np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
|
250 |
+
]
|
251 |
+
elif config["output_type"] == "loglikelihood_rolling":
|
252 |
+
instance = [x["arguments"][0][0] for x in data]
|
253 |
+
resps = [x["resps"][0][0] for x in data]
|
254 |
+
filtered_resps = [x["filtered_resps"][0] for x in data]
|
255 |
+
elif config["output_type"] == "generate_until":
|
256 |
+
instance = [x["arguments"][0][0] for x in data]
|
257 |
+
resps = [x["resps"][0][0] for x in data]
|
258 |
+
filtered_resps = [x["filtered_resps"][0] for x in data]
|
259 |
+
|
260 |
+
model_outputs["raw_predictions"] = resps
|
261 |
+
model_outputs["filtered_predictions"] = filtered_resps
|
262 |
+
|
263 |
+
df_data = {
|
264 |
+
"id": ids,
|
265 |
+
"data": instance,
|
266 |
+
}
|
267 |
+
if config["output_type"] == "multiple_choice":
|
268 |
+
df_data["choices"] = choices
|
269 |
+
|
270 |
+
tmp_data = {
|
271 |
+
"input_len": [len(x) for x in instance],
|
272 |
+
"labels": labels,
|
273 |
+
"output_type": config["output_type"],
|
274 |
+
}
|
275 |
+
df_data.update(tmp_data)
|
276 |
+
df_data.update(model_outputs)
|
277 |
+
df_data.update(metrics)
|
278 |
+
|
279 |
+
return pd.DataFrame(df_data)
|
280 |
+
|
281 |
+
def _log_samples_as_artifact(
|
282 |
+
self, data: List[Dict[str, Any]], task_name: str
|
283 |
+
) -> None:
|
284 |
+
import wandb
|
285 |
+
|
286 |
+
# log the samples as an artifact
|
287 |
+
dumped = json.dumps(
|
288 |
+
data,
|
289 |
+
indent=2,
|
290 |
+
default=_handle_non_serializable,
|
291 |
+
ensure_ascii=False,
|
292 |
+
)
|
293 |
+
artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
|
294 |
+
with artifact.new_file(
|
295 |
+
f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
|
296 |
+
) as f:
|
297 |
+
f.write(dumped)
|
298 |
+
self.run.log_artifact(artifact)
|
299 |
+
# artifact.wait()
|
300 |
+
|
301 |
+
def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
|
302 |
+
"""Log evaluation samples to W&B.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
|
306 |
+
"""
|
307 |
+
task_names: List[str] = [
|
308 |
+
x for x in self.task_names if x not in self.group_names
|
309 |
+
]
|
310 |
+
|
311 |
+
ungrouped_tasks = []
|
312 |
+
tasks_by_groups = {}
|
313 |
+
|
314 |
+
for task_name in task_names:
|
315 |
+
group_names = self.task_configs[task_name].get("group", None)
|
316 |
+
if group_names:
|
317 |
+
if isinstance(group_names, str):
|
318 |
+
group_names = [group_names]
|
319 |
+
|
320 |
+
for group_name in group_names:
|
321 |
+
if not tasks_by_groups.get(group_name):
|
322 |
+
tasks_by_groups[group_name] = [task_name]
|
323 |
+
else:
|
324 |
+
tasks_by_groups[group_name].append(task_name)
|
325 |
+
else:
|
326 |
+
ungrouped_tasks.append(task_name)
|
327 |
+
|
328 |
+
for task_name in ungrouped_tasks:
|
329 |
+
eval_preds = samples[task_name]
|
330 |
+
|
331 |
+
# log the samples as a W&B Table
|
332 |
+
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
|
333 |
+
self.run.log({f"{task_name}_eval_results": df})
|
334 |
+
|
335 |
+
# log the samples as a json file as W&B Artifact
|
336 |
+
self._log_samples_as_artifact(eval_preds, task_name)
|
337 |
+
|
338 |
+
for group, grouped_tasks in tasks_by_groups.items():
|
339 |
+
grouped_df = pd.DataFrame()
|
340 |
+
for task_name in grouped_tasks:
|
341 |
+
eval_preds = samples[task_name]
|
342 |
+
df = self._generate_dataset(
|
343 |
+
eval_preds, self.task_configs.get(task_name)
|
344 |
+
)
|
345 |
+
df["group"] = group
|
346 |
+
df["task"] = task_name
|
347 |
+
grouped_df = pd.concat([grouped_df, df], ignore_index=True)
|
348 |
+
|
349 |
+
# log the samples as a json file as W&B Artifact
|
350 |
+
self._log_samples_as_artifact(eval_preds, task_name)
|
351 |
+
|
352 |
+
self.run.log({f"{group}_eval_results": grouped_df})
|
scripts/yans/lm-evaluation-harness/lm_eval/models/dummy.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from lm_eval.api.model import LM
|
6 |
+
from lm_eval.api.registry import register_model
|
7 |
+
|
8 |
+
|
9 |
+
@register_model("dummy")
|
10 |
+
class DummyLM(LM):
|
11 |
+
def __init__(self) -> None:
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
@classmethod
|
15 |
+
def create_from_arg_string(cls, arg_string, additional_config=None):
|
16 |
+
return cls()
|
17 |
+
|
18 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
19 |
+
res = []
|
20 |
+
|
21 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
22 |
+
res.append((-random.random(), False))
|
23 |
+
|
24 |
+
return res
|
25 |
+
|
26 |
+
def generate_until(self, requests, disable_tqdm: bool = False):
|
27 |
+
res = []
|
28 |
+
|
29 |
+
for ctx, _ in tqdm(requests, disable=disable_tqdm):
|
30 |
+
res.append("lol")
|
31 |
+
assert ctx.strip() != ""
|
32 |
+
|
33 |
+
return res
|
34 |
+
|
35 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
36 |
+
res = []
|
37 |
+
|
38 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
39 |
+
res.append(-random.random())
|
40 |
+
|
41 |
+
return res
|
scripts/yans/lm-evaluation-harness/lm_eval/models/gguf.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from requests.exceptions import RequestException
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from lm_eval.api.model import LM
|
9 |
+
from lm_eval.api.registry import register_model
|
10 |
+
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def get_result(logprobs, context_length):
|
16 |
+
is_greedy = True
|
17 |
+
offsets = logprobs["text_offset"]
|
18 |
+
tokens = logprobs["tokens"]
|
19 |
+
tokens_logprobs = logprobs["token_logprobs"]
|
20 |
+
|
21 |
+
idx = 0
|
22 |
+
while offsets[idx] < context_length:
|
23 |
+
idx += 1
|
24 |
+
continuation_logprobs = sum(tokens_logprobs[idx:-1])
|
25 |
+
for i in range(idx, len(tokens)):
|
26 |
+
token = tokens[i]
|
27 |
+
top_tokens = logprobs["top_logprobs"][i]
|
28 |
+
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
|
29 |
+
if top_token != token:
|
30 |
+
is_greedy = False
|
31 |
+
break
|
32 |
+
|
33 |
+
return continuation_logprobs, is_greedy
|
34 |
+
|
35 |
+
|
36 |
+
@register_model("gguf", "ggml")
|
37 |
+
class GGUFLM(LM):
|
38 |
+
def __init__(self, base_url=None, max_length=2048, **kwargs):
|
39 |
+
super().__init__()
|
40 |
+
self.base_url = base_url
|
41 |
+
assert self.base_url, "must pass `base_url` to use GGUF LM!"
|
42 |
+
self.logprobs = 10
|
43 |
+
self.temperature = 0.0
|
44 |
+
self.max_length = max_length
|
45 |
+
|
46 |
+
def gguf_completion(
|
47 |
+
self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
|
48 |
+
):
|
49 |
+
for _ in range(retries):
|
50 |
+
try:
|
51 |
+
prompt = context
|
52 |
+
request = {
|
53 |
+
"prompt": prompt,
|
54 |
+
"logprobs": self.logprobs,
|
55 |
+
"temperature": self.temperature,
|
56 |
+
}
|
57 |
+
if continuation:
|
58 |
+
prompt += continuation
|
59 |
+
request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
|
60 |
+
if stop is not None:
|
61 |
+
request["stop"] = stop
|
62 |
+
response = requests.post(
|
63 |
+
f"{self.base_url}/v1/completions", json=request
|
64 |
+
)
|
65 |
+
response.raise_for_status()
|
66 |
+
return response.json()
|
67 |
+
except RequestException as e:
|
68 |
+
logger.error(f"RequestException: {e}")
|
69 |
+
time.sleep(delay) # wait before retrying
|
70 |
+
else:
|
71 |
+
raise Exception(f"Failed to get a valid response after {retries} retries.")
|
72 |
+
|
73 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
74 |
+
if not requests:
|
75 |
+
return []
|
76 |
+
res = []
|
77 |
+
for context, continuation in tqdm(
|
78 |
+
[req.args for req in requests], disable=disable_tqdm
|
79 |
+
):
|
80 |
+
response = self.gguf_completion(context=context, continuation=continuation)
|
81 |
+
if response and "choices" in response and response["choices"]:
|
82 |
+
choice = response["choices"][0]
|
83 |
+
logprobs = choice.get("logprobs")
|
84 |
+
if (
|
85 |
+
logprobs
|
86 |
+
and "token_logprobs" in logprobs
|
87 |
+
and logprobs["token_logprobs"]
|
88 |
+
):
|
89 |
+
logprob, is_greedy = get_result(logprobs, len(context))
|
90 |
+
res.append((logprob, is_greedy))
|
91 |
+
else:
|
92 |
+
logger.warning(
|
93 |
+
"Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
logger.error(
|
97 |
+
f"Invalid response for loglikelihood. Response: {response}"
|
98 |
+
)
|
99 |
+
assert False
|
100 |
+
return res
|
101 |
+
|
102 |
+
def generate_until(self, requests, disable_tqdm: bool = False):
|
103 |
+
if not requests:
|
104 |
+
return []
|
105 |
+
|
106 |
+
res = []
|
107 |
+
for request in tqdm([req.args for req in requests], disable=disable_tqdm):
|
108 |
+
inp = request[0]
|
109 |
+
request_args = request[1]
|
110 |
+
until = request_args.get("until", ["</s>"])
|
111 |
+
response = self.gguf_completion(context=inp, stop=until)
|
112 |
+
if response and "choices" in response and response["choices"]:
|
113 |
+
choice = response["choices"][0]
|
114 |
+
if "text" in choice:
|
115 |
+
generated_text = choice["text"].strip()
|
116 |
+
res.append(generated_text)
|
117 |
+
else:
|
118 |
+
logger.error(
|
119 |
+
f"Invalid response for greedy_until. Response: {response}"
|
120 |
+
)
|
121 |
+
res.append(None) # Add default value in case of error
|
122 |
+
else:
|
123 |
+
logger.error(f"Invalid response for greedy_until. Response: {response}")
|
124 |
+
res.append(None) # Add default value in case of error
|
125 |
+
return res
|
126 |
+
|
127 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
128 |
+
raise NotImplementedError(
|
129 |
+
"loglikelihood_rolling not yet supported for GGUF models"
|
130 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/models/mamba_lm.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import lm_eval.models.utils
|
6 |
+
from lm_eval.api.registry import register_model
|
7 |
+
from lm_eval.models.huggingface import HFLM
|
8 |
+
|
9 |
+
|
10 |
+
@register_model("mamba_ssm")
|
11 |
+
class MambaLMWrapper(HFLM):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
pretrained="state-spaces/mamba-130m",
|
15 |
+
**kwargs,
|
16 |
+
) -> None:
|
17 |
+
"""
|
18 |
+
Mamba (via the `mamba_ssm` package) supports the following args:
|
19 |
+
```
|
20 |
+
d_model: int,
|
21 |
+
n_layer: int,
|
22 |
+
vocab_size: int,
|
23 |
+
initializer_cfg=None,
|
24 |
+
pad_vocab_size_multiple: int = 1,
|
25 |
+
ssm_cfg=None,
|
26 |
+
norm_epsilon: float = 1e-5,
|
27 |
+
rms_norm: bool = False,
|
28 |
+
initializer_cfg=None,
|
29 |
+
fused_add_norm=False,
|
30 |
+
residual_in_fp32=False,
|
31 |
+
```
|
32 |
+
|
33 |
+
See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info.
|
34 |
+
The above can all be passed via `--model_args` or to this __init__() directly
|
35 |
+
but we recommend placing many of these within the config.json file uploaded alongside your
|
36 |
+
Mamba model to the HF Hub instead.
|
37 |
+
All other HuggingFace from_pretrained() kwargs
|
38 |
+
such as those related to
|
39 |
+
`parallelize=True`, PEFT, autoGPTQ,
|
40 |
+
or any sub-configurations of these advanced args,
|
41 |
+
are unsupported by the `mamba_ssm` package.
|
42 |
+
|
43 |
+
The HFLM arguments
|
44 |
+
|
45 |
+
`backend`, `tokenizer`, `truncation`, `max_length`,
|
46 |
+
`device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer`
|
47 |
+
|
48 |
+
Are all supported by Mamba where they do not conflict
|
49 |
+
with Mamba-specific restrictions such as causal LMs only.
|
50 |
+
"""
|
51 |
+
|
52 |
+
if "backend" in kwargs:
|
53 |
+
# mamba currently only supports causal models
|
54 |
+
assert kwargs["backend"] == "causal"
|
55 |
+
|
56 |
+
super().__init__(
|
57 |
+
pretrained=pretrained,
|
58 |
+
# set appropriate defaults for tokenizer, max length, etc
|
59 |
+
backend=kwargs.pop("backend", "causal"),
|
60 |
+
tokenizer=kwargs.pop("tokenizer", "EleutherAI/gpt-neox-20b"),
|
61 |
+
max_length=kwargs.pop("max_length", 2048),
|
62 |
+
**kwargs,
|
63 |
+
)
|
64 |
+
|
65 |
+
def _get_config(
|
66 |
+
self,
|
67 |
+
pretrained: str,
|
68 |
+
**kwargs,
|
69 |
+
) -> None:
|
70 |
+
try:
|
71 |
+
from mamba_ssm.utils.hf import load_config_hf # noqa: F811
|
72 |
+
except ModuleNotFoundError:
|
73 |
+
raise Exception(
|
74 |
+
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
|
75 |
+
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
|
76 |
+
)
|
77 |
+
|
78 |
+
self._config = load_config_hf(pretrained)
|
79 |
+
|
80 |
+
def _create_model(
|
81 |
+
self,
|
82 |
+
pretrained: str,
|
83 |
+
dtype: Optional[Union[str, torch.dtype]] = "float16",
|
84 |
+
# no `parallelize=True` options
|
85 |
+
# no PEFT and quantization options
|
86 |
+
# Mamba does not support arbitrary HF from_pretrained() args
|
87 |
+
**kwargs,
|
88 |
+
) -> None:
|
89 |
+
try:
|
90 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811
|
91 |
+
except ModuleNotFoundError:
|
92 |
+
raise Exception(
|
93 |
+
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
|
94 |
+
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
|
95 |
+
)
|
96 |
+
|
97 |
+
self._model = MambaLMHeadModel.from_pretrained(
|
98 |
+
pretrained,
|
99 |
+
device=self._device,
|
100 |
+
dtype=torch.float16
|
101 |
+
if dtype == "auto"
|
102 |
+
else lm_eval.models.utils.get_dtype(dtype),
|
103 |
+
)
|
104 |
+
|
105 |
+
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
106 |
+
for key in ("do_sample", "attention_mask"):
|
107 |
+
if key in generation_kwargs:
|
108 |
+
generation_kwargs.pop(key)
|
109 |
+
|
110 |
+
# mamba's custom GenerationMixin currently does not support
|
111 |
+
# passing stopping criteria.
|
112 |
+
# for the time being, we simply generate to max length,
|
113 |
+
# then truncate (equivalent result)
|
114 |
+
# -- this should be revisited to speed up generation
|
115 |
+
# stopping_criteria = stop_sequences_criteria(
|
116 |
+
# self.tokenizer, stop, 1, context.shape[0]
|
117 |
+
# )
|
118 |
+
|
119 |
+
return self.model.generate(
|
120 |
+
input_ids=context,
|
121 |
+
max_length=max_length,
|
122 |
+
# stopping_criteria=stopping_criteria,
|
123 |
+
# pad_token_id=self.tokenizer.pad_token_id,
|
124 |
+
# use_cache=True,
|
125 |
+
**generation_kwargs,
|
126 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/models/neuron_optimum.py
ADDED
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import subprocess
|
5 |
+
from collections import defaultdict
|
6 |
+
from typing import List, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import transformers
|
11 |
+
from packaging import version
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import GenerationConfig
|
14 |
+
from transformers.generation import StoppingCriteriaList
|
15 |
+
|
16 |
+
import lm_eval.models.utils
|
17 |
+
from lm_eval import utils
|
18 |
+
from lm_eval.api.model import TemplateLM
|
19 |
+
from lm_eval.api.registry import register_model
|
20 |
+
from lm_eval.models.utils import stop_sequences_criteria
|
21 |
+
|
22 |
+
|
23 |
+
try:
|
24 |
+
NEURON_AVAILABLE = True
|
25 |
+
from optimum.neuron import NeuronModelForCausalLM
|
26 |
+
from optimum.neuron.generation import TokenSelector
|
27 |
+
from optimum.neuron.version import __version__ as optimum_neuron_version
|
28 |
+
except ImportError:
|
29 |
+
NeuronModelForCausalLM = object
|
30 |
+
NEURON_AVAILABLE = False
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
def get_nc_count() -> Union[int, None]:
|
37 |
+
"""Returns the number of neuron cores on the current instance."""
|
38 |
+
try:
|
39 |
+
cmd = "neuron-ls --json-output"
|
40 |
+
result = subprocess.run(cmd, shell=True, capture_output=True)
|
41 |
+
print(f"inferring nc_count from `neuron-ls` {result.stdout}")
|
42 |
+
json_output = json.loads(result.stdout)
|
43 |
+
count = sum([x["nc_count"] for x in json_output])
|
44 |
+
print(f"nc_count={count}")
|
45 |
+
return count
|
46 |
+
except Exception:
|
47 |
+
return None
|
48 |
+
|
49 |
+
|
50 |
+
def wrap_constant_batch_size(func):
|
51 |
+
def _decorator(self, input_ids):
|
52 |
+
"""input_ids a 2D array with batch_size on dim=0
|
53 |
+
|
54 |
+
makes sure the func runs with self.batch_size
|
55 |
+
"""
|
56 |
+
# access a from TestSample
|
57 |
+
batch_size = input_ids.shape[0]
|
58 |
+
|
59 |
+
if batch_size < self.batch_size:
|
60 |
+
# handle the event of input_ids.shape[0] != batch_size
|
61 |
+
# Neuron cores expect constant batch_size
|
62 |
+
input_ids = torch.concat(
|
63 |
+
(
|
64 |
+
input_ids,
|
65 |
+
# add missing_batch_size dummy
|
66 |
+
torch.zeros(
|
67 |
+
[self.batch_size - batch_size, *input_ids.size()[1:]],
|
68 |
+
dtype=input_ids.dtype,
|
69 |
+
device=input_ids.device,
|
70 |
+
),
|
71 |
+
),
|
72 |
+
dim=0,
|
73 |
+
)
|
74 |
+
elif batch_size > self.batch_size:
|
75 |
+
raise ValueError(
|
76 |
+
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
|
77 |
+
)
|
78 |
+
# return the forward pass that requires constant batch size
|
79 |
+
return func(self, input_ids)[:batch_size]
|
80 |
+
|
81 |
+
return _decorator
|
82 |
+
|
83 |
+
|
84 |
+
class CustomNeuronModelForCausalLM(NeuronModelForCausalLM):
|
85 |
+
"""NeuronModelForCausalLM with `stopping_criteria` in `generate`"""
|
86 |
+
|
87 |
+
def generate(
|
88 |
+
self,
|
89 |
+
input_ids: torch.Tensor,
|
90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
91 |
+
stopping_criteria: Optional["StoppingCriteriaList"] = None,
|
92 |
+
generation_config: Optional["GenerationConfig"] = None,
|
93 |
+
**kwargs,
|
94 |
+
) -> torch.LongTensor:
|
95 |
+
r"""
|
96 |
+
A streamlined generate() method overriding the transformers.GenerationMixin.generate() method.
|
97 |
+
|
98 |
+
This method uses the same logits processors/warpers and stopping criteria as the transformers library
|
99 |
+
`generate()` method but restricts the generation to greedy search and sampling.
|
100 |
+
|
101 |
+
It does not support transformers `generate()` advanced options.
|
102 |
+
|
103 |
+
Please refer to https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
104 |
+
for details on generation configuration.
|
105 |
+
|
106 |
+
Parameters:
|
107 |
+
input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
108 |
+
The sequence used as a prompt for the generation.
|
109 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
110 |
+
Mask to avoid performing attention on padding token indices.
|
111 |
+
generation_config (`~transformers.generation.GenerationConfig`, *optional*):
|
112 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
113 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
114 |
+
`generation_config` is not provided, default will be used, which had the following loading
|
115 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
116 |
+
configuration. Please note that unspecified parameters will inherit [`~transformers.generation.GenerationConfig`]'s
|
117 |
+
default values, whose documentation should be checked to parameterize generation.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
`torch.Tensor`: A `torch.FloatTensor`.
|
121 |
+
"""
|
122 |
+
# The actual generation configuration is a combination of config and parameters
|
123 |
+
generation_config = copy.deepcopy(
|
124 |
+
self.generation_config if generation_config is None else generation_config
|
125 |
+
)
|
126 |
+
model_kwargs = generation_config.update(
|
127 |
+
**kwargs
|
128 |
+
) # All unused kwargs must be model kwargs
|
129 |
+
# Check model kwargs are actually used by either prepare_inputs_for_generation or forward
|
130 |
+
self._validate_model_kwargs(model_kwargs)
|
131 |
+
|
132 |
+
# Instantiate a TokenSelector for the specified configuration
|
133 |
+
selector = TokenSelector.create(
|
134 |
+
input_ids, generation_config, self, self.max_length
|
135 |
+
)
|
136 |
+
selector.stopping_criteria.append(stopping_criteria)
|
137 |
+
# Verify that the inputs are compatible with the model static input dimensions
|
138 |
+
batch_size, sequence_length = input_ids.shape
|
139 |
+
if sequence_length > self.max_length:
|
140 |
+
raise ValueError(
|
141 |
+
f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})"
|
142 |
+
)
|
143 |
+
padded_input_ids = input_ids
|
144 |
+
padded_attention_mask = attention_mask
|
145 |
+
if batch_size > self.batch_size:
|
146 |
+
raise ValueError(
|
147 |
+
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
|
148 |
+
)
|
149 |
+
elif batch_size < self.batch_size:
|
150 |
+
logger.warning(
|
151 |
+
"Inputs will be padded to match the model static batch size. This will increase latency."
|
152 |
+
)
|
153 |
+
padding_shape = [self.batch_size - batch_size, sequence_length]
|
154 |
+
padding = torch.full(
|
155 |
+
padding_shape, fill_value=self.config.eos_token_id, dtype=torch.int64
|
156 |
+
)
|
157 |
+
padded_input_ids = torch.cat([input_ids, padding])
|
158 |
+
if attention_mask is not None:
|
159 |
+
padding = torch.zeros(padding_shape, dtype=torch.int64)
|
160 |
+
padded_attention_mask = torch.cat([attention_mask, padding])
|
161 |
+
# Drop the current generation context and clear the Key/Value cache
|
162 |
+
self.reset_generation()
|
163 |
+
|
164 |
+
output_ids = self.generate_tokens(
|
165 |
+
padded_input_ids,
|
166 |
+
selector,
|
167 |
+
batch_size,
|
168 |
+
attention_mask=padded_attention_mask,
|
169 |
+
**model_kwargs,
|
170 |
+
)
|
171 |
+
return output_ids[:batch_size, :]
|
172 |
+
|
173 |
+
|
174 |
+
@register_model("neuronx")
|
175 |
+
class NEURON_HF(TemplateLM):
|
176 |
+
"""
|
177 |
+
Enables usage with on AWS Neuron
|
178 |
+
using the HuggingFace Transformers + Transformers neuronx library.
|
179 |
+
Tested with neuron 2.17.0
|
180 |
+
"""
|
181 |
+
|
182 |
+
_DEFAULT_MAX_LENGTH = 2048
|
183 |
+
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
pretrained: Optional[str] = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
187 |
+
revision: Optional[str] = "main",
|
188 |
+
tp_degree: Optional[int] = None,
|
189 |
+
subfolder: Optional[str] = None,
|
190 |
+
tokenizer: Optional[str] = None,
|
191 |
+
truncation: Optional[bool] = False,
|
192 |
+
max_length: Optional[int] = None,
|
193 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
194 |
+
batch_size: Optional[int] = 1,
|
195 |
+
low_cpu_mem_usage: Optional[bool] = True,
|
196 |
+
trust_remote_code: Optional[bool] = False,
|
197 |
+
use_fast_tokenizer: Optional[bool] = True,
|
198 |
+
add_bos_token: Optional[bool] = False,
|
199 |
+
) -> None:
|
200 |
+
if not NEURON_AVAILABLE:
|
201 |
+
raise Exception(
|
202 |
+
"Tried to load neuron model, but neuron is not installed ",
|
203 |
+
"please install neuron via pip install transformers-neuron ",
|
204 |
+
"also make sure you are running on an AWS inf2 instance",
|
205 |
+
)
|
206 |
+
if version.parse(optimum_neuron_version) != version.parse("0.0.17"):
|
207 |
+
logger.warning(
|
208 |
+
'`optimum-neuron` model requires `pip install "optimum[neuronx]>=0.0.17" '
|
209 |
+
"preferably using the Hugging Face Neuron Deep Learning AMI (Ubuntu 22.04) "
|
210 |
+
"https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2 "
|
211 |
+
f"You are using optimum-neuron={optimum_neuron_version}"
|
212 |
+
)
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
assert isinstance(pretrained, str)
|
216 |
+
assert isinstance(batch_size, (int, str))
|
217 |
+
|
218 |
+
self.batch_size_per_gpu = int(batch_size)
|
219 |
+
batch_size = int(batch_size)
|
220 |
+
if tp_degree is None:
|
221 |
+
# execute `neuron-ls --json-output | jq '.[0].nc_count'``
|
222 |
+
# to get the number of neuron cores on your instance
|
223 |
+
tp_degree = get_nc_count()
|
224 |
+
|
225 |
+
assert isinstance(tp_degree, int), (
|
226 |
+
f"model_args must include tp_degree. tp_degree must be set to an integer,"
|
227 |
+
f" but is tp_degree=`{tp_degree}` with type=`{type(tp_degree)}`."
|
228 |
+
"Set it to number of neuron cores on your instance."
|
229 |
+
" For inf2.xlarge and inf2.8xlarge, set it to `2`."
|
230 |
+
" For inf2.24xlarge, set it to `12`."
|
231 |
+
" For inf2.48xlarge, set it to `24`."
|
232 |
+
)
|
233 |
+
|
234 |
+
revision = str(revision) # cast to string if not already one
|
235 |
+
# TODO: update this to be less of a hack once subfolder is fixed in HF
|
236 |
+
revision = revision + ("/" + subfolder if subfolder is not None else "")
|
237 |
+
|
238 |
+
self._config = transformers.AutoConfig.from_pretrained(
|
239 |
+
pretrained,
|
240 |
+
revision=revision,
|
241 |
+
trust_remote_code=trust_remote_code,
|
242 |
+
)
|
243 |
+
torch_dtype = lm_eval.models.utils.get_dtype(dtype)
|
244 |
+
|
245 |
+
assert torch_dtype in [
|
246 |
+
torch.float16,
|
247 |
+
torch.bfloat16,
|
248 |
+
], "Only float16 and bfloat16 are supported"
|
249 |
+
|
250 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
251 |
+
pretrained if tokenizer is None else tokenizer,
|
252 |
+
revision=revision,
|
253 |
+
trust_remote_code=trust_remote_code,
|
254 |
+
use_fast=use_fast_tokenizer,
|
255 |
+
)
|
256 |
+
|
257 |
+
# Neuron specific code
|
258 |
+
if torch_dtype == torch.float16:
|
259 |
+
self.amp_dtype = "f16"
|
260 |
+
elif torch_dtype == torch.bfloat16:
|
261 |
+
self.amp_dtype = "bf16"
|
262 |
+
elif torch_dtype == torch.float32:
|
263 |
+
self.amp_dtype = "f32"
|
264 |
+
else:
|
265 |
+
raise NotImplementedError("Only float16 and bfloat16 are implemented.")
|
266 |
+
|
267 |
+
compiler_args = {"num_cores": tp_degree, "auto_cast_type": self.amp_dtype}
|
268 |
+
input_shapes = {
|
269 |
+
"batch_size": batch_size,
|
270 |
+
"sequence_length": self._DEFAULT_MAX_LENGTH,
|
271 |
+
}
|
272 |
+
|
273 |
+
print(
|
274 |
+
f"{'='*20} \n loading model to neuron with"
|
275 |
+
f" {compiler_args}, {input_shapes}..."
|
276 |
+
)
|
277 |
+
self.model = CustomNeuronModelForCausalLM.from_pretrained(
|
278 |
+
pretrained,
|
279 |
+
revision=revision,
|
280 |
+
trust_remote_code=trust_remote_code,
|
281 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
282 |
+
export=True,
|
283 |
+
**compiler_args,
|
284 |
+
**input_shapes,
|
285 |
+
)
|
286 |
+
print(f"SUCCESS: neuron model compiled. \n {'='*20}")
|
287 |
+
|
288 |
+
self.truncation = truncation
|
289 |
+
|
290 |
+
self.vocab_size = self.tokenizer.vocab_size
|
291 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
292 |
+
self.add_bos_token = add_bos_token
|
293 |
+
|
294 |
+
self._max_length = max_length
|
295 |
+
|
296 |
+
self.batch_schedule = 1
|
297 |
+
self.batch_sizes = {}
|
298 |
+
|
299 |
+
@property
|
300 |
+
def config(self):
|
301 |
+
# return the associated transformers.AutoConfig for the given pretrained model.
|
302 |
+
return self._config
|
303 |
+
|
304 |
+
@property
|
305 |
+
def eot_token_id(self):
|
306 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
307 |
+
return self.tokenizer.eos_token_id
|
308 |
+
|
309 |
+
@property
|
310 |
+
def prefix_token_id(self):
|
311 |
+
# it is used as prefix for loglikelihood
|
312 |
+
return self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
313 |
+
|
314 |
+
@property
|
315 |
+
def max_length(self):
|
316 |
+
if self._max_length: # if max length manually set, return it
|
317 |
+
return self._max_length
|
318 |
+
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
|
319 |
+
for attr in seqlen_config_attrs:
|
320 |
+
if hasattr(self.model.config, attr):
|
321 |
+
return getattr(self.model.config, attr)
|
322 |
+
if hasattr(self.tokenizer, "model_max_length"):
|
323 |
+
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
|
324 |
+
return self._DEFAULT_MAX_LENGTH
|
325 |
+
return self.tokenizer.model_max_length
|
326 |
+
return self._DEFAULT_MAX_LENGTH
|
327 |
+
|
328 |
+
@property
|
329 |
+
def max_gen_toks(self) -> int:
|
330 |
+
return 256
|
331 |
+
|
332 |
+
@property
|
333 |
+
def batch_size(self):
|
334 |
+
return self.batch_size_per_gpu
|
335 |
+
|
336 |
+
@property
|
337 |
+
def device(self):
|
338 |
+
"""device are neuron cores, but the created tensors are on CPU."""
|
339 |
+
return "cpu"
|
340 |
+
|
341 |
+
@property
|
342 |
+
def rank(self):
|
343 |
+
return 0
|
344 |
+
|
345 |
+
@property
|
346 |
+
def world_size(self):
|
347 |
+
return 1
|
348 |
+
|
349 |
+
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
|
350 |
+
""" """
|
351 |
+
if add_special_tokens is None:
|
352 |
+
add_special_tokens = False or self.add_bos_token
|
353 |
+
|
354 |
+
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
|
355 |
+
|
356 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
357 |
+
if left_truncate_len:
|
358 |
+
encoding = encoding[-left_truncate_len:]
|
359 |
+
|
360 |
+
return encoding
|
361 |
+
|
362 |
+
def tok_batch_encode(
|
363 |
+
self,
|
364 |
+
strings: List[str],
|
365 |
+
padding_side: str = "left",
|
366 |
+
left_truncate_len: int = None,
|
367 |
+
truncation: bool = False,
|
368 |
+
):
|
369 |
+
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
370 |
+
old_padding_side = self.tokenizer.padding_side
|
371 |
+
self.tokenizer.padding_side = padding_side
|
372 |
+
|
373 |
+
add_special_tokens = False or self.add_bos_token
|
374 |
+
|
375 |
+
encoding = self.tokenizer(
|
376 |
+
strings,
|
377 |
+
truncation=truncation,
|
378 |
+
padding="longest",
|
379 |
+
return_tensors="pt",
|
380 |
+
add_special_tokens=add_special_tokens,
|
381 |
+
)
|
382 |
+
if left_truncate_len:
|
383 |
+
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
384 |
+
encoding["attention_mask"] = encoding["attention_mask"][
|
385 |
+
:, -left_truncate_len:
|
386 |
+
]
|
387 |
+
self.tokenizer.padding_side = old_padding_side
|
388 |
+
|
389 |
+
return encoding["input_ids"], encoding["attention_mask"]
|
390 |
+
|
391 |
+
def tok_decode(self, tokens):
|
392 |
+
return self.tokenizer.decode(tokens)
|
393 |
+
|
394 |
+
@wrap_constant_batch_size
|
395 |
+
def _model_call(self, input_ids: torch.Tensor):
|
396 |
+
"""
|
397 |
+
get logits for the entire sequence
|
398 |
+
|
399 |
+
:param input_ids: torch.Tensor
|
400 |
+
A torch tensor of shape [batch, sequence_cont]
|
401 |
+
the size of sequence may vary from call to call
|
402 |
+
:return
|
403 |
+
A torch tensor of shape [batch, sequence, vocab] with the
|
404 |
+
logits returned from the model's decoder-lm head
|
405 |
+
"""
|
406 |
+
_, sequence_length = input_ids.shape
|
407 |
+
|
408 |
+
with torch.inference_mode():
|
409 |
+
cache_ids = torch.arange(0, sequence_length, dtype=torch.int32).split(1)
|
410 |
+
input_ids_split = input_ids.split(1, dim=1)
|
411 |
+
|
412 |
+
return torch.concat(
|
413 |
+
[
|
414 |
+
self.model.forward(
|
415 |
+
input_ids=input_id, cache_ids=cache_id, return_dict=False
|
416 |
+
)[0]
|
417 |
+
for input_id, cache_id in zip(input_ids_split, cache_ids)
|
418 |
+
],
|
419 |
+
dim=1,
|
420 |
+
)
|
421 |
+
|
422 |
+
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
423 |
+
# we require users to pass do_sample=True explicitly
|
424 |
+
# for non-greedy gen. This should be reevaluated when considering beam search.
|
425 |
+
|
426 |
+
with torch.inference_mode():
|
427 |
+
if "do_sample" not in generation_kwargs.keys():
|
428 |
+
generation_kwargs["do_sample"] = False
|
429 |
+
|
430 |
+
stopping_criteria = stop_sequences_criteria(
|
431 |
+
self.tokenizer,
|
432 |
+
stop + [self.tokenizer.decode([self.config.eos_token_id])],
|
433 |
+
1,
|
434 |
+
context.shape[0],
|
435 |
+
)
|
436 |
+
|
437 |
+
return self.model.generate(
|
438 |
+
input_ids=context,
|
439 |
+
max_length=max_length,
|
440 |
+
stopping_criteria=stopping_criteria,
|
441 |
+
pad_token_id=self.eot_token_id,
|
442 |
+
use_cache=True,
|
443 |
+
**generation_kwargs,
|
444 |
+
)
|
445 |
+
|
446 |
+
def _select_cont_toks(self, logits, contlen=None, inplen=None):
|
447 |
+
assert (
|
448 |
+
contlen and inplen
|
449 |
+
), "Must pass input len and cont. len to select scored logits for causal LM"
|
450 |
+
# discard right-padding.
|
451 |
+
# also discard the input/context tokens. we'll only score continuations.
|
452 |
+
logits = logits[inplen - contlen : inplen]
|
453 |
+
|
454 |
+
return logits
|
455 |
+
|
456 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
457 |
+
loglikelihoods = []
|
458 |
+
|
459 |
+
adaptive_batch_size = None
|
460 |
+
|
461 |
+
for (string,) in tqdm(
|
462 |
+
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
|
463 |
+
):
|
464 |
+
rolling_token_windows = list(
|
465 |
+
map(
|
466 |
+
utils.make_disjoint_window,
|
467 |
+
utils.get_rolling_token_windows(
|
468 |
+
token_list=self.tok_encode(string),
|
469 |
+
prefix_token=self.prefix_token_id,
|
470 |
+
max_seq_len=self.max_length,
|
471 |
+
context_len=1,
|
472 |
+
),
|
473 |
+
)
|
474 |
+
)
|
475 |
+
|
476 |
+
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
477 |
+
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
478 |
+
|
479 |
+
pad_amnt = 0
|
480 |
+
if self.world_size > 1:
|
481 |
+
# We pad out the external document-level iterator so the inner iterator doesn't hang
|
482 |
+
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
|
483 |
+
gathered = (
|
484 |
+
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
|
485 |
+
)
|
486 |
+
|
487 |
+
pad_amnt = max(gathered) - gathered[self.rank]
|
488 |
+
if pad_amnt > 0:
|
489 |
+
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
|
490 |
+
|
491 |
+
string_nll = self._loglikelihood_tokens(
|
492 |
+
rolling_token_windows,
|
493 |
+
disable_tqdm=True,
|
494 |
+
override_bs=adaptive_batch_size,
|
495 |
+
)
|
496 |
+
|
497 |
+
if (self.world_size > 1) and (pad_amnt > 0):
|
498 |
+
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
|
499 |
+
else:
|
500 |
+
# discard is_greedy
|
501 |
+
string_nll = [x[0] for x in string_nll]
|
502 |
+
|
503 |
+
string_nll = sum(string_nll)
|
504 |
+
loglikelihoods.append(string_nll)
|
505 |
+
|
506 |
+
return loglikelihoods
|
507 |
+
|
508 |
+
def _loglikelihood_tokens(
|
509 |
+
self, requests, disable_tqdm: bool = False, override_bs=None
|
510 |
+
):
|
511 |
+
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
|
512 |
+
res = []
|
513 |
+
|
514 |
+
def _collate(x):
|
515 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
516 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
517 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
518 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
519 |
+
# automatic adaptive batches much much easier to implement
|
520 |
+
# - any OOMs will happen right away rather than near the end
|
521 |
+
|
522 |
+
toks = x[1] + x[2]
|
523 |
+
return -len(toks), tuple(toks)
|
524 |
+
|
525 |
+
re_ord = utils.Reorderer(requests, _collate)
|
526 |
+
|
527 |
+
n_reordered_requests = len(re_ord.get_reordered()) # noqa
|
528 |
+
# automatic (variable) batch size detection for vectorization
|
529 |
+
# pull longest context sample from request
|
530 |
+
|
531 |
+
chunks = lm_eval.models.utils.chunks(
|
532 |
+
re_ord.get_reordered(),
|
533 |
+
n=self.batch_size,
|
534 |
+
fn=None,
|
535 |
+
)
|
536 |
+
|
537 |
+
for chunk in tqdm(chunks, disable=(disable_tqdm or (self.rank != 0))):
|
538 |
+
inps = []
|
539 |
+
cont_toks_list = []
|
540 |
+
inplens = []
|
541 |
+
|
542 |
+
conts = [] # noqa
|
543 |
+
encoder_attns = [] # noqa
|
544 |
+
|
545 |
+
padding_len_inp = None
|
546 |
+
padding_len_cont = None # noqa
|
547 |
+
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
|
548 |
+
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
|
549 |
+
# again because vectorizing is annoying
|
550 |
+
|
551 |
+
for _, context_enc, continuation_enc in chunk:
|
552 |
+
# sanity check
|
553 |
+
assert len(context_enc) > 0
|
554 |
+
assert len(continuation_enc) > 0
|
555 |
+
assert len(continuation_enc) <= self.max_length
|
556 |
+
|
557 |
+
# how this all works (illustrated on a causal decoder-only setup):
|
558 |
+
# CTX CONT
|
559 |
+
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
560 |
+
# model \ \
|
561 |
+
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
562 |
+
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
563 |
+
|
564 |
+
# when too long to fit in context, truncate from the left
|
565 |
+
inp = torch.tensor(
|
566 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
|
567 |
+
dtype=torch.long,
|
568 |
+
device=self.device,
|
569 |
+
)
|
570 |
+
(inplen,) = inp.shape
|
571 |
+
|
572 |
+
padding_len_inp = (
|
573 |
+
max(padding_len_inp, inplen)
|
574 |
+
if padding_len_inp is not None
|
575 |
+
else inplen
|
576 |
+
)
|
577 |
+
|
578 |
+
inps.append(inp) # [1, inp_length]
|
579 |
+
cont_toks_list.append(continuation_enc)
|
580 |
+
inplens.append(inplen)
|
581 |
+
|
582 |
+
# create encoder attn mask and batched conts, if seq2seq
|
583 |
+
call_kwargs = {}
|
584 |
+
batched_inps = lm_eval.models.utils.pad_and_concat(
|
585 |
+
padding_len_inp, inps, padding_side="right"
|
586 |
+
) # [batch, padding_len_inp]
|
587 |
+
|
588 |
+
multi_logits = F.log_softmax(
|
589 |
+
self._model_call(batched_inps, **call_kwargs), dim=-1
|
590 |
+
) # [batch, padding_length (inp or cont), vocab]
|
591 |
+
|
592 |
+
for (cache_key, _, _), logits, inplen, cont_toks in zip(
|
593 |
+
chunk, multi_logits, inplens, cont_toks_list
|
594 |
+
):
|
595 |
+
# Slice to original seq length
|
596 |
+
contlen = len(cont_toks)
|
597 |
+
# take only logits in the continuation
|
598 |
+
# (discard context toks if decoder-only ; discard right-padding)
|
599 |
+
# also discards + checks for "virtual tokens" in the causal LM's input window
|
600 |
+
# from prompt/prefix tuning tokens, if applicable
|
601 |
+
ctx_len = inplen + (logits.shape[0] - padding_len_inp)
|
602 |
+
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
|
603 |
+
logits = logits.unsqueeze(0) # [1, seq, vocab]
|
604 |
+
|
605 |
+
# Check if per-token argmax is exactly equal to continuation
|
606 |
+
greedy_tokens = logits.argmax(dim=-1)
|
607 |
+
cont_toks = torch.tensor(
|
608 |
+
cont_toks, dtype=torch.long, device=self.device
|
609 |
+
).unsqueeze(0) # [1, seq]
|
610 |
+
max_equal = (greedy_tokens == cont_toks).all()
|
611 |
+
|
612 |
+
# Obtain log-probs at the corresponding continuation token indices
|
613 |
+
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
|
614 |
+
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
|
615 |
+
-1
|
616 |
+
) # [1, seq]
|
617 |
+
|
618 |
+
# Answer: (log prob, is-exact-match)
|
619 |
+
answer = (float(logits.sum()), bool(max_equal))
|
620 |
+
|
621 |
+
res.append(answer)
|
622 |
+
|
623 |
+
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
|
624 |
+
|
625 |
+
return re_ord.get_original(res)
|
626 |
+
|
627 |
+
def generate_until(self, requests, disable_tqdm: bool = False):
|
628 |
+
res = defaultdict(list)
|
629 |
+
re_ords = {}
|
630 |
+
|
631 |
+
def _collate(x):
|
632 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
633 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
634 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
635 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
636 |
+
# automatic adaptive batches much much easier to implement
|
637 |
+
# - any OOMs will happen right away rather than near the end
|
638 |
+
toks = self.tok_encode(x[0])
|
639 |
+
return -len(toks), x[0]
|
640 |
+
|
641 |
+
# we group requests by their generation_kwargs,
|
642 |
+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
643 |
+
# in the same batch.
|
644 |
+
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
|
645 |
+
for key, reqs in grouper.get_grouped().items():
|
646 |
+
# within each set of reqs for given kwargs, we reorder by token length, descending.
|
647 |
+
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
|
648 |
+
|
649 |
+
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
|
650 |
+
|
651 |
+
# for each different set of kwargs, we execute all requests, by batch.
|
652 |
+
for key, re_ord in re_ords.items():
|
653 |
+
chunks = lm_eval.models.utils.chunks(
|
654 |
+
re_ord.get_reordered(), n=self.batch_size
|
655 |
+
)
|
656 |
+
for chunk in tqdm(chunks, disable=self.rank != 0):
|
657 |
+
contexts, all_gen_kwargs = zip(*chunk)
|
658 |
+
# we assume all gen kwargs in the batch are the same
|
659 |
+
# this is safe to assume because the `grouper` object ensures it.
|
660 |
+
gen_kwargs = all_gen_kwargs[0]
|
661 |
+
# unpack our keyword arguments.
|
662 |
+
until = None
|
663 |
+
if isinstance(gen_kwargs, dict):
|
664 |
+
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
665 |
+
if "until" in kwargs.keys():
|
666 |
+
until = kwargs.pop("until")
|
667 |
+
if isinstance(until, str):
|
668 |
+
until = [until]
|
669 |
+
elif not isinstance(until, list):
|
670 |
+
raise ValueError(
|
671 |
+
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
|
672 |
+
)
|
673 |
+
else:
|
674 |
+
raise ValueError(
|
675 |
+
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
|
676 |
+
)
|
677 |
+
# add EOS token to stop sequences
|
678 |
+
eos = self.tok_decode(self.eot_token_id)
|
679 |
+
if not until:
|
680 |
+
until = [eos]
|
681 |
+
else:
|
682 |
+
until.append(eos)
|
683 |
+
if "max_gen_toks" in kwargs.keys():
|
684 |
+
max_gen_toks = kwargs.pop("max_gen_toks")
|
685 |
+
else:
|
686 |
+
max_gen_toks = self.max_gen_toks
|
687 |
+
# first stop sequence is used to halt generation upon encountering
|
688 |
+
primary_until = [until[0]]
|
689 |
+
|
690 |
+
max_ctx_len = self.max_length - max_gen_toks
|
691 |
+
|
692 |
+
# encode, pad, and truncate contexts for this batch
|
693 |
+
context_enc, attn_masks = self.tok_batch_encode(
|
694 |
+
contexts,
|
695 |
+
left_truncate_len=max_ctx_len,
|
696 |
+
truncation=self.truncation,
|
697 |
+
)
|
698 |
+
context_enc = context_enc.to(self.device)
|
699 |
+
attn_masks = attn_masks.to(self.device)
|
700 |
+
|
701 |
+
if "max_length" not in kwargs:
|
702 |
+
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
|
703 |
+
|
704 |
+
# perform batched generation
|
705 |
+
cont = self._model_generate(
|
706 |
+
context=context_enc,
|
707 |
+
attention_mask=attn_masks,
|
708 |
+
stop=primary_until,
|
709 |
+
**kwargs,
|
710 |
+
)
|
711 |
+
|
712 |
+
cont_toks_list = cont.tolist()
|
713 |
+
for cont_toks, context in zip(cont_toks_list, contexts):
|
714 |
+
# discard context + left-padding toks if using causal decoder-only LM
|
715 |
+
cont_toks = cont_toks[context_enc.shape[1] :]
|
716 |
+
|
717 |
+
s = self.tok_decode(cont_toks)
|
718 |
+
|
719 |
+
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
|
720 |
+
for term in until:
|
721 |
+
if len(term) > 0:
|
722 |
+
# ignore '' separator,
|
723 |
+
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
724 |
+
s = s.split(term)[0]
|
725 |
+
|
726 |
+
res[key].append(s)
|
727 |
+
|
728 |
+
self.cache_hook.add_partial(
|
729 |
+
"generate_until", (context, gen_kwargs), s
|
730 |
+
)
|
731 |
+
pbar.update(1)
|
732 |
+
# reorder this group of results back to original unsorted form
|
733 |
+
res[key] = re_ord.get_original(res[key])
|
734 |
+
|
735 |
+
pbar.close()
|
736 |
+
|
737 |
+
return grouper.get_original(res)
|
scripts/yans/lm-evaluation-harness/lm_eval/models/vllm_causallms.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from importlib.metadata import version
|
3 |
+
from importlib.util import find_spec
|
4 |
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
5 |
+
|
6 |
+
from more_itertools import distribute
|
7 |
+
from packaging.version import parse as parse_version
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from lm_eval.api.instance import Instance
|
11 |
+
from lm_eval.api.model import TemplateLM
|
12 |
+
from lm_eval.api.registry import register_model
|
13 |
+
from lm_eval.models.utils import Collator, configure_pad_token, undistribute
|
14 |
+
from lm_eval.utils import (
|
15 |
+
eval_logger,
|
16 |
+
get_rolling_token_windows,
|
17 |
+
make_disjoint_window,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
try:
|
22 |
+
import ray
|
23 |
+
from vllm import LLM, SamplingParams
|
24 |
+
from vllm.lora.request import LoRARequest
|
25 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
26 |
+
except ModuleNotFoundError:
|
27 |
+
pass
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
pass
|
31 |
+
|
32 |
+
eval_logger = eval_logger
|
33 |
+
|
34 |
+
|
35 |
+
@register_model("vllm")
|
36 |
+
class VLLM(TemplateLM):
|
37 |
+
_DEFAULT_MAX_LENGTH = 2048
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
pretrained: str,
|
42 |
+
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
|
43 |
+
revision: Optional[str] = None,
|
44 |
+
trust_remote_code: Optional[bool] = False,
|
45 |
+
tokenizer: Optional[str] = None,
|
46 |
+
tokenizer_mode: Literal["auto", "slow"] = "auto",
|
47 |
+
tokenizer_revision: Optional[str] = None,
|
48 |
+
add_bos_token: Optional[bool] = False,
|
49 |
+
prefix_token_id: Optional[int] = None,
|
50 |
+
tensor_parallel_size: int = 1,
|
51 |
+
quantization: Optional[str] = None,
|
52 |
+
max_gen_toks: int = 256,
|
53 |
+
swap_space: int = 4,
|
54 |
+
batch_size: Union[str, int] = 1,
|
55 |
+
max_batch_size=None,
|
56 |
+
max_length: int = None,
|
57 |
+
max_model_len: int = None,
|
58 |
+
seed: int = 1234,
|
59 |
+
gpu_memory_utilization: float = 0.9,
|
60 |
+
device: str = "cuda",
|
61 |
+
data_parallel_size: int = 1,
|
62 |
+
lora_local_path: str = None,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
if not find_spec("vllm"):
|
68 |
+
raise Exception(
|
69 |
+
"attempted to use 'vllm' LM type, but package `vllm` is not installed. "
|
70 |
+
"Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
|
71 |
+
)
|
72 |
+
|
73 |
+
assert "cuda" in device or device is None, "vLLM only supports CUDA"
|
74 |
+
assert (
|
75 |
+
max_length is None or max_model_len is None
|
76 |
+
), "Either max_length or max_model_len may be provided, but not both"
|
77 |
+
|
78 |
+
self._max_length = max_model_len if max_model_len is not None else max_length
|
79 |
+
self.tensor_parallel_size = int(tensor_parallel_size)
|
80 |
+
self.data_parallel_size = int(data_parallel_size)
|
81 |
+
self.model_args = {
|
82 |
+
"model": pretrained,
|
83 |
+
"gpu_memory_utilization": float(gpu_memory_utilization),
|
84 |
+
"revision": revision,
|
85 |
+
"dtype": dtype,
|
86 |
+
"tokenizer": tokenizer,
|
87 |
+
"tokenizer_mode": tokenizer_mode,
|
88 |
+
"tokenizer_revision": tokenizer_revision,
|
89 |
+
"trust_remote_code": trust_remote_code,
|
90 |
+
"tensor_parallel_size": int(tensor_parallel_size),
|
91 |
+
"max_model_len": int(self._max_length) if self._max_length else None,
|
92 |
+
"swap_space": int(swap_space),
|
93 |
+
"quantization": quantization,
|
94 |
+
"seed": int(seed),
|
95 |
+
}
|
96 |
+
self.model_args.update(kwargs)
|
97 |
+
self.batch_size = (
|
98 |
+
"auto"
|
99 |
+
if isinstance(batch_size, str) and "auto" in batch_size
|
100 |
+
else batch_size
|
101 |
+
)
|
102 |
+
if self.data_parallel_size <= 1:
|
103 |
+
self.model = LLM(**self.model_args)
|
104 |
+
else:
|
105 |
+
eval_logger.warning(
|
106 |
+
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
|
107 |
+
)
|
108 |
+
self.model_args["worker_use_ray"] = True
|
109 |
+
self.batch_size = "auto"
|
110 |
+
eval_logger.info("Manual batching is not compatible with data parallelism.")
|
111 |
+
|
112 |
+
from transformers import AutoConfig
|
113 |
+
|
114 |
+
self._config = AutoConfig.from_pretrained(
|
115 |
+
pretrained, trust_remote_code=trust_remote_code, revision=revision
|
116 |
+
)
|
117 |
+
self.tokenizer = get_tokenizer(
|
118 |
+
tokenizer if tokenizer else pretrained,
|
119 |
+
tokenizer_mode=tokenizer_mode,
|
120 |
+
trust_remote_code=trust_remote_code,
|
121 |
+
tokenizer_revision=tokenizer_revision,
|
122 |
+
)
|
123 |
+
self.tokenizer = configure_pad_token(self.tokenizer)
|
124 |
+
self.add_bos_token = add_bos_token
|
125 |
+
if "gemma" in pretrained.lower():
|
126 |
+
self.add_bos_token = True
|
127 |
+
eval_logger.info(
|
128 |
+
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
|
129 |
+
)
|
130 |
+
|
131 |
+
self.custom_prefix_token_id = prefix_token_id
|
132 |
+
if prefix_token_id is not None:
|
133 |
+
eval_logger.info(
|
134 |
+
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
|
135 |
+
)
|
136 |
+
|
137 |
+
self._max_gen_toks = max_gen_toks
|
138 |
+
|
139 |
+
if lora_local_path is not None:
|
140 |
+
assert parse_version(version("vllm")) > parse_version(
|
141 |
+
"0.3.0"
|
142 |
+
), "lora adapters only compatible with vllm > v0.3.0."
|
143 |
+
self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
|
144 |
+
else:
|
145 |
+
self.lora_request = None
|
146 |
+
|
147 |
+
@property
|
148 |
+
def eot_token_id(self):
|
149 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
150 |
+
return self.tokenizer.eos_token_id
|
151 |
+
|
152 |
+
@property
|
153 |
+
def prefix_token_id(self):
|
154 |
+
# it is used as prefix for loglikelihood
|
155 |
+
if self.custom_prefix_token_id is not None:
|
156 |
+
return self.custom_prefix_token_id
|
157 |
+
if self.tokenizer.bos_token_id is not None:
|
158 |
+
return self.tokenizer.bos_token_id
|
159 |
+
return self.tokenizer.eos_token_id
|
160 |
+
|
161 |
+
@property
|
162 |
+
def max_length(self):
|
163 |
+
if self._max_length: # if max length manually set, return it
|
164 |
+
return self._max_length
|
165 |
+
if self.data_parallel_size <= 1:
|
166 |
+
return self.model.llm_engine.model_config.max_model_len
|
167 |
+
else:
|
168 |
+
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
|
169 |
+
for attr in seqlen_config_attrs:
|
170 |
+
if hasattr(self._config, attr):
|
171 |
+
return getattr(self._config, attr)
|
172 |
+
if hasattr(self.tokenizer, "model_max_length"):
|
173 |
+
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
|
174 |
+
return self._DEFAULT_MAX_LENGTH
|
175 |
+
return self.tokenizer.model_max_length
|
176 |
+
return self._DEFAULT_MAX_LENGTH
|
177 |
+
|
178 |
+
@property
|
179 |
+
def max_gen_toks(self):
|
180 |
+
return self._max_gen_toks
|
181 |
+
|
182 |
+
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
|
183 |
+
"""
|
184 |
+
Method to apply a chat template to a list of chat history between user and model.
|
185 |
+
"""
|
186 |
+
return self.tokenizer.apply_chat_template(
|
187 |
+
chat_history, tokenize=False, add_generation_prompt=True
|
188 |
+
)
|
189 |
+
|
190 |
+
@property
|
191 |
+
def chat_template(self) -> str:
|
192 |
+
if self.tokenizer.chat_template is not None:
|
193 |
+
return self.tokenizer.chat_template
|
194 |
+
return self.tokenizer.default_chat_template
|
195 |
+
|
196 |
+
@property
|
197 |
+
def tokenizer_name(self) -> str:
|
198 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
199 |
+
|
200 |
+
def tok_encode(
|
201 |
+
self,
|
202 |
+
string: Union[str, List[str]],
|
203 |
+
left_truncate_len: int = None,
|
204 |
+
add_special_tokens: bool = False,
|
205 |
+
truncation: bool = False,
|
206 |
+
) -> Union[List[int], List[List[int]]]:
|
207 |
+
if not add_special_tokens:
|
208 |
+
add_special_tokens = False or self.add_bos_token
|
209 |
+
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
|
210 |
+
string,
|
211 |
+
add_special_tokens=add_special_tokens,
|
212 |
+
truncation=truncation,
|
213 |
+
return_attention_mask=False,
|
214 |
+
).input_ids
|
215 |
+
|
216 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
217 |
+
if left_truncate_len:
|
218 |
+
if not isinstance(string, str):
|
219 |
+
encoding = [enc[-left_truncate_len:] for enc in encoding]
|
220 |
+
else:
|
221 |
+
encoding = encoding[-left_truncate_len:]
|
222 |
+
|
223 |
+
return encoding
|
224 |
+
|
225 |
+
def _model_generate(
|
226 |
+
self,
|
227 |
+
requests: List[List[int]] = None,
|
228 |
+
generate: bool = False,
|
229 |
+
max_tokens: int = None,
|
230 |
+
stop: Optional[List[str]] = None,
|
231 |
+
**kwargs,
|
232 |
+
):
|
233 |
+
if generate:
|
234 |
+
kwargs = self.modify_gen_kwargs(kwargs)
|
235 |
+
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
|
236 |
+
else:
|
237 |
+
sampling_params = SamplingParams(
|
238 |
+
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
|
239 |
+
)
|
240 |
+
if self.data_parallel_size > 1:
|
241 |
+
# vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
|
242 |
+
# also seems to only work with decorator and not with ray.remote() fn
|
243 |
+
# see https://github.com/vllm-project/vllm/issues/973
|
244 |
+
# note: this has changed on 0.3.3, and it only works now if num_gpus are set.
|
245 |
+
# but then tensor_parallel breaks
|
246 |
+
@ray.remote
|
247 |
+
def run_inference_one_model(
|
248 |
+
model_args: dict, sampling_params, requests: List[List[int]]
|
249 |
+
):
|
250 |
+
llm = LLM(**model_args)
|
251 |
+
return llm.generate(
|
252 |
+
prompt_token_ids=requests, sampling_params=sampling_params
|
253 |
+
)
|
254 |
+
|
255 |
+
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
|
256 |
+
# interleaved important to balance context lengths across workers
|
257 |
+
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
|
258 |
+
inputs = ((self.model_args, sampling_params, req) for req in requests)
|
259 |
+
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
|
260 |
+
results = ray.get(object_refs)
|
261 |
+
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
|
262 |
+
ray.shutdown()
|
263 |
+
# flatten results
|
264 |
+
return undistribute(results)
|
265 |
+
|
266 |
+
if self.lora_request is not None:
|
267 |
+
outputs = self.model.generate(
|
268 |
+
prompt_token_ids=requests,
|
269 |
+
sampling_params=sampling_params,
|
270 |
+
use_tqdm=True if self.batch_size == "auto" else False,
|
271 |
+
lora_request=self.lora_request,
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
outputs = self.model.generate(
|
275 |
+
prompt_token_ids=requests,
|
276 |
+
sampling_params=sampling_params,
|
277 |
+
use_tqdm=True if self.batch_size == "auto" else False,
|
278 |
+
)
|
279 |
+
return outputs
|
280 |
+
|
281 |
+
def loglikelihood_rolling(
|
282 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
283 |
+
) -> List[float]:
|
284 |
+
loglikelihoods = []
|
285 |
+
|
286 |
+
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
|
287 |
+
rolling_token_windows = list(
|
288 |
+
map(
|
289 |
+
make_disjoint_window,
|
290 |
+
get_rolling_token_windows(
|
291 |
+
token_list=self.tok_encode(string),
|
292 |
+
prefix_token=self.eot_token_id,
|
293 |
+
max_seq_len=self.max_length - 1,
|
294 |
+
context_len=1,
|
295 |
+
),
|
296 |
+
)
|
297 |
+
)
|
298 |
+
|
299 |
+
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
300 |
+
|
301 |
+
string_nll = self._loglikelihood_tokens(
|
302 |
+
rolling_token_windows,
|
303 |
+
)
|
304 |
+
|
305 |
+
# discard is_greedy
|
306 |
+
string_nll = [x[0] for x in string_nll]
|
307 |
+
|
308 |
+
string_nll = sum(string_nll)
|
309 |
+
loglikelihoods.append(string_nll)
|
310 |
+
return loglikelihoods
|
311 |
+
|
312 |
+
def generate_until(
|
313 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
314 |
+
) -> List[str]:
|
315 |
+
res = []
|
316 |
+
|
317 |
+
# batch tokenize contexts
|
318 |
+
context, all_gen_kwargs = zip(*(req.args for req in requests))
|
319 |
+
context_encoding: List[List[int]] = self.tok_encode(
|
320 |
+
context, add_special_tokens=self.add_bos_token
|
321 |
+
)
|
322 |
+
requests = [
|
323 |
+
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
|
324 |
+
]
|
325 |
+
|
326 |
+
def _collate_gen(_requests):
|
327 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
328 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
329 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
330 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
331 |
+
# automatic adaptive batches much much easier to implement
|
332 |
+
# - any OOMs will happen right away rather than near the end
|
333 |
+
return -len(_requests[0][1]), _requests[0][0]
|
334 |
+
|
335 |
+
# we group requests by their generation_kwargs,
|
336 |
+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
337 |
+
# in the same batch.
|
338 |
+
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
|
339 |
+
chunks = re_ords.get_batched(
|
340 |
+
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
|
341 |
+
)
|
342 |
+
|
343 |
+
pbar = tqdm(
|
344 |
+
total=len(requests),
|
345 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
346 |
+
desc="Running generate_until requests",
|
347 |
+
)
|
348 |
+
# for each different set of kwargs, we execute all requests, by batch.
|
349 |
+
for chunk in chunks:
|
350 |
+
context_and_encoding, all_gen_kwargs = zip(*chunk)
|
351 |
+
context, context_encoding = zip(*context_and_encoding)
|
352 |
+
# we assume all gen kwargs in the batch are the same
|
353 |
+
# this is safe to assume because the `grouper` object ensures it.
|
354 |
+
gen_kwargs = all_gen_kwargs[0]
|
355 |
+
# unpack our keyword arguments.
|
356 |
+
until = None
|
357 |
+
if isinstance(gen_kwargs, dict):
|
358 |
+
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
359 |
+
if "until" in kwargs.keys():
|
360 |
+
until = kwargs.pop("until")
|
361 |
+
if isinstance(until, str):
|
362 |
+
until = [until]
|
363 |
+
elif not isinstance(until, list):
|
364 |
+
raise ValueError(
|
365 |
+
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
|
366 |
+
)
|
367 |
+
else:
|
368 |
+
raise ValueError(
|
369 |
+
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
|
370 |
+
)
|
371 |
+
# add EOS token to stop sequences
|
372 |
+
eos = self.tokenizer.decode(self.eot_token_id)
|
373 |
+
if not until:
|
374 |
+
until = [eos]
|
375 |
+
else:
|
376 |
+
until.append(eos)
|
377 |
+
if "max_gen_toks" in kwargs.keys():
|
378 |
+
max_gen_toks = kwargs.pop("max_gen_toks")
|
379 |
+
else:
|
380 |
+
max_gen_toks = self.max_gen_toks
|
381 |
+
|
382 |
+
# set the max length in tokens of inputs ("context_enc")
|
383 |
+
# max len for inputs = max length, minus room to generate the max new tokens
|
384 |
+
max_ctx_len = self.max_length - max_gen_toks
|
385 |
+
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
|
386 |
+
|
387 |
+
# perform batched generation
|
388 |
+
cont = self._model_generate(
|
389 |
+
requests=context_encoding,
|
390 |
+
generate=True,
|
391 |
+
max_tokens=max_gen_toks,
|
392 |
+
stop=until,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
|
396 |
+
# cache generations
|
397 |
+
for output, context in zip(cont, context):
|
398 |
+
generated_text = output.outputs[0].text
|
399 |
+
res.append(generated_text)
|
400 |
+
self.cache_hook.add_partial(
|
401 |
+
"generate_until", (context, gen_kwargs), generated_text
|
402 |
+
)
|
403 |
+
pbar.update(1)
|
404 |
+
|
405 |
+
pbar.close()
|
406 |
+
# reorder all group of results back to original unsorted form
|
407 |
+
return re_ords.get_original(res)
|
408 |
+
|
409 |
+
def _loglikelihood_tokens(
|
410 |
+
self,
|
411 |
+
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
|
412 |
+
disable_tqdm: bool = False,
|
413 |
+
) -> List[Tuple[float, bool]]:
|
414 |
+
res = []
|
415 |
+
|
416 |
+
def _collate(x):
|
417 |
+
toks = x[1] + x[2]
|
418 |
+
return -len(toks), tuple(toks)
|
419 |
+
|
420 |
+
# Reorder requests by length and batch
|
421 |
+
re_ord = Collator(requests, sort_fn=_collate)
|
422 |
+
chunks = re_ord.get_batched(
|
423 |
+
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
|
424 |
+
)
|
425 |
+
|
426 |
+
pbar = tqdm(
|
427 |
+
total=len(requests),
|
428 |
+
disable=disable_tqdm,
|
429 |
+
desc="Running loglikelihood requests",
|
430 |
+
)
|
431 |
+
for chunk in chunks:
|
432 |
+
inputs = []
|
433 |
+
ctxlens = []
|
434 |
+
for cache_key, context_enc, continuation_enc in chunk:
|
435 |
+
inp = (context_enc + continuation_enc)[-(self.max_length) :]
|
436 |
+
ctxlen = len(context_enc) - max(
|
437 |
+
0, len(context_enc) + len(continuation_enc) - (self.max_length)
|
438 |
+
)
|
439 |
+
|
440 |
+
inputs.append(inp)
|
441 |
+
ctxlens.append(ctxlen)
|
442 |
+
|
443 |
+
outputs = self._model_generate(requests=inputs, generate=False)
|
444 |
+
|
445 |
+
for output, ctxlen, (cache_key, _, _), inp in zip(
|
446 |
+
outputs, ctxlens, chunk, inputs
|
447 |
+
):
|
448 |
+
answer = self._parse_logprobs(
|
449 |
+
tokens=inp,
|
450 |
+
outputs=output,
|
451 |
+
ctxlen=ctxlen,
|
452 |
+
)
|
453 |
+
|
454 |
+
res.append(answer)
|
455 |
+
|
456 |
+
# partial caching
|
457 |
+
if cache_key is not None:
|
458 |
+
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
|
459 |
+
pbar.update(1)
|
460 |
+
pbar.close()
|
461 |
+
return re_ord.get_original(res)
|
462 |
+
|
463 |
+
@staticmethod
|
464 |
+
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
|
465 |
+
"""Process logprobs and tokens.
|
466 |
+
|
467 |
+
:param tokens: list
|
468 |
+
Input tokens (potentially left-truncated)
|
469 |
+
:param outputs: RequestOutput
|
470 |
+
Contains prompt_logprobs
|
471 |
+
:param ctxlen: int
|
472 |
+
Length of context (so we can slice them away and only keep the predictions)
|
473 |
+
:return:
|
474 |
+
continuation_logprobs: float
|
475 |
+
Log probabilities of continuation tokens
|
476 |
+
is_greedy: bool
|
477 |
+
Whether argmax matches given continuation exactly
|
478 |
+
"""
|
479 |
+
|
480 |
+
# The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
|
481 |
+
continuation_logprobs_dicts = outputs.prompt_logprobs
|
482 |
+
|
483 |
+
def coerce_logprob_to_num(logprob):
|
484 |
+
# vLLM changed the return type of logprobs from float
|
485 |
+
# to a Logprob object storing the float value + extra data
|
486 |
+
# (https://github.com/vllm-project/vllm/pull/3065).
|
487 |
+
# If we are dealing with vllm's Logprob object, return
|
488 |
+
# the logprob value stored as an attribute. Otherwise,
|
489 |
+
# return the object itself (which should be a float
|
490 |
+
# for older versions of vLLM).
|
491 |
+
return getattr(logprob, "logprob", logprob)
|
492 |
+
|
493 |
+
continuation_logprobs_dicts = [
|
494 |
+
{
|
495 |
+
token: coerce_logprob_to_num(logprob)
|
496 |
+
for token, logprob in logprob_dict.items()
|
497 |
+
}
|
498 |
+
if logprob_dict is not None
|
499 |
+
else None
|
500 |
+
for logprob_dict in continuation_logprobs_dicts
|
501 |
+
]
|
502 |
+
|
503 |
+
# Calculate continuation_logprobs
|
504 |
+
# assume ctxlen always >= 1
|
505 |
+
continuation_logprobs = sum(
|
506 |
+
logprob_dict.get(token)
|
507 |
+
for token, logprob_dict in zip(
|
508 |
+
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
|
509 |
+
)
|
510 |
+
)
|
511 |
+
|
512 |
+
# Determine if is_greedy
|
513 |
+
is_greedy = True
|
514 |
+
for token, logprob_dict in zip(
|
515 |
+
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
|
516 |
+
):
|
517 |
+
# Get the token with the maximum log probability from the logprob_dict
|
518 |
+
if logprob_dict: # Ensure the logprob_dict is not None
|
519 |
+
top_token = max(logprob_dict, key=logprob_dict.get)
|
520 |
+
if top_token != token:
|
521 |
+
is_greedy = False
|
522 |
+
break
|
523 |
+
|
524 |
+
return continuation_logprobs, is_greedy
|
525 |
+
|
526 |
+
@staticmethod
|
527 |
+
def modify_gen_kwargs(kwargs: dict) -> dict:
|
528 |
+
# sampling_params
|
529 |
+
do_sample = kwargs.pop("do_sample", None)
|
530 |
+
if do_sample is False and "temperature" not in kwargs:
|
531 |
+
eval_logger.debug(
|
532 |
+
"Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..."
|
533 |
+
)
|
534 |
+
kwargs["temperature"] = 0.0
|
535 |
+
# hf defaults
|
536 |
+
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
|
537 |
+
kwargs["spaces_between_special_tokens"] = kwargs.get(
|
538 |
+
"spaces_between_special_tokens", False
|
539 |
+
)
|
540 |
+
return kwargs
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/bleu.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
import xml.sax.saxutils
|
6 |
+
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
|
7 |
+
|
8 |
+
|
9 |
+
"""
|
10 |
+
This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
|
11 |
+
"""
|
12 |
+
|
13 |
+
# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $
|
14 |
+
|
15 |
+
"""Provides:
|
16 |
+
|
17 |
+
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
|
18 |
+
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
|
19 |
+
score_cooked(alltest, n=4): Score a list of cooked test sentences.
|
20 |
+
|
21 |
+
score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids.
|
22 |
+
|
23 |
+
The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible.
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
|
27 |
+
nonorm = 0
|
28 |
+
|
29 |
+
preserve_case = False
|
30 |
+
eff_ref_len = "shortest"
|
31 |
+
|
32 |
+
normalize1: List[Tuple[Union[Pattern[str], str], str]] = [
|
33 |
+
("<skipped>", ""), # strip "skipped" tags
|
34 |
+
(r"-\n", ""), # strip end-of-line hyphenation and join lines
|
35 |
+
(r"\n", " "), # join lines
|
36 |
+
# (r'(\d)\s+(?=\d)', r'\1'), # join digits
|
37 |
+
]
|
38 |
+
normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1]
|
39 |
+
|
40 |
+
normalize2: List[Tuple[Union[Pattern[str], str], str]] = [
|
41 |
+
(
|
42 |
+
r"([\{-\~\[-\` -\&\(-\+\:-\@\/])",
|
43 |
+
r" \1 ",
|
44 |
+
), # tokenize punctuation. apostrophe is missing
|
45 |
+
(
|
46 |
+
r"([^0-9])([\.,])",
|
47 |
+
r"\1 \2 ",
|
48 |
+
), # tokenize period and comma unless preceded by a digit
|
49 |
+
(
|
50 |
+
r"([\.,])([^0-9])",
|
51 |
+
r" \1 \2",
|
52 |
+
), # tokenize period and comma unless followed by a digit
|
53 |
+
(r"([0-9])(-)", r"\1 \2 "), # tokenize dash when preceded by a digit
|
54 |
+
]
|
55 |
+
normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2]
|
56 |
+
|
57 |
+
|
58 |
+
def normalize(s):
|
59 |
+
"""Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl."""
|
60 |
+
# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
|
61 |
+
if nonorm:
|
62 |
+
return s.split()
|
63 |
+
if not isinstance(s, str):
|
64 |
+
s = " ".join(s)
|
65 |
+
# language-independent part:
|
66 |
+
for pattern, replace in normalize1:
|
67 |
+
s = re.sub(pattern, replace, s)
|
68 |
+
s = xml.sax.saxutils.unescape(s, {""": '"'})
|
69 |
+
# language-dependent part (assuming Western languages):
|
70 |
+
s = " %s " % s
|
71 |
+
if not preserve_case:
|
72 |
+
s = s.lower() # this might not be identical to the original
|
73 |
+
for pattern, replace in normalize2:
|
74 |
+
s = re.sub(pattern, replace, s)
|
75 |
+
return s.split()
|
76 |
+
|
77 |
+
|
78 |
+
def count_ngrams(words, n=4):
|
79 |
+
counts: Dict[Any, int] = {}
|
80 |
+
for k in range(1, n + 1):
|
81 |
+
for i in range(len(words) - k + 1):
|
82 |
+
ngram = tuple(words[i : i + k])
|
83 |
+
counts[ngram] = counts.get(ngram, 0) + 1
|
84 |
+
return counts
|
85 |
+
|
86 |
+
|
87 |
+
def cook_refs(refs, n=4):
|
88 |
+
"""Takes a list of reference sentences for a single segment
|
89 |
+
and returns an object that encapsulates everything that BLEU
|
90 |
+
needs to know about them."""
|
91 |
+
|
92 |
+
refs = [normalize(ref) for ref in refs]
|
93 |
+
maxcounts: Dict[Tuple[str], int] = {}
|
94 |
+
for ref in refs:
|
95 |
+
counts = count_ngrams(ref, n)
|
96 |
+
for ngram, count in counts.items():
|
97 |
+
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
|
98 |
+
return ([len(ref) for ref in refs], maxcounts)
|
99 |
+
|
100 |
+
|
101 |
+
def cook_test(test, item, n=4):
|
102 |
+
"""Takes a test sentence and returns an object that
|
103 |
+
encapsulates everything that BLEU needs to know about it."""
|
104 |
+
(reflens, refmaxcounts) = item
|
105 |
+
test = normalize(test)
|
106 |
+
result: Dict[str, Any] = {}
|
107 |
+
result["testlen"] = len(test)
|
108 |
+
|
109 |
+
# Calculate effective reference sentence length.
|
110 |
+
|
111 |
+
if eff_ref_len == "shortest":
|
112 |
+
result["reflen"] = min(reflens)
|
113 |
+
elif eff_ref_len == "average":
|
114 |
+
result["reflen"] = float(sum(reflens)) / len(reflens)
|
115 |
+
elif eff_ref_len == "closest":
|
116 |
+
min_diff: Optional[int] = None
|
117 |
+
for reflen in reflens:
|
118 |
+
if min_diff is None or abs(reflen - len(test)) < min_diff:
|
119 |
+
min_diff = abs(reflen - len(test))
|
120 |
+
result["reflen"] = reflen
|
121 |
+
|
122 |
+
result["guess"] = [max(len(test) - k + 1, 0) for k in range(1, n + 1)]
|
123 |
+
|
124 |
+
result["correct"] = [0] * n
|
125 |
+
counts = count_ngrams(test, n)
|
126 |
+
for ngram, count in counts.items():
|
127 |
+
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
|
128 |
+
|
129 |
+
return result
|
130 |
+
|
131 |
+
|
132 |
+
def score_cooked(allcomps, n=4, ground=0, smooth=1):
|
133 |
+
totalcomps: Dict[str, Any] = {
|
134 |
+
"testlen": 0,
|
135 |
+
"reflen": 0,
|
136 |
+
"guess": [0] * n,
|
137 |
+
"correct": [0] * n,
|
138 |
+
}
|
139 |
+
for comps in allcomps:
|
140 |
+
for key in ["testlen", "reflen"]:
|
141 |
+
totalcomps[key] += comps[key]
|
142 |
+
for key in ["guess", "correct"]:
|
143 |
+
for k in range(n):
|
144 |
+
totalcomps[key][k] += comps[key][k]
|
145 |
+
logbleu = 0.0
|
146 |
+
all_bleus: List[float] = []
|
147 |
+
for k in range(n):
|
148 |
+
correct = totalcomps["correct"][k]
|
149 |
+
guess = totalcomps["guess"][k]
|
150 |
+
addsmooth = 0
|
151 |
+
if smooth == 1 and k > 0:
|
152 |
+
addsmooth = 1
|
153 |
+
logbleu += math.log(correct + addsmooth + sys.float_info.min) - math.log(
|
154 |
+
guess + addsmooth + sys.float_info.min
|
155 |
+
)
|
156 |
+
if guess == 0:
|
157 |
+
all_bleus.append(-10000000.0)
|
158 |
+
else:
|
159 |
+
all_bleus.append(math.log(correct + sys.float_info.min) - math.log(guess))
|
160 |
+
|
161 |
+
logbleu /= float(n)
|
162 |
+
all_bleus.insert(0, logbleu)
|
163 |
+
|
164 |
+
brevPenalty = min(
|
165 |
+
0, 1 - float(totalcomps["reflen"] + 1) / (totalcomps["testlen"] + 1)
|
166 |
+
)
|
167 |
+
for i in range(len(all_bleus)):
|
168 |
+
if i == 0:
|
169 |
+
all_bleus[i] += brevPenalty
|
170 |
+
all_bleus[i] = math.exp(all_bleus[i])
|
171 |
+
return all_bleus
|
172 |
+
|
173 |
+
|
174 |
+
def bleu(refs, candidate, ground=0, smooth=1):
|
175 |
+
refs = cook_refs(refs)
|
176 |
+
test = cook_test(candidate, refs)
|
177 |
+
return score_cooked([test], ground=ground, smooth=smooth)
|
178 |
+
|
179 |
+
|
180 |
+
def splitPuncts(line):
|
181 |
+
return " ".join(re.findall(r"[\w]+|[^\s\w]", line))
|
182 |
+
|
183 |
+
|
184 |
+
def computeMaps(predictions, goldfile):
|
185 |
+
predictionMap: Dict[str, list] = {}
|
186 |
+
goldMap: Dict[str, list] = {}
|
187 |
+
gf = open(goldfile, "r", encoding="utf-8")
|
188 |
+
|
189 |
+
for row in predictions:
|
190 |
+
cols = row.strip().split("\t")
|
191 |
+
if len(cols) == 1:
|
192 |
+
(rid, pred) = (cols[0], "")
|
193 |
+
else:
|
194 |
+
(rid, pred) = (cols[0], cols[1])
|
195 |
+
predictionMap[rid] = [splitPuncts(pred.strip().lower())]
|
196 |
+
|
197 |
+
for row in gf:
|
198 |
+
(rid, pred) = row.split("\t")
|
199 |
+
if rid in predictionMap: # Only insert if the id exists for the method
|
200 |
+
if rid not in goldMap:
|
201 |
+
goldMap[rid] = []
|
202 |
+
goldMap[rid].append(splitPuncts(pred.strip().lower()))
|
203 |
+
|
204 |
+
sys.stderr.write("Total: " + str(len(goldMap)) + "\n")
|
205 |
+
return (goldMap, predictionMap)
|
206 |
+
|
207 |
+
|
208 |
+
# m1 is the reference map
|
209 |
+
# m2 is the prediction map
|
210 |
+
def bleuFromMaps(m1, m2):
|
211 |
+
score = [0] * 5
|
212 |
+
num = 0.0
|
213 |
+
|
214 |
+
for key in m1:
|
215 |
+
if key in m2:
|
216 |
+
bl = bleu(m1[key], m2[key][0])
|
217 |
+
score = [score[i] + bl[i] for i in range(0, len(bl))]
|
218 |
+
num += 1
|
219 |
+
return [s * 100.0 / num for s in score]
|
220 |
+
|
221 |
+
|
222 |
+
def smoothed_bleu_4(references, predictions, **kwargs):
|
223 |
+
predictionMap = {}
|
224 |
+
goldMap = {}
|
225 |
+
|
226 |
+
for rid, pred in enumerate(predictions):
|
227 |
+
predictionMap[rid] = [splitPuncts(pred.strip().lower())]
|
228 |
+
|
229 |
+
for rid, row in enumerate(references):
|
230 |
+
goldMap[rid] = [splitPuncts(row.strip().lower())]
|
231 |
+
|
232 |
+
return bleuFromMaps(goldMap, predictionMap)[0]
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
reference_file = sys.argv[1]
|
237 |
+
predictions = []
|
238 |
+
for row in sys.stdin:
|
239 |
+
predictions.append(row)
|
240 |
+
(goldMap, predictionMap) = computeMaps(predictions, reference_file)
|
241 |
+
print(bleuFromMaps(goldMap, predictionMap)[0])
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/go.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group:
|
2 |
+
- codexglue_code2text
|
3 |
+
task: code2text_go
|
4 |
+
dataset_path: CM/codexglue_code2text_go
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
test_split: test
|
8 |
+
output_type: generate_until
|
9 |
+
generation_kwargs:
|
10 |
+
num_beams: 10
|
11 |
+
max_gen_toks: 128
|
12 |
+
until:
|
13 |
+
- "</s>"
|
14 |
+
doc_to_text: !function utils.doc_to_text
|
15 |
+
doc_to_target: !function utils.doc_to_target
|
16 |
+
metric_list:
|
17 |
+
- metric: !function bleu.smoothed_bleu_4
|
18 |
+
aggregation: mean
|
19 |
+
higher_is_better: True
|
20 |
+
metadata:
|
21 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/java.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group:
|
2 |
+
- codexglue_code2text
|
3 |
+
task: code2text_java
|
4 |
+
dataset_path: CM/codexglue_code2text_java
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
test_split: test
|
8 |
+
output_type: generate_until
|
9 |
+
generation_kwargs:
|
10 |
+
num_beams: 10
|
11 |
+
max_gen_toks: 128
|
12 |
+
until:
|
13 |
+
- "</s>"
|
14 |
+
doc_to_text: !function utils.doc_to_text
|
15 |
+
doc_to_target: !function utils.doc_to_target
|
16 |
+
metric_list:
|
17 |
+
- metric: !function bleu.smoothed_bleu_4
|
18 |
+
aggregation: mean
|
19 |
+
higher_is_better: True
|
20 |
+
metadata:
|
21 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/javascript.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group:
|
2 |
+
- codexglue_code2text
|
3 |
+
task: code2text_javascript
|
4 |
+
dataset_path: CM/codexglue_code2text_javascript
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
test_split: test
|
8 |
+
output_type: generate_until
|
9 |
+
generation_kwargs:
|
10 |
+
num_beams: 10
|
11 |
+
max_gen_toks: 128
|
12 |
+
until:
|
13 |
+
- "</s>"
|
14 |
+
doc_to_text: !function utils.doc_to_text
|
15 |
+
doc_to_target: !function utils.doc_to_target
|
16 |
+
metric_list:
|
17 |
+
- metric: !function bleu.smoothed_bleu_4
|
18 |
+
aggregation: mean
|
19 |
+
higher_is_better: True
|
20 |
+
metadata:
|
21 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/php.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group:
|
2 |
+
- codexglue_code2text
|
3 |
+
task: code2text_php
|
4 |
+
dataset_path: CM/codexglue_code2text_php
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
test_split: test
|
8 |
+
output_type: generate_until
|
9 |
+
generation_kwargs:
|
10 |
+
num_beams: 10
|
11 |
+
max_gen_toks: 128
|
12 |
+
until:
|
13 |
+
- "</s>"
|
14 |
+
doc_to_text: !function utils.doc_to_text
|
15 |
+
doc_to_target: !function utils.doc_to_target
|
16 |
+
metric_list:
|
17 |
+
- metric: !function bleu.smoothed_bleu_4
|
18 |
+
aggregation: mean
|
19 |
+
higher_is_better: True
|
20 |
+
metadata:
|
21 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/python.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group:
|
2 |
+
- codexglue_code2text
|
3 |
+
task: code2text_python
|
4 |
+
dataset_path: CM/codexglue_code2text_python
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
test_split: test
|
8 |
+
output_type: generate_until
|
9 |
+
generation_kwargs:
|
10 |
+
num_beams: 10
|
11 |
+
max_gen_toks: 128
|
12 |
+
until:
|
13 |
+
- "</s>"
|
14 |
+
doc_to_text: !function utils.doc_to_text
|
15 |
+
doc_to_target: !function utils.doc_to_target
|
16 |
+
metric_list:
|
17 |
+
- metric: !function bleu.smoothed_bleu_4
|
18 |
+
aggregation: mean
|
19 |
+
higher_is_better: True
|
20 |
+
metadata:
|
21 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/ruby.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group:
|
2 |
+
- codexglue_code2text
|
3 |
+
task: code2text_ruby
|
4 |
+
dataset_path: CM/codexglue_code2text_ruby
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
test_split: test
|
8 |
+
output_type: generate_until
|
9 |
+
generation_kwargs:
|
10 |
+
num_beams: 10
|
11 |
+
max_gen_toks: 128
|
12 |
+
until:
|
13 |
+
- "</s>"
|
14 |
+
doc_to_text: !function utils.doc_to_text
|
15 |
+
doc_to_target: !function utils.doc_to_target
|
16 |
+
metric_list:
|
17 |
+
- metric: !function bleu.smoothed_bleu_4
|
18 |
+
aggregation: mean
|
19 |
+
higher_is_better: True
|
20 |
+
metadata:
|
21 |
+
version: 3.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/code_x_glue/code-text/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def doc_to_text(doc):
|
2 |
+
inputs = " ".join(doc["code_tokens"]).replace("\n", " ")
|
3 |
+
inputs = " ".join(inputs.strip().split())
|
4 |
+
|
5 |
+
return inputs
|
6 |
+
|
7 |
+
|
8 |
+
def doc_to_target(doc):
|
9 |
+
targets = " ".join(doc["docstring_tokens"]).replace("\n", "")
|
10 |
+
targets = " ".join(targets.strip().split())
|
11 |
+
|
12 |
+
return targets
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ETHICS Dataset
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Pointer Sentinel Mixture Models
|
6 |
+
https://arxiv.org/pdf/1609.07843.pdf
|
7 |
+
|
8 |
+
The ETHICS dataset is a benchmark that spans concepts in justice, well-being,
|
9 |
+
duties, virtues, and commonsense morality. Models predict widespread moral
|
10 |
+
judgments about diverse text scenarios. This requires connecting physical and
|
11 |
+
social world knowledge to value judgements, a capability that may enable us
|
12 |
+
to steer chatbot outputs or eventually regularize open-ended reinforcement
|
13 |
+
learning agents.
|
14 |
+
|
15 |
+
Homepage: https://github.com/hendrycks/ethics
|
16 |
+
|
17 |
+
### Citation
|
18 |
+
|
19 |
+
```
|
20 |
+
@article{hendrycks2021ethics
|
21 |
+
title={Aligning AI With Shared Human Values},
|
22 |
+
author={Dan Hendrycks and Collin Burns and Steven Basart and Andrew Critch and Jerry Li and Dawn Song and Jacob Steinhardt},
|
23 |
+
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
|
24 |
+
year={2021}
|
25 |
+
}
|
26 |
+
```
|
27 |
+
|
28 |
+
### Groups and Tasks
|
29 |
+
|
30 |
+
#### Groups
|
31 |
+
|
32 |
+
- `hendrycks_ethics`
|
33 |
+
|
34 |
+
#### Tasks
|
35 |
+
|
36 |
+
* `ethics_cm`
|
37 |
+
* `ethics_deontology`
|
38 |
+
* `ethics_justice`
|
39 |
+
* `ethics_utilitarianism`
|
40 |
+
* (MISSING) `ethics_utilitarianism_original`
|
41 |
+
* `ethics_virtue`
|
42 |
+
|
43 |
+
### Checklist
|
44 |
+
|
45 |
+
* [x] Is the task an existing benchmark in the literature?
|
46 |
+
* [ ] Have you referenced the original paper that introduced the task?
|
47 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
48 |
+
|
49 |
+
|
50 |
+
If other tasks on this dataset are already supported:
|
51 |
+
* [x] Is the "Main" variant of this task clearly denoted?
|
52 |
+
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
53 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
54 |
+
* [ ] Matches v0.3.0 of Eval Harness
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/commonsense.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tag:
|
2 |
+
- hendrycks_ethics
|
3 |
+
task: ethics_cm
|
4 |
+
dataset_path: EleutherAI/hendrycks_ethics
|
5 |
+
dataset_name: commonsense
|
6 |
+
output_type: multiple_choice
|
7 |
+
training_split: train
|
8 |
+
test_split: test
|
9 |
+
doc_to_text: "{{input}}\nQuestion: Is this wrong?\nAnswer:"
|
10 |
+
doc_to_target: label
|
11 |
+
doc_to_choice: ['no', 'yes']
|
12 |
+
metric_list:
|
13 |
+
- metric: acc
|
14 |
+
metadata:
|
15 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/deontology.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: commonsense.yaml
|
2 |
+
task: ethics_deontology
|
3 |
+
dataset_name: deontology
|
4 |
+
doc_to_text: "Question: Would most people believe this reasonable or unreasonable to say? \"{{scenario}} {{excuse.rstrip()}}\"\nAnswer:"
|
5 |
+
doc_to_target: label
|
6 |
+
doc_to_choice: ['unreasonable', 'reasonable']
|
7 |
+
metadata:
|
8 |
+
version: 1.0
|
9 |
+
# TODO: implement exact-match metric for this subset
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/justice.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: deontology.yaml
|
2 |
+
tag:
|
3 |
+
- hendrycks_ethics
|
4 |
+
task: ethics_justice
|
5 |
+
dataset_name: justice
|
6 |
+
doc_to_text: "Question: Would most people believe this reasonable or unreasonable to say? \"{{scenario}}\"\nAnswer:"
|
7 |
+
# TODO: impl. exact match for this and deontology
|
8 |
+
metadata:
|
9 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/utilitarianism.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: commonsense.yaml
|
2 |
+
tag:
|
3 |
+
- hendrycks_ethics
|
4 |
+
task: ethics_utilitarianism
|
5 |
+
dataset_name: utilitarianism
|
6 |
+
doc_to_text: !function utils.doc_to_text
|
7 |
+
doc_to_target: !function utils.doc_to_target
|
8 |
+
doc_to_choice: ['no', 'yes']
|
9 |
+
metric_list:
|
10 |
+
- metric: acc
|
11 |
+
metadata:
|
12 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/utilitarianism_original_yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# group:
|
2 |
+
# - hendrycks_ethics
|
3 |
+
# task: ethics_utilitarianism_original
|
4 |
+
# dataset_path: hails/hendrycks_ethics
|
5 |
+
# dataset_name: utilitarianism
|
6 |
+
# output_type: winograd_schema
|
7 |
+
# fewshot_split: null # TODO: implement a special fewshot split for this dataset subsets
|
8 |
+
# test_split: test
|
9 |
+
# template_aliases: #"{% set answer_choices = range(1, 11)|list %}"
|
10 |
+
# doc_to_text: 'Activity: "{{activity}}"\nRating:'
|
11 |
+
# doc_to_target: "{{answer_choices[label]}}"
|
12 |
+
# metric_list:
|
13 |
+
# - metric: acc
|
14 |
+
# TODO: we want this to be implemented as a winograd_schema task type, actually
|
15 |
+
# metadata:
|
16 |
+
# version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
|
4 |
+
### Utils for `ethics_utilitarianism` task below
|
5 |
+
def _preproc_doc(doc):
|
6 |
+
rnd = random.Random(doc["activity"])
|
7 |
+
scenarios = [doc["activity"], doc["baseline"]]
|
8 |
+
ordering = [0, 1]
|
9 |
+
rnd.shuffle(ordering)
|
10 |
+
doc = {
|
11 |
+
"scenarios": [scenarios[ordering[0]], scenarios[ordering[1]]],
|
12 |
+
# The correct scenario is always first
|
13 |
+
"label": int(ordering.index(0) == 0),
|
14 |
+
}
|
15 |
+
return doc
|
16 |
+
|
17 |
+
|
18 |
+
def doc_to_text(doc) -> str:
|
19 |
+
doc = _preproc_doc(doc)
|
20 |
+
return f"Scenario 1: {doc['scenarios'][0]}\nScenario 2: {doc['scenarios'][1]}\nQuestion: Is Scenario 1 preferable?\nAnswer:"
|
21 |
+
|
22 |
+
|
23 |
+
def doc_to_target(doc):
|
24 |
+
doc = _preproc_doc(doc)
|
25 |
+
return doc["label"]
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/hendrycks_ethics/virtue.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include: commonsense.yaml
|
2 |
+
tag:
|
3 |
+
- hendrycks_ethics
|
4 |
+
task: ethics_virtue
|
5 |
+
dataset_name: virtue
|
6 |
+
doc_to_text: "Sentence: {{scenario}}\nQuestion: Does the character in this sentence exhibit the trait \"{{trait}}\"?\nAnswer:"
|
7 |
+
doc_to_target: label
|
8 |
+
doc_to_choice: ['no', 'yes']
|
9 |
+
metadata:
|
10 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/mc_taco/README.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MC Taco
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: `"Going on a vacation" takes longer than "Going for a walk": A Study of Temporal Commonsense Understanding`
|
6 |
+
Abstract: https://arxiv.org/abs/1909.03065
|
7 |
+
|
8 |
+
MC-TACO is a dataset of 13k question-answer pairs that require temporal commonsense
|
9 |
+
comprehension. The dataset contains five temporal properties, (1) duration (how long
|
10 |
+
an event takes), (2) temporal ordering (typical order of events), (3) typical time
|
11 |
+
(when an event occurs), (4) frequency (how often an event occurs), and (5) stationarity
|
12 |
+
(whether a state is maintained for a very long time or indefinitely).
|
13 |
+
|
14 |
+
WARNING: Running this task with a `--limit` arg will give misleading results! The
|
15 |
+
corresponding dataset is structured such that each multiple-choice-question gathered
|
16 |
+
by the authors is split into question-option pairs, where each such pair gets
|
17 |
+
siloed into an individual document for plausibility testing. Because the harness
|
18 |
+
shuffles these documents, setting `--limit` will likely "cut off" certain candidate
|
19 |
+
answers. This is a problem because the task's metrics require an exhaustive evaluation
|
20 |
+
of a question's options. See section 4 of the paper for details.
|
21 |
+
|
22 |
+
Homepage: https://leaderboard.allenai.org/mctaco/submissions/public
|
23 |
+
|
24 |
+
|
25 |
+
### Citation
|
26 |
+
|
27 |
+
```
|
28 |
+
BibTeX-formatted citation goes here
|
29 |
+
```
|
30 |
+
|
31 |
+
### Groups and Tasks
|
32 |
+
|
33 |
+
#### Groups
|
34 |
+
|
35 |
+
* Not part of a group yet.
|
36 |
+
|
37 |
+
#### Tasks
|
38 |
+
|
39 |
+
* `mc_taco`
|
40 |
+
|
41 |
+
|
42 |
+
### Checklist
|
43 |
+
|
44 |
+
For adding novel benchmarks/datasets to the library:
|
45 |
+
* [ ] Is the task an existing benchmark in the literature?
|
46 |
+
* [ ] Have you referenced the original paper that introduced the task?
|
47 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
48 |
+
|
49 |
+
|
50 |
+
If other tasks on this dataset are already supported:
|
51 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
52 |
+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
53 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/mc_taco/default.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: mc_taco
|
2 |
+
dataset_path: mc_taco
|
3 |
+
output_type: multiple_choice
|
4 |
+
validation_split: validation
|
5 |
+
test_split: test
|
6 |
+
doc_to_text: "{{sentence}}\nQuestion: {{question}}\nAnswer: {{answer}}\nPlausible:"
|
7 |
+
doc_to_target: label
|
8 |
+
doc_to_choice: ["no", "yes"]
|
9 |
+
should_decontaminate: true
|
10 |
+
doc_to_decontamination_query: "{{question}} {{sentence}}"
|
11 |
+
metric_list:
|
12 |
+
- metric: acc
|
13 |
+
- metric: f1
|
14 |
+
metadata:
|
15 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/pubmedqa/README.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PubMedQA
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: `PubMedQA: A Dataset for Biomedical Research Question Answering`
|
6 |
+
|
7 |
+
Abstract: https://arxiv.org/abs/1909.06146
|
8 |
+
|
9 |
+
PubMedQA is a novel biomedical question answering (QA) dataset collected from
|
10 |
+
PubMed abstracts. The task of PubMedQA is to answer research questions with
|
11 |
+
yes/no/maybe (e.g.: Do preoperative statins reduce atrial fibrillation after
|
12 |
+
coronary artery bypass grafting?) using the corresponding abstracts. PubMedQA
|
13 |
+
has 1k expert-annotated, 61.2k unlabeled and 211.3k artificially generated QA
|
14 |
+
instances. Each PubMedQA instance is composed of (1) a question which is either
|
15 |
+
an existing research article title or derived from one, (2) a context which is
|
16 |
+
the corresponding abstract without its conclusion, (3) a long answer, which is
|
17 |
+
the conclusion of the abstract and, presumably, answers the research question,
|
18 |
+
and (4) a yes/no/maybe answer which summarizes the conclusion.
|
19 |
+
|
20 |
+
Homepage: https://pubmedqa.github.io/
|
21 |
+
|
22 |
+
|
23 |
+
### Citation
|
24 |
+
|
25 |
+
```
|
26 |
+
@inproceedings{jin2019pubmedqa,
|
27 |
+
title={PubMedQA: A Dataset for Biomedical Research Question Answering},
|
28 |
+
author={Jin, Qiao and Dhingra, Bhuwan and Liu, Zhengping and Cohen, William and Lu, Xinghua},
|
29 |
+
booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
|
30 |
+
pages={2567--2577},
|
31 |
+
year={2019}
|
32 |
+
}
|
33 |
+
```
|
34 |
+
|
35 |
+
### Groups and Tasks
|
36 |
+
|
37 |
+
#### Groups
|
38 |
+
|
39 |
+
* Not part of a group yet
|
40 |
+
|
41 |
+
#### Tasks
|
42 |
+
|
43 |
+
* `pubmed_qa`
|
44 |
+
|
45 |
+
### Checklist
|
46 |
+
|
47 |
+
For adding novel benchmarks/datasets to the library:
|
48 |
+
* [ ] Is the task an existing benchmark in the literature?
|
49 |
+
* [ ] Have you referenced the original paper that introduced the task?
|
50 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
51 |
+
|
52 |
+
|
53 |
+
If other tasks on this dataset are already supported:
|
54 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
55 |
+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
56 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/pubmedqa/preprocess_pubmedqa.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def doc_to_text(doc) -> str:
|
2 |
+
ctxs = "\n".join(doc["CONTEXTS"])
|
3 |
+
return "Abstract: {}\nQuestion: {}\nAnswer:".format(
|
4 |
+
ctxs,
|
5 |
+
doc["QUESTION"],
|
6 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/pubmedqa/pubmedqa.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: pubmedqa
|
2 |
+
dataset_path: bigbio/pubmed_qa
|
3 |
+
dataset_name: pubmed_qa_labeled_fold0_source
|
4 |
+
output_type: multiple_choice
|
5 |
+
training_split: train
|
6 |
+
validation_split: validation
|
7 |
+
test_split: test
|
8 |
+
doc_to_text: !function preprocess_pubmedqa.doc_to_text
|
9 |
+
doc_to_target: final_decision
|
10 |
+
doc_to_choice: ["yes", "no", "maybe"]
|
11 |
+
metric_list:
|
12 |
+
- metric: acc
|
13 |
+
aggregation: mean
|
14 |
+
higher_is_better: true
|
15 |
+
metadata:
|
16 |
+
version: 1.0
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/qa4mre/README.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# QA4MRE
|
2 |
+
|
3 |
+
### Paper
|
4 |
+
|
5 |
+
Title: `QA4MRE 2011-2013: Overview of Question Answering for Machine Reading Evaluation`
|
6 |
+
|
7 |
+
Abstract: https://www.cs.cmu.edu/~./hovy/papers/13CLEF-QA4MRE.pdf
|
8 |
+
|
9 |
+
The (English only) QA4MRE challenge which was run as a Lab at CLEF 2011-2013.
|
10 |
+
The main objective of this exercise is to develop a methodology for evaluating
|
11 |
+
Machine Reading systems through Question Answering and Reading Comprehension
|
12 |
+
Tests. Systems should be able to extract knowledge from large volumes of text
|
13 |
+
and use this knowledge to answer questions. Four different tasks have been
|
14 |
+
organized during these years: Main Task, Processing Modality and Negation for
|
15 |
+
Machine Reading, Machine Reading of Biomedical Texts about Alzheimer's disease,
|
16 |
+
and Entrance Exam.
|
17 |
+
|
18 |
+
Homepage: http://nlp.uned.es/clef-qa/repository/qa4mre.php
|
19 |
+
|
20 |
+
|
21 |
+
### Citation
|
22 |
+
|
23 |
+
```
|
24 |
+
@inproceedings{Peas2013QA4MRE2O,
|
25 |
+
title={QA4MRE 2011-2013: Overview of Question Answering for Machine Reading Evaluation},
|
26 |
+
author={Anselmo Pe{\~n}as and Eduard H. Hovy and Pamela Forner and {\'A}lvaro Rodrigo and Richard F. E. Sutcliffe and Roser Morante},
|
27 |
+
booktitle={CLEF},
|
28 |
+
year={2013}
|
29 |
+
}
|
30 |
+
```
|
31 |
+
|
32 |
+
### Groups and Tasks
|
33 |
+
|
34 |
+
#### Groups
|
35 |
+
|
36 |
+
* `qa4mre`
|
37 |
+
|
38 |
+
#### Tasks
|
39 |
+
|
40 |
+
* `qa4mre_2011`
|
41 |
+
* `qa4mre_2012`
|
42 |
+
* `qa4mre_2013`
|
43 |
+
|
44 |
+
### Checklist
|
45 |
+
|
46 |
+
For adding novel benchmarks/datasets to the library:
|
47 |
+
* [ ] Is the task an existing benchmark in the literature?
|
48 |
+
* [ ] Have you referenced the original paper that introduced the task?
|
49 |
+
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
50 |
+
|
51 |
+
|
52 |
+
If other tasks on this dataset are already supported:
|
53 |
+
* [ ] Is the "Main" variant of this task clearly denoted?
|
54 |
+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
55 |
+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/qa4mre/preprocess_qa4mre.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def qa4mre_process(doc):
|
2 |
+
return int(doc["correct_answer_id"]) - 1
|
3 |
+
|
4 |
+
|
5 |
+
def doc_to_target(doc):
|
6 |
+
return doc["answer_options"]["answer_str"][qa4mre_process(doc)]
|
scripts/yans/lm-evaluation-harness/lm_eval/tasks/qa4mre/qa4mre_2011.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tag:
|
2 |
+
- qa4mre
|
3 |
+
task: qa4mre_2011
|
4 |
+
dataset_path: qa4mre
|
5 |
+
dataset_name: 2011.main.EN
|
6 |
+
output_type: multiple_choice
|
7 |
+
test_split: train
|
8 |
+
# doc_to_text: "{{document_str.strip()}}\nQuestion: {{question_str}}\nChoices:\n- {{answer_choices|join('\n- ')}}\nAnswer:"
|
9 |
+
doc_to_text: "{{document_str.strip()}}\nQuestion: {{question_str}}\nAnswer:"
|
10 |
+
doc_to_target: "{{correct_answer_id|int - 1}}"
|
11 |
+
doc_to_choice: "{{answer_options.answer_str}}"
|
12 |
+
should_decontaminate: true
|
13 |
+
doc_to_decontamination_query: "{{document_str.strip()}} + ' ' + {{question_str}}"
|
14 |
+
metric_list:
|
15 |
+
- metric: acc
|
16 |
+
aggregation: mean
|
17 |
+
higher_is_better: true
|
18 |
+
- metric: acc_norm
|
19 |
+
aggregation: mean
|
20 |
+
higher_is_better: true
|
21 |
+
metadata:
|
22 |
+
version: 1.0
|