import logging import warnings from typing import List, Sequence import pytorch_lightning as pl import rich.syntax import rich.tree from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only def get_logger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python command line logger.""" logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ( "debug", "info", "warning", "error", "exception", "fatal", "critical", ): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger log = get_logger(__name__) def extras(config: DictConfig) -> None: """Applies optional utilities, controlled by config flags. Utilities: - Ignoring python warnings - Rich config printing """ # disable python warnings if if config.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") # pretty print config tree using Rich library if if config.get("print_config"): log.info("Printing config tree with Rich! ") print_config(config, resolve=True) @rank_zero_only def print_config( config: DictConfig, print_order: Sequence[str] = ( "datamodule", "model", "callbacks", "logger", "trainer", ), resolve: bool = True, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. print_order (Sequence[str], optional): Determines in what order config components are printed. resolve (bool, optional): Whether to resolve reference fields of DictConfig. """ style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) quee = [] for field in print_order: quee.append(field) if field in config else log.info(f"Field '{field}' not found in config") for field in config: if field not in quee: quee.append(field) for field in quee: branch = tree.add(field, style=style, guide_style=style) config_group = config[field] if isinstance(config_group, DictConfig): branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) else: branch_content = str(config_group) branch.add(rich.syntax.Syntax(branch_content, "yaml")) rich.print(tree) with open("config_tree.log", "w") as file: rich.print(tree, file=file) @rank_zero_only def log_hyperparameters( config: DictConfig, model: pl.LightningModule, datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], logger: List[pl.loggers.LightningLoggerBase], ) -> None: """Controls which config parts are saved by Lightning loggers. Additionaly saves: - number of model parameters """ hparams = {} # choose which parts of hydra config are saved by loggers hparams["trainer"] = config["trainer"] hparams["model"] = config["model"] hparams["datamodule"] = config["datamodule"] if "seed" in config: hparams["seed"] = config["seed"] if "callbacks" in config: hparams["callbacks"] = config["callbacks"] # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) hparams["model/params/trainable"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) hparams["model/params/non_trainable"] = sum( p.numel() for p in model.parameters() if not p.requires_grad ) # send hparams to all loggers trainer.logger.log_hyperparams(hparams) def finish( config: DictConfig, model: pl.LightningModule, datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], logger: List[pl.loggers.LightningLoggerBase], ) -> None: """Makes sure everything closed properly.""" # without this sweeps with wandb logger might crash! for lg in logger: if isinstance(lg, pl.loggers.wandb.WandbLogger): import wandb wandb.finish()