File size: 1,944 Bytes
04fbff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import time
import wandb
import shutil
import logging
import os.path as osp
from torch.utils.tensorboard import SummaryWriter


def mv_archived_logger(name):
    timestamp = time.strftime("%Y-%m-%d_%H:%M:%S_", time.localtime())
    basename = 'archived_' + timestamp + osp.basename(name)
    archived_name = osp.join(osp.dirname(name), basename)
    shutil.move(name, archived_name) 


class CustomLogger:
    def __init__(self, common_cfg, tb_cfg=None, wandb_cfg=None, rank=0):
        global global_logger
        self.rank = rank

        if self.rank == 0:
            self.logger = logging.getLogger('VFI')
            self.logger.setLevel(logging.INFO)
            format_str = logging.Formatter(common_cfg['format'])

            console_handler = logging.StreamHandler()
            console_handler.setFormatter(format_str)

            if osp.exists(common_cfg['filename']):
                mv_archived_logger(common_cfg['filename'])

            file_handler = logging.FileHandler(common_cfg['filename'],
                                               common_cfg['filemode'])
            file_handler.setFormatter(format_str)

            self.logger.addHandler(console_handler)
            self.logger.addHandler(file_handler)
            self.tb_logger = None

            self.enable_wandb = False

            if wandb_cfg is not None:
                self.enable_wandb = True
                wandb.init(**wandb_cfg)

            if tb_cfg is not None:
                self.tb_logger = SummaryWriter(**tb_cfg)

        global_logger = self

    def __call__(self, msg=None, level=logging.INFO, tb_msg=None):
        if self.rank != 0:
            return
        if msg is not None:
            self.logger.log(level, msg)

        if self.tb_logger is not None and tb_msg is not None:
            self.tb_logger.add_scalar(*tb_msg)

    def close(self):
        if self.rank == 0 and self.enable_wandb:
            wandb.finish()