|
import functools |
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchaudio |
|
from x_transformers import ContinuousTransformerWrapper |
|
from x_transformers.x_transformers import RelativePositionBias |
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
class GroupNorm32(nn.GroupNorm): |
|
def forward(self, x): |
|
return super().forward(x.float()).type(x.dtype) |
|
|
|
|
|
def normalization(channels): |
|
""" |
|
Make a standard normalization layer. |
|
|
|
:param channels: number of input channels. |
|
:return: an nn.Module for normalization. |
|
""" |
|
groups = 32 |
|
if channels <= 16: |
|
groups = 8 |
|
elif channels <= 64: |
|
groups = 16 |
|
while channels % groups != 0: |
|
groups = int(groups / 2) |
|
assert groups > 2 |
|
return GroupNorm32(groups, channels) |
|
|
|
|
|
class QKVAttentionLegacy(nn.Module): |
|
""" |
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping |
|
""" |
|
|
|
def __init__(self, n_heads): |
|
super().__init__() |
|
self.n_heads = n_heads |
|
|
|
def forward(self, qkv, mask=None, rel_pos=None): |
|
""" |
|
Apply QKV attention. |
|
|
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. |
|
:return: an [N x (H * C) x T] tensor after attention. |
|
""" |
|
bs, width, length = qkv.shape |
|
assert width % (3 * self.n_heads) == 0 |
|
ch = width // (3 * self.n_heads) |
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) |
|
scale = 1 / math.sqrt(math.sqrt(ch)) |
|
weight = torch.einsum( |
|
"bct,bcs->bts", q * scale, k * scale |
|
) |
|
if rel_pos is not None: |
|
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) |
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) |
|
if mask is not None: |
|
|
|
mask = mask.repeat(self.n_heads, 1).unsqueeze(1) |
|
weight = weight * mask |
|
a = torch.einsum("bts,bcs->bct", weight, v) |
|
|
|
return a.reshape(bs, -1, length) |
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
""" |
|
An attention block that allows spatial positions to attend to each other. |
|
|
|
Originally ported from here, but adapted to the N-d case. |
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels, |
|
num_heads=1, |
|
num_head_channels=-1, |
|
do_checkpoint=True, |
|
relative_pos_embeddings=False, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.do_checkpoint = do_checkpoint |
|
if num_head_channels == -1: |
|
self.num_heads = num_heads |
|
else: |
|
assert ( |
|
channels % num_head_channels == 0 |
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
|
self.num_heads = channels // num_head_channels |
|
self.norm = normalization(channels) |
|
self.qkv = nn.Conv1d(channels, channels * 3, 1) |
|
|
|
self.attention = QKVAttentionLegacy(self.num_heads) |
|
|
|
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) |
|
if relative_pos_embeddings: |
|
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) |
|
else: |
|
self.relative_pos_embeddings = None |
|
|
|
def forward(self, x, mask=None): |
|
b, c, *spatial = x.shape |
|
x = x.reshape(b, c, -1) |
|
qkv = self.qkv(self.norm(x)) |
|
h = self.attention(qkv, mask, self.relative_pos_embeddings) |
|
h = self.proj_out(h) |
|
return (x + h).reshape(b, c, *spatial) |
|
|
|
|
|
class Upsample(nn.Module): |
|
""" |
|
An upsampling layer with an optional convolution. |
|
|
|
:param channels: channels in the inputs and outputs. |
|
:param use_conv: a bool determining if a convolution is applied. |
|
""" |
|
|
|
def __init__(self, channels, use_conv, out_channels=None, factor=4): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.factor = factor |
|
if use_conv: |
|
ksize = 5 |
|
pad = 2 |
|
self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad) |
|
|
|
def forward(self, x): |
|
assert x.shape[1] == self.channels |
|
x = F.interpolate(x, scale_factor=self.factor, mode="nearest") |
|
if self.use_conv: |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class Downsample(nn.Module): |
|
""" |
|
A downsampling layer with an optional convolution. |
|
|
|
:param channels: channels in the inputs and outputs. |
|
:param use_conv: a bool determining if a convolution is applied. |
|
""" |
|
|
|
def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
|
|
stride = factor |
|
if use_conv: |
|
self.op = nn.Conv1d( |
|
self.channels, self.out_channels, ksize, stride=stride, padding=pad |
|
) |
|
else: |
|
assert self.channels == self.out_channels |
|
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) |
|
|
|
def forward(self, x): |
|
assert x.shape[1] == self.channels |
|
return self.op(x) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
dropout, |
|
out_channels=None, |
|
use_conv=False, |
|
use_scale_shift_norm=False, |
|
up=False, |
|
down=False, |
|
kernel_size=3, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.dropout = dropout |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.use_scale_shift_norm = use_scale_shift_norm |
|
padding = 1 if kernel_size == 3 else 2 |
|
|
|
self.in_layers = nn.Sequential( |
|
normalization(channels), |
|
nn.SiLU(), |
|
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), |
|
) |
|
|
|
self.updown = up or down |
|
|
|
if up: |
|
self.h_upd = Upsample(channels, False) |
|
self.x_upd = Upsample(channels, False) |
|
elif down: |
|
self.h_upd = Downsample(channels, False) |
|
self.x_upd = Downsample(channels, False) |
|
else: |
|
self.h_upd = self.x_upd = nn.Identity() |
|
|
|
self.out_layers = nn.Sequential( |
|
normalization(self.out_channels), |
|
nn.SiLU(), |
|
nn.Dropout(p=dropout), |
|
zero_module( |
|
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding) |
|
), |
|
) |
|
|
|
if self.out_channels == channels: |
|
self.skip_connection = nn.Identity() |
|
elif use_conv: |
|
self.skip_connection = nn.Conv1d( |
|
channels, self.out_channels, kernel_size, padding=padding |
|
) |
|
else: |
|
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) |
|
|
|
def forward(self, x): |
|
if self.updown: |
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] |
|
h = in_rest(x) |
|
h = self.h_upd(h) |
|
x = self.x_upd(x) |
|
h = in_conv(h) |
|
else: |
|
h = self.in_layers(x) |
|
h = self.out_layers(h) |
|
return self.skip_connection(x) + h |
|
|
|
|
|
class AudioMiniEncoder(nn.Module): |
|
def __init__(self, |
|
spec_dim, |
|
embedding_dim, |
|
base_channels=128, |
|
depth=2, |
|
resnet_blocks=2, |
|
attn_blocks=4, |
|
num_attn_heads=4, |
|
dropout=0, |
|
downsample_factor=2, |
|
kernel_size=3): |
|
super().__init__() |
|
self.init = nn.Sequential( |
|
nn.Conv1d(spec_dim, base_channels, 3, padding=1) |
|
) |
|
ch = base_channels |
|
res = [] |
|
for l in range(depth): |
|
for r in range(resnet_blocks): |
|
res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) |
|
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) |
|
ch *= 2 |
|
self.res = nn.Sequential(*res) |
|
self.final = nn.Sequential( |
|
normalization(ch), |
|
nn.SiLU(), |
|
nn.Conv1d(ch, embedding_dim, 1) |
|
) |
|
attn = [] |
|
for a in range(attn_blocks): |
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads,)) |
|
self.attn = nn.Sequential(*attn) |
|
self.dim = embedding_dim |
|
|
|
def forward(self, x): |
|
h = self.init(x) |
|
h = self.res(h) |
|
h = self.final(h) |
|
h = self.attn(h) |
|
return h[:, :, 0] |
|
|
|
|
|
class TorchMelSpectrogram(nn.Module): |
|
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000, |
|
sampling_rate=22050, normalize=False, mel_norm_file='data/mel_norms.pth'): |
|
super().__init__() |
|
|
|
self.filter_length = filter_length |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.n_mel_channels = n_mel_channels |
|
self.mel_fmin = mel_fmin |
|
self.mel_fmax = mel_fmax |
|
self.sampling_rate = sampling_rate |
|
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, |
|
win_length=self.win_length, power=2, normalized=normalize, |
|
sample_rate=self.sampling_rate, f_min=self.mel_fmin, |
|
f_max=self.mel_fmax, n_mels=self.n_mel_channels, |
|
norm="slaney") |
|
self.mel_norm_file = mel_norm_file |
|
if self.mel_norm_file is not None: |
|
self.mel_norms = torch.load(self.mel_norm_file) |
|
else: |
|
self.mel_norms = None |
|
|
|
def forward(self, inp): |
|
if len(inp.shape) == 3: |
|
inp = inp.squeeze(1) |
|
assert len(inp.shape) == 2 |
|
self.mel_stft = self.mel_stft.to(inp.device) |
|
mel = self.mel_stft(inp) |
|
|
|
mel = torch.log(torch.clamp(mel, min=1e-5)) |
|
if self.mel_norms is not None: |
|
self.mel_norms = self.mel_norms.to(mel.device) |
|
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) |
|
return mel |
|
|
|
|
|
class CheckpointedLayer(nn.Module): |
|
""" |
|
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses |
|
checkpoint for all other args. |
|
""" |
|
def __init__(self, wrap): |
|
super().__init__() |
|
self.wrap = wrap |
|
|
|
def forward(self, x, *args, **kwargs): |
|
for k, v in kwargs.items(): |
|
assert not (isinstance(v, torch.Tensor) and v.requires_grad) |
|
partial = functools.partial(self.wrap, **kwargs) |
|
return torch.utils.checkpoint.checkpoint(partial, x, *args) |
|
|
|
|
|
class CheckpointedXTransformerEncoder(nn.Module): |
|
""" |
|
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid |
|
to channels-last that XTransformer expects. |
|
""" |
|
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): |
|
super().__init__() |
|
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) |
|
self.needs_permute = needs_permute |
|
self.exit_permute = exit_permute |
|
|
|
if not checkpoint: |
|
return |
|
for i in range(len(self.transformer.attn_layers.layers)): |
|
n, b, r = self.transformer.attn_layers.layers[i] |
|
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) |
|
|
|
def forward(self, x, **kwargs): |
|
if self.needs_permute: |
|
x = x.permute(0,2,1) |
|
h = self.transformer(x, **kwargs) |
|
if self.exit_permute: |
|
h = h.permute(0,2,1) |
|
return h |