multimodalart HF staff commited on
Commit
422da78
·
1 Parent(s): ef21785

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -943
main.py DELETED
@@ -1,943 +0,0 @@
1
- import argparse
2
- import datetime
3
- import glob
4
- import inspect
5
- import os
6
- import sys
7
- from inspect import Parameter
8
- from typing import Union
9
-
10
- import numpy as np
11
- import pytorch_lightning as pl
12
- import torch
13
- import torchvision
14
- import wandb
15
- from matplotlib import pyplot as plt
16
- from natsort import natsorted
17
- from omegaconf import OmegaConf
18
- from packaging import version
19
- from PIL import Image
20
- from pytorch_lightning import seed_everything
21
- from pytorch_lightning.callbacks import Callback
22
- from pytorch_lightning.loggers import WandbLogger
23
- from pytorch_lightning.trainer import Trainer
24
- from pytorch_lightning.utilities import rank_zero_only
25
-
26
- from sgm.util import exists, instantiate_from_config, isheatmap
27
-
28
- MULTINODE_HACKS = True
29
-
30
-
31
- def default_trainer_args():
32
- argspec = dict(inspect.signature(Trainer.__init__).parameters)
33
- argspec.pop("self")
34
- default_args = {
35
- param: argspec[param].default
36
- for param in argspec
37
- if argspec[param] != Parameter.empty
38
- }
39
- return default_args
40
-
41
-
42
- def get_parser(**parser_kwargs):
43
- def str2bool(v):
44
- if isinstance(v, bool):
45
- return v
46
- if v.lower() in ("yes", "true", "t", "y", "1"):
47
- return True
48
- elif v.lower() in ("no", "false", "f", "n", "0"):
49
- return False
50
- else:
51
- raise argparse.ArgumentTypeError("Boolean value expected.")
52
-
53
- parser = argparse.ArgumentParser(**parser_kwargs)
54
- parser.add_argument(
55
- "-n",
56
- "--name",
57
- type=str,
58
- const=True,
59
- default="",
60
- nargs="?",
61
- help="postfix for logdir",
62
- )
63
- parser.add_argument(
64
- "--no_date",
65
- type=str2bool,
66
- nargs="?",
67
- const=True,
68
- default=False,
69
- help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
70
- )
71
- parser.add_argument(
72
- "-r",
73
- "--resume",
74
- type=str,
75
- const=True,
76
- default="",
77
- nargs="?",
78
- help="resume from logdir or checkpoint in logdir",
79
- )
80
- parser.add_argument(
81
- "-b",
82
- "--base",
83
- nargs="*",
84
- metavar="base_config.yaml",
85
- help="paths to base configs. Loaded from left-to-right. "
86
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
87
- default=list(),
88
- )
89
- parser.add_argument(
90
- "-t",
91
- "--train",
92
- type=str2bool,
93
- const=True,
94
- default=True,
95
- nargs="?",
96
- help="train",
97
- )
98
- parser.add_argument(
99
- "--no-test",
100
- type=str2bool,
101
- const=True,
102
- default=False,
103
- nargs="?",
104
- help="disable test",
105
- )
106
- parser.add_argument(
107
- "-p", "--project", help="name of new or path to existing project"
108
- )
109
- parser.add_argument(
110
- "-d",
111
- "--debug",
112
- type=str2bool,
113
- nargs="?",
114
- const=True,
115
- default=False,
116
- help="enable post-mortem debugging",
117
- )
118
- parser.add_argument(
119
- "-s",
120
- "--seed",
121
- type=int,
122
- default=23,
123
- help="seed for seed_everything",
124
- )
125
- parser.add_argument(
126
- "-f",
127
- "--postfix",
128
- type=str,
129
- default="",
130
- help="post-postfix for default name",
131
- )
132
- parser.add_argument(
133
- "--projectname",
134
- type=str,
135
- default="stablediffusion",
136
- )
137
- parser.add_argument(
138
- "-l",
139
- "--logdir",
140
- type=str,
141
- default="logs",
142
- help="directory for logging dat shit",
143
- )
144
- parser.add_argument(
145
- "--scale_lr",
146
- type=str2bool,
147
- nargs="?",
148
- const=True,
149
- default=False,
150
- help="scale base-lr by ngpu * batch_size * n_accumulate",
151
- )
152
- parser.add_argument(
153
- "--legacy_naming",
154
- type=str2bool,
155
- nargs="?",
156
- const=True,
157
- default=False,
158
- help="name run based on config file name if true, else by whole path",
159
- )
160
- parser.add_argument(
161
- "--enable_tf32",
162
- type=str2bool,
163
- nargs="?",
164
- const=True,
165
- default=False,
166
- help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
167
- )
168
- parser.add_argument(
169
- "--startup",
170
- type=str,
171
- default=None,
172
- help="Startuptime from distributed script",
173
- )
174
- parser.add_argument(
175
- "--wandb",
176
- type=str2bool,
177
- nargs="?",
178
- const=True,
179
- default=False, # TODO: later default to True
180
- help="log to wandb",
181
- )
182
- parser.add_argument(
183
- "--no_base_name",
184
- type=str2bool,
185
- nargs="?",
186
- const=True,
187
- default=False, # TODO: later default to True
188
- help="log to wandb",
189
- )
190
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
191
- parser.add_argument(
192
- "--resume_from_checkpoint",
193
- type=str,
194
- default=None,
195
- help="single checkpoint file to resume from",
196
- )
197
- default_args = default_trainer_args()
198
- for key in default_args:
199
- parser.add_argument("--" + key, default=default_args[key])
200
- return parser
201
-
202
-
203
- def get_checkpoint_name(logdir):
204
- ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
205
- ckpt = natsorted(glob.glob(ckpt))
206
- print('available "last" checkpoints:')
207
- print(ckpt)
208
- if len(ckpt) > 1:
209
- print("got most recent checkpoint")
210
- ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
211
- print(f"Most recent ckpt is {ckpt}")
212
- with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
213
- f.write(ckpt + "\n")
214
- try:
215
- version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
216
- except Exception as e:
217
- print("version confusion but not bad")
218
- print(e)
219
- version = 1
220
- # version = last_version + 1
221
- else:
222
- # in this case, we only have one "last.ckpt"
223
- ckpt = ckpt[0]
224
- version = 1
225
- melk_ckpt_name = f"last-v{version}.ckpt"
226
- print(f"Current melk ckpt name: {melk_ckpt_name}")
227
- return ckpt, melk_ckpt_name
228
-
229
-
230
- class SetupCallback(Callback):
231
- def __init__(
232
- self,
233
- resume,
234
- now,
235
- logdir,
236
- ckptdir,
237
- cfgdir,
238
- config,
239
- lightning_config,
240
- debug,
241
- ckpt_name=None,
242
- ):
243
- super().__init__()
244
- self.resume = resume
245
- self.now = now
246
- self.logdir = logdir
247
- self.ckptdir = ckptdir
248
- self.cfgdir = cfgdir
249
- self.config = config
250
- self.lightning_config = lightning_config
251
- self.debug = debug
252
- self.ckpt_name = ckpt_name
253
-
254
- def on_exception(self, trainer: pl.Trainer, pl_module, exception):
255
- if not self.debug and trainer.global_rank == 0:
256
- print("Summoning checkpoint.")
257
- if self.ckpt_name is None:
258
- ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
259
- else:
260
- ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
261
- trainer.save_checkpoint(ckpt_path)
262
-
263
- def on_fit_start(self, trainer, pl_module):
264
- if trainer.global_rank == 0:
265
- # Create logdirs and save configs
266
- os.makedirs(self.logdir, exist_ok=True)
267
- os.makedirs(self.ckptdir, exist_ok=True)
268
- os.makedirs(self.cfgdir, exist_ok=True)
269
-
270
- if "callbacks" in self.lightning_config:
271
- if (
272
- "metrics_over_trainsteps_checkpoint"
273
- in self.lightning_config["callbacks"]
274
- ):
275
- os.makedirs(
276
- os.path.join(self.ckptdir, "trainstep_checkpoints"),
277
- exist_ok=True,
278
- )
279
- print("Project config")
280
- print(OmegaConf.to_yaml(self.config))
281
- if MULTINODE_HACKS:
282
- import time
283
-
284
- time.sleep(5)
285
- OmegaConf.save(
286
- self.config,
287
- os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
288
- )
289
-
290
- print("Lightning config")
291
- print(OmegaConf.to_yaml(self.lightning_config))
292
- OmegaConf.save(
293
- OmegaConf.create({"lightning": self.lightning_config}),
294
- os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
295
- )
296
-
297
- else:
298
- # ModelCheckpoint callback created log directory --- remove it
299
- if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
300
- dst, name = os.path.split(self.logdir)
301
- dst = os.path.join(dst, "child_runs", name)
302
- os.makedirs(os.path.split(dst)[0], exist_ok=True)
303
- try:
304
- os.rename(self.logdir, dst)
305
- except FileNotFoundError:
306
- pass
307
-
308
-
309
- class ImageLogger(Callback):
310
- def __init__(
311
- self,
312
- batch_frequency,
313
- max_images,
314
- clamp=True,
315
- increase_log_steps=True,
316
- rescale=True,
317
- disabled=False,
318
- log_on_batch_idx=False,
319
- log_first_step=False,
320
- log_images_kwargs=None,
321
- log_before_first_step=False,
322
- enable_autocast=True,
323
- ):
324
- super().__init__()
325
- self.enable_autocast = enable_autocast
326
- self.rescale = rescale
327
- self.batch_freq = batch_frequency
328
- self.max_images = max_images
329
- self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
330
- if not increase_log_steps:
331
- self.log_steps = [self.batch_freq]
332
- self.clamp = clamp
333
- self.disabled = disabled
334
- self.log_on_batch_idx = log_on_batch_idx
335
- self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
336
- self.log_first_step = log_first_step
337
- self.log_before_first_step = log_before_first_step
338
-
339
- @rank_zero_only
340
- def log_local(
341
- self,
342
- save_dir,
343
- split,
344
- images,
345
- global_step,
346
- current_epoch,
347
- batch_idx,
348
- pl_module: Union[None, pl.LightningModule] = None,
349
- ):
350
- root = os.path.join(save_dir, "images", split)
351
- for k in images:
352
- if isheatmap(images[k]):
353
- fig, ax = plt.subplots()
354
- ax = ax.matshow(
355
- images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
356
- )
357
- plt.colorbar(ax)
358
- plt.axis("off")
359
-
360
- filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
361
- k, global_step, current_epoch, batch_idx
362
- )
363
- os.makedirs(root, exist_ok=True)
364
- path = os.path.join(root, filename)
365
- plt.savefig(path)
366
- plt.close()
367
- # TODO: support wandb
368
- else:
369
- grid = torchvision.utils.make_grid(images[k], nrow=4)
370
- if self.rescale:
371
- grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
372
- grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
373
- grid = grid.numpy()
374
- grid = (grid * 255).astype(np.uint8)
375
- filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
376
- k, global_step, current_epoch, batch_idx
377
- )
378
- path = os.path.join(root, filename)
379
- os.makedirs(os.path.split(path)[0], exist_ok=True)
380
- img = Image.fromarray(grid)
381
- img.save(path)
382
- if exists(pl_module):
383
- assert isinstance(
384
- pl_module.logger, WandbLogger
385
- ), "logger_log_image only supports WandbLogger currently"
386
- pl_module.logger.log_image(
387
- key=f"{split}/{k}",
388
- images=[
389
- img,
390
- ],
391
- step=pl_module.global_step,
392
- )
393
-
394
- @rank_zero_only
395
- def log_img(self, pl_module, batch, batch_idx, split="train"):
396
- check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
397
- if (
398
- self.check_frequency(check_idx)
399
- and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
400
- and callable(pl_module.log_images)
401
- and
402
- # batch_idx > 5 and
403
- self.max_images > 0
404
- ):
405
- logger = type(pl_module.logger)
406
- is_train = pl_module.training
407
- if is_train:
408
- pl_module.eval()
409
-
410
- gpu_autocast_kwargs = {
411
- "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
412
- "dtype": torch.get_autocast_gpu_dtype(),
413
- "cache_enabled": torch.is_autocast_cache_enabled(),
414
- }
415
- with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
416
- images = pl_module.log_images(
417
- batch, split=split, **self.log_images_kwargs
418
- )
419
-
420
- for k in images:
421
- N = min(images[k].shape[0], self.max_images)
422
- if not isheatmap(images[k]):
423
- images[k] = images[k][:N]
424
- if isinstance(images[k], torch.Tensor):
425
- images[k] = images[k].detach().float().cpu()
426
- if self.clamp and not isheatmap(images[k]):
427
- images[k] = torch.clamp(images[k], -1.0, 1.0)
428
-
429
- self.log_local(
430
- pl_module.logger.save_dir,
431
- split,
432
- images,
433
- pl_module.global_step,
434
- pl_module.current_epoch,
435
- batch_idx,
436
- pl_module=pl_module
437
- if isinstance(pl_module.logger, WandbLogger)
438
- else None,
439
- )
440
-
441
- if is_train:
442
- pl_module.train()
443
-
444
- def check_frequency(self, check_idx):
445
- if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
446
- check_idx > 0 or self.log_first_step
447
- ):
448
- try:
449
- self.log_steps.pop(0)
450
- except IndexError as e:
451
- print(e)
452
- pass
453
- return True
454
- return False
455
-
456
- @rank_zero_only
457
- def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
458
- if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
459
- self.log_img(pl_module, batch, batch_idx, split="train")
460
-
461
- @rank_zero_only
462
- def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
463
- if self.log_before_first_step and pl_module.global_step == 0:
464
- print(f"{self.__class__.__name__}: logging before training")
465
- self.log_img(pl_module, batch, batch_idx, split="train")
466
-
467
- @rank_zero_only
468
- def on_validation_batch_end(
469
- self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
470
- ):
471
- if not self.disabled and pl_module.global_step > 0:
472
- self.log_img(pl_module, batch, batch_idx, split="val")
473
- if hasattr(pl_module, "calibrate_grad_norm"):
474
- if (
475
- pl_module.calibrate_grad_norm and batch_idx % 25 == 0
476
- ) and batch_idx > 0:
477
- self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
478
-
479
-
480
- @rank_zero_only
481
- def init_wandb(save_dir, opt, config, group_name, name_str):
482
- print(f"setting WANDB_DIR to {save_dir}")
483
- os.makedirs(save_dir, exist_ok=True)
484
-
485
- os.environ["WANDB_DIR"] = save_dir
486
- if opt.debug:
487
- wandb.init(project=opt.projectname, mode="offline", group=group_name)
488
- else:
489
- wandb.init(
490
- project=opt.projectname,
491
- config=config,
492
- settings=wandb.Settings(code_dir="./sgm"),
493
- group=group_name,
494
- name=name_str,
495
- )
496
-
497
-
498
- if __name__ == "__main__":
499
- # custom parser to specify config files, train, test and debug mode,
500
- # postfix, resume.
501
- # `--key value` arguments are interpreted as arguments to the trainer.
502
- # `nested.key=value` arguments are interpreted as config parameters.
503
- # configs are merged from left-to-right followed by command line parameters.
504
-
505
- # model:
506
- # base_learning_rate: float
507
- # target: path to lightning module
508
- # params:
509
- # key: value
510
- # data:
511
- # target: main.DataModuleFromConfig
512
- # params:
513
- # batch_size: int
514
- # wrap: bool
515
- # train:
516
- # target: path to train dataset
517
- # params:
518
- # key: value
519
- # validation:
520
- # target: path to validation dataset
521
- # params:
522
- # key: value
523
- # test:
524
- # target: path to test dataset
525
- # params:
526
- # key: value
527
- # lightning: (optional, has sane defaults and can be specified on cmdline)
528
- # trainer:
529
- # additional arguments to trainer
530
- # logger:
531
- # logger to instantiate
532
- # modelcheckpoint:
533
- # modelcheckpoint to instantiate
534
- # callbacks:
535
- # callback1:
536
- # target: importpath
537
- # params:
538
- # key: value
539
-
540
- now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
541
-
542
- # add cwd for convenience and to make classes in this file available when
543
- # running as `python main.py`
544
- # (in particular `main.DataModuleFromConfig`)
545
- sys.path.append(os.getcwd())
546
-
547
- parser = get_parser()
548
-
549
- opt, unknown = parser.parse_known_args()
550
-
551
- if opt.name and opt.resume:
552
- raise ValueError(
553
- "-n/--name and -r/--resume cannot be specified both."
554
- "If you want to resume training in a new log folder, "
555
- "use -n/--name in combination with --resume_from_checkpoint"
556
- )
557
- melk_ckpt_name = None
558
- name = None
559
- if opt.resume:
560
- if not os.path.exists(opt.resume):
561
- raise ValueError("Cannot find {}".format(opt.resume))
562
- if os.path.isfile(opt.resume):
563
- paths = opt.resume.split("/")
564
- # idx = len(paths)-paths[::-1].index("logs")+1
565
- # logdir = "/".join(paths[:idx])
566
- logdir = "/".join(paths[:-2])
567
- ckpt = opt.resume
568
- _, melk_ckpt_name = get_checkpoint_name(logdir)
569
- else:
570
- assert os.path.isdir(opt.resume), opt.resume
571
- logdir = opt.resume.rstrip("/")
572
- ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
573
-
574
- print("#" * 100)
575
- print(f'Resuming from checkpoint "{ckpt}"')
576
- print("#" * 100)
577
-
578
- opt.resume_from_checkpoint = ckpt
579
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
580
- opt.base = base_configs + opt.base
581
- _tmp = logdir.split("/")
582
- nowname = _tmp[-1]
583
- else:
584
- if opt.name:
585
- name = "_" + opt.name
586
- elif opt.base:
587
- if opt.no_base_name:
588
- name = ""
589
- else:
590
- if opt.legacy_naming:
591
- cfg_fname = os.path.split(opt.base[0])[-1]
592
- cfg_name = os.path.splitext(cfg_fname)[0]
593
- else:
594
- assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
595
- opt.base[0]
596
- )[0]
597
- cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
598
- os.path.split(opt.base[0])[0].split(os.sep).index("configs")
599
- + 1 :
600
- ] # cut away the first one (we assert all configs are in "configs")
601
- cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
602
- cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
603
- name = "_" + cfg_name
604
- else:
605
- name = ""
606
- if not opt.no_date:
607
- nowname = now + name + opt.postfix
608
- else:
609
- nowname = name + opt.postfix
610
- if nowname.startswith("_"):
611
- nowname = nowname[1:]
612
- logdir = os.path.join(opt.logdir, nowname)
613
- print(f"LOGDIR: {logdir}")
614
-
615
- ckptdir = os.path.join(logdir, "checkpoints")
616
- cfgdir = os.path.join(logdir, "configs")
617
- seed_everything(opt.seed, workers=True)
618
-
619
- # move before model init, in case a torch.compile(...) is called somewhere
620
- if opt.enable_tf32:
621
- # pt_version = version.parse(torch.__version__)
622
- torch.backends.cuda.matmul.allow_tf32 = True
623
- torch.backends.cudnn.allow_tf32 = True
624
- print(f"Enabling TF32 for PyTorch {torch.__version__}")
625
- else:
626
- print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
627
- print(
628
- f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
629
- )
630
- print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
631
-
632
- try:
633
- # init and save configs
634
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
635
- cli = OmegaConf.from_dotlist(unknown)
636
- config = OmegaConf.merge(*configs, cli)
637
- lightning_config = config.pop("lightning", OmegaConf.create())
638
- # merge trainer cli with config
639
- trainer_config = lightning_config.get("trainer", OmegaConf.create())
640
-
641
- # default to gpu
642
- trainer_config["accelerator"] = "gpu"
643
- #
644
- standard_args = default_trainer_args()
645
- for k in standard_args:
646
- if getattr(opt, k) != standard_args[k]:
647
- trainer_config[k] = getattr(opt, k)
648
-
649
- ckpt_resume_path = opt.resume_from_checkpoint
650
-
651
- if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
652
- del trainer_config["accelerator"]
653
- cpu = True
654
- else:
655
- gpuinfo = trainer_config["devices"]
656
- print(f"Running on GPUs {gpuinfo}")
657
- cpu = False
658
- trainer_opt = argparse.Namespace(**trainer_config)
659
- lightning_config.trainer = trainer_config
660
-
661
- # model
662
- model = instantiate_from_config(config.model)
663
-
664
- # trainer and callbacks
665
- trainer_kwargs = dict()
666
-
667
- # default logger configs
668
- default_logger_cfgs = {
669
- "wandb": {
670
- "target": "pytorch_lightning.loggers.WandbLogger",
671
- "params": {
672
- "name": nowname,
673
- # "save_dir": logdir,
674
- "offline": opt.debug,
675
- "id": nowname,
676
- "project": opt.projectname,
677
- "log_model": False,
678
- # "dir": logdir,
679
- },
680
- },
681
- "csv": {
682
- "target": "pytorch_lightning.loggers.CSVLogger",
683
- "params": {
684
- "name": "testtube", # hack for sbord fanatics
685
- "save_dir": logdir,
686
- },
687
- },
688
- }
689
- default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
690
- if opt.wandb:
691
- # TODO change once leaving "swiffer" config directory
692
- try:
693
- group_name = nowname.split(now)[-1].split("-")[1]
694
- except:
695
- group_name = nowname
696
- default_logger_cfg["params"]["group"] = group_name
697
- init_wandb(
698
- os.path.join(os.getcwd(), logdir),
699
- opt=opt,
700
- group_name=group_name,
701
- config=config,
702
- name_str=nowname,
703
- )
704
- if "logger" in lightning_config:
705
- logger_cfg = lightning_config.logger
706
- else:
707
- logger_cfg = OmegaConf.create()
708
- logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
709
- trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
710
-
711
- # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
712
- # specify which metric is used to determine best models
713
- default_modelckpt_cfg = {
714
- "target": "pytorch_lightning.callbacks.ModelCheckpoint",
715
- "params": {
716
- "dirpath": ckptdir,
717
- "filename": "{epoch:06}",
718
- "verbose": True,
719
- "save_last": True,
720
- },
721
- }
722
- if hasattr(model, "monitor"):
723
- print(f"Monitoring {model.monitor} as checkpoint metric.")
724
- default_modelckpt_cfg["params"]["monitor"] = model.monitor
725
- default_modelckpt_cfg["params"]["save_top_k"] = 3
726
-
727
- if "modelcheckpoint" in lightning_config:
728
- modelckpt_cfg = lightning_config.modelcheckpoint
729
- else:
730
- modelckpt_cfg = OmegaConf.create()
731
- modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
732
- print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
733
-
734
- # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
735
- # default to ddp if not further specified
736
- default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
737
-
738
- if "strategy" in lightning_config:
739
- strategy_cfg = lightning_config.strategy
740
- else:
741
- strategy_cfg = OmegaConf.create()
742
- default_strategy_config["params"] = {
743
- "find_unused_parameters": False,
744
- # "static_graph": True,
745
- # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
746
- }
747
- strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
748
- print(
749
- f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
750
- )
751
- trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
752
-
753
- # add callback which sets up log directory
754
- default_callbacks_cfg = {
755
- "setup_callback": {
756
- "target": "main.SetupCallback",
757
- "params": {
758
- "resume": opt.resume,
759
- "now": now,
760
- "logdir": logdir,
761
- "ckptdir": ckptdir,
762
- "cfgdir": cfgdir,
763
- "config": config,
764
- "lightning_config": lightning_config,
765
- "debug": opt.debug,
766
- "ckpt_name": melk_ckpt_name,
767
- },
768
- },
769
- "image_logger": {
770
- "target": "main.ImageLogger",
771
- "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
772
- },
773
- "learning_rate_logger": {
774
- "target": "pytorch_lightning.callbacks.LearningRateMonitor",
775
- "params": {
776
- "logging_interval": "step",
777
- # "log_momentum": True
778
- },
779
- },
780
- }
781
- if version.parse(pl.__version__) >= version.parse("1.4.0"):
782
- default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
783
-
784
- if "callbacks" in lightning_config:
785
- callbacks_cfg = lightning_config.callbacks
786
- else:
787
- callbacks_cfg = OmegaConf.create()
788
-
789
- if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
790
- print(
791
- "Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
792
- )
793
- default_metrics_over_trainsteps_ckpt_dict = {
794
- "metrics_over_trainsteps_checkpoint": {
795
- "target": "pytorch_lightning.callbacks.ModelCheckpoint",
796
- "params": {
797
- "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
798
- "filename": "{epoch:06}-{step:09}",
799
- "verbose": True,
800
- "save_top_k": -1,
801
- "every_n_train_steps": 10000,
802
- "save_weights_only": True,
803
- },
804
- }
805
- }
806
- default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
807
-
808
- callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
809
- if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
810
- callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
811
- elif "ignore_keys_callback" in callbacks_cfg:
812
- del callbacks_cfg["ignore_keys_callback"]
813
-
814
- trainer_kwargs["callbacks"] = [
815
- instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
816
- ]
817
- if not "plugins" in trainer_kwargs:
818
- trainer_kwargs["plugins"] = list()
819
-
820
- # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
821
- trainer_opt = vars(trainer_opt)
822
- trainer_kwargs = {
823
- key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
824
- }
825
- trainer = Trainer(**trainer_opt, **trainer_kwargs)
826
-
827
- trainer.logdir = logdir ###
828
-
829
- # data
830
- data = instantiate_from_config(config.data)
831
- # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
832
- # calling these ourselves should not be necessary but it is.
833
- # lightning still takes care of proper multiprocessing though
834
- data.prepare_data()
835
- # data.setup()
836
- print("#### Data #####")
837
- try:
838
- for k in data.datasets:
839
- print(
840
- f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
841
- )
842
- except:
843
- print("datasets not yet initialized.")
844
-
845
- # configure learning rate
846
- if "batch_size" in config.data.params:
847
- bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
848
- else:
849
- bs, base_lr = (
850
- config.data.params.train.loader.batch_size,
851
- config.model.base_learning_rate,
852
- )
853
- if not cpu:
854
- ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
855
- else:
856
- ngpu = 1
857
- if "accumulate_grad_batches" in lightning_config.trainer:
858
- accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
859
- else:
860
- accumulate_grad_batches = 1
861
- print(f"accumulate_grad_batches = {accumulate_grad_batches}")
862
- lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
863
- if opt.scale_lr:
864
- model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
865
- print(
866
- "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
867
- model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
868
- )
869
- )
870
- else:
871
- model.learning_rate = base_lr
872
- print("++++ NOT USING LR SCALING ++++")
873
- print(f"Setting learning rate to {model.learning_rate:.2e}")
874
-
875
- # allow checkpointing via USR1
876
- def melk(*args, **kwargs):
877
- # run all checkpoint hooks
878
- if trainer.global_rank == 0:
879
- print("Summoning checkpoint.")
880
- if melk_ckpt_name is None:
881
- ckpt_path = os.path.join(ckptdir, "last.ckpt")
882
- else:
883
- ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
884
- trainer.save_checkpoint(ckpt_path)
885
-
886
- def divein(*args, **kwargs):
887
- if trainer.global_rank == 0:
888
- import pudb
889
-
890
- pudb.set_trace()
891
-
892
- import signal
893
-
894
- signal.signal(signal.SIGUSR1, melk)
895
- signal.signal(signal.SIGUSR2, divein)
896
-
897
- # run
898
- if opt.train:
899
- try:
900
- trainer.fit(model, data, ckpt_path=ckpt_resume_path)
901
- except Exception:
902
- if not opt.debug:
903
- melk()
904
- raise
905
- if not opt.no_test and not trainer.interrupted:
906
- trainer.test(model, data)
907
- except RuntimeError as err:
908
- if MULTINODE_HACKS:
909
- import datetime
910
- import os
911
- import socket
912
-
913
- import requests
914
-
915
- device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
916
- hostname = socket.gethostname()
917
- ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
918
- resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
919
- print(
920
- f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
921
- flush=True,
922
- )
923
- raise err
924
- except Exception:
925
- if opt.debug and trainer.global_rank == 0:
926
- try:
927
- import pudb as debugger
928
- except ImportError:
929
- import pdb as debugger
930
- debugger.post_mortem()
931
- raise
932
- finally:
933
- # move newly created debug project to debug_runs
934
- if opt.debug and not opt.resume and trainer.global_rank == 0:
935
- dst, name = os.path.split(logdir)
936
- dst = os.path.join(dst, "debug_runs", name)
937
- os.makedirs(os.path.split(dst)[0], exist_ok=True)
938
- os.rename(logdir, dst)
939
-
940
- if opt.wandb:
941
- wandb.finish()
942
- # if trainer.global_rank == 0:
943
- # print(trainer.profiler.summary())