|
import gradio as gr |
|
import torch |
|
import gc |
|
import spaces |
|
import gc |
|
import os |
|
import random |
|
import numpy as np |
|
from scipy.signal.windows import hann |
|
import soundfile as sf |
|
import torch |
|
import librosa |
|
from audiosr import build_model, super_resolution |
|
from scipy import signal |
|
import pyloudnorm as pyln |
|
import tempfile |
|
import spaces |
|
|
|
class AudioUpscaler: |
|
""" |
|
Upscales audio using the AudioSR model. |
|
""" |
|
|
|
def __init__(self, model_name="basic", device="auto"): |
|
""" |
|
Initializes the AudioUpscaler. |
|
|
|
Args: |
|
model_name (str, optional): Name of the AudioSR model to use. Defaults to "basic". |
|
device (str, optional): Device to use for inference. Defaults to "auto". |
|
""" |
|
|
|
self.model_name = model_name |
|
self.device = device |
|
self.sr = 48000 |
|
self.audiosr = None |
|
|
|
def setup(self): |
|
""" |
|
Loads the AudioSR model. |
|
""" |
|
|
|
print("Loading Model...") |
|
self.audiosr = build_model(model_name=self.model_name, device=self.device) |
|
print("Model loaded!") |
|
|
|
def _match_array_shapes(self, array_1: np.ndarray, array_2: np.ndarray): |
|
""" |
|
Matches the shapes of two arrays by padding the shorter one with zeros. |
|
|
|
Args: |
|
array_1 (np.ndarray): First array. |
|
array_2 (np.ndarray): Second array. |
|
|
|
Returns: |
|
np.ndarray: The first array with a matching shape to the second array. |
|
""" |
|
|
|
if (len(array_1.shape) == 1) & (len(array_2.shape) == 1): |
|
if array_1.shape[0] > array_2.shape[0]: |
|
array_1 = array_1[: array_2.shape[0]] |
|
elif array_1.shape[0] < array_2.shape[0]: |
|
array_1 = np.pad( |
|
array_1, |
|
((array_2.shape[0] - array_1.shape[0], 0)), |
|
"constant", |
|
constant_values=0, |
|
) |
|
else: |
|
if array_1.shape[1] > array_2.shape[1]: |
|
array_1 = array_1[:, : array_2.shape[1]] |
|
elif array_1.shape[1] < array_2.shape[1]: |
|
padding = array_2.shape[1] - array_1.shape[1] |
|
array_1 = np.pad( |
|
array_1, ((0, 0), (0, padding)), "constant", constant_values=0 |
|
) |
|
return array_1 |
|
|
|
def _lr_filter( |
|
self, audio, cutoff, filter_type, order=12, sr=48000 |
|
): |
|
""" |
|
Applies a low-pass or high-pass filter to the audio. |
|
|
|
Args: |
|
audio (np.ndarray): Audio data. |
|
cutoff (int): Cutoff frequency. |
|
filter_type (str): Filter type ("lowpass" or "highpass"). |
|
order (int, optional): Filter order. Defaults to 12. |
|
sr (int, optional): Sample rate. Defaults to 48000. |
|
|
|
Returns: |
|
np.ndarray: Filtered audio data. |
|
""" |
|
|
|
audio = audio.T |
|
nyquist = 0.5 * sr |
|
normal_cutoff = cutoff / nyquist |
|
b, a = signal.butter( |
|
order // 2, normal_cutoff, btype=filter_type, analog=False |
|
) |
|
sos = signal.tf2sos(b, a) |
|
filtered_audio = signal.sosfiltfilt(sos, audio) |
|
return filtered_audio.T |
|
|
|
def _process_audio( |
|
self, |
|
input_file, |
|
chunk_size=5.12, |
|
overlap=0.1, |
|
seed=None, |
|
guidance_scale=3.5, |
|
ddim_steps=50, |
|
multiband_ensemble=True, |
|
input_cutoff=14000, |
|
): |
|
""" |
|
Processes the audio in chunks and performs upsampling. |
|
|
|
Args: |
|
input_file (str): Path to the input audio file. |
|
chunk_size (float, optional): Chunk size in seconds. Defaults to 5.12. |
|
overlap (float, optional): Overlap between chunks in seconds. Defaults to 0.1. |
|
seed (int, optional): Random seed. Defaults to None. |
|
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. |
|
ddim_steps (int, optional): Number of inference steps. Defaults to 50. |
|
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. |
|
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. |
|
|
|
Returns: |
|
np.ndarray: Upsampled audio data. |
|
""" |
|
|
|
audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False) |
|
audio = audio.T |
|
sr = input_cutoff * 2 |
|
|
|
is_stereo = len(audio.shape) == 2 |
|
if is_stereo: |
|
audio_ch1, audio_ch2 = audio[:, 0], audio[:, 1] |
|
else: |
|
audio_ch1 = audio |
|
|
|
chunk_samples = int(chunk_size * sr) |
|
overlap_samples = int(overlap * chunk_samples) |
|
|
|
output_chunk_samples = int(chunk_size * self.sr) |
|
output_overlap_samples = int(overlap * output_chunk_samples) |
|
enable_overlap = True if overlap > 0 else False |
|
|
|
def process_chunks(audio): |
|
chunks = [] |
|
original_lengths = [] |
|
start = 0 |
|
while start < len(audio): |
|
print(f"{start} / {len(audio)}") |
|
end = min(start + chunk_samples, len(audio)) |
|
chunk = audio[start:end] |
|
if len(chunk) < chunk_samples: |
|
original_lengths.append(len(chunk)) |
|
pad = np.zeros(chunk_samples - len(chunk)) |
|
chunk = np.concatenate([chunk, pad]) |
|
else: |
|
original_lengths.append(chunk_samples) |
|
chunks.append(chunk) |
|
start += ( |
|
chunk_samples - overlap_samples |
|
if enable_overlap |
|
else chunk_samples |
|
) |
|
return chunks, original_lengths |
|
|
|
chunks_ch1, original_lengths_ch1 = process_chunks(audio_ch1) |
|
if is_stereo: |
|
chunks_ch2, original_lengths_ch2 = process_chunks(audio_ch2) |
|
|
|
sample_rate_ratio = self.sr / sr |
|
total_length = ( |
|
len(chunks_ch1) * output_chunk_samples |
|
- (len(chunks_ch1) - 1) |
|
* (output_overlap_samples if enable_overlap else 0) |
|
) |
|
reconstructed_ch1 = np.zeros((1, total_length)) |
|
|
|
meter_before = pyln.Meter(sr) |
|
meter_after = pyln.Meter(self.sr) |
|
|
|
for i, chunk in enumerate(chunks_ch1): |
|
print(f"{i} / {len(chunks_ch1)}") |
|
loudness_before = meter_before.integrated_loudness(chunk) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: |
|
sf.write(temp_wav.name, chunk, sr) |
|
|
|
out_chunk = super_resolution( |
|
self.audiosr, |
|
temp_wav.name, |
|
seed=seed, |
|
guidance_scale=guidance_scale, |
|
ddim_steps=ddim_steps, |
|
latent_t_per_second=12.8, |
|
) |
|
out_chunk = out_chunk[0] |
|
num_samples_to_keep = int( |
|
original_lengths_ch1[i] * sample_rate_ratio |
|
) |
|
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze() |
|
|
|
loudness_after = meter_after.integrated_loudness(out_chunk) |
|
out_chunk = pyln.normalize.loudness( |
|
out_chunk, loudness_after, loudness_before |
|
) |
|
|
|
if enable_overlap: |
|
actual_overlap_samples = min( |
|
output_overlap_samples, num_samples_to_keep |
|
) |
|
fade_out = np.linspace(1.0, 0.0, actual_overlap_samples) |
|
fade_in = np.linspace(0.0, 1.0, actual_overlap_samples) |
|
|
|
if i == 0: |
|
out_chunk[-actual_overlap_samples:] *= fade_out |
|
elif i < len(chunks_ch1) - 1: |
|
out_chunk[:actual_overlap_samples] *= fade_in |
|
out_chunk[-actual_overlap_samples:] *= fade_out |
|
else: |
|
out_chunk[:actual_overlap_samples] *= fade_in |
|
|
|
start = i * ( |
|
output_chunk_samples - output_overlap_samples |
|
if enable_overlap |
|
else output_chunk_samples |
|
) |
|
end = start + out_chunk.shape[0] |
|
reconstructed_ch1[0, start:end] += out_chunk.flatten() |
|
|
|
if is_stereo: |
|
reconstructed_ch2 = np.zeros((1, total_length)) |
|
for i, chunk in enumerate(chunks_ch2): |
|
print(f"{i} / {len(chunks_ch2)}") |
|
loudness_before = meter_before.integrated_loudness(chunk) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: |
|
sf.write(temp_wav.name, chunk, sr) |
|
|
|
out_chunk = super_resolution( |
|
self.audiosr, |
|
temp_wav.name, |
|
seed=seed, |
|
guidance_scale=guidance_scale, |
|
ddim_steps=ddim_steps, |
|
latent_t_per_second=12.8, |
|
) |
|
out_chunk = out_chunk[0] |
|
num_samples_to_keep = int( |
|
original_lengths_ch2[i] * sample_rate_ratio |
|
) |
|
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze() |
|
|
|
loudness_after = meter_after.integrated_loudness(out_chunk) |
|
out_chunk = pyln.normalize.loudness( |
|
out_chunk, loudness_after, loudness_before |
|
) |
|
|
|
if enable_overlap: |
|
actual_overlap_samples = min( |
|
output_overlap_samples, num_samples_to_keep |
|
) |
|
fade_out = np.linspace(1.0, 0.0, actual_overlap_samples) |
|
fade_in = np.linspace(0.0, 1.0, actual_overlap_samples) |
|
|
|
if i == 0: |
|
out_chunk[-actual_overlap_samples:] *= fade_out |
|
elif i < len(chunks_ch1) - 1: |
|
out_chunk[:actual_overlap_samples] *= fade_in |
|
out_chunk[-actual_overlap_samples:] *= fade_out |
|
else: |
|
out_chunk[:actual_overlap_samples] *= fade_in |
|
|
|
start = i * ( |
|
output_chunk_samples - output_overlap_samples |
|
if enable_overlap |
|
else output_chunk_samples |
|
) |
|
end = start + out_chunk.shape[0] |
|
reconstructed_ch2[0, start:end] += out_chunk.flatten() |
|
|
|
reconstructed_audio = np.stack( |
|
[reconstructed_ch1, reconstructed_ch2], axis=-1 |
|
) |
|
else: |
|
reconstructed_audio = reconstructed_ch1 |
|
|
|
if multiband_ensemble: |
|
low, _ = librosa.load(input_file, sr=48000, mono=False) |
|
output = self._match_array_shapes( |
|
reconstructed_audio[0].T, low |
|
) |
|
crossover_freq = input_cutoff - 1000 |
|
low = self._lr_filter( |
|
low.T, crossover_freq, "lowpass", order=10 |
|
) |
|
high = self._lr_filter( |
|
output.T, crossover_freq, "highpass", order=10 |
|
) |
|
high = self._lr_filter( |
|
high, 23000, "lowpass", order=2 |
|
) |
|
output = low + high |
|
else: |
|
output = reconstructed_audio[0] |
|
|
|
return output |
|
|
|
def predict( |
|
self, |
|
input_file, |
|
output_folder, |
|
ddim_steps=50, |
|
guidance_scale=3.5, |
|
overlap=0.04, |
|
chunk_size=10.24, |
|
seed=None, |
|
multiband_ensemble=True, |
|
input_cutoff=14000, |
|
): |
|
""" |
|
Upscales the audio and saves the result. |
|
|
|
Args: |
|
input_file (str): Path to the input audio file. |
|
output_folder (str): Path to the output folder. |
|
ddim_steps (int, optional): Number of inference steps. Defaults to 50. |
|
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. |
|
overlap (float, optional): Overlap between chunks. Defaults to 0.04. |
|
chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24. |
|
seed (int, optional): Random seed. Defaults to None. |
|
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. |
|
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. |
|
""" |
|
if seed == 0: |
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
waveform = self._process_audio( |
|
input_file, |
|
chunk_size=chunk_size, |
|
overlap=overlap, |
|
seed=seed, |
|
guidance_scale=guidance_scale, |
|
ddim_steps=ddim_steps, |
|
multiband_ensemble=multiband_ensemble, |
|
input_cutoff=input_cutoff, |
|
) |
|
|
|
filename = os.path.splitext(os.path.basename(input_file))[0] |
|
output_file = f"{output_folder}/SR_{filename}.wav" |
|
sf.write(output_file, data=waveform, samplerate=48000, subtype="PCM_16") |
|
print(f"File created: {output_file}") |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return waveform |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=300) |
|
def inference(audio_file, model_name, guidance_scale, ddim_steps, seed): |
|
audiosr = build_model(model_name=model_name) |
|
|
|
gc.collect() |
|
|
|
|
|
if seed == 0: |
|
import random |
|
seed = random.randint(1, 2**32-1) |
|
|
|
waveform = super_resolution( |
|
audiosr, |
|
audio_file, |
|
seed, |
|
guidance_scale=guidance_scale, |
|
ddim_steps=ddim_steps |
|
) |
|
|
|
|
|
|
|
return (48000, waveform) |
|
|
|
@spaces.GPU(duration=300) |
|
def upscale_audio( |
|
input_file, |
|
output_folder, |
|
ddim_steps=20, |
|
guidance_scale=3.5, |
|
overlap=0.04, |
|
chunk_size=10.24, |
|
seed=0, |
|
multiband_ensemble=True, |
|
input_cutoff=14000, |
|
): |
|
""" |
|
Upscales the audio using the AudioSR model. |
|
|
|
Args: |
|
input_file (str): Path to the input audio file. |
|
output_folder (str): Path to the output folder. |
|
ddim_steps (int, optional): Number of inference steps. Defaults to 20. |
|
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. |
|
overlap (float, optional): Overlap between chunks. Defaults to 0.04. |
|
chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. |
|
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. |
|
|
|
Returns: |
|
tuple: Upscaled audio data and sample rate. |
|
""" |
|
torch.cuda.empty_cache() |
|
|
|
gc.collect() |
|
upscaler = AudioUpscaler() |
|
upscaler.setup() |
|
waveform = upscaler.predict( |
|
input_file, |
|
output_folder, |
|
ddim_steps=ddim_steps, |
|
guidance_scale=guidance_scale, |
|
overlap=overlap, |
|
chunk_size=chunk_size, |
|
seed=seed, |
|
multiband_ensemble=multiband_ensemble, |
|
input_cutoff=input_cutoff, |
|
) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
gc.collect() |
|
|
|
return (48000,waveform) |
|
|
|
os.getcwd() |
|
gr.Textbox |
|
|
|
iface = gr.Interface( |
|
fn=upscale_audio, |
|
inputs=[ |
|
gr.Audio(type="filepath", label="Input Audio"), |
|
gr.Textbox(".",label="Out-dir"), |
|
gr.Slider(10, 500, value=20, step=1, label="DDIM Steps", info="Number of inference steps (quality/speed)"), |
|
gr.Slider(1.0, 20.0, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (creativity/fidelity)"), |
|
gr.Slider(0.0, 0.5, value=0.04, step=0.01, label="Overlap (s)", info="Overlap between chunks (smooth transitions)"), |
|
gr.Slider(5.12, 20.48, value=5.12, step=0.64, label="Chunk Size (s)", info="Chunk size (memory/artifact balance)"), |
|
gr.Number(value=0, precision=0, label="Seed", info="Random seed (0 for random)"), |
|
gr.Checkbox(label="Multiband Ensemble", value=False, info="Enhance high frequencies"), |
|
gr.Slider(500, 15000, value=9000, step=500, label="Crossover Frequency (Hz)", info="For multiband processing", visible=True) |
|
], |
|
outputs=gr.Audio(type="numpy", label="Output Audio"), |
|
title="AudioSR", |
|
description="Audio Super Resolution with AudioSR" |
|
) |
|
|
|
iface.launch(share=False) |
|
|