Spaces:
Running
Running
File size: 2,998 Bytes
e17e8cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""
import torch
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
from torch.distributions.categorical import Categorical
import torch.nn.functional as F
def _get_key_padding_mask(commands, seq_dim=0):
"""
Args:
commands: Shape [S, ...]
"""
with torch.no_grad():
key_padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) > 0
if seq_dim == 0:
return key_padding_mask.transpose(0, 1)
return key_padding_mask
def _get_padding_mask(commands, seq_dim=0, extended=False):
with torch.no_grad():
padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) == 0
padding_mask = padding_mask.float()
if extended:
# padding_mask doesn't include the final EOS, extend by 1 position to include it in the loss
S = commands.size(seq_dim)
torch.narrow(padding_mask, seq_dim, 3, S-3).add_(torch.narrow(padding_mask, seq_dim, 0, S-3)).clamp_(max=1)
if seq_dim == 0:
return padding_mask.unsqueeze(-1)
return padding_mask
def _get_group_mask(commands, seq_dim=0):
"""
Args:
commands: Shape [S, ...]
"""
with torch.no_grad():
group_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("m")).cumsum(dim=seq_dim)
return group_mask
def _get_visibility_mask(commands, seq_dim=0):
"""
Args:
commands: Shape [S, ...]
"""
S = commands.size(seq_dim)
with torch.no_grad():
visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) < S - 1
if seq_dim == 0:
return visibility_mask.unsqueeze(-1)
return visibility_mask
def _get_key_visibility_mask(commands, seq_dim=0):
S = commands.size(seq_dim)
with torch.no_grad():
key_visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) >= S - 1
if seq_dim == 0:
return key_visibility_mask.transpose(0, 1)
return key_visibility_mask
def _generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def _sample_categorical(temperature=0.0001, *args_logits):
if len(args_logits) == 1:
arg_logits, = args_logits
return Categorical(logits=arg_logits / temperature).sample()
return (*(Categorical(logits=arg_logits / temperature).sample() for arg_logits in args_logits),)
def _threshold_sample(arg_logits, threshold=0.5, temperature=1.0):
scores = F.softmax(arg_logits / temperature, dim=-1)[..., 1]
return scores > threshold
|