|
import logging |
|
import warnings |
|
from typing import List, Sequence |
|
|
|
import pytorch_lightning as pl |
|
import rich.syntax |
|
import rich.tree |
|
from omegaconf import DictConfig, OmegaConf |
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
|
def get_logger(name=__name__) -> logging.Logger: |
|
"""Initializes multi-GPU-friendly python command line logger.""" |
|
|
|
logger = logging.getLogger(name) |
|
|
|
|
|
|
|
for level in ( |
|
"debug", |
|
"info", |
|
"warning", |
|
"error", |
|
"exception", |
|
"fatal", |
|
"critical", |
|
): |
|
setattr(logger, level, rank_zero_only(getattr(logger, level))) |
|
|
|
return logger |
|
|
|
|
|
log = get_logger(__name__) |
|
|
|
|
|
def extras(config: DictConfig) -> None: |
|
"""Applies optional utilities, controlled by config flags. |
|
|
|
Utilities: |
|
- Ignoring python warnings |
|
- Rich config printing |
|
""" |
|
|
|
|
|
if config.get("ignore_warnings"): |
|
log.info("Disabling python warnings! <config.ignore_warnings=True>") |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
if config.get("print_config"): |
|
log.info("Printing config tree with Rich! <config.print_config=True>") |
|
print_config(config, resolve=True) |
|
|
|
|
|
@rank_zero_only |
|
def print_config( |
|
config: DictConfig, |
|
print_order: Sequence[str] = ( |
|
"datamodule", |
|
"model", |
|
"callbacks", |
|
"logger", |
|
"trainer", |
|
), |
|
resolve: bool = True, |
|
) -> None: |
|
"""Prints content of DictConfig using Rich library and its tree structure. |
|
|
|
Args: |
|
config (DictConfig): Configuration composed by Hydra. |
|
print_order (Sequence[str], optional): Determines in what order config components are printed. |
|
resolve (bool, optional): Whether to resolve reference fields of DictConfig. |
|
""" |
|
|
|
style = "dim" |
|
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
|
|
|
quee = [] |
|
|
|
for field in print_order: |
|
quee.append(field) if field in config else log.info(f"Field '{field}' not found in config") |
|
|
|
for field in config: |
|
if field not in quee: |
|
quee.append(field) |
|
|
|
for field in quee: |
|
branch = tree.add(field, style=style, guide_style=style) |
|
|
|
config_group = config[field] |
|
if isinstance(config_group, DictConfig): |
|
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) |
|
else: |
|
branch_content = str(config_group) |
|
|
|
branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
|
|
|
rich.print(tree) |
|
|
|
with open("config_tree.log", "w") as file: |
|
rich.print(tree, file=file) |
|
|
|
|
|
@rank_zero_only |
|
def log_hyperparameters( |
|
config: DictConfig, |
|
model: pl.LightningModule, |
|
datamodule: pl.LightningDataModule, |
|
trainer: pl.Trainer, |
|
callbacks: List[pl.Callback], |
|
logger: List[pl.loggers.LightningLoggerBase], |
|
) -> None: |
|
"""Controls which config parts are saved by Lightning loggers. |
|
|
|
Additionaly saves: |
|
- number of model parameters |
|
""" |
|
|
|
hparams = {} |
|
|
|
|
|
hparams["trainer"] = config["trainer"] |
|
hparams["model"] = config["model"] |
|
hparams["datamodule"] = config["datamodule"] |
|
|
|
if "seed" in config: |
|
hparams["seed"] = config["seed"] |
|
if "callbacks" in config: |
|
hparams["callbacks"] = config["callbacks"] |
|
|
|
|
|
hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) |
|
hparams["model/params/trainable"] = sum( |
|
p.numel() for p in model.parameters() if p.requires_grad |
|
) |
|
hparams["model/params/non_trainable"] = sum( |
|
p.numel() for p in model.parameters() if not p.requires_grad |
|
) |
|
|
|
|
|
trainer.logger.log_hyperparams(hparams) |
|
|
|
|
|
def finish( |
|
config: DictConfig, |
|
model: pl.LightningModule, |
|
datamodule: pl.LightningDataModule, |
|
trainer: pl.Trainer, |
|
callbacks: List[pl.Callback], |
|
logger: List[pl.loggers.LightningLoggerBase], |
|
) -> None: |
|
"""Makes sure everything closed properly.""" |
|
|
|
|
|
for lg in logger: |
|
if isinstance(lg, pl.loggers.wandb.WandbLogger): |
|
import wandb |
|
|
|
wandb.finish() |
|
|