import gradio as gr import spaces import numpy as np import torch from fastgeco.model import ScoreModel from geco.util.other import pad_spec import os import torchaudio from speechbrain.lobes.models.dual_path import Encoder, SBTransformerBlock, SBTransformerBlock, Dual_Path_Model, Decoder device = 'cuda' if torch.cuda.is_available() else 'cpu' sample_rate = 8000 num_spks = 2 ckpt_path = 'ckpts/' def load_sepformer(ckpt_path): encoder = Encoder( kernel_size=160, out_channels=256, in_channels=1 ) SBtfintra = SBTransformerBlock( num_layers=8, d_model=256, nhead=8, d_ffn=1024, dropout=0, use_positional_encoding=True, norm_before=True, ) SBtfinter = SBTransformerBlock( num_layers=8, d_model=256, nhead=8, d_ffn=1024, dropout=0, use_positional_encoding=True, norm_before=True, ) masknet = Dual_Path_Model( num_spks=num_spks, in_channels=256, out_channels=256, num_layers=2, K=250, intra_model=SBtfintra, inter_model=SBtfinter, norm='ln', linear_layer_after_inter_intra=False, skip_around_intra=True, ) decoder = Decoder( in_channels=256, out_channels=1, kernel_size=160, stride=80, bias=False, ) encoder_weights = torch.load(os.path.join(ckpt_path, 'encoder.ckpt')) encoder.load_state_dict(encoder_weights) masknet_weights = torch.load(os.path.join(ckpt_path, 'masknet.ckpt')) masknet.load_state_dict(masknet_weights) decoder_weights = torch.load(os.path.join(ckpt_path, 'decoder.ckpt')) decoder.load_state_dict(decoder_weights) encoder = encoder.eval().to(device) masknet = masknet.eval().to(device) decoder = decoder.eval().to(device) return encoder, masknet, decoder def load_fastgeco(ckpt_path): checkpoint_file = os.path.join(ckpt_path, 'fastgeco.ckpt') model = ScoreModel.load_from_checkpoint( checkpoint_file, batch_size=1, num_workers=0, kwargs=dict(gpu=False) ) model.eval(no_ema=False) model.to(device) return model encoder, masknet, decoder = load_sepformer(ckpt_path) fastgeco_model = load_fastgeco(ckpt_path) @spaces.GPU def separate(test_file, encoder, masknet, decoder): with torch.no_grad(): print('Process SepFormer...') mix, fs_file = torchaudio.load(test_file) mix = mix.to(device) fs_model = sample_rate # resample the data if needed if fs_file != fs_model: print( "Resampling the audio from {} Hz to {} Hz".format( fs_file, fs_model ) ) tf = torchaudio.transforms.Resample( orig_freq=fs_file, new_freq=fs_model ).to(device) mix = mix.mean(dim=0, keepdim=True) mix = tf(mix) mix = mix.to(device) # Separation mix_w = encoder(mix) est_mask = masknet(mix_w) mix_w = torch.stack([mix_w] * num_spks) sep_h = mix_w * est_mask # Decoding est_sources = torch.cat( [ decoder(sep_h[i]).unsqueeze(-1) for i in range(num_spks) ], dim=-1, ) est_sources = ( est_sources / est_sources.abs().max(dim=1, keepdim=True)[0] ).squeeze() return est_sources, mix @spaces.GPU def correct(model, est_sources, mix): with torch.no_grad(): print('Process Fast-Geco...') N = 1 reverse_starting_point = 0.5 output = [] for idx in range(num_spks): y = est_sources[:, idx].unsqueeze(0) # noisy m = mix min_leng = min(y.shape[-1],m.shape[-1]) y = y[...,:min_leng] m = m[...,:min_leng] T_orig = y.size(1) norm_factor = y.abs().max() y = y / norm_factor m = m / norm_factor Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(device))), 0) Y = pad_spec(Y) M = torch.unsqueeze(model._forward_transform(model._stft(m.to(device))), 0) M = pad_spec(M) timesteps = torch.linspace(reverse_starting_point, 0.03, N, device=Y.device) std = model.sde._std(reverse_starting_point*torch.ones((Y.shape[0],), device=Y.device)) z = torch.randn_like(Y) X_t = Y + z * std[:, None, None, None] t = timesteps[0] dt = timesteps[-1] f, g = model.sde.sde(X_t, t, Y) vec_t = torch.ones(Y.shape[0], device=Y.device) * t mean_x_tm1 = X_t - (f - g**2*model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt #mean of x t minus 1 = mu(x_{t-1}) sample = mean_x_tm1 sample = sample.squeeze() x_hat = model.to_audio(sample.squeeze(), T_orig) x_hat = x_hat * norm_factor new_norm_factor = x_hat.abs().max() x_hat = x_hat / new_norm_factor x_hat = x_hat.squeeze().cpu().numpy() output.append(x_hat) return (sample_rate, output[0]), (sample_rate, output[1]) @spaces.GPU def process_audio(test_file): result, mix = separate(test_file, encoder, masknet, decoder) audio1, audio2 = correct(fastgeco_model, result, mix) return audio1, audio2 # List of demo audio files demo_audio_files = [ ("Demo Audio 1", "demo/item0_mix.wav"), ("Demo Audio 2", "demo/item1_mix.wav"), ("Demo Audio 3", "demo/item2_mix.wav"), ("Demo Audio 4", "demo/item3_mix.wav"), ("Demo Audio 5", "demo/item4_mix.wav"), ] def update_audio_input(choice): return choice # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # Fast-GeCo: Noise-robust Speech Separation with Fast Generative Correction Separate the noisy mixture speech with a generative correction method, only support 2 speakers now. Learn more about 🟣**Fast-GeCo** on the [Fast-GeCo Repo](https://github.com/WangHelin1997/Fast-GeCo/). """) with gr.Tab("Speech Separation"): # Input: Upload audio file with gr.Row(): gt_file_input = gr.Audio(label="Upload Audio to Separate", type="filepath", value="demo/item0_mix.wav") # Dropdown for demo audio selection demo_selector = gr.Dropdown( label="Select Demo Audio", choices=[name for name, _ in demo_audio_files], value="Demo Audio 1" ) button = gr.Button("Generate", scale=1) # Output Component for edited audio with gr.Row(): result1 = gr.Audio(label="Separated Audio 1", type="numpy") result2 = gr.Audio(label="Separated Audio 2", type="numpy") # Update the audio input with the selected demo audio file demo_selector.change( fn=lambda choice: next(path for name, path in demo_audio_files if name == choice), inputs=demo_selector, outputs=gt_file_input ) # Define the trigger and input-output linking button.click( fn=process_audio, inputs=[ gt_file_input, ], outputs=[result1, result2] ) # Launch the Gradio demo demo.launch()