Spaces:
Runtime error
Runtime error
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import csv | |
from collections import defaultdict | |
from dataclasses import dataclass, field | |
from typing import List, Optional | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from matplotlib.ticker import ScalarFormatter | |
from transformers import HfArgumentParser | |
def list_field(default=None, metadata=None): | |
return field(default_factory=lambda: default, metadata=metadata) | |
class PlotArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
""" | |
csv_file: str = field( | |
metadata={"help": "The csv file to plot."}, | |
) | |
plot_along_batch: bool = field( | |
default=False, | |
metadata={"help": "Whether to plot along batch size or sequence length. Defaults to sequence length."}, | |
) | |
is_time: bool = field( | |
default=False, | |
metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."}, | |
) | |
no_log_scale: bool = field( | |
default=False, | |
metadata={"help": "Disable logarithmic scale when plotting"}, | |
) | |
is_train: bool = field( | |
default=False, | |
metadata={ | |
"help": "Whether the csv file has training results or inference results. Defaults to inference results." | |
}, | |
) | |
figure_png_file: Optional[str] = field( | |
default=None, | |
metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."}, | |
) | |
short_model_names: Optional[List[str]] = list_field( | |
default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."} | |
) | |
def can_convert_to_int(string): | |
try: | |
int(string) | |
return True | |
except ValueError: | |
return False | |
def can_convert_to_float(string): | |
try: | |
float(string) | |
return True | |
except ValueError: | |
return False | |
class Plot: | |
def __init__(self, args): | |
self.args = args | |
self.result_dict = defaultdict(lambda: {"bsz": [], "seq_len": [], "result": {}}) | |
with open(self.args.csv_file, newline="") as csv_file: | |
reader = csv.DictReader(csv_file) | |
for row in reader: | |
model_name = row["model"] | |
self.result_dict[model_name]["bsz"].append(int(row["batch_size"])) | |
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"])) | |
if can_convert_to_int(row["result"]): | |
# value is not None | |
self.result_dict[model_name]["result"][ | |
(int(row["batch_size"]), int(row["sequence_length"])) | |
] = int(row["result"]) | |
elif can_convert_to_float(row["result"]): | |
# value is not None | |
self.result_dict[model_name]["result"][ | |
(int(row["batch_size"]), int(row["sequence_length"])) | |
] = float(row["result"]) | |
def plot(self): | |
fig, ax = plt.subplots() | |
title_str = "Time usage" if self.args.is_time else "Memory usage" | |
title_str = title_str + " for training" if self.args.is_train else title_str + " for inference" | |
if not self.args.no_log_scale: | |
# set logarithm scales | |
ax.set_xscale("log") | |
ax.set_yscale("log") | |
for axis in [ax.xaxis, ax.yaxis]: | |
axis.set_major_formatter(ScalarFormatter()) | |
for model_name_idx, model_name in enumerate(self.result_dict.keys()): | |
batch_sizes = sorted(set(self.result_dict[model_name]["bsz"])) | |
sequence_lengths = sorted(set(self.result_dict[model_name]["seq_len"])) | |
results = self.result_dict[model_name]["result"] | |
(x_axis_array, inner_loop_array) = ( | |
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes) | |
) | |
label_model_name = ( | |
model_name if self.args.short_model_names is None else self.args.short_model_names[model_name_idx] | |
) | |
for inner_loop_value in inner_loop_array: | |
if self.args.plot_along_batch: | |
y_axis_array = np.asarray( | |
[results[(x, inner_loop_value)] for x in x_axis_array if (x, inner_loop_value) in results], | |
dtype=int, | |
) | |
else: | |
y_axis_array = np.asarray( | |
[results[(inner_loop_value, x)] for x in x_axis_array if (inner_loop_value, x) in results], | |
dtype=np.float32, | |
) | |
(x_axis_label, inner_loop_label) = ( | |
("batch_size", "len") if self.args.plot_along_batch else ("in #tokens", "bsz") | |
) | |
x_axis_array = np.asarray(x_axis_array, int)[: len(y_axis_array)] | |
plt.scatter( | |
x_axis_array, y_axis_array, label=f"{label_model_name} - {inner_loop_label}: {inner_loop_value}" | |
) | |
plt.plot(x_axis_array, y_axis_array, "--") | |
title_str += f" {label_model_name} vs." | |
title_str = title_str[:-4] | |
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB" | |
# plot | |
plt.title(title_str) | |
plt.xlabel(x_axis_label) | |
plt.ylabel(y_axis_label) | |
plt.legend() | |
if self.args.figure_png_file is not None: | |
plt.savefig(self.args.figure_png_file) | |
else: | |
plt.show() | |
def main(): | |
parser = HfArgumentParser(PlotArguments) | |
plot_args = parser.parse_args_into_dataclasses()[0] | |
plot = Plot(args=plot_args) | |
plot.plot() | |
if __name__ == "__main__": | |
main() | |