Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import gradio as gr | |
import yaml | |
import librosa | |
from tqdm.auto import tqdm | |
import spaces | |
import look2hear.models | |
from ml_collections import ConfigDict | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def load_audio(file_path): | |
audio, samplerate = librosa.load(file_path, mono=False, sr=44100) | |
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}') | |
#audio = dBgain(audio, -6) | |
return torch.from_numpy(audio), samplerate | |
def get_config(config_path): | |
with open(config_path) as f: | |
#config = OmegaConf.load(config_path) | |
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) | |
return config | |
def _getWindowingArray(window_size, fade_size): | |
# IMPORTANT NOTE : | |
# no fades here in the end, only removing the failed ending of the chunk | |
fadein = torch.linspace(1, 1, fade_size) | |
fadeout = torch.linspace(0, 0, fade_size) | |
window = torch.ones(window_size) | |
window[-fade_size:] *= fadeout | |
window[:fade_size] *= fadein | |
return window | |
description = f''' | |
texts | |
''' | |
apollo_config = get_config('configs/apollo.yaml') | |
apollo_vocal2_config = get_config('configs/config_apollo_vocal.yaml') | |
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device) | |
apollo_vocal = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal.bin', **apollo_config['model']).to(device) | |
apollo_vocal2 = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal2.bin', **apollo_vocal2_config['model']).to(device) | |
models = { | |
'apollo': apollo_model, | |
'apollo_vocal': apollo_vocal, | |
'apollo_vocal2': apollo_vocal2 | |
} | |
choices = [ | |
('MP3 restore', 'apollo'), | |
('Apollo vocal', 'apollo_vocal'), | |
('Apollo vocal2', 'apollo_vocal2') | |
] | |
def enchance(choice, audio): | |
print(choice) | |
model = models[choice] | |
test_data, samplerate = load_audio(audio) | |
C = 10 * samplerate # chunk_size seconds to samples | |
N = 2 | |
step = C // N | |
fade_size = 3 * 44100 # 3 seconds | |
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}") | |
border = C - step | |
# handle mono inputs correctly | |
if len(test_data.shape) == 1: | |
test_data = test_data.unsqueeze(0) | |
# Pad the input if necessary | |
if test_data.shape[1] > 2 * border and (border > 0): | |
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect') | |
windowingArray = _getWindowingArray(C, fade_size) | |
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) | |
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) | |
i = 0 | |
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False) | |
while i < test_data.shape[1]: | |
part = test_data[:, i:i + C] | |
length = part.shape[-1] | |
if length < C: | |
if length > C // 2 + 1: | |
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') | |
else: | |
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) | |
chunk = part.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
out = model(chunk).squeeze(0).squeeze(0).cpu() | |
window = windowingArray | |
if i == 0: # First audio chunk, no fadein | |
window[:fade_size] = 1 | |
elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout | |
window[-fade_size:] = 1 | |
result[..., i:i+length] += out[..., :length] * window[..., :length] | |
counter[..., i:i+length] += window[..., :length] | |
i += step | |
progress_bar.update(step) | |
progress_bar.close() | |
final_output = result / counter | |
final_output = final_output.squeeze(0).numpy() | |
np.nan_to_num(final_output, copy=False, nan=0.0) | |
# Remove padding if added earlier | |
if test_data.shape[1] > 2 * border and (border > 0): | |
final_output = final_output[..., border:-border] | |
return samplerate, final_output.T | |
if __name__ == "__main__": | |
i = gr.Interface( | |
fn=enchance, | |
description=description, | |
inputs=[ | |
gr.Dropdown(label="Model", choices=choices, value=choices[0]), | |
gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_progress_color': '#3C82F6'}), | |
], | |
outputs=[ | |
gr.Audio( | |
label="Output Audio", | |
autoplay=False, | |
streaming=False, | |
type="numpy", | |
), | |
], | |
allow_flagging ='never', | |
cache_examples=False, | |
title='Enchanser', | |
) | |
i.queue(max_size=20, default_concurrency_limit=4) | |
i.launch(share=False, server_name="0.0.0.0") | |