asigalov61 commited on
Commit
a3ef88a
·
verified ·
1 Parent(s): 9282a07

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +502 -0
app.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #==================================================================================
2
+
3
+ print('=' * 70)
4
+ print('Loading core Giant Music Transformer modules...')
5
+
6
+ import os
7
+ import sys
8
+
9
+ print('=' * 70)
10
+ print('Loading main Giant Music Transformer modules...')
11
+
12
+ os.environ['USE_FLASH_ATTENTION'] = '1'
13
+
14
+ import torch
15
+
16
+ torch.set_float32_matmul_precision('high')
17
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
18
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
19
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
20
+ torch.backends.cuda.enable_math_sdp(True)
21
+ torch.backends.cuda.enable_flash_sdp(True)
22
+ torch.backends.cuda.enable_cudnn_sdp(True)
23
+
24
+ os.chdir('/home/ubuntu/Giant-Music-Transformer/')
25
+ print("Current working directory: ", os.getcwd())
26
+ sys.path.append(os.getcwd())
27
+ import TMIDIX
28
+
29
+ from midi_to_colab_audio import midi_to_colab_audio
30
+
31
+ from x_transformer_1_23_2 import *
32
+
33
+ import random
34
+
35
+ os.chdir('/home/ubuntu/')
36
+ print('=' * 70)
37
+ print('Loading aux Giant Music Transformer modules...')
38
+
39
+ import matplotlib.pyplot as plt
40
+
41
+ import gradio as gr
42
+
43
+ print('=' * 70)
44
+ print('PyTorch version:', torch.__version__)
45
+ print('=' * 70)
46
+ print('Done!')
47
+ print('Enjoy! :)')
48
+ print('=' * 70)
49
+
50
+ #==================================================================================
51
+
52
+ print('=' * 70)
53
+ print('Instantiating model...')
54
+
55
+ device_type = 'cuda'
56
+ dtype = 'bfloat16'
57
+
58
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
59
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
60
+
61
+ SEQ_LEN = 8192
62
+ PAD_IDX = 19463
63
+
64
+ model = TransformerWrapper(
65
+ num_tokens = PAD_IDX+1,
66
+ max_seq_len = SEQ_LEN,
67
+ attn_layers = Decoder(dim = 2048,
68
+ depth = 8,
69
+ heads = 32,
70
+ rotary_pos_emb = True,
71
+ attn_flash = True
72
+ )
73
+ )
74
+
75
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
76
+
77
+ print('=' * 70)
78
+ print('Loading model checkpoint...')
79
+
80
+ model_path = '/home/ubuntu/Giant-Music-Transformer/Models/Medium/Giant_Music_Transformer_Medium_Trained_Model_10446_steps_0.7202_loss_0.8233_acc.pth'
81
+
82
+ model.load_state_dict(torch.load(model_path))
83
+
84
+ print('=' * 70)
85
+
86
+ model.cuda()
87
+ model.eval()
88
+
89
+ print('Done!')
90
+ print('=' * 70)
91
+ print('Model will use', dtype, 'precision...')
92
+ print('=' * 70)
93
+
94
+ #==================================================================================
95
+
96
+ SOUDFONT_PATH = '/usr/share/sounds/sf2/FluidR3_GM.sf2'
97
+
98
+ NUM_OUT_BATCHES = 8
99
+
100
+ #==================================================================================
101
+
102
+ def load_midi(input_midi):
103
+
104
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
105
+
106
+ escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)
107
+
108
+ escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16)
109
+
110
+ instruments_list = list(set([y[6] for y in escore_notes]))
111
+
112
+ #=======================================================
113
+ # FINAL PROCESSING
114
+ #=======================================================
115
+
116
+ melody_chords = []
117
+
118
+ # Break between compositions / Intro seq
119
+
120
+ if 128 in instruments_list:
121
+ drums_present = 19331 # Yes
122
+ else:
123
+ drums_present = 19330 # No
124
+
125
+ pat = escore_notes[0][6]
126
+
127
+ melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq
128
+
129
+ #=======================================================
130
+ # MAIN PROCESSING CYCLE
131
+ #=======================================================
132
+
133
+ pe = escore_notes[0]
134
+
135
+ for e in escore_notes:
136
+
137
+ #=======================================================
138
+ # Timings...
139
+
140
+ # Cliping all values...
141
+ delta_time = max(0, min(255, e[1]-pe[1]))
142
+
143
+ # Durations and channels
144
+
145
+ dur = max(0, min(255, e[2]))
146
+ cha = max(0, min(15, e[3]))
147
+
148
+ # Patches
149
+ if cha == 9: # Drums patch will be == 128
150
+ pat = 128
151
+
152
+ else:
153
+ pat = e[6]
154
+
155
+ # Pitches
156
+
157
+ ptc = max(1, min(127, e[4]))
158
+
159
+ # Velocities
160
+
161
+ # Calculating octo-velocity
162
+ vel = max(8, min(127, e[5]))
163
+ velocity = round(vel / 15)-1
164
+
165
+ #=======================================================
166
+ # FINAL NOTE SEQ
167
+ #=======================================================
168
+
169
+ # Writing final note asynchronously
170
+
171
+ dur_vel = (8 * dur) + velocity
172
+ pat_ptc = (129 * pat) + ptc
173
+
174
+ melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])
175
+
176
+ pe = e
177
+
178
+ return melody_chords
179
+
180
+ #==================================================================================
181
+
182
+ def save_midi(tokens, batch_number=None):
183
+
184
+ song = tokens
185
+ song_f = []
186
+
187
+ time = 0
188
+ dur = 0
189
+ vel = 90
190
+ pitch = 0
191
+ channel = 0
192
+
193
+ patches = [-1] * 16
194
+
195
+ channels = [0] * 16
196
+ channels[9] = 1
197
+
198
+ for ss in song:
199
+
200
+ if 0 <= ss < 256:
201
+
202
+ time += ss * 16
203
+
204
+ if 256 <= ss < 2304:
205
+
206
+ dur = ((ss-256) // 8) * 16
207
+ vel = (((ss-256) % 8)+1) * 15
208
+
209
+ if 2304 <= ss < 18945:
210
+
211
+ patch = (ss-2304) // 129
212
+
213
+ if patch < 128:
214
+
215
+ if patch not in patches:
216
+ if 0 in channels:
217
+ cha = channels.index(0)
218
+ channels[cha] = 1
219
+ else:
220
+ cha = 15
221
+
222
+ patches[cha] = patch
223
+ channel = patches.index(patch)
224
+ else:
225
+ channel = patches.index(patch)
226
+
227
+ if patch == 128:
228
+ channel = 9
229
+
230
+ pitch = (ss-2304) % 129
231
+
232
+ song_f.append(['note', time, dur, channel, pitch, vel, patch ])
233
+
234
+ patches = [0 if x==-1 else x for x in patches]
235
+
236
+ if batch_number == None:
237
+ fname = '/home/ubuntu/Giant-Music-Transformer-Music-Composition'
238
+
239
+ else:
240
+ fname = '/home/ubuntu/Giant-Music-Transformer-Music-Composition_'+str(batch_number)
241
+
242
+ data = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
243
+ output_signature = 'Giant Music Transformer',
244
+ output_file_name = fname,
245
+ track_name='Project Los Angeles',
246
+ list_of_MIDI_patches=patches,
247
+ verbose=False
248
+ )
249
+
250
+ return song_f
251
+
252
+ #==================================================================================
253
+
254
+ def generate_music(prime,
255
+ num_gen_tokens,
256
+ num_gen_batches,
257
+ gen_outro,
258
+ gen_drums,
259
+ model_temperature,
260
+ model_sampling_top_p
261
+ ):
262
+
263
+ if not prime:
264
+ inputs = [19461]
265
+
266
+ else:
267
+ inputs = prime
268
+
269
+ if gen_outro:
270
+ inputs.extend([18945])
271
+
272
+ if gen_drums:
273
+ drums = [36, 38]
274
+ drum_pitch = random.choice(drums)
275
+ inputs.extend([0, ((8*8)+6)+256, ((128*129)+drum_pitch)+2304])
276
+
277
+ torch.cuda.empty_cache()
278
+
279
+ inp = [inputs] * num_gen_batches
280
+
281
+ inp = torch.LongTensor(inp).cuda()
282
+
283
+ with ctx:
284
+ with torch.inference_mode():
285
+ out = model.generate(inp,
286
+ num_gen_tokens,
287
+ filter_logits_fn=top_p,
288
+ filter_kwargs={'thres': model_sampling_top_p},
289
+ temperature=model_temperature,
290
+ return_prime=False,
291
+ verbose=False)
292
+
293
+ output = out.tolist()
294
+
295
+ return output
296
+
297
+ #==================================================================================
298
+
299
+ final_composition = []
300
+ generated_batches = []
301
+
302
+ #==================================================================================
303
+
304
+ def generate_callback(input_midi,
305
+ num_prime_tokens,
306
+ num_gen_tokens,
307
+ gen_outro,
308
+ gen_drums,
309
+ model_temperature,
310
+ model_sampling_top_p
311
+ ):
312
+
313
+ global generated_batches
314
+ generated_batches = []
315
+
316
+ if not final_composition and input_midi is not None:
317
+ final_composition.extend(load_midi(input_midi)[:num_prime_tokens])
318
+
319
+ batched_gen_tokens = generate_music(final_composition,
320
+ num_gen_tokens,
321
+ NUM_OUT_BATCHES,
322
+ gen_outro,
323
+ gen_drums,
324
+ model_temperature,
325
+ model_sampling_top_p
326
+ )
327
+
328
+ outputs = []
329
+
330
+ for i in range(len(batched_gen_tokens)):
331
+
332
+ tokens = batched_gen_tokens[i]
333
+
334
+ # Save MIDI to a temporary file
335
+ midi_score = save_midi(tokens, i)
336
+
337
+ # MIDI plot
338
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score, plot_title='Batch # ' + str(i), return_plt=True)
339
+
340
+ # File name
341
+ fname = '/home/ubuntu/Giant-Music-Transformer-Music-Composition_'+str(i)
342
+
343
+ # Save audio to a temporary file
344
+ midi_audio = midi_to_colab_audio(fname + '.mid',
345
+ soundfont_path=SOUDFONT_PATH,
346
+ sample_rate=16000,
347
+ output_for_gradio=True
348
+ )
349
+
350
+ outputs.append(((16000, midi_audio), midi_plot, tokens))
351
+
352
+ return outputs
353
+
354
+ #==================================================================================
355
+
356
+ def generate_callback_wrapper(input_midi,
357
+ num_prime_tokens,
358
+ num_gen_tokens,
359
+ gen_outro,
360
+ gen_drums,
361
+ model_temperature,
362
+ model_sampling_top_p
363
+ ):
364
+
365
+ result = generate_callback(input_midi,
366
+ num_prime_tokens,
367
+ num_gen_tokens,
368
+ gen_outro,
369
+ gen_drums,
370
+ model_temperature,
371
+ model_sampling_top_p
372
+ )
373
+
374
+ generated_batches.extend([sublist[2] for sublist in result])
375
+
376
+ return tuple(item for sublist in result for item in sublist[:2])
377
+
378
+ #==================================================================================
379
+
380
+ def add_batch(batch_number):
381
+
382
+ final_composition.extend(generated_batches[batch_number])
383
+
384
+ # Save MIDI to a temporary file
385
+ midi_score = save_midi(final_composition)
386
+
387
+ # MIDI plot
388
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score, plot_title='Giant Music Transformer Composition', return_plt=True)
389
+
390
+ # File name
391
+ fname = 'Giant-Music-Transformer-Music-Composition'
392
+
393
+ # Save audio to a temporary file
394
+ midi_audio = midi_to_colab_audio(fname + '.mid',
395
+ soundfont_path=SOUDFONT_PATH,
396
+ sample_rate=16000,
397
+ output_for_gradio=True
398
+ )
399
+
400
+ return (16000, midi_audio), midi_plot, fname+'.mid'
401
+
402
+ #==================================================================================
403
+
404
+ def remove_batch(batch_number, num_tokens):
405
+
406
+ global final_composition
407
+
408
+ if len(final_composition) > num_tokens:
409
+ final_composition = final_composition[:-num_tokens]
410
+
411
+ # Save MIDI to a temporary file
412
+ midi_score = save_midi(final_composition)
413
+
414
+ # MIDI plot
415
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score, plot_title='Giant Music Transformer Composition', return_plt=True)
416
+
417
+ # File name
418
+ fname = 'Giant-Music-Transformer-Music-Composition'
419
+
420
+ # Save audio to a temporary file
421
+ midi_audio = midi_to_colab_audio(fname + '.mid',
422
+ soundfont_path=SOUDFONT_PATH,
423
+ sample_rate=16000,
424
+ output_for_gradio=True
425
+ )
426
+
427
+ return (16000, midi_audio), midi_plot, fname+'.mid'
428
+
429
+ #==================================================================================
430
+
431
+ def reset():
432
+ global final_composition
433
+ final_composition = []
434
+
435
+ #==================================================================================
436
+
437
+ with gr.Blocks() as demo:
438
+
439
+ gr.Markdown("## Upload your MIDI or select a sample example MIDI")
440
+
441
+ input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
442
+ clear_btn = gr.ClearButton(input_midi, variant="stop", value="Reset")
443
+
444
+ clear_btn.click(reset)
445
+
446
+ gr.Markdown("## Generate")
447
+
448
+ num_prime_tokens = gr.Slider(15, 6999, value=600, step=3, label="Number of prime tokens")
449
+ num_gen_tokens = gr.Slider(15, 1200, value=600, step=3, label="Number of tokens to generate")
450
+ gen_outro = gr.Checkbox(value=False, label="Try to generate an outro")
451
+ gen_drums = gr.Checkbox(value=False, label="Try to introduce drums")
452
+ model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
453
+ model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
454
+
455
+ generate_btn = gr.Button("Generate", variant="primary")
456
+
457
+ gr.Markdown("## Select batch")
458
+
459
+ outputs = []
460
+
461
+ for i in range(NUM_OUT_BATCHES):
462
+ with gr.Tab(f"Batch # {i}") as tab:
463
+
464
+ audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3", elem_id="midi_audio")
465
+ plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot")
466
+
467
+ outputs.extend([audio_output, plot_output])
468
+
469
+ generate_btn.click(generate_callback_wrapper,
470
+ [input_midi,
471
+ num_prime_tokens,
472
+ num_gen_tokens,
473
+ gen_outro,
474
+ gen_drums,
475
+ model_temperature,
476
+ model_sampling_top_p
477
+ ],
478
+ outputs
479
+ )
480
+
481
+ gr.Markdown("## Add/Remove batch")
482
+
483
+ batch_number = gr.Slider(0, NUM_OUT_BATCHES, value=0, step=1, label="Batch number to add/remove")
484
+
485
+ add_btn = gr.Button("Add batch", variant="primary")
486
+ remove_btn = gr.Button("Remove batch", variant="stop")
487
+
488
+ final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3", elem_id="midi_audio")
489
+ final_plot_output = gr.Plot(label="Final MIDI plot")
490
+ final_file_output = gr.File(label="Final MIDI file")
491
+
492
+ add_btn.click(add_batch, inputs=[batch_number],
493
+ outputs=[final_audio_output, final_plot_output, final_file_output]
494
+ )
495
+
496
+ remove_btn.click(remove_batch, inputs=[batch_number, num_gen_tokens],
497
+ outputs=[final_audio_output, final_plot_output, final_file_output]
498
+ )
499
+
500
+ demo.unload(lambda: print("User ended session."))
501
+
502
+ demo.launch(share=True)