Daniel Gil-U Fuhge
add model files
e17e8cc
raw
history blame
11 kB
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""
from __future__ import annotations
import torch
import torch.utils.data
from typing import Union
Num = Union[int, float]
class AnimationTensor:
COMMANDS_SIMPLIFIED = ['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9']
CMD_ARGS_MASK = torch.tensor([[0, 0, 0], # a0
[0, 0, 0], # a1
[0, 0, 0], # a2
[1, 1, 1], # a3
[0, 0, 0], # a4
[0, 0, 0], # a5
[0, 0, 0], # a6
[0, 0, 0], # a7
[1, 1, 1], # a8
[0, 0, 0]]) # a9
class Index:
COMMAND = 0
DURATION = 1
FROM = 2
BEGIN = 3
class IndexArgs:
DURATION = 0
FROM = 1
BEGIN = 2
all_arg_keys = ['duration', 'from', 'begin']
cmd_arg_keys = ["commands", *all_arg_keys]
all_keys = ["commands", *all_arg_keys]
def __init__(self, commands, duration, from_, begin,
seq_len=None, label=None, PAD_VAL=-1, ARGS_DIM=256, filling=0):
self.commands = commands.reshape(-1, 1).float()
self.duration = duration.float()
self.from_ = from_.float()
self.begin = begin.float()
self.seq_len = torch.tensor(len(commands)) if seq_len is None else seq_len
self.label = label
self.PAD_VAL = PAD_VAL
self.ARGS_DIM = ARGS_DIM
# self.sos_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("SOS")]).unsqueeze(-1)
# self.eos_token = self.pad_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("EOS")]).unsqueeze(-1)
self.filling = filling
class SVGTensor:
# 0 1 2 3 4 5 6
COMMANDS_SIMPLIFIED = ["m", "l", "c", "a", "EOS", "SOS", "z"]
# rad x lrg sw ctrl ctrl end
# ius axs arc eep 1 2 pos
# rot fg fg
CMD_ARGS_MASK = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # m
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # l
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], # c
[1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1], # a
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # EOS
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # SOS
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) # z
class Index:
COMMAND = 0
RADIUS = slice(1, 3)
X_AXIS_ROT = 3
LARGE_ARC_FLG = 4
SWEEP_FLG = 5
START_POS = slice(6, 8)
CONTROL1 = slice(8, 10)
CONTROL2 = slice(10, 12)
END_POS = slice(12, 14)
class IndexArgs:
RADIUS = slice(0, 2)
X_AXIS_ROT = 2
LARGE_ARC_FLG = 3
SWEEP_FLG = 4
CONTROL1 = slice(5, 7)
CONTROL2 = slice(7, 9)
END_POS = slice(9, 11)
position_keys = ["control1", "control2", "end_pos"]
all_position_keys = ["start_pos", *position_keys]
arg_keys = ["radius", "x_axis_rot", "large_arc_flg", "sweep_flg", *position_keys]
all_arg_keys = [*arg_keys[:4], "start_pos", *arg_keys[4:]]
cmd_arg_keys = ["commands", *arg_keys]
all_keys = ["commands", *all_arg_keys]
def __init__(self, commands, radius, x_axis_rot, large_arc_flg, sweep_flg, control1, control2, end_pos,
seq_len=None, label=None, PAD_VAL=-1, ARGS_DIM=256, filling=0):
self.commands = commands.reshape(-1, 1).float()
self.radius = radius.float()
self.x_axis_rot = x_axis_rot.reshape(-1, 1).float()
self.large_arc_flg = large_arc_flg.reshape(-1, 1).float()
self.sweep_flg = sweep_flg.reshape(-1, 1).float()
self.control1 = control1.float()
self.control2 = control2.float()
self.end_pos = end_pos.float()
self.seq_len = torch.tensor(len(commands)) if seq_len is None else seq_len
self.label = label
self.PAD_VAL = PAD_VAL
self.ARGS_DIM = ARGS_DIM
self.sos_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("SOS")]).unsqueeze(-1)
self.eos_token = self.pad_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("EOS")]).unsqueeze(-1)
self.filling = filling
@property
def start_pos(self):
start_pos = self.end_pos[:-1]
return torch.cat([
start_pos.new_zeros(1, 2),
start_pos
])
@staticmethod
def from_data(data, *args, **kwargs):
return SVGTensor(data[:, SVGTensor.Index.COMMAND], data[:, SVGTensor.Index.RADIUS], data[:, SVGTensor.Index.X_AXIS_ROT],
data[:, SVGTensor.Index.LARGE_ARC_FLG], data[:, SVGTensor.Index.SWEEP_FLG], data[:, SVGTensor.Index.CONTROL1],
data[:, SVGTensor.Index.CONTROL2], data[:, SVGTensor.Index.END_POS], *args, **kwargs)
@staticmethod
def from_cmd_args(commands, args, *nargs, **kwargs):
return SVGTensor(commands, args[:, SVGTensor.IndexArgs.RADIUS], args[:, SVGTensor.IndexArgs.X_AXIS_ROT],
args[:, SVGTensor.IndexArgs.LARGE_ARC_FLG], args[:, SVGTensor.IndexArgs.SWEEP_FLG], args[:, SVGTensor.IndexArgs.CONTROL1],
args[:, SVGTensor.IndexArgs.CONTROL2], args[:, SVGTensor.IndexArgs.END_POS], *nargs, **kwargs)
def get_data(self, keys):
return torch.cat([self.__getattribute__(key) for key in keys], dim=-1)
@property
def data(self):
return self.get_data(self.all_keys)
def copy(self):
return SVGTensor(*[self.__getattribute__(key).clone() for key in self.cmd_arg_keys],
seq_len=self.seq_len.clone(), label=self.label, PAD_VAL=self.PAD_VAL, ARGS_DIM=self.ARGS_DIM,
filling=self.filling)
def add_sos(self):
self.commands = torch.cat([self.sos_token, self.commands])
for key in self.arg_keys:
v = self.__getattribute__(key)
self.__setattr__(key, torch.cat([v.new_full((1, v.size(-1)), self.PAD_VAL), v]))
self.seq_len += 1
return self
def drop_sos(self):
for key in self.cmd_arg_keys:
self.__setattr__(key, self.__getattribute__(key)[1:])
self.seq_len -= 1
return self
def add_eos(self):
self.commands = torch.cat([self.commands, self.eos_token])
for key in self.arg_keys:
v = self.__getattribute__(key)
self.__setattr__(key, torch.cat([v, v.new_full((1, v.size(-1)), self.PAD_VAL)]))
return self
def pad(self, seq_len=51):
pad_len = max(seq_len - len(self.commands), 0)
self.commands = torch.cat([self.commands, self.pad_token.repeat(pad_len, 1)])
for key in self.arg_keys:
v = self.__getattribute__(key)
self.__setattr__(key, torch.cat([v, v.new_full((pad_len, v.size(-1)), self.PAD_VAL)]))
return self
def unpad(self):
# Remove EOS + padding
for key in self.cmd_arg_keys:
self.__setattr__(key, self.__getattribute__(key)[:self.seq_len])
return self
def draw(self, *args, **kwags):
from deepsvg.svglib.svg import SVGPath
return SVGPath.from_tensor(self.data).draw(*args, **kwags)
def cmds(self):
return self.commands.reshape(-1)
def args(self, with_start_pos=False):
if with_start_pos:
return self.get_data(self.all_arg_keys)
return self.get_data(self.arg_keys)
def _get_real_commands_mask(self):
mask = self.cmds() < self.COMMANDS_SIMPLIFIED.index("EOS")
return mask
def _get_args_mask(self):
mask = SVGTensor.CMD_ARGS_MASK[self.cmds().long()].bool()
return mask
def get_relative_args(self):
data = self.args().clone()
real_commands = self._get_real_commands_mask()
data_real_commands = data[real_commands]
start_pos = data_real_commands[:-1, SVGTensor.IndexArgs.END_POS].clone()
data_real_commands[1:, SVGTensor.IndexArgs.CONTROL1] -= start_pos
data_real_commands[1:, SVGTensor.IndexArgs.CONTROL2] -= start_pos
data_real_commands[1:, SVGTensor.IndexArgs.END_POS] -= start_pos
data[real_commands] = data_real_commands
mask = self._get_args_mask()
data[mask] += self.ARGS_DIM - 1
data[~mask] = self.PAD_VAL
return data
def sample_points(self, n=10):
device = self.commands.device
z = torch.linspace(0, 1, n, device=device)
Z = torch.stack([torch.ones_like(z), z, z.pow(2), z.pow(3)], dim=1)
Q = torch.tensor([
[[0., 0., 0., 0.], #  "m"
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[1., 0., 0., 0.], # "l"
[-1, 0., 0., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[1., 0., 0., 0.], #  "c"
[-3, 3., 0., 0.],
[3., -6, 3., 0.],
[-1, 3., -3, 1.]],
torch.zeros(4, 4), # "a", no support yet
torch.zeros(4, 4), # "EOS"
torch.zeros(4, 4), # "SOS"
torch.zeros(4, 4), # "z"
], device=device)
commands, pos = self.commands.reshape(-1).long(), self.get_data(self.all_position_keys).reshape(-1, 4, 2)
inds = (commands == self.COMMANDS_SIMPLIFIED.index("l")) | (commands == self.COMMANDS_SIMPLIFIED.index("c"))
commands, pos = commands[inds], pos[inds]
Z_coeffs = torch.matmul(Q[commands], pos)
# Last point being first point of next command, we drop last point except the one from the last command
sample_points = torch.matmul(Z, Z_coeffs)
sample_points = torch.cat([sample_points[:, :-1].reshape(-1, 2), sample_points[-1, -1].unsqueeze(0)])
return sample_points
@staticmethod
def get_length_distribution(p, normalize=True):
start, end = p[:-1], p[1:]
length_distr = torch.norm(end - start, dim=-1).cumsum(dim=0)
length_distr = torch.cat([length_distr.new_zeros(1), length_distr])
if normalize:
length_distr = length_distr / length_distr[-1]
return length_distr
def sample_uniform_points(self, n=100):
p = self.sample_points(n=n)
distr_unif = torch.linspace(0., 1., n).to(p.device)
distr = self.get_length_distribution(p, normalize=True)
d = torch.cdist(distr_unif.unsqueeze(-1), distr.unsqueeze(-1))
matching = d.argmin(dim=-1)
return p[matching]