Spaces:
Running
Running
""" | |
Given a directory of results, plot the benchmarks for each task as a bar chart and line chart. | |
""" | |
import argparse | |
import os | |
from typing import Optional | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import seaborn as sns | |
from dgeb import TaskResult, get_all_tasks, get_output_folder, get_tasks_by_name | |
ALL_TASKS = [task.metadata.id for task in get_all_tasks()] | |
def plot_benchmarks( | |
results_dir, | |
task_ids: Optional[list[str]] = None, | |
output="benchmarks.png", | |
model_substring=None, | |
): | |
models = os.listdir(results_dir) | |
all_results = [] | |
tasks = get_all_tasks() if task_ids is None else get_tasks_by_name(task_ids) | |
for model_name in models: | |
if model_substring is not None and all( | |
substr not in model_name for substr in model_substring | |
): | |
continue | |
for task in tasks: | |
if task.metadata.display_name == "NoOp Task": | |
continue | |
filepath = get_output_folder(model_name, task, results_dir, create=False) | |
# if the file does not exist, skip | |
if not os.path.exists(filepath): | |
continue | |
with open(filepath) as f: | |
task_result = TaskResult.model_validate_json(f.read()) | |
num_params = task_result.model["num_params"] | |
primary_metric_id = task_result.task.primary_metric_id | |
main_scores = [ | |
metric.value | |
for layer_result in task_result.results | |
for metric in layer_result.metrics | |
if metric.id == primary_metric_id | |
] | |
best_score = max(main_scores) | |
all_results.append( | |
{ | |
"task": task.metadata.display_name, | |
"model": model_name, | |
"num_params": num_params, | |
"score": best_score, | |
} | |
) | |
results_df = pd.DataFrame(all_results) | |
# order the models by ascending number of parameters | |
results_df["num_params"] = results_df["num_params"].astype(int) | |
results_df = results_df.sort_values(by="num_params") | |
# number of tasks | |
n_tasks = len(set(results_df["task"])) | |
_, ax = plt.subplots(2, n_tasks, figsize=(5 * n_tasks, 10)) | |
for i, task in enumerate(set(results_df["task"])): | |
if n_tasks > 1: | |
sns.barplot( | |
x="model", | |
y="score", | |
data=results_df[results_df["task"] == task], | |
ax=ax[0][i], | |
) | |
ax[0][i].set_title(task) | |
# rotate the x axis labels | |
for tick in ax[0][i].get_xticklabels(): | |
tick.set_rotation(90) | |
else: | |
sns.barplot( | |
x="model", | |
y="score", | |
data=results_df[results_df["task"] == task], | |
ax=ax[0], | |
) | |
ax[0].set_title(task) | |
# rotate the x axis labels | |
for tick in ax[0].get_xticklabels(): | |
tick.set_rotation(90) | |
# make a line graph with number of parameters on x axis for each task in the second row of figures | |
for i, task in enumerate(set(results_df["task"])): | |
if n_tasks > 1: | |
sns.lineplot( | |
x="num_params", | |
y="score", | |
data=results_df[results_df["task"] == task], | |
ax=ax[1][i], | |
) | |
ax[1][i].set_title(task) | |
ax[1][i].set_xlabel("Number of parameters") | |
else: | |
sns.lineplot( | |
x="num_params", | |
y="score", | |
data=results_df[results_df["task"] == task], | |
ax=ax[1], | |
) | |
ax[1].set_title(task) | |
ax[1].set_xlabel("Number of parameters") | |
plt.tight_layout() | |
plt.savefig(output) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-d", | |
"--results_dir", | |
type=str, | |
default="results", | |
help="Directory containing the results of the benchmarking", | |
) | |
parser.add_argument( | |
"-t", | |
"--tasks", | |
type=lambda s: [item for item in s.split(",")], | |
default=None, | |
help=f"Comma separated list of tasks to plot. Choose from {ALL_TASKS} or do not specify to plot all tasks. ", | |
) | |
parser.add_argument( | |
"-o", | |
"--output", | |
type=str, | |
default="benchmarks.png", | |
help="Output file for the plot", | |
) | |
parser.add_argument( | |
"--model_substring", | |
type=lambda s: [item for item in s.split(",")], | |
default=None, | |
help="Comma separated list of model substrings. Only plot results for models containing this substring", | |
) | |
args = parser.parse_args() | |
plot_benchmarks(args.results_dir, args.tasks, args.output, args.model_substring) | |