Spaces:
Running
Running
hugo flores garcia
commited on
Commit
·
cd84ee3
1
Parent(s):
39bff10
several mods
Browse files- TODOS +1 -0
- app.py +129 -101
- conf/generated/ivo/c2f.yml +15 -0
- conf/generated/ivo/coarse.yml +8 -0
- conf/generated/ivo/interface.yml +6 -0
- conf/generated/lazaro-ros-sep/c2f.yml +15 -0
- conf/generated/lazaro-ros-sep/coarse.yml +8 -0
- conf/generated/lazaro-ros-sep/interface.yml +6 -0
- conf/generated/lazaro-ros/c2f.yml +15 -0
- conf/generated/lazaro-ros/coarse.yml +8 -0
- conf/generated/lazaro-ros/interface.yml +6 -0
- conf/generated/march-31/c2f.yml +15 -0
- conf/generated/march-31/coarse.yml +8 -0
- conf/generated/march-31/interface.yml +6 -0
- conf/generated/sax-new/c2f.yml +15 -0
- conf/generated/sax-new/coarse.yml +8 -0
- conf/generated/sax-new/interface.yml +6 -0
- conf/generated/saxophone/c2f.yml +15 -0
- conf/generated/saxophone/coarse.yml +8 -0
- conf/generated/saxophone/interface.yml +6 -0
- conf/lora/lora-s2s.yml +27 -0
- conf/lora/lora.yml +1 -1
- scripts/exp/export.py +2 -3
- scripts/exp/train.py +60 -0
- token_telephone/tt.py +15 -13
- vampnet/beats.py +2 -1
- vampnet/control.py +277 -0
- vampnet/interface.py +16 -8
- vampnet/mask.py +10 -6
- vampnet/modules/transformer.py +117 -6
TODOS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[ ] add sketch2sound finetuning
|
app.py
CHANGED
@@ -21,6 +21,7 @@ interface = Interface.default()
|
|
21 |
init_model_choice = open("DEFAULT_MODEL").read().strip()
|
22 |
# load the init model
|
23 |
interface.load_finetuned(init_model_choice)
|
|
|
24 |
|
25 |
def to_output(sig):
|
26 |
return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
|
@@ -105,9 +106,33 @@ def _vamp(
|
|
105 |
n_mask_codebooks, periodic_w, onset_mask_width,
|
106 |
dropout, sampletemp, typical_filtering,
|
107 |
typical_mass, typical_min_tokens, top_p,
|
108 |
-
sample_cutoff, stretch_factor, api=False
|
109 |
):
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
t0 = time.time()
|
112 |
interface.to("cuda" if torch.cuda.is_available() else "cpu")
|
113 |
print(f"using device {interface.device}")
|
@@ -121,6 +146,9 @@ def _vamp(
|
|
121 |
|
122 |
sig = at.AudioSignal(input_audio, sr).to_mono()
|
123 |
|
|
|
|
|
|
|
124 |
# reload the model if necessary
|
125 |
interface.load_finetuned(model_choice)
|
126 |
|
@@ -129,38 +157,70 @@ def _vamp(
|
|
129 |
|
130 |
codes = interface.encode(sig)
|
131 |
|
132 |
-
mask = new_vampnet_mask(
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
periodic_prompt=periodic_p,
|
|
|
|
|
|
|
138 |
upper_codebook_mask=n_mask_codebooks,
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
# save the mask as a txt file
|
143 |
interface.set_chunk_size(10.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
codes, mask = interface.vamp(
|
145 |
codes, mask,
|
146 |
-
batch_size=
|
147 |
-
feedback_steps=
|
148 |
-
_sampling_steps=
|
149 |
time_stretch_factor=stretch_factor,
|
150 |
return_mask=True,
|
151 |
temperature=sampletemp,
|
152 |
typical_filtering=typical_filtering,
|
153 |
typical_mass=typical_mass,
|
154 |
typical_min_tokens=typical_min_tokens,
|
155 |
-
top_p=
|
156 |
seed=_seed,
|
157 |
-
sample_cutoff=
|
158 |
)
|
159 |
print(f"vamp took {time.time() - t0} seconds")
|
160 |
|
161 |
sig = interface.decode(codes)
|
|
|
162 |
|
163 |
-
return to_output(sig)
|
164 |
|
165 |
def vamp(data):
|
166 |
return _vamp(
|
@@ -180,31 +240,29 @@ def vamp(data):
|
|
180 |
top_p=data[top_p],
|
181 |
sample_cutoff=data[sample_cutoff],
|
182 |
stretch_factor=data[stretch_factor],
|
|
|
|
|
|
|
183 |
api=False,
|
184 |
)
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
# stretch_factor=data[stretch_factor],
|
204 |
-
# api=True,
|
205 |
-
# )
|
206 |
-
|
207 |
-
def api_vamp(input_audio, sampletemp, top_p, periodic_p, periodic_w, dropout, stretch_factor, onset_mask_width, typical_filtering, typical_mass, typical_min_tokens, seed, model_choice, n_mask_codebooks, pitch_shift_amt, sample_cutoff):
|
208 |
return _vamp(
|
209 |
seed=seed,
|
210 |
input_audio=input_audio,
|
@@ -222,50 +280,12 @@ def api_vamp(input_audio, sampletemp, top_p, periodic_p, periodic_w, dropout, st
|
|
222 |
top_p=top_p,
|
223 |
sample_cutoff=sample_cutoff,
|
224 |
stretch_factor=stretch_factor,
|
|
|
|
|
|
|
225 |
api=True,
|
226 |
)
|
227 |
|
228 |
-
OUT_DIR = Path("gradio-outputs")
|
229 |
-
OUT_DIR.mkdir(exist_ok=True)
|
230 |
-
def harp_vamp(input_audio_file, periodic_p, n_mask_codebooks):
|
231 |
-
sig = at.AudioSignal(input_audio_file)
|
232 |
-
sr, samples = sig.sample_rate, sig.samples[0][0].detach().cpu().numpy()
|
233 |
-
# convert to int32
|
234 |
-
samples = (samples * np.iinfo(np.int32).max).astype(np.int32)
|
235 |
-
sr, samples = _vamp(
|
236 |
-
seed=0,
|
237 |
-
input_audio=(sr, samples),
|
238 |
-
model_choice=init_model_choice,
|
239 |
-
pitch_shift_amt=0,
|
240 |
-
periodic_p=periodic_p,
|
241 |
-
n_mask_codebooks=n_mask_codebooks,
|
242 |
-
periodic_w=1,
|
243 |
-
onset_mask_width=0,
|
244 |
-
dropout=0.0,
|
245 |
-
sampletemp=1.0,
|
246 |
-
typical_filtering=True,
|
247 |
-
typical_mass=0.15,
|
248 |
-
typical_min_tokens=64,
|
249 |
-
top_p=0.0,
|
250 |
-
sample_cutoff=1.0,
|
251 |
-
stretch_factor=1,
|
252 |
-
)
|
253 |
-
|
254 |
-
sig = at.AudioSignal(samples, sr)
|
255 |
-
# write to file
|
256 |
-
# clear the outdir
|
257 |
-
for p in OUT_DIR.glob("*"):
|
258 |
-
p.unlink()
|
259 |
-
OUT_DIR.mkdir(exist_ok=True)
|
260 |
-
# outpath = OUT_DIR / f"{uuid.uuid4()}.wav"
|
261 |
-
from pyharp import AudioLabel, LabelList, save_audio
|
262 |
-
outpath = save_audio(sig)
|
263 |
-
sig.write(outpath)
|
264 |
-
output_labels = LabelList()
|
265 |
-
output_labels.append(AudioLabel(label='~', t=0.0, amplitude=0.5, description='generated audio'))
|
266 |
-
return outpath, output_labels
|
267 |
-
|
268 |
-
|
269 |
with gr.Blocks() as demo:
|
270 |
with gr.Row():
|
271 |
with gr.Column():
|
@@ -359,6 +379,11 @@ with gr.Blocks() as demo:
|
|
359 |
value=1,
|
360 |
)
|
361 |
|
|
|
|
|
|
|
|
|
|
|
362 |
|
363 |
with gr.Accordion("sampling settings", open=False):
|
364 |
sampletemp = gr.Slider(
|
@@ -399,6 +424,22 @@ with gr.Blocks() as demo:
|
|
399 |
value=1.0,
|
400 |
step=0.01
|
401 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
|
404 |
dropout = gr.Slider(
|
@@ -433,7 +474,7 @@ with gr.Blocks() as demo:
|
|
433 |
|
434 |
audio_outs = []
|
435 |
use_as_input_btns = []
|
436 |
-
for i in range(
|
437 |
with gr.Column():
|
438 |
audio_outs.append(gr.Audio(
|
439 |
label=f"output audio {i+1}",
|
@@ -466,13 +507,16 @@ with gr.Blocks() as demo:
|
|
466 |
n_mask_codebooks,
|
467 |
pitch_shift_amt,
|
468 |
sample_cutoff,
|
|
|
|
|
|
|
469 |
}
|
470 |
|
471 |
# connect widgets
|
472 |
vamp_button.click(
|
473 |
fn=vamp,
|
474 |
inputs=_inputs,
|
475 |
-
outputs=[audio_outs[0]],
|
476 |
)
|
477 |
|
478 |
api_vamp_button = gr.Button("api vamp", visible=True)
|
@@ -491,31 +535,15 @@ with gr.Blocks() as demo:
|
|
491 |
model_choice,
|
492 |
n_mask_codebooks,
|
493 |
pitch_shift_amt,
|
494 |
-
sample_cutoff
|
|
|
|
|
|
|
495 |
],
|
496 |
-
outputs=[audio_outs[0]],
|
497 |
api_name="vamp"
|
498 |
)
|
499 |
|
500 |
-
from pyharp import ModelCard, build_endpoint
|
501 |
-
card = ModelCard(
|
502 |
-
name="vampnet",
|
503 |
-
description="vampnet! is a model for generating audio from audio",
|
504 |
-
author="hugo flores garcía",
|
505 |
-
tags=["music generation"],
|
506 |
-
midi_in=False,
|
507 |
-
midi_out=False
|
508 |
-
)
|
509 |
-
|
510 |
-
# Build a HARP-compatible endpoint
|
511 |
-
app = build_endpoint(model_card=card,
|
512 |
-
components=[
|
513 |
-
periodic_p,
|
514 |
-
n_mask_codebooks,
|
515 |
-
],
|
516 |
-
process_fn=harp_vamp)
|
517 |
-
|
518 |
-
|
519 |
|
520 |
try:
|
521 |
demo.queue()
|
|
|
21 |
init_model_choice = open("DEFAULT_MODEL").read().strip()
|
22 |
# load the init model
|
23 |
interface.load_finetuned(init_model_choice)
|
24 |
+
interface.to(device)
|
25 |
|
26 |
def to_output(sig):
|
27 |
return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
|
|
|
106 |
n_mask_codebooks, periodic_w, onset_mask_width,
|
107 |
dropout, sampletemp, typical_filtering,
|
108 |
typical_mass, typical_min_tokens, top_p,
|
109 |
+
sample_cutoff, stretch_factor, sampling_steps, beat_mask_ms, num_feedback_steps, api=False
|
110 |
):
|
111 |
|
112 |
+
print("args!")
|
113 |
+
print(f"seed: {seed}")
|
114 |
+
print(f"input_audio: {input_audio}")
|
115 |
+
print(f"model_choice: {model_choice}")
|
116 |
+
print(f"pitch_shift_amt: {pitch_shift_amt}")
|
117 |
+
print(f"periodic_p: {periodic_p}")
|
118 |
+
print(f"n_mask_codebooks: {n_mask_codebooks}")
|
119 |
+
print(f"periodic_w: {periodic_w}")
|
120 |
+
print(f"onset_mask_width: {onset_mask_width}")
|
121 |
+
print(f"dropout: {dropout}")
|
122 |
+
print(f"sampletemp: {sampletemp}")
|
123 |
+
print(f"typical_filtering: {typical_filtering}")
|
124 |
+
print(f"typical_mass: {typical_mass}")
|
125 |
+
print(f"typical_min_tokens: {typical_min_tokens}")
|
126 |
+
print(f"top_p: {top_p}")
|
127 |
+
print(f"sample_cutoff: {sample_cutoff}")
|
128 |
+
print(f"stretch_factor: {stretch_factor}")
|
129 |
+
print(f"sampling_steps: {sampling_steps}")
|
130 |
+
print(f"api: {api}")
|
131 |
+
print(f"beat_mask_ms: {beat_mask_ms}")
|
132 |
+
print(f"using device {interface.device}")
|
133 |
+
print(f"num feedback steps: {num_feedback_steps}")
|
134 |
+
|
135 |
+
|
136 |
t0 = time.time()
|
137 |
interface.to("cuda" if torch.cuda.is_available() else "cpu")
|
138 |
print(f"using device {interface.device}")
|
|
|
146 |
|
147 |
sig = at.AudioSignal(input_audio, sr).to_mono()
|
148 |
|
149 |
+
loudness = sig.loudness()
|
150 |
+
sig = interface._preprocess(sig)
|
151 |
+
|
152 |
# reload the model if necessary
|
153 |
interface.load_finetuned(model_choice)
|
154 |
|
|
|
157 |
|
158 |
codes = interface.encode(sig)
|
159 |
|
160 |
+
# mask = new_vampnet_mask(
|
161 |
+
# interface,
|
162 |
+
# codes,
|
163 |
+
# onset_idxs=onsets(sig, hop_length=interface.codec.hop_length),
|
164 |
+
# width=onset_mask_width,
|
165 |
+
# periodic_prompt=periodic_p,
|
166 |
+
# upper_codebook_mask=n_mask_codebooks,
|
167 |
+
# drop_amt=dropout
|
168 |
+
# ).long()
|
169 |
+
|
170 |
+
|
171 |
+
mask = interface.build_mask(
|
172 |
+
codes,
|
173 |
+
sig=sig,
|
174 |
periodic_prompt=periodic_p,
|
175 |
+
periodic_prompt_width=periodic_w,
|
176 |
+
onset_mask_width=onset_mask_width,
|
177 |
+
_dropout=dropout,
|
178 |
upper_codebook_mask=n_mask_codebooks,
|
179 |
+
)
|
180 |
+
if beat_mask_ms > 0:
|
181 |
+
# bm = pmask.mask_or(
|
182 |
+
# pmask.periodic_mask(
|
183 |
+
# codes, periodic_p, periodic_w, random_roll=False
|
184 |
+
# ),
|
185 |
+
# )
|
186 |
+
mask = pmask.mask_and(
|
187 |
+
mask, interface.make_beat_mask(
|
188 |
+
sig, after_beat_s=beat_mask_ms/1000.,
|
189 |
+
)
|
190 |
+
)
|
191 |
+
mask = pmask.codebook_mask(mask, n_mask_codebooks)
|
192 |
+
np.savetxt("scratch/rms_mask.txt", mask[0].cpu().numpy(), fmt='%d')
|
193 |
|
|
|
194 |
interface.set_chunk_size(10.0)
|
195 |
+
|
196 |
+
# lord help me
|
197 |
+
if top_p is not None:
|
198 |
+
if top_p > 0:
|
199 |
+
pass
|
200 |
+
else:
|
201 |
+
top_p = None
|
202 |
+
|
203 |
codes, mask = interface.vamp(
|
204 |
codes, mask,
|
205 |
+
batch_size=2,
|
206 |
+
feedback_steps=num_feedback_steps,
|
207 |
+
_sampling_steps=sampling_steps,
|
208 |
time_stretch_factor=stretch_factor,
|
209 |
return_mask=True,
|
210 |
temperature=sampletemp,
|
211 |
typical_filtering=typical_filtering,
|
212 |
typical_mass=typical_mass,
|
213 |
typical_min_tokens=typical_min_tokens,
|
214 |
+
top_p=top_p,
|
215 |
seed=_seed,
|
216 |
+
sample_cutoff=sample_cutoff,
|
217 |
)
|
218 |
print(f"vamp took {time.time() - t0} seconds")
|
219 |
|
220 |
sig = interface.decode(codes)
|
221 |
+
sig = sig.normalize(loudness)
|
222 |
|
223 |
+
return to_output(sig[0]), to_output(sig[1])
|
224 |
|
225 |
def vamp(data):
|
226 |
return _vamp(
|
|
|
240 |
top_p=data[top_p],
|
241 |
sample_cutoff=data[sample_cutoff],
|
242 |
stretch_factor=data[stretch_factor],
|
243 |
+
sampling_steps=data[sampling_steps],
|
244 |
+
beat_mask_ms=data[beat_mask_ms],
|
245 |
+
num_feedback_steps=data[num_feedback_steps],
|
246 |
api=False,
|
247 |
)
|
248 |
|
249 |
+
|
250 |
+
def api_vamp(input_audio,
|
251 |
+
sampletemp, top_p,
|
252 |
+
periodic_p, periodic_w,
|
253 |
+
dropout,
|
254 |
+
stretch_factor,
|
255 |
+
onset_mask_width,
|
256 |
+
typical_filtering,
|
257 |
+
typical_mass,
|
258 |
+
typical_min_tokens,
|
259 |
+
seed,
|
260 |
+
model_choice,
|
261 |
+
n_mask_codebooks,
|
262 |
+
pitch_shift_amt,
|
263 |
+
sample_cutoff,
|
264 |
+
sampling_steps,
|
265 |
+
beat_mask_ms, num_feedback_steps):
|
|
|
|
|
|
|
|
|
|
|
266 |
return _vamp(
|
267 |
seed=seed,
|
268 |
input_audio=input_audio,
|
|
|
280 |
top_p=top_p,
|
281 |
sample_cutoff=sample_cutoff,
|
282 |
stretch_factor=stretch_factor,
|
283 |
+
sampling_steps=sampling_steps,
|
284 |
+
beat_mask_ms=beat_mask_ms,
|
285 |
+
num_feedback_steps=num_feedback_steps,
|
286 |
api=True,
|
287 |
)
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
with gr.Blocks() as demo:
|
290 |
with gr.Row():
|
291 |
with gr.Column():
|
|
|
379 |
value=1,
|
380 |
)
|
381 |
|
382 |
+
beat_mask_ms = gr.Number(
|
383 |
+
label="beat mask width (milliseconds)",
|
384 |
+
value=0,
|
385 |
+
)
|
386 |
+
|
387 |
|
388 |
with gr.Accordion("sampling settings", open=False):
|
389 |
sampletemp = gr.Slider(
|
|
|
424 |
value=1.0,
|
425 |
step=0.01
|
426 |
)
|
427 |
+
sampling_steps = gr.Slider(
|
428 |
+
label="sampling steps",
|
429 |
+
minimum=1,
|
430 |
+
maximum=128,
|
431 |
+
step=1,
|
432 |
+
value=36
|
433 |
+
)
|
434 |
+
num_feedback_steps = gr.Slider(
|
435 |
+
label="feedback steps",
|
436 |
+
minimum=1,
|
437 |
+
maximum=16,
|
438 |
+
step=1,
|
439 |
+
value=1
|
440 |
+
)
|
441 |
+
|
442 |
+
|
443 |
|
444 |
|
445 |
dropout = gr.Slider(
|
|
|
474 |
|
475 |
audio_outs = []
|
476 |
use_as_input_btns = []
|
477 |
+
for i in range(2):
|
478 |
with gr.Column():
|
479 |
audio_outs.append(gr.Audio(
|
480 |
label=f"output audio {i+1}",
|
|
|
507 |
n_mask_codebooks,
|
508 |
pitch_shift_amt,
|
509 |
sample_cutoff,
|
510 |
+
sampling_steps,
|
511 |
+
beat_mask_ms,
|
512 |
+
num_feedback_steps
|
513 |
}
|
514 |
|
515 |
# connect widgets
|
516 |
vamp_button.click(
|
517 |
fn=vamp,
|
518 |
inputs=_inputs,
|
519 |
+
outputs=[audio_outs[0], audio_outs[1]],
|
520 |
)
|
521 |
|
522 |
api_vamp_button = gr.Button("api vamp", visible=True)
|
|
|
535 |
model_choice,
|
536 |
n_mask_codebooks,
|
537 |
pitch_shift_amt,
|
538 |
+
sample_cutoff,
|
539 |
+
sampling_steps,
|
540 |
+
beat_mask_ms,
|
541 |
+
num_feedback_steps
|
542 |
],
|
543 |
+
outputs=[audio_outs[0], audio_outs[1]],
|
544 |
api_name="vamp"
|
545 |
)
|
546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
|
548 |
try:
|
549 |
demo.queue()
|
conf/generated/ivo/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/ivo/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- ./scratch/miguel/ivo/separated
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/ivo/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/ivo/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- ./scratch/miguel/ivo/separated
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/ivo/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - ./scratch/miguel/ivo/separated
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/ivo/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/ivo/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/lazaro-ros-sep/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/lazaro-ros-sep/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- ./scratch/miguel/lazaro-ros/separated
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/lazaro-ros-sep/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/lazaro-ros-sep/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- ./scratch/miguel/lazaro-ros/separated
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/lazaro-ros-sep/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - ./scratch/miguel/lazaro-ros/separated
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/lazaro-ros-sep/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/lazaro-ros-sep/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/lazaro-ros/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/lazaro-ros/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- ./scratch/miguel/lazaro-ros
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/lazaro-ros/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/lazaro-ros/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- ./scratch/miguel/lazaro-ros
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/lazaro-ros/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - ./scratch/miguel/lazaro-ros
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/lazaro-ros/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/lazaro-ros/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/march-31/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/march-31/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- sound-journal-march-31
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/march-31/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/march-31/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- sound-journal-march-31
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/march-31/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - sound-journal-march-31
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/march-31/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/march-31/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/sax-new/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/sax-new/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- ./scratch/miguel/saxophone-new/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/sax-new/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/sax-new/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- ./scratch/miguel/saxophone-new/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/sax-new/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - ./scratch/miguel/saxophone-new/
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/sax-new/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/sax-new/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/saxophone/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/saxophone/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- scratch/sounds
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/saxophone/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/saxophone/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- scratch/sounds
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/saxophone/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - scratch/sounds
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/saxophone/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/saxophone/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/lora/lora-s2s.yml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioDataset.n_examples: 100000000
|
7 |
+
val/AudioDataset.n_examples: 500
|
8 |
+
|
9 |
+
|
10 |
+
NoamScheduler.warmup: 500
|
11 |
+
|
12 |
+
batch_size: 7
|
13 |
+
num_workers: 7
|
14 |
+
save_iters: [2000, 4000, 10000,20000, 40000, 100000]
|
15 |
+
sample_freq: 2000
|
16 |
+
val_freq: 1000
|
17 |
+
|
18 |
+
AdamW.lr: 0.0001
|
19 |
+
|
20 |
+
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
+
AudioDataset.without_replacement: False
|
22 |
+
num_iters: 500000
|
23 |
+
|
24 |
+
|
25 |
+
# control signals to use as conditioning.
|
26 |
+
Sketch2SoundController.ctrl_keys: ['rmsq16',]
|
27 |
+
|
conf/lora/lora.yml
CHANGED
@@ -19,4 +19,4 @@ AdamW.lr: 0.0001
|
|
19 |
|
20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
AudioDataset.without_replacement: False
|
22 |
-
num_iters: 500000
|
|
|
19 |
|
20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
AudioDataset.without_replacement: False
|
22 |
+
num_iters: 500000
|
scripts/exp/export.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
-
run_dir = Path("runs/
|
4 |
name = run_dir.name
|
5 |
|
6 |
repo_dir = Path("models/vampnet")
|
7 |
|
8 |
-
|
9 |
for part in ("coarse", "c2f"):
|
10 |
outdir = repo_dir / "loras" / name
|
11 |
outdir.mkdir(parents=True, exist_ok=True)
|
@@ -16,7 +15,7 @@ for part in ("coarse", "c2f"):
|
|
16 |
|
17 |
# now, push to hub
|
18 |
from huggingface_hub import Repository
|
19 |
-
repo = Repository(repo_dir,
|
20 |
repo.push_to_hub(
|
21 |
commit_message=f"add {name}"
|
22 |
)
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
+
run_dir = Path("runs/lazaro-ros-sep")
|
4 |
name = run_dir.name
|
5 |
|
6 |
repo_dir = Path("models/vampnet")
|
7 |
|
|
|
8 |
for part in ("coarse", "c2f"):
|
9 |
outdir = repo_dir / "loras" / name
|
10 |
outdir.mkdir(parents=True, exist_ok=True)
|
|
|
15 |
|
16 |
# now, push to hub
|
17 |
from huggingface_hub import Repository
|
18 |
+
repo = Repository(str(repo_dir), git_user="hugofloresgarcia", git_email="[email protected]")
|
19 |
repo.push_to_hub(
|
20 |
commit_message=f"add {name}"
|
21 |
)
|
scripts/exp/train.py
CHANGED
@@ -18,6 +18,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
|
|
21 |
from vampnet.util import codebook_unflatten, codebook_flatten
|
22 |
from vampnet import mask as pmask
|
23 |
# from dac.model.dac import DAC
|
@@ -66,6 +67,8 @@ AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
|
|
66 |
|
67 |
IGNORE_INDEX = -100
|
68 |
|
|
|
|
|
69 |
|
70 |
@argbind.bind("train", "val", without_prefix=True)
|
71 |
def build_transform():
|
@@ -118,6 +121,36 @@ def add_num_params_repr_hook(model):
|
|
118 |
|
119 |
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
def accuracy(
|
123 |
preds: torch.Tensor,
|
@@ -184,6 +217,8 @@ def _metrics(z_hat, r, target, flat_mask, output):
|
|
184 |
class State:
|
185 |
model: VampNet
|
186 |
codec: DAC
|
|
|
|
|
187 |
|
188 |
optimizer: AdamW
|
189 |
scheduler: NoamScheduler
|
@@ -218,6 +253,11 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
|
|
218 |
mask = pmask.random(z, r)
|
219 |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
220 |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
223 |
|
@@ -266,6 +306,22 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
|
|
266 |
|
267 |
return {k: v for k, v in sorted(output.items())}
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
@timer()
|
271 |
@torch.no_grad()
|
@@ -561,6 +617,8 @@ def load(
|
|
561 |
# load the datasets
|
562 |
train_data, val_data = build_datasets(args, sample_rate)
|
563 |
|
|
|
|
|
564 |
return State(
|
565 |
tracker=tracker,
|
566 |
model=model,
|
@@ -572,6 +630,7 @@ def load(
|
|
572 |
train_data=train_data,
|
573 |
val_data=val_data,
|
574 |
grad_clip_val=grad_clip_val,
|
|
|
575 |
)
|
576 |
|
577 |
|
@@ -612,6 +671,7 @@ def train(
|
|
612 |
tracker=tracker,
|
613 |
save_path=save_path)
|
614 |
print("initialized state.")
|
|
|
615 |
|
616 |
train_dataloader = accel.prepare_dataloader(
|
617 |
state.train_data,
|
|
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
21 |
+
# from vampnet.control import Sketch2SoundController
|
22 |
from vampnet.util import codebook_unflatten, codebook_flatten
|
23 |
from vampnet import mask as pmask
|
24 |
# from dac.model.dac import DAC
|
|
|
67 |
|
68 |
IGNORE_INDEX = -100
|
69 |
|
70 |
+
# Sketch2SoundController = argbind.bind(Sketch2SoundController)
|
71 |
+
|
72 |
|
73 |
@argbind.bind("train", "val", without_prefix=True)
|
74 |
def build_transform():
|
|
|
121 |
|
122 |
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
123 |
|
124 |
+
def get_controls(state, sig: at.AudioSignal):
|
125 |
+
# get controls
|
126 |
+
n_batch = sig.samples.shape[0]
|
127 |
+
if state.controller is not None:
|
128 |
+
ctrls = state.controller.extract(sig)
|
129 |
+
# draw control masks
|
130 |
+
ctrl_masks = state.controller.random_mask(
|
131 |
+
ctrls,
|
132 |
+
r=state.rng.draw(n_batch)[:, 0].to(state.device)
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
ctrls = None
|
136 |
+
ctrl_masks = None
|
137 |
+
|
138 |
+
return ctrls, ctrl_masks
|
139 |
+
|
140 |
+
|
141 |
+
def generate_z_mask(state, z, vn, n_batch, ctrl_masks=None):
|
142 |
+
r = state.rng.draw(n_batch)[:, 0].to(state.device)
|
143 |
+
|
144 |
+
mask, ii = state.model.random_mask(z, r)
|
145 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
146 |
+
|
147 |
+
# outpaint?
|
148 |
+
# if state.outpaint_prob > 0:
|
149 |
+
# if flip_coin(state.outpaint_prob):
|
150 |
+
# mask, ctrl_masks = state.build_tria_mask(mask, ctrl_masks)
|
151 |
+
z_mask = pmask.apply_mask(z, mask, vn.mask_token)
|
152 |
+
|
153 |
+
return z_mask, mask, ii, r, ctrl_masks
|
154 |
|
155 |
def accuracy(
|
156 |
preds: torch.Tensor,
|
|
|
217 |
class State:
|
218 |
model: VampNet
|
219 |
codec: DAC
|
220 |
+
# controller: Sketch2SoundController
|
221 |
+
controller: Optional[object]
|
222 |
|
223 |
optimizer: AdamW
|
224 |
scheduler: NoamScheduler
|
|
|
253 |
mask = pmask.random(z, r)
|
254 |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
255 |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
256 |
+
|
257 |
+
# get controls
|
258 |
+
ctrls, ctrl_masks = get_controls(state, signal)
|
259 |
+
|
260 |
+
# TODO: KEEP INCORPORATING ZMASK CODE
|
261 |
|
262 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
263 |
|
|
|
306 |
|
307 |
return {k: v for k, v in sorted(output.items())}
|
308 |
|
309 |
+
# def get_controls(self, sig: sn.Signal, controller):
|
310 |
+
# # get controls
|
311 |
+
# n_batch = sig.wav.shape[0]
|
312 |
+
# if self.controller is not None:
|
313 |
+
# ctrls = self.controller.extract(sig)
|
314 |
+
# # draw control masks
|
315 |
+
# ctrl_masks = self.controller.random_mask(
|
316 |
+
# ctrls,
|
317 |
+
# r=self.rng.draw(n_batch)[:, 0].to(self.device)
|
318 |
+
# )
|
319 |
+
# else:
|
320 |
+
# ctrls = None
|
321 |
+
# ctrl_masks = None
|
322 |
+
|
323 |
+
# return ctrls, ctrl_masks
|
324 |
+
|
325 |
|
326 |
@timer()
|
327 |
@torch.no_grad()
|
|
|
617 |
# load the datasets
|
618 |
train_data, val_data = build_datasets(args, sample_rate)
|
619 |
|
620 |
+
# controller = Sketch2SoundController(sample_rate=sample_rate, hop_length=codec.hop_length)
|
621 |
+
|
622 |
return State(
|
623 |
tracker=tracker,
|
624 |
model=model,
|
|
|
630 |
train_data=train_data,
|
631 |
val_data=val_data,
|
632 |
grad_clip_val=grad_clip_val,
|
633 |
+
controller=None,
|
634 |
)
|
635 |
|
636 |
|
|
|
671 |
tracker=tracker,
|
672 |
save_path=save_path)
|
673 |
print("initialized state.")
|
674 |
+
state.device = accel.device
|
675 |
|
676 |
train_dataloader = accel.prepare_dataloader(
|
677 |
state.train_data,
|
token_telephone/tt.py
CHANGED
@@ -16,10 +16,25 @@ import numpy as np
|
|
16 |
import torch
|
17 |
from einops import rearrange
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
PROFILE = False
|
20 |
DEBUG = False
|
21 |
DEBUG_NO_VAMPNET = False
|
22 |
set_debug(DEBUG)
|
|
|
23 |
# if DEBUG:
|
24 |
# import gc
|
25 |
# # log when gc start and stops
|
@@ -80,19 +95,6 @@ Thread(target=draw_intro_screen).start()
|
|
80 |
from audiotools import AudioSignal
|
81 |
from vamp_helper import load_interface, ez_variation
|
82 |
|
83 |
-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
84 |
-
# ~~~~~~ configs! ~~~~~~~~
|
85 |
-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
86 |
-
|
87 |
-
MAX_LOUDNESS = -20
|
88 |
-
MIN_LOUDNESS = -40
|
89 |
-
COLS = 40
|
90 |
-
ROWS = 13
|
91 |
-
|
92 |
-
device = 'Scarlett 4i4 4th Gen'
|
93 |
-
sample_rate = 48000
|
94 |
-
num_channels = 4
|
95 |
-
blocksize = 16384
|
96 |
|
97 |
|
98 |
# TODO:
|
|
|
16 |
import torch
|
17 |
from einops import rearrange
|
18 |
|
19 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
20 |
+
# ~~~~~~ configs! ~~~~~~~~
|
21 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
22 |
+
|
23 |
+
MAX_LOUDNESS = -20
|
24 |
+
MIN_LOUDNESS = -40
|
25 |
+
COLS = 40
|
26 |
+
ROWS = 13
|
27 |
+
|
28 |
+
device = 'Scarlett 4i4 4th Gen'
|
29 |
+
sample_rate = 48000
|
30 |
+
num_channels = 4
|
31 |
+
blocksize = 16384
|
32 |
+
|
33 |
PROFILE = False
|
34 |
DEBUG = False
|
35 |
DEBUG_NO_VAMPNET = False
|
36 |
set_debug(DEBUG)
|
37 |
+
|
38 |
# if DEBUG:
|
39 |
# import gc
|
40 |
# # log when gc start and stops
|
|
|
95 |
from audiotools import AudioSignal
|
96 |
from vamp_helper import load_interface, ez_variation
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
# TODO:
|
vampnet/beats.py
CHANGED
@@ -213,10 +213,11 @@ class WaveBeat(BeatTracker):
|
|
213 |
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
214 |
"""returns beat and downbeat times, in seconds"""
|
215 |
# extract beats
|
|
|
216 |
beats, downbeats = self.model.predict_beats_from_array(
|
217 |
audio=signal.audio_data.squeeze(0),
|
218 |
sr=signal.sample_rate,
|
219 |
-
use_gpu=
|
220 |
)
|
221 |
|
222 |
return beats, downbeats
|
|
|
213 |
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
214 |
"""returns beat and downbeat times, in seconds"""
|
215 |
# extract beats
|
216 |
+
self.model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
217 |
beats, downbeats = self.model.predict_beats_from_array(
|
218 |
audio=signal.audio_data.squeeze(0),
|
219 |
sr=signal.sample_rate,
|
220 |
+
use_gpu=torch.cuda.is_available(),
|
221 |
)
|
222 |
|
223 |
return beats, downbeats
|
vampnet/control.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from functools import partial
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
import vampnet.dsp.signal as sn
|
8 |
+
from vampnet.dsp.signal import Signal
|
9 |
+
from vampnet.mask import random_along_time
|
10 |
+
from torch import Tensor
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
class MedianFilterAugment(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
kernel_size: int,
|
18 |
+
train_min: int = 1,
|
19 |
+
train_max: int = 20,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.kernel_size = kernel_size
|
23 |
+
self.train_min = train_min
|
24 |
+
self.train_max = train_max
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
if self.training:
|
28 |
+
sizes = torch.randint(
|
29 |
+
self.train_min,
|
30 |
+
self.train_max,
|
31 |
+
size=(x.shape[0],)
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
sizes = self.kernel_size
|
35 |
+
# print(f"median filter sizes: {sizes}")
|
36 |
+
return sn.median_filter_1d(x, sizes)
|
37 |
+
|
38 |
+
class RMS(nn.Module):
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
hop_length,
|
42 |
+
window_length=2048,
|
43 |
+
n_quantize=None,
|
44 |
+
sample_rate=44100,
|
45 |
+
median_filter_size: Optional[int] = None,
|
46 |
+
train_median_filter_min=1,
|
47 |
+
train_median_filter_max=15,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.hop_length = hop_length
|
52 |
+
self.window_length = window_length
|
53 |
+
self.n_quantize = n_quantize
|
54 |
+
self.sample_rate = sample_rate
|
55 |
+
|
56 |
+
self.mf = MedianFilterAugment(
|
57 |
+
kernel_size=median_filter_size,
|
58 |
+
train_min=train_median_filter_min,
|
59 |
+
train_max=train_median_filter_max
|
60 |
+
) if median_filter_size is not None else None
|
61 |
+
|
62 |
+
@property
|
63 |
+
def dim(self):
|
64 |
+
return 1
|
65 |
+
|
66 |
+
def extract(self, sig: Signal) -> Tensor:
|
67 |
+
rmsd = sn.rms(sig,
|
68 |
+
window_length=self.window_length,
|
69 |
+
hop_length=self.hop_length,
|
70 |
+
)[:, :, :-1] # TODO: cutting the last frame to match DAC tokens but why :'(
|
71 |
+
nb, _, _ = rmsd.shape
|
72 |
+
|
73 |
+
if self.n_quantize is not None:
|
74 |
+
# standardize to 0-1
|
75 |
+
rmsd = (rmsd - rmsd.min()) / (rmsd.max() - rmsd.min())
|
76 |
+
|
77 |
+
# quantize to 128 steps
|
78 |
+
rmsd = torch.round(rmsd * self.n_quantize)
|
79 |
+
rmsd = rmsd / self.n_quantize
|
80 |
+
|
81 |
+
if self.mf is not None:
|
82 |
+
rmsd = self.mf(rmsd)
|
83 |
+
|
84 |
+
return rmsd
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
class HarmonicChroma(nn.Module):
|
89 |
+
|
90 |
+
def __init__(self,
|
91 |
+
hop_length: int, window_length: int = 4096,
|
92 |
+
n_chroma: int = 48, sample_rate: int = 44100,
|
93 |
+
top_n: int = 0
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
from torchaudio.prototype.transforms import ChromaScale
|
97 |
+
self.hop_length = hop_length
|
98 |
+
self.window_length = window_length
|
99 |
+
self.n_chroma = n_chroma
|
100 |
+
self.sample_rate = sample_rate
|
101 |
+
self.top_n = top_n
|
102 |
+
|
103 |
+
# HUGO: this representation, as is,
|
104 |
+
# encodes timbre information in the chroma
|
105 |
+
# which is not what we want!!!
|
106 |
+
# would a median filter help perhaps?
|
107 |
+
self.chroma = ChromaScale(
|
108 |
+
sample_rate=self.sample_rate,
|
109 |
+
n_freqs=self.window_length // 2 + 1,
|
110 |
+
n_chroma=self.n_chroma,
|
111 |
+
octwidth=5.0,
|
112 |
+
)
|
113 |
+
|
114 |
+
@property
|
115 |
+
def dim(self):
|
116 |
+
return self.n_chroma
|
117 |
+
|
118 |
+
def extract(self, sig: Signal) -> Tensor:
|
119 |
+
from vampnet.dsp.hpss import hpss
|
120 |
+
self.chroma.to(sig.wav.device)
|
121 |
+
|
122 |
+
# spectrogram
|
123 |
+
spec = sn.stft(sig,
|
124 |
+
window_length=self.window_length,
|
125 |
+
hop_length=self.hop_length
|
126 |
+
)
|
127 |
+
# magnitude
|
128 |
+
spec = torch.abs(spec)
|
129 |
+
|
130 |
+
# hpss
|
131 |
+
spec = hpss(spec, kernel_size=51, hard=True)[0]
|
132 |
+
|
133 |
+
# chroma
|
134 |
+
chroma = self.chroma(spec)
|
135 |
+
|
136 |
+
# get the rms of this spec
|
137 |
+
rms_d = sn.rms_from_spec(
|
138 |
+
spec, window_length=self.window_length
|
139 |
+
)
|
140 |
+
|
141 |
+
# convert the rms to db
|
142 |
+
rms_d = 10 * torch.log10(rms_d + 1e-7)
|
143 |
+
|
144 |
+
# make a mask based on the rms < -40
|
145 |
+
mask = torch.where(rms_d < -40, torch.zeros_like(rms_d), torch.ones_like(rms_d))
|
146 |
+
|
147 |
+
# remove anything below 80 (where the fuck did I get this number from?)
|
148 |
+
chroma = torch.where(chroma < 100, torch.zeros_like(chroma), chroma)
|
149 |
+
|
150 |
+
# Get top 2 values and indices along the -2 dimension
|
151 |
+
if self.top_n:
|
152 |
+
_, topk_indices = torch.topk(chroma, self.top_n, dim=-2)
|
153 |
+
|
154 |
+
# Create a mask for the top 2 values
|
155 |
+
topk_mask = torch.zeros_like(chroma).scatter_(-2, topk_indices, 1.0)
|
156 |
+
|
157 |
+
# Retain only the top 2 values
|
158 |
+
chroma = chroma * topk_mask
|
159 |
+
|
160 |
+
# apply the mask
|
161 |
+
chroma = chroma * mask.unsqueeze(-2)
|
162 |
+
|
163 |
+
# Apply softmax along dim=-2
|
164 |
+
if self.top_n > 0:
|
165 |
+
chroma = torch.nn.functional.softmax(chroma, dim=-2)
|
166 |
+
|
167 |
+
# mask out any timesteps whose chroma have all equal values (all 0s before softmax)
|
168 |
+
# TODO: i did this with chatgpt, there's gott a be a better way
|
169 |
+
chroma_mean = chroma.mean(dim=-2, keepdim=True)
|
170 |
+
chroma_diff = torch.abs(chroma - chroma_mean)
|
171 |
+
equal_mask = torch.all(chroma_diff < 1e-6, dim=-2, keepdim=True)
|
172 |
+
|
173 |
+
# Set chroma values to zero for timesteps with all equal values
|
174 |
+
chroma = torch.where(equal_mask, torch.zeros_like(chroma), chroma)
|
175 |
+
|
176 |
+
|
177 |
+
return chroma[:, 0, :, :-1] # mono only :( FIX ME!
|
178 |
+
|
179 |
+
|
180 |
+
# TODO: try harmonic mel?
|
181 |
+
|
182 |
+
CONTROLLERS = {
|
183 |
+
"rms": RMS,
|
184 |
+
"rmsq128": partial(RMS, n_quantize=128),
|
185 |
+
"rmsq16": partial(RMS, n_quantize=16),
|
186 |
+
"rms-median": partial(RMS, median_filter_size=5),
|
187 |
+
"rmsq16-median": partial(RMS, n_quantize=16, median_filter_size=3),
|
188 |
+
"hchroma": HarmonicChroma,
|
189 |
+
"hchroma-12c-top2": partial(HarmonicChroma, n_chroma=12, top_n=2), # TODO: refactor me. If this works, this should just be named hchroma.
|
190 |
+
"hchroma-36c-top3": partial(HarmonicChroma, n_chroma=36, top_n=3) # TODO: refactor me. If this works, this should just be named hchroma.
|
191 |
+
}
|
192 |
+
|
193 |
+
class Sketch2SoundController(nn.Module):
|
194 |
+
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
ctrl_keys: list[str],
|
198 |
+
hop_length: str,
|
199 |
+
sample_rate: int,
|
200 |
+
):
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
assert all([k in CONTROLLERS for k in ctrl_keys]), f"got an unsupported control key in {ctrl_keys}!\n supported: {CONTROLLERS.keys()}"
|
204 |
+
|
205 |
+
self.hop_length = hop_length
|
206 |
+
self.ctrl_keys = ctrl_keys
|
207 |
+
self.sample_rate = sample_rate
|
208 |
+
|
209 |
+
self.controllers = {
|
210 |
+
k: CONTROLLERS[k](hop_length=hop_length, sample_rate=sample_rate)
|
211 |
+
for k in self.ctrl_keys
|
212 |
+
}
|
213 |
+
|
214 |
+
@property
|
215 |
+
def ctrl_dims(self, ) -> dict[str, int]:
|
216 |
+
return {
|
217 |
+
k: controller.dim for k, controller in self.controllers.items()
|
218 |
+
}
|
219 |
+
|
220 |
+
def extract(self, sig: Signal) -> dict[str, Tensor]:
|
221 |
+
ctrls = {
|
222 |
+
k: controller.extract(sig) for k, controller in self.controllers.items()
|
223 |
+
}
|
224 |
+
return ctrls
|
225 |
+
|
226 |
+
def random_mask(self, ctrls: dict[str, Tensor], r: float):
|
227 |
+
masks = {}
|
228 |
+
for k, ctrl in ctrls.items():
|
229 |
+
masks[k] = 1-random_along_time(ctrl, r)
|
230 |
+
return masks
|
231 |
+
|
232 |
+
def empty_mask(self, ctrls: dict[str, Tensor]):
|
233 |
+
first_key = next(iter(ctrls))
|
234 |
+
mask = torch.zeros_like(ctrls[first_key])
|
235 |
+
return {k: mask for k in ctrls}
|
236 |
+
|
237 |
+
|
238 |
+
def test_controller():
|
239 |
+
controller = Sketch2SoundController(
|
240 |
+
ctrl_keys=["rms-median", "rms", "rmsq128"],
|
241 |
+
hop_length=512,
|
242 |
+
sample_rate=44100
|
243 |
+
)
|
244 |
+
controller.train()
|
245 |
+
# sig = sn.read_from_file("assets/example.wav")
|
246 |
+
# sig = sn.read_from_file("/Users/hugo/Downloads/DCS_SE_FullChoir_ScaleUpDown06_A2_DYN.wav")
|
247 |
+
# sig = sn.excerpt('/Users/hugo/Downloads/(guitarra - hugo mix) bubararu - tambor negro.wav', offset=0, duration=10)
|
248 |
+
sig = sn.read_from_file("assets/voice-prompt.wav")
|
249 |
+
ctrls = controller.extract(sig)
|
250 |
+
print(f"given sig of shape {sig.wav.shape}, extracted controls: {ctrls}")
|
251 |
+
|
252 |
+
# print the whole thing
|
253 |
+
# torch.set_printoptions(profile="full")
|
254 |
+
# print(ctrls["hchroma"][0][0][:, 200:210])
|
255 |
+
|
256 |
+
# imshow the chroma
|
257 |
+
import matplotlib.pyplot as plt
|
258 |
+
|
259 |
+
# Define relative heights for the subplots
|
260 |
+
fig, (ax1, ax2, ax3, ax4) = plt.subplots(
|
261 |
+
4, 1,
|
262 |
+
sharex=True,
|
263 |
+
)
|
264 |
+
|
265 |
+
# Display the spectrogram on the top
|
266 |
+
ax1.imshow(sn.stft(sig, hop_length=512, window_length=2048).abs()[0][0].cpu().log().numpy(), aspect='auto', origin='lower')
|
267 |
+
# display rms on the bottom
|
268 |
+
ax2.plot(ctrls["rms-median"][0][0])
|
269 |
+
ax3.plot(ctrls["rms"][0][0])
|
270 |
+
ax4.plot(ctrls["rmsq128"][0][0])
|
271 |
+
|
272 |
+
plt.tight_layout() # Ensure proper spacing
|
273 |
+
plt.savefig("img.png")
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
test_controller()
|
vampnet/interface.py
CHANGED
@@ -59,7 +59,7 @@ class Interface(torch.nn.Module):
|
|
59 |
coarse2fine_ckpt: str = None,
|
60 |
coarse2fine_lora_ckpt: str = None,
|
61 |
codec_ckpt: str = None,
|
62 |
-
wavebeat_ckpt: str =
|
63 |
device: str = "cpu",
|
64 |
coarse_chunk_size_s: int = 10,
|
65 |
coarse2fine_chunk_size_s: int = 3,
|
@@ -96,7 +96,7 @@ class Interface(torch.nn.Module):
|
|
96 |
|
97 |
if wavebeat_ckpt is not None:
|
98 |
logging.debug(f"loading wavebeat from {wavebeat_ckpt}")
|
99 |
-
self.beat_tracker = WaveBeat(wavebeat_ckpt)
|
100 |
self.beat_tracker.model.to(device)
|
101 |
else:
|
102 |
self.beat_tracker = None
|
@@ -254,6 +254,7 @@ class Interface(torch.nn.Module):
|
|
254 |
"""
|
255 |
assert self.beat_tracker is not None, "No beat tracker loaded"
|
256 |
|
|
|
257 |
# get the beat times
|
258 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
259 |
|
@@ -516,12 +517,19 @@ class Interface(torch.nn.Module):
|
|
516 |
# the forward pass
|
517 |
logging.debug(z.shape)
|
518 |
logging.debug("coarse!")
|
519 |
-
zv
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
|
526 |
# add the top codebooks back in
|
527 |
if zv.shape[1] < z.shape[1]:
|
|
|
59 |
coarse2fine_ckpt: str = None,
|
60 |
coarse2fine_lora_ckpt: str = None,
|
61 |
codec_ckpt: str = None,
|
62 |
+
wavebeat_ckpt: str = "./models/vampnet/wavebeat.pth",
|
63 |
device: str = "cpu",
|
64 |
coarse_chunk_size_s: int = 10,
|
65 |
coarse2fine_chunk_size_s: int = 3,
|
|
|
96 |
|
97 |
if wavebeat_ckpt is not None:
|
98 |
logging.debug(f"loading wavebeat from {wavebeat_ckpt}")
|
99 |
+
self.beat_tracker = WaveBeat(wavebeat_ckpt, device=device)
|
100 |
self.beat_tracker.model.to(device)
|
101 |
else:
|
102 |
self.beat_tracker = None
|
|
|
254 |
"""
|
255 |
assert self.beat_tracker is not None, "No beat tracker loaded"
|
256 |
|
257 |
+
|
258 |
# get the beat times
|
259 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
260 |
|
|
|
517 |
# the forward pass
|
518 |
logging.debug(z.shape)
|
519 |
logging.debug("coarse!")
|
520 |
+
zv = z
|
521 |
+
for i in range(feedback_steps):
|
522 |
+
zv, mask_z = self.coarse_vamp(
|
523 |
+
zv,
|
524 |
+
mask=mask,
|
525 |
+
return_mask=True,
|
526 |
+
**kwargs)
|
527 |
+
# roll the mask around a random amount
|
528 |
+
mask_z = mask_z.roll(
|
529 |
+
shifts=(i + 1) % feedback_steps,
|
530 |
+
dims=-1
|
531 |
+
)
|
532 |
+
|
533 |
|
534 |
# add the top codebooks back in
|
535 |
if zv.shape[1] < z.shape[1]:
|
vampnet/mask.py
CHANGED
@@ -163,14 +163,18 @@ def dropout(
|
|
163 |
mask: torch.Tensor,
|
164 |
p: float,
|
165 |
):
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
mask =
|
|
|
172 |
return mask.long()
|
173 |
|
|
|
|
|
|
|
174 |
def mask_or(
|
175 |
mask1: torch.Tensor,
|
176 |
mask2: torch.Tensor
|
|
|
163 |
mask: torch.Tensor,
|
164 |
p: float,
|
165 |
):
|
166 |
+
# instead of the above, mask along the last dimensions
|
167 |
+
tsteps = mask.shape[-1]
|
168 |
+
tsteps_to_drop = int(tsteps * p)
|
169 |
+
tsteps_to_keep = tsteps - tsteps_to_drop
|
170 |
+
idxs_to_drop = torch.randint(0, tsteps, (tsteps_to_drop,))
|
171 |
+
mask = mask.clone()
|
172 |
+
mask[:, :, idxs_to_drop] = 1
|
173 |
return mask.long()
|
174 |
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
def mask_or(
|
179 |
mask1: torch.Tensor,
|
180 |
mask2: torch.Tensor
|
vampnet/modules/transformer.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
|
|
9 |
from einops import rearrange
|
10 |
import loralib as lora
|
11 |
import audiotools as at
|
@@ -405,7 +406,7 @@ class TransformerStack(nn.Module):
|
|
405 |
)
|
406 |
|
407 |
# Perform last normalization
|
408 |
-
self.norm = RMSNorm(d_model) if last_layer else None
|
409 |
|
410 |
def subsequent_mask(self, size):
|
411 |
return torch.ones(1, size, size).tril().bool()
|
@@ -461,6 +462,75 @@ class TransformerStack(nn.Module):
|
|
461 |
else:
|
462 |
return out
|
463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
|
465 |
class VampNet(at.ml.BaseModel):
|
466 |
def __init__(
|
@@ -475,7 +545,10 @@ class VampNet(at.ml.BaseModel):
|
|
475 |
vocab_size: int = 1024,
|
476 |
flash_attn: bool = True,
|
477 |
noise_mode: str = "mask",
|
478 |
-
dropout: float = 0.1
|
|
|
|
|
|
|
479 |
):
|
480 |
super().__init__()
|
481 |
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
@@ -489,6 +562,11 @@ class VampNet(at.ml.BaseModel):
|
|
489 |
self.latent_dim = latent_dim
|
490 |
self.flash_attn = flash_attn
|
491 |
self.noise_mode = noise_mode
|
|
|
|
|
|
|
|
|
|
|
492 |
|
493 |
assert self.noise_mode == "mask", "deprecated"
|
494 |
|
@@ -525,10 +603,25 @@ class VampNet(at.ml.BaseModel):
|
|
525 |
),
|
526 |
)
|
527 |
|
528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
x = self.embedding(x)
|
530 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
531 |
|
|
|
|
|
|
|
|
|
532 |
x = rearrange(x, "b d n -> b n d")
|
533 |
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
|
534 |
if return_activations:
|
@@ -600,6 +693,8 @@ class VampNet(at.ml.BaseModel):
|
|
600 |
temperature: float = 1.0,
|
601 |
mask: Optional[torch.Tensor] = None,
|
602 |
mask_temperature: float = 10.5,
|
|
|
|
|
603 |
typical_filtering=True,
|
604 |
typical_mass=0.15,
|
605 |
typical_min_tokens=64,
|
@@ -609,7 +704,9 @@ class VampNet(at.ml.BaseModel):
|
|
609 |
return_signal=True,
|
610 |
debug=False,
|
611 |
causal_weight: float = 0.0,
|
|
|
612 |
cfg_guidance: float = None,
|
|
|
613 |
):
|
614 |
if seed is not None:
|
615 |
at.util.seed(seed)
|
@@ -622,6 +719,22 @@ class VampNet(at.ml.BaseModel):
|
|
622 |
z = start_tokens
|
623 |
nb = z.shape[0]
|
624 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
if z is None:
|
626 |
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
627 |
self.device
|
@@ -727,6 +840,7 @@ class VampNet(at.ml.BaseModel):
|
|
727 |
# infer from latents
|
728 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
729 |
logits = self.forward(latents) # b, prob, seq
|
|
|
730 |
|
731 |
if cfg_guidance is not None:
|
732 |
logits_cond, logits_uncond = logits[:nb], logits[nb:]
|
@@ -774,9 +888,6 @@ class VampNet(at.ml.BaseModel):
|
|
774 |
plt.imshow(_mask[0].cpu().numpy())
|
775 |
plt.savefig(f"{STEP_FOLDER}/mask.png")
|
776 |
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
# update the mask, remove conditioning codebooks from the mask
|
781 |
# add z back into sampled z where the mask was false
|
782 |
sampled_z = torch.where(
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
from einops import rearrange
|
11 |
import loralib as lora
|
12 |
import audiotools as at
|
|
|
406 |
)
|
407 |
|
408 |
# Perform last normalization
|
409 |
+
self.norm = RMSNorm(d_model) if last_layer else None
|
410 |
|
411 |
def subsequent_mask(self, size):
|
412 |
return torch.ones(1, size, size).tril().bool()
|
|
|
462 |
else:
|
463 |
return out
|
464 |
|
465 |
+
class CFGDropout(nn.Module):
|
466 |
+
|
467 |
+
def __init__(self, p: float = 0.2):
|
468 |
+
super().__init__()
|
469 |
+
self.p = p
|
470 |
+
|
471 |
+
def forward(self, x: Tensor):
|
472 |
+
# dropout along the batch dim
|
473 |
+
if self.training:
|
474 |
+
mask = torch.rand(x.shape[0], 1, 1, device=x.device) > self.p
|
475 |
+
else:
|
476 |
+
mask = torch.ones(x.shape[0], 1, 1, device=x.device)
|
477 |
+
return x * mask
|
478 |
+
|
479 |
+
class ControlEncoder(nn.Module):
|
480 |
+
|
481 |
+
def __init__(self,
|
482 |
+
ctrl_dims: dict[str, int],
|
483 |
+
embedding_dim: int,
|
484 |
+
cfg_dropout_prob: float
|
485 |
+
):
|
486 |
+
super().__init__()
|
487 |
+
self.ctrl_encoders = nn.ModuleDict({
|
488 |
+
key: nn.Linear(dim, embedding_dim)
|
489 |
+
for key, dim in ctrl_dims.items()
|
490 |
+
})
|
491 |
+
|
492 |
+
self.cfg_dropout = CFGDropout(p=cfg_dropout_prob)
|
493 |
+
self.all_dropout = CFGDropout(p=cfg_dropout_prob / 2)
|
494 |
+
|
495 |
+
def forward(self,
|
496 |
+
embedding: Tensor, # embedding to which we will add ctrls
|
497 |
+
ctrls: dict[str, Tensor],
|
498 |
+
ctrl_masks: dict[str, Tensor]
|
499 |
+
):
|
500 |
+
# INPUT: ctrl tensor should be shape (b d n)
|
501 |
+
|
502 |
+
# assert that we got all the right ctrls and ctrl_masks according to the encoders that we have
|
503 |
+
assert list(sorted(ctrls.keys())) == list(sorted(self.ctrl_encoders.keys())), "ctrls and ctrl_encoders keys do not match"
|
504 |
+
assert list(sorted(ctrl_masks.keys())) == list(sorted(self.ctrl_encoders.keys())), "ctrl_masks and ctrl_encoders keys do not match"
|
505 |
+
|
506 |
+
out_emb = torch.zeros_like(embedding)
|
507 |
+
for ck in ctrls:
|
508 |
+
ctrld = ctrls[ck]
|
509 |
+
ctrlmask = ctrl_masks[ck]
|
510 |
+
|
511 |
+
assert ctrld.shape[-1] == embedding.shape[-1], "ctrls should match x along time dimension"
|
512 |
+
assert ctrlmask.ndim == 2, "ctrlmask should be 2d"
|
513 |
+
assert ctrlmask.shape[-1] == ctrld.shape[-1], "ctrlmask should match ctrld along time dimension"
|
514 |
+
|
515 |
+
# project ctrl with encoder
|
516 |
+
ctrld = rearrange(ctrld, "b d n -> b n d")
|
517 |
+
ctrl_emb = self.ctrl_encoders[ck](ctrld)
|
518 |
+
ctrld = rearrange(ctrld, "b n d -> b d n")
|
519 |
+
ctrl_emb = rearrange(ctrl_emb, "b n d -> b d n")
|
520 |
+
|
521 |
+
# apply ctrl mask
|
522 |
+
ctrl_emb = ctrl_emb * ctrlmask[:, None, :]
|
523 |
+
|
524 |
+
# apply cfg dropout
|
525 |
+
ctrl_emb = self.cfg_dropout(ctrl_emb)
|
526 |
+
|
527 |
+
# add to the out_emb
|
528 |
+
out_emb = out_emb + ctrl_emb
|
529 |
+
|
530 |
+
# randomly dropout all ctrls
|
531 |
+
out_emb = self.all_dropout(out_emb)
|
532 |
+
|
533 |
+
return out_emb
|
534 |
|
535 |
class VampNet(at.ml.BaseModel):
|
536 |
def __init__(
|
|
|
545 |
vocab_size: int = 1024,
|
546 |
flash_attn: bool = True,
|
547 |
noise_mode: str = "mask",
|
548 |
+
dropout: float = 0.1,
|
549 |
+
ctrl_dims: Optional[dict[str, int]] = None,
|
550 |
+
cfg_dropout_prob: float = 0.2,
|
551 |
+
cond_dim: int = 0,
|
552 |
):
|
553 |
super().__init__()
|
554 |
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
|
|
562 |
self.latent_dim = latent_dim
|
563 |
self.flash_attn = flash_attn
|
564 |
self.noise_mode = noise_mode
|
565 |
+
self.cond_dim = cond_dim
|
566 |
+
self.r_cond_dim = r_cond_dim
|
567 |
+
self.dropout = dropout
|
568 |
+
self.cfg_dropout_prob = cfg_dropout_prob
|
569 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
570 |
|
571 |
assert self.noise_mode == "mask", "deprecated"
|
572 |
|
|
|
603 |
),
|
604 |
)
|
605 |
|
606 |
+
if self.cond_dim > 0:
|
607 |
+
self.cfg_dropout = CFGDropout(p=cfg_dropout_prob)
|
608 |
+
|
609 |
+
self.ctrl_dims = ctrl_dims
|
610 |
+
if self.ctrl_dims is not None:
|
611 |
+
self.ctrl_encoder = ControlEncoder(
|
612 |
+
ctrl_dims,
|
613 |
+
embedding_dim=embedding_dim,
|
614 |
+
cfg_dropout_prob=cfg_dropout_prob
|
615 |
+
)
|
616 |
+
|
617 |
+
def forward(self, x, ctrls=None, ctrl_masks=None, return_activations: bool = False):
|
618 |
x = self.embedding(x)
|
619 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
620 |
|
621 |
+
if self.ctrl_dims is not None:
|
622 |
+
# apply controls
|
623 |
+
x = x + self.ctrl_encoder(x, ctrls, ctrl_masks)
|
624 |
+
|
625 |
x = rearrange(x, "b d n -> b n d")
|
626 |
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
|
627 |
if return_activations:
|
|
|
693 |
temperature: float = 1.0,
|
694 |
mask: Optional[torch.Tensor] = None,
|
695 |
mask_temperature: float = 10.5,
|
696 |
+
ctrls:dict = None,
|
697 |
+
ctrl_masks:dict = None,
|
698 |
typical_filtering=True,
|
699 |
typical_mass=0.15,
|
700 |
typical_min_tokens=64,
|
|
|
704 |
return_signal=True,
|
705 |
debug=False,
|
706 |
causal_weight: float = 0.0,
|
707 |
+
cfg_scale: float = 3.0,
|
708 |
cfg_guidance: float = None,
|
709 |
+
cond = None # unused
|
710 |
):
|
711 |
if seed is not None:
|
712 |
at.util.seed(seed)
|
|
|
719 |
z = start_tokens
|
720 |
nb = z.shape[0]
|
721 |
|
722 |
+
use_cfg = ctrls is not None
|
723 |
+
tocfg = lambda x: x.repeat(2, 1, 1) if use_cfg else x
|
724 |
+
tocfgblank = lambda x: torch.cat([x, torch.zeros_like(x)], dim=0) if use_cfg else x
|
725 |
+
def fromcfg(x):
|
726 |
+
if use_cfg:
|
727 |
+
xcond, xuncond = x.chunk(2)
|
728 |
+
return xuncond + cfg_scale * (xcond - xuncond)
|
729 |
+
return x
|
730 |
+
|
731 |
+
z = tocfg(z)
|
732 |
+
if ctrls is not None:
|
733 |
+
ctrls = {k: tocfg(v) for k, v in ctrls.items()}
|
734 |
+
ctrl_masks = {k: tocfgblank(v) for k, v in ctrl_masks.items()}
|
735 |
+
if cond is not None:
|
736 |
+
cond = tocfg(cond)
|
737 |
+
|
738 |
if z is None:
|
739 |
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
740 |
self.device
|
|
|
840 |
# infer from latents
|
841 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
842 |
logits = self.forward(latents) # b, prob, seq
|
843 |
+
logits = fromcfg(logits)
|
844 |
|
845 |
if cfg_guidance is not None:
|
846 |
logits_cond, logits_uncond = logits[:nb], logits[nb:]
|
|
|
888 |
plt.imshow(_mask[0].cpu().numpy())
|
889 |
plt.savefig(f"{STEP_FOLDER}/mask.png")
|
890 |
|
|
|
|
|
|
|
891 |
# update the mask, remove conditioning codebooks from the mask
|
892 |
# add z back into sampled z where the mask was false
|
893 |
sampled_z = torch.where(
|