Spaces:
Running
Running
"""This code is taken from <https://github.com/alexandre01/deepsvg> | |
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte | |
from the paper >https://arxiv.org/pdf/2007.11301.pdf> | |
""" | |
from src.preprocessing.deepsvg.deepsvg_config.config import _Config | |
from src.preprocessing.deepsvg.deepsvg_models.model import SVGTransformer | |
from src.preprocessing.deepsvg.deepsvg_models.loss import SVGLoss | |
from src.preprocessing.deepsvg.deepsvg_models.model_config import * | |
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG | |
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor | |
from src.preprocessing.deepsvg.deepsvg_svglib.svglib_utils import make_grid | |
from src.preprocessing.deepsvg.deepsvg_svglib.geom import Bbox | |
from src.preprocessing.deepsvg.deepsvg_utils.utils import batchify, linear | |
import torchvision.transforms.functional as TF | |
import torch.optim.lr_scheduler as lr_scheduler | |
import random | |
class ModelConfig(Hierarchical): | |
""" | |
Overriding default model config. | |
""" | |
def __init__(self): | |
super().__init__() | |
class Config(_Config): | |
""" | |
Overriding default training config. | |
""" | |
def __init__(self, num_gpus=1): | |
super().__init__(num_gpus=num_gpus) | |
# Model | |
self.model_cfg = ModelConfig() | |
self.model_args = self.model_cfg.get_model_args() | |
# Dataset | |
self.filter_category = None | |
self.train_ratio = 1.0 | |
self.max_num_groups = 8 | |
self.max_total_len = 50 | |
# Dataloader | |
self.loader_num_workers = 4 * num_gpus | |
# Training | |
self.num_epochs = 50 | |
self.val_every = 1000 | |
# Optimization | |
self.learning_rate = 1e-3 * num_gpus | |
self.batch_size = 60 * num_gpus | |
self.grad_clip = 1.0 | |
def make_schedulers(self, optimizers, epoch_size): | |
optimizer, = optimizers | |
return [lr_scheduler.StepLR(optimizer, step_size=2.5 * epoch_size, gamma=0.9)] | |
def make_model(self): | |
return SVGTransformer(self.model_cfg) | |
def make_losses(self): | |
return [SVGLoss(self.model_cfg)] | |
def get_weights(self, step, epoch): | |
return { | |
"kl_tolerance": 0.1, | |
"loss_kl_weight": linear(0, 10, step, 0, 10000), | |
"loss_hierarch_weight": 1.0, | |
"loss_cmd_weight": 1.0, | |
"loss_args_weight": 2.0, | |
"loss_visibility_weight": 1.0 | |
} | |
def set_train_vars(self, train_vars, dataloader): | |
train_vars.x_inputs_train = [dataloader.dataset.get(idx, [*self.model_args, "tensor_grouped"]) | |
for idx in random.sample(range(len(dataloader.dataset)), k=10)] | |
def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir): | |
device = next(model.parameters()).device | |
# Reconstruction | |
for i, data in enumerate(train_vars.x_inputs_train): | |
model_args = batchify((data[key] for key in self.model_args), device) | |
commands_y, args_y = model.module.greedy_sample(*model_args) | |
tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu()) | |
try: | |
svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color("random") | |
except: | |
continue | |
tensor_target = data["tensor_grouped"][0].copy().drop_sos().unpad() | |
svg_path_gt = SVG.from_tensor(tensor_target.data, viewbox=Bbox(256)).normalize().split_paths().set_color("random") | |
img = make_grid([svg_path_sample, svg_path_gt]).draw(do_display=False, return_png=True, fill=False, with_points=False) | |
summary_writer.add_image(f"reconstructions_train/{i}", TF.to_tensor(img), step) | |