import os os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/compressor_full.pt") os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/reverb_full.pt") os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/amp_full.pt") os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt") os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/synth2synth_full.pt") import sys import math import torch import librosa.display import auraloss import torchaudio import numpy as np import scipy.signal from tqdm.notebook import tqdm from time import sleep import pyloudnorm as pyln import gradio as gr def measure_rt60(h, fs=1, decay_db=30, rt60_tgt=None): """ Analyze the RT60 of an impulse response. Args: h (ndarray): The discrete time impulse response as 1d array. fs (float, optional): Sample rate of the impulse response. (Default: 48000) decay_db (float, optional): The decay in decibels for which we actually estimate the time. (Default: 60) rt60_tgt (float, optional): This parameter can be used to indicate a target RT60. (Default: None) Returns: est_rt60 (float): Estimated RT60. """ h = np.array(h) fs = float(fs) # The power of the impulse response in dB power = h ** 2 energy = np.cumsum(power[::-1])[::-1] # Integration according to Schroeder try: # remove the possibly all zero tail i_nz = np.max(np.where(energy > 0)[0]) energy = energy[:i_nz] energy_db = 10 * np.log10(energy) energy_db -= energy_db[0] # -5 dB headroom i_5db = np.min(np.where(-5 - energy_db > 0)[0]) e_5db = energy_db[i_5db] t_5db = i_5db / fs # after decay i_decay = np.min(np.where(-5 - decay_db - energy_db > 0)[0]) t_decay = i_decay / fs # compute the decay time decay_time = t_decay - t_5db est_rt60 = (60 / decay_db) * decay_time except: est_rt60 = np.array(0.0) return est_rt60 def causal_crop(x, length: int): if x.shape[-1] != length: stop = x.shape[-1] - 1 start = stop - length x = x[..., start:stop] return x class FiLM(torch.nn.Module): def __init__( self, cond_dim, # dim of conditioning input num_features, # dim of the conv channel batch_norm=True, ): super().__init__() self.num_features = num_features self.batch_norm = batch_norm if batch_norm: self.bn = torch.nn.BatchNorm1d(num_features, affine=False) self.adaptor = torch.nn.Linear(cond_dim, num_features * 2) def forward(self, x, cond): cond = self.adaptor(cond) g, b = torch.chunk(cond, 2, dim=-1) g = g.permute(0, 2, 1) b = b.permute(0, 2, 1) if self.batch_norm: x = self.bn(x) # apply BatchNorm without affine x = (x * g) + b # then apply conditional affine return x class TCNBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=0, activation=True): super().__init__() self.conv = torch.nn.Conv1d( in_channels, out_channels, kernel_size, dilation=dilation, padding=0, #((kernel_size-1)//2)*dilation, bias=True) if cond_dim > 0: self.film = FiLM(cond_dim, out_channels, batch_norm=False) if activation: #self.act = torch.nn.Tanh() self.act = torch.nn.PReLU() self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False) def forward(self, x, c=None): x_in = x x = self.conv(x) if hasattr(self, "film"): x = self.film(x, c) if hasattr(self, "act"): x = self.act(x) x_res = causal_crop(self.res(x_in), x.shape[-1]) x = x + x_res return x class TCN(torch.nn.Module): def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0): super().__init__() self.kernel_size = kernel_size self.n_channels = n_channels self.dilation_growth = dilation_growth self.n_blocks = n_blocks self.stack_size = n_blocks self.blocks = torch.nn.ModuleList() for n in range(n_blocks): if n == 0: in_ch = n_inputs out_ch = n_channels act = True elif (n+1) == n_blocks: in_ch = n_channels out_ch = n_outputs act = True else: in_ch = n_channels out_ch = n_channels act = True dilation = dilation_growth ** n self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act)) def forward(self, x, c=None): for block in self.blocks: x = block(x, c) return x def compute_receptive_field(self): """Compute the receptive field in samples.""" rf = self.kernel_size for n in range(1, self.n_blocks): dilation = self.dilation_growth ** (n % self.stack_size) rf = rf + ((self.kernel_size - 1) * dilation) return rf # setup the pre-trained models model_comp = torch.load("compressor_full.pt", map_location="cpu").eval() model_verb = torch.load("reverb_full.pt", map_location="cpu").eval() model_amp = torch.load("amp_full.pt", map_location="cpu").eval() model_delay = torch.load("delay_full.pt", map_location="cpu").eval() model_synth = torch.load("synth2synth_full.pt", map_location="cpu").eval() def inference(aud, effect_type): x_p, sample_rate = torchaudio.load(aud) effect_type = effect_type #@param ["Compressor", "Reverb", "Amp", "Analog Delay", "Synth2Synth"] gain_dB = -24 #@param {type:"slider", min:-24, max:24, step:0.1} c0 = -1.4 #@param {type:"slider", min:-10, max:10, step:0.1} c1 = 3 #@param {type:"slider", min:-10, max:10, step:0.1} mix = 70 #@param {type:"slider", min:0, max:100, step:1} width = 50 #@param {type:"slider", min:0, max:100, step:1} max_length = 30 #@param {type:"slider", min:5, max:120, step:1} stereo = True #@param {type:"boolean"} tail = True #@param {type:"boolean"} # select model type if effect_type == "Compressor": pt_model = model_comp elif effect_type == "Reverb": pt_model = model_verb elif effect_type == "Amp": pt_model = model_amp elif effect_type == "Analog Delay": pt_model = model_delay elif effect_type == "Synth2Synth": pt_model = model_synth # measure the receptive field pt_model_rf = pt_model.compute_receptive_field() # crop input signal if needed max_samples = int(sample_rate * max_length) x_p_crop = x_p[:,:max_samples] chs = x_p_crop.shape[0] # if mono and stereo requested if chs == 1 and stereo: x_p_crop = x_p_crop.repeat(2,1) chs = 2 # pad the input signal front_pad = pt_model_rf-1 back_pad = 0 if not tail else front_pad x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad)) # design highpass filter sos = scipy.signal.butter( 8, 20.0, fs=sample_rate, output="sos", btype="highpass" ) # compute linear gain gain_ln = 10 ** (gain_dB / 20.0) # process audio with pre-trained model with torch.no_grad(): y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad) for n in range(chs): if n == 0: factor = (width*5e-3) elif n == 1: factor = -(width*5e-3) c = torch.tensor([float(c0+factor), float(c1+factor)]).view(1,1,-1) y_hat_ch = pt_model(gain_ln * x_p_pad[n,:].view(1,1,-1), c) y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy()) y_hat_ch = torch.tensor(y_hat_ch) y_hat[n,:] = y_hat_ch # pad the dry signal x_dry = torch.nn.functional.pad(x_p_crop, (0,back_pad)) # normalize each first y_hat /= y_hat.abs().max() x_dry /= x_dry.abs().max() # mix mix = mix/100.0 y_hat = (mix * y_hat) + ((1-mix) * x_dry) # remove transient y_hat = y_hat[...,8192:] y_hat /= y_hat.abs().max() torchaudio.save("output.mp3", y_hat.view(chs,-1), sample_rate, compression=320.0) return "output.mp3" title = "Steerable nafx" description = "Gradio demo for Demucs: Music Source Separation in the Waveform Domain. To use it, simply upload your audio, or click one of the examples to load them. Read more at the links below." article = "

Music Source Separation in the Waveform Domain | Github Repo

" gr.Interface( inference, [gr.inputs.Audio(type="filepath", label="Input"),gr.inputs.Dropdown(choices=["Compressor", "Reverb", "Amp", "Analog Delay", "Synth2Synth"], type="value", default="Analog Delay", label="Effect Type")], gr.outputs.Audio(type="file", label="Output"), title=title, description=description, article=article, enable_queue=True ).launch(debug=True)