Spaces:
Running
Running
File size: 3,072 Bytes
c1112a0 bb34ae2 883013e c1112a0 0806b5e bb34ae2 c1112a0 0806b5e 883013e b7e4485 8a8ea06 bb34ae2 883013e 0806b5e 883013e 0806b5e b7e4485 883013e bb34ae2 883013e c1112a0 883013e c1112a0 883013e 8cc4dc4 b7e4485 883013e 8cc4dc4 883013e bb34ae2 883013e c1112a0 883013e bb34ae2 0806b5e 883013e bb34ae2 883013e 106218a 883013e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from timbre_trap.framework.modules import TimbreTrap
from pyharp import *
import gradio as gr
import torchaudio
import torch
import os
model = TimbreTrap(sample_rate=22050,
n_octaves=9,
bins_per_octave=60,
secs_per_block=3,
latent_size=128,
model_complexity=2,
skip_connections=False)
model.eval()
model_path_orig = os.path.join('models', 'tt-orig.pt')
#model_path_demo = os.path.join('models', 'tt-demo.pt')
tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
#tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model_card = ModelCard(
name='Timbre-Trap',
description='De-timbre your audio!',
author='Frank Cwitkowitz',
tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
)
def process_fn(audio_path, transcribe):#, demo):
# Load the audio with torchaudio
audio, fs = torchaudio.load(audio_path)
# Average channels to obtain mono-channel
audio = torch.mean(audio, dim=0, keepdim=True)
# Resample audio to the specified sampling rate
audio = torchaudio.functional.resample(audio, fs, 22050)
# Add a batch dimension
audio = audio.unsqueeze(0)
# Determine original number of samples
n_samples = audio.size(-1)
"""
if demo:
# Load weights of the demo version
model.load_state_dict(tt_weights_demo)
else:
"""
# Load weights of the original model
model.load_state_dict(tt_weights_orig)
# Add audio to current device
audio = audio.to(device)
# Obtain transcription or reconstructed spectral coefficients
coefficients = model.chunked_inference(audio, transcribe)
# Invert coefficients to produce audio
audio = model.sliCQ.decode(coefficients)
# Trim to original number of samples
audio = audio[..., :n_samples]
# Remove batch dimension
audio = audio.squeeze(0)
# Low-pass filter the audio in attempt to remove artifacts
audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000)
# Resample audio back to the original sampling rate
audio = torchaudio.functional.resample(audio, 22050, fs)
# Bring audio back to CPU
audio = audio.cpu()
# Create a temporary directory for output
os.makedirs('_outputs', exist_ok=True)
# Create a path for saving the audio
save_path = os.path.join('_outputs', 'output.wav')
# Save the audio
torchaudio.save(save_path, audio, fs)
# No output labels
output_labels = LabelList()
return save_path, output_labels
# Build Gradio endpoint
with gr.Blocks() as demo:
components = [
gr.Checkbox(
value=False,
label='Remove Timbre'
)
]
app = build_endpoint(model_card=model_card,
components=components,
process_fn=process_fn)
demo.queue()
demo.launch(share=True)
|