Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -317,7 +317,7 @@ def tokens_to_MIDI(tokens, MIDI_name):
|
|
317 |
print('Done!')
|
318 |
print('=' * 70)
|
319 |
|
320 |
-
return
|
321 |
|
322 |
# =================================================================================================
|
323 |
|
@@ -381,27 +381,71 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
|
|
381 |
|
382 |
#===============================================================================
|
383 |
|
384 |
-
|
385 |
-
|
386 |
|
|
|
|
|
|
|
387 |
#==================================================================
|
388 |
|
389 |
print('=' * 70)
|
390 |
-
print('Number of tokens:', len(
|
391 |
-
print('Number of notes:', len(
|
392 |
-
print('
|
393 |
-
print('
|
394 |
-
print('Generating...')
|
395 |
|
396 |
-
|
397 |
|
|
|
|
|
398 |
print('=' * 70)
|
399 |
print('Giant Music Transformer MIDI Comparator')
|
400 |
print('=' * 70)
|
401 |
|
402 |
-
|
403 |
-
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
|
406 |
print('=' * 70)
|
407 |
print('Done!')
|
|
|
317 |
print('Done!')
|
318 |
print('=' * 70)
|
319 |
|
320 |
+
return new_fn, audio
|
321 |
|
322 |
# =================================================================================================
|
323 |
|
|
|
381 |
|
382 |
#===============================================================================
|
383 |
|
384 |
+
print('Loading MIDIs...')
|
|
|
385 |
|
386 |
+
src_tokens, src_notes = read_MIDI(input_src_midi.name)
|
387 |
+
trg_tokens, trg_notes = read_MIDI(input_trg_midi.name)
|
388 |
+
|
389 |
#==================================================================
|
390 |
|
391 |
print('=' * 70)
|
392 |
+
print('Number of src tokens:', len(src_tokens))
|
393 |
+
print('Number of src notes:', len(src_notes))
|
394 |
+
print('Number of trg tokens:', len(trg_tokens))
|
395 |
+
print('Number of trg notes:', len(trg_notes))
|
|
|
396 |
|
397 |
+
#==========================================================================
|
398 |
|
399 |
+
print('=' * 70)
|
400 |
+
print('Comparing...')
|
401 |
print('=' * 70)
|
402 |
print('Giant Music Transformer MIDI Comparator')
|
403 |
print('=' * 70)
|
404 |
|
405 |
+
input_src_tokens = src_tokens
|
406 |
+
input_trg_tokens = trg_tokens
|
407 |
+
|
408 |
+
sampling_resolution = input_sampling_overlap * 3
|
409 |
+
sampling_overlap = input_sampling_overlap * 3
|
410 |
+
|
411 |
+
comp_length = (min(len(input_src_tokens), len(input_trg_tokens)) // sampling_resolution) * sampling_resolution
|
412 |
+
|
413 |
+
comp_cos_sims = []
|
414 |
+
|
415 |
+
for i in range(0, comp_length, sampling_resolution-sampling_overlap):
|
416 |
+
|
417 |
+
torch.cuda.empty_cache()
|
418 |
+
|
419 |
+
inp = [input_src_tokens[i:i+sampling_resolution]]
|
420 |
+
|
421 |
+
inp = torch.LongTensor(inp).cuda()
|
422 |
+
|
423 |
+
with ctx:
|
424 |
+
out = model(inp)
|
425 |
+
cache = out[2]
|
426 |
+
src_embedings = cache.layer_hiddens[-1]
|
427 |
+
|
428 |
+
torch.cuda.empty_cache()
|
429 |
+
|
430 |
+
inp = [input_trg_tokens[i:i+sampling_resolution]]
|
431 |
+
|
432 |
+
inp = torch.LongTensor(inp).cuda()
|
433 |
+
|
434 |
+
with ctx:
|
435 |
+
out = model(inp)
|
436 |
+
cache = out[2]
|
437 |
+
trg_embedings = cache.layer_hiddens[-1]
|
438 |
+
|
439 |
+
|
440 |
+
cos_sim = pairwise.cosine_similarity([src_embedings.cpu().detach().numpy()[0].flatten()],
|
441 |
+
[trg_embedings.cpu().detach().numpy()[0].flatten()]
|
442 |
+
).tolist()[0][0]
|
443 |
+
|
444 |
+
comp_cos_sims.append(cos_sim)
|
445 |
+
|
446 |
+
min_cos_sim = min(comp_cos_sims)
|
447 |
+
avg_cos_sim = sum(comp_cos_sims) / len(comp_cos_sims)
|
448 |
+
max_cos_sim = max(comp_cos_sims)
|
449 |
|
450 |
print('=' * 70)
|
451 |
print('Done!')
|