File size: 2,929 Bytes
33e3a91
 
 
 
4f821f0
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
0ead0b4
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c45a5
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c45a5
 
 
 
 
 
 
 
 
 
 
 
 
 
4f821f0
96c45a5
 
33e3a91
 
 
28d63d4
33e3a91
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
from tqdm import tqdm
from copy import deepcopy

import soundfile as sf
import numpy as np
import gradio as gr
import torch

import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

from util import print_size, sampling
from network import CleanUNet
import torchaudio

def load_simple(filename):
    print(filename)
    audio, _ = torchaudio.load(filename)
    return audio

CONFIG = "configs/DNS-large-full.json"
CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl"

# Parse configs. Globals nicer in this case
with open(CONFIG) as f:
    data = f.read()
    config = json.loads(data)
    gen_config              = config["gen_config"]
    global network_config
    network_config          = config["network_config"]      # to define wavenet
    global train_config
    train_config            = config["train_config"]        # train config
    global trainset_config
    trainset_config         = config["trainset_config"]     # to read trainset configurations

def denoise(filename, ckpt_path = CHECKPOINT):
    """
    Denoise audio
    Parameters:
    output_directory (str):         save generated speeches to this path
    ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; 
                                    automitically selects the maximum iteration if 'max' is selected
    subset (str):                   training, testing, validation
    dump (bool):                    whether save enhanced (denoised) audio
    """

    # setup local experiment path
    exp_path = train_config["exp_path"]
    print('exp_path:', exp_path)

    # load data
    loader_config = deepcopy(trainset_config)
    loader_config["crop_length_sec"] = 0

    # predefine model
    net = CleanUNet(**network_config)
    print_size(net)

    # load checkpoint
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()

    # inference
    batch_size = 1000000
    new_file_name = filename + "_denoised.wav"
    noisy_audio = load_simple(filename)
    LENGTH = len(noisy_audio[0].squeeze())
    noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1)
    all_audio = []

    for batch in tqdm(noisy_audio):
        with torch.no_grad():
            generated_audio = sampling(net, batch)
            generated_audio = generated_audio.cpu().numpy().squeeze()
            all_audio.append(generated_audio)

    all_audio = np.concatenate(all_audio, axis=0)
    print("saved to:", new_file_name)
    sf.write(new_file_name, all_audio.squeeze(), 32000)

    return new_file_name


audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
inputs = [audio]
outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')

title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia"

gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch()