import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from .base_model import BaseModel class RMSNorm(nn.Module): def __init__(self, dimension, groups=1): super().__init__() self.weight = nn.Parameter(torch.ones(dimension)) self.groups = groups self.eps = 1e-5 def forward(self, input): # input size: (B, N, T) B, N, T = input.shape assert N % self.groups == 0 input_float = input.reshape(B, self.groups, -1, T).float() input_norm = input_float * torch.rsqrt(input_float.pow(2).mean(-2, keepdim=True) + self.eps) return input_norm.type_as(input).reshape(B, N, T) * self.weight.reshape(1, -1, 1) class RMVN(nn.Module): """ Rescaled MVN. """ def __init__(self, dimension, groups=1): super(RMVN, self).__init__() self.mean = nn.Parameter(torch.zeros(dimension)) self.std = nn.Parameter(torch.ones(dimension)) self.groups = groups self.eps = 1e-5 def forward(self, input): # input size: (B, N, *) B, N = input.shape[:2] assert N % self.groups == 0 input_reshape = input.reshape(B, self.groups, N // self.groups, -1) T = input_reshape.shape[-1] input_norm = (input_reshape - input_reshape.mean(2).unsqueeze(2)) / (input_reshape.var(2).unsqueeze(2) + self.eps).sqrt() input_norm = input_norm.reshape(B, N, T) * self.std.reshape(1, -1, 1) + self.mean.reshape(1, -1, 1) return input_norm.reshape(input.shape) class Roformer(nn.Module): """ Transformer with rotary positional embedding. """ def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000, input_drop=0., attention_drop=0., causal=True): super().__init__() self.input_size = input_size self.hidden_size = hidden_size // num_head self.num_head = num_head self.theta = theta # base frequency for RoPE self.window = window # pre-calculate rotary embeddings cos_freq, sin_freq = self._calc_rotary_emb() self.register_buffer("cos_freq", cos_freq) # win, N self.register_buffer("sin_freq", sin_freq) # win, N self.attention_drop = attention_drop self.causal = causal self.eps = 1e-5 self.input_norm = RMSNorm(self.input_size) self.input_drop = nn.Dropout(p=input_drop) self.weight = nn.Conv1d(self.input_size, self.hidden_size*self.num_head*3, 1, bias=False) self.output = nn.Conv1d(self.hidden_size*self.num_head, self.input_size, 1, bias=False) self.MLP = nn.Sequential(RMSNorm(self.input_size), nn.Conv1d(self.input_size, self.input_size*8, 1, bias=False), nn.SiLU() ) self.MLP_output = nn.Conv1d(self.input_size*4, self.input_size, 1, bias=False) def _calc_rotary_emb(self): freq = 1. / (self.theta ** (torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size)) # theta_i freq = freq.reshape(1, -1) # 1, N//2 pos = torch.arange(0, self.window).reshape(-1, 1) # win, 1 cos_freq = torch.cos(pos*freq) # win, N//2 sin_freq = torch.sin(pos*freq) # win, N//2 cos_freq = torch.stack([cos_freq]*2, -1).reshape(self.window, self.hidden_size) # win, N sin_freq = torch.stack([sin_freq]*2, -1).reshape(self.window, self.hidden_size) # win, N return cos_freq, sin_freq def _add_rotary_emb(self, feature, pos): # feature shape: ..., N N = feature.shape[-1] feature_reshape = feature.reshape(-1, N) pos = min(pos, self.window-1) cos_freq = self.cos_freq[pos] sin_freq = self.sin_freq[pos] reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, N) feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) return feature_rope.reshape(feature.shape) def _add_rotary_sequence(self, feature): # feature shape: ..., T, N T, N = feature.shape[-2:] feature_reshape = feature.reshape(-1, T, N) cos_freq = self.cos_freq[:T] sin_freq = self.sin_freq[:T] reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, T, N) feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) return feature_rope.reshape(feature.shape) def forward(self, input): # input shape: B, N, T B, _, T = input.shape weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size*3, T).mT Q, K, V = torch.split(weight, self.hidden_size, dim=-1) # B, num_head, T, N # rotary positional embedding Q_rot = self._add_rotary_sequence(Q) K_rot = self._add_rotary_sequence(K) attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(), dropout_p=self.attention_drop, is_causal=self.causal) # B, num_head, T, N attention_output = attention_output.mT.reshape(B, -1, T) output = self.output(attention_output) + input gate, z = self.MLP(output).chunk(2, dim=1) output = output + self.MLP_output(F.silu(gate) * z) return output, (K_rot, V) class ConvActNorm1d(nn.Module): def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): super(ConvActNorm1d, self).__init__() self.in_channel = in_channel self.kernel = kernel self.causal = causal if not causal: self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2, groups=in_channel), RMSNorm(in_channel), nn.Conv1d(in_channel, hidden_channel, 1), nn.SiLU(), nn.Conv1d(hidden_channel, in_channel, 1) ) else: self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1, groups=in_channel), RMSNorm(in_channel), nn.Conv1d(in_channel, hidden_channel, 1), nn.SiLU(), nn.Conv1d(hidden_channel, in_channel, 1) ) def forward(self, input): output = self.conv(input) if self.causal: output = output[...,:-self.kernel+1] return input + output class ICB(nn.Module): def __init__(self, in_channel, kernel=7, causal=False): super(ICB, self).__init__() self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) ) def forward(self, input): return self.blocks(input) class BSNet(nn.Module): def __init__(self, feature_dim, kernel=7): super(BSNet, self).__init__() self.feature_dim = feature_dim self.band_net = Roformer(self.feature_dim, self.feature_dim, num_head=8, window=100, causal=False) self.seq_net = ICB(self.feature_dim, kernel=kernel) def forward(self, input): # input shape: B, nband, N, T B, nband, N, T = input.shape # band comm band_input = input.permute(0,3,2,1).reshape(B*T, -1, nband) band_output, _ = self.band_net(band_input) band_output = band_output.reshape(B, T, -1, nband).permute(0,3,2,1) # sequence modeling output = self.seq_net(band_output.reshape(B*nband, -1, T)).reshape(B, nband, -1, T) # B, nband, N, T return output class Apollo(BaseModel): def __init__( self, sr: int, win: int, feature_dim: int, layer: int ): super().__init__(sample_rate=sr) self.sr = sr self.win = int(sr * win // 1000) self.stride = self.win // 2 self.enc_dim = self.win // 2 + 1 self.feature_dim = feature_dim self.eps = torch.finfo(torch.float32).eps # 80 bands bandwidth = int(self.win / 160) self.band_width = [bandwidth]*79 self.band_width.append(self.enc_dim - np.sum(self.band_width)) self.nband = len(self.band_width) print(self.band_width, self.nband) self.BN = nn.ModuleList([]) for i in range(self.nband): self.BN.append(nn.Sequential(RMSNorm(self.band_width[i]*2+1), nn.Conv1d(self.band_width[i]*2+1, self.feature_dim, 1)) ) self.net = [] for _ in range(layer): self.net.append(BSNet(self.feature_dim)) self.net = nn.Sequential(*self.net) self.output = nn.ModuleList([]) for i in range(self.nband): self.output.append(nn.Sequential(RMSNorm(self.feature_dim), nn.Conv1d(self.feature_dim, self.band_width[i]*4, 1), nn.GLU(dim=1) ) ) def spec_band_split(self, input): B, nch, nsample = input.shape spec = torch.stft(input.view(B*nch, nsample), n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(input.device), return_complex=True) subband_spec = [] subband_spec_norm = [] subband_power = [] band_idx = 0 for i in range(self.nband): this_spec = spec[:,band_idx:band_idx+self.band_width[i]] subband_spec.append(this_spec) # B, BW, T subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T subband_spec_norm.append(torch.complex(this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1])) # B, BW, T band_idx += self.band_width[i] subband_power = torch.cat(subband_power, 1) # B, nband, T return subband_spec_norm, subband_power def feature_extractor(self, input): subband_spec_norm, subband_power = self.spec_band_split(input) # normalization and bottleneck subband_feature = [] for i in range(self.nband): concat_spec = torch.cat([subband_spec_norm[i].real, subband_spec_norm[i].imag, torch.log(subband_power[:,i].unsqueeze(1))], 1) subband_feature.append(self.BN[i](concat_spec)) subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T return subband_feature def forward(self, input): B, nch, nsample = input.shape subband_feature = self.feature_extractor(input) feature = self.net(subband_feature) est_spec = [] for i in range(self.nband): this_RI = self.output[i](feature[:,i]).view(B*nch, 2, self.band_width[i], -1) est_spec.append(torch.complex(this_RI[:,0], this_RI[:,1])) est_spec = torch.cat(est_spec, 1) output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1) return output def get_model_args(self): model_args = {"n_sample_rate": 2} return model_args