Spaces:
Sleeping
Sleeping
update apps and examples
Browse files- .gitattributes +2 -0
- app.py +244 -393
- data/audio/Suri's Improv.mp3 +3 -0
- data/audio/like_no_tomorrow_20sec.wav +3 -0
- gradio_components/image.py +47 -4
- gradio_components/model_cards.py +75 -0
- gradio_components/prediction.py +215 -108
.gitattributes
CHANGED
@@ -40,3 +40,5 @@ data/audio/twinkle_twinkle_little_stars_mozart.mp3 filter=lfs diff=lfs merge=lfs
|
|
40 |
*.mp3 !text !filter !merge !diff
|
41 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
42 |
data/audio/old_town_road20sec.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
40 |
*.mp3 !text !filter !merge !diff
|
41 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
42 |
data/audio/old_town_road20sec.mp3 filter=lfs diff=lfs merge=lfs -text
|
43 |
+
data/audio/Suri's[[:space:]]Improv.mp3 filter=lfs diff=lfs merge=lfs -text
|
44 |
+
data/audio/like_no_tomorrow_20sec.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,10 +1,20 @@
|
|
1 |
import os
|
2 |
-
|
3 |
import gradio as gr
|
|
|
4 |
|
5 |
-
from gradio_components.image import generate_caption, improve_prompt
|
|
|
6 |
from gradio_components.prediction import predict, transcribe
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
theme = gr.themes.Glass(
|
9 |
primary_hue="fuchsia",
|
10 |
secondary_hue="indigo",
|
@@ -67,463 +77,304 @@ theme = gr.themes.Glass(
|
|
67 |
)
|
68 |
|
69 |
|
70 |
-
_AUDIOCRAFT_MODELS = [
|
71 |
-
"facebook/musicgen-melody",
|
72 |
-
"facebook/musicgen-medium",
|
73 |
-
"facebook/musicgen-small",
|
74 |
-
"facebook/musicgen-large",
|
75 |
-
"facebook/musicgen-melody-large",
|
76 |
-
"facebook/audiogen-medium",
|
77 |
-
]
|
78 |
-
|
79 |
-
|
80 |
-
def generate_prompt(difficulty, style):
|
81 |
-
_DIFFICULTY_MAPPIN = {
|
82 |
-
"Easy": "beginner player",
|
83 |
-
"Medium": "player who has 2-3 years experience",
|
84 |
-
"Hard": "player who has more than 4 years experiences",
|
85 |
-
}
|
86 |
-
prompt = "piano only music for a {} to practice with the touch of {}".format(
|
87 |
-
_DIFFICULTY_MAPPIN[difficulty], style
|
88 |
-
)
|
89 |
-
return prompt
|
90 |
-
|
91 |
-
|
92 |
-
def toggle_melody_condition(melody_condition):
|
93 |
-
if melody_condition:
|
94 |
-
return gr.Audio(
|
95 |
-
sources=["microphone", "upload"],
|
96 |
-
label="Record or upload your audio",
|
97 |
-
show_label=True,
|
98 |
-
visible=True,
|
99 |
-
)
|
100 |
-
else:
|
101 |
-
return gr.Audio(
|
102 |
-
sources=["microphone", "upload"],
|
103 |
-
label="Record or upload your audio",
|
104 |
-
show_label=True,
|
105 |
-
visible=False,
|
106 |
-
)
|
107 |
-
|
108 |
-
|
109 |
-
def toggle_custom_prompt(customize, difficulty, style):
|
110 |
-
if customize:
|
111 |
-
return gr.Textbox(label="Type your prompt", interactive=True, visible=True)
|
112 |
-
else:
|
113 |
-
prompt = generate_prompt(difficulty, style)
|
114 |
-
return gr.Textbox(
|
115 |
-
label="Generated Prompt", value=prompt, interactive=False, visible=True
|
116 |
-
)
|
117 |
-
|
118 |
-
|
119 |
-
def show_caption(show_caption_condition, description, prompt):
|
120 |
-
if show_caption_condition:
|
121 |
-
return (
|
122 |
-
gr.Textbox(
|
123 |
-
label="Image Caption",
|
124 |
-
value=description,
|
125 |
-
interactive=False,
|
126 |
-
show_label=True,
|
127 |
-
visible=True,
|
128 |
-
),
|
129 |
-
gr.Textbox(
|
130 |
-
label="Generated Prompt",
|
131 |
-
value=prompt,
|
132 |
-
interactive=True,
|
133 |
-
show_label=True,
|
134 |
-
visible=True,
|
135 |
-
),
|
136 |
-
gr.Button("Generate Music", interactive=True, visible=True),
|
137 |
-
)
|
138 |
-
else:
|
139 |
-
return (
|
140 |
-
gr.Textbox(
|
141 |
-
label="Image Caption",
|
142 |
-
value=description,
|
143 |
-
interactive=False,
|
144 |
-
show_label=True,
|
145 |
-
visible=False,
|
146 |
-
),
|
147 |
-
gr.Textbox(
|
148 |
-
label="Generated Prompt",
|
149 |
-
value=prompt,
|
150 |
-
interactive=True,
|
151 |
-
show_label=True,
|
152 |
-
visible=False,
|
153 |
-
),
|
154 |
-
gr.Button(label="Generate Music", interactive=True, visible=True),
|
155 |
-
)
|
156 |
-
|
157 |
|
158 |
-
def
|
159 |
-
|
160 |
return prompt
|
161 |
|
162 |
|
163 |
-
def
|
164 |
-
return gr.Textbox(
|
165 |
-
label="Generated Prompt", value=prompt, interactive=False, visible=True
|
166 |
-
)
|
167 |
-
|
168 |
-
|
169 |
-
def post_submit(show_caption, model_path, image_input):
|
170 |
-
_, description, prompt = generate_caption(image_input, model_path)
|
171 |
-
return (
|
172 |
-
gr.Textbox(
|
173 |
-
label="Image Caption",
|
174 |
-
value=description,
|
175 |
-
interactive=False,
|
176 |
-
show_label=True,
|
177 |
-
visible=show_caption,
|
178 |
-
),
|
179 |
-
gr.Textbox(
|
180 |
-
label="Generated Prompt",
|
181 |
-
value=prompt,
|
182 |
-
interactive=True,
|
183 |
-
show_label=True,
|
184 |
-
visible=show_caption,
|
185 |
-
),
|
186 |
-
gr.Button("Generate Music", interactive=True, visible=True),
|
187 |
-
)
|
188 |
-
|
189 |
-
|
190 |
-
def UI():
|
191 |
with gr.Blocks() as demo:
|
192 |
-
with gr.Tab("Generate Music by
|
193 |
with gr.Row():
|
194 |
with gr.Column():
|
195 |
with gr.Row():
|
196 |
model_path = gr.Dropdown(
|
197 |
-
choices=
|
198 |
label="Select the model",
|
199 |
-
value="facebook/musicgen-
|
200 |
-
)
|
201 |
-
with gr.Row():
|
202 |
-
duration = gr.Slider(
|
203 |
-
minimum=10,
|
204 |
-
maximum=60,
|
205 |
-
value=10,
|
206 |
-
label="Duration",
|
207 |
-
interactive=True,
|
208 |
)
|
|
|
209 |
with gr.Row():
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
label="Temperature", value=1.0, interactive=True
|
214 |
-
)
|
215 |
-
sample_rate = gr.Number(
|
216 |
-
label="output music sample rate",
|
217 |
-
value=32000,
|
218 |
interactive=True,
|
|
|
219 |
)
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
224 |
interactive=True,
|
225 |
)
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
229 |
label="music genre",
|
230 |
interactive=True,
|
231 |
)
|
|
|
|
|
|
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
)
|
270 |
-
with gr.Column():
|
271 |
-
show_prompt = gr.Button("Show the prompt", interactive=True)
|
272 |
-
prompt_text = gr.Textbox(
|
273 |
-
"Optimized Prompt", interactive=False, visible=False
|
274 |
-
)
|
275 |
-
optimize.click(optimize_fn, inputs=[prompt], outputs=prompt)
|
276 |
-
show_prompt.click(
|
277 |
-
display_prompt, inputs=[prompt], outputs=prompt_text
|
278 |
-
)
|
279 |
|
280 |
with gr.Column():
|
281 |
with gr.Row():
|
282 |
-
melody = gr.Audio(
|
283 |
-
|
284 |
-
label="Record or upload your audio",
|
285 |
-
# interactive=True,
|
286 |
-
show_label=True,
|
287 |
-
)
|
288 |
-
with gr.Row():
|
289 |
submit = gr.Button("Generate Music")
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
with gr.Row():
|
294 |
-
|
295 |
-
|
296 |
-
transcribe_button.click(
|
297 |
-
transcribe, inputs=[output_audio], outputs=d
|
298 |
-
)
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
gr.Examples(
|
315 |
-
examples=[
|
316 |
-
[
|
317 |
-
os.path.join(
|
318 |
-
os.path.dirname(__file__),
|
319 |
-
"./data/audio/twinkle_twinkle_little_stars_mozart_20sec"
|
320 |
-
".mp3",
|
321 |
-
),
|
322 |
-
"Easy",
|
323 |
-
32000,
|
324 |
-
20,
|
325 |
-
],
|
326 |
-
[
|
327 |
-
os.path.join(
|
328 |
-
os.path.dirname(__file__),
|
329 |
-
"./data/audio/golden_hour_20sec.mp3",
|
330 |
-
),
|
331 |
-
"Easy",
|
332 |
-
32000,
|
333 |
-
20,
|
334 |
-
],
|
335 |
-
[
|
336 |
-
os.path.join(
|
337 |
-
os.path.dirname(__file__),
|
338 |
-
"./data/audio/turkish_march_mozart_20sec.mp3",
|
339 |
-
),
|
340 |
-
"Easy",
|
341 |
-
32000,
|
342 |
-
20,
|
343 |
-
],
|
344 |
[
|
345 |
os.path.join(
|
346 |
-
os.path.dirname(__file__),
|
347 |
-
"./data/audio/golden_hour_20sec.mp3",
|
348 |
),
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
],
|
353 |
-
[
|
354 |
-
os.path.join(
|
355 |
-
os.path.dirname(__file__),
|
356 |
-
"./data/audio/golden_hour_20sec.mp3",
|
357 |
-
),
|
358 |
-
"Hard",
|
359 |
-
32000,
|
360 |
-
40,
|
361 |
],
|
362 |
[
|
363 |
os.path.join(
|
364 |
-
os.path.dirname(__file__),
|
365 |
-
"./data/audio/golden_hour_20sec.mp3",
|
366 |
),
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
]
|
371 |
-
[
|
372 |
-
os.path.join(
|
373 |
-
os.path.dirname(__file__),
|
374 |
-
"./data/audio/old_town_road20sec.mp3",
|
375 |
-
),
|
376 |
-
"Hard",
|
377 |
-
32000,
|
378 |
-
40,
|
379 |
-
],
|
380 |
],
|
381 |
-
inputs=[
|
382 |
-
label="Audio Examples",
|
383 |
-
outputs=[output_audio],
|
384 |
-
# cache_examples=True,
|
385 |
)
|
386 |
|
387 |
with gr.Tab("Generate Music by image"):
|
388 |
-
with gr.
|
389 |
-
with gr.
|
390 |
image_input = gr.Image("Upload an image", type="filepath")
|
391 |
-
|
392 |
-
label=
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
generate = gr.Button(
|
420 |
-
"Generate Music", interactive=True, visible=False
|
421 |
-
)
|
422 |
-
|
423 |
-
with gr.Column():
|
424 |
-
with gr.Row():
|
425 |
-
model_path = gr.Dropdown(
|
426 |
-
choices=_AUDIOCRAFT_MODELS,
|
427 |
-
label="Select the model",
|
428 |
-
value="facebook/musicgen-large",
|
429 |
-
)
|
430 |
-
with gr.Row():
|
431 |
-
duration = gr.Slider(
|
432 |
-
minimum=10,
|
433 |
-
maximum=60,
|
434 |
-
value=10,
|
435 |
-
label="Duration",
|
436 |
-
interactive=True,
|
437 |
)
|
438 |
-
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
439 |
-
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
440 |
-
temperature = gr.Number(
|
441 |
-
label="Temperature", value=1.0, interactive=True
|
442 |
-
)
|
443 |
-
sample_rate = gr.Number(
|
444 |
-
label="output music sample rate", value=32000, interactive=True
|
445 |
-
)
|
446 |
-
with gr.Column():
|
447 |
-
output_audio = gr.Audio(
|
448 |
-
"listen to the generated music",
|
449 |
-
type="filepath",
|
450 |
-
show_label=True,
|
451 |
-
)
|
452 |
-
transcribe_button = gr.Button("Transcribe")
|
453 |
-
d = gr.DownloadButton("Download the file", visible=False)
|
454 |
-
submit.click(
|
455 |
-
fn=post_submit,
|
456 |
-
inputs=[show_prompt, model_path, image_input],
|
457 |
-
outputs=[description, prompt, generate],
|
458 |
-
)
|
459 |
-
show_prompt.change(
|
460 |
-
fn=show_caption,
|
461 |
-
inputs=[show_prompt, description, prompt],
|
462 |
-
outputs=[description, prompt, generate],
|
463 |
-
)
|
464 |
-
transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
|
465 |
-
generate.click(
|
466 |
-
fn=predict,
|
467 |
-
inputs=[
|
468 |
-
model_path,
|
469 |
-
prompt,
|
470 |
-
melody,
|
471 |
-
duration,
|
472 |
-
topk,
|
473 |
-
topp,
|
474 |
-
temperature,
|
475 |
-
sample_rate,
|
476 |
-
],
|
477 |
-
outputs=output_audio,
|
478 |
-
)
|
479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
gr.Examples(
|
481 |
-
examples=[
|
482 |
[
|
483 |
os.path.join(
|
484 |
-
os.path.dirname(__file__),
|
485 |
-
"./data/image/kids_drawing.jpeg",
|
486 |
),
|
487 |
-
False,
|
488 |
-
None,
|
489 |
"facebook/musicgen-large",
|
|
|
|
|
490 |
],
|
491 |
[
|
492 |
os.path.join(
|
493 |
-
os.path.dirname(__file__),
|
494 |
-
"./data/image/cat.jpeg",
|
495 |
),
|
496 |
-
|
|
|
497 |
None,
|
498 |
-
"facebook/musicgen-large",
|
499 |
],
|
500 |
[
|
501 |
os.path.join(
|
502 |
-
os.path.dirname(__file__),
|
503 |
-
"./data/image/cat.jpeg",
|
504 |
),
|
505 |
-
True,
|
506 |
-
"./data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3",
|
507 |
"facebook/musicgen-melody-large",
|
|
|
|
|
|
|
|
|
508 |
],
|
509 |
[
|
510 |
os.path.join(
|
511 |
-
os.path.dirname(__file__),
|
512 |
-
"./data/image/beach.jpeg",
|
513 |
),
|
514 |
-
|
|
|
515 |
None,
|
516 |
-
"facebook/audiogen-medium",
|
517 |
],
|
518 |
],
|
519 |
-
inputs=[image_input,
|
520 |
-
label="Audio Examples",
|
521 |
-
outputs=[output_audio],
|
522 |
-
# cache_examples=True,
|
523 |
)
|
524 |
|
525 |
-
demo.queue().launch()
|
526 |
|
527 |
|
528 |
if __name__ == "__main__":
|
529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
+
from audiocraft.models import MAGNeT, MusicGen, AudioGen
|
4 |
|
5 |
+
# from gradio_components.image import generate_caption, improve_prompt
|
6 |
+
from gradio_components.image import generate_caption_gpt4
|
7 |
from gradio_components.prediction import predict, transcribe
|
8 |
|
9 |
+
import re
|
10 |
+
import argparse
|
11 |
+
from gradio_components.model_cards import TEXT_TO_MIDI_MODELS, TEXT_TO_SOUND_MODELS, MELODY_CONTINUATION_MODELS, TEXT_TO_MUSIC_MODELS, MODEL_CARDS, MELODY_CONDITIONED_MODELS
|
12 |
+
import ast
|
13 |
+
import json
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
theme = gr.themes.Glass(
|
19 |
primary_hue="fuchsia",
|
20 |
secondary_hue="indigo",
|
|
|
77 |
)
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
def generate_prompt(prompt, style):
|
82 |
+
prompt = ','.join([prompt]+style)
|
83 |
return prompt
|
84 |
|
85 |
|
86 |
+
def UI(share=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
with gr.Blocks() as demo:
|
88 |
+
with gr.Tab("Generate Music by text"):
|
89 |
with gr.Row():
|
90 |
with gr.Column():
|
91 |
with gr.Row():
|
92 |
model_path = gr.Dropdown(
|
93 |
+
choices=TEXT_TO_MUSIC_MODELS,
|
94 |
label="Select the model",
|
95 |
+
value="facebook/musicgen-large",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
)
|
97 |
+
|
98 |
with gr.Row():
|
99 |
+
text_prompt = gr.Textbox(
|
100 |
+
label="Let's make a song about ...",
|
101 |
+
value="First day learning music generation in Standford university",
|
|
|
|
|
|
|
|
|
|
|
102 |
interactive=True,
|
103 |
+
visible=True,
|
104 |
)
|
105 |
+
num_outputs = gr.Number(
|
106 |
+
label="Number of outputs",
|
107 |
+
value=1,
|
108 |
+
minimum=1,
|
109 |
+
maximum=10,
|
110 |
interactive=True,
|
111 |
)
|
112 |
+
|
113 |
+
with gr.Row():
|
114 |
+
style = gr.CheckboxGroup(
|
115 |
+
["Jazz", "Classical Music", "Hip Hop", "Ragga Jungle", "Dark Jazz", "Soul", "Blues", "80s Rock N Roll"],
|
116 |
+
value=None,
|
117 |
label="music genre",
|
118 |
interactive=True,
|
119 |
)
|
120 |
+
@gr.on(inputs=[style], outputs=text_prompt)
|
121 |
+
def update_prompt(style):
|
122 |
+
return generate_prompt(text_prompt.value, style)
|
123 |
|
124 |
+
config_output_textbox = gr.Textbox(label="Model Configs", visible=False)
|
125 |
+
|
126 |
+
@gr.render(inputs=model_path)
|
127 |
+
def show_config_options(model_path):
|
128 |
+
print(model_path)
|
129 |
+
|
130 |
+
with gr.Accordion("Model Generation Configs"):
|
131 |
+
if "magnet" in model_path:
|
132 |
+
with gr.Row():
|
133 |
+
top_k = gr.Number(label="Top-k", value=300, interactive=True)
|
134 |
+
top_p = gr.Number(label="Top-p", value=0, interactive=True)
|
135 |
+
temperature = gr.Number(
|
136 |
+
label="Temperature", value=1.0, interactive=True
|
137 |
+
)
|
138 |
+
span_arrangement = gr.Radio(["nonoverlap", "stride1"], value='nonoverlap', label="span arrangment", info=" Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') ")
|
139 |
+
@gr.on(inputs=[top_k, top_p, temperature, span_arrangement], outputs=config_output_textbox)
|
140 |
+
def return_model_configs(top_k, top_p, temperature, span_arrangement):
|
141 |
+
return {"top_k": top_k, "top_p": top_p, "temperature": temperature, "span_arrangement": span_arrangement}
|
142 |
+
else:
|
143 |
+
with gr.Row():
|
144 |
+
duration = gr.Slider(
|
145 |
+
minimum=10,
|
146 |
+
maximum=30,
|
147 |
+
value=30,
|
148 |
+
label="Duration",
|
149 |
+
interactive=True,
|
150 |
+
)
|
151 |
+
use_sampling = gr.Checkbox(label="Use Sampling", interactive=True, value=True)
|
152 |
+
top_k = gr.Number(label="Top-k", value=300, interactive=True)
|
153 |
+
top_p = gr.Number(label="Top-p", value=0, interactive=True)
|
154 |
+
temperature = gr.Number(
|
155 |
+
label="Temperature", value=1.0, interactive=True
|
156 |
+
)
|
157 |
+
@gr.on(inputs=[duration, use_sampling, top_k, top_p, temperature], outputs=config_output_textbox)
|
158 |
+
def return_model_configs(duration, use_sampling, top_k, top_p, temperature):
|
159 |
+
return {"duration": duration, "use_sampling": use_sampling, "top_k": top_k, "top_p": top_p, "temperature": temperature}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
with gr.Column():
|
162 |
with gr.Row():
|
163 |
+
melody = gr.Audio(sources=["upload"], type="numpy", label="File",
|
164 |
+
interactive=True, elem_id="melody-input", visible=False)
|
|
|
|
|
|
|
|
|
|
|
165 |
submit = gr.Button("Generate Music")
|
166 |
+
result_text = gr.Textbox(label="Generated Music (text)", type="text", interactive=False)
|
167 |
+
print(result_text)
|
168 |
+
output_audios = []
|
169 |
+
@gr.render(inputs=result_text)
|
170 |
+
def show_output_audio(tmp_paths):
|
171 |
+
if tmp_paths:
|
172 |
+
tmp_paths = ast.literal_eval(tmp_paths)
|
173 |
+
print(tmp_paths)
|
174 |
+
for i in range(len(tmp_paths)):
|
175 |
+
tmp_path = tmp_paths[i]
|
176 |
+
_audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False, visible=True)
|
177 |
+
output_audios.append(_audio)
|
178 |
+
|
179 |
+
submit.click(
|
180 |
+
fn=predict,
|
181 |
+
inputs=[model_path, config_output_textbox, text_prompt, melody, num_outputs],
|
182 |
+
outputs=result_text,
|
183 |
+
queue=True
|
184 |
)
|
185 |
+
|
186 |
+
|
187 |
+
with gr.Tab("Generate Music by melody"):
|
188 |
+
with gr.Column():
|
189 |
+
with gr.Row():
|
190 |
+
radio_melody_condition = gr.Radio(["Muisc Continuation", "Music Conditioning"], value=None, label="Select the condition")
|
191 |
+
model_path2 = gr.Dropdown(label="model")
|
192 |
+
@gr.on(inputs=radio_melody_condition, outputs=model_path2)
|
193 |
+
def model_selection(radio_melody_condition):
|
194 |
+
if radio_melody_condition == "Muisc Continuation":
|
195 |
+
model_path2 = gr.Dropdown(
|
196 |
+
choices=MELODY_CONTINUATION_MODELS,
|
197 |
+
label="Select the model",
|
198 |
+
value="facebook/musicgen-large",
|
199 |
+
interactive=True,
|
200 |
+
visible=True
|
201 |
+
)
|
202 |
+
elif radio_melody_condition == "Music Conditioning":
|
203 |
+
model_path2 = gr.Dropdown(
|
204 |
+
choices=MELODY_CONDITIONED_MODELS,
|
205 |
+
label="Select the model",
|
206 |
+
value="facebook/musicgen-melody-large",
|
207 |
+
interactive=True,
|
208 |
+
visible=True
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
model_path2 = gr.Dropdown(
|
212 |
+
choices=TEXT_TO_SOUND_MODELS,
|
213 |
+
label="Select the model",
|
214 |
+
value="facebook/musicgen-large",
|
215 |
+
interactive=True,
|
216 |
+
visible=False
|
217 |
+
)
|
218 |
+
return model_path2
|
219 |
+
upload_melody = gr.Audio(sources=["upload", "microphone"], type="filepath", label="File")
|
220 |
+
prompt_text2 = gr.Textbox(
|
221 |
+
label="Let's make a song about ...",
|
222 |
+
value=None,
|
223 |
+
interactive=True,
|
224 |
+
visible=True,
|
225 |
+
)
|
226 |
+
with gr.Row():
|
227 |
+
config_output_textbox2 = gr.Textbox(
|
228 |
+
label="Model Configs",
|
229 |
+
visible=True)
|
230 |
with gr.Row():
|
231 |
+
duration2 = gr.Number(10, label="Duration", interactive=True)
|
232 |
+
num_outputs2 = gr.Number(1, label="Number of outputs", interactive=True)
|
|
|
|
|
|
|
233 |
|
234 |
+
@gr.on(inputs=[duration2], outputs=config_output_textbox2)
|
235 |
+
def return_model_configs2(duration):
|
236 |
+
return {"duration": duration, "use_sampling": True, "top_k": 300, "top_p": 0, "temperature": 1}
|
237 |
+
submit2 = gr.Button("Generate Music")
|
238 |
+
result_text2 = gr.Textbox(label="Generated Music (melody)", type="text", interactive=False, visible=True)
|
239 |
+
submit2.click(
|
240 |
+
fn=predict,
|
241 |
+
inputs=[model_path2, config_output_textbox2, prompt_text2, upload_melody, num_outputs2],
|
242 |
+
outputs=result_text2,
|
243 |
+
queue=True
|
244 |
+
)
|
245 |
+
|
246 |
+
@gr.render(inputs=result_text2)
|
247 |
+
def show_output_audio(tmp_paths):
|
248 |
+
if tmp_paths:
|
249 |
+
tmp_paths = ast.literal_eval(tmp_paths)
|
250 |
+
print(tmp_paths)
|
251 |
+
for i in range(len(tmp_paths)):
|
252 |
+
tmp_path = tmp_paths[i]
|
253 |
+
_audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False)
|
254 |
+
output_audios.append(_audio)
|
255 |
gr.Examples(
|
256 |
+
examples = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
[
|
258 |
os.path.join(
|
259 |
+
os.path.dirname(__file__), "./data/audio/Suri's Improv.mp3"
|
|
|
260 |
),
|
261 |
+
30,
|
262 |
+
"facebook/musicgen-large",
|
263 |
+
"Muisc Continuation",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
],
|
265 |
[
|
266 |
os.path.join(
|
267 |
+
os.path.dirname(__file__), "./data/audio/lie_no_tomorrow_20sec.wav"
|
|
|
268 |
),
|
269 |
+
40,
|
270 |
+
"facebook/musicgen-melody-large",
|
271 |
+
"Music Conditioning",
|
272 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
],
|
274 |
+
inputs=[upload_melody, duration2, model_path2, radio_melody_condition],
|
|
|
|
|
|
|
275 |
)
|
276 |
|
277 |
with gr.Tab("Generate Music by image"):
|
278 |
+
with gr.Column():
|
279 |
+
with gr.Row():
|
280 |
image_input = gr.Image("Upload an image", type="filepath")
|
281 |
+
with gr.Accordion("Image Captioning", open=False):
|
282 |
+
image_description = gr.Textbox(label='image description', visible=True, interactive=False)
|
283 |
+
image_caption = gr.Textbox(label='generated text prompt', visible=True, interactive=True)
|
284 |
+
@gr.on(inputs=image_input, outputs=[image_description, image_caption])
|
285 |
+
def generate_image_text_prompt(image_input):
|
286 |
+
if image_input:
|
287 |
+
image_description, image_caption = generate_caption_gpt4(image_input, model_path)
|
288 |
+
# meesage_object, description, prompt = generate_caption_claude3(image_input, model_path)
|
289 |
+
return image_description, image_caption
|
290 |
+
return "", ""
|
291 |
+
with gr.Row():
|
292 |
+
melody3 = gr.Audio(sources=["upload", "microphone"], type="filepath", label="File", visible=True)
|
293 |
+
with gr.Column():
|
294 |
+
model_path3 = gr.Dropdown(
|
295 |
+
choices=TEXT_TO_SOUND_MODELS + TEXT_TO_MUSIC_MODELS + MELODY_CONDITIONED_MODELS,
|
296 |
+
label="Select the model",
|
297 |
+
value="facebook/musicgen-large",
|
298 |
+
)
|
299 |
+
duration3 = gr.Number(30, visible=False, label="Duration")
|
300 |
+
submit3 = gr.Button("Generate Music")
|
301 |
+
result_text3 = gr.Textbox(label="Generated Music (image)", type="text", interactive=False, visible=True)
|
302 |
+
def predict_image_music(model_path3, image_caption, duration3, melody3):
|
303 |
+
model_configs = {"duration": duration3, "use_sampling": True, "top_k": 250, "top_p": 0, "temperature": 1}
|
304 |
+
return predict(
|
305 |
+
model_version = model_path3,
|
306 |
+
generation_configs = model_configs,
|
307 |
+
prompt_text = image_caption,
|
308 |
+
prompt_wav = melody3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
+
submit3.click(
|
312 |
+
fn=predict_image_music,
|
313 |
+
inputs=[model_path3, image_caption, duration3, melody3],
|
314 |
+
outputs=result_text3,
|
315 |
+
queue=True
|
316 |
+
)
|
317 |
+
|
318 |
+
@gr.render(inputs=result_text3)
|
319 |
+
def show_output_audio(tmp_paths):
|
320 |
+
if tmp_paths:
|
321 |
+
tmp_paths = ast.literal_eval(tmp_paths)
|
322 |
+
print(tmp_paths)
|
323 |
+
for i in range(len(tmp_paths)):
|
324 |
+
tmp_path = tmp_paths[i]
|
325 |
+
_audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False)
|
326 |
+
output_audios.append(_audio)
|
327 |
+
|
328 |
+
@gr.render(inputs=result_text3)
|
329 |
+
def show_transcribt_audio(tmp_paths):
|
330 |
+
transcribe(tmp_paths)
|
331 |
gr.Examples(
|
332 |
+
examples = [
|
333 |
[
|
334 |
os.path.join(
|
335 |
+
os.path.dirname(__file__), "./data/image/beach.jpeg"
|
|
|
336 |
),
|
|
|
|
|
337 |
"facebook/musicgen-large",
|
338 |
+
30,
|
339 |
+
None,
|
340 |
],
|
341 |
[
|
342 |
os.path.join(
|
343 |
+
os.path.dirname(__file__), "./data/image/beach.jpeg"
|
|
|
344 |
),
|
345 |
+
"facebook/audiogen-medium",
|
346 |
+
15,
|
347 |
None,
|
|
|
348 |
],
|
349 |
[
|
350 |
os.path.join(
|
351 |
+
os.path.dirname(__file__), "./data/image/beach.jpeg"
|
|
|
352 |
),
|
|
|
|
|
353 |
"facebook/musicgen-melody-large",
|
354 |
+
30,
|
355 |
+
os.path.join(
|
356 |
+
os.path.dirname(__file__), "./data/audio/Suri's Improv.mp3"
|
357 |
+
),
|
358 |
],
|
359 |
[
|
360 |
os.path.join(
|
361 |
+
os.path.dirname(__file__), "./data/image/cat.jpeg"
|
|
|
362 |
),
|
363 |
+
"facebook/musicgen-large",
|
364 |
+
30,
|
365 |
None,
|
|
|
366 |
],
|
367 |
],
|
368 |
+
inputs=[image_input, model_path3, duration3, melody3],
|
|
|
|
|
|
|
369 |
)
|
370 |
|
371 |
+
demo.queue().launch(share=share)
|
372 |
|
373 |
|
374 |
if __name__ == "__main__":
|
375 |
+
# Create the parser
|
376 |
+
parser = argparse.ArgumentParser()
|
377 |
+
parser.add_argument('--share', action='store_true', help='Enable sharing.')
|
378 |
+
args = parser.parse_args()
|
379 |
+
|
380 |
+
UI(share=args.share)
|
data/audio/Suri's Improv.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:077b6d42a3ee05b15c3a02d7e2aaad7841e52b005d0443ac4aa280464a9a9c96
|
3 |
+
size 163337
|
data/audio/like_no_tomorrow_20sec.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cabd551cfeea4bb608a010aa33118e744b55b816fc71697ffab7edb5ff350805
|
3 |
+
size 7680088
|
gradio_components/image.py
CHANGED
@@ -4,9 +4,11 @@ import os
|
|
4 |
|
5 |
import anthropic
|
6 |
import gradio as gr
|
|
|
|
|
7 |
|
8 |
# Remember to put your API Key here
|
9 |
-
client = anthropic.Anthropic(api_key=os.
|
10 |
|
11 |
# image1_url = "https://i.abcnewsfe.com/a/7d849ccc-e0fe-4416-959d-85889e338add/dune-1-ht-bb-231212_1702405287482_hpMain_16x9.jpeg"
|
12 |
image1_media_type = "image/jpeg"
|
@@ -20,12 +22,22 @@ The model was trained with descriptions from a stock music catalog, descriptions
|
|
20 |
|
21 |
Try to make the prompt simple and concise with only 1-2 sentences
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
SYSTEM_PROMPT_AUDIO = """You are an expert llm prompt engineer, you understand the structure of llms and facebook musicgen text to audio model. You will be provided with an image, and require to output a prompt for the musicgen model to capture the essense of the image. Try to do it step by step, evaluate and analyze the image thoroughly. After that, develop a prompt that contains the detail of what background sounds this image should have. This prompt will be provided to audiogen model to generate a 15s audio clip.
|
26 |
Try to make the prompt simple and concise with only 1-2 sentences
|
27 |
|
28 |
-
|
|
|
|
|
|
|
29 |
"""
|
30 |
|
31 |
PROMPT_IMPROVEMENT_GENERATE_PROMPT = """
|
@@ -58,8 +70,39 @@ def improve_prompt(prompt):
|
|
58 |
prompt = message_object["prompt"]
|
59 |
return message_object, prompt
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
def
|
63 |
if model_file == "facebook/audiogen-medium":
|
64 |
system_prompt = SYSTEM_PROMPT_AUDIO
|
65 |
else:
|
|
|
4 |
|
5 |
import anthropic
|
6 |
import gradio as gr
|
7 |
+
from openai import OpenAI
|
8 |
+
|
9 |
|
10 |
# Remember to put your API Key here
|
11 |
+
client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
12 |
|
13 |
# image1_url = "https://i.abcnewsfe.com/a/7d849ccc-e0fe-4416-959d-85889e338add/dune-1-ht-bb-231212_1702405287482_hpMain_16x9.jpeg"
|
14 |
image1_media_type = "image/jpeg"
|
|
|
22 |
|
23 |
Try to make the prompt simple and concise with only 1-2 sentences
|
24 |
|
25 |
+
only return dictionary, with two items `description` and `prompt`
|
26 |
+
|
27 |
+
for example
|
28 |
+
{
|
29 |
+
"description": "A serene beach at sunset with gentle waves and a distant ship.",
|
30 |
+
"prompt": "A calming instrumental with gentle guitar, soft piano, and ocean waves sound effects, perfect for a relaxing moment by the sea."
|
31 |
+
}
|
32 |
+
"""
|
33 |
|
34 |
SYSTEM_PROMPT_AUDIO = """You are an expert llm prompt engineer, you understand the structure of llms and facebook musicgen text to audio model. You will be provided with an image, and require to output a prompt for the musicgen model to capture the essense of the image. Try to do it step by step, evaluate and analyze the image thoroughly. After that, develop a prompt that contains the detail of what background sounds this image should have. This prompt will be provided to audiogen model to generate a 15s audio clip.
|
35 |
Try to make the prompt simple and concise with only 1-2 sentences
|
36 |
|
37 |
+
only return dictionary, with two items `description` and `prompt`
|
38 |
+
for example
|
39 |
+
{"description": "A serene beach scene at sunset with gentle waves lapping on the shore and a distant ship sailing on the water.",
|
40 |
+
"prompt": "Gentle waves flowing on the beach at sunset, with a distant ship in the background."}
|
41 |
"""
|
42 |
|
43 |
PROMPT_IMPROVEMENT_GENERATE_PROMPT = """
|
|
|
70 |
prompt = message_object["prompt"]
|
71 |
return message_object, prompt
|
72 |
|
73 |
+
def generate_caption_gpt4(image_file, model_file):
|
74 |
+
client = OpenAI()
|
75 |
+
if model_file == "facebook/audiogen-medium":
|
76 |
+
system_prompt = SYSTEM_PROMPT_AUDIO
|
77 |
+
else:
|
78 |
+
system_prompt = SYSTEM_PROMPT
|
79 |
+
with open(image_file, "rb") as f:
|
80 |
+
image_encoded = base64.b64encode(f.read()).decode("utf-8")
|
81 |
+
response = client.chat.completions.create(
|
82 |
+
model="gpt-4o",
|
83 |
+
messages=[
|
84 |
+
{
|
85 |
+
"role": "user",
|
86 |
+
"content": [
|
87 |
+
{"type": "text",
|
88 |
+
"text": system_prompt},
|
89 |
+
{
|
90 |
+
"type": "image_url",
|
91 |
+
"image_url": {
|
92 |
+
"url": f"data:image/jpeg;base64,{image_encoded}",
|
93 |
+
},
|
94 |
+
},
|
95 |
+
],
|
96 |
+
}
|
97 |
+
],
|
98 |
+
max_tokens=300,
|
99 |
+
)
|
100 |
+
message = json.loads(response.choices[0].message.content)
|
101 |
+
return message['description'], message['prompt']
|
102 |
+
|
103 |
+
|
104 |
|
105 |
+
def generate_caption_claude3(image_file, model_file, progress=gr.Progress()):
|
106 |
if model_file == "facebook/audiogen-medium":
|
107 |
system_prompt = SYSTEM_PROMPT_AUDIO
|
108 |
else:
|
gradio_components/model_cards.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
TEXT_TO_MUSIC_MODELS = [
|
4 |
+
"facebook/musicgen-medium",
|
5 |
+
"facebook/musicgen-small",
|
6 |
+
"facebook/musicgen-large",
|
7 |
+
'facebook/magnet-small-10secs',
|
8 |
+
'facebook/magnet-medium-10secs',
|
9 |
+
'facebook/magnet-small-30secs',
|
10 |
+
'facebook/magnet-medium-30secs',
|
11 |
+
# "facebook/musicgen-stereo-small",
|
12 |
+
# "facebook/musicgen-stereo-medium",
|
13 |
+
# "facebook/musicgen-stereo-large",
|
14 |
+
]
|
15 |
+
|
16 |
+
TEXT_TO_MIDI_MODELS = [
|
17 |
+
"musiclang/musiclang-v2",
|
18 |
+
]
|
19 |
+
|
20 |
+
MELODY_CONTINUATION_MODELS = [
|
21 |
+
"facebook/musicgen-medium",
|
22 |
+
"facebook/musicgen-small",
|
23 |
+
"facebook/musicgen-large",
|
24 |
+
]
|
25 |
+
|
26 |
+
TEXT_TO_SOUND_MODELS = [
|
27 |
+
'facebook/audio-magnet-small',
|
28 |
+
'facebook/audio-magnet-medium',
|
29 |
+
"facebook/audiogen-medium",
|
30 |
+
]
|
31 |
+
|
32 |
+
MELODY_CONDITIONED_MODELS = [
|
33 |
+
"facebook/musicgen-melody",
|
34 |
+
"facebook/musicgen-melody-large",
|
35 |
+
# "facebook/musicgen-stereo-melody",
|
36 |
+
# "facebook/musicgen-stereo-melody-large",
|
37 |
+
]
|
38 |
+
|
39 |
+
STEREO_MODEL = [
|
40 |
+
"facebook/musicgen-stereo-small",
|
41 |
+
"facebook/musicgen-stereo-medium",
|
42 |
+
"facebook/musicgen-stereo-large",
|
43 |
+
"facebook/musicgen-stereo-melody",
|
44 |
+
"facebook/musicgen-stereo-melody-large",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
MODEL_CARDS = {
|
49 |
+
"text-to-music": TEXT_TO_MUSIC_MODELS,
|
50 |
+
"text-to-midi": TEXT_TO_MIDI_MODELS,
|
51 |
+
"text-to-sound": TEXT_TO_SOUND_MODELS,
|
52 |
+
"melody-conditioned": MELODY_CONDITIONED_MODELS,
|
53 |
+
}
|
54 |
+
|
55 |
+
MODEL_DISCLAIMERS = {
|
56 |
+
"facebook/musicgen-melody": "1.5B transformer decoder also supporting melody conditioning.",
|
57 |
+
"facebook/musicgen-medium": "1.5B transformer decoder.",
|
58 |
+
"facebook/musicgen-small": "300M transformer decoder.",
|
59 |
+
"facebook/musicgen-large": "3.3B transformer decoder also supporting melody conditioning.",
|
60 |
+
"facebook/musicgen-melody-large": "3.3B transformer decoder.",
|
61 |
+
'facebook/magnet-small-10secs': "A 300M non-autoregressive transformer capable of generating 10-second music conditioned on text.",
|
62 |
+
'facebook/magnet-medium-10secs': "A 1.5B parameters, 10 seconds music samples..",
|
63 |
+
'facebook/magnet-small-30secs': "A 300M parameters, 30 seconds music samples.",
|
64 |
+
'facebook/magnet-medium-30secs': "A 1.5B parameters, 30 seconds music samples.",
|
65 |
+
# "musiclang/musiclang-v2": "This model generates music from text prompts.", TODO: Implement MusicLang
|
66 |
+
'facebook/audio-magnet-small': "a 300M non-autoregressive transformer capable of generating 10 second sound effects conditioned on text.",
|
67 |
+
'facebook/audio-magnet-medium': "10 second sound effect generation, 1.5B parameters.",
|
68 |
+
"facebook/audiogen-medium": "1.5B transformer decoder capable of generating sound effects conditioned on text.",
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def print_model_cards():
|
74 |
+
for key, value in MODEL_CARDS.items():
|
75 |
+
print(key, ":", value)
|
gradio_components/prediction.py
CHANGED
@@ -8,77 +8,177 @@ import gradio as gr
|
|
8 |
import torch
|
9 |
from audiocraft.data.audio import audio_write
|
10 |
from audiocraft.data.audio_utils import convert_audio
|
11 |
-
from audiocraft.models import AudioGen, MusicGen
|
12 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
13 |
-
from transformers import AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
|
|
15 |
|
16 |
-
def load_model(version=
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
else:
|
20 |
-
|
|
|
|
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def _do_predictions(
|
24 |
model_file,
|
25 |
model,
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
progress=False,
|
30 |
-
|
31 |
-
target_sr=32000,
|
32 |
-
target_ac=1,
|
33 |
**gen_kwargs,
|
34 |
):
|
35 |
print(
|
36 |
-
"new
|
37 |
-
|
38 |
-
|
39 |
-
[None if m is None else (m[0], m[1].shape) for m in melodies],
|
40 |
)
|
41 |
be = time.time()
|
42 |
-
processed_melodies = []
|
43 |
-
model.set_generation_params(duration=duration)
|
44 |
-
for melody in melodies:
|
45 |
-
if melody is None:
|
46 |
-
processed_melodies.append(None)
|
47 |
-
else:
|
48 |
-
sr, melody = (
|
49 |
-
melody[0],
|
50 |
-
torch.from_numpy(melody[1]).to(model.device).float().t(),
|
51 |
-
)
|
52 |
-
print(f"Input audio sample rate is {sr}")
|
53 |
-
if melody.dim() == 1:
|
54 |
-
melody = melody[None]
|
55 |
-
melody = melody[..., : int(sr * duration)]
|
56 |
-
melody = convert_audio(melody, sr, target_sr, target_ac)
|
57 |
-
processed_melodies.append(melody)
|
58 |
-
|
59 |
try:
|
60 |
-
if
|
61 |
-
# melody condition
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
melody_sample_rate=target_sr,
|
66 |
-
progress=progress,
|
67 |
-
return_tokens=False,
|
68 |
-
)
|
69 |
-
else:
|
70 |
-
if model_file == "facebook/audiogen-medium":
|
71 |
-
# audio condition
|
72 |
-
outputs = model.generate(texts, progress=progress)
|
73 |
else:
|
74 |
-
#
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
except RuntimeError as e:
|
78 |
raise gr.Error("Error while generating " + e.args[0])
|
79 |
outputs = outputs.detach().cpu().float()
|
80 |
-
|
81 |
-
|
82 |
for output in outputs:
|
83 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
84 |
audio_write(
|
@@ -90,45 +190,36 @@ def _do_predictions(
|
|
90 |
loudness_compressor=True,
|
91 |
add_suffix=False,
|
92 |
)
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def predict(
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
topp,
|
105 |
-
temperature,
|
106 |
-
target_sr,
|
107 |
progress=gr.Progress(),
|
108 |
):
|
109 |
global INTERRUPTING
|
110 |
-
global USE_DIFFUSION
|
111 |
INTERRUPTING = False
|
112 |
progress(0, desc="Loading model...")
|
113 |
-
model_path = model_path.strip()
|
114 |
-
# if model_path:
|
115 |
-
# if not Path(model_path).exists():
|
116 |
-
# raise gr.Error(f"Model path {model_path} doesn't exist.")
|
117 |
-
# if not Path(model_path).is_dir():
|
118 |
-
# raise gr.Error(f"Model path {model_path} must be a folder containing "
|
119 |
-
# "state_dict.bin and compression_state_dict_.bin.")
|
120 |
-
if temperature < 0:
|
121 |
-
raise gr.Error("Temperature must be >= 0.")
|
122 |
-
if topk < 0:
|
123 |
-
raise gr.Error("Topk must be non-negative.")
|
124 |
-
if topp < 0:
|
125 |
-
raise gr.Error("Topp must be non-negative.")
|
126 |
-
|
127 |
-
topk = int(topk)
|
128 |
-
model = load_model(model_path)
|
129 |
-
|
130 |
-
max_generated = 0
|
131 |
-
|
132 |
def _progress(generated, to_generate):
|
133 |
nonlocal max_generated
|
134 |
max_generated = max(generated, max_generated)
|
@@ -136,40 +227,56 @@ def predict(
|
|
136 |
if INTERRUPTING:
|
137 |
raise gr.Error("Interrupted.")
|
138 |
|
|
|
139 |
model.set_custom_progress_callback(_progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
|
142 |
-
|
143 |
model,
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
progress=True,
|
148 |
-
|
149 |
-
|
150 |
-
top_k=topk,
|
151 |
-
top_p=topp,
|
152 |
-
temperature=temperature,
|
153 |
-
gradio_progress=progress,
|
154 |
)
|
155 |
-
return
|
156 |
|
157 |
|
158 |
def transcribe(audio_path):
|
|
|
|
|
|
|
159 |
# model_output, midi_data, note_events = predict("generated_0.wav")
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
-
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
|
166 |
-
try:
|
167 |
-
midi_data.write(file)
|
168 |
-
print(f"midi file saved to {file.name}")
|
169 |
-
except Exception as e:
|
170 |
-
print(f"Error while writing midi file: {e}")
|
171 |
-
raise e
|
172 |
|
173 |
-
return gr.DownloadButton(
|
174 |
-
value=file.name, label=f"Download MIDI file {file.name}", visible=True
|
175 |
-
)
|
|
|
8 |
import torch
|
9 |
from audiocraft.data.audio import audio_write
|
10 |
from audiocraft.data.audio_utils import convert_audio
|
11 |
+
from audiocraft.models import AudioGen, MusicGen, MAGNeT
|
12 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
13 |
+
# from transformers import AutoModelForSeq2SeqLM
|
14 |
+
from concurrent.futures import ProcessPoolExecutor
|
15 |
+
import typing as tp
|
16 |
+
import warnings
|
17 |
+
import json
|
18 |
+
import ast
|
19 |
+
import torchaudio
|
20 |
|
21 |
+
MODEL = None
|
22 |
|
23 |
+
def load_model(version='facebook/musicgen-large'):
|
24 |
+
global MODEL
|
25 |
+
if MODEL is None or MODEL.name != version:
|
26 |
+
del MODEL
|
27 |
+
MODEL = None # in case loading would crash
|
28 |
+
print("Loading model", version)
|
29 |
+
if "magnet" in version:
|
30 |
+
MODEL = MAGNeT.get_pretrained(version)
|
31 |
+
elif "musicgen" in version:
|
32 |
+
MODEL = MusicGen.get_pretrained(version)
|
33 |
+
elif "musiclang" in version:
|
34 |
+
# TODO: Implement MusicLang
|
35 |
+
pass
|
36 |
+
elif "audiogen" in version:
|
37 |
+
MODEL = AudioGen.get_pretrained(version)
|
38 |
else:
|
39 |
+
raise ValueError("Invalid model version")
|
40 |
+
|
41 |
+
return MODEL
|
42 |
|
43 |
+
pool = ProcessPoolExecutor(4)
|
44 |
+
class FileCleaner:
|
45 |
+
def __init__(self, file_lifetime: float = 3600):
|
46 |
+
self.file_lifetime = file_lifetime
|
47 |
+
self.files = []
|
48 |
+
|
49 |
+
def add(self, path: tp.Union[str, Path]):
|
50 |
+
self._cleanup()
|
51 |
+
self.files.append((time.time(), Path(path)))
|
52 |
+
|
53 |
+
def _cleanup(self):
|
54 |
+
now = time.time()
|
55 |
+
for time_added, path in list(self.files):
|
56 |
+
if now - time_added > self.file_lifetime:
|
57 |
+
if path.exists():
|
58 |
+
path.unlink()
|
59 |
+
self.files.pop(0)
|
60 |
+
else:
|
61 |
+
break
|
62 |
+
|
63 |
+
file_cleaner = FileCleaner()
|
64 |
+
|
65 |
+
def inference_musicgen_text_to_music(model, configs, text, num_outputs=1):
|
66 |
+
model.set_generation_params(
|
67 |
+
**configs
|
68 |
+
)
|
69 |
+
descriptions = [text for _ in range(num_outputs)]
|
70 |
+
output = model.generate(descriptions=descriptions ,progress=True, return_tokens=False)
|
71 |
+
return output
|
72 |
+
|
73 |
+
def inference_musicgen_continuation(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1):
|
74 |
+
model.set_generation_params(
|
75 |
+
**configs
|
76 |
+
)
|
77 |
+
# melody, prompt_sr = torchaudio.load(prompt_waveform)
|
78 |
+
# descriptions = [text for _ in range(num_outputs)]
|
79 |
+
# prompt = [prompt_waveform for _ in range(num_outputs)]
|
80 |
+
output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=False)
|
81 |
+
return output
|
82 |
+
|
83 |
+
def inference_musicgen_melody_condition(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1):
|
84 |
+
model.set_generation_params(**configs)
|
85 |
+
descriptions = [text for _ in range(num_outputs)]
|
86 |
+
output = model.generate_with_chroma(
|
87 |
+
descriptions=descriptions,
|
88 |
+
melody_wavs=prompt_waveform,
|
89 |
+
melody_sample_rate=prompt_sr,
|
90 |
+
progress=True,
|
91 |
+
return_tokens=False
|
92 |
+
)
|
93 |
+
return output
|
94 |
+
|
95 |
+
def inference_magnet(model, configs, text, num_outputs=1):
|
96 |
+
model.set_generation_params(
|
97 |
+
**configs
|
98 |
+
)
|
99 |
+
descriptions = [text for _ in range(num_outputs)]
|
100 |
+
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
|
101 |
+
return output
|
102 |
+
|
103 |
+
def inference_magnet_audio(model, configs, text, num_outputs=1):
|
104 |
+
model.set_generation_params(
|
105 |
+
**configs
|
106 |
+
)
|
107 |
+
descriptions = [text for _ in range(num_outputs)]
|
108 |
+
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
|
109 |
+
return output
|
110 |
+
|
111 |
+
def inference_audiogen(model, configs, text, num_outputs=1):
|
112 |
+
model.set_generation_params(
|
113 |
+
**configs
|
114 |
+
)
|
115 |
+
descriptions = [text for _ in range(num_outputs)]
|
116 |
+
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
|
117 |
+
return output
|
118 |
+
|
119 |
+
def inference_musiclang():
|
120 |
+
# TODO: Implement MusicLang
|
121 |
+
pass
|
122 |
+
|
123 |
+
|
124 |
+
def process_audio(gr_audio, prompt_duration, model):
|
125 |
+
# audio, sr = torch.from_numpy(gr_audio[1]).to(model.device).float().t(), gr_audio[0]
|
126 |
+
audio, sr = torchaudio.load(gr_audio)
|
127 |
+
audio = audio[..., :int(prompt_duration * sr)]
|
128 |
+
return audio, sr
|
129 |
+
|
130 |
+
_MODEL_INFERENCES = {
|
131 |
+
"facebook/musicgen-small": inference_musicgen_text_to_music,
|
132 |
+
"facebook/musicgen-medium": inference_musicgen_text_to_music,
|
133 |
+
"facebook/musicgen-large": inference_musicgen_text_to_music,
|
134 |
+
"facebook/musicgen-melody": inference_musicgen_melody_condition,
|
135 |
+
"facebook/musicgen-melody-large": inference_musicgen_melody_condition,
|
136 |
+
"facebook/magnet-small-10secs": inference_magnet,
|
137 |
+
"facebook/magnet-medium-10secs": inference_magnet,
|
138 |
+
"facebook/magnet-small-30secs": inference_magnet,
|
139 |
+
"facebook/magnet-medium-30secs": inference_magnet,
|
140 |
+
"facebook/audio-magnet-small": inference_magnet_audio,
|
141 |
+
"facebook/audio-magnet-medium": inference_magnet_audio,
|
142 |
+
"facebook/audiogen-medium": inference_audiogen,
|
143 |
+
"musicgen-continuation": inference_musicgen_continuation,
|
144 |
+
}
|
145 |
|
146 |
def _do_predictions(
|
147 |
model_file,
|
148 |
model,
|
149 |
+
text,
|
150 |
+
melody = None,
|
151 |
+
mel_sample_rate=None,
|
152 |
progress=False,
|
153 |
+
num_generations=1,
|
|
|
|
|
154 |
**gen_kwargs,
|
155 |
):
|
156 |
print(
|
157 |
+
"new generation",
|
158 |
+
text,
|
159 |
+
None if melody is None else melody.shape
|
|
|
160 |
)
|
161 |
be = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
try:
|
163 |
+
if melody is not None:
|
164 |
+
# melody condition or continuation
|
165 |
+
if 'melody' in model_file:
|
166 |
+
# melody condition - musicgen-melody, musicgen-melody-large
|
167 |
+
inderence_func = _MODEL_INFERENCES[model_file]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
else:
|
169 |
+
# melody continuation
|
170 |
+
inderence_func = _MODEL_INFERENCES['musicgen-continuation']
|
171 |
+
outputs = inderence_func(model, gen_kwargs, text, melody, mel_sample_rate, num_generations)
|
172 |
+
else:
|
173 |
+
# text-to-music, text-to-sound
|
174 |
+
inderence_func = _MODEL_INFERENCES[model_file]
|
175 |
+
outputs = inderence_func(model, gen_kwargs, text, num_generations)
|
176 |
|
177 |
except RuntimeError as e:
|
178 |
raise gr.Error("Error while generating " + e.args[0])
|
179 |
outputs = outputs.detach().cpu().float()
|
180 |
+
out_audios = []
|
181 |
+
video_processes = []
|
182 |
for output in outputs:
|
183 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
184 |
audio_write(
|
|
|
190 |
loudness_compressor=True,
|
191 |
add_suffix=False,
|
192 |
)
|
193 |
+
# video_processes.append(pool.submit(make_waveform, file.name))
|
194 |
+
out_audios.append(file.name)
|
195 |
+
file_cleaner.add(file.name)
|
196 |
+
# out_videos = [video.result() for video in video_processes]
|
197 |
+
# for video in out_videos:
|
198 |
+
# file_cleaner.add(video)
|
199 |
+
|
200 |
+
print("generation finished", len(outputs), time.time() - be)
|
201 |
+
return out_audios
|
202 |
|
203 |
+
def make_waveform(*args, **kwargs):
|
204 |
+
# Further remove some warnings.
|
205 |
+
be = time.time()
|
206 |
+
with warnings.catch_warnings():
|
207 |
+
warnings.simplefilter('ignore')
|
208 |
+
out = gr.make_waveform(*args, **kwargs)
|
209 |
+
print("Make a video took", time.time() - be)
|
210 |
+
return out
|
211 |
+
|
212 |
def predict(
|
213 |
+
model_version,
|
214 |
+
generation_configs,
|
215 |
+
prompt_text=None,
|
216 |
+
prompt_wav=None,
|
217 |
+
num_generations=1,
|
|
|
|
|
|
|
218 |
progress=gr.Progress(),
|
219 |
):
|
220 |
global INTERRUPTING
|
|
|
221 |
INTERRUPTING = False
|
222 |
progress(0, desc="Loading model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
def _progress(generated, to_generate):
|
224 |
nonlocal max_generated
|
225 |
max_generated = max(generated, max_generated)
|
|
|
227 |
if INTERRUPTING:
|
228 |
raise gr.Error("Interrupted.")
|
229 |
|
230 |
+
model = load_model(model_version)
|
231 |
model.set_custom_progress_callback(_progress)
|
232 |
+
if isinstance(generation_configs, str):
|
233 |
+
generation_configs = ast.literal_eval(generation_configs)
|
234 |
+
max_generated = 0
|
235 |
+
if prompt_wav is not None:
|
236 |
+
melody, mel_sample_rate = process_audio(prompt_wav, generation_configs['duration'], model)
|
237 |
+
else:
|
238 |
+
melody, mel_sample_rate = None, None
|
239 |
+
|
240 |
+
|
241 |
|
242 |
+
audios = _do_predictions(
|
243 |
+
model_version,
|
244 |
model,
|
245 |
+
prompt_text,
|
246 |
+
melody,
|
247 |
+
mel_sample_rate,
|
248 |
progress=True,
|
249 |
+
num_generations = num_generations,
|
250 |
+
**generation_configs,
|
|
|
|
|
|
|
|
|
251 |
)
|
252 |
+
return audios
|
253 |
|
254 |
|
255 |
def transcribe(audio_path):
|
256 |
+
"""
|
257 |
+
Transcribe an audio file to MIDI using the basic_pitch model.
|
258 |
+
"""
|
259 |
# model_output, midi_data, note_events = predict("generated_0.wav")
|
260 |
+
tmp_paths = ast.literal_eval(audio_path)
|
261 |
+
download_buttons = []
|
262 |
+
for audio_path in tmp_paths:
|
263 |
+
model_output, midi_data, note_events = basic_pitch.inference.predict(
|
264 |
+
audio_path=audio_path,
|
265 |
+
model_or_model_path=ICASSP_2022_MODEL_PATH,
|
266 |
+
)
|
267 |
+
|
268 |
+
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
|
269 |
+
try:
|
270 |
+
midi_data.write(file)
|
271 |
+
print(f"midi file saved to {file.name}")
|
272 |
+
except Exception as e:
|
273 |
+
print(f"Error while writing midi file: {e}")
|
274 |
+
raise e
|
275 |
+
download_buttons.append(gr.DownloadButton(
|
276 |
+
value=file.name, label=f"Download MIDI file {file.name}", visible=True
|
277 |
+
))
|
278 |
+
file_cleaner.add(file.name)
|
279 |
+
|
280 |
+
return download_buttons
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
|
|
|
|
|