File size: 4,513 Bytes
d2e7940 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import logging
import warnings
from typing import List, Sequence
import pytorch_lightning as pl
import rich.syntax
import rich.tree
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
def get_logger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python command line logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_logger(__name__)
def extras(config: DictConfig) -> None:
"""Applies optional utilities, controlled by config flags.
Utilities:
- Ignoring python warnings
- Rich config printing
"""
# disable python warnings if <config.ignore_warnings=True>
if config.get("ignore_warnings"):
log.info("Disabling python warnings! <config.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# pretty print config tree using Rich library if <config.print_config=True>
if config.get("print_config"):
log.info("Printing config tree with Rich! <config.print_config=True>")
print_config(config, resolve=True)
@rank_zero_only
def print_config(
config: DictConfig,
print_order: Sequence[str] = (
"datamodule",
"model",
"callbacks",
"logger",
"trainer",
),
resolve: bool = True,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
print_order (Sequence[str], optional): Determines in what order config components are printed.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
quee = []
for field in print_order:
quee.append(field) if field in config else log.info(f"Field '{field}' not found in config")
for field in config:
if field not in quee:
quee.append(field)
for field in quee:
branch = tree.add(field, style=style, guide_style=style)
config_group = config[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
with open("config_tree.log", "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def log_hyperparameters(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
) -> None:
"""Controls which config parts are saved by Lightning loggers.
Additionaly saves:
- number of model parameters
"""
hparams = {}
# choose which parts of hydra config are saved by loggers
hparams["trainer"] = config["trainer"]
hparams["model"] = config["model"]
hparams["datamodule"] = config["datamodule"]
if "seed" in config:
hparams["seed"] = config["seed"]
if "callbacks" in config:
hparams["callbacks"] = config["callbacks"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
# send hparams to all loggers
trainer.logger.log_hyperparams(hparams)
def finish(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
) -> None:
"""Makes sure everything closed properly."""
# without this sweeps with wandb logger might crash!
for lg in logger:
if isinstance(lg, pl.loggers.wandb.WandbLogger):
import wandb
wandb.finish()
|