import time import wandb import shutil import logging import os.path as osp from torch.utils.tensorboard import SummaryWriter def mv_archived_logger(name): timestamp = time.strftime("%Y-%m-%d_%H:%M:%S_", time.localtime()) basename = 'archived_' + timestamp + osp.basename(name) archived_name = osp.join(osp.dirname(name), basename) shutil.move(name, archived_name) class CustomLogger: def __init__(self, common_cfg, tb_cfg=None, wandb_cfg=None, rank=0): global global_logger self.rank = rank if self.rank == 0: self.logger = logging.getLogger('VFI') self.logger.setLevel(logging.INFO) format_str = logging.Formatter(common_cfg['format']) console_handler = logging.StreamHandler() console_handler.setFormatter(format_str) if osp.exists(common_cfg['filename']): mv_archived_logger(common_cfg['filename']) file_handler = logging.FileHandler(common_cfg['filename'], common_cfg['filemode']) file_handler.setFormatter(format_str) self.logger.addHandler(console_handler) self.logger.addHandler(file_handler) self.tb_logger = None self.enable_wandb = False if wandb_cfg is not None: self.enable_wandb = True wandb.init(**wandb_cfg) if tb_cfg is not None: self.tb_logger = SummaryWriter(**tb_cfg) global_logger = self def __call__(self, msg=None, level=logging.INFO, tb_msg=None): if self.rank != 0: return if msg is not None: self.logger.log(level, msg) if self.tb_logger is not None and tb_msg is not None: self.tb_logger.add_scalar(*tb_msg) def close(self): if self.rank == 0 and self.enable_wandb: wandb.finish()