Spaces:
Runtime error
Runtime error
| import time | |
| from typing import Optional | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from torch.nn.utils import weight_norm | |
| # Scripting this brings model speed up 1.4x | |
| def snake(x, alpha): | |
| shape = x.shape | |
| x = x.reshape(shape[0], shape[1], -1) | |
| x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) | |
| x = x.reshape(shape) | |
| return x | |
| class Snake1d(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.ones(1, channels, 1)) | |
| def forward(self, x): | |
| return snake(x, self.alpha) | |
| def num_params(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def recurse_children(module, fn): | |
| for child in module.children(): | |
| if isinstance(child, nn.ModuleList): | |
| for c in child: | |
| yield recurse_children(c, fn) | |
| if isinstance(child, nn.ModuleDict): | |
| for c in child.values(): | |
| yield recurse_children(c, fn) | |
| yield recurse_children(child, fn) | |
| yield fn(child) | |
| def WNConv1d(*args, **kwargs): | |
| return weight_norm(nn.Conv1d(*args, **kwargs)) | |
| def WNConvTranspose1d(*args, **kwargs): | |
| return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
| class SequentialWithFiLM(nn.Module): | |
| """ | |
| handy wrapper for nn.Sequential that allows FiLM layers to be | |
| inserted in between other layers. | |
| """ | |
| def __init__(self, *layers): | |
| super().__init__() | |
| self.layers = nn.ModuleList(layers) | |
| def has_film(module): | |
| mod_has_film = any( | |
| [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))] | |
| ) | |
| return mod_has_film | |
| def forward(self, x, cond): | |
| for layer in self.layers: | |
| if self.has_film(layer): | |
| x = layer(x, cond) | |
| else: | |
| x = layer(x) | |
| return x | |
| class FiLM(nn.Module): | |
| def __init__(self, input_dim: int, output_dim: int): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| if input_dim > 0: | |
| self.beta = nn.Linear(input_dim, output_dim) | |
| self.gamma = nn.Linear(input_dim, output_dim) | |
| def forward(self, x, r): | |
| if self.input_dim == 0: | |
| return x | |
| else: | |
| beta, gamma = self.beta(r), self.gamma(r) | |
| beta, gamma = ( | |
| beta.view(x.size(0), self.output_dim, 1), | |
| gamma.view(x.size(0), self.output_dim, 1), | |
| ) | |
| x = x * (gamma + 1) + beta | |
| return x | |
| class CodebookEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| latent_dim: int, | |
| n_codebooks: int, | |
| emb_dim: int, | |
| special_tokens: Optional[Tuple[str]] = None, | |
| ): | |
| super().__init__() | |
| self.n_codebooks = n_codebooks | |
| self.emb_dim = emb_dim | |
| self.latent_dim = latent_dim | |
| self.vocab_size = vocab_size | |
| if special_tokens is not None: | |
| for tkn in special_tokens: | |
| self.special = nn.ParameterDict( | |
| { | |
| tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim)) | |
| for tkn in special_tokens | |
| } | |
| ) | |
| self.special_idxs = { | |
| tkn: i + vocab_size for i, tkn in enumerate(special_tokens) | |
| } | |
| self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1) | |
| def from_codes(self, codes: torch.Tensor, codec): | |
| """ | |
| get a sequence of continuous embeddings from a sequence of discrete codes. | |
| unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens | |
| necessary for the language model, like <MASK>. | |
| """ | |
| n_codebooks = codes.shape[1] | |
| latent = [] | |
| for i in range(n_codebooks): | |
| c = codes[:, i, :] | |
| lookup_table = codec.quantizer.quantizers[i].codebook.weight | |
| if hasattr(self, "special"): | |
| special_lookup = torch.cat( | |
| [self.special[tkn][i : i + 1] for tkn in self.special], dim=0 | |
| ) | |
| lookup_table = torch.cat([lookup_table, special_lookup], dim=0) | |
| l = F.embedding(c, lookup_table).transpose(1, 2) | |
| latent.append(l) | |
| latent = torch.cat(latent, dim=1) | |
| return latent | |
| def forward(self, latents: torch.Tensor): | |
| """ | |
| project a sequence of latents to a sequence of embeddings | |
| """ | |
| x = self.out_proj(latents) | |
| return x | |