Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
test: add unit tests for benchmarks
Browse files- src/benchmarks.py +41 -34
- tests/src/test_benchmarks.py +43 -9
src/benchmarks.py
CHANGED
@@ -10,7 +10,7 @@ from src.models import TaskType, get_safe_name
|
|
10 |
@dataclass
|
11 |
class Benchmark:
|
12 |
name: str # [domain]_[language]_[metric], task_key in the json file,
|
13 |
-
metric: str #
|
14 |
col_name: str # [domain]_[language], name to display in the leaderboard
|
15 |
domain: str
|
16 |
lang: str
|
@@ -18,54 +18,61 @@ class Benchmark:
|
|
18 |
|
19 |
|
20 |
# create a function return an enum class containing all the benchmarks
|
21 |
-
def
|
22 |
benchmark_dict = {}
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
for
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
col_name = benchmark_name
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
benchmark_dict[benchmark_name] = Benchmark(
|
35 |
benchmark_name, metric, col_name, domain, lang, task
|
36 |
)
|
37 |
-
elif task_type == TaskType.long_doc:
|
38 |
-
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
39 |
-
if task != task_type.value:
|
40 |
-
continue
|
41 |
-
for domain, lang_dict in domain_dict.items():
|
42 |
-
for lang, dataset_list in lang_dict.items():
|
43 |
-
for dataset in dataset_list:
|
44 |
-
benchmark_name = f"{domain}_{lang}_{dataset}"
|
45 |
-
benchmark_name = get_safe_name(benchmark_name)
|
46 |
-
col_name = benchmark_name
|
47 |
-
if "test" not in dataset_list[dataset]["splits"]:
|
48 |
-
continue
|
49 |
-
for metric in METRIC_LIST:
|
50 |
-
benchmark_dict[benchmark_name] = Benchmark(
|
51 |
-
benchmark_name, metric, col_name, domain, lang, task
|
52 |
-
)
|
53 |
return benchmark_dict
|
54 |
|
55 |
|
56 |
_qa_benchmark_dict = {}
|
57 |
for version in BENCHMARK_VERSION_LIST:
|
58 |
safe_version_name = get_safe_name(version)
|
59 |
-
_qa_benchmark_dict[safe_version_name] =
|
60 |
-
|
61 |
-
|
|
|
|
|
62 |
|
63 |
_doc_benchmark_dict = {}
|
64 |
for version in BENCHMARK_VERSION_LIST:
|
65 |
safe_version_name = get_safe_name(version)
|
66 |
-
_doc_benchmark_dict[safe_version_name] =
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
|
70 |
|
71 |
QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
|
|
|
10 |
@dataclass
|
11 |
class Benchmark:
|
12 |
name: str # [domain]_[language]_[metric], task_key in the json file,
|
13 |
+
metric: str # metric_key in the json file
|
14 |
col_name: str # [domain]_[language], name to display in the leaderboard
|
15 |
domain: str
|
16 |
lang: str
|
|
|
18 |
|
19 |
|
20 |
# create a function return an enum class containing all the benchmarks
|
21 |
+
def get_qa_benchmarks_dict(version: str):
|
22 |
benchmark_dict = {}
|
23 |
+
for task, domain_dict in BenchmarkTable[version].items():
|
24 |
+
if task != TaskType.qa.value:
|
25 |
+
continue
|
26 |
+
for domain, lang_dict in domain_dict.items():
|
27 |
+
for lang, dataset_list in lang_dict.items():
|
28 |
+
benchmark_name = get_safe_name(f"{domain}_{lang}")
|
29 |
+
col_name = benchmark_name
|
30 |
+
for metric in dataset_list:
|
31 |
+
if "test" not in dataset_list[metric]["splits"]:
|
32 |
+
continue
|
33 |
+
benchmark_dict[benchmark_name] = Benchmark(
|
34 |
+
benchmark_name, metric, col_name, domain, lang, task
|
35 |
+
)
|
36 |
+
return benchmark_dict
|
37 |
+
|
38 |
+
|
39 |
+
def get_doc_benchmarks_dict(version: str):
|
40 |
+
benchmark_dict = {}
|
41 |
+
for task, domain_dict in BenchmarkTable[version].items():
|
42 |
+
if task != TaskType.long_doc.value:
|
43 |
+
continue
|
44 |
+
for domain, lang_dict in domain_dict.items():
|
45 |
+
for lang, dataset_list in lang_dict.items():
|
46 |
+
for dataset in dataset_list:
|
47 |
+
benchmark_name = f"{domain}_{lang}_{dataset}"
|
48 |
+
benchmark_name = get_safe_name(benchmark_name)
|
49 |
col_name = benchmark_name
|
50 |
+
if "test" not in dataset_list[dataset]["splits"]:
|
51 |
+
continue
|
52 |
+
for metric in METRIC_LIST:
|
53 |
benchmark_dict[benchmark_name] = Benchmark(
|
54 |
benchmark_name, metric, col_name, domain, lang, task
|
55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
return benchmark_dict
|
57 |
|
58 |
|
59 |
_qa_benchmark_dict = {}
|
60 |
for version in BENCHMARK_VERSION_LIST:
|
61 |
safe_version_name = get_safe_name(version)
|
62 |
+
_qa_benchmark_dict[safe_version_name] = \
|
63 |
+
Enum(
|
64 |
+
f"QABenchmarks_{safe_version_name}",
|
65 |
+
get_qa_benchmarks_dict(version)
|
66 |
+
)
|
67 |
|
68 |
_doc_benchmark_dict = {}
|
69 |
for version in BENCHMARK_VERSION_LIST:
|
70 |
safe_version_name = get_safe_name(version)
|
71 |
+
_doc_benchmark_dict[safe_version_name] = \
|
72 |
+
Enum(
|
73 |
+
f"LongDocBenchmarks_{safe_version_name}",
|
74 |
+
get_doc_benchmarks_dict(version)
|
75 |
+
)
|
76 |
|
77 |
|
78 |
QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
|
tests/src/test_benchmarks.py
CHANGED
@@ -1,15 +1,49 @@
|
|
|
|
|
|
1 |
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
|
|
|
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
for benchmark_list in list(QABenchmarks):
|
6 |
-
|
7 |
-
|
8 |
-
print(b)
|
9 |
-
qa_benchmarks = QABenchmarks["2404"]
|
10 |
-
l = list(frozenset([c.value.domain for c in list(qa_benchmarks.value)]))
|
11 |
-
print(l)
|
12 |
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
4 |
+
from src.envs import BENCHMARK_VERSION_LIST
|
5 |
+
|
6 |
|
7 |
+
# Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
|
8 |
+
# 24.05
|
9 |
+
# | Task | dev | test |
|
10 |
+
# | ---- | --- | ---- |
|
11 |
+
# | Long-Doc | 4 | 11 |
|
12 |
+
# | QA | 54 | 53 |
|
13 |
+
#
|
14 |
+
# 24.04
|
15 |
+
# | Task | test |
|
16 |
+
# | ---- | ---- |
|
17 |
+
# | Long-Doc | 15 |
|
18 |
+
# | QA | 13 |
|
19 |
|
20 |
+
@pytest.mark.parametrize(
|
21 |
+
"num_datasets_dict",
|
22 |
+
[
|
23 |
+
{
|
24 |
+
"air_bench_2404": 13,
|
25 |
+
"air_bench_2405": 53
|
26 |
+
}
|
27 |
+
]
|
28 |
+
)
|
29 |
+
def test_qa_benchmarks(num_datasets_dict):
|
30 |
+
assert len(QABenchmarks) == len(BENCHMARK_VERSION_LIST)
|
31 |
for benchmark_list in list(QABenchmarks):
|
32 |
+
version_slug = benchmark_list.name
|
33 |
+
assert num_datasets_dict[version_slug] == len(benchmark_list.value)
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
+
@pytest.mark.parametrize(
|
37 |
+
"num_datasets_dict",
|
38 |
+
[
|
39 |
+
{
|
40 |
+
"air_bench_2404": 15,
|
41 |
+
"air_bench_2405": 11
|
42 |
+
}
|
43 |
+
]
|
44 |
+
)
|
45 |
+
def test_doc_benchmarks(num_datasets_dict):
|
46 |
+
assert len(LongDocBenchmarks) == len(BENCHMARK_VERSION_LIST)
|
47 |
+
for benchmark_list in list(LongDocBenchmarks):
|
48 |
+
version_slug = benchmark_list.name
|
49 |
+
assert num_datasets_dict[version_slug] == len(benchmark_list.value)
|