Hannes Kuchelmeister
cleanup to make ready as submodule
c7be723
raw
history blame
4.51 kB
import logging
import warnings
from typing import List, Sequence
import pytorch_lightning as pl
import rich.syntax
import rich.tree
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
def get_logger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python command line logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_logger(__name__)
def extras(config: DictConfig) -> None:
"""Applies optional utilities, controlled by config flags.
Utilities:
- Ignoring python warnings
- Rich config printing
"""
# disable python warnings if <config.ignore_warnings=True>
if config.get("ignore_warnings"):
log.info("Disabling python warnings! <config.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# pretty print config tree using Rich library if <config.print_config=True>
if config.get("print_config"):
log.info("Printing config tree with Rich! <config.print_config=True>")
print_config(config, resolve=True)
@rank_zero_only
def print_config(
config: DictConfig,
print_order: Sequence[str] = (
"datamodule",
"model",
"callbacks",
"logger",
"trainer",
),
resolve: bool = True,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
print_order (Sequence[str], optional): Determines in what order config components are printed.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
quee = []
for field in print_order:
quee.append(field) if field in config else log.info(f"Field '{field}' not found in config")
for field in config:
if field not in quee:
quee.append(field)
for field in quee:
branch = tree.add(field, style=style, guide_style=style)
config_group = config[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
with open("config_tree.log", "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def log_hyperparameters(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
) -> None:
"""Controls which config parts are saved by Lightning loggers.
Additionaly saves:
- number of model parameters
"""
hparams = {}
# choose which parts of hydra config are saved by loggers
hparams["trainer"] = config["trainer"]
hparams["model"] = config["model"]
hparams["datamodule"] = config["datamodule"]
if "seed" in config:
hparams["seed"] = config["seed"]
if "callbacks" in config:
hparams["callbacks"] = config["callbacks"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
# send hparams to all loggers
trainer.logger.log_hyperparams(hparams)
def finish(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
) -> None:
"""Makes sure everything closed properly."""
# without this sweeps with wandb logger might crash!
for lg in logger:
if isinstance(lg, pl.loggers.wandb.WandbLogger):
import wandb
wandb.finish()