import logging |
import warnings |
from typing import List, Sequence |
import pytorch_lightning as pl |
import rich.syntax |
import rich.tree |
from omegaconf import DictConfig, OmegaConf |
from pytorch_lightning.utilities import rank_zero_only |
def get_logger(name=__name__) -> logging.Logger: |
"""Initializes multi-GPU-friendly python command line logger.""" |
logger = logging.getLogger(name) |
for level in ( |
"debug", |
"info", |
"warning", |
"error", |
"exception", |
"fatal", |
"critical", |
): |
setattr(logger, level, rank_zero_only(getattr(logger, level))) |
return logger |
log = get_logger(__name__) |
def extras(config: DictConfig) -> None: |
"""Applies optional utilities, controlled by config flags. |
Utilities: |
- Ignoring python warnings |
- Rich config printing |
""" |
if config.get("ignore_warnings"): |
log.info("Disabling python warnings! <config.ignore_warnings=True>") |
warnings.filterwarnings("ignore") |
if config.get("print_config"): |
log.info("Printing config tree with Rich! <config.print_config=True>") |
print_config(config, resolve=True) |
@rank_zero_only |
def print_config( |
config: DictConfig, |
print_order: Sequence[str] = ( |
"datamodule", |
"model", |
"callbacks", |
"logger", |
"trainer", |
), |
resolve: bool = True, |
) -> None: |
"""Prints content of DictConfig using Rich library and its tree structure. |
Args: |
config (DictConfig): Configuration composed by Hydra. |
print_order (Sequence[str], optional): Determines in what order config components are printed. |
resolve (bool, optional): Whether to resolve reference fields of DictConfig. |
""" |
style = "dim" |
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
quee = [] |
for field in print_order: |
quee.append(field) if field in config else log.info(f"Field '{field}' not found in config") |
for field in config: |
if field not in quee: |
quee.append(field) |
for field in quee: |
branch = tree.add(field, style=style, guide_style=style) |
config_group = config[field] |
if isinstance(config_group, DictConfig): |
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) |
else: |
branch_content = str(config_group) |
branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
rich.print(tree) |
with open("config_tree.log", "w") as file: |
rich.print(tree, file=file) |
@rank_zero_only |
def log_hyperparameters( |
config: DictConfig, |
model: pl.LightningModule, |
datamodule: pl.LightningDataModule, |
trainer: pl.Trainer, |
callbacks: List[pl.Callback], |
logger: List[pl.loggers.LightningLoggerBase], |
) -> None: |
"""Controls which config parts are saved by Lightning loggers. |
Additionaly saves: |
- number of model parameters |
""" |
hparams = {} |
hparams["trainer"] = config["trainer"] |
hparams["model"] = config["model"] |
hparams["datamodule"] = config["datamodule"] |
if "seed" in config: |
hparams["seed"] = config["seed"] |
if "callbacks" in config: |
hparams["callbacks"] = config["callbacks"] |
hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) |
hparams["model/params/trainable"] = sum( |
p.numel() for p in model.parameters() if p.requires_grad |
) |
hparams["model/params/non_trainable"] = sum( |
p.numel() for p in model.parameters() if not p.requires_grad |
) |
trainer.logger.log_hyperparams(hparams) |
def finish( |
config: DictConfig, |
model: pl.LightningModule, |
datamodule: pl.LightningDataModule, |
trainer: pl.Trainer, |
callbacks: List[pl.Callback], |
logger: List[pl.loggers.LightningLoggerBase], |
) -> None: |
"""Makes sure everything closed properly.""" |
for lg in logger: |
if isinstance(lg, pl.loggers.wandb.WandbLogger): |
import wandb |
wandb.finish() |