Spaces:
Running
Running
File size: 3,736 Bytes
e17e8cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""
import torch.optim as optim
from src.preprocessing.deepsvg.deepsvg_schedulers.warmup import GradualWarmupScheduler
class _Config:
"""
Training config.
"""
def __init__(self, num_gpus=1):
self.num_gpus = num_gpus #
self.dataloader_module = "deepsvg.svgtensor_dataset" #
self.collate_fn = None #
self.data_dir = "./data/svgs_tensors/" #
self.meta_filepath = "./data/svgs_meta.csv" #
self.loader_num_workers = 0 #
self.pretrained_path = "./models/hierarchical_ordered.pth.tar" #
self.model_cfg = None #
self.num_epochs = None #
self.num_steps = None #
self.learning_rate = 1e-3 #
self.batch_size = 100 #
self.warmup_steps = 500 #
# Dataset
self.train_ratio = 1.0 #
self.nb_augmentations = 1 #
self.max_num_groups = 15 #
self.max_seq_len = 30 #
self.max_total_len = None #
self.filter_uni = None #
self.filter_category = None #
self.filter_platform = None #
self.filter_labels = None #
self.grad_clip = None #
self.log_every = 20 #
self.val_every = 1000 #
self.ckpt_every = 1000 #
self.stats_to_print = {
"train": ["lr", "time"]
}
self.model_args = [] #
self.optimizer_starts = [0] #
# Overridable methods
def make_model(self):
raise NotImplementedError
def make_losses(self):
raise NotImplementedError
def make_optimizers(self, model):
return [optim.AdamW(model.parameters(), self.learning_rate)]
def make_schedulers(self, optimizers, epoch_size):
return [None] * len(optimizers)
def make_warmup_schedulers(self, optimizers, scheduler_lrs):
return [GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=self.warmup_steps, after_scheduler=scheduler_lr)
for optimizer, scheduler_lr in zip(optimizers, scheduler_lrs)]
def get_params(self, step, epoch):
return {}
def get_weights(self, step, epoch):
return {}
def set_train_vars(self, train_vars, dataloader):
pass
def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
pass
# Utility methods
def values(self):
for key in dir(self):
if not key.startswith("__") and not callable(getattr(self, key)):
yield key, getattr(self, key)
def to_dict(self):
return {key: val for key, val in self.values()}
def load_dict(self, dict):
for key, val in dict.items():
setattr(self, key, val)
def print_params(self):
for key, val in self.values():
print(f" {key} = {val}")
|