Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,415 Bytes
8b7a945 4a44211 8b7a945 8b7258f 8b7a945 43fbed5 8b7a945 1a2dba5 8b7a945 8102fce 8b7a945 f30cbcc 8b7a945 f30cbcc 8b7a945 32ebf18 8b7a945 443f557 8b7a945 9c49811 8b7a945 9c49811 8b7a945 a96f80a 9134169 8b7a945 9c49811 8b7a945 f30cbcc 8b7a945 9c49811 1a2dba5 8b7a945 f30cbcc 32ebf18 a96f80a 8b7a945 9134169 9c49811 e8879cc f30cbcc e8879cc f30cbcc 32ebf18 bbfe4c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from dataclasses import dataclass
from enum import Enum
def get_safe_name(name: str):
"""Get RFC 1123 compatible safe name"""
name = name.replace('-', '_')
return ''.join(
character.lower()
for character in name
if (character.isalnum() or character == '_'))
dataset_dict = {
"qa": {
"wiki": {
"en": ["wikipedia_20240101", ],
"zh": ["wikipedia_20240101", ]
},
"web": {
"en": ["mC4", ],
"zh": ["mC4", ]
},
"news": {
"en": ["CC-News", ],
"zh": ["CC-News", ]
},
"healthcare": {
"en": ["PubMedQA", ],
"zh": ["Huatuo-26M", ]
},
"law": {
"en": ["pile-of-law", ],
# "zh": ["flk_npc_gov_cn", ]
},
"finance": {
"en": ["Reuters-Financial", ],
"zh": ["FinCorpus", ]
},
"arxiv": {
"en": ["Arxiv", ]},
"msmarco": {
"en": ["MS MARCO", ]},
},
"long-doc": {
"arxiv": {
"en": ["gpt3", "llama2", "llm-survey", "gemini"],
},
"book": {
"en": [
"origin-of-species_darwin",
"a-brief-history-of-time_stephen-hawking"
]
},
"healthcare": {
"en": [
"pubmed_100k-200k_1",
"pubmed_100k-200k_2",
"pubmed_100k-200k_3",
"pubmed_40k-50k_5-merged",
"pubmed_30k-40k_10-merged"
]
},
"law": {
"en": [
"lex_files_300k-400k",
"lex_files_400k-500k",
"lex_files_500k-600k",
"lex_files_600k-700k"
]
}
}
}
METRIC_LIST = [
"ndcg_at_1",
"ndcg_at_3",
"ndcg_at_5",
"ndcg_at_10",
"ndcg_at_100",
"ndcg_at_1000",
"map_at_1",
"map_at_3",
"map_at_5",
"map_at_10",
"map_at_100",
"map_at_1000",
"recall_at_1",
"recall_at_3",
"recall_at_5",
"recall_at_10",
"recall_at_100",
"recall_at_1000",
"precision_at_1",
"precision_at_3",
"precision_at_5",
"precision_at_10",
"precision_at_100",
"precision_at_1000",
"mrr_at_1",
"mrr_at_3",
"mrr_at_5",
"mrr_at_10",
"mrr_at_100",
"mrr_at_1000"
]
@dataclass
class Benchmark:
name: str # [domain]_[language]_[metric], task_key in the json file,
metric: str # ndcg_at_1 ,metric_key in the json file
col_name: str # [domain]_[language], name to display in the leaderboard
domain: str
lang: str
task: str
qa_benchmark_dict = {}
long_doc_benchmark_dict = {}
for task, domain_dict in dataset_dict.items():
for domain, lang_dict in domain_dict.items():
for lang, dataset_list in lang_dict.items():
if task == "qa":
benchmark_name = f"{domain}_{lang}"
benchmark_name = get_safe_name(benchmark_name)
col_name = benchmark_name
for metric in dataset_list:
qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
elif task == "long-doc":
for dataset in dataset_list:
benchmark_name = f"{domain}_{lang}_{dataset}"
benchmark_name = get_safe_name(benchmark_name)
col_name = benchmark_name
for metric in METRIC_LIST:
long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain,
lang, task)
BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()]
BENCHMARK_COLS_LONG_DOC = [c.col_name for c in long_doc_benchmark_dict.values()]
DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))
DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))
DEFAULT_METRIC = "recall_at_10" |