asigalov61 commited on
Commit
32bd50f
·
verified ·
1 Parent(s): ab0089a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -37
app.py CHANGED
@@ -33,55 +33,50 @@ import TMIDIX
33
  import matplotlib.pyplot as plt
34
 
35
  # =================================================================================================
36
-
37
- @spaces.GPU
38
- def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
39
- print('=' * 70)
40
- print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
41
- start_time = reqtime.time()
42
 
43
- print('Loading model...')
44
 
45
- SEQ_LEN = 8192 # Models seq len
46
- PAD_IDX = 19463 # Models pad index
47
- DEVICE = 'cuda' # 'cpu'
48
 
49
- # instantiate the model
50
 
51
- model = TransformerWrapper(
52
- num_tokens = PAD_IDX+1,
53
- max_seq_len = SEQ_LEN,
54
- attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, rotary_pos_emb = True, attn_flash = True)
55
- )
56
-
57
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
58
 
59
- print('=' * 70)
60
 
61
- print('Loading model checkpoint...')
62
 
63
- model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
64
- filename='Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
65
- )
66
-
67
- model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
68
-
69
- model = torch.compile(model, mode='max-autotune')
70
-
71
- print('=' * 70)
72
 
73
- model.to(DEVICE)
74
- model.eval()
 
75
 
76
- if DEVICE == 'cpu':
77
- dtype = torch.bfloat16
78
- else:
79
- dtype = torch.bfloat16
80
 
81
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
82
 
83
- print('Done!')
 
 
 
 
 
 
 
 
 
 
 
84
  print('=' * 70)
 
 
85
 
86
  fn = os.path.basename(input_midi.name)
87
  fn1 = fn.split('.')[0]
@@ -94,6 +89,9 @@ def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
94
  print('Req patch number:', input_patch_number)
95
  print('-' * 70)
96
 
 
 
 
97
  #===============================================================================
98
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
99
 
 
33
  import matplotlib.pyplot as plt
34
 
35
  # =================================================================================================
 
 
 
 
 
 
36
 
37
+ print('Loading model...')
38
 
39
+ SEQ_LEN = 8192 # Models seq len
40
+ PAD_IDX = 19463 # Models pad index
41
+ DEVICE = 'cuda' # 'cpu'
42
 
43
+ # instantiate the model
44
 
45
+ model = TransformerWrapper(
46
+ num_tokens = PAD_IDX+1,
47
+ max_seq_len = SEQ_LEN,
48
+ attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, rotary_pos_emb = True, attn_flash = True)
49
+ )
 
 
50
 
51
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
52
 
53
+ print('=' * 70)
54
 
55
+ print('Loading model checkpoint...')
 
 
 
 
 
 
 
 
56
 
57
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
58
+ filename='Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
59
+ )
60
 
61
+ model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
 
 
 
62
 
63
+ print('=' * 70)
64
 
65
+ if DEVICE == 'cpu':
66
+ dtype = torch.bfloat16
67
+ else:
68
+ dtype = torch.bfloat16
69
+
70
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
71
+
72
+ print('Done!')
73
+ print('=' * 70)
74
+
75
+ @spaces.GPU
76
+ def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
77
  print('=' * 70)
78
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
79
+ start_time = reqtime.time()
80
 
81
  fn = os.path.basename(input_midi.name)
82
  fn1 = fn.split('.')[0]
 
89
  print('Req patch number:', input_patch_number)
90
  print('-' * 70)
91
 
92
+ model.to(DEVICE)
93
+ model.eval()
94
+
95
  #===============================================================================
96
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
97