import numpy as np |
import streamlit as st |
import librosa |
import soundfile as sf |
import librosa.display |
from config import CONFIG |
import torch |
from dataset import MaskGenerator |
import onnxruntime, onnx |
import matplotlib.pyplot as plt |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
from pystoi import stoi |
from pesq import pesq |
import pandas as pd |
import torchaudio |
from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI |
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ |
from PLCMOS.plc_mos import PLCMOSEstimator |
from speechmos import dnsmos |
from speechmos import plcmos |
import speech_recognition as speech_r |
from jiwer import wer |
import time |
@st.cache |
def load_model(model): |
path = 'lightning_logs/version_0/checkpoints/' + str(model) |
onnx_model = onnx.load(path) |
options = onnxruntime.SessionOptions() |
options.intra_op_num_threads = 2 |
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
session = onnxruntime.InferenceSession(path, options) |
input_names = [x.name for x in session.get_inputs()] |
output_names = [x.name for x in session.get_outputs()] |
return session, onnx_model, input_names, output_names |
def inference(re_im, session, onnx_model, input_names, output_names): |
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], |
dtype=np.float32) |
for i, _input in enumerate(onnx_model.graph.input) |
} |
output_audio = [] |
for t in range(re_im.shape[0]): |
inputs[input_names[0]] = re_im[t] |
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) |
inputs[input_names[1]] = prev_mag |
inputs[input_names[2]] = predictor_state |
inputs[input_names[3]] = mlp_state |
output_audio.append(out) |
output_audio = torch.tensor(np.concatenate(output_audio, 0)) |
output_audio = output_audio.permute(1, 0, 2).contiguous() |
output_audio = torch.view_as_complex(output_audio) |
output_audio = torch.istft(output_audio, window, stride, window=hann) |
return output_audio.numpy() |
def visualize(hr, lr, recon, sr): |
sr = sr |
window_size = 1024 |
window = np.hanning(window_size) |
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) |
stft_hr = 2 * np.abs(stft_hr) / np.sum(window) |
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) |
stft_lr = 2 * np.abs(stft_lr) / np.sum(window) |
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) |
stft_recon = 2 * np.abs(stft_recon) / np.sum(window) |
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) |
ax1.title.set_text('Оригинальный сигнал') |
ax2.title.set_text('Сигнал с потерями') |
ax3.title.set_text('Улучшенный сигнал') |
canvas = FigureCanvas(fig) |
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr) |
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr) |
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr) |
ax1.set_xlabel('Время, с') |
ax1.set_ylabel('Частота, Гц') |
ax2.set_xlabel('Время, с') |
ax2.set_ylabel('Частота, Гц') |
ax3.set_xlabel('Время, с') |
ax3.set_ylabel('Частота, Гц') |
return fig |
packet_size = CONFIG.DATA.EVAL.packet_size |
window = CONFIG.DATA.window_size |
stride = CONFIG.DATA.stride |
title = 'Сокрытие потерь пакетов' |
st.set_page_config(page_title=title, page_icon=":sound:") |
st.title(title) |
st.subheader('1. Загрузка аудио') |
uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц") |
is_file_uploaded = uploaded_file is not None |
if not is_file_uploaded: |
uploaded_file = 'sample.wav' |
target, sr = librosa.load(uploaded_file, sr=48000) |
target = target[:packet_size * (len(target) // packet_size)] |
st.text('Ваше аудио') |
st.audio(uploaded_file) |
model_ver = st.selectbox( |
'Оригинал или Pruned ?', |
('frn.onnx', 'frn_modified.onnx')) |
st.write('Вы выбрали:', model_ver) |
lang = st.selectbox( |
'Выберите язык', |
('ru-RU', 'en-EN')) |
st.write('Вы выбрали:', lang) |
st.subheader('2. Выберите желаемый процент потерь') |
slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)] |
loss_percent = float(slider[0])/100 |
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) |
lossy_input = target.copy().reshape(-1, packet_size) |
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] |
lossy_input *= mask |
lossy_input = lossy_input.reshape(-1) |
hann = torch.sqrt(torch.hann_window(window)) |
lossy_input_tensor = torch.tensor(lossy_input) |
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( |
1).numpy().astype(np.float32) |
session, onnx_model, input_names, output_names = load_model(model_ver) |
if st.button('Сгенерировать потери'): |
start_time = time.time() |
with st.spinner('Ожидайте...'): |
output = inference(re_im, session, onnx_model, input_names, output_names) |
st.subheader('3. Визуализация') |
fig = visualize(target, lossy_input, output, sr) |
st.pyplot(fig) |
st.success('Сделано!') |
st.text(str(time.time() - start_time)) |
sf.write('target.wav', target, sr) |
sf.write('lossy.wav', lossy_input, sr) |
sf.write('enhanced.wav', output, sr) |
st.text('Оригинальное аудио') |
st.audio('target.wav') |
st.text('Аудио с потерями') |
st.audio('lossy.wav') |
st.text('Улучшенное аудио') |
st.audio('enhanced.wav') |
data_clean, samplerate = sf.read('target.wav') |
data_lossy, samplerate = sf.read('lossy.wav') |
data_enhanced, samplerate = sf.read('enhanced.wav') |
min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0]) |
data_clean = data_clean[:min_len] |
data_lossy = data_lossy[:min_len] |
data_enhanced = data_enhanced[:min_len] |
stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5) |
stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5) |
stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5) |
stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] |
if samplerate != 16000: |
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) |
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) |
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) |
pesq_orig = pesq(fs = 16000, ref = data_clean, deg = data_clean, mode='wb') |
pesq_lossy = pesq(fs = 16000, ref = data_clean, deg = data_lossy, mode='wb') |
pesq_enhanced = pesq(fs = 16000, ref = data_clean, deg = data_enhanced, mode='wb') |
psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] |
data_clean, fs = sf.read('target.wav') |
data_lossy, fs = sf.read('lossy.wav') |
data_enhanced, fs = sf.read('enhanced.wav') |
if fs!= 16000: |
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) |
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) |
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) |
PLC_example=PLCMOSEstimator() |
PLC_org = PLC_example.run(audio_degraded=data_clean, audio_clean=data_clean)[0] |
PLC_lossy = PLC_example.run(audio_degraded=data_lossy, audio_clean=data_clean)[0] |
PLC_enhanced = PLC_example.run(audio_degraded=data_enhanced, audio_clean=data_clean)[0] |
PLC_massv1 = [PLC_org, PLC_lossy, PLC_enhanced] |
df_1 = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOSv1']) |
df_1['Audio'] = ['Clean', 'Lossy', 'Enhanced'] |
df_1['PESQ'] = psq_mas |
df_1['STOI'] = stoi_mass |
df_1['PLCMOSv1'] = PLC_massv1 |
PLC_massv2 = [plcmos.run("target.wav", sr=16000)['plcmos'], plcmos.run("lossy.wav", sr=16000)['plcmos'], plcmos.run("enhanced.wav", sr=16000)['plcmos']] |
df_1['PLCMOSv2'] = PLC_massv2 |
r = speech_r.Recognizer() |
harvard = speech_r.AudioFile('target.wav') |
with harvard as source: |
audio = r.record(source) |
orig = r.recognize_google(audio, language = str(lang)) |
harvard = speech_r.AudioFile('lossy.wav') |
try: |
with harvard as source: |
audio = r.record(source) |
lossy = r.recognize_google(audio, language = str(lang)) |
except speech_r.UnknownValueError: |
lossy = '' |
harvard = speech_r.AudioFile('enhanced.wav') |
try: |
with harvard as source: |
audio = r.record(source) |
enhanced = r.recognize_google(audio, language = str(lang)) |
except speech_r.UnknownValueError: |
enhanced = '' |
error1 = wer(orig, orig) |
error2 = wer(orig, lossy) |
error3 = wer(orig, enhanced) |
WER_mass=[error1*100, error2*100, error3*100] |
df_1['WER'] = WER_mass |
st.dataframe(df_1) |