DGEB / scripts /plot_benchmarks.py
Joshua Kravitz
Initial commit
e284167
"""
Given a directory of results, plot the benchmarks for each task as a bar chart and line chart.
"""
import argparse
import os
from typing import Optional
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from dgeb import TaskResult, get_all_tasks, get_output_folder, get_tasks_by_name
ALL_TASKS = [task.metadata.id for task in get_all_tasks()]
def plot_benchmarks(
results_dir,
task_ids: Optional[list[str]] = None,
output="benchmarks.png",
model_substring=None,
):
models = os.listdir(results_dir)
all_results = []
tasks = get_all_tasks() if task_ids is None else get_tasks_by_name(task_ids)
for model_name in models:
if model_substring is not None and all(
substr not in model_name for substr in model_substring
):
continue
for task in tasks:
if task.metadata.display_name == "NoOp Task":
continue
filepath = get_output_folder(model_name, task, results_dir, create=False)
# if the file does not exist, skip
if not os.path.exists(filepath):
continue
with open(filepath) as f:
task_result = TaskResult.model_validate_json(f.read())
num_params = task_result.model["num_params"]
primary_metric_id = task_result.task.primary_metric_id
main_scores = [
metric.value
for layer_result in task_result.results
for metric in layer_result.metrics
if metric.id == primary_metric_id
]
best_score = max(main_scores)
all_results.append(
{
"task": task.metadata.display_name,
"model": model_name,
"num_params": num_params,
"score": best_score,
}
)
results_df = pd.DataFrame(all_results)
# order the models by ascending number of parameters
results_df["num_params"] = results_df["num_params"].astype(int)
results_df = results_df.sort_values(by="num_params")
# number of tasks
n_tasks = len(set(results_df["task"]))
_, ax = plt.subplots(2, n_tasks, figsize=(5 * n_tasks, 10))
for i, task in enumerate(set(results_df["task"])):
if n_tasks > 1:
sns.barplot(
x="model",
y="score",
data=results_df[results_df["task"] == task],
ax=ax[0][i],
)
ax[0][i].set_title(task)
# rotate the x axis labels
for tick in ax[0][i].get_xticklabels():
tick.set_rotation(90)
else:
sns.barplot(
x="model",
y="score",
data=results_df[results_df["task"] == task],
ax=ax[0],
)
ax[0].set_title(task)
# rotate the x axis labels
for tick in ax[0].get_xticklabels():
tick.set_rotation(90)
# make a line graph with number of parameters on x axis for each task in the second row of figures
for i, task in enumerate(set(results_df["task"])):
if n_tasks > 1:
sns.lineplot(
x="num_params",
y="score",
data=results_df[results_df["task"] == task],
ax=ax[1][i],
)
ax[1][i].set_title(task)
ax[1][i].set_xlabel("Number of parameters")
else:
sns.lineplot(
x="num_params",
y="score",
data=results_df[results_df["task"] == task],
ax=ax[1],
)
ax[1].set_title(task)
ax[1].set_xlabel("Number of parameters")
plt.tight_layout()
plt.savefig(output)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--results_dir",
type=str,
default="results",
help="Directory containing the results of the benchmarking",
)
parser.add_argument(
"-t",
"--tasks",
type=lambda s: [item for item in s.split(",")],
default=None,
help=f"Comma separated list of tasks to plot. Choose from {ALL_TASKS} or do not specify to plot all tasks. ",
)
parser.add_argument(
"-o",
"--output",
type=str,
default="benchmarks.png",
help="Output file for the plot",
)
parser.add_argument(
"--model_substring",
type=lambda s: [item for item in s.split(",")],
default=None,
help="Comma separated list of model substrings. Only plot results for models containing this substring",
)
args = parser.parse_args()
plot_benchmarks(args.results_dir, args.tasks, args.output, args.model_substring)