import os.path as osp import warnings warnings.filterwarnings('ignore') from typing import Optional from pathlib import Path from models.maplocnet import MapLocNet import hydra import pytorch_lightning as pl import torch from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only from module import GenericModule from logger import logger, pl_logger, EXPERIMENTS_PATH from module import GenericModule from dataset import UavMapDatasetModule from pytorch_lightning.callbacks.early_stopping import EarlyStopping # print(osp.join(osp.dirname(__file__), "conf")) class CleanProgressBar(pl.callbacks.TQDMProgressBar): def get_metrics(self, trainer, model): items = super().get_metrics(trainer, model) items.pop("v_num", None) # don't show the version number items.pop("loss", None) return items class SeedingCallback(pl.callbacks.Callback): def on_epoch_start_(self, trainer, module): seed = module.cfg.experiment.seed is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0 if trainer.training and not is_overfit: seed = seed + trainer.current_epoch # Temporarily disable the logging (does not seem to work?) pl_logger.disabled = True try: pl.seed_everything(seed, workers=True) finally: pl_logger.disabled = False def on_train_epoch_start(self, *args, **kwargs): self.on_epoch_start_(*args, **kwargs) def on_validation_epoch_start(self, *args, **kwargs): self.on_epoch_start_(*args, **kwargs) def on_test_epoch_start(self, *args, **kwargs): self.on_epoch_start_(*args, **kwargs) class ConsoleLogger(pl.callbacks.Callback): @rank_zero_only def on_train_epoch_start(self, trainer, module): logger.info( "New training epoch %d for experiment '%s'.", module.current_epoch, module.cfg.experiment.name, ) # @rank_zero_only # def on_validation_epoch_end(self, trainer, module): # results = { # **dict(module.metrics_val.items()), # **dict(module.losses_val.items()), # } # results = [f"{k} {v.compute():.3E}" for k, v in results.items()] # logger.info(f'[Validation] {{{", ".join(results)}}}') def find_last_checkpoint_path(experiment_dir): cls = pl.callbacks.ModelCheckpoint path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION) if osp.exists(path): return path else: return None def prepare_experiment_dir(experiment_dir, cfg, rank): config_path = osp.join(experiment_dir, "config.yaml") last_checkpoint_path = find_last_checkpoint_path(experiment_dir) if last_checkpoint_path is not None: if rank == 0: logger.info( "Resuming the training from checkpoint %s", last_checkpoint_path ) if osp.exists(config_path): with open(config_path, "r") as fp: cfg_prev = OmegaConf.create(fp.read()) compare_keys = ["experiment", "data", "model", "training"] if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy( cfg_prev, compare_keys ): raise ValueError( "Attempting to resume training with a different config: " f"{OmegaConf.masked_copy(cfg, compare_keys)} vs " f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}" ) if rank == 0: Path(experiment_dir).mkdir(exist_ok=True, parents=True) with open(config_path, "w") as fp: OmegaConf.save(cfg, fp) return last_checkpoint_path def train(cfg: DictConfig) -> None: torch.set_float32_matmul_precision("medium") OmegaConf.resolve(cfg) rank = rank_zero_only.rank if rank == 0: logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg)) if cfg.experiment.gpus in (None, 0): logger.warning("Will train on CPU...") cfg.experiment.gpus = 0 elif not torch.cuda.is_available(): raise ValueError("Requested GPU but no NVIDIA drivers found.") pl.seed_everything(cfg.experiment.seed, workers=True) init_checkpoint_path = cfg.training.get("finetune_from_checkpoint") if init_checkpoint_path is not None: logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path) model = GenericModule.load_from_checkpoint( init_checkpoint_path, strict=True, find_best=False, cfg=cfg ) else: model = GenericModule(cfg) if rank == 0: logger.info("Network:\n%s", model.model) experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name) last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank) checkpointing_epoch = pl.callbacks.ModelCheckpoint( dirpath=experiment_dir, filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}", auto_insert_metric_name=False, save_last=True, every_n_epochs=1, save_on_train_epoch_end=True, verbose=True, **cfg.training.checkpointing, ) checkpointing_step = pl.callbacks.ModelCheckpoint( dirpath=experiment_dir, filename="checkpoint-step-{step}-{loss/total/val:02f}", auto_insert_metric_name=False, save_last=True, every_n_train_steps=1000, verbose=True, **cfg.training.checkpointing, ) checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing" # 创建 EarlyStopping 回调 early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5) strategy = None if cfg.experiment.gpus > 1: strategy = pl.strategies.DDPStrategy(find_unused_parameters=False) for split in ["train", "val"]: cfg.data[split].batch_size = ( cfg.data[split].batch_size // cfg.experiment.gpus ) cfg.data[split].num_workers = int( (cfg.data[split].num_workers + cfg.experiment.gpus - 1) / cfg.experiment.gpus ) # data = data_modules[cfg.data.get("name", "mapillary")](cfg.data) datamodule =UavMapDatasetModule(cfg.data) tb_args = {"name": cfg.experiment.name, "version": ""} tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args) callbacks = [ checkpointing_epoch, checkpointing_step, # early_stopping_callback, pl.callbacks.LearningRateMonitor(), SeedingCallback(), CleanProgressBar(), ConsoleLogger(), ] if cfg.experiment.gpus > 0: callbacks.append(pl.callbacks.DeviceStatsMonitor()) trainer = pl.Trainer( default_root_dir=experiment_dir, detect_anomaly=False, # strategy=ddp_find_unused_parameters_true, enable_model_summary=True, sync_batchnorm=True, enable_checkpointing=True, logger=tb, callbacks=callbacks, strategy=strategy, check_val_every_n_epoch=1, accelerator="gpu", num_nodes=1, **cfg.training.trainer, ) trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path) @hydra.main( config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml" ) def main(cfg: DictConfig) -> None: OmegaConf.save(config=cfg, f='maplocnet.yaml') train(cfg) if __name__ == "__main__": main()