import time from typing import Optional from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils import weight_norm # Scripting this brings model speed up 1.4x @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x): return snake(x, self.alpha) def num_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def recurse_children(module, fn): for child in module.children(): if isinstance(child, nn.ModuleList): for c in child: yield recurse_children(c, fn) if isinstance(child, nn.ModuleDict): for c in child.values(): yield recurse_children(c, fn) yield recurse_children(child, fn) yield fn(child) def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class SequentialWithFiLM(nn.Module): """ handy wrapper for nn.Sequential that allows FiLM layers to be inserted in between other layers. """ def __init__(self, *layers): super().__init__() self.layers = nn.ModuleList(layers) @staticmethod def has_film(module): mod_has_film = any( [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))] ) return mod_has_film def forward(self, x, cond): for layer in self.layers: if self.has_film(layer): x = layer(x, cond) else: x = layer(x) return x class FiLM(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() self.input_dim = input_dim self.output_dim = output_dim if input_dim > 0: self.beta = nn.Linear(input_dim, output_dim) self.gamma = nn.Linear(input_dim, output_dim) def forward(self, x, r): if self.input_dim == 0: return x else: beta, gamma = self.beta(r), self.gamma(r) beta, gamma = ( beta.view(x.size(0), self.output_dim, 1), gamma.view(x.size(0), self.output_dim, 1), ) x = x * (gamma + 1) + beta return x class CodebookEmbedding(nn.Module): def __init__( self, vocab_size: int, latent_dim: int, n_codebooks: int, emb_dim: int, special_tokens: Optional[Tuple[str]] = None, ): super().__init__() self.n_codebooks = n_codebooks self.emb_dim = emb_dim self.latent_dim = latent_dim self.vocab_size = vocab_size if special_tokens is not None: for tkn in special_tokens: self.special = nn.ParameterDict( { tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim)) for tkn in special_tokens } ) self.special_idxs = { tkn: i + vocab_size for i, tkn in enumerate(special_tokens) } self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1) def from_codes(self, codes: torch.Tensor, codec): """ get a sequence of continuous embeddings from a sequence of discrete codes. unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens necessary for the language model, like . """ n_codebooks = codes.shape[1] latent = [] for i in range(n_codebooks): c = codes[:, i, :] lookup_table = codec.quantizer.quantizers[i].codebook.weight if hasattr(self, "special"): special_lookup = torch.cat( [self.special[tkn][i : i + 1] for tkn in self.special], dim=0 ) lookup_table = torch.cat([lookup_table, special_lookup], dim=0) l = F.embedding(c, lookup_table).transpose(1, 2) latent.append(l) latent = torch.cat(latent, dim=1) return latent def forward(self, latents: torch.Tensor): """ project a sequence of latents to a sequence of embeddings """ x = self.out_proj(latents) return x