Spaces:
Running
on
A10G
Running
on
A10G
import argparse | |
import datetime | |
import glob | |
import inspect | |
import os | |
import sys | |
from inspect import Parameter | |
from typing import Union | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torchvision | |
import wandb | |
from matplotlib import pyplot as plt | |
from natsort import natsorted | |
from omegaconf import OmegaConf | |
from packaging import version | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from pytorch_lightning.callbacks import Callback | |
from pytorch_lightning.loggers import WandbLogger | |
from pytorch_lightning.trainer import Trainer | |
from pytorch_lightning.utilities import rank_zero_only | |
from sgm.util import exists, instantiate_from_config, isheatmap | |
MULTINODE_HACKS = True | |
def default_trainer_args(): | |
argspec = dict(inspect.signature(Trainer.__init__).parameters) | |
argspec.pop("self") | |
default_args = { | |
param: argspec[param].default | |
for param in argspec | |
if argspec[param] != Parameter.empty | |
} | |
return default_args | |
def get_parser(**parser_kwargs): | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected.") | |
parser = argparse.ArgumentParser(**parser_kwargs) | |
parser.add_argument( | |
"-n", | |
"--name", | |
type=str, | |
const=True, | |
default="", | |
nargs="?", | |
help="postfix for logdir", | |
) | |
parser.add_argument( | |
"--no_date", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)", | |
) | |
parser.add_argument( | |
"-r", | |
"--resume", | |
type=str, | |
const=True, | |
default="", | |
nargs="?", | |
help="resume from logdir or checkpoint in logdir", | |
) | |
parser.add_argument( | |
"-b", | |
"--base", | |
nargs="*", | |
metavar="base_config.yaml", | |
help="paths to base configs. Loaded from left-to-right. " | |
"Parameters can be overwritten or added with command-line options of the form `--key value`.", | |
default=list(), | |
) | |
parser.add_argument( | |
"-t", | |
"--train", | |
type=str2bool, | |
const=True, | |
default=True, | |
nargs="?", | |
help="train", | |
) | |
parser.add_argument( | |
"--no-test", | |
type=str2bool, | |
const=True, | |
default=False, | |
nargs="?", | |
help="disable test", | |
) | |
parser.add_argument( | |
"-p", "--project", help="name of new or path to existing project" | |
) | |
parser.add_argument( | |
"-d", | |
"--debug", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="enable post-mortem debugging", | |
) | |
parser.add_argument( | |
"-s", | |
"--seed", | |
type=int, | |
default=23, | |
help="seed for seed_everything", | |
) | |
parser.add_argument( | |
"-f", | |
"--postfix", | |
type=str, | |
default="", | |
help="post-postfix for default name", | |
) | |
parser.add_argument( | |
"--projectname", | |
type=str, | |
default="stablediffusion", | |
) | |
parser.add_argument( | |
"-l", | |
"--logdir", | |
type=str, | |
default="logs", | |
help="directory for logging dat shit", | |
) | |
parser.add_argument( | |
"--scale_lr", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="scale base-lr by ngpu * batch_size * n_accumulate", | |
) | |
parser.add_argument( | |
"--legacy_naming", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="name run based on config file name if true, else by whole path", | |
) | |
parser.add_argument( | |
"--enable_tf32", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12", | |
) | |
parser.add_argument( | |
"--startup", | |
type=str, | |
default=None, | |
help="Startuptime from distributed script", | |
) | |
parser.add_argument( | |
"--wandb", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, # TODO: later default to True | |
help="log to wandb", | |
) | |
parser.add_argument( | |
"--no_base_name", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, # TODO: later default to True | |
help="log to wandb", | |
) | |
if version.parse(torch.__version__) >= version.parse("2.0.0"): | |
parser.add_argument( | |
"--resume_from_checkpoint", | |
type=str, | |
default=None, | |
help="single checkpoint file to resume from", | |
) | |
default_args = default_trainer_args() | |
for key in default_args: | |
parser.add_argument("--" + key, default=default_args[key]) | |
return parser | |
def get_checkpoint_name(logdir): | |
ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt") | |
ckpt = natsorted(glob.glob(ckpt)) | |
print('available "last" checkpoints:') | |
print(ckpt) | |
if len(ckpt) > 1: | |
print("got most recent checkpoint") | |
ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1] | |
print(f"Most recent ckpt is {ckpt}") | |
with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f: | |
f.write(ckpt + "\n") | |
try: | |
version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0]) | |
except Exception as e: | |
print("version confusion but not bad") | |
print(e) | |
version = 1 | |
# version = last_version + 1 | |
else: | |
# in this case, we only have one "last.ckpt" | |
ckpt = ckpt[0] | |
version = 1 | |
melk_ckpt_name = f"last-v{version}.ckpt" | |
print(f"Current melk ckpt name: {melk_ckpt_name}") | |
return ckpt, melk_ckpt_name | |
class SetupCallback(Callback): | |
def __init__( | |
self, | |
resume, | |
now, | |
logdir, | |
ckptdir, | |
cfgdir, | |
config, | |
lightning_config, | |
debug, | |
ckpt_name=None, | |
): | |
super().__init__() | |
self.resume = resume | |
self.now = now | |
self.logdir = logdir | |
self.ckptdir = ckptdir | |
self.cfgdir = cfgdir | |
self.config = config | |
self.lightning_config = lightning_config | |
self.debug = debug | |
self.ckpt_name = ckpt_name | |
def on_exception(self, trainer: pl.Trainer, pl_module, exception): | |
if not self.debug and trainer.global_rank == 0: | |
print("Summoning checkpoint.") | |
if self.ckpt_name is None: | |
ckpt_path = os.path.join(self.ckptdir, "last.ckpt") | |
else: | |
ckpt_path = os.path.join(self.ckptdir, self.ckpt_name) | |
trainer.save_checkpoint(ckpt_path) | |
def on_fit_start(self, trainer, pl_module): | |
if trainer.global_rank == 0: | |
# Create logdirs and save configs | |
os.makedirs(self.logdir, exist_ok=True) | |
os.makedirs(self.ckptdir, exist_ok=True) | |
os.makedirs(self.cfgdir, exist_ok=True) | |
if "callbacks" in self.lightning_config: | |
if ( | |
"metrics_over_trainsteps_checkpoint" | |
in self.lightning_config["callbacks"] | |
): | |
os.makedirs( | |
os.path.join(self.ckptdir, "trainstep_checkpoints"), | |
exist_ok=True, | |
) | |
print("Project config") | |
print(OmegaConf.to_yaml(self.config)) | |
if MULTINODE_HACKS: | |
import time | |
time.sleep(5) | |
OmegaConf.save( | |
self.config, | |
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), | |
) | |
print("Lightning config") | |
print(OmegaConf.to_yaml(self.lightning_config)) | |
OmegaConf.save( | |
OmegaConf.create({"lightning": self.lightning_config}), | |
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), | |
) | |
else: | |
# ModelCheckpoint callback created log directory --- remove it | |
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): | |
dst, name = os.path.split(self.logdir) | |
dst = os.path.join(dst, "child_runs", name) | |
os.makedirs(os.path.split(dst)[0], exist_ok=True) | |
try: | |
os.rename(self.logdir, dst) | |
except FileNotFoundError: | |
pass | |
class ImageLogger(Callback): | |
def __init__( | |
self, | |
batch_frequency, | |
max_images, | |
clamp=True, | |
increase_log_steps=True, | |
rescale=True, | |
disabled=False, | |
log_on_batch_idx=False, | |
log_first_step=False, | |
log_images_kwargs=None, | |
log_before_first_step=False, | |
enable_autocast=True, | |
): | |
super().__init__() | |
self.enable_autocast = enable_autocast | |
self.rescale = rescale | |
self.batch_freq = batch_frequency | |
self.max_images = max_images | |
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] | |
if not increase_log_steps: | |
self.log_steps = [self.batch_freq] | |
self.clamp = clamp | |
self.disabled = disabled | |
self.log_on_batch_idx = log_on_batch_idx | |
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} | |
self.log_first_step = log_first_step | |
self.log_before_first_step = log_before_first_step | |
def log_local( | |
self, | |
save_dir, | |
split, | |
images, | |
global_step, | |
current_epoch, | |
batch_idx, | |
pl_module: Union[None, pl.LightningModule] = None, | |
): | |
root = os.path.join(save_dir, "images", split) | |
for k in images: | |
if isheatmap(images[k]): | |
fig, ax = plt.subplots() | |
ax = ax.matshow( | |
images[k].cpu().numpy(), cmap="hot", interpolation="lanczos" | |
) | |
plt.colorbar(ax) | |
plt.axis("off") | |
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( | |
k, global_step, current_epoch, batch_idx | |
) | |
os.makedirs(root, exist_ok=True) | |
path = os.path.join(root, filename) | |
plt.savefig(path) | |
plt.close() | |
# TODO: support wandb | |
else: | |
grid = torchvision.utils.make_grid(images[k], nrow=4) | |
if self.rescale: | |
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w | |
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
grid = grid.numpy() | |
grid = (grid * 255).astype(np.uint8) | |
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( | |
k, global_step, current_epoch, batch_idx | |
) | |
path = os.path.join(root, filename) | |
os.makedirs(os.path.split(path)[0], exist_ok=True) | |
img = Image.fromarray(grid) | |
img.save(path) | |
if exists(pl_module): | |
assert isinstance( | |
pl_module.logger, WandbLogger | |
), "logger_log_image only supports WandbLogger currently" | |
pl_module.logger.log_image( | |
key=f"{split}/{k}", | |
images=[ | |
img, | |
], | |
step=pl_module.global_step, | |
) | |
def log_img(self, pl_module, batch, batch_idx, split="train"): | |
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step | |
if ( | |
self.check_frequency(check_idx) | |
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 | |
and callable(pl_module.log_images) | |
and | |
# batch_idx > 5 and | |
self.max_images > 0 | |
): | |
logger = type(pl_module.logger) | |
is_train = pl_module.training | |
if is_train: | |
pl_module.eval() | |
gpu_autocast_kwargs = { | |
"enabled": self.enable_autocast, # torch.is_autocast_enabled(), | |
"dtype": torch.get_autocast_gpu_dtype(), | |
"cache_enabled": torch.is_autocast_cache_enabled(), | |
} | |
with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs): | |
images = pl_module.log_images( | |
batch, split=split, **self.log_images_kwargs | |
) | |
for k in images: | |
N = min(images[k].shape[0], self.max_images) | |
if not isheatmap(images[k]): | |
images[k] = images[k][:N] | |
if isinstance(images[k], torch.Tensor): | |
images[k] = images[k].detach().float().cpu() | |
if self.clamp and not isheatmap(images[k]): | |
images[k] = torch.clamp(images[k], -1.0, 1.0) | |
self.log_local( | |
pl_module.logger.save_dir, | |
split, | |
images, | |
pl_module.global_step, | |
pl_module.current_epoch, | |
batch_idx, | |
pl_module=pl_module | |
if isinstance(pl_module.logger, WandbLogger) | |
else None, | |
) | |
if is_train: | |
pl_module.train() | |
def check_frequency(self, check_idx): | |
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( | |
check_idx > 0 or self.log_first_step | |
): | |
try: | |
self.log_steps.pop(0) | |
except IndexError as e: | |
print(e) | |
pass | |
return True | |
return False | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): | |
self.log_img(pl_module, batch, batch_idx, split="train") | |
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): | |
if self.log_before_first_step and pl_module.global_step == 0: | |
print(f"{self.__class__.__name__}: logging before training") | |
self.log_img(pl_module, batch, batch_idx, split="train") | |
def on_validation_batch_end( | |
self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs | |
): | |
if not self.disabled and pl_module.global_step > 0: | |
self.log_img(pl_module, batch, batch_idx, split="val") | |
if hasattr(pl_module, "calibrate_grad_norm"): | |
if ( | |
pl_module.calibrate_grad_norm and batch_idx % 25 == 0 | |
) and batch_idx > 0: | |
self.log_gradients(trainer, pl_module, batch_idx=batch_idx) | |
def init_wandb(save_dir, opt, config, group_name, name_str): | |
print(f"setting WANDB_DIR to {save_dir}") | |
os.makedirs(save_dir, exist_ok=True) | |
os.environ["WANDB_DIR"] = save_dir | |
if opt.debug: | |
wandb.init(project=opt.projectname, mode="offline", group=group_name) | |
else: | |
wandb.init( | |
project=opt.projectname, | |
config=config, | |
settings=wandb.Settings(code_dir="./sgm"), | |
group=group_name, | |
name=name_str, | |
) | |
if __name__ == "__main__": | |
# custom parser to specify config files, train, test and debug mode, | |
# postfix, resume. | |
# `--key value` arguments are interpreted as arguments to the trainer. | |
# `nested.key=value` arguments are interpreted as config parameters. | |
# configs are merged from left-to-right followed by command line parameters. | |
# model: | |
# base_learning_rate: float | |
# target: path to lightning module | |
# params: | |
# key: value | |
# data: | |
# target: main.DataModuleFromConfig | |
# params: | |
# batch_size: int | |
# wrap: bool | |
# train: | |
# target: path to train dataset | |
# params: | |
# key: value | |
# validation: | |
# target: path to validation dataset | |
# params: | |
# key: value | |
# test: | |
# target: path to test dataset | |
# params: | |
# key: value | |
# lightning: (optional, has sane defaults and can be specified on cmdline) | |
# trainer: | |
# additional arguments to trainer | |
# logger: | |
# logger to instantiate | |
# modelcheckpoint: | |
# modelcheckpoint to instantiate | |
# callbacks: | |
# callback1: | |
# target: importpath | |
# params: | |
# key: value | |
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
# add cwd for convenience and to make classes in this file available when | |
# running as `python main.py` | |
# (in particular `main.DataModuleFromConfig`) | |
sys.path.append(os.getcwd()) | |
parser = get_parser() | |
opt, unknown = parser.parse_known_args() | |
if opt.name and opt.resume: | |
raise ValueError( | |
"-n/--name and -r/--resume cannot be specified both." | |
"If you want to resume training in a new log folder, " | |
"use -n/--name in combination with --resume_from_checkpoint" | |
) | |
melk_ckpt_name = None | |
name = None | |
if opt.resume: | |
if not os.path.exists(opt.resume): | |
raise ValueError("Cannot find {}".format(opt.resume)) | |
if os.path.isfile(opt.resume): | |
paths = opt.resume.split("/") | |
# idx = len(paths)-paths[::-1].index("logs")+1 | |
# logdir = "/".join(paths[:idx]) | |
logdir = "/".join(paths[:-2]) | |
ckpt = opt.resume | |
_, melk_ckpt_name = get_checkpoint_name(logdir) | |
else: | |
assert os.path.isdir(opt.resume), opt.resume | |
logdir = opt.resume.rstrip("/") | |
ckpt, melk_ckpt_name = get_checkpoint_name(logdir) | |
print("#" * 100) | |
print(f'Resuming from checkpoint "{ckpt}"') | |
print("#" * 100) | |
opt.resume_from_checkpoint = ckpt | |
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) | |
opt.base = base_configs + opt.base | |
_tmp = logdir.split("/") | |
nowname = _tmp[-1] | |
else: | |
if opt.name: | |
name = "_" + opt.name | |
elif opt.base: | |
if opt.no_base_name: | |
name = "" | |
else: | |
if opt.legacy_naming: | |
cfg_fname = os.path.split(opt.base[0])[-1] | |
cfg_name = os.path.splitext(cfg_fname)[0] | |
else: | |
assert "configs" in os.path.split(opt.base[0])[0], os.path.split( | |
opt.base[0] | |
)[0] | |
cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[ | |
os.path.split(opt.base[0])[0].split(os.sep).index("configs") | |
+ 1 : | |
] # cut away the first one (we assert all configs are in "configs") | |
cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0] | |
cfg_name = "-".join(cfg_path) + f"-{cfg_name}" | |
name = "_" + cfg_name | |
else: | |
name = "" | |
if not opt.no_date: | |
nowname = now + name + opt.postfix | |
else: | |
nowname = name + opt.postfix | |
if nowname.startswith("_"): | |
nowname = nowname[1:] | |
logdir = os.path.join(opt.logdir, nowname) | |
print(f"LOGDIR: {logdir}") | |
ckptdir = os.path.join(logdir, "checkpoints") | |
cfgdir = os.path.join(logdir, "configs") | |
seed_everything(opt.seed, workers=True) | |
# move before model init, in case a torch.compile(...) is called somewhere | |
if opt.enable_tf32: | |
# pt_version = version.parse(torch.__version__) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
print(f"Enabling TF32 for PyTorch {torch.__version__}") | |
else: | |
print(f"Using default TF32 settings for PyTorch {torch.__version__}:") | |
print( | |
f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}" | |
) | |
print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}") | |
try: | |
# init and save configs | |
configs = [OmegaConf.load(cfg) for cfg in opt.base] | |
cli = OmegaConf.from_dotlist(unknown) | |
config = OmegaConf.merge(*configs, cli) | |
lightning_config = config.pop("lightning", OmegaConf.create()) | |
# merge trainer cli with config | |
trainer_config = lightning_config.get("trainer", OmegaConf.create()) | |
# default to gpu | |
trainer_config["accelerator"] = "gpu" | |
# | |
standard_args = default_trainer_args() | |
for k in standard_args: | |
if getattr(opt, k) != standard_args[k]: | |
trainer_config[k] = getattr(opt, k) | |
ckpt_resume_path = opt.resume_from_checkpoint | |
if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu": | |
del trainer_config["accelerator"] | |
cpu = True | |
else: | |
gpuinfo = trainer_config["devices"] | |
print(f"Running on GPUs {gpuinfo}") | |
cpu = False | |
trainer_opt = argparse.Namespace(**trainer_config) | |
lightning_config.trainer = trainer_config | |
# model | |
model = instantiate_from_config(config.model) | |
# trainer and callbacks | |
trainer_kwargs = dict() | |
# default logger configs | |
default_logger_cfgs = { | |
"wandb": { | |
"target": "pytorch_lightning.loggers.WandbLogger", | |
"params": { | |
"name": nowname, | |
# "save_dir": logdir, | |
"offline": opt.debug, | |
"id": nowname, | |
"project": opt.projectname, | |
"log_model": False, | |
# "dir": logdir, | |
}, | |
}, | |
"csv": { | |
"target": "pytorch_lightning.loggers.CSVLogger", | |
"params": { | |
"name": "testtube", # hack for sbord fanatics | |
"save_dir": logdir, | |
}, | |
}, | |
} | |
default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"] | |
if opt.wandb: | |
# TODO change once leaving "swiffer" config directory | |
try: | |
group_name = nowname.split(now)[-1].split("-")[1] | |
except: | |
group_name = nowname | |
default_logger_cfg["params"]["group"] = group_name | |
init_wandb( | |
os.path.join(os.getcwd(), logdir), | |
opt=opt, | |
group_name=group_name, | |
config=config, | |
name_str=nowname, | |
) | |
if "logger" in lightning_config: | |
logger_cfg = lightning_config.logger | |
else: | |
logger_cfg = OmegaConf.create() | |
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) | |
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) | |
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to | |
# specify which metric is used to determine best models | |
default_modelckpt_cfg = { | |
"target": "pytorch_lightning.callbacks.ModelCheckpoint", | |
"params": { | |
"dirpath": ckptdir, | |
"filename": "{epoch:06}", | |
"verbose": True, | |
"save_last": True, | |
}, | |
} | |
if hasattr(model, "monitor"): | |
print(f"Monitoring {model.monitor} as checkpoint metric.") | |
default_modelckpt_cfg["params"]["monitor"] = model.monitor | |
default_modelckpt_cfg["params"]["save_top_k"] = 3 | |
if "modelcheckpoint" in lightning_config: | |
modelckpt_cfg = lightning_config.modelcheckpoint | |
else: | |
modelckpt_cfg = OmegaConf.create() | |
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) | |
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") | |
# https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html | |
# default to ddp if not further specified | |
default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"} | |
if "strategy" in lightning_config: | |
strategy_cfg = lightning_config.strategy | |
else: | |
strategy_cfg = OmegaConf.create() | |
default_strategy_config["params"] = { | |
"find_unused_parameters": False, | |
# "static_graph": True, | |
# "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded | |
} | |
strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg) | |
print( | |
f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ " | |
) | |
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) | |
# add callback which sets up log directory | |
default_callbacks_cfg = { | |
"setup_callback": { | |
"target": "main.SetupCallback", | |
"params": { | |
"resume": opt.resume, | |
"now": now, | |
"logdir": logdir, | |
"ckptdir": ckptdir, | |
"cfgdir": cfgdir, | |
"config": config, | |
"lightning_config": lightning_config, | |
"debug": opt.debug, | |
"ckpt_name": melk_ckpt_name, | |
}, | |
}, | |
"image_logger": { | |
"target": "main.ImageLogger", | |
"params": {"batch_frequency": 1000, "max_images": 4, "clamp": True}, | |
}, | |
"learning_rate_logger": { | |
"target": "pytorch_lightning.callbacks.LearningRateMonitor", | |
"params": { | |
"logging_interval": "step", | |
# "log_momentum": True | |
}, | |
}, | |
} | |
if version.parse(pl.__version__) >= version.parse("1.4.0"): | |
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) | |
if "callbacks" in lightning_config: | |
callbacks_cfg = lightning_config.callbacks | |
else: | |
callbacks_cfg = OmegaConf.create() | |
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg: | |
print( | |
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space." | |
) | |
default_metrics_over_trainsteps_ckpt_dict = { | |
"metrics_over_trainsteps_checkpoint": { | |
"target": "pytorch_lightning.callbacks.ModelCheckpoint", | |
"params": { | |
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), | |
"filename": "{epoch:06}-{step:09}", | |
"verbose": True, | |
"save_top_k": -1, | |
"every_n_train_steps": 10000, | |
"save_weights_only": True, | |
}, | |
} | |
} | |
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) | |
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) | |
if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None: | |
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path | |
elif "ignore_keys_callback" in callbacks_cfg: | |
del callbacks_cfg["ignore_keys_callback"] | |
trainer_kwargs["callbacks"] = [ | |
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg | |
] | |
if not "plugins" in trainer_kwargs: | |
trainer_kwargs["plugins"] = list() | |
# cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs) | |
trainer_opt = vars(trainer_opt) | |
trainer_kwargs = { | |
key: val for key, val in trainer_kwargs.items() if key not in trainer_opt | |
} | |
trainer = Trainer(**trainer_opt, **trainer_kwargs) | |
trainer.logdir = logdir ### | |
# data | |
data = instantiate_from_config(config.data) | |
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html | |
# calling these ourselves should not be necessary but it is. | |
# lightning still takes care of proper multiprocessing though | |
data.prepare_data() | |
# data.setup() | |
print("#### Data #####") | |
try: | |
for k in data.datasets: | |
print( | |
f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}" | |
) | |
except: | |
print("datasets not yet initialized.") | |
# configure learning rate | |
if "batch_size" in config.data.params: | |
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate | |
else: | |
bs, base_lr = ( | |
config.data.params.train.loader.batch_size, | |
config.model.base_learning_rate, | |
) | |
if not cpu: | |
ngpu = len(lightning_config.trainer.devices.strip(",").split(",")) | |
else: | |
ngpu = 1 | |
if "accumulate_grad_batches" in lightning_config.trainer: | |
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches | |
else: | |
accumulate_grad_batches = 1 | |
print(f"accumulate_grad_batches = {accumulate_grad_batches}") | |
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches | |
if opt.scale_lr: | |
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr | |
print( | |
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( | |
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr | |
) | |
) | |
else: | |
model.learning_rate = base_lr | |
print("++++ NOT USING LR SCALING ++++") | |
print(f"Setting learning rate to {model.learning_rate:.2e}") | |
# allow checkpointing via USR1 | |
def melk(*args, **kwargs): | |
# run all checkpoint hooks | |
if trainer.global_rank == 0: | |
print("Summoning checkpoint.") | |
if melk_ckpt_name is None: | |
ckpt_path = os.path.join(ckptdir, "last.ckpt") | |
else: | |
ckpt_path = os.path.join(ckptdir, melk_ckpt_name) | |
trainer.save_checkpoint(ckpt_path) | |
def divein(*args, **kwargs): | |
if trainer.global_rank == 0: | |
import pudb | |
pudb.set_trace() | |
import signal | |
signal.signal(signal.SIGUSR1, melk) | |
signal.signal(signal.SIGUSR2, divein) | |
# run | |
if opt.train: | |
try: | |
trainer.fit(model, data, ckpt_path=ckpt_resume_path) | |
except Exception: | |
if not opt.debug: | |
melk() | |
raise | |
if not opt.no_test and not trainer.interrupted: | |
trainer.test(model, data) | |
except RuntimeError as err: | |
if MULTINODE_HACKS: | |
import datetime | |
import os | |
import socket | |
import requests | |
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?") | |
hostname = socket.gethostname() | |
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") | |
resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id") | |
print( | |
f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}", | |
flush=True, | |
) | |
raise err | |
except Exception: | |
if opt.debug and trainer.global_rank == 0: | |
try: | |
import pudb as debugger | |
except ImportError: | |
import pdb as debugger | |
debugger.post_mortem() | |
raise | |
finally: | |
# move newly created debug project to debug_runs | |
if opt.debug and not opt.resume and trainer.global_rank == 0: | |
dst, name = os.path.split(logdir) | |
dst = os.path.join(dst, "debug_runs", name) | |
os.makedirs(os.path.split(dst)[0], exist_ok=True) | |
os.rename(logdir, dst) | |
if opt.wandb: | |
wandb.finish() | |
# if trainer.global_rank == 0: | |
# print(trainer.profiler.summary()) | |