""" Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/dit.py Code is under MIT License """ # References: # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py # https://github.com/jaywalnut310/vits/blob/main/attentions.py # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py import torch import torch.nn as nn import torch.nn.functional as F class FFN(nn.Module): def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.p_dropout = p_dropout self.gin_channels = gin_channels self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) self.drop = nn.Dropout(p_dropout) self.act1 = nn.GELU(approximate="tanh") def forward(self, x, x_mask): x = self.conv_1(x * x_mask) x = self.act1(x) x = self.drop(x) x = self.conv_2(x * x_mask) return x * x_mask class MultiHeadAttention(nn.Module): def __init__(self, channels, out_channels, n_heads, p_dropout=0.): super().__init__() assert channels % n_heads == 0 self.channels = channels self.out_channels = out_channels self.n_heads = n_heads self.p_dropout = p_dropout self.k_channels = channels // n_heads self.conv_q = torch.nn.Conv1d(channels, channels, 1) self.conv_k = torch.nn.Conv1d(channels, channels, 1) self.conv_v = torch.nn.Conv1d(channels, channels, 1) # from https://nn.labml.ai/transformers/rope/index.html self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) self.drop = torch.nn.Dropout(p_dropout) torch.nn.init.xavier_uniform_(self.conv_q.weight) torch.nn.init.xavier_uniform_(self.conv_k.weight) torch.nn.init.xavier_uniform_(self.conv_v.weight) def forward(self, x, attn_mask=None): q = self.conv_q(x) k = self.conv_k(x) v = self.conv_v(x) x = self.attention(q, k, v, mask=attn_mask) x = self.conv_o(x) return x def attention(self, query, key, value, mask=None): b, d, t_s, t_t = (*key.size(), query.size(2)) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head] key = self.key_rotary_pe(key) output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0) output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390 class DiTConVBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__(self, hidden_channels, out_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0): super().__init__() self.norm1 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) self.attn = MultiHeadAttention(hidden_channels + out_channels, hidden_channels + out_channels, num_heads, p_dropout) self.norm2 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) self.mlp = FFN(hidden_channels + out_channels, hidden_channels + out_channels, filter_channels, kernel_size, p_dropout=p_dropout) self.adaLN_modulation = nn.Sequential( nn.Linear(gin_channels, hidden_channels + out_channels) if gin_channels != hidden_channels + out_channels else nn.Identity(), nn.SiLU(), nn.Linear(hidden_channels + out_channels, 6 * (hidden_channels + out_channels), bias=True) ) def forward(self, x, c, x_mask): """ Args: x : [batch_size, channel, time] c : [batch_size, channel] x_mask : [batch_size, 1, time] return the same shape as x """ x = x * x_mask attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time] # attn_mask = attn_mask.to(torch.bool) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1] x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1, 2)).transpose(1, 2), shift_msa, scale_msa), attn_mask) * x_mask # x = x.masked_fill(~x_mask, 0.0) x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1, 2)).transpose(1, 2), shift_mlp, scale_mlp), x_mask) * x_mask # no condition version # x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask) # x = x + self.mlp(self.norm1(x.transpose(1,2)).transpose(1,2), x_mask) return x @staticmethod def modulate(x, shift, scale): return x * (1 + scale) + shift class RotaryPositionalEmbeddings(nn.Module): """ ## RoPE module Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token. """ def __init__(self, d: int, base: int = 10_000): r""" * `d` is the number of features $d$ * `base` is the constant used for calculating $\Theta$ """ super().__init__() self.base = base self.d = int(d) self.cos_cached = None self.sin_cached = None def _build_cache(self, x: torch.Tensor): r""" Cache $\cos$ and $\sin$ values """ # Return if cache is already built if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: return # Get sequence length seq_len = x.shape[0] # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) # Calculate the product of position index and $\theta_i$ idx_theta = torch.einsum("n,d->nd", seq_idx, theta) # Concatenate so that for row $m$ we have # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # Cache them self.cos_cached = idx_theta2.cos()[:, None, None, :] self.sin_cached = idx_theta2.sin()[:, None, None, :] def _neg_half(self, x: torch.Tensor): # $\frac{d}{2}$ d_2 = self.d // 2 # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) def forward(self, x: torch.Tensor): """ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` """ # Cache $\cos$ and $\sin$ values x = x.permute(2, 0, 1, 3) # b h t d -> t b h d self._build_cache(x) # Split the features, we can choose to apply rotary embeddings only to a partial set of features. x_rope, x_pass = x[..., : self.d], x[..., self.d:] # Calculate # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ neg_half_x = self._neg_half(x_rope) x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2)