import os import json from tqdm import tqdm from copy import deepcopy import numpy as np import gradio as gr import torch import random random.seed(0) torch.manual_seed(0) np.random.seed(0) from scipy.io.wavfile import write as wavwrite from util import print_size, sampling from network import CleanUNet import torchaudio def load_simple(filename): audio, _ = torchaudio.load(filename) return audio CONFIG = "configs/DNS-large-full.json" CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl" # Parse configs. Globals nicer in this case with open(CONFIG) as f: data = f.read() config = json.loads(data) gen_config = config["gen_config"] global network_config network_config = config["network_config"] # to define wavenet global train_config train_config = config["train_config"] # train config global trainset_config trainset_config = config["trainset_config"] # to read trainset configurations def denoise(files, ckpt_path): """ Denoise audio Parameters: output_directory (str): save generated speeches to this path ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; automitically selects the maximum iteration if 'max' is selected subset (str): training, testing, validation dump (bool): whether save enhanced (denoised) audio """ # setup local experiment path exp_path = train_config["exp_path"] print('exp_path:', exp_path) # load data loader_config = deepcopy(trainset_config) loader_config["crop_length_sec"] = 0 # predefine model net = CleanUNet(**network_config) print_size(net) # load checkpoint checkpoint = torch.load(ckpt_path, map_location='cpu') net.load_state_dict(checkpoint['model_state_dict']) net.eval() # inference batch_size = 1000000 for file_path in tqdm(files): file_name = os.path.basename(file_path) file_dir = os.path.dirname(file_name) new_file_name = file_name + "_denoised.wav" noisy_audio = load_simple(file_path) LENGTH = len(noisy_audio[0].squeeze()) noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1) all_audio = [] for batch in tqdm(noisy_audio): with torch.no_grad(): generated_audio = sampling(net, batch) generated_audio = generated_audio.cpu().numpy().squeeze() all_audio.append(generated_audio) all_audio = np.concatenate(all_audio, axis=0) save_file = os.path.join(file_dir, new_file_name) print("saved to:", save_file) wavwrite(save_file, 32000, all_audio.squeeze()) audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath') inputs = [audio, CHECKPOINT] outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath') title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia" gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch()