hugo flores garcia commited on
Commit
cd84ee3
·
1 Parent(s): 39bff10

several mods

Browse files
TODOS ADDED
@@ -0,0 +1 @@
 
 
1
+ [ ] add sketch2sound finetuning
app.py CHANGED
@@ -21,6 +21,7 @@ interface = Interface.default()
21
  init_model_choice = open("DEFAULT_MODEL").read().strip()
22
  # load the init model
23
  interface.load_finetuned(init_model_choice)
 
24
 
25
  def to_output(sig):
26
  return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
@@ -105,9 +106,33 @@ def _vamp(
105
  n_mask_codebooks, periodic_w, onset_mask_width,
106
  dropout, sampletemp, typical_filtering,
107
  typical_mass, typical_min_tokens, top_p,
108
- sample_cutoff, stretch_factor, api=False
109
  ):
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  t0 = time.time()
112
  interface.to("cuda" if torch.cuda.is_available() else "cpu")
113
  print(f"using device {interface.device}")
@@ -121,6 +146,9 @@ def _vamp(
121
 
122
  sig = at.AudioSignal(input_audio, sr).to_mono()
123
 
 
 
 
124
  # reload the model if necessary
125
  interface.load_finetuned(model_choice)
126
 
@@ -129,38 +157,70 @@ def _vamp(
129
 
130
  codes = interface.encode(sig)
131
 
132
- mask = new_vampnet_mask(
133
- interface,
134
- codes,
135
- onset_idxs=onsets(sig, hop_length=interface.codec.hop_length),
136
- width=onset_mask_width,
 
 
 
 
 
 
 
 
 
137
  periodic_prompt=periodic_p,
 
 
 
138
  upper_codebook_mask=n_mask_codebooks,
139
- drop_amt=dropout
140
- ).long()
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- # save the mask as a txt file
143
  interface.set_chunk_size(10.0)
 
 
 
 
 
 
 
 
144
  codes, mask = interface.vamp(
145
  codes, mask,
146
- batch_size=1 if api else 1,
147
- feedback_steps=1,
148
- _sampling_steps=12 if sig.duration <6.0 else 24,
149
  time_stretch_factor=stretch_factor,
150
  return_mask=True,
151
  temperature=sampletemp,
152
  typical_filtering=typical_filtering,
153
  typical_mass=typical_mass,
154
  typical_min_tokens=typical_min_tokens,
155
- top_p=None,
156
  seed=_seed,
157
- sample_cutoff=1.0,
158
  )
159
  print(f"vamp took {time.time() - t0} seconds")
160
 
161
  sig = interface.decode(codes)
 
162
 
163
- return to_output(sig)
164
 
165
  def vamp(data):
166
  return _vamp(
@@ -180,31 +240,29 @@ def vamp(data):
180
  top_p=data[top_p],
181
  sample_cutoff=data[sample_cutoff],
182
  stretch_factor=data[stretch_factor],
 
 
 
183
  api=False,
184
  )
185
 
186
- # def api_vamp(data):
187
- # return _vamp(
188
- # seed=data[seed],
189
- # input_audio=data[input_audio],
190
- # model_choice=data[model_choice],
191
- # pitch_shift_amt=data[pitch_shift_amt],
192
- # periodic_p=data[periodic_p],
193
- # n_mask_codebooks=data[n_mask_codebooks],
194
- # periodic_w=data[periodic_w],
195
- # onset_mask_width=data[onset_mask_width],
196
- # dropout=data[dropout],
197
- # sampletemp=data[sampletemp],
198
- # typical_filtering=data[typical_filtering],
199
- # typical_mass=data[typical_mass],
200
- # typical_min_tokens=data[typical_min_tokens],
201
- # top_p=data[top_p],
202
- # sample_cutoff=data[sample_cutoff],
203
- # stretch_factor=data[stretch_factor],
204
- # api=True,
205
- # )
206
-
207
- def api_vamp(input_audio, sampletemp, top_p, periodic_p, periodic_w, dropout, stretch_factor, onset_mask_width, typical_filtering, typical_mass, typical_min_tokens, seed, model_choice, n_mask_codebooks, pitch_shift_amt, sample_cutoff):
208
  return _vamp(
209
  seed=seed,
210
  input_audio=input_audio,
@@ -222,50 +280,12 @@ def api_vamp(input_audio, sampletemp, top_p, periodic_p, periodic_w, dropout, st
222
  top_p=top_p,
223
  sample_cutoff=sample_cutoff,
224
  stretch_factor=stretch_factor,
 
 
 
225
  api=True,
226
  )
227
 
228
- OUT_DIR = Path("gradio-outputs")
229
- OUT_DIR.mkdir(exist_ok=True)
230
- def harp_vamp(input_audio_file, periodic_p, n_mask_codebooks):
231
- sig = at.AudioSignal(input_audio_file)
232
- sr, samples = sig.sample_rate, sig.samples[0][0].detach().cpu().numpy()
233
- # convert to int32
234
- samples = (samples * np.iinfo(np.int32).max).astype(np.int32)
235
- sr, samples = _vamp(
236
- seed=0,
237
- input_audio=(sr, samples),
238
- model_choice=init_model_choice,
239
- pitch_shift_amt=0,
240
- periodic_p=periodic_p,
241
- n_mask_codebooks=n_mask_codebooks,
242
- periodic_w=1,
243
- onset_mask_width=0,
244
- dropout=0.0,
245
- sampletemp=1.0,
246
- typical_filtering=True,
247
- typical_mass=0.15,
248
- typical_min_tokens=64,
249
- top_p=0.0,
250
- sample_cutoff=1.0,
251
- stretch_factor=1,
252
- )
253
-
254
- sig = at.AudioSignal(samples, sr)
255
- # write to file
256
- # clear the outdir
257
- for p in OUT_DIR.glob("*"):
258
- p.unlink()
259
- OUT_DIR.mkdir(exist_ok=True)
260
- # outpath = OUT_DIR / f"{uuid.uuid4()}.wav"
261
- from pyharp import AudioLabel, LabelList, save_audio
262
- outpath = save_audio(sig)
263
- sig.write(outpath)
264
- output_labels = LabelList()
265
- output_labels.append(AudioLabel(label='~', t=0.0, amplitude=0.5, description='generated audio'))
266
- return outpath, output_labels
267
-
268
-
269
  with gr.Blocks() as demo:
270
  with gr.Row():
271
  with gr.Column():
@@ -359,6 +379,11 @@ with gr.Blocks() as demo:
359
  value=1,
360
  )
361
 
 
 
 
 
 
362
 
363
  with gr.Accordion("sampling settings", open=False):
364
  sampletemp = gr.Slider(
@@ -399,6 +424,22 @@ with gr.Blocks() as demo:
399
  value=1.0,
400
  step=0.01
401
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
 
404
  dropout = gr.Slider(
@@ -433,7 +474,7 @@ with gr.Blocks() as demo:
433
 
434
  audio_outs = []
435
  use_as_input_btns = []
436
- for i in range(1):
437
  with gr.Column():
438
  audio_outs.append(gr.Audio(
439
  label=f"output audio {i+1}",
@@ -466,13 +507,16 @@ with gr.Blocks() as demo:
466
  n_mask_codebooks,
467
  pitch_shift_amt,
468
  sample_cutoff,
 
 
 
469
  }
470
 
471
  # connect widgets
472
  vamp_button.click(
473
  fn=vamp,
474
  inputs=_inputs,
475
- outputs=[audio_outs[0]],
476
  )
477
 
478
  api_vamp_button = gr.Button("api vamp", visible=True)
@@ -491,31 +535,15 @@ with gr.Blocks() as demo:
491
  model_choice,
492
  n_mask_codebooks,
493
  pitch_shift_amt,
494
- sample_cutoff
 
 
 
495
  ],
496
- outputs=[audio_outs[0]],
497
  api_name="vamp"
498
  )
499
 
500
- from pyharp import ModelCard, build_endpoint
501
- card = ModelCard(
502
- name="vampnet",
503
- description="vampnet! is a model for generating audio from audio",
504
- author="hugo flores garcía",
505
- tags=["music generation"],
506
- midi_in=False,
507
- midi_out=False
508
- )
509
-
510
- # Build a HARP-compatible endpoint
511
- app = build_endpoint(model_card=card,
512
- components=[
513
- periodic_p,
514
- n_mask_codebooks,
515
- ],
516
- process_fn=harp_vamp)
517
-
518
-
519
 
520
  try:
521
  demo.queue()
 
21
  init_model_choice = open("DEFAULT_MODEL").read().strip()
22
  # load the init model
23
  interface.load_finetuned(init_model_choice)
24
+ interface.to(device)
25
 
26
  def to_output(sig):
27
  return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
 
106
  n_mask_codebooks, periodic_w, onset_mask_width,
107
  dropout, sampletemp, typical_filtering,
108
  typical_mass, typical_min_tokens, top_p,
109
+ sample_cutoff, stretch_factor, sampling_steps, beat_mask_ms, num_feedback_steps, api=False
110
  ):
111
 
112
+ print("args!")
113
+ print(f"seed: {seed}")
114
+ print(f"input_audio: {input_audio}")
115
+ print(f"model_choice: {model_choice}")
116
+ print(f"pitch_shift_amt: {pitch_shift_amt}")
117
+ print(f"periodic_p: {periodic_p}")
118
+ print(f"n_mask_codebooks: {n_mask_codebooks}")
119
+ print(f"periodic_w: {periodic_w}")
120
+ print(f"onset_mask_width: {onset_mask_width}")
121
+ print(f"dropout: {dropout}")
122
+ print(f"sampletemp: {sampletemp}")
123
+ print(f"typical_filtering: {typical_filtering}")
124
+ print(f"typical_mass: {typical_mass}")
125
+ print(f"typical_min_tokens: {typical_min_tokens}")
126
+ print(f"top_p: {top_p}")
127
+ print(f"sample_cutoff: {sample_cutoff}")
128
+ print(f"stretch_factor: {stretch_factor}")
129
+ print(f"sampling_steps: {sampling_steps}")
130
+ print(f"api: {api}")
131
+ print(f"beat_mask_ms: {beat_mask_ms}")
132
+ print(f"using device {interface.device}")
133
+ print(f"num feedback steps: {num_feedback_steps}")
134
+
135
+
136
  t0 = time.time()
137
  interface.to("cuda" if torch.cuda.is_available() else "cpu")
138
  print(f"using device {interface.device}")
 
146
 
147
  sig = at.AudioSignal(input_audio, sr).to_mono()
148
 
149
+ loudness = sig.loudness()
150
+ sig = interface._preprocess(sig)
151
+
152
  # reload the model if necessary
153
  interface.load_finetuned(model_choice)
154
 
 
157
 
158
  codes = interface.encode(sig)
159
 
160
+ # mask = new_vampnet_mask(
161
+ # interface,
162
+ # codes,
163
+ # onset_idxs=onsets(sig, hop_length=interface.codec.hop_length),
164
+ # width=onset_mask_width,
165
+ # periodic_prompt=periodic_p,
166
+ # upper_codebook_mask=n_mask_codebooks,
167
+ # drop_amt=dropout
168
+ # ).long()
169
+
170
+
171
+ mask = interface.build_mask(
172
+ codes,
173
+ sig=sig,
174
  periodic_prompt=periodic_p,
175
+ periodic_prompt_width=periodic_w,
176
+ onset_mask_width=onset_mask_width,
177
+ _dropout=dropout,
178
  upper_codebook_mask=n_mask_codebooks,
179
+ )
180
+ if beat_mask_ms > 0:
181
+ # bm = pmask.mask_or(
182
+ # pmask.periodic_mask(
183
+ # codes, periodic_p, periodic_w, random_roll=False
184
+ # ),
185
+ # )
186
+ mask = pmask.mask_and(
187
+ mask, interface.make_beat_mask(
188
+ sig, after_beat_s=beat_mask_ms/1000.,
189
+ )
190
+ )
191
+ mask = pmask.codebook_mask(mask, n_mask_codebooks)
192
+ np.savetxt("scratch/rms_mask.txt", mask[0].cpu().numpy(), fmt='%d')
193
 
 
194
  interface.set_chunk_size(10.0)
195
+
196
+ # lord help me
197
+ if top_p is not None:
198
+ if top_p > 0:
199
+ pass
200
+ else:
201
+ top_p = None
202
+
203
  codes, mask = interface.vamp(
204
  codes, mask,
205
+ batch_size=2,
206
+ feedback_steps=num_feedback_steps,
207
+ _sampling_steps=sampling_steps,
208
  time_stretch_factor=stretch_factor,
209
  return_mask=True,
210
  temperature=sampletemp,
211
  typical_filtering=typical_filtering,
212
  typical_mass=typical_mass,
213
  typical_min_tokens=typical_min_tokens,
214
+ top_p=top_p,
215
  seed=_seed,
216
+ sample_cutoff=sample_cutoff,
217
  )
218
  print(f"vamp took {time.time() - t0} seconds")
219
 
220
  sig = interface.decode(codes)
221
+ sig = sig.normalize(loudness)
222
 
223
+ return to_output(sig[0]), to_output(sig[1])
224
 
225
  def vamp(data):
226
  return _vamp(
 
240
  top_p=data[top_p],
241
  sample_cutoff=data[sample_cutoff],
242
  stretch_factor=data[stretch_factor],
243
+ sampling_steps=data[sampling_steps],
244
+ beat_mask_ms=data[beat_mask_ms],
245
+ num_feedback_steps=data[num_feedback_steps],
246
  api=False,
247
  )
248
 
249
+
250
+ def api_vamp(input_audio,
251
+ sampletemp, top_p,
252
+ periodic_p, periodic_w,
253
+ dropout,
254
+ stretch_factor,
255
+ onset_mask_width,
256
+ typical_filtering,
257
+ typical_mass,
258
+ typical_min_tokens,
259
+ seed,
260
+ model_choice,
261
+ n_mask_codebooks,
262
+ pitch_shift_amt,
263
+ sample_cutoff,
264
+ sampling_steps,
265
+ beat_mask_ms, num_feedback_steps):
 
 
 
 
 
266
  return _vamp(
267
  seed=seed,
268
  input_audio=input_audio,
 
280
  top_p=top_p,
281
  sample_cutoff=sample_cutoff,
282
  stretch_factor=stretch_factor,
283
+ sampling_steps=sampling_steps,
284
+ beat_mask_ms=beat_mask_ms,
285
+ num_feedback_steps=num_feedback_steps,
286
  api=True,
287
  )
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  with gr.Blocks() as demo:
290
  with gr.Row():
291
  with gr.Column():
 
379
  value=1,
380
  )
381
 
382
+ beat_mask_ms = gr.Number(
383
+ label="beat mask width (milliseconds)",
384
+ value=0,
385
+ )
386
+
387
 
388
  with gr.Accordion("sampling settings", open=False):
389
  sampletemp = gr.Slider(
 
424
  value=1.0,
425
  step=0.01
426
  )
427
+ sampling_steps = gr.Slider(
428
+ label="sampling steps",
429
+ minimum=1,
430
+ maximum=128,
431
+ step=1,
432
+ value=36
433
+ )
434
+ num_feedback_steps = gr.Slider(
435
+ label="feedback steps",
436
+ minimum=1,
437
+ maximum=16,
438
+ step=1,
439
+ value=1
440
+ )
441
+
442
+
443
 
444
 
445
  dropout = gr.Slider(
 
474
 
475
  audio_outs = []
476
  use_as_input_btns = []
477
+ for i in range(2):
478
  with gr.Column():
479
  audio_outs.append(gr.Audio(
480
  label=f"output audio {i+1}",
 
507
  n_mask_codebooks,
508
  pitch_shift_amt,
509
  sample_cutoff,
510
+ sampling_steps,
511
+ beat_mask_ms,
512
+ num_feedback_steps
513
  }
514
 
515
  # connect widgets
516
  vamp_button.click(
517
  fn=vamp,
518
  inputs=_inputs,
519
+ outputs=[audio_outs[0], audio_outs[1]],
520
  )
521
 
522
  api_vamp_button = gr.Button("api vamp", visible=True)
 
535
  model_choice,
536
  n_mask_codebooks,
537
  pitch_shift_amt,
538
+ sample_cutoff,
539
+ sampling_steps,
540
+ beat_mask_ms,
541
+ num_feedback_steps
542
  ],
543
+ outputs=[audio_outs[0], audio_outs[1]],
544
  api_name="vamp"
545
  )
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
  try:
549
  demo.queue()
conf/generated/ivo/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/ivo/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - ./scratch/miguel/ivo/separated
15
+ val/AudioLoader.sources: *id001
conf/generated/ivo/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/ivo/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - ./scratch/miguel/ivo/separated
8
+ val/AudioLoader.sources: *id001
conf/generated/ivo/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - ./scratch/miguel/ivo/separated
3
+ Interface.coarse2fine_ckpt: ./runs/ivo/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/ivo/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/lazaro-ros-sep/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/lazaro-ros-sep/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - ./scratch/miguel/lazaro-ros/separated
15
+ val/AudioLoader.sources: *id001
conf/generated/lazaro-ros-sep/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/lazaro-ros-sep/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - ./scratch/miguel/lazaro-ros/separated
8
+ val/AudioLoader.sources: *id001
conf/generated/lazaro-ros-sep/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - ./scratch/miguel/lazaro-ros/separated
3
+ Interface.coarse2fine_ckpt: ./runs/lazaro-ros-sep/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/lazaro-ros-sep/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/lazaro-ros/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/lazaro-ros/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - ./scratch/miguel/lazaro-ros
15
+ val/AudioLoader.sources: *id001
conf/generated/lazaro-ros/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/lazaro-ros/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - ./scratch/miguel/lazaro-ros
8
+ val/AudioLoader.sources: *id001
conf/generated/lazaro-ros/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - ./scratch/miguel/lazaro-ros
3
+ Interface.coarse2fine_ckpt: ./runs/lazaro-ros/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/lazaro-ros/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/march-31/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/march-31/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - sound-journal-march-31
15
+ val/AudioLoader.sources: *id001
conf/generated/march-31/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/march-31/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - sound-journal-march-31
8
+ val/AudioLoader.sources: *id001
conf/generated/march-31/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - sound-journal-march-31
3
+ Interface.coarse2fine_ckpt: ./runs/march-31/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/march-31/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/sax-new/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/sax-new/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - ./scratch/miguel/saxophone-new/
15
+ val/AudioLoader.sources: *id001
conf/generated/sax-new/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/sax-new/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - ./scratch/miguel/saxophone-new/
8
+ val/AudioLoader.sources: *id001
conf/generated/sax-new/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - ./scratch/miguel/saxophone-new/
3
+ Interface.coarse2fine_ckpt: ./runs/sax-new/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/sax-new/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/saxophone/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/saxophone/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - scratch/sounds
15
+ val/AudioLoader.sources: *id001
conf/generated/saxophone/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/saxophone/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - scratch/sounds
8
+ val/AudioLoader.sources: *id001
conf/generated/saxophone/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - scratch/sounds
3
+ Interface.coarse2fine_ckpt: ./runs/saxophone/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/saxophone/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/lora/lora-s2s.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 500
8
+
9
+
10
+ NoamScheduler.warmup: 500
11
+
12
+ batch_size: 7
13
+ num_workers: 7
14
+ save_iters: [2000, 4000, 10000,20000, 40000, 100000]
15
+ sample_freq: 2000
16
+ val_freq: 1000
17
+
18
+ AdamW.lr: 0.0001
19
+
20
+ # let's us organize sound classes into folders and choose from those sound classes uniformly
21
+ AudioDataset.without_replacement: False
22
+ num_iters: 500000
23
+
24
+
25
+ # control signals to use as conditioning.
26
+ Sketch2SoundController.ctrl_keys: ['rmsq16',]
27
+
conf/lora/lora.yml CHANGED
@@ -19,4 +19,4 @@ AdamW.lr: 0.0001
19
 
20
  # let's us organize sound classes into folders and choose from those sound classes uniformly
21
  AudioDataset.without_replacement: False
22
- num_iters: 500000
 
19
 
20
  # let's us organize sound classes into folders and choose from those sound classes uniformly
21
  AudioDataset.without_replacement: False
22
+ num_iters: 500000
scripts/exp/export.py CHANGED
@@ -1,11 +1,10 @@
1
  from pathlib import Path
2
 
3
- run_dir = Path("runs/sample-instrument")
4
  name = run_dir.name
5
 
6
  repo_dir = Path("models/vampnet")
7
 
8
-
9
  for part in ("coarse", "c2f"):
10
  outdir = repo_dir / "loras" / name
11
  outdir.mkdir(parents=True, exist_ok=True)
@@ -16,7 +15,7 @@ for part in ("coarse", "c2f"):
16
 
17
  # now, push to hub
18
  from huggingface_hub import Repository
19
- repo = Repository(repo_dir, git_user="hugofloresgarcia", git_email="[email protected]")
20
  repo.push_to_hub(
21
  commit_message=f"add {name}"
22
  )
 
1
  from pathlib import Path
2
 
3
+ run_dir = Path("runs/lazaro-ros-sep")
4
  name = run_dir.name
5
 
6
  repo_dir = Path("models/vampnet")
7
 
 
8
  for part in ("coarse", "c2f"):
9
  outdir = repo_dir / "loras" / name
10
  outdir.mkdir(parents=True, exist_ok=True)
 
15
 
16
  # now, push to hub
17
  from huggingface_hub import Repository
18
+ repo = Repository(str(repo_dir), git_user="hugofloresgarcia", git_email="[email protected]")
19
  repo.push_to_hub(
20
  commit_message=f"add {name}"
21
  )
scripts/exp/train.py CHANGED
@@ -18,6 +18,7 @@ from torch.utils.tensorboard import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
 
21
  from vampnet.util import codebook_unflatten, codebook_flatten
22
  from vampnet import mask as pmask
23
  # from dac.model.dac import DAC
@@ -66,6 +67,8 @@ AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
66
 
67
  IGNORE_INDEX = -100
68
 
 
 
69
 
70
  @argbind.bind("train", "val", without_prefix=True)
71
  def build_transform():
@@ -118,6 +121,36 @@ def add_num_params_repr_hook(model):
118
 
119
  setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def accuracy(
123
  preds: torch.Tensor,
@@ -184,6 +217,8 @@ def _metrics(z_hat, r, target, flat_mask, output):
184
  class State:
185
  model: VampNet
186
  codec: DAC
 
 
187
 
188
  optimizer: AdamW
189
  scheduler: NoamScheduler
@@ -218,6 +253,11 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
218
  mask = pmask.random(z, r)
219
  mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
220
  z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
 
 
 
 
 
221
 
222
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
223
 
@@ -266,6 +306,22 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
266
 
267
  return {k: v for k, v in sorted(output.items())}
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  @timer()
271
  @torch.no_grad()
@@ -561,6 +617,8 @@ def load(
561
  # load the datasets
562
  train_data, val_data = build_datasets(args, sample_rate)
563
 
 
 
564
  return State(
565
  tracker=tracker,
566
  model=model,
@@ -572,6 +630,7 @@ def load(
572
  train_data=train_data,
573
  val_data=val_data,
574
  grad_clip_val=grad_clip_val,
 
575
  )
576
 
577
 
@@ -612,6 +671,7 @@ def train(
612
  tracker=tracker,
613
  save_path=save_path)
614
  print("initialized state.")
 
615
 
616
  train_dataloader = accel.prepare_dataloader(
617
  state.train_data,
 
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
21
+ # from vampnet.control import Sketch2SoundController
22
  from vampnet.util import codebook_unflatten, codebook_flatten
23
  from vampnet import mask as pmask
24
  # from dac.model.dac import DAC
 
67
 
68
  IGNORE_INDEX = -100
69
 
70
+ # Sketch2SoundController = argbind.bind(Sketch2SoundController)
71
+
72
 
73
  @argbind.bind("train", "val", without_prefix=True)
74
  def build_transform():
 
121
 
122
  setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
123
 
124
+ def get_controls(state, sig: at.AudioSignal):
125
+ # get controls
126
+ n_batch = sig.samples.shape[0]
127
+ if state.controller is not None:
128
+ ctrls = state.controller.extract(sig)
129
+ # draw control masks
130
+ ctrl_masks = state.controller.random_mask(
131
+ ctrls,
132
+ r=state.rng.draw(n_batch)[:, 0].to(state.device)
133
+ )
134
+ else:
135
+ ctrls = None
136
+ ctrl_masks = None
137
+
138
+ return ctrls, ctrl_masks
139
+
140
+
141
+ def generate_z_mask(state, z, vn, n_batch, ctrl_masks=None):
142
+ r = state.rng.draw(n_batch)[:, 0].to(state.device)
143
+
144
+ mask, ii = state.model.random_mask(z, r)
145
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
146
+
147
+ # outpaint?
148
+ # if state.outpaint_prob > 0:
149
+ # if flip_coin(state.outpaint_prob):
150
+ # mask, ctrl_masks = state.build_tria_mask(mask, ctrl_masks)
151
+ z_mask = pmask.apply_mask(z, mask, vn.mask_token)
152
+
153
+ return z_mask, mask, ii, r, ctrl_masks
154
 
155
  def accuracy(
156
  preds: torch.Tensor,
 
217
  class State:
218
  model: VampNet
219
  codec: DAC
220
+ # controller: Sketch2SoundController
221
+ controller: Optional[object]
222
 
223
  optimizer: AdamW
224
  scheduler: NoamScheduler
 
253
  mask = pmask.random(z, r)
254
  mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
255
  z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
256
+
257
+ # get controls
258
+ ctrls, ctrl_masks = get_controls(state, signal)
259
+
260
+ # TODO: KEEP INCORPORATING ZMASK CODE
261
 
262
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
263
 
 
306
 
307
  return {k: v for k, v in sorted(output.items())}
308
 
309
+ # def get_controls(self, sig: sn.Signal, controller):
310
+ # # get controls
311
+ # n_batch = sig.wav.shape[0]
312
+ # if self.controller is not None:
313
+ # ctrls = self.controller.extract(sig)
314
+ # # draw control masks
315
+ # ctrl_masks = self.controller.random_mask(
316
+ # ctrls,
317
+ # r=self.rng.draw(n_batch)[:, 0].to(self.device)
318
+ # )
319
+ # else:
320
+ # ctrls = None
321
+ # ctrl_masks = None
322
+
323
+ # return ctrls, ctrl_masks
324
+
325
 
326
  @timer()
327
  @torch.no_grad()
 
617
  # load the datasets
618
  train_data, val_data = build_datasets(args, sample_rate)
619
 
620
+ # controller = Sketch2SoundController(sample_rate=sample_rate, hop_length=codec.hop_length)
621
+
622
  return State(
623
  tracker=tracker,
624
  model=model,
 
630
  train_data=train_data,
631
  val_data=val_data,
632
  grad_clip_val=grad_clip_val,
633
+ controller=None,
634
  )
635
 
636
 
 
671
  tracker=tracker,
672
  save_path=save_path)
673
  print("initialized state.")
674
+ state.device = accel.device
675
 
676
  train_dataloader = accel.prepare_dataloader(
677
  state.train_data,
token_telephone/tt.py CHANGED
@@ -16,10 +16,25 @@ import numpy as np
16
  import torch
17
  from einops import rearrange
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  PROFILE = False
20
  DEBUG = False
21
  DEBUG_NO_VAMPNET = False
22
  set_debug(DEBUG)
 
23
  # if DEBUG:
24
  # import gc
25
  # # log when gc start and stops
@@ -80,19 +95,6 @@ Thread(target=draw_intro_screen).start()
80
  from audiotools import AudioSignal
81
  from vamp_helper import load_interface, ez_variation
82
 
83
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
84
- # ~~~~~~ configs! ~~~~~~~~
85
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86
-
87
- MAX_LOUDNESS = -20
88
- MIN_LOUDNESS = -40
89
- COLS = 40
90
- ROWS = 13
91
-
92
- device = 'Scarlett 4i4 4th Gen'
93
- sample_rate = 48000
94
- num_channels = 4
95
- blocksize = 16384
96
 
97
 
98
  # TODO:
 
16
  import torch
17
  from einops import rearrange
18
 
19
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20
+ # ~~~~~~ configs! ~~~~~~~~
21
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22
+
23
+ MAX_LOUDNESS = -20
24
+ MIN_LOUDNESS = -40
25
+ COLS = 40
26
+ ROWS = 13
27
+
28
+ device = 'Scarlett 4i4 4th Gen'
29
+ sample_rate = 48000
30
+ num_channels = 4
31
+ blocksize = 16384
32
+
33
  PROFILE = False
34
  DEBUG = False
35
  DEBUG_NO_VAMPNET = False
36
  set_debug(DEBUG)
37
+
38
  # if DEBUG:
39
  # import gc
40
  # # log when gc start and stops
 
95
  from audiotools import AudioSignal
96
  from vamp_helper import load_interface, ez_variation
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
 
100
  # TODO:
vampnet/beats.py CHANGED
@@ -213,10 +213,11 @@ class WaveBeat(BeatTracker):
213
  def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
214
  """returns beat and downbeat times, in seconds"""
215
  # extract beats
 
216
  beats, downbeats = self.model.predict_beats_from_array(
217
  audio=signal.audio_data.squeeze(0),
218
  sr=signal.sample_rate,
219
- use_gpu=self.device != "cpu",
220
  )
221
 
222
  return beats, downbeats
 
213
  def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
214
  """returns beat and downbeat times, in seconds"""
215
  # extract beats
216
+ self.model.to('cuda' if torch.cuda.is_available() else 'cpu')
217
  beats, downbeats = self.model.predict_beats_from_array(
218
  audio=signal.audio_data.squeeze(0),
219
  sr=signal.sample_rate,
220
+ use_gpu=torch.cuda.is_available(),
221
  )
222
 
223
  return beats, downbeats
vampnet/control.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+ from typing import Optional
4
+
5
+ from torch import nn
6
+
7
+ import vampnet.dsp.signal as sn
8
+ from vampnet.dsp.signal import Signal
9
+ from vampnet.mask import random_along_time
10
+ from torch import Tensor
11
+ import torch
12
+
13
+
14
+ class MedianFilterAugment(nn.Module):
15
+
16
+ def __init__(self,
17
+ kernel_size: int,
18
+ train_min: int = 1,
19
+ train_max: int = 20,
20
+ ):
21
+ super().__init__()
22
+ self.kernel_size = kernel_size
23
+ self.train_min = train_min
24
+ self.train_max = train_max
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ if self.training:
28
+ sizes = torch.randint(
29
+ self.train_min,
30
+ self.train_max,
31
+ size=(x.shape[0],)
32
+ )
33
+ else:
34
+ sizes = self.kernel_size
35
+ # print(f"median filter sizes: {sizes}")
36
+ return sn.median_filter_1d(x, sizes)
37
+
38
+ class RMS(nn.Module):
39
+
40
+ def __init__(self,
41
+ hop_length,
42
+ window_length=2048,
43
+ n_quantize=None,
44
+ sample_rate=44100,
45
+ median_filter_size: Optional[int] = None,
46
+ train_median_filter_min=1,
47
+ train_median_filter_max=15,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.hop_length = hop_length
52
+ self.window_length = window_length
53
+ self.n_quantize = n_quantize
54
+ self.sample_rate = sample_rate
55
+
56
+ self.mf = MedianFilterAugment(
57
+ kernel_size=median_filter_size,
58
+ train_min=train_median_filter_min,
59
+ train_max=train_median_filter_max
60
+ ) if median_filter_size is not None else None
61
+
62
+ @property
63
+ def dim(self):
64
+ return 1
65
+
66
+ def extract(self, sig: Signal) -> Tensor:
67
+ rmsd = sn.rms(sig,
68
+ window_length=self.window_length,
69
+ hop_length=self.hop_length,
70
+ )[:, :, :-1] # TODO: cutting the last frame to match DAC tokens but why :'(
71
+ nb, _, _ = rmsd.shape
72
+
73
+ if self.n_quantize is not None:
74
+ # standardize to 0-1
75
+ rmsd = (rmsd - rmsd.min()) / (rmsd.max() - rmsd.min())
76
+
77
+ # quantize to 128 steps
78
+ rmsd = torch.round(rmsd * self.n_quantize)
79
+ rmsd = rmsd / self.n_quantize
80
+
81
+ if self.mf is not None:
82
+ rmsd = self.mf(rmsd)
83
+
84
+ return rmsd
85
+
86
+
87
+
88
+ class HarmonicChroma(nn.Module):
89
+
90
+ def __init__(self,
91
+ hop_length: int, window_length: int = 4096,
92
+ n_chroma: int = 48, sample_rate: int = 44100,
93
+ top_n: int = 0
94
+ ):
95
+ super().__init__()
96
+ from torchaudio.prototype.transforms import ChromaScale
97
+ self.hop_length = hop_length
98
+ self.window_length = window_length
99
+ self.n_chroma = n_chroma
100
+ self.sample_rate = sample_rate
101
+ self.top_n = top_n
102
+
103
+ # HUGO: this representation, as is,
104
+ # encodes timbre information in the chroma
105
+ # which is not what we want!!!
106
+ # would a median filter help perhaps?
107
+ self.chroma = ChromaScale(
108
+ sample_rate=self.sample_rate,
109
+ n_freqs=self.window_length // 2 + 1,
110
+ n_chroma=self.n_chroma,
111
+ octwidth=5.0,
112
+ )
113
+
114
+ @property
115
+ def dim(self):
116
+ return self.n_chroma
117
+
118
+ def extract(self, sig: Signal) -> Tensor:
119
+ from vampnet.dsp.hpss import hpss
120
+ self.chroma.to(sig.wav.device)
121
+
122
+ # spectrogram
123
+ spec = sn.stft(sig,
124
+ window_length=self.window_length,
125
+ hop_length=self.hop_length
126
+ )
127
+ # magnitude
128
+ spec = torch.abs(spec)
129
+
130
+ # hpss
131
+ spec = hpss(spec, kernel_size=51, hard=True)[0]
132
+
133
+ # chroma
134
+ chroma = self.chroma(spec)
135
+
136
+ # get the rms of this spec
137
+ rms_d = sn.rms_from_spec(
138
+ spec, window_length=self.window_length
139
+ )
140
+
141
+ # convert the rms to db
142
+ rms_d = 10 * torch.log10(rms_d + 1e-7)
143
+
144
+ # make a mask based on the rms < -40
145
+ mask = torch.where(rms_d < -40, torch.zeros_like(rms_d), torch.ones_like(rms_d))
146
+
147
+ # remove anything below 80 (where the fuck did I get this number from?)
148
+ chroma = torch.where(chroma < 100, torch.zeros_like(chroma), chroma)
149
+
150
+ # Get top 2 values and indices along the -2 dimension
151
+ if self.top_n:
152
+ _, topk_indices = torch.topk(chroma, self.top_n, dim=-2)
153
+
154
+ # Create a mask for the top 2 values
155
+ topk_mask = torch.zeros_like(chroma).scatter_(-2, topk_indices, 1.0)
156
+
157
+ # Retain only the top 2 values
158
+ chroma = chroma * topk_mask
159
+
160
+ # apply the mask
161
+ chroma = chroma * mask.unsqueeze(-2)
162
+
163
+ # Apply softmax along dim=-2
164
+ if self.top_n > 0:
165
+ chroma = torch.nn.functional.softmax(chroma, dim=-2)
166
+
167
+ # mask out any timesteps whose chroma have all equal values (all 0s before softmax)
168
+ # TODO: i did this with chatgpt, there's gott a be a better way
169
+ chroma_mean = chroma.mean(dim=-2, keepdim=True)
170
+ chroma_diff = torch.abs(chroma - chroma_mean)
171
+ equal_mask = torch.all(chroma_diff < 1e-6, dim=-2, keepdim=True)
172
+
173
+ # Set chroma values to zero for timesteps with all equal values
174
+ chroma = torch.where(equal_mask, torch.zeros_like(chroma), chroma)
175
+
176
+
177
+ return chroma[:, 0, :, :-1] # mono only :( FIX ME!
178
+
179
+
180
+ # TODO: try harmonic mel?
181
+
182
+ CONTROLLERS = {
183
+ "rms": RMS,
184
+ "rmsq128": partial(RMS, n_quantize=128),
185
+ "rmsq16": partial(RMS, n_quantize=16),
186
+ "rms-median": partial(RMS, median_filter_size=5),
187
+ "rmsq16-median": partial(RMS, n_quantize=16, median_filter_size=3),
188
+ "hchroma": HarmonicChroma,
189
+ "hchroma-12c-top2": partial(HarmonicChroma, n_chroma=12, top_n=2), # TODO: refactor me. If this works, this should just be named hchroma.
190
+ "hchroma-36c-top3": partial(HarmonicChroma, n_chroma=36, top_n=3) # TODO: refactor me. If this works, this should just be named hchroma.
191
+ }
192
+
193
+ class Sketch2SoundController(nn.Module):
194
+
195
+ def __init__(
196
+ self,
197
+ ctrl_keys: list[str],
198
+ hop_length: str,
199
+ sample_rate: int,
200
+ ):
201
+ super().__init__()
202
+
203
+ assert all([k in CONTROLLERS for k in ctrl_keys]), f"got an unsupported control key in {ctrl_keys}!\n supported: {CONTROLLERS.keys()}"
204
+
205
+ self.hop_length = hop_length
206
+ self.ctrl_keys = ctrl_keys
207
+ self.sample_rate = sample_rate
208
+
209
+ self.controllers = {
210
+ k: CONTROLLERS[k](hop_length=hop_length, sample_rate=sample_rate)
211
+ for k in self.ctrl_keys
212
+ }
213
+
214
+ @property
215
+ def ctrl_dims(self, ) -> dict[str, int]:
216
+ return {
217
+ k: controller.dim for k, controller in self.controllers.items()
218
+ }
219
+
220
+ def extract(self, sig: Signal) -> dict[str, Tensor]:
221
+ ctrls = {
222
+ k: controller.extract(sig) for k, controller in self.controllers.items()
223
+ }
224
+ return ctrls
225
+
226
+ def random_mask(self, ctrls: dict[str, Tensor], r: float):
227
+ masks = {}
228
+ for k, ctrl in ctrls.items():
229
+ masks[k] = 1-random_along_time(ctrl, r)
230
+ return masks
231
+
232
+ def empty_mask(self, ctrls: dict[str, Tensor]):
233
+ first_key = next(iter(ctrls))
234
+ mask = torch.zeros_like(ctrls[first_key])
235
+ return {k: mask for k in ctrls}
236
+
237
+
238
+ def test_controller():
239
+ controller = Sketch2SoundController(
240
+ ctrl_keys=["rms-median", "rms", "rmsq128"],
241
+ hop_length=512,
242
+ sample_rate=44100
243
+ )
244
+ controller.train()
245
+ # sig = sn.read_from_file("assets/example.wav")
246
+ # sig = sn.read_from_file("/Users/hugo/Downloads/DCS_SE_FullChoir_ScaleUpDown06_A2_DYN.wav")
247
+ # sig = sn.excerpt('/Users/hugo/Downloads/(guitarra - hugo mix) bubararu - tambor negro.wav', offset=0, duration=10)
248
+ sig = sn.read_from_file("assets/voice-prompt.wav")
249
+ ctrls = controller.extract(sig)
250
+ print(f"given sig of shape {sig.wav.shape}, extracted controls: {ctrls}")
251
+
252
+ # print the whole thing
253
+ # torch.set_printoptions(profile="full")
254
+ # print(ctrls["hchroma"][0][0][:, 200:210])
255
+
256
+ # imshow the chroma
257
+ import matplotlib.pyplot as plt
258
+
259
+ # Define relative heights for the subplots
260
+ fig, (ax1, ax2, ax3, ax4) = plt.subplots(
261
+ 4, 1,
262
+ sharex=True,
263
+ )
264
+
265
+ # Display the spectrogram on the top
266
+ ax1.imshow(sn.stft(sig, hop_length=512, window_length=2048).abs()[0][0].cpu().log().numpy(), aspect='auto', origin='lower')
267
+ # display rms on the bottom
268
+ ax2.plot(ctrls["rms-median"][0][0])
269
+ ax3.plot(ctrls["rms"][0][0])
270
+ ax4.plot(ctrls["rmsq128"][0][0])
271
+
272
+ plt.tight_layout() # Ensure proper spacing
273
+ plt.savefig("img.png")
274
+
275
+
276
+ if __name__ == "__main__":
277
+ test_controller()
vampnet/interface.py CHANGED
@@ -59,7 +59,7 @@ class Interface(torch.nn.Module):
59
  coarse2fine_ckpt: str = None,
60
  coarse2fine_lora_ckpt: str = None,
61
  codec_ckpt: str = None,
62
- wavebeat_ckpt: str = None,
63
  device: str = "cpu",
64
  coarse_chunk_size_s: int = 10,
65
  coarse2fine_chunk_size_s: int = 3,
@@ -96,7 +96,7 @@ class Interface(torch.nn.Module):
96
 
97
  if wavebeat_ckpt is not None:
98
  logging.debug(f"loading wavebeat from {wavebeat_ckpt}")
99
- self.beat_tracker = WaveBeat(wavebeat_ckpt)
100
  self.beat_tracker.model.to(device)
101
  else:
102
  self.beat_tracker = None
@@ -254,6 +254,7 @@ class Interface(torch.nn.Module):
254
  """
255
  assert self.beat_tracker is not None, "No beat tracker loaded"
256
 
 
257
  # get the beat times
258
  beats, downbeats = self.beat_tracker.extract_beats(signal)
259
 
@@ -516,12 +517,19 @@ class Interface(torch.nn.Module):
516
  # the forward pass
517
  logging.debug(z.shape)
518
  logging.debug("coarse!")
519
- zv, mask_z = self.coarse_vamp(
520
- z,
521
- mask=mask,
522
- return_mask=True,
523
- **kwargs
524
- )
 
 
 
 
 
 
 
525
 
526
  # add the top codebooks back in
527
  if zv.shape[1] < z.shape[1]:
 
59
  coarse2fine_ckpt: str = None,
60
  coarse2fine_lora_ckpt: str = None,
61
  codec_ckpt: str = None,
62
+ wavebeat_ckpt: str = "./models/vampnet/wavebeat.pth",
63
  device: str = "cpu",
64
  coarse_chunk_size_s: int = 10,
65
  coarse2fine_chunk_size_s: int = 3,
 
96
 
97
  if wavebeat_ckpt is not None:
98
  logging.debug(f"loading wavebeat from {wavebeat_ckpt}")
99
+ self.beat_tracker = WaveBeat(wavebeat_ckpt, device=device)
100
  self.beat_tracker.model.to(device)
101
  else:
102
  self.beat_tracker = None
 
254
  """
255
  assert self.beat_tracker is not None, "No beat tracker loaded"
256
 
257
+
258
  # get the beat times
259
  beats, downbeats = self.beat_tracker.extract_beats(signal)
260
 
 
517
  # the forward pass
518
  logging.debug(z.shape)
519
  logging.debug("coarse!")
520
+ zv = z
521
+ for i in range(feedback_steps):
522
+ zv, mask_z = self.coarse_vamp(
523
+ zv,
524
+ mask=mask,
525
+ return_mask=True,
526
+ **kwargs)
527
+ # roll the mask around a random amount
528
+ mask_z = mask_z.roll(
529
+ shifts=(i + 1) % feedback_steps,
530
+ dims=-1
531
+ )
532
+
533
 
534
  # add the top codebooks back in
535
  if zv.shape[1] < z.shape[1]:
vampnet/mask.py CHANGED
@@ -163,14 +163,18 @@ def dropout(
163
  mask: torch.Tensor,
164
  p: float,
165
  ):
166
- assert 0 <= p <= 1, "p must be between 0 and 1"
167
- assert mask.max() <= 1, "mask must be binary"
168
- assert mask.min() >= 0, "mask must be binary"
169
- mask = (~mask.bool()).float()
170
- mask = torch.bernoulli(mask * (1 - p))
171
- mask = ~mask.round().bool()
 
172
  return mask.long()
173
 
 
 
 
174
  def mask_or(
175
  mask1: torch.Tensor,
176
  mask2: torch.Tensor
 
163
  mask: torch.Tensor,
164
  p: float,
165
  ):
166
+ # instead of the above, mask along the last dimensions
167
+ tsteps = mask.shape[-1]
168
+ tsteps_to_drop = int(tsteps * p)
169
+ tsteps_to_keep = tsteps - tsteps_to_drop
170
+ idxs_to_drop = torch.randint(0, tsteps, (tsteps_to_drop,))
171
+ mask = mask.clone()
172
+ mask[:, :, idxs_to_drop] = 1
173
  return mask.long()
174
 
175
+
176
+
177
+
178
  def mask_or(
179
  mask1: torch.Tensor,
180
  mask2: torch.Tensor
vampnet/modules/transformer.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
9
  from einops import rearrange
10
  import loralib as lora
11
  import audiotools as at
@@ -405,7 +406,7 @@ class TransformerStack(nn.Module):
405
  )
406
 
407
  # Perform last normalization
408
- self.norm = RMSNorm(d_model) if last_layer else None
409
 
410
  def subsequent_mask(self, size):
411
  return torch.ones(1, size, size).tril().bool()
@@ -461,6 +462,75 @@ class TransformerStack(nn.Module):
461
  else:
462
  return out
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
  class VampNet(at.ml.BaseModel):
466
  def __init__(
@@ -475,7 +545,10 @@ class VampNet(at.ml.BaseModel):
475
  vocab_size: int = 1024,
476
  flash_attn: bool = True,
477
  noise_mode: str = "mask",
478
- dropout: float = 0.1
 
 
 
479
  ):
480
  super().__init__()
481
  assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
@@ -489,6 +562,11 @@ class VampNet(at.ml.BaseModel):
489
  self.latent_dim = latent_dim
490
  self.flash_attn = flash_attn
491
  self.noise_mode = noise_mode
 
 
 
 
 
492
 
493
  assert self.noise_mode == "mask", "deprecated"
494
 
@@ -525,10 +603,25 @@ class VampNet(at.ml.BaseModel):
525
  ),
526
  )
527
 
528
- def forward(self, x, return_activations: bool = False):
 
 
 
 
 
 
 
 
 
 
 
529
  x = self.embedding(x)
530
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
531
 
 
 
 
 
532
  x = rearrange(x, "b d n -> b n d")
533
  out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
534
  if return_activations:
@@ -600,6 +693,8 @@ class VampNet(at.ml.BaseModel):
600
  temperature: float = 1.0,
601
  mask: Optional[torch.Tensor] = None,
602
  mask_temperature: float = 10.5,
 
 
603
  typical_filtering=True,
604
  typical_mass=0.15,
605
  typical_min_tokens=64,
@@ -609,7 +704,9 @@ class VampNet(at.ml.BaseModel):
609
  return_signal=True,
610
  debug=False,
611
  causal_weight: float = 0.0,
 
612
  cfg_guidance: float = None,
 
613
  ):
614
  if seed is not None:
615
  at.util.seed(seed)
@@ -622,6 +719,22 @@ class VampNet(at.ml.BaseModel):
622
  z = start_tokens
623
  nb = z.shape[0]
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  if z is None:
626
  z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
627
  self.device
@@ -727,6 +840,7 @@ class VampNet(at.ml.BaseModel):
727
  # infer from latents
728
  # NOTE: this collapses the codebook dimension into the sequence dimension
729
  logits = self.forward(latents) # b, prob, seq
 
730
 
731
  if cfg_guidance is not None:
732
  logits_cond, logits_uncond = logits[:nb], logits[nb:]
@@ -774,9 +888,6 @@ class VampNet(at.ml.BaseModel):
774
  plt.imshow(_mask[0].cpu().numpy())
775
  plt.savefig(f"{STEP_FOLDER}/mask.png")
776
 
777
-
778
-
779
-
780
  # update the mask, remove conditioning codebooks from the mask
781
  # add z back into sampled z where the mask was false
782
  sampled_z = torch.where(
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ from torch import Tensor
10
  from einops import rearrange
11
  import loralib as lora
12
  import audiotools as at
 
406
  )
407
 
408
  # Perform last normalization
409
+ self.norm = RMSNorm(d_model) if last_layer else None
410
 
411
  def subsequent_mask(self, size):
412
  return torch.ones(1, size, size).tril().bool()
 
462
  else:
463
  return out
464
 
465
+ class CFGDropout(nn.Module):
466
+
467
+ def __init__(self, p: float = 0.2):
468
+ super().__init__()
469
+ self.p = p
470
+
471
+ def forward(self, x: Tensor):
472
+ # dropout along the batch dim
473
+ if self.training:
474
+ mask = torch.rand(x.shape[0], 1, 1, device=x.device) > self.p
475
+ else:
476
+ mask = torch.ones(x.shape[0], 1, 1, device=x.device)
477
+ return x * mask
478
+
479
+ class ControlEncoder(nn.Module):
480
+
481
+ def __init__(self,
482
+ ctrl_dims: dict[str, int],
483
+ embedding_dim: int,
484
+ cfg_dropout_prob: float
485
+ ):
486
+ super().__init__()
487
+ self.ctrl_encoders = nn.ModuleDict({
488
+ key: nn.Linear(dim, embedding_dim)
489
+ for key, dim in ctrl_dims.items()
490
+ })
491
+
492
+ self.cfg_dropout = CFGDropout(p=cfg_dropout_prob)
493
+ self.all_dropout = CFGDropout(p=cfg_dropout_prob / 2)
494
+
495
+ def forward(self,
496
+ embedding: Tensor, # embedding to which we will add ctrls
497
+ ctrls: dict[str, Tensor],
498
+ ctrl_masks: dict[str, Tensor]
499
+ ):
500
+ # INPUT: ctrl tensor should be shape (b d n)
501
+
502
+ # assert that we got all the right ctrls and ctrl_masks according to the encoders that we have
503
+ assert list(sorted(ctrls.keys())) == list(sorted(self.ctrl_encoders.keys())), "ctrls and ctrl_encoders keys do not match"
504
+ assert list(sorted(ctrl_masks.keys())) == list(sorted(self.ctrl_encoders.keys())), "ctrl_masks and ctrl_encoders keys do not match"
505
+
506
+ out_emb = torch.zeros_like(embedding)
507
+ for ck in ctrls:
508
+ ctrld = ctrls[ck]
509
+ ctrlmask = ctrl_masks[ck]
510
+
511
+ assert ctrld.shape[-1] == embedding.shape[-1], "ctrls should match x along time dimension"
512
+ assert ctrlmask.ndim == 2, "ctrlmask should be 2d"
513
+ assert ctrlmask.shape[-1] == ctrld.shape[-1], "ctrlmask should match ctrld along time dimension"
514
+
515
+ # project ctrl with encoder
516
+ ctrld = rearrange(ctrld, "b d n -> b n d")
517
+ ctrl_emb = self.ctrl_encoders[ck](ctrld)
518
+ ctrld = rearrange(ctrld, "b n d -> b d n")
519
+ ctrl_emb = rearrange(ctrl_emb, "b n d -> b d n")
520
+
521
+ # apply ctrl mask
522
+ ctrl_emb = ctrl_emb * ctrlmask[:, None, :]
523
+
524
+ # apply cfg dropout
525
+ ctrl_emb = self.cfg_dropout(ctrl_emb)
526
+
527
+ # add to the out_emb
528
+ out_emb = out_emb + ctrl_emb
529
+
530
+ # randomly dropout all ctrls
531
+ out_emb = self.all_dropout(out_emb)
532
+
533
+ return out_emb
534
 
535
  class VampNet(at.ml.BaseModel):
536
  def __init__(
 
545
  vocab_size: int = 1024,
546
  flash_attn: bool = True,
547
  noise_mode: str = "mask",
548
+ dropout: float = 0.1,
549
+ ctrl_dims: Optional[dict[str, int]] = None,
550
+ cfg_dropout_prob: float = 0.2,
551
+ cond_dim: int = 0,
552
  ):
553
  super().__init__()
554
  assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
 
562
  self.latent_dim = latent_dim
563
  self.flash_attn = flash_attn
564
  self.noise_mode = noise_mode
565
+ self.cond_dim = cond_dim
566
+ self.r_cond_dim = r_cond_dim
567
+ self.dropout = dropout
568
+ self.cfg_dropout_prob = cfg_dropout_prob
569
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
570
 
571
  assert self.noise_mode == "mask", "deprecated"
572
 
 
603
  ),
604
  )
605
 
606
+ if self.cond_dim > 0:
607
+ self.cfg_dropout = CFGDropout(p=cfg_dropout_prob)
608
+
609
+ self.ctrl_dims = ctrl_dims
610
+ if self.ctrl_dims is not None:
611
+ self.ctrl_encoder = ControlEncoder(
612
+ ctrl_dims,
613
+ embedding_dim=embedding_dim,
614
+ cfg_dropout_prob=cfg_dropout_prob
615
+ )
616
+
617
+ def forward(self, x, ctrls=None, ctrl_masks=None, return_activations: bool = False):
618
  x = self.embedding(x)
619
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
620
 
621
+ if self.ctrl_dims is not None:
622
+ # apply controls
623
+ x = x + self.ctrl_encoder(x, ctrls, ctrl_masks)
624
+
625
  x = rearrange(x, "b d n -> b n d")
626
  out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
627
  if return_activations:
 
693
  temperature: float = 1.0,
694
  mask: Optional[torch.Tensor] = None,
695
  mask_temperature: float = 10.5,
696
+ ctrls:dict = None,
697
+ ctrl_masks:dict = None,
698
  typical_filtering=True,
699
  typical_mass=0.15,
700
  typical_min_tokens=64,
 
704
  return_signal=True,
705
  debug=False,
706
  causal_weight: float = 0.0,
707
+ cfg_scale: float = 3.0,
708
  cfg_guidance: float = None,
709
+ cond = None # unused
710
  ):
711
  if seed is not None:
712
  at.util.seed(seed)
 
719
  z = start_tokens
720
  nb = z.shape[0]
721
 
722
+ use_cfg = ctrls is not None
723
+ tocfg = lambda x: x.repeat(2, 1, 1) if use_cfg else x
724
+ tocfgblank = lambda x: torch.cat([x, torch.zeros_like(x)], dim=0) if use_cfg else x
725
+ def fromcfg(x):
726
+ if use_cfg:
727
+ xcond, xuncond = x.chunk(2)
728
+ return xuncond + cfg_scale * (xcond - xuncond)
729
+ return x
730
+
731
+ z = tocfg(z)
732
+ if ctrls is not None:
733
+ ctrls = {k: tocfg(v) for k, v in ctrls.items()}
734
+ ctrl_masks = {k: tocfgblank(v) for k, v in ctrl_masks.items()}
735
+ if cond is not None:
736
+ cond = tocfg(cond)
737
+
738
  if z is None:
739
  z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
740
  self.device
 
840
  # infer from latents
841
  # NOTE: this collapses the codebook dimension into the sequence dimension
842
  logits = self.forward(latents) # b, prob, seq
843
+ logits = fromcfg(logits)
844
 
845
  if cfg_guidance is not None:
846
  logits_cond, logits_uncond = logits[:nb], logits[nb:]
 
888
  plt.imshow(_mask[0].cpu().numpy())
889
  plt.savefig(f"{STEP_FOLDER}/mask.png")
890
 
 
 
 
891
  # update the mask, remove conditioning codebooks from the mask
892
  # add z back into sampled z where the mask was false
893
  sampled_z = torch.where(