Spaces:
Runtime error
Runtime error
import json | |
from tqdm import tqdm | |
from copy import deepcopy | |
import soundfile as sf | |
import numpy as np | |
import gradio as gr | |
import torch | |
import random | |
random.seed(0) | |
torch.manual_seed(0) | |
np.random.seed(0) | |
from util import print_size, sampling | |
from network import CleanUNet | |
import torchaudio | |
import torchaudio.transforms as T | |
SAMPLE_RATE = 22050 | |
def load_simple(filename): | |
wav, sr = torchaudio.load(filename) | |
resampler = T.Resample(sr, SAMPLE_RATE, dtype=wav.dtype) | |
resampled_wav = resampler(audio) | |
return resampled_wav | |
CONFIG = "configs/DNS-large-full.json" | |
CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl" | |
# Parse configs. Globals nicer in this case | |
with open(CONFIG) as f: | |
data = f.read() | |
config = json.loads(data) | |
gen_config = config["gen_config"] | |
global network_config | |
network_config = config["network_config"] # to define wavenet | |
global train_config | |
train_config = config["train_config"] # train config | |
global trainset_config | |
trainset_config = config["trainset_config"] # to read trainset configurations | |
def denoise(filename, ckpt_path = CHECKPOINT, out = "out.wav"): | |
""" | |
Denoise audio | |
Parameters: | |
output_directory (str): save generated speeches to this path | |
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; | |
automitically selects the maximum iteration if 'max' is selected | |
subset (str): training, testing, validation | |
dump (bool): whether save enhanced (denoised) audio | |
""" | |
# setup local experiment path | |
exp_path = train_config["exp_path"] | |
print('exp_path:', exp_path) | |
# load data | |
loader_config = deepcopy(trainset_config) | |
loader_config["crop_length_sec"] = 0 | |
# predefine model | |
net = CleanUNet(**network_config) | |
print_size(net) | |
# load checkpoint | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
net.load_state_dict(checkpoint['model_state_dict']) | |
net.eval() | |
# inference | |
noisy_audio = load_simple(filename) | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(): | |
generated_audio = sampling(net, noisy_audio) | |
generated_audio = generated_audio[0].squeeze().cpu().numpy() | |
sf.write(out, np.ravel(generated_audio), SAMPLE_RATE) | |
return out | |
audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath') | |
inputs = [audio] | |
outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath') | |
title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia" | |
gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch() |