Spaces:
Running
Running
"""This code is taken from <https://github.com/alexandre01/deepsvg> | |
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte | |
from the paper >https://arxiv.org/pdf/2007.11301.pdf> | |
""" | |
import torch | |
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor | |
from torch.distributions.categorical import Categorical | |
import torch.nn.functional as F | |
def _get_key_padding_mask(commands, seq_dim=0): | |
""" | |
Args: | |
commands: Shape [S, ...] | |
""" | |
with torch.no_grad(): | |
key_padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) > 0 | |
if seq_dim == 0: | |
return key_padding_mask.transpose(0, 1) | |
return key_padding_mask | |
def _get_padding_mask(commands, seq_dim=0, extended=False): | |
with torch.no_grad(): | |
padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) == 0 | |
padding_mask = padding_mask.float() | |
if extended: | |
# padding_mask doesn't include the final EOS, extend by 1 position to include it in the loss | |
S = commands.size(seq_dim) | |
torch.narrow(padding_mask, seq_dim, 3, S-3).add_(torch.narrow(padding_mask, seq_dim, 0, S-3)).clamp_(max=1) | |
if seq_dim == 0: | |
return padding_mask.unsqueeze(-1) | |
return padding_mask | |
def _get_group_mask(commands, seq_dim=0): | |
""" | |
Args: | |
commands: Shape [S, ...] | |
""" | |
with torch.no_grad(): | |
group_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("m")).cumsum(dim=seq_dim) | |
return group_mask | |
def _get_visibility_mask(commands, seq_dim=0): | |
""" | |
Args: | |
commands: Shape [S, ...] | |
""" | |
S = commands.size(seq_dim) | |
with torch.no_grad(): | |
visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) < S - 1 | |
if seq_dim == 0: | |
return visibility_mask.unsqueeze(-1) | |
return visibility_mask | |
def _get_key_visibility_mask(commands, seq_dim=0): | |
S = commands.size(seq_dim) | |
with torch.no_grad(): | |
key_visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) >= S - 1 | |
if seq_dim == 0: | |
return key_visibility_mask.transpose(0, 1) | |
return key_visibility_mask | |
def _generate_square_subsequent_mask(sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
def _sample_categorical(temperature=0.0001, *args_logits): | |
if len(args_logits) == 1: | |
arg_logits, = args_logits | |
return Categorical(logits=arg_logits / temperature).sample() | |
return (*(Categorical(logits=arg_logits / temperature).sample() for arg_logits in args_logits),) | |
def _threshold_sample(arg_logits, threshold=0.5, temperature=1.0): | |
scores = F.softmax(arg_logits / temperature, dim=-1)[..., 1] | |
return scores > threshold | |