tlem / tlem.py
Zhuoyang Song
FIX: fix conflicts
272edd2
raw
history blame
4.84 kB
# %%
try:
from ipytorch import logging
except Exception as e:
import logging
from typing import Any, Optional, Protocol, Iterable, Callable
from numpy.lib import extract
from tqdm.auto import tqdm
from evaluate.evaluation_suite import EvaluationSuite
import evaluate
import numpy as np
import datasets
import pandas as pd
from .tasks import *
from .utils import is_equiv
class ReasoningMetric(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def _info(self):
# if self.config_name in ["cmmlu"]:
features = datasets.Features(
{
"responses": datasets.Value("string"),
# "responses": datasets.Sequence(datasets.Value("float")),
"references": datasets.Value("string"),
}
)
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.EvaluationModuleInfo(
# This is the description that will appear on the modules page.
# module_type="measurement",
description="",
citation="",
inputs_description="",
# This defines the format of each prediction and reference
features=features,
# Homepage of the module for documentation
homepage="http://module.homepage",
# Additional links to the codebase or references
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
reference_urls=["http://path.to.reference.url/new_module"],
)
def _compute(self, responses, references, verbose=False):
extract_responses, extract_references = getattr(Metrics, self.config_name)(
responses, references
)
df = pd.DataFrame(
{
"responses": responses,
"references": references,
}
)
df["extract_responses"] = extract_responses
df["extract_references"] = extract_references
print(df)
results = {
"Accuracy": (df["extract_references"] == df["extract_responses"])
.astype(int)
.mean(),
}
logging.info(results)
if verbose:
results["df"] = df
return results
class Suite(EvaluationSuite):
task_class = Task
def run(
self,
model_or_pipeline: Any,
) -> dict[str, float]:
self.assert_suite_nonempty()
def run_tasks(tasks):
for task in (bar := tqdm(tasks, leave=False)):
bar.desc = f"complete {task.name}."
if task.name not in self.cached_result:
self.cached_result[task.name] = task.run(model_or_pipeline)
results = [self.cached_result[task.name] for task in tasks]
return pd.DataFrame(results).mean().to_dict()
if isinstance(self.suite, dict):
for category, tasks in (bar := tqdm(self.suite.items())):
bar.desc = f"complete {category}."
logging.warning(f"Combined results {category}: {run_tasks(tasks)}")
else:
logging.warning(f"Combined results: {run_tasks(self.suite)}")
return self.cached_result
def add(self, name):
self.load(name)
def load(self, name):
chat = False
match name:
case _ if "chat" in name:
chat = True
match name:
case _ if name.startswith("mmlu"):
suite = MMLU.suite(chat=chat)
case _ if name.startswith("cmmlu"):
suite = CMMLU.suite(chat=chat)
case "gsm8k":
suite = Task(
dataset_name=("gsm8k", "main"),
metric_name=("sustech/tlem", "gsm8k"),
input_column="question",
label_column="answer",
)
case "bbh":
suite = BBH.suite()
case "arc":
suite = ARC.suite()
case "hellaswag":
suite = HellaSwag.suite()
case "drop":
suite = DROP.suite()
case "winogrande":
suite = Winogrande.suite()
case _ if name.startswith("ceval"):
suite = CEVAL.suite(chat=chat)
case "mt_bench":
suite = Task(
dataset_name="SUSTech/mt_bench_judge",
split="train",
prompt=mt_bench_prompt
# metric_name=("sustech/tlem", "gsm8k"),
)
match name:
case _ if "test" in name:
suite = suite["Test"]
self.suite = suite
def __init__(self, name="tlem"):
super().__init__(name)
self.cached_result = {}
self.suite = []