"""This code is taken from by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte from the paper >https://arxiv.org/pdf/2007.11301.pdf> """ import torch.optim as optim from src.preprocessing.deepsvg.deepsvg_schedulers.warmup import GradualWarmupScheduler class _Config: """ Training config. """ def __init__(self, num_gpus=1): self.num_gpus = num_gpus # self.dataloader_module = "deepsvg.svgtensor_dataset" # self.collate_fn = None # self.data_dir = "./data/svgs_tensors/" # self.meta_filepath = "./data/svgs_meta.csv" # self.loader_num_workers = 0 # self.pretrained_path = "./models/hierarchical_ordered.pth.tar" # self.model_cfg = None # self.num_epochs = None # self.num_steps = None # self.learning_rate = 1e-3 # self.batch_size = 100 # self.warmup_steps = 500 # # Dataset self.train_ratio = 1.0 # self.nb_augmentations = 1 # self.max_num_groups = 15 # self.max_seq_len = 30 # self.max_total_len = None # self.filter_uni = None # self.filter_category = None # self.filter_platform = None # self.filter_labels = None # self.grad_clip = None # self.log_every = 20 # self.val_every = 1000 # self.ckpt_every = 1000 # self.stats_to_print = { "train": ["lr", "time"] } self.model_args = [] # self.optimizer_starts = [0] # # Overridable methods def make_model(self): raise NotImplementedError def make_losses(self): raise NotImplementedError def make_optimizers(self, model): return [optim.AdamW(model.parameters(), self.learning_rate)] def make_schedulers(self, optimizers, epoch_size): return [None] * len(optimizers) def make_warmup_schedulers(self, optimizers, scheduler_lrs): return [GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=self.warmup_steps, after_scheduler=scheduler_lr) for optimizer, scheduler_lr in zip(optimizers, scheduler_lrs)] def get_params(self, step, epoch): return {} def get_weights(self, step, epoch): return {} def set_train_vars(self, train_vars, dataloader): pass def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir): pass # Utility methods def values(self): for key in dir(self): if not key.startswith("__") and not callable(getattr(self, key)): yield key, getattr(self, key) def to_dict(self): return {key: val for key, val in self.values()} def load_dict(self, dict): for key, val in dict.items(): setattr(self, key, val) def print_params(self): for key, val in self.values(): print(f" {key} = {val}")