"""This code is taken from by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte from the paper >https://arxiv.org/pdf/2007.11301.pdf> """ import torch from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor from torch.distributions.categorical import Categorical import torch.nn.functional as F def _get_key_padding_mask(commands, seq_dim=0): """ Args: commands: Shape [S, ...] """ with torch.no_grad(): key_padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) > 0 if seq_dim == 0: return key_padding_mask.transpose(0, 1) return key_padding_mask def _get_padding_mask(commands, seq_dim=0, extended=False): with torch.no_grad(): padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) == 0 padding_mask = padding_mask.float() if extended: # padding_mask doesn't include the final EOS, extend by 1 position to include it in the loss S = commands.size(seq_dim) torch.narrow(padding_mask, seq_dim, 3, S-3).add_(torch.narrow(padding_mask, seq_dim, 0, S-3)).clamp_(max=1) if seq_dim == 0: return padding_mask.unsqueeze(-1) return padding_mask def _get_group_mask(commands, seq_dim=0): """ Args: commands: Shape [S, ...] """ with torch.no_grad(): group_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("m")).cumsum(dim=seq_dim) return group_mask def _get_visibility_mask(commands, seq_dim=0): """ Args: commands: Shape [S, ...] """ S = commands.size(seq_dim) with torch.no_grad(): visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) < S - 1 if seq_dim == 0: return visibility_mask.unsqueeze(-1) return visibility_mask def _get_key_visibility_mask(commands, seq_dim=0): S = commands.size(seq_dim) with torch.no_grad(): key_visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) >= S - 1 if seq_dim == 0: return key_visibility_mask.transpose(0, 1) return key_visibility_mask def _generate_square_subsequent_mask(sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def _sample_categorical(temperature=0.0001, *args_logits): if len(args_logits) == 1: arg_logits, = args_logits return Categorical(logits=arg_logits / temperature).sample() return (*(Categorical(logits=arg_logits / temperature).sample() for arg_logits in args_logits),) def _threshold_sample(arg_logits, threshold=0.5, temperature=1.0): scores = F.softmax(arg_logits / temperature, dim=-1)[..., 1] return scores > threshold