File size: 3,441 Bytes
33e3a91 4f821f0 33e3a91 73e61ac 33e3a91 73e61ac f7db087 73e61ac 33e3a91 b7e88e1 33e3a91 04d9b94 33e3a91 96c45a5 9ca6f22 96c45a5 04d9b94 33e3a91 5d75e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import json
from tqdm import tqdm
from copy import deepcopy
import soundfile as sf
import numpy as np
import gradio as gr
import torch
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
from util import print_size, sampling
from network import CleanUNet
import torchaudio
import torchaudio.transforms as T
SAMPLE_RATE = 22050
def load_simple(filename):
wav, sr = torchaudio.load(filename)
resampler = T.Resample(sr, SAMPLE_RATE, dtype=wav.dtype)
resampled_wav = resampler(wav)
return resampled_wav
CONFIG = "configs/DNS-large-full.json"
CHECKPOINT = "./exp/DNS-large-full/checkpoint/pretrained.pkl"
# Parse configs. Globals nicer in this case
with open(CONFIG) as f:
data = f.read()
config = json.loads(data)
gen_config = config["gen_config"]
global network_config
network_config = config["network_config"] # to define wavenet
global train_config
train_config = config["train_config"] # train config
global trainset_config
trainset_config = config["trainset_config"] # to read trainset configurations
def denoise(filename, ckpt_path = CHECKPOINT, out = "out.wav"):
"""
Denoise audio
Parameters:
output_directory (str): save generated speeches to this path
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded;
automitically selects the maximum iteration if 'max' is selected
subset (str): training, testing, validation
dump (bool): whether save enhanced (denoised) audio
"""
# setup local experiment path
exp_path = train_config["exp_path"]
print('exp_path:', exp_path)
# load data
loader_config = deepcopy(trainset_config)
loader_config["crop_length_sec"] = 0
# predefine model
net = CleanUNet(**network_config)
print_size(net)
# load checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
# inference
noisy_audio = load_simple(filename)
with torch.no_grad():
with torch.cuda.amp.autocast():
generated_audio = sampling(net, noisy_audio)
generated_audio = generated_audio[0].squeeze().cpu().numpy()
sf.write(out, np.ravel(generated_audio), SAMPLE_RATE)
return out
# audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
# inputs = [audio]
# outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')
# title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia"
# gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch()
mic_transcribe = gr.Interface(
fn=denoise,
inputs=[
gr.inputs.Audio(source="microphone", label="Audio to denoise", type="filepath", optional=True),
],
outputs=gr.outputs.Audio(label = "Denoised audio", type = 'filepath'),
layout="horizontal",
#theme="huggingface",
title="My Demo: Speech enhancement",
#description=(
# "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
# f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
# " of arbitrary length."
# ),
allow_flagging="never",
)
mic_transcribe.launch() |