Spaces:
Running
Running
from dataclasses import dataclass, field | |
from datasets import load_dataset, Dataset | |
from functools import cached_property | |
from tqdm.auto import tqdm | |
from typing import Any, Optional, Protocol, Iterable, Callable | |
import logging | |
import pandas as pd | |
from functools import partial | |
from datasets.utils.logging import disable_progress_bar | |
from .utils import * | |
from evaluate import load | |
from collections import defaultdict | |
import sys | |
# if sys.version_info >= (3, 9): | |
# from functools import cache | |
# else: | |
# from functools import lru_cache as cache | |
disable_progress_bar() | |
def mt_bench_prompt(example): | |
judge_prompt = "You are ChatGPT, a large language model trained by OpenAI. Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. The Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response." | |
judge_prompt = "You are ChatGPT, a large language model trained by OpenAI. Your task is to act as an impartial judge and evaluate the quality of the responses provided by an 'assistant' role in the displayed conversation. Your evaluation should focus on the helpfulness, relevance, accuracy, depth, creativity, language fluency, clarity, and level of detail in the assistant's responses. Please note that the evaluation should not consider the user's questions or the overall conversation, but solely the quality of the assistant's replies." | |
multi_prompt = "You evaluation should focus on the assistant's answer to the second user question." | |
ref_prompt = "In the conversation, you will encounter system messages labeled 'Reference Answer' followed by the assistant's response. Your task is to evaluate the quality of the assistant's response by comparing it with the reference answer." | |
json_prompt = 'You must rate the response on a scale of 1 to 10 in JSON format, for example: {"rating": 5}.' | |
prompt_list = [judge_prompt] | |
conversations = example["conversation"] | |
if example["turn"] == 2: | |
prompt_list.append(multi_prompt) | |
if example["reference"] is not None: | |
conversations = [] | |
quesiotns = filter(lambda e: e["role"] == "user", example["conversation"]) | |
answers = filter(lambda e: e["role"] == "assistant", example["conversation"]) | |
for q, a, r in zip(quesiotns, answers, example["reference"]): | |
conversations.append(q) | |
conversations.append( | |
{"role": "system", "content": "Reference Answer: " + r} | |
) | |
conversations.append(a) | |
prompt_list.append(ref_prompt) | |
prompt_list.append(json_prompt) | |
messages = [{"role": "system", "content": " ".join(prompt_list)}] + conversations | |
return messages | |
class Task: | |
dataset_name: str | tuple[str, str] = ("gsm8k", "main") | |
split: str = "test" | |
# metrics: list[str] = field(default_factory=list) | |
metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k") | |
input_column: str = "question" | |
label_column: str = "" | |
prompt: Optional[Callable | str] = None | |
few_shot: int = 0 | |
few_shot_from: Optional[str] = None | |
# results: dict[str, Any] = field(default_factory=dict) | |
def __post_init__(self): | |
names = ( | |
[self.dataset_name] | |
if isinstance(self.dataset_name, str) | |
else list(self.dataset_name) | |
) | |
names[0] = names[0].split("/")[-1] | |
self.name = "-".join(names) + f"-{self.split}" | |
if isinstance(self.prompt, str): | |
self.prompt = lambda example: { | |
self.input_column: self.prompt.format( | |
input_column=example[self.input_column] | |
) | |
} | |
self.label_column = self.label_column or self.input_column | |
def samples(self): | |
return self.dataset[self.input_column] | |
def dataset(self): | |
ds = load_dataset( | |
*self.dataset_name | |
if isinstance(self.dataset_name, tuple) | |
else self.dataset_name, | |
# split=self.split, | |
) | |
test_ds = ds[self.split] | |
if self.prompt is not None: | |
test_ds = test_ds.map(self.prompt) | |
if self.few_shot: | |
if self.few_shot_from is None: | |
for name in ["train", "validation", "val", "dev"]: | |
if name in ds: | |
self.few_shot_from = name | |
break | |
assert self.few_shot_from != self.split | |
shots = ds[self.few_shot_from].select(range(self.few_shot)) | |
if self.prompt is not None: | |
shots = shots.map(self.prompt) | |
shots = shots.map( | |
lambda example: { | |
self.input_column: example[self.input_column] | |
+ example[self.label_column], | |
} | |
)[self.input_column] | |
few_shot_prompts = "\n\n".join(shots) | |
test_ds = test_ds.map( | |
lambda example: { | |
self.input_column: few_shot_prompts | |
+ "\n\n" | |
+ example[self.input_column], | |
} | |
) | |
return test_ds | |
def metric(self): | |
metric = ( | |
load(self.metric_name) | |
if isinstance(self.metric_name, str) | |
else load(*self.metric_name) | |
) | |
return metric | |
# @cache | |
def run( | |
self, | |
pipeline, | |
): | |
if (outputs := pipeline(self.samples)) is None: | |
logging.warning("pipeline returns None") | |
return | |
self.outputs = outputs | |
try: | |
result = self.metric._compute( | |
responses=outputs, references=self.dataset[self.label_column] | |
) | |
except Exception as e: | |
result = self.metric.compute( | |
responses=outputs, references=self.dataset[self.label_column] | |
) | |
finally: | |
result = outputs | |
# if log: | |
# name = name or pipeline.__name__ | |
# self.results[name] = result | |
return result | |
def multichoice(responses: Any, references: list[str]): | |
if isinstance(responses[0], str): | |
responses = [extract_choice(response) for response in responses] | |
else: | |
responses = decode_choice(responses) | |
return responses, references | |
def multichoice_zh(responses: Any, references: list[str]): | |
if isinstance(responses[0], str): | |
responses = [extract_choice_zh(response) for response in responses] | |
else: | |
responses = decode_choice(responses) | |
return responses, references | |
class Metrics: | |
cmmlu = multichoice_zh | |
mmlu = multichoice | |
def gsm8k(responses: list[str], answers: list[str | int]): | |
# scores = [] | |
# for response, answer in zip(responses, answers): | |
# pred = extract_numeric(response) | |
# gold = extract_numeric(answer) if isinstance(answer, str) else str(answer) | |
# scores.append(1.0 * (pred == gold)) | |
responses = [extract_numeric(response) for response in responses] | |
answers = [ | |
extract_numeric(answer) if isinstance(answer, str) else str(answer) | |
for answer in answers | |
] | |
return responses, answers | |
def MATH(responses: list[str], answers: list[str]): | |
scores = [] | |
for response, answer in zip(responses, answers): | |
indices = [pos for pos, char in enumerate(response) if char == "$"] | |
if len(indices) <= 2: | |
scores.append(0) | |
continue | |
else: | |
result = response[indices[-2] + 1 : indices[-1]] | |
gold = get_answer(answer) | |
scores.append(1.0 * is_equiv(result, gold)) | |
return scores | |
def math23k(responses: list[str], answers: list[str]): | |
scores = [] | |
for response, answer in zip(responses, answers): | |
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH) | |
gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH) | |
scores.append(1.0 * (pred == gold)) | |
return scores | |
class CMMLU: | |
input_column = "prompt" | |
label_column = "Answer" | |
def prompt_cmmlu(example, chat=False): | |
prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:" | |
prompt = prefix + example["Question"] | |
for choice in list("ABCD"): | |
prompt += f"\n{choice}. {example[choice]}" | |
prompt += "\n答案:" | |
return {"prompt": prompt} | |
subcategories = { | |
"agronomy": ["other"], | |
"anatomy": ["biology"], | |
"ancient_chinese": ["linguistics", "china specific"], | |
"arts": ["arts"], | |
"astronomy": ["physics"], | |
"business_ethics": ["business"], | |
"chinese_civil_service_exam": ["politics", "china specific"], | |
"chinese_driving_rule": ["other", "china specific"], | |
"chinese_food_culture": ["culture", "china specific"], | |
"chinese_foreign_policy": ["politics", "china specific"], | |
"chinese_history": ["history", "china specific"], | |
"chinese_literature": ["literature", "china specific"], | |
"chinese_teacher_qualification": ["education", "china specific"], | |
"college_actuarial_science": ["math"], | |
"college_education": ["education"], | |
"college_engineering_hydrology": ["engineering"], | |
"college_law": ["law"], | |
"college_mathematics": ["math"], | |
"college_medical_statistics": ["statistics"], | |
"clinical_knowledge": ["other"], | |
"college_medicine": ["other"], | |
"computer_science": ["computer science"], | |
"computer_security": ["other"], | |
"conceptual_physics": ["physics"], | |
"construction_project_management": ["other", "china specific"], | |
"economics": ["economics"], | |
"education": ["education"], | |
"elementary_chinese": ["linguistics", "china specific"], | |
"elementary_commonsense": ["other", "china specific"], | |
"elementary_information_and_technology": ["other"], | |
"electrical_engineering": ["engineering"], | |
"elementary_mathematics": ["math"], | |
"ethnology": ["culture", "china specific"], | |
"food_science": ["other"], | |
"genetics": ["biology"], | |
"global_facts": ["global"], | |
"high_school_biology": ["biology"], | |
"high_school_chemistry": ["chemistry"], | |
"high_school_geography": ["geography"], | |
"high_school_mathematics": ["math"], | |
"high_school_physics": ["physics"], | |
"high_school_politics": ["politics", "china specific"], | |
"human_sexuality": ["other"], | |
"international_law": ["law"], | |
"journalism": ["sociology"], | |
"jurisprudence": ["law"], | |
"legal_and_moral_basis": ["other"], | |
"logical": ["philosophy"], | |
"machine_learning": ["computer science"], | |
"management": ["business"], | |
"marketing": ["business"], | |
"marxist_theory": ["philosophy"], | |
"modern_chinese": ["linguistics", "china specific"], | |
"nutrition": ["other"], | |
"philosophy": ["philosophy"], | |
"professional_accounting": ["business"], | |
"professional_law": ["law"], | |
"professional_medicine": ["other"], | |
"professional_psychology": ["psychology"], | |
"public_relations": ["politics"], | |
"security_study": ["politics"], | |
"sociology": ["culture"], | |
"sports_science": ["other"], | |
"traditional_chinese_medicine": ["other", "china specific"], | |
"virology": ["biology"], | |
"world_history": ["history"], | |
"world_religions": ["global"], | |
} | |
categories = { | |
"STEM": [ | |
"physics", | |
"chemistry", | |
"biology", | |
"computer science", | |
"math", | |
"engineering", | |
"statistics", | |
], | |
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"], | |
"Social Science": [ | |
"linguistics", | |
"business", | |
"politics", | |
"culture", | |
"economics", | |
"geography", | |
"psychology", | |
"education", | |
"sociology", | |
], | |
"Other": ["other"], | |
"China specific": ["china specific"], | |
"Test": ["computer science"], | |
} | |
def suite(cls, chat=False): | |
finer_categories = ( | |
pd.Series(cls.subcategories) # noqa # type: ignore | |
.explode() | |
.reset_index() | |
.set_index(0) | |
.groupby(0) | |
.agg(list)["index"] | |
.to_dict() | |
) | |
suite = defaultdict(list) | |
cls.categories["all"] = list(finer_categories.keys()) | |
for k, v in cls.categories.items(): | |
for subject in v: | |
suite[k].extend( | |
[ | |
Task( | |
("haonan-li/cmmlu", subcategories), | |
metric_name=("sustech/tlem", "cmmlu"), | |
input_column=cls.input_column, | |
label_column=cls.label_column, | |
prompt=partial(cls.prompt_cmmlu, chat=chat), | |
few_shot=0 if chat else 5, | |
few_shot_from="dev", | |
) | |
for subcategories in finer_categories[subject] | |
] | |
) | |
return suite | |
class MMLU: | |
input_column = "prompt" | |
label_column = "target" | |
def prompt_mmlu(cls, example, chat=False): | |
prefix = ( | |
"The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n" | |
if chat | |
else "Question: " | |
) | |
prompt = prefix + example["input"] | |
for choice in list("ABCD"): | |
prompt += f"\n{choice}. {example[choice]}" | |
prompt += "\nAnswer:" | |
return {"prompt": prompt} | |
subcategories = { | |
"abstract_algebra": ["math"], | |
"anatomy": ["health"], | |
"astronomy": ["physics"], | |
"business_ethics": ["business"], | |
"clinical_knowledge": ["health"], | |
"college_biology": ["biology"], | |
"college_chemistry": ["chemistry"], | |
"college_computer_science": ["computer science"], | |
"college_mathematics": ["math"], | |
"college_medicine": ["health"], | |
"college_physics": ["physics"], | |
"computer_security": ["computer science"], | |
"conceptual_physics": ["physics"], | |
"econometrics": ["economics"], | |
"electrical_engineering": ["engineering"], | |
"elementary_mathematics": ["math"], | |
"formal_logic": ["philosophy"], | |
"global_facts": ["other"], | |
"high_school_biology": ["biology"], | |
"high_school_chemistry": ["chemistry"], | |
"high_school_computer_science": ["computer science"], | |
"high_school_european_history": ["history"], | |
"high_school_geography": ["geography"], | |
"high_school_government_and_politics": ["politics"], | |
"high_school_macroeconomics": ["economics"], | |
"high_school_mathematics": ["math"], | |
"high_school_microeconomics": ["economics"], | |
"high_school_physics": ["physics"], | |
"high_school_psychology": ["psychology"], | |
"high_school_statistics": ["math"], | |
"high_school_us_history": ["history"], | |
"high_school_world_history": ["history"], | |
"human_aging": ["health"], | |
"human_sexuality": ["culture"], | |
"international_law": ["law"], | |
"jurisprudence": ["law"], | |
"logical_fallacies": ["philosophy"], | |
"machine_learning": ["computer science"], | |
"management": ["business"], | |
"marketing": ["business"], | |
"medical_genetics": ["health"], | |
"miscellaneous": ["other"], | |
"moral_disputes": ["philosophy"], | |
"moral_scenarios": ["philosophy"], | |
"nutrition": ["health"], | |
"philosophy": ["philosophy"], | |
"prehistory": ["history"], | |
"professional_accounting": ["other"], | |
"professional_law": ["law"], | |
"professional_medicine": ["health"], | |
"professional_psychology": ["psychology"], | |
"public_relations": ["politics"], | |
"security_studies": ["politics"], | |
"sociology": ["culture"], | |
"us_foreign_policy": ["politics"], | |
"virology": ["health"], | |
"world_religions": ["philosophy"], | |
} | |
categories = { | |
"STEM": [ | |
"physics", | |
"chemistry", | |
"biology", | |
"computer science", | |
"math", | |
"engineering", | |
], | |
"humanities": ["history", "philosophy", "law"], | |
"social sciences": [ | |
"politics", | |
"culture", | |
"economics", | |
"geography", | |
"psychology", | |
], | |
"other": ["other", "business", "health"], | |
} | |
def suite(cls, chat=False): | |
finer_categories = ( | |
pd.Series(cls.subcategories) # noqa # type: ignore | |
.explode() | |
.reset_index() | |
.set_index(0) | |
.groupby(0) | |
.agg(list)["index"] | |
.to_dict() | |
) | |
suite = defaultdict(list) | |
cls.categories["all"] = list(finer_categories.keys()) | |
for k, v in cls.categories.items(): | |
for subject in v: | |
suite[k].extend( | |
[ | |
Task( | |
("lukaemon/mmlu", subcategories), | |
metric_name=("sustech/tlem", "mmlu"), | |
input_column=cls.input_column, | |
label_column=cls.label_column, | |
prompt=partial(cls.prompt_mmlu, chat=chat), | |
few_shot=0 if chat else 5, | |
few_shot_from="validation", | |
) | |
for subcategories in finer_categories[subject] | |
] | |
) | |
return suite | |