File size: 3,736 Bytes
e17e8cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""

import torch.optim as optim
from src.preprocessing.deepsvg.deepsvg_schedulers.warmup import GradualWarmupScheduler


class _Config:
    """
    Training config.
    """
    def __init__(self, num_gpus=1):

        self.num_gpus = num_gpus                              #

        self.dataloader_module = "deepsvg.svgtensor_dataset"  #
        self.collate_fn = None                                #
        self.data_dir = "./data/svgs_tensors/"             #
        self.meta_filepath = "./data/svgs_meta.csv"            #
        self.loader_num_workers = 0                           #

        self.pretrained_path = "./models/hierarchical_ordered.pth.tar"        #

        self.model_cfg = None                                 #

        self.num_epochs = None                                #
        self.num_steps = None                                 #
        self.learning_rate = 1e-3                             #
        self.batch_size = 100                                 #
        self.warmup_steps = 500                               #


        # Dataset
        self.train_ratio = 1.0                                #
        self.nb_augmentations = 1                             #

        self.max_num_groups = 15                              #
        self.max_seq_len = 30                                 #
        self.max_total_len = None                             #

        self.filter_uni = None                                #
        self.filter_category = None                           #
        self.filter_platform = None                           #

        self.filter_labels = None                             #

        self.grad_clip = None                                 #

        self.log_every = 20                                   #
        self.val_every = 1000                                 #
        self.ckpt_every = 1000                                #

        self.stats_to_print = {
            "train": ["lr", "time"]
        }

        self.model_args = []                                  #
        self.optimizer_starts = [0]                           #

    # Overridable methods
    def make_model(self):
        raise NotImplementedError

    def make_losses(self):
        raise NotImplementedError

    def make_optimizers(self, model):
        return [optim.AdamW(model.parameters(), self.learning_rate)]

    def make_schedulers(self, optimizers, epoch_size):
        return [None] * len(optimizers)

    def make_warmup_schedulers(self, optimizers, scheduler_lrs):
        return [GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=self.warmup_steps, after_scheduler=scheduler_lr)
                for optimizer, scheduler_lr in zip(optimizers, scheduler_lrs)]

    def get_params(self, step, epoch):
        return {}

    def get_weights(self, step, epoch):
        return {}

    def set_train_vars(self, train_vars, dataloader):
        pass

    def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
        pass

    # Utility methods
    def values(self):
        for key in dir(self):
            if not key.startswith("__") and not callable(getattr(self, key)):
                yield key, getattr(self, key)

    def to_dict(self):
        return {key: val for key, val in self.values()}

    def load_dict(self, dict):
        for key, val in dict.items():
            setattr(self, key, val)

    def print_params(self):
        for key, val in self.values():
            print(f"  {key} = {val}")