suric commited on
Commit
48860c6
·
1 Parent(s): 694f61a

update apps and examples

Browse files
.gitattributes CHANGED
@@ -40,3 +40,5 @@ data/audio/twinkle_twinkle_little_stars_mozart.mp3 filter=lfs diff=lfs merge=lfs
40
  *.mp3 !text !filter !merge !diff
41
  *.jpeg filter=lfs diff=lfs merge=lfs -text
42
  data/audio/old_town_road20sec.mp3 filter=lfs diff=lfs merge=lfs -text
 
 
 
40
  *.mp3 !text !filter !merge !diff
41
  *.jpeg filter=lfs diff=lfs merge=lfs -text
42
  data/audio/old_town_road20sec.mp3 filter=lfs diff=lfs merge=lfs -text
43
+ data/audio/Suri's[[:space:]]Improv.mp3 filter=lfs diff=lfs merge=lfs -text
44
+ data/audio/like_no_tomorrow_20sec.wav filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,10 +1,20 @@
1
  import os
2
-
3
  import gradio as gr
 
4
 
5
- from gradio_components.image import generate_caption, improve_prompt
 
6
  from gradio_components.prediction import predict, transcribe
7
 
 
 
 
 
 
 
 
 
 
8
  theme = gr.themes.Glass(
9
  primary_hue="fuchsia",
10
  secondary_hue="indigo",
@@ -67,463 +77,304 @@ theme = gr.themes.Glass(
67
  )
68
 
69
 
70
- _AUDIOCRAFT_MODELS = [
71
- "facebook/musicgen-melody",
72
- "facebook/musicgen-medium",
73
- "facebook/musicgen-small",
74
- "facebook/musicgen-large",
75
- "facebook/musicgen-melody-large",
76
- "facebook/audiogen-medium",
77
- ]
78
-
79
-
80
- def generate_prompt(difficulty, style):
81
- _DIFFICULTY_MAPPIN = {
82
- "Easy": "beginner player",
83
- "Medium": "player who has 2-3 years experience",
84
- "Hard": "player who has more than 4 years experiences",
85
- }
86
- prompt = "piano only music for a {} to practice with the touch of {}".format(
87
- _DIFFICULTY_MAPPIN[difficulty], style
88
- )
89
- return prompt
90
-
91
-
92
- def toggle_melody_condition(melody_condition):
93
- if melody_condition:
94
- return gr.Audio(
95
- sources=["microphone", "upload"],
96
- label="Record or upload your audio",
97
- show_label=True,
98
- visible=True,
99
- )
100
- else:
101
- return gr.Audio(
102
- sources=["microphone", "upload"],
103
- label="Record or upload your audio",
104
- show_label=True,
105
- visible=False,
106
- )
107
-
108
-
109
- def toggle_custom_prompt(customize, difficulty, style):
110
- if customize:
111
- return gr.Textbox(label="Type your prompt", interactive=True, visible=True)
112
- else:
113
- prompt = generate_prompt(difficulty, style)
114
- return gr.Textbox(
115
- label="Generated Prompt", value=prompt, interactive=False, visible=True
116
- )
117
-
118
-
119
- def show_caption(show_caption_condition, description, prompt):
120
- if show_caption_condition:
121
- return (
122
- gr.Textbox(
123
- label="Image Caption",
124
- value=description,
125
- interactive=False,
126
- show_label=True,
127
- visible=True,
128
- ),
129
- gr.Textbox(
130
- label="Generated Prompt",
131
- value=prompt,
132
- interactive=True,
133
- show_label=True,
134
- visible=True,
135
- ),
136
- gr.Button("Generate Music", interactive=True, visible=True),
137
- )
138
- else:
139
- return (
140
- gr.Textbox(
141
- label="Image Caption",
142
- value=description,
143
- interactive=False,
144
- show_label=True,
145
- visible=False,
146
- ),
147
- gr.Textbox(
148
- label="Generated Prompt",
149
- value=prompt,
150
- interactive=True,
151
- show_label=True,
152
- visible=False,
153
- ),
154
- gr.Button(label="Generate Music", interactive=True, visible=True),
155
- )
156
-
157
 
158
- def optimize_fn(prompt):
159
- message_object, prompt = improve_prompt(prompt)
160
  return prompt
161
 
162
 
163
- def display_prompt(prompt):
164
- return gr.Textbox(
165
- label="Generated Prompt", value=prompt, interactive=False, visible=True
166
- )
167
-
168
-
169
- def post_submit(show_caption, model_path, image_input):
170
- _, description, prompt = generate_caption(image_input, model_path)
171
- return (
172
- gr.Textbox(
173
- label="Image Caption",
174
- value=description,
175
- interactive=False,
176
- show_label=True,
177
- visible=show_caption,
178
- ),
179
- gr.Textbox(
180
- label="Generated Prompt",
181
- value=prompt,
182
- interactive=True,
183
- show_label=True,
184
- visible=show_caption,
185
- ),
186
- gr.Button("Generate Music", interactive=True, visible=True),
187
- )
188
-
189
-
190
- def UI():
191
  with gr.Blocks() as demo:
192
- with gr.Tab("Generate Music by melody"):
193
  with gr.Row():
194
  with gr.Column():
195
  with gr.Row():
196
  model_path = gr.Dropdown(
197
- choices=_AUDIOCRAFT_MODELS,
198
  label="Select the model",
199
- value="facebook/musicgen-melody-large",
200
- )
201
- with gr.Row():
202
- duration = gr.Slider(
203
- minimum=10,
204
- maximum=60,
205
- value=10,
206
- label="Duration",
207
- interactive=True,
208
  )
 
209
  with gr.Row():
210
- topk = gr.Number(label="Top-k", value=250, interactive=True)
211
- topp = gr.Number(label="Top-p", value=0, interactive=True)
212
- temperature = gr.Number(
213
- label="Temperature", value=1.0, interactive=True
214
- )
215
- sample_rate = gr.Number(
216
- label="output music sample rate",
217
- value=32000,
218
  interactive=True,
 
219
  )
220
- difficulty = gr.Radio(
221
- ["Easy", "Medium", "Hard"],
222
- label="Difficulty",
223
- value="Easy",
 
224
  interactive=True,
225
  )
226
- style = gr.Radio(
227
- ["Jazz", "Classical Music", "Hip Hop"],
228
- value="Classical Music",
 
 
229
  label="music genre",
230
  interactive=True,
231
  )
 
 
 
232
 
233
- def update_prompt(difficulty, style):
234
- return gr.Textbox(
235
- label="",
236
- value=generate_prompt(difficulty, style),
237
- interactive=False,
238
- visible=False)
239
- customize = gr.Checkbox(
240
- label="Customize the prompt", interactive=True, value=False
241
- )
242
-
243
- _init_prompt = generate_prompt(difficulty.value, style.value)
244
- prompt = gr.Textbox(
245
- label="",
246
- value=_init_prompt,
247
- interactive=False,
248
- visible=False,
249
- )
250
- customize.change(
251
- fn=toggle_custom_prompt,
252
- inputs=[customize, difficulty, style],
253
- outputs=prompt,
254
- )
255
- difficulty.change(
256
- update_prompt,
257
- inputs=[difficulty, style],
258
- outputs=prompt
259
- )
260
- style.change(
261
- update_prompt,
262
- inputs=[difficulty, style],
263
- outputs=prompt
264
- )
265
- print(prompt)
266
- with gr.Column():
267
- optimize = gr.Button(
268
- "Optimize the prompt", interactive=True
269
- )
270
- with gr.Column():
271
- show_prompt = gr.Button("Show the prompt", interactive=True)
272
- prompt_text = gr.Textbox(
273
- "Optimized Prompt", interactive=False, visible=False
274
- )
275
- optimize.click(optimize_fn, inputs=[prompt], outputs=prompt)
276
- show_prompt.click(
277
- display_prompt, inputs=[prompt], outputs=prompt_text
278
- )
279
 
280
  with gr.Column():
281
  with gr.Row():
282
- melody = gr.Audio(
283
- sources=["microphone", "upload"],
284
- label="Record or upload your audio",
285
- # interactive=True,
286
- show_label=True,
287
- )
288
- with gr.Row():
289
  submit = gr.Button("Generate Music")
290
- output_audio = gr.Audio(
291
- "listen to the generated music", type="filepath"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  with gr.Row():
294
- transcribe_button = gr.Button("Transcribe")
295
- d = gr.DownloadButton("Download the file", visible=False)
296
- transcribe_button.click(
297
- transcribe, inputs=[output_audio], outputs=d
298
- )
299
 
300
- submit.click(
301
- fn=predict,
302
- inputs=[
303
- model_path,
304
- prompt,
305
- melody,
306
- duration,
307
- topk,
308
- topp,
309
- temperature,
310
- sample_rate,
311
- ],
312
- outputs=output_audio,
313
- )
 
 
 
 
 
 
 
314
  gr.Examples(
315
- examples=[
316
- [
317
- os.path.join(
318
- os.path.dirname(__file__),
319
- "./data/audio/twinkle_twinkle_little_stars_mozart_20sec"
320
- ".mp3",
321
- ),
322
- "Easy",
323
- 32000,
324
- 20,
325
- ],
326
- [
327
- os.path.join(
328
- os.path.dirname(__file__),
329
- "./data/audio/golden_hour_20sec.mp3",
330
- ),
331
- "Easy",
332
- 32000,
333
- 20,
334
- ],
335
- [
336
- os.path.join(
337
- os.path.dirname(__file__),
338
- "./data/audio/turkish_march_mozart_20sec.mp3",
339
- ),
340
- "Easy",
341
- 32000,
342
- 20,
343
- ],
344
  [
345
  os.path.join(
346
- os.path.dirname(__file__),
347
- "./data/audio/golden_hour_20sec.mp3",
348
  ),
349
- "Hard",
350
- 32000,
351
- 20,
352
- ],
353
- [
354
- os.path.join(
355
- os.path.dirname(__file__),
356
- "./data/audio/golden_hour_20sec.mp3",
357
- ),
358
- "Hard",
359
- 32000,
360
- 40,
361
  ],
362
  [
363
  os.path.join(
364
- os.path.dirname(__file__),
365
- "./data/audio/golden_hour_20sec.mp3",
366
  ),
367
- "Hard",
368
- 16000,
369
- 20,
370
- ],
371
- [
372
- os.path.join(
373
- os.path.dirname(__file__),
374
- "./data/audio/old_town_road20sec.mp3",
375
- ),
376
- "Hard",
377
- 32000,
378
- 40,
379
- ],
380
  ],
381
- inputs=[melody, difficulty, sample_rate, duration],
382
- label="Audio Examples",
383
- outputs=[output_audio],
384
- # cache_examples=True,
385
  )
386
 
387
  with gr.Tab("Generate Music by image"):
388
- with gr.Row():
389
- with gr.Column():
390
  image_input = gr.Image("Upload an image", type="filepath")
391
- melody_condition = gr.Checkbox(
392
- label="Generate music by melody", interactive=True, value=False
393
- )
394
- melody = gr.Audio(
395
- sources=["microphone", "upload"],
396
- label="Record or upload your audio",
397
- show_label=True,
398
- visible=False,
399
- )
400
- melody_condition.change(
401
- fn=toggle_melody_condition,
402
- inputs=[melody_condition],
403
- outputs=melody,
404
- )
405
- description = gr.Textbox(
406
- label="Image Captioning",
407
- show_label=True,
408
- interactive=False,
409
- visible=False,
410
- )
411
- prompt = gr.Textbox(
412
- label="Generated Prompt",
413
- show_label=True,
414
- interactive=True,
415
- visible=False,
416
- )
417
- show_prompt = gr.Checkbox(label="Show the prompt", interactive=True)
418
- submit = gr.Button("submit", interactive=True, visible=True)
419
- generate = gr.Button(
420
- "Generate Music", interactive=True, visible=False
421
- )
422
-
423
- with gr.Column():
424
- with gr.Row():
425
- model_path = gr.Dropdown(
426
- choices=_AUDIOCRAFT_MODELS,
427
- label="Select the model",
428
- value="facebook/musicgen-large",
429
- )
430
- with gr.Row():
431
- duration = gr.Slider(
432
- minimum=10,
433
- maximum=60,
434
- value=10,
435
- label="Duration",
436
- interactive=True,
437
  )
438
- topk = gr.Number(label="Top-k", value=250, interactive=True)
439
- topp = gr.Number(label="Top-p", value=0, interactive=True)
440
- temperature = gr.Number(
441
- label="Temperature", value=1.0, interactive=True
442
- )
443
- sample_rate = gr.Number(
444
- label="output music sample rate", value=32000, interactive=True
445
- )
446
- with gr.Column():
447
- output_audio = gr.Audio(
448
- "listen to the generated music",
449
- type="filepath",
450
- show_label=True,
451
- )
452
- transcribe_button = gr.Button("Transcribe")
453
- d = gr.DownloadButton("Download the file", visible=False)
454
- submit.click(
455
- fn=post_submit,
456
- inputs=[show_prompt, model_path, image_input],
457
- outputs=[description, prompt, generate],
458
- )
459
- show_prompt.change(
460
- fn=show_caption,
461
- inputs=[show_prompt, description, prompt],
462
- outputs=[description, prompt, generate],
463
- )
464
- transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
465
- generate.click(
466
- fn=predict,
467
- inputs=[
468
- model_path,
469
- prompt,
470
- melody,
471
- duration,
472
- topk,
473
- topp,
474
- temperature,
475
- sample_rate,
476
- ],
477
- outputs=output_audio,
478
- )
479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  gr.Examples(
481
- examples=[
482
  [
483
  os.path.join(
484
- os.path.dirname(__file__),
485
- "./data/image/kids_drawing.jpeg",
486
  ),
487
- False,
488
- None,
489
  "facebook/musicgen-large",
 
 
490
  ],
491
  [
492
  os.path.join(
493
- os.path.dirname(__file__),
494
- "./data/image/cat.jpeg",
495
  ),
496
- False,
 
497
  None,
498
- "facebook/musicgen-large",
499
  ],
500
  [
501
  os.path.join(
502
- os.path.dirname(__file__),
503
- "./data/image/cat.jpeg",
504
  ),
505
- True,
506
- "./data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3",
507
  "facebook/musicgen-melody-large",
 
 
 
 
508
  ],
509
  [
510
  os.path.join(
511
- os.path.dirname(__file__),
512
- "./data/image/beach.jpeg",
513
  ),
514
- False,
 
515
  None,
516
- "facebook/audiogen-medium",
517
  ],
518
  ],
519
- inputs=[image_input, melody_condition, melody, model_path],
520
- label="Audio Examples",
521
- outputs=[output_audio],
522
- # cache_examples=True,
523
  )
524
 
525
- demo.queue().launch()
526
 
527
 
528
  if __name__ == "__main__":
529
- UI()
 
 
 
 
 
 
1
  import os
 
2
  import gradio as gr
3
+ from audiocraft.models import MAGNeT, MusicGen, AudioGen
4
 
5
+ # from gradio_components.image import generate_caption, improve_prompt
6
+ from gradio_components.image import generate_caption_gpt4
7
  from gradio_components.prediction import predict, transcribe
8
 
9
+ import re
10
+ import argparse
11
+ from gradio_components.model_cards import TEXT_TO_MIDI_MODELS, TEXT_TO_SOUND_MODELS, MELODY_CONTINUATION_MODELS, TEXT_TO_MUSIC_MODELS, MODEL_CARDS, MELODY_CONDITIONED_MODELS
12
+ import ast
13
+ import json
14
+
15
+
16
+
17
+
18
  theme = gr.themes.Glass(
19
  primary_hue="fuchsia",
20
  secondary_hue="indigo",
 
77
  )
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ def generate_prompt(prompt, style):
82
+ prompt = ','.join([prompt]+style)
83
  return prompt
84
 
85
 
86
+ def UI(share=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  with gr.Blocks() as demo:
88
+ with gr.Tab("Generate Music by text"):
89
  with gr.Row():
90
  with gr.Column():
91
  with gr.Row():
92
  model_path = gr.Dropdown(
93
+ choices=TEXT_TO_MUSIC_MODELS,
94
  label="Select the model",
95
+ value="facebook/musicgen-large",
 
 
 
 
 
 
 
 
96
  )
97
+
98
  with gr.Row():
99
+ text_prompt = gr.Textbox(
100
+ label="Let's make a song about ...",
101
+ value="First day learning music generation in Standford university",
 
 
 
 
 
102
  interactive=True,
103
+ visible=True,
104
  )
105
+ num_outputs = gr.Number(
106
+ label="Number of outputs",
107
+ value=1,
108
+ minimum=1,
109
+ maximum=10,
110
  interactive=True,
111
  )
112
+
113
+ with gr.Row():
114
+ style = gr.CheckboxGroup(
115
+ ["Jazz", "Classical Music", "Hip Hop", "Ragga Jungle", "Dark Jazz", "Soul", "Blues", "80s Rock N Roll"],
116
+ value=None,
117
  label="music genre",
118
  interactive=True,
119
  )
120
+ @gr.on(inputs=[style], outputs=text_prompt)
121
+ def update_prompt(style):
122
+ return generate_prompt(text_prompt.value, style)
123
 
124
+ config_output_textbox = gr.Textbox(label="Model Configs", visible=False)
125
+
126
+ @gr.render(inputs=model_path)
127
+ def show_config_options(model_path):
128
+ print(model_path)
129
+
130
+ with gr.Accordion("Model Generation Configs"):
131
+ if "magnet" in model_path:
132
+ with gr.Row():
133
+ top_k = gr.Number(label="Top-k", value=300, interactive=True)
134
+ top_p = gr.Number(label="Top-p", value=0, interactive=True)
135
+ temperature = gr.Number(
136
+ label="Temperature", value=1.0, interactive=True
137
+ )
138
+ span_arrangement = gr.Radio(["nonoverlap", "stride1"], value='nonoverlap', label="span arrangment", info=" Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') ")
139
+ @gr.on(inputs=[top_k, top_p, temperature, span_arrangement], outputs=config_output_textbox)
140
+ def return_model_configs(top_k, top_p, temperature, span_arrangement):
141
+ return {"top_k": top_k, "top_p": top_p, "temperature": temperature, "span_arrangement": span_arrangement}
142
+ else:
143
+ with gr.Row():
144
+ duration = gr.Slider(
145
+ minimum=10,
146
+ maximum=30,
147
+ value=30,
148
+ label="Duration",
149
+ interactive=True,
150
+ )
151
+ use_sampling = gr.Checkbox(label="Use Sampling", interactive=True, value=True)
152
+ top_k = gr.Number(label="Top-k", value=300, interactive=True)
153
+ top_p = gr.Number(label="Top-p", value=0, interactive=True)
154
+ temperature = gr.Number(
155
+ label="Temperature", value=1.0, interactive=True
156
+ )
157
+ @gr.on(inputs=[duration, use_sampling, top_k, top_p, temperature], outputs=config_output_textbox)
158
+ def return_model_configs(duration, use_sampling, top_k, top_p, temperature):
159
+ return {"duration": duration, "use_sampling": use_sampling, "top_k": top_k, "top_p": top_p, "temperature": temperature}
 
 
 
 
 
 
 
 
 
 
160
 
161
  with gr.Column():
162
  with gr.Row():
163
+ melody = gr.Audio(sources=["upload"], type="numpy", label="File",
164
+ interactive=True, elem_id="melody-input", visible=False)
 
 
 
 
 
165
  submit = gr.Button("Generate Music")
166
+ result_text = gr.Textbox(label="Generated Music (text)", type="text", interactive=False)
167
+ print(result_text)
168
+ output_audios = []
169
+ @gr.render(inputs=result_text)
170
+ def show_output_audio(tmp_paths):
171
+ if tmp_paths:
172
+ tmp_paths = ast.literal_eval(tmp_paths)
173
+ print(tmp_paths)
174
+ for i in range(len(tmp_paths)):
175
+ tmp_path = tmp_paths[i]
176
+ _audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False, visible=True)
177
+ output_audios.append(_audio)
178
+
179
+ submit.click(
180
+ fn=predict,
181
+ inputs=[model_path, config_output_textbox, text_prompt, melody, num_outputs],
182
+ outputs=result_text,
183
+ queue=True
184
  )
185
+
186
+
187
+ with gr.Tab("Generate Music by melody"):
188
+ with gr.Column():
189
+ with gr.Row():
190
+ radio_melody_condition = gr.Radio(["Muisc Continuation", "Music Conditioning"], value=None, label="Select the condition")
191
+ model_path2 = gr.Dropdown(label="model")
192
+ @gr.on(inputs=radio_melody_condition, outputs=model_path2)
193
+ def model_selection(radio_melody_condition):
194
+ if radio_melody_condition == "Muisc Continuation":
195
+ model_path2 = gr.Dropdown(
196
+ choices=MELODY_CONTINUATION_MODELS,
197
+ label="Select the model",
198
+ value="facebook/musicgen-large",
199
+ interactive=True,
200
+ visible=True
201
+ )
202
+ elif radio_melody_condition == "Music Conditioning":
203
+ model_path2 = gr.Dropdown(
204
+ choices=MELODY_CONDITIONED_MODELS,
205
+ label="Select the model",
206
+ value="facebook/musicgen-melody-large",
207
+ interactive=True,
208
+ visible=True
209
+ )
210
+ else:
211
+ model_path2 = gr.Dropdown(
212
+ choices=TEXT_TO_SOUND_MODELS,
213
+ label="Select the model",
214
+ value="facebook/musicgen-large",
215
+ interactive=True,
216
+ visible=False
217
+ )
218
+ return model_path2
219
+ upload_melody = gr.Audio(sources=["upload", "microphone"], type="filepath", label="File")
220
+ prompt_text2 = gr.Textbox(
221
+ label="Let's make a song about ...",
222
+ value=None,
223
+ interactive=True,
224
+ visible=True,
225
+ )
226
+ with gr.Row():
227
+ config_output_textbox2 = gr.Textbox(
228
+ label="Model Configs",
229
+ visible=True)
230
  with gr.Row():
231
+ duration2 = gr.Number(10, label="Duration", interactive=True)
232
+ num_outputs2 = gr.Number(1, label="Number of outputs", interactive=True)
 
 
 
233
 
234
+ @gr.on(inputs=[duration2], outputs=config_output_textbox2)
235
+ def return_model_configs2(duration):
236
+ return {"duration": duration, "use_sampling": True, "top_k": 300, "top_p": 0, "temperature": 1}
237
+ submit2 = gr.Button("Generate Music")
238
+ result_text2 = gr.Textbox(label="Generated Music (melody)", type="text", interactive=False, visible=True)
239
+ submit2.click(
240
+ fn=predict,
241
+ inputs=[model_path2, config_output_textbox2, prompt_text2, upload_melody, num_outputs2],
242
+ outputs=result_text2,
243
+ queue=True
244
+ )
245
+
246
+ @gr.render(inputs=result_text2)
247
+ def show_output_audio(tmp_paths):
248
+ if tmp_paths:
249
+ tmp_paths = ast.literal_eval(tmp_paths)
250
+ print(tmp_paths)
251
+ for i in range(len(tmp_paths)):
252
+ tmp_path = tmp_paths[i]
253
+ _audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False)
254
+ output_audios.append(_audio)
255
  gr.Examples(
256
+ examples = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  [
258
  os.path.join(
259
+ os.path.dirname(__file__), "./data/audio/Suri's Improv.mp3"
 
260
  ),
261
+ 30,
262
+ "facebook/musicgen-large",
263
+ "Muisc Continuation",
 
 
 
 
 
 
 
 
 
264
  ],
265
  [
266
  os.path.join(
267
+ os.path.dirname(__file__), "./data/audio/lie_no_tomorrow_20sec.wav"
 
268
  ),
269
+ 40,
270
+ "facebook/musicgen-melody-large",
271
+ "Music Conditioning",
272
+ ]
 
 
 
 
 
 
 
 
 
273
  ],
274
+ inputs=[upload_melody, duration2, model_path2, radio_melody_condition],
 
 
 
275
  )
276
 
277
  with gr.Tab("Generate Music by image"):
278
+ with gr.Column():
279
+ with gr.Row():
280
  image_input = gr.Image("Upload an image", type="filepath")
281
+ with gr.Accordion("Image Captioning", open=False):
282
+ image_description = gr.Textbox(label='image description', visible=True, interactive=False)
283
+ image_caption = gr.Textbox(label='generated text prompt', visible=True, interactive=True)
284
+ @gr.on(inputs=image_input, outputs=[image_description, image_caption])
285
+ def generate_image_text_prompt(image_input):
286
+ if image_input:
287
+ image_description, image_caption = generate_caption_gpt4(image_input, model_path)
288
+ # meesage_object, description, prompt = generate_caption_claude3(image_input, model_path)
289
+ return image_description, image_caption
290
+ return "", ""
291
+ with gr.Row():
292
+ melody3 = gr.Audio(sources=["upload", "microphone"], type="filepath", label="File", visible=True)
293
+ with gr.Column():
294
+ model_path3 = gr.Dropdown(
295
+ choices=TEXT_TO_SOUND_MODELS + TEXT_TO_MUSIC_MODELS + MELODY_CONDITIONED_MODELS,
296
+ label="Select the model",
297
+ value="facebook/musicgen-large",
298
+ )
299
+ duration3 = gr.Number(30, visible=False, label="Duration")
300
+ submit3 = gr.Button("Generate Music")
301
+ result_text3 = gr.Textbox(label="Generated Music (image)", type="text", interactive=False, visible=True)
302
+ def predict_image_music(model_path3, image_caption, duration3, melody3):
303
+ model_configs = {"duration": duration3, "use_sampling": True, "top_k": 250, "top_p": 0, "temperature": 1}
304
+ return predict(
305
+ model_version = model_path3,
306
+ generation_configs = model_configs,
307
+ prompt_text = image_caption,
308
+ prompt_wav = melody3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
+ submit3.click(
312
+ fn=predict_image_music,
313
+ inputs=[model_path3, image_caption, duration3, melody3],
314
+ outputs=result_text3,
315
+ queue=True
316
+ )
317
+
318
+ @gr.render(inputs=result_text3)
319
+ def show_output_audio(tmp_paths):
320
+ if tmp_paths:
321
+ tmp_paths = ast.literal_eval(tmp_paths)
322
+ print(tmp_paths)
323
+ for i in range(len(tmp_paths)):
324
+ tmp_path = tmp_paths[i]
325
+ _audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False)
326
+ output_audios.append(_audio)
327
+
328
+ @gr.render(inputs=result_text3)
329
+ def show_transcribt_audio(tmp_paths):
330
+ transcribe(tmp_paths)
331
  gr.Examples(
332
+ examples = [
333
  [
334
  os.path.join(
335
+ os.path.dirname(__file__), "./data/image/beach.jpeg"
 
336
  ),
 
 
337
  "facebook/musicgen-large",
338
+ 30,
339
+ None,
340
  ],
341
  [
342
  os.path.join(
343
+ os.path.dirname(__file__), "./data/image/beach.jpeg"
 
344
  ),
345
+ "facebook/audiogen-medium",
346
+ 15,
347
  None,
 
348
  ],
349
  [
350
  os.path.join(
351
+ os.path.dirname(__file__), "./data/image/beach.jpeg"
 
352
  ),
 
 
353
  "facebook/musicgen-melody-large",
354
+ 30,
355
+ os.path.join(
356
+ os.path.dirname(__file__), "./data/audio/Suri's Improv.mp3"
357
+ ),
358
  ],
359
  [
360
  os.path.join(
361
+ os.path.dirname(__file__), "./data/image/cat.jpeg"
 
362
  ),
363
+ "facebook/musicgen-large",
364
+ 30,
365
  None,
 
366
  ],
367
  ],
368
+ inputs=[image_input, model_path3, duration3, melody3],
 
 
 
369
  )
370
 
371
+ demo.queue().launch(share=share)
372
 
373
 
374
  if __name__ == "__main__":
375
+ # Create the parser
376
+ parser = argparse.ArgumentParser()
377
+ parser.add_argument('--share', action='store_true', help='Enable sharing.')
378
+ args = parser.parse_args()
379
+
380
+ UI(share=args.share)
data/audio/Suri's Improv.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:077b6d42a3ee05b15c3a02d7e2aaad7841e52b005d0443ac4aa280464a9a9c96
3
+ size 163337
data/audio/like_no_tomorrow_20sec.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cabd551cfeea4bb608a010aa33118e744b55b816fc71697ffab7edb5ff350805
3
+ size 7680088
gradio_components/image.py CHANGED
@@ -4,9 +4,11 @@ import os
4
 
5
  import anthropic
6
  import gradio as gr
 
 
7
 
8
  # Remember to put your API Key here
9
- client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
10
 
11
  # image1_url = "https://i.abcnewsfe.com/a/7d849ccc-e0fe-4416-959d-85889e338add/dune-1-ht-bb-231212_1702405287482_hpMain_16x9.jpeg"
12
  image1_media_type = "image/jpeg"
@@ -20,12 +22,22 @@ The model was trained with descriptions from a stock music catalog, descriptions
20
 
21
  Try to make the prompt simple and concise with only 1-2 sentences
22
 
23
- Make sure the ouput is in JSON fomat, with two items `description` and `prompt`"""
 
 
 
 
 
 
 
24
 
25
  SYSTEM_PROMPT_AUDIO = """You are an expert llm prompt engineer, you understand the structure of llms and facebook musicgen text to audio model. You will be provided with an image, and require to output a prompt for the musicgen model to capture the essense of the image. Try to do it step by step, evaluate and analyze the image thoroughly. After that, develop a prompt that contains the detail of what background sounds this image should have. This prompt will be provided to audiogen model to generate a 15s audio clip.
26
  Try to make the prompt simple and concise with only 1-2 sentences
27
 
28
- Make sure the ouput is in JSON fomat, with two items `description` and `prompt`
 
 
 
29
  """
30
 
31
  PROMPT_IMPROVEMENT_GENERATE_PROMPT = """
@@ -58,8 +70,39 @@ def improve_prompt(prompt):
58
  prompt = message_object["prompt"]
59
  return message_object, prompt
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def generate_caption(image_file, model_file, progress=gr.Progress()):
63
  if model_file == "facebook/audiogen-medium":
64
  system_prompt = SYSTEM_PROMPT_AUDIO
65
  else:
 
4
 
5
  import anthropic
6
  import gradio as gr
7
+ from openai import OpenAI
8
+
9
 
10
  # Remember to put your API Key here
11
+ client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
12
 
13
  # image1_url = "https://i.abcnewsfe.com/a/7d849ccc-e0fe-4416-959d-85889e338add/dune-1-ht-bb-231212_1702405287482_hpMain_16x9.jpeg"
14
  image1_media_type = "image/jpeg"
 
22
 
23
  Try to make the prompt simple and concise with only 1-2 sentences
24
 
25
+ only return dictionary, with two items `description` and `prompt`
26
+
27
+ for example
28
+ {
29
+ "description": "A serene beach at sunset with gentle waves and a distant ship.",
30
+ "prompt": "A calming instrumental with gentle guitar, soft piano, and ocean waves sound effects, perfect for a relaxing moment by the sea."
31
+ }
32
+ """
33
 
34
  SYSTEM_PROMPT_AUDIO = """You are an expert llm prompt engineer, you understand the structure of llms and facebook musicgen text to audio model. You will be provided with an image, and require to output a prompt for the musicgen model to capture the essense of the image. Try to do it step by step, evaluate and analyze the image thoroughly. After that, develop a prompt that contains the detail of what background sounds this image should have. This prompt will be provided to audiogen model to generate a 15s audio clip.
35
  Try to make the prompt simple and concise with only 1-2 sentences
36
 
37
+ only return dictionary, with two items `description` and `prompt`
38
+ for example
39
+ {"description": "A serene beach scene at sunset with gentle waves lapping on the shore and a distant ship sailing on the water.",
40
+ "prompt": "Gentle waves flowing on the beach at sunset, with a distant ship in the background."}
41
  """
42
 
43
  PROMPT_IMPROVEMENT_GENERATE_PROMPT = """
 
70
  prompt = message_object["prompt"]
71
  return message_object, prompt
72
 
73
+ def generate_caption_gpt4(image_file, model_file):
74
+ client = OpenAI()
75
+ if model_file == "facebook/audiogen-medium":
76
+ system_prompt = SYSTEM_PROMPT_AUDIO
77
+ else:
78
+ system_prompt = SYSTEM_PROMPT
79
+ with open(image_file, "rb") as f:
80
+ image_encoded = base64.b64encode(f.read()).decode("utf-8")
81
+ response = client.chat.completions.create(
82
+ model="gpt-4o",
83
+ messages=[
84
+ {
85
+ "role": "user",
86
+ "content": [
87
+ {"type": "text",
88
+ "text": system_prompt},
89
+ {
90
+ "type": "image_url",
91
+ "image_url": {
92
+ "url": f"data:image/jpeg;base64,{image_encoded}",
93
+ },
94
+ },
95
+ ],
96
+ }
97
+ ],
98
+ max_tokens=300,
99
+ )
100
+ message = json.loads(response.choices[0].message.content)
101
+ return message['description'], message['prompt']
102
+
103
+
104
 
105
+ def generate_caption_claude3(image_file, model_file, progress=gr.Progress()):
106
  if model_file == "facebook/audiogen-medium":
107
  system_prompt = SYSTEM_PROMPT_AUDIO
108
  else:
gradio_components/model_cards.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ TEXT_TO_MUSIC_MODELS = [
4
+ "facebook/musicgen-medium",
5
+ "facebook/musicgen-small",
6
+ "facebook/musicgen-large",
7
+ 'facebook/magnet-small-10secs',
8
+ 'facebook/magnet-medium-10secs',
9
+ 'facebook/magnet-small-30secs',
10
+ 'facebook/magnet-medium-30secs',
11
+ # "facebook/musicgen-stereo-small",
12
+ # "facebook/musicgen-stereo-medium",
13
+ # "facebook/musicgen-stereo-large",
14
+ ]
15
+
16
+ TEXT_TO_MIDI_MODELS = [
17
+ "musiclang/musiclang-v2",
18
+ ]
19
+
20
+ MELODY_CONTINUATION_MODELS = [
21
+ "facebook/musicgen-medium",
22
+ "facebook/musicgen-small",
23
+ "facebook/musicgen-large",
24
+ ]
25
+
26
+ TEXT_TO_SOUND_MODELS = [
27
+ 'facebook/audio-magnet-small',
28
+ 'facebook/audio-magnet-medium',
29
+ "facebook/audiogen-medium",
30
+ ]
31
+
32
+ MELODY_CONDITIONED_MODELS = [
33
+ "facebook/musicgen-melody",
34
+ "facebook/musicgen-melody-large",
35
+ # "facebook/musicgen-stereo-melody",
36
+ # "facebook/musicgen-stereo-melody-large",
37
+ ]
38
+
39
+ STEREO_MODEL = [
40
+ "facebook/musicgen-stereo-small",
41
+ "facebook/musicgen-stereo-medium",
42
+ "facebook/musicgen-stereo-large",
43
+ "facebook/musicgen-stereo-melody",
44
+ "facebook/musicgen-stereo-melody-large",
45
+ ]
46
+
47
+
48
+ MODEL_CARDS = {
49
+ "text-to-music": TEXT_TO_MUSIC_MODELS,
50
+ "text-to-midi": TEXT_TO_MIDI_MODELS,
51
+ "text-to-sound": TEXT_TO_SOUND_MODELS,
52
+ "melody-conditioned": MELODY_CONDITIONED_MODELS,
53
+ }
54
+
55
+ MODEL_DISCLAIMERS = {
56
+ "facebook/musicgen-melody": "1.5B transformer decoder also supporting melody conditioning.",
57
+ "facebook/musicgen-medium": "1.5B transformer decoder.",
58
+ "facebook/musicgen-small": "300M transformer decoder.",
59
+ "facebook/musicgen-large": "3.3B transformer decoder also supporting melody conditioning.",
60
+ "facebook/musicgen-melody-large": "3.3B transformer decoder.",
61
+ 'facebook/magnet-small-10secs': "A 300M non-autoregressive transformer capable of generating 10-second music conditioned on text.",
62
+ 'facebook/magnet-medium-10secs': "A 1.5B parameters, 10 seconds music samples..",
63
+ 'facebook/magnet-small-30secs': "A 300M parameters, 30 seconds music samples.",
64
+ 'facebook/magnet-medium-30secs': "A 1.5B parameters, 30 seconds music samples.",
65
+ # "musiclang/musiclang-v2": "This model generates music from text prompts.", TODO: Implement MusicLang
66
+ 'facebook/audio-magnet-small': "a 300M non-autoregressive transformer capable of generating 10 second sound effects conditioned on text.",
67
+ 'facebook/audio-magnet-medium': "10 second sound effect generation, 1.5B parameters.",
68
+ "facebook/audiogen-medium": "1.5B transformer decoder capable of generating sound effects conditioned on text.",
69
+ }
70
+
71
+
72
+
73
+ def print_model_cards():
74
+ for key, value in MODEL_CARDS.items():
75
+ print(key, ":", value)
gradio_components/prediction.py CHANGED
@@ -8,77 +8,177 @@ import gradio as gr
8
  import torch
9
  from audiocraft.data.audio import audio_write
10
  from audiocraft.data.audio_utils import convert_audio
11
- from audiocraft.models import AudioGen, MusicGen
12
  from basic_pitch import ICASSP_2022_MODEL_PATH
13
- from transformers import AutoModelForSeq2SeqLM
 
 
 
 
 
 
14
 
 
15
 
16
- def load_model(version="facebook/musicgen-melody"):
17
- if version in ["facebook/audiogen-medium"]:
18
- return AudioGen.get_pretrained(version)
 
 
 
 
 
 
 
 
 
 
 
 
19
  else:
20
- return MusicGen.get_pretrained(version)
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def _do_predictions(
24
  model_file,
25
  model,
26
- texts,
27
- melodies,
28
- duration,
29
  progress=False,
30
- gradio_progress=None,
31
- target_sr=32000,
32
- target_ac=1,
33
  **gen_kwargs,
34
  ):
35
  print(
36
- "new batch",
37
- len(texts),
38
- texts,
39
- [None if m is None else (m[0], m[1].shape) for m in melodies],
40
  )
41
  be = time.time()
42
- processed_melodies = []
43
- model.set_generation_params(duration=duration)
44
- for melody in melodies:
45
- if melody is None:
46
- processed_melodies.append(None)
47
- else:
48
- sr, melody = (
49
- melody[0],
50
- torch.from_numpy(melody[1]).to(model.device).float().t(),
51
- )
52
- print(f"Input audio sample rate is {sr}")
53
- if melody.dim() == 1:
54
- melody = melody[None]
55
- melody = melody[..., : int(sr * duration)]
56
- melody = convert_audio(melody, sr, target_sr, target_ac)
57
- processed_melodies.append(melody)
58
-
59
  try:
60
- if any(m is not None for m in processed_melodies):
61
- # melody condition
62
- outputs = model.generate_with_chroma(
63
- descriptions=texts,
64
- melody_wavs=processed_melodies,
65
- melody_sample_rate=target_sr,
66
- progress=progress,
67
- return_tokens=False,
68
- )
69
- else:
70
- if model_file == "facebook/audiogen-medium":
71
- # audio condition
72
- outputs = model.generate(texts, progress=progress)
73
  else:
74
- # text only
75
- outputs = model.generate(texts, progress=progress)
 
 
 
 
 
76
 
77
  except RuntimeError as e:
78
  raise gr.Error("Error while generating " + e.args[0])
79
  outputs = outputs.detach().cpu().float()
80
- pending_videos = []
81
- out_wavs = []
82
  for output in outputs:
83
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
84
  audio_write(
@@ -90,45 +190,36 @@ def _do_predictions(
90
  loudness_compressor=True,
91
  add_suffix=False,
92
  )
93
- out_wavs.append(file.name)
94
- print("generation finished", len(texts), time.time() - be)
95
- return out_wavs
96
-
 
 
 
 
 
97
 
 
 
 
 
 
 
 
 
 
98
  def predict(
99
- model_path,
100
- text,
101
- melody,
102
- duration,
103
- topk,
104
- topp,
105
- temperature,
106
- target_sr,
107
  progress=gr.Progress(),
108
  ):
109
  global INTERRUPTING
110
- global USE_DIFFUSION
111
  INTERRUPTING = False
112
  progress(0, desc="Loading model...")
113
- model_path = model_path.strip()
114
- # if model_path:
115
- # if not Path(model_path).exists():
116
- # raise gr.Error(f"Model path {model_path} doesn't exist.")
117
- # if not Path(model_path).is_dir():
118
- # raise gr.Error(f"Model path {model_path} must be a folder containing "
119
- # "state_dict.bin and compression_state_dict_.bin.")
120
- if temperature < 0:
121
- raise gr.Error("Temperature must be >= 0.")
122
- if topk < 0:
123
- raise gr.Error("Topk must be non-negative.")
124
- if topp < 0:
125
- raise gr.Error("Topp must be non-negative.")
126
-
127
- topk = int(topk)
128
- model = load_model(model_path)
129
-
130
- max_generated = 0
131
-
132
  def _progress(generated, to_generate):
133
  nonlocal max_generated
134
  max_generated = max(generated, max_generated)
@@ -136,40 +227,56 @@ def predict(
136
  if INTERRUPTING:
137
  raise gr.Error("Interrupted.")
138
 
 
139
  model.set_custom_progress_callback(_progress)
 
 
 
 
 
 
 
 
 
140
 
141
- wavs = _do_predictions(
142
- model_path,
143
  model,
144
- [text],
145
- [melody],
146
- duration,
147
  progress=True,
148
- target_ac=1,
149
- target_sr=target_sr,
150
- top_k=topk,
151
- top_p=topp,
152
- temperature=temperature,
153
- gradio_progress=progress,
154
  )
155
- return wavs[0]
156
 
157
 
158
  def transcribe(audio_path):
 
 
 
159
  # model_output, midi_data, note_events = predict("generated_0.wav")
160
- model_output, midi_data, note_events = basic_pitch.inference.predict(
161
- audio_path=audio_path,
162
- model_or_model_path=ICASSP_2022_MODEL_PATH,
163
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
166
- try:
167
- midi_data.write(file)
168
- print(f"midi file saved to {file.name}")
169
- except Exception as e:
170
- print(f"Error while writing midi file: {e}")
171
- raise e
172
 
173
- return gr.DownloadButton(
174
- value=file.name, label=f"Download MIDI file {file.name}", visible=True
175
- )
 
8
  import torch
9
  from audiocraft.data.audio import audio_write
10
  from audiocraft.data.audio_utils import convert_audio
11
+ from audiocraft.models import AudioGen, MusicGen, MAGNeT
12
  from basic_pitch import ICASSP_2022_MODEL_PATH
13
+ # from transformers import AutoModelForSeq2SeqLM
14
+ from concurrent.futures import ProcessPoolExecutor
15
+ import typing as tp
16
+ import warnings
17
+ import json
18
+ import ast
19
+ import torchaudio
20
 
21
+ MODEL = None
22
 
23
+ def load_model(version='facebook/musicgen-large'):
24
+ global MODEL
25
+ if MODEL is None or MODEL.name != version:
26
+ del MODEL
27
+ MODEL = None # in case loading would crash
28
+ print("Loading model", version)
29
+ if "magnet" in version:
30
+ MODEL = MAGNeT.get_pretrained(version)
31
+ elif "musicgen" in version:
32
+ MODEL = MusicGen.get_pretrained(version)
33
+ elif "musiclang" in version:
34
+ # TODO: Implement MusicLang
35
+ pass
36
+ elif "audiogen" in version:
37
+ MODEL = AudioGen.get_pretrained(version)
38
  else:
39
+ raise ValueError("Invalid model version")
40
+
41
+ return MODEL
42
 
43
+ pool = ProcessPoolExecutor(4)
44
+ class FileCleaner:
45
+ def __init__(self, file_lifetime: float = 3600):
46
+ self.file_lifetime = file_lifetime
47
+ self.files = []
48
+
49
+ def add(self, path: tp.Union[str, Path]):
50
+ self._cleanup()
51
+ self.files.append((time.time(), Path(path)))
52
+
53
+ def _cleanup(self):
54
+ now = time.time()
55
+ for time_added, path in list(self.files):
56
+ if now - time_added > self.file_lifetime:
57
+ if path.exists():
58
+ path.unlink()
59
+ self.files.pop(0)
60
+ else:
61
+ break
62
+
63
+ file_cleaner = FileCleaner()
64
+
65
+ def inference_musicgen_text_to_music(model, configs, text, num_outputs=1):
66
+ model.set_generation_params(
67
+ **configs
68
+ )
69
+ descriptions = [text for _ in range(num_outputs)]
70
+ output = model.generate(descriptions=descriptions ,progress=True, return_tokens=False)
71
+ return output
72
+
73
+ def inference_musicgen_continuation(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1):
74
+ model.set_generation_params(
75
+ **configs
76
+ )
77
+ # melody, prompt_sr = torchaudio.load(prompt_waveform)
78
+ # descriptions = [text for _ in range(num_outputs)]
79
+ # prompt = [prompt_waveform for _ in range(num_outputs)]
80
+ output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=False)
81
+ return output
82
+
83
+ def inference_musicgen_melody_condition(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1):
84
+ model.set_generation_params(**configs)
85
+ descriptions = [text for _ in range(num_outputs)]
86
+ output = model.generate_with_chroma(
87
+ descriptions=descriptions,
88
+ melody_wavs=prompt_waveform,
89
+ melody_sample_rate=prompt_sr,
90
+ progress=True,
91
+ return_tokens=False
92
+ )
93
+ return output
94
+
95
+ def inference_magnet(model, configs, text, num_outputs=1):
96
+ model.set_generation_params(
97
+ **configs
98
+ )
99
+ descriptions = [text for _ in range(num_outputs)]
100
+ output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
101
+ return output
102
+
103
+ def inference_magnet_audio(model, configs, text, num_outputs=1):
104
+ model.set_generation_params(
105
+ **configs
106
+ )
107
+ descriptions = [text for _ in range(num_outputs)]
108
+ output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
109
+ return output
110
+
111
+ def inference_audiogen(model, configs, text, num_outputs=1):
112
+ model.set_generation_params(
113
+ **configs
114
+ )
115
+ descriptions = [text for _ in range(num_outputs)]
116
+ output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
117
+ return output
118
+
119
+ def inference_musiclang():
120
+ # TODO: Implement MusicLang
121
+ pass
122
+
123
+
124
+ def process_audio(gr_audio, prompt_duration, model):
125
+ # audio, sr = torch.from_numpy(gr_audio[1]).to(model.device).float().t(), gr_audio[0]
126
+ audio, sr = torchaudio.load(gr_audio)
127
+ audio = audio[..., :int(prompt_duration * sr)]
128
+ return audio, sr
129
+
130
+ _MODEL_INFERENCES = {
131
+ "facebook/musicgen-small": inference_musicgen_text_to_music,
132
+ "facebook/musicgen-medium": inference_musicgen_text_to_music,
133
+ "facebook/musicgen-large": inference_musicgen_text_to_music,
134
+ "facebook/musicgen-melody": inference_musicgen_melody_condition,
135
+ "facebook/musicgen-melody-large": inference_musicgen_melody_condition,
136
+ "facebook/magnet-small-10secs": inference_magnet,
137
+ "facebook/magnet-medium-10secs": inference_magnet,
138
+ "facebook/magnet-small-30secs": inference_magnet,
139
+ "facebook/magnet-medium-30secs": inference_magnet,
140
+ "facebook/audio-magnet-small": inference_magnet_audio,
141
+ "facebook/audio-magnet-medium": inference_magnet_audio,
142
+ "facebook/audiogen-medium": inference_audiogen,
143
+ "musicgen-continuation": inference_musicgen_continuation,
144
+ }
145
 
146
  def _do_predictions(
147
  model_file,
148
  model,
149
+ text,
150
+ melody = None,
151
+ mel_sample_rate=None,
152
  progress=False,
153
+ num_generations=1,
 
 
154
  **gen_kwargs,
155
  ):
156
  print(
157
+ "new generation",
158
+ text,
159
+ None if melody is None else melody.shape
 
160
  )
161
  be = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  try:
163
+ if melody is not None:
164
+ # melody condition or continuation
165
+ if 'melody' in model_file:
166
+ # melody condition - musicgen-melody, musicgen-melody-large
167
+ inderence_func = _MODEL_INFERENCES[model_file]
 
 
 
 
 
 
 
 
168
  else:
169
+ # melody continuation
170
+ inderence_func = _MODEL_INFERENCES['musicgen-continuation']
171
+ outputs = inderence_func(model, gen_kwargs, text, melody, mel_sample_rate, num_generations)
172
+ else:
173
+ # text-to-music, text-to-sound
174
+ inderence_func = _MODEL_INFERENCES[model_file]
175
+ outputs = inderence_func(model, gen_kwargs, text, num_generations)
176
 
177
  except RuntimeError as e:
178
  raise gr.Error("Error while generating " + e.args[0])
179
  outputs = outputs.detach().cpu().float()
180
+ out_audios = []
181
+ video_processes = []
182
  for output in outputs:
183
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
184
  audio_write(
 
190
  loudness_compressor=True,
191
  add_suffix=False,
192
  )
193
+ # video_processes.append(pool.submit(make_waveform, file.name))
194
+ out_audios.append(file.name)
195
+ file_cleaner.add(file.name)
196
+ # out_videos = [video.result() for video in video_processes]
197
+ # for video in out_videos:
198
+ # file_cleaner.add(video)
199
+
200
+ print("generation finished", len(outputs), time.time() - be)
201
+ return out_audios
202
 
203
+ def make_waveform(*args, **kwargs):
204
+ # Further remove some warnings.
205
+ be = time.time()
206
+ with warnings.catch_warnings():
207
+ warnings.simplefilter('ignore')
208
+ out = gr.make_waveform(*args, **kwargs)
209
+ print("Make a video took", time.time() - be)
210
+ return out
211
+
212
  def predict(
213
+ model_version,
214
+ generation_configs,
215
+ prompt_text=None,
216
+ prompt_wav=None,
217
+ num_generations=1,
 
 
 
218
  progress=gr.Progress(),
219
  ):
220
  global INTERRUPTING
 
221
  INTERRUPTING = False
222
  progress(0, desc="Loading model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def _progress(generated, to_generate):
224
  nonlocal max_generated
225
  max_generated = max(generated, max_generated)
 
227
  if INTERRUPTING:
228
  raise gr.Error("Interrupted.")
229
 
230
+ model = load_model(model_version)
231
  model.set_custom_progress_callback(_progress)
232
+ if isinstance(generation_configs, str):
233
+ generation_configs = ast.literal_eval(generation_configs)
234
+ max_generated = 0
235
+ if prompt_wav is not None:
236
+ melody, mel_sample_rate = process_audio(prompt_wav, generation_configs['duration'], model)
237
+ else:
238
+ melody, mel_sample_rate = None, None
239
+
240
+
241
 
242
+ audios = _do_predictions(
243
+ model_version,
244
  model,
245
+ prompt_text,
246
+ melody,
247
+ mel_sample_rate,
248
  progress=True,
249
+ num_generations = num_generations,
250
+ **generation_configs,
 
 
 
 
251
  )
252
+ return audios
253
 
254
 
255
  def transcribe(audio_path):
256
+ """
257
+ Transcribe an audio file to MIDI using the basic_pitch model.
258
+ """
259
  # model_output, midi_data, note_events = predict("generated_0.wav")
260
+ tmp_paths = ast.literal_eval(audio_path)
261
+ download_buttons = []
262
+ for audio_path in tmp_paths:
263
+ model_output, midi_data, note_events = basic_pitch.inference.predict(
264
+ audio_path=audio_path,
265
+ model_or_model_path=ICASSP_2022_MODEL_PATH,
266
+ )
267
+
268
+ with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
269
+ try:
270
+ midi_data.write(file)
271
+ print(f"midi file saved to {file.name}")
272
+ except Exception as e:
273
+ print(f"Error while writing midi file: {e}")
274
+ raise e
275
+ download_buttons.append(gr.DownloadButton(
276
+ value=file.name, label=f"Download MIDI file {file.name}", visible=True
277
+ ))
278
+ file_cleaner.add(file.name)
279
+
280
+ return download_buttons
281
 
 
 
 
 
 
 
 
282