Spaces:
Build error
Build error
File size: 6,340 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import time
import warnings
from importlib.util import find_spec
from pathlib import Path
from typing import Callable, List
import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Callback
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities import rank_zero_only
from . import pylogger, rich_utils
log = pylogger.get_pylogger(__name__)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that wraps the task function in extra utilities.
Makes multirun more resistant to failure.
Utilities:
- Calling the `utils.extras()` before the task is started
- Calling the `utils.close_loggers()` after the task is finished
- Logging the exception if occurs
- Logging the task total execution time
- Logging the output dir
"""
def wrap(cfg: DictConfig):
# apply extra utilities
extras(cfg)
# execute the task
try:
start_time = time.time()
ret = task_func(cfg=cfg)
except Exception as ex:
log.exception("") # save exception to `.log` file
raise ex
finally:
path = Path(cfg.paths.output_dir, "exec_time.log")
content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
save_file(path, content) # save task execution time (even if exception occurs)
close_loggers() # close loggers (even if exception occurs so multirun won't fail)
log.info(f"Output dir: {cfg.paths.output_dir}")
return ret
return wrap
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
rich_utils.enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
@rank_zero_only
def save_file(path: str, content: str) -> None:
"""Save file in rank zero mode (only on one process in multi-GPU setup)."""
with open(path, "w+") as file:
file.write(content)
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config."""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("Callbacks config is empty.")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config."""
logger: List[Logger] = []
if not logger_cfg:
log.warning("Logger config is empty.")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
return logger
@rank_zero_only
def log_hyperparameters(object_dict: dict) -> None:
"""Controls which config parts are saved by lightning loggers.
Additionally saves:
- Number of model parameters
"""
hparams = {}
cfg = object_dict["cfg"]
model = object_dict["model"]
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
for k in cfg.keys():
hparams[k] = cfg.get(k)
# Resolve all interpolations
def _resolve(_cfg):
if isinstance(_cfg, DictConfig):
_cfg = OmegaConf.to_container(_cfg, resolve=True)
return _cfg
hparams = {k: _resolve(v) for k, v in hparams.items()}
# send hparams to all loggers
trainer.logger.log_hyperparams(hparams)
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule."""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def close_loggers() -> None:
"""Makes sure all loggers closed properly (prevents logging failure during multirun)."""
log.info("Closing loggers...")
if find_spec("wandb"): # if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
|