Spaces:
Build error
Build error
import logging | |
import os | |
from pathlib import Path | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
from pytorch_lightning.utilities import rank_zero_only | |
from seq2seq_utils import save_json | |
def count_trainable_parameters(model): | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
params = sum([np.prod(p.size()) for p in model_parameters]) | |
return params | |
logger = logging.getLogger(__name__) | |
class Seq2SeqLoggingCallback(pl.Callback): | |
def on_batch_end(self, trainer, pl_module): | |
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} | |
pl_module.logger.log_metrics(lrs) | |
def _write_logs( | |
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True | |
) -> None: | |
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") | |
metrics = trainer.callback_metrics | |
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) | |
# Log results | |
od = Path(pl_module.hparams.output_dir) | |
if type_path == "test": | |
results_file = od / "test_results.txt" | |
generations_file = od / "test_generations.txt" | |
else: | |
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | |
# If people want this it will be easy enough to add back. | |
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | |
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | |
results_file.parent.mkdir(exist_ok=True) | |
generations_file.parent.mkdir(exist_ok=True) | |
with open(results_file, "a+") as writer: | |
for key in sorted(metrics): | |
if key in ["log", "progress_bar", "preds"]: | |
continue | |
val = metrics[key] | |
if isinstance(val, torch.Tensor): | |
val = val.item() | |
msg = f"{key}: {val:.6f}\n" | |
writer.write(msg) | |
if not save_generations: | |
return | |
if "preds" in metrics: | |
content = "\n".join(metrics["preds"]) | |
generations_file.open("w+").write(content) | |
def on_train_start(self, trainer, pl_module): | |
try: | |
npars = pl_module.model.model.num_parameters() | |
except AttributeError: | |
npars = pl_module.model.num_parameters() | |
n_trainable_pars = count_trainable_parameters(pl_module) | |
# mp stands for million parameters | |
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) | |
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | |
save_json(pl_module.metrics, pl_module.metrics_save_path) | |
return self._write_logs(trainer, pl_module, "test") | |
def on_validation_end(self, trainer: pl.Trainer, pl_module): | |
save_json(pl_module.metrics, pl_module.metrics_save_path) | |
# Uncommenting this will save val generations | |
# return self._write_logs(trainer, pl_module, "valid") | |
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False): | |
"""Saves the best model by validation ROUGE2 score.""" | |
if metric == "rouge2": | |
exp = "{val_avg_rouge2:.4f}-{step_count}" | |
elif metric == "bleu": | |
exp = "{val_avg_bleu:.4f}-{step_count}" | |
elif metric == "loss": | |
exp = "{val_avg_loss:.4f}-{step_count}" | |
else: | |
raise NotImplementedError( | |
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function." | |
) | |
checkpoint_callback = ModelCheckpoint( | |
filepath=os.path.join(output_dir, exp), | |
monitor=f"val_{metric}", | |
mode="min" if "loss" in metric else "max", | |
save_top_k=save_top_k, | |
) | |
return checkpoint_callback | |
def get_early_stopping_callback(metric, patience): | |
return EarlyStopping( | |
monitor=f"val_{metric}", # does this need avg? | |
mode="min" if "loss" in metric else "max", | |
patience=patience, | |
verbose=True, | |
) | |