File size: 3,072 Bytes
c1112a0
bb34ae2
883013e
 
 
 
 
 
c1112a0
 
 
 
 
 
 
 
 
 
 
0806b5e
bb34ae2
c1112a0
0806b5e
883013e
b7e4485
 
8a8ea06
bb34ae2
883013e
 
 
 
 
 
0806b5e
883013e
 
 
 
 
 
 
 
 
 
 
0806b5e
 
 
 
 
 
 
 
 
b7e4485
 
 
883013e
bb34ae2
883013e
c1112a0
 
883013e
 
 
 
 
c1112a0
 
883013e
8cc4dc4
 
 
b7e4485
 
 
883013e
 
 
 
 
8cc4dc4
883013e
bb34ae2
 
 
 
883013e
 
c1112a0
883013e
bb34ae2
0806b5e
 
 
 
883013e
 
bb34ae2
 
 
883013e
106218a
883013e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from timbre_trap.framework.modules import TimbreTrap
from pyharp import *

import gradio as gr
import torchaudio
import torch
import os


model = TimbreTrap(sample_rate=22050,
                   n_octaves=9,
                   bins_per_octave=60,
                   secs_per_block=3,
                   latent_size=128,
                   model_complexity=2,
                   skip_connections=False)
model.eval()

model_path_orig = os.path.join('models', 'tt-orig.pt')
#model_path_demo = os.path.join('models', 'tt-demo.pt')

tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
#tt_weights_demo = torch.load(model_path_demo, map_location='cpu')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

model_card = ModelCard(
    name='Timbre-Trap',
    description='De-timbre your audio!',
    author='Frank Cwitkowitz',
    tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
)

def process_fn(audio_path, transcribe):#, demo):
    # Load the audio with torchaudio
    audio, fs = torchaudio.load(audio_path)
    # Average channels to obtain mono-channel
    audio = torch.mean(audio, dim=0, keepdim=True)
    # Resample audio to the specified sampling rate
    audio = torchaudio.functional.resample(audio, fs, 22050)
    # Add a batch dimension
    audio = audio.unsqueeze(0)
    # Determine original number of samples
    n_samples = audio.size(-1)

    """
    if demo:
        # Load weights of the demo version
        model.load_state_dict(tt_weights_demo)
    else:
    """
    # Load weights of the original model
    model.load_state_dict(tt_weights_orig)

    # Add audio to current device
    audio = audio.to(device)

    # Obtain transcription or reconstructed spectral coefficients
    coefficients = model.chunked_inference(audio, transcribe)

    # Invert coefficients to produce audio
    audio = model.sliCQ.decode(coefficients)
    # Trim to original number of samples
    audio = audio[..., :n_samples]
    # Remove batch dimension
    audio = audio.squeeze(0)

    # Low-pass filter the audio in attempt to remove artifacts
    audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000)

    # Resample audio back to the original sampling rate
    audio = torchaudio.functional.resample(audio, 22050, fs)

    # Bring audio back to CPU
    audio = audio.cpu()

    # Create a temporary directory for output
    os.makedirs('_outputs', exist_ok=True)
    # Create a path for saving the audio
    save_path = os.path.join('_outputs', 'output.wav')
    # Save the audio
    torchaudio.save(save_path, audio, fs)

    # No output labels
    output_labels = LabelList()

    return save_path, output_labels


# Build Gradio endpoint
with gr.Blocks() as demo:
    components = [
        gr.Checkbox(
            value=False,
            label='Remove Timbre'
        )
    ]

    app = build_endpoint(model_card=model_card,
                         components=components,
                         process_fn=process_fn)

demo.queue()
demo.launch(share=True)