asigalov61 commited on
Commit
48a1204
·
verified ·
1 Parent(s): bac0956

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -12
app.py CHANGED
@@ -317,7 +317,7 @@ def tokens_to_MIDI(tokens, MIDI_name):
317
  print('Done!')
318
  print('=' * 70)
319
 
320
- return audio, new_fn
321
 
322
  # =================================================================================================
323
 
@@ -381,27 +381,71 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
381
 
382
  #===============================================================================
383
 
384
- toekns, notes = read_MIDI(input_midi.name)
385
-
386
 
 
 
 
387
  #==================================================================
388
 
389
  print('=' * 70)
390
- print('Number of tokens:', len(toekns))
391
- print('Number of notes:', len(notes))
392
- print('Sample output events', toekns[:5])
393
- print('=' * 70)
394
- print('Generating...')
395
 
396
- temperature = 0.85
397
 
 
 
398
  print('=' * 70)
399
  print('Giant Music Transformer MIDI Comparator')
400
  print('=' * 70)
401
 
402
- #==========================================================================
403
-
404
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  print('=' * 70)
407
  print('Done!')
 
317
  print('Done!')
318
  print('=' * 70)
319
 
320
+ return new_fn, audio
321
 
322
  # =================================================================================================
323
 
 
381
 
382
  #===============================================================================
383
 
384
+ print('Loading MIDIs...')
 
385
 
386
+ src_tokens, src_notes = read_MIDI(input_src_midi.name)
387
+ trg_tokens, trg_notes = read_MIDI(input_trg_midi.name)
388
+
389
  #==================================================================
390
 
391
  print('=' * 70)
392
+ print('Number of src tokens:', len(src_tokens))
393
+ print('Number of src notes:', len(src_notes))
394
+ print('Number of trg tokens:', len(trg_tokens))
395
+ print('Number of trg notes:', len(trg_notes))
 
396
 
397
+ #==========================================================================
398
 
399
+ print('=' * 70)
400
+ print('Comparing...')
401
  print('=' * 70)
402
  print('Giant Music Transformer MIDI Comparator')
403
  print('=' * 70)
404
 
405
+ input_src_tokens = src_tokens
406
+ input_trg_tokens = trg_tokens
407
+
408
+ sampling_resolution = input_sampling_overlap * 3
409
+ sampling_overlap = input_sampling_overlap * 3
410
+
411
+ comp_length = (min(len(input_src_tokens), len(input_trg_tokens)) // sampling_resolution) * sampling_resolution
412
+
413
+ comp_cos_sims = []
414
+
415
+ for i in range(0, comp_length, sampling_resolution-sampling_overlap):
416
+
417
+ torch.cuda.empty_cache()
418
+
419
+ inp = [input_src_tokens[i:i+sampling_resolution]]
420
+
421
+ inp = torch.LongTensor(inp).cuda()
422
+
423
+ with ctx:
424
+ out = model(inp)
425
+ cache = out[2]
426
+ src_embedings = cache.layer_hiddens[-1]
427
+
428
+ torch.cuda.empty_cache()
429
+
430
+ inp = [input_trg_tokens[i:i+sampling_resolution]]
431
+
432
+ inp = torch.LongTensor(inp).cuda()
433
+
434
+ with ctx:
435
+ out = model(inp)
436
+ cache = out[2]
437
+ trg_embedings = cache.layer_hiddens[-1]
438
+
439
+
440
+ cos_sim = pairwise.cosine_similarity([src_embedings.cpu().detach().numpy()[0].flatten()],
441
+ [trg_embedings.cpu().detach().numpy()[0].flatten()]
442
+ ).tolist()[0][0]
443
+
444
+ comp_cos_sims.append(cos_sim)
445
+
446
+ min_cos_sim = min(comp_cos_sims)
447
+ avg_cos_sim = sum(comp_cos_sims) / len(comp_cos_sims)
448
+ max_cos_sim = max(comp_cos_sims)
449
 
450
  print('=' * 70)
451
  print('Done!')