File size: 7,135 Bytes
2840956 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os, re
from omegaconf import OmegaConf
import logging
mainlogger = logging.getLogger('mainlogger')
import torch
from collections import OrderedDict
def init_workspace(name, logdir, model_config, lightning_config, rank=0):
workdir = os.path.join(logdir, name)
ckptdir = os.path.join(workdir, "checkpoints")
cfgdir = os.path.join(workdir, "configs")
loginfo = os.path.join(workdir, "loginfo")
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
os.makedirs(workdir, exist_ok=True)
os.makedirs(ckptdir, exist_ok=True)
os.makedirs(cfgdir, exist_ok=True)
os.makedirs(loginfo, exist_ok=True)
if rank == 0:
if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'), exist_ok=True)
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
OmegaConf.save(OmegaConf.create({"lightning": lightning_config}), os.path.join(cfgdir, "lightning.yaml"))
return workdir, ckptdir, cfgdir, loginfo
def check_config_attribute(config, name):
if name in config:
value = getattr(config, name)
return value
else:
return None
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
default_callbacks_cfg = {
"model_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch}",
"verbose": True,
"save_last": False,
}
},
"batch_logger": {
"target": "callbacks.ImageLogger",
"params": {
"save_dir": logdir,
"batch_frequency": 1000,
"max_images": 4,
"clamp": True,
}
},
"learning_rate_logger": {
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
"params": {
"logging_interval": "step",
"log_momentum": False
}
},
"cuda_callback": {
"target": "callbacks.CUDACallback"
},
}
## optional setting for saving checkpoints
monitor_metric = check_config_attribute(config.model.params, "monitor")
if monitor_metric is not None:
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
mainlogger.info('Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch}-{step}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
return callbacks_cfg
def get_trainer_logger(lightning_config, logdir, on_debug):
default_logger_cfgs = {
"tensorboard": {
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "tensorboard",
}
},
"testtube": {
"target": "pytorch_lightning.loggers.CSVLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
}
},
}
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
return logger_cfg
def get_trainer_strategy(lightning_config):
default_strategy_dict = {
"target": "pytorch_lightning.strategies.DDPShardedStrategy"
}
if "strategy" in lightning_config:
strategy_cfg = lightning_config.strategy
return strategy_cfg
else:
strategy_cfg = OmegaConf.create()
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
return strategy_cfg
def load_checkpoints(model, model_cfg):
if check_config_attribute(model_cfg, "pretrained_checkpoint"):
pretrained_ckpt = model_cfg.pretrained_checkpoint
assert os.path.exists(pretrained_ckpt), "Error: Pre-trained checkpoint NOT found at:%s"%pretrained_ckpt
mainlogger.info(">>> Load weights from pretrained checkpoint")
pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
try:
if 'state_dict' in pl_sd.keys():
model.load_state_dict(pl_sd["state_dict"], strict=True)
mainlogger.info(">>> Loaded weights from pretrained checkpoint: %s"%pretrained_ckpt)
else:
# deepspeed
new_pl_sd = OrderedDict()
for key in pl_sd['module'].keys():
new_pl_sd[key[16:]]=pl_sd['module'][key]
model.load_state_dict(new_pl_sd, strict=True)
except:
model.load_state_dict(pl_sd)
else:
mainlogger.info(">>> Start training from scratch")
return model
def set_logger(logfile, name='mainlogger'):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logfile, mode='w')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
ch.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(fh)
logger.addHandler(ch)
return logger |