hugo flores garcia
recovering from a gittastrophe
41b9d24
raw
history blame
4.76 kB
import time
from typing import Optional
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def recurse_children(module, fn):
for child in module.children():
if isinstance(child, nn.ModuleList):
for c in child:
yield recurse_children(c, fn)
if isinstance(child, nn.ModuleDict):
for c in child.values():
yield recurse_children(c, fn)
yield recurse_children(child, fn)
yield fn(child)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class SequentialWithFiLM(nn.Module):
"""
handy wrapper for nn.Sequential that allows FiLM layers to be
inserted in between other layers.
"""
def __init__(self, *layers):
super().__init__()
self.layers = nn.ModuleList(layers)
@staticmethod
def has_film(module):
mod_has_film = any(
[res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
)
return mod_has_film
def forward(self, x, cond):
for layer in self.layers:
if self.has_film(layer):
x = layer(x, cond)
else:
x = layer(x)
return x
class FiLM(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
if input_dim > 0:
self.beta = nn.Linear(input_dim, output_dim)
self.gamma = nn.Linear(input_dim, output_dim)
def forward(self, x, r):
if self.input_dim == 0:
return x
else:
beta, gamma = self.beta(r), self.gamma(r)
beta, gamma = (
beta.view(x.size(0), self.output_dim, 1),
gamma.view(x.size(0), self.output_dim, 1),
)
x = x * (gamma + 1) + beta
return x
class CodebookEmbedding(nn.Module):
def __init__(
self,
vocab_size: int,
latent_dim: int,
n_codebooks: int,
emb_dim: int,
special_tokens: Optional[Tuple[str]] = None,
):
super().__init__()
self.n_codebooks = n_codebooks
self.emb_dim = emb_dim
self.latent_dim = latent_dim
self.vocab_size = vocab_size
if special_tokens is not None:
for tkn in special_tokens:
self.special = nn.ParameterDict(
{
tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
for tkn in special_tokens
}
)
self.special_idxs = {
tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
}
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
def from_codes(self, codes: torch.Tensor, codec):
"""
get a sequence of continuous embeddings from a sequence of discrete codes.
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
necessary for the language model, like <MASK>.
"""
n_codebooks = codes.shape[1]
latent = []
for i in range(n_codebooks):
c = codes[:, i, :]
lookup_table = codec.quantizer.quantizers[i].codebook.weight
if hasattr(self, "special"):
special_lookup = torch.cat(
[self.special[tkn][i : i + 1] for tkn in self.special], dim=0
)
lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
l = F.embedding(c, lookup_table).transpose(1, 2)
latent.append(l)
latent = torch.cat(latent, dim=1)
return latent
def forward(self, latents: torch.Tensor):
"""
project a sequence of latents to a sequence of embeddings
"""
x = self.out_proj(latents)
return x