import spaces from timbre_trap.framework.modules import TimbreTrap from pyharp import * import gradio as gr import torchaudio import torch import os model = TimbreTrap(sample_rate=22050, n_octaves=9, bins_per_octave=60, secs_per_block=3, latent_size=128, model_complexity=2, skip_connections=False) model.eval() model_path_orig = os.path.join('models', 'tt-orig.pt') #model_path_demo = os.path.join('models', 'tt-demo.pt') tt_weights_orig = torch.load(model_path_orig, map_location='cpu') #tt_weights_demo = torch.load(model_path_demo, map_location='cpu') # if torch.cuda.is_available(): # model = model.cuda() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model_card = ModelCard( name='Timbre-Trap', description='De-timbre your audio!', author='Frank Cwitkowitz', tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering'] ) @spaces.GPU def process_fn(audio_path, transcribe):#, demo): # Load the audio with torchaudio audio, fs = torchaudio.load(audio_path) # Average channels to obtain mono-channel audio = torch.mean(audio, dim=0, keepdim=True) # Resample audio to the specified sampling rate audio = torchaudio.functional.resample(audio, fs, 22050) # Add a batch dimension audio = audio.unsqueeze(0) # Determine original number of samples n_samples = audio.size(-1) """ if demo: # Load weights of the demo version model.load_state_dict(tt_weights_demo) else: """ # Load weights of the original model model.load_state_dict(tt_weights_orig) audio = audio.to(device) # Obtain transcription or reconstructed spectral coefficients coefficients = model.chunked_inference(audio, transcribe) # Invert coefficients to produce audio audio = model.sliCQ.decode(coefficients) # Trim to original number of samples audio = audio[..., :n_samples] # Remove batch dimension audio = audio.squeeze(0) # Low-pass filter the audio in attempt to remove artifacts audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000) # Resample audio back to the original sampling rate audio = torchaudio.functional.resample(audio, 22050, fs) audio = audio.cpu() # Create a temporary directory for output os.makedirs('_outputs', exist_ok=True) # Create a path for saving the audio save_path = os.path.join('_outputs', 'output.wav') # Save the audio torchaudio.save(save_path, audio, fs) # No output labels output_labels = LabelList() return save_path, output_labels # Build Gradio endpoint with gr.Blocks() as demo: components = [ gr.Checkbox( value=False, label='Remove Timbre' ) ] app = build_endpoint(model_card=model_card, components=components, process_fn=process_fn) demo.queue() demo.launch(share=True)