File size: 4,243 Bytes
d526dbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
r""" Logging """

import datetime
import logging
import os

from tensorboardX import SummaryWriter
import torch


class Logger:
    r""" Writes results of training/testing """
    @classmethod
    def initialize(cls, args, training):
        logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
        logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime
        if logpath == '': logpath = logtime

        cls.logpath = os.path.join('logs', logpath + '.log')
        cls.benchmark = args.benchmark
        os.makedirs(cls.logpath)

        logging.basicConfig(filemode='w',
                            filename=os.path.join(cls.logpath, 'log.txt'),
                            level=logging.INFO,
                            format='%(message)s',
                            datefmt='%m-%d %H:%M:%S')

        # Console log config
        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        formatter = logging.Formatter('%(message)s')
        console.setFormatter(formatter)
        logging.getLogger('').addHandler(console)

        # Tensorboard writer
        cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))

        # Log arguments
        if training:
            logging.info(':======== Convolutional Hough Matching Networks =========')
            for arg_key in args.__dict__:
                logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
            logging.info(':========================================================\n')

    @classmethod
    def info(cls, msg):
        r""" Writes message to .txt """
        logging.info(msg)

    @classmethod
    def save_model(cls, model, epoch, val_pck):
        torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt'))
        cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck))


class AverageMeter:
    r""" Stores loss, evaluation results, selected layers """
    def __init__(self, benchamrk):
        r""" Constructor of AverageMeter """
        self.buffer_keys = ['pck']
        self.buffer = {}
        for key in self.buffer_keys:
            self.buffer[key] = []

        self.loss_buffer = []

    def update(self, eval_result, loss=None):
        for key in self.buffer_keys:
            self.buffer[key] += eval_result[key]

        if loss is not None:
            self.loss_buffer.append(loss)

    def write_result(self, split, epoch):
        msg = '\n*** %s ' % split
        msg += '[@Epoch %02d] ' % epoch

        if len(self.loss_buffer) > 0:
            msg += 'Loss: %5.2f  ' % (sum(self.loss_buffer) / len(self.loss_buffer))

        for key in self.buffer_keys:
            msg += '%s: %6.2f  ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
        msg += '***\n'
        Logger.info(msg)

    def write_process(self, batch_idx, datalen, epoch):
        msg = '[Epoch: %02d] ' % epoch
        msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
        if len(self.loss_buffer) > 0:
            msg += 'Loss: %5.2f  ' % self.loss_buffer[-1]
            msg += 'Avg Loss: %5.5f  ' % (sum(self.loss_buffer) / len(self.loss_buffer))

        for key in self.buffer_keys:
            msg += 'Avg %s: %5.2f  ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100)
        Logger.info(msg)

    def write_test_process(self, batch_idx, datalen):
        msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)

        for key in self.buffer_keys:
            if key == 'pck':
                pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100
                val = ''
                for p in pcks:
                    val += '%5.2f   ' % p.item()
                msg += 'Avg %s: %s   ' % (key.upper(), val)
            else:
                msg += 'Avg %s: %5.2f  ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
        Logger.info(msg)

    def get_test_result(self):
        result = {}
        for key in self.buffer_keys:
            result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100

        return result