timbre-trap / app.py
Nathan Pruyne
Move audio back to CPU
2e4533f
import spaces
from timbre_trap.framework.modules import TimbreTrap
from pyharp import *
import gradio as gr
import torchaudio
import torch
import os
model = TimbreTrap(sample_rate=22050,
n_octaves=9,
bins_per_octave=60,
secs_per_block=3,
latent_size=128,
model_complexity=2,
skip_connections=False)
model.eval()
model_path_orig = os.path.join('models', 'tt-orig.pt')
#model_path_demo = os.path.join('models', 'tt-demo.pt')
tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
#tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
# if torch.cuda.is_available():
# model = model.cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model_card = ModelCard(
name='Timbre-Trap',
description='De-timbre your audio!',
author='Frank Cwitkowitz',
tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
)
@spaces.GPU
def process_fn(audio_path, transcribe):#, demo):
# Load the audio with torchaudio
audio, fs = torchaudio.load(audio_path)
# Average channels to obtain mono-channel
audio = torch.mean(audio, dim=0, keepdim=True)
# Resample audio to the specified sampling rate
audio = torchaudio.functional.resample(audio, fs, 22050)
# Add a batch dimension
audio = audio.unsqueeze(0)
# Determine original number of samples
n_samples = audio.size(-1)
"""
if demo:
# Load weights of the demo version
model.load_state_dict(tt_weights_demo)
else:
"""
# Load weights of the original model
model.load_state_dict(tt_weights_orig)
audio = audio.to(device)
# Obtain transcription or reconstructed spectral coefficients
coefficients = model.chunked_inference(audio, transcribe)
# Invert coefficients to produce audio
audio = model.sliCQ.decode(coefficients)
# Trim to original number of samples
audio = audio[..., :n_samples]
# Remove batch dimension
audio = audio.squeeze(0)
# Low-pass filter the audio in attempt to remove artifacts
audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000)
# Resample audio back to the original sampling rate
audio = torchaudio.functional.resample(audio, 22050, fs)
audio = audio.cpu()
# Create a temporary directory for output
os.makedirs('_outputs', exist_ok=True)
# Create a path for saving the audio
save_path = os.path.join('_outputs', 'output.wav')
# Save the audio
torchaudio.save(save_path, audio, fs)
# No output labels
output_labels = LabelList()
return save_path, output_labels
# Build Gradio endpoint
with gr.Blocks() as demo:
components = [
gr.Checkbox(
value=False,
label='Remove Timbre'
)
]
app = build_endpoint(model_card=model_card,
components=components,
process_fn=process_fn)
demo.queue()
demo.launch(share=True)