asigalov61 commited on
Commit
4154701
·
verified ·
1 Parent(s): 4d440f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -44
app.py CHANGED
@@ -37,6 +37,7 @@ print('Loading aux Giant Music Transformer modules...')
37
  import matplotlib.pyplot as plt
38
 
39
  import gradio as gr
 
40
 
41
  print('=' * 70)
42
  print('PyTorch version:', torch.__version__)
@@ -47,50 +48,6 @@ print('=' * 70)
47
 
48
  #==================================================================================
49
 
50
- print('=' * 70)
51
- print('Instantiating model...')
52
-
53
- device_type = 'cuda'
54
- dtype = 'bfloat16'
55
-
56
- ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
57
- ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
58
-
59
- SEQ_LEN = 8192
60
- PAD_IDX = 19463
61
-
62
- model = TransformerWrapper(
63
- num_tokens = PAD_IDX+1,
64
- max_seq_len = SEQ_LEN,
65
- attn_layers = Decoder(dim = 2048,
66
- depth = 8,
67
- heads = 32,
68
- rotary_pos_emb = True,
69
- attn_flash = True
70
- )
71
- )
72
-
73
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
74
-
75
- print('=' * 70)
76
- print('Loading model checkpoint...')
77
-
78
- model_path = 'Giant-Music-Transformer/Models/Medium/Giant_Music_Transformer_Medium_Trained_Model_10446_steps_0.7202_loss_0.8233_acc.pth'
79
-
80
- model.load_state_dict(torch.load(model_path))
81
-
82
- print('=' * 70)
83
-
84
- model.cuda()
85
- model.eval()
86
-
87
- print('Done!')
88
- print('=' * 70)
89
- print('Model will use', dtype, 'precision...')
90
- print('=' * 70)
91
-
92
- #==================================================================================
93
-
94
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7'
95
 
96
  NUM_OUT_BATCHES = 8
@@ -249,6 +206,7 @@ def save_midi(tokens, batch_number=None):
249
 
250
  #==================================================================================
251
 
 
252
  def generate_music(prime,
253
  num_gen_tokens,
254
  num_gen_batches,
@@ -258,6 +216,55 @@ def generate_music(prime,
258
  model_sampling_top_p
259
  ):
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if not prime:
262
  inputs = [19461]
263
 
@@ -290,6 +297,9 @@ def generate_music(prime,
290
 
291
  output = out.tolist()
292
 
 
 
 
293
  return output
294
 
295
  #==================================================================================
 
37
  import matplotlib.pyplot as plt
38
 
39
  import gradio as gr
40
+ import spaces
41
 
42
  print('=' * 70)
43
  print('PyTorch version:', torch.__version__)
 
48
 
49
  #==================================================================================
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7'
52
 
53
  NUM_OUT_BATCHES = 8
 
206
 
207
  #==================================================================================
208
 
209
+ @spaces.GPU
210
  def generate_music(prime,
211
  num_gen_tokens,
212
  num_gen_batches,
 
216
  model_sampling_top_p
217
  ):
218
 
219
+
220
+ #==============================================================================
221
+
222
+ print('=' * 70)
223
+ print('Instantiating model...')
224
+
225
+ device_type = 'cuda'
226
+ dtype = 'bfloat16'
227
+
228
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
229
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
230
+
231
+ SEQ_LEN = 8192
232
+ PAD_IDX = 19463
233
+
234
+ model = TransformerWrapper(
235
+ num_tokens = PAD_IDX+1,
236
+ max_seq_len = SEQ_LEN,
237
+ attn_layers = Decoder(dim = 2048,
238
+ depth = 8,
239
+ heads = 32,
240
+ rotary_pos_emb = True,
241
+ attn_flash = True
242
+ )
243
+ )
244
+
245
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
246
+
247
+ print('=' * 70)
248
+ print('Loading model checkpoint...')
249
+
250
+ model_path = 'Giant-Music-Transformer/Models/Medium/Giant_Music_Transformer_Medium_Trained_Model_10446_steps_0.7202_loss_0.8233_acc.pth'
251
+
252
+ model.load_state_dict(torch.load(model_path))
253
+
254
+ print('=' * 70)
255
+
256
+ model.cuda()
257
+ model.eval()
258
+
259
+ print('Done!')
260
+ print('=' * 70)
261
+ print('Model will use', dtype, 'precision...')
262
+ print('=' * 70)
263
+
264
+ #==============================================================================
265
+
266
+ print('Generating...')
267
+
268
  if not prime:
269
  inputs = [19461]
270
 
 
297
 
298
  output = out.tolist()
299
 
300
+ print('Done!')
301
+ print('=' * 70)
302
+
303
  return output
304
 
305
  #==================================================================================