asigalov61 commited on
Commit
32da081
·
verified ·
1 Parent(s): f85efc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -28
app.py CHANGED
@@ -325,43 +325,45 @@ def tokens_to_MIDI(tokens, MIDI_name):
325
  return new_fn, song_f, audio
326
 
327
  # =================================================================================================
328
-
329
- @spaces.GPU
330
- def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, input_sampling_overlap):
331
-
332
- print('=' * 70)
333
- print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
334
- start_time = reqtime.time()
335
 
336
- print('Loading model...')
337
 
338
- SEQ_LEN = 8192 # Models seq len
339
- PAD_IDX = 19463 # Models pad index
340
- DEVICE = 'cuda' # 'cuda'
341
 
342
- # instantiate the model
343
 
344
- model = TransformerWrapper(
345
- num_tokens = PAD_IDX+1,
346
- max_seq_len = SEQ_LEN,
347
- attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, rotary_pos_emb = True, attn_flash = True)
348
- )
349
-
350
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
351
 
352
- print('=' * 70)
353
 
354
- print('Loading model checkpoint...')
355
 
356
- model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
357
- filename='Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
358
- )
359
-
360
- model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
361
-
362
- model = torch.compile(model, mode='max-autotune')
 
 
 
 
 
 
 
 
 
363
 
364
  print('=' * 70)
 
 
365
 
366
  model.to(DEVICE)
367
  model.eval()
 
325
  return new_fn, song_f, audio
326
 
327
  # =================================================================================================
 
 
 
 
 
 
 
328
 
329
+ print('Loading model...')
330
 
331
+ SEQ_LEN = 8192 # Models seq len
332
+ PAD_IDX = 19463 # Models pad index
333
+ DEVICE = 'cuda' # 'cuda'
334
 
335
+ # instantiate the model
336
 
337
+ model = TransformerWrapper(
338
+ num_tokens = PAD_IDX+1,
339
+ max_seq_len = SEQ_LEN,
340
+ attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, rotary_pos_emb = True, attn_flash = True)
341
+ )
 
 
342
 
343
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
344
 
345
+ print('=' * 70)
346
 
347
+ print('Loading model checkpoint...')
348
+
349
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
350
+ filename='Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
351
+ )
352
+
353
+ model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
354
+
355
+ model = torch.compile(model, mode='max-autotune')
356
+
357
+ print('=' * 70)
358
+
359
+ # =================================================================================================
360
+
361
+ @spaces.GPU
362
+ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, input_sampling_overlap):
363
 
364
  print('=' * 70)
365
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
366
+ start_time = reqtime.time()
367
 
368
  model.to(DEVICE)
369
  model.eval()