audio_denoiser / app.py
Vageesh1's picture
Update app.py
13e479e
raw
history blame
2.86 kB
import torch
import torchaudio
import torchaudio.functional as F
from torchaudio.utils import download_asset
from pesq import pesq
from pystoi import stoi
import mir_eval
from pydub import AudioSegment
import matplotlib.pyplot as plt
import streamlit as st
from helper import plot_spectrogram,plot_mask,si_snr,generate_mixture,evaluate,get_irms
target_snr=3
#parameters for STFT
N_FFT = 1024
N_HOP = 256
stft = torchaudio.transforms.Spectrogram(
n_fft=N_FFT,
hop_length=N_HOP,
power=None,
)
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)
#defining a psd transform
psd_transform = torchaudio.transforms.PSD()
mvdr_transform = torchaudio.transforms.SoudenMVDR()
#defining the reference microphone
REFERENCE_CHANNEL = 0
#creating a random noise for better calculations
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
waveform_noise = waveform_noise.to(torch.double)
stft_noise = stft(waveform_noise)
def ui():
st.title("Speech Enhancer")
st.markdown("Made by Vageesh")
#making an audio developer uploader:
audio_file = st.file_uploader("Upload an audio file in wav format", type=[ "wav"])
if audio_file is not None:
waveform_clean,sr=torchaudio.load(audio_file)
waveform_clean = waveform_clean.to(torch.double)
stft_clean = stft(waveform_clean)
st.text("Your uploaded audio")
st.audio(audio_file)
#creating a mixture of our audio file and the noise file
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
#making the files into torch double format
waveform_mix = waveform_mix.to(torch.double)
#computing STFT
stft_mix = stft(waveform_mix)
#plotting the spectogram
spec_img=plot_spectrogram(stft_mix)
st.image(spec_img)
#showing mixed audio in streamlit
torchaudio.save("./waveform_mix.wav", waveform_mix, sr)
st.audio("./waveform_mix.wav")
#getting the irms
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
#getting the psd speech
psd_speech = psd_transform(stft_mix, irm_speech)
psd_noise = psd_transform(stft_mix, irm_noise)
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])
#plotting the cleaned audio and hearing it
spec_clean_img=plot_spectrogram(stft_souden)
waveform_souden = waveform_souden.reshape(1, -1)
st.image(spec_clean_img)
torchaudio.save("./waveform_souden.wav", waveform_souden, sr)
st.audio("./waveform_souden.wav")
if __name__=="__main__":
ui()