suric commited on
Commit
6a24aec
·
1 Parent(s): da26cf8

update image-to-music tab

Browse files
.gitattributes CHANGED
@@ -38,3 +38,4 @@ data/audio/turkish_march_mozart.mp3 filter=lfs diff=lfs merge=lfs -text
38
  data/audio/twinkle_twinkle_little_stars_mozart.mp3 filter=lfs diff=lfs merge=lfs -text
39
  *.mp3 filter=lfs diff=lfs merge=lfs -text
40
  *.mp3 !text !filter !merge !diff
 
 
38
  data/audio/twinkle_twinkle_little_stars_mozart.mp3 filter=lfs diff=lfs merge=lfs -text
39
  *.mp3 filter=lfs diff=lfs merge=lfs -text
40
  *.mp3 !text !filter !merge !diff
41
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -2,217 +2,458 @@ import os
2
 
3
  import gradio as gr
4
 
 
5
  from gradio_components.prediction import predict, transcribe
6
 
7
  theme = gr.themes.Glass(
8
- primary_hue="fuchsia",
9
- secondary_hue="indigo",
10
- neutral_hue="slate",
11
- font=[gr.themes.GoogleFont('Source Sans Pro'), 'ui-sans-serif', 'system-ui',
12
- 'sans-serif'],
13
- ).set(
14
- body_background_fill_dark='*background_fill_primary',
15
- embed_radius='*table_radius',
16
- background_fill_primary='*neutral_50',
17
- background_fill_primary_dark='*neutral_950',
18
- background_fill_secondary_dark='*neutral_900',
19
- border_color_accent='*neutral_600',
20
- border_color_accent_subdued='*color_accent',
21
- border_color_primary_dark='*neutral_700',
22
- block_background_fill='*background_fill_primary',
23
- block_background_fill_dark='*neutral_800',
24
- block_border_width='1px',
25
- block_label_background_fill='*background_fill_primary',
26
- block_label_background_fill_dark='*background_fill_secondary',
27
- block_label_text_color='*neutral_500',
28
- block_label_text_size='*text_sm',
29
- block_label_text_weight='400',
30
- block_shadow='none',
31
- block_shadow_dark='none',
32
- block_title_text_color='*neutral_500',
33
- block_title_text_weight='400',
34
- panel_border_width='0',
35
- panel_border_width_dark='0',
36
- checkbox_background_color_dark='*neutral_800',
37
- checkbox_border_width='*input_border_width',
38
- checkbox_label_border_width='*input_border_width',
39
- input_background_fill='*neutral_100',
40
- input_background_fill_dark='*neutral_700',
41
- input_border_color_focus_dark='*neutral_700',
42
- input_border_width='0px',
43
- input_border_width_dark='0px',
44
- slider_color='#2563eb',
45
- slider_color_dark='#2563eb',
46
- table_even_background_fill_dark='*neutral_950',
47
- table_odd_background_fill_dark='*neutral_900',
48
- button_border_width='*input_border_width',
49
- button_shadow_active='none',
50
- button_primary_background_fill='*primary_200',
51
- button_primary_background_fill_dark='*primary_700',
52
- button_primary_background_fill_hover='*button_primary_background_fill',
53
- button_primary_background_fill_hover_dark='*button_primary_background_fill',
54
- button_secondary_background_fill='*neutral_200',
55
- button_secondary_background_fill_dark='*neutral_600',
56
- button_secondary_background_fill_hover='*button_secondary_background_fill',
57
- button_secondary_background_fill_hover_dark='*button_secondary_background_fill',
58
- button_cancel_background_fill='*button_secondary_background_fill',
59
- button_cancel_background_fill_dark='*button_secondary_background_fill',
60
- button_cancel_background_fill_hover='*button_cancel_background_fill',
61
- button_cancel_background_fill_hover_dark='*button_cancel_background_fill'
62
- )
63
-
64
- _AUDIOCRAFT_MODELS = ["facebook/musicgen-melody",
65
- "facebook/musicgen-medium",
66
- "facebook/musicgen-small",
67
- "facebook/musicgen-large",
68
- "facebook/musicgen-melody-large"]
 
 
 
 
 
 
 
 
69
 
70
 
71
  def generate_prompt(difficulty, style):
72
- _DIFFICULTY_MAPPIN = {
73
- "Easy": "beginner player",
74
- "Medum": "player who has 2-3 years experience",
75
- "Hard": "player who has more than 4 years experiences"
76
- }
77
- prompt = 'piano only music for a {} to pratice with the touch of {}'.format(
78
- _DIFFICULTY_MAPPIN[difficulty], style
79
- )
80
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  def UI():
84
- with gr.Blocks() as demo:
85
- with gr.Tab("Generate Music by melody"):
86
- with gr.Row():
87
- with gr.Column():
88
- with gr.Row():
89
- model_path = gr.Dropdown(
90
- choices=_AUDIOCRAFT_MODELS,
91
- label="Select the model",
92
- value="facebook/musicgen-melody-large"
93
- )
94
- with gr.Row():
95
- duration = gr.Slider(
96
- minimum=10,
97
- maximum=60,
98
- value=10,
99
- label="Duration",
100
- interactive=True
101
- )
102
- with gr.Row():
103
- topk = gr.Number(label="Top-k", value=250, interactive=True)
104
- topp = gr.Number(label="Top-p", value=0, interactive=True)
105
- temperature = gr.Number(
106
- label="Temperature", value=1.0, interactive=True
107
- )
108
- sample_rate = gr.Number(
109
- label="output music sample rate", value=32000,
110
- interactive=True
111
- )
112
- difficulty = gr.Radio(
113
- ["Easy", "Medium", "Hard"], label="Difficulty",
114
- value="Easy", interactive=True
115
- )
116
- style = gr.Radio(
117
- ["Jazz", "Classical Music", "Hip Hop", "Others"],
118
- value="Classical Music", label="music genre",
119
- interactive=True
120
- )
121
- if style == "Others":
122
- style = gr.Textbox(label="Type your music genre")
123
- prompt = generate_prompt(difficulty.value, style.value)
124
- customize = gr.Checkbox(
125
- label="Customize the prompt", interactive=True
126
- )
127
- if customize:
128
- prompt = gr.Textbox(label="Type your prompt")
129
- with gr.Column():
130
- with gr.Row():
131
- melody = gr.Audio(
132
- sources=["microphone", "upload"],
133
- label="Record or upload your audio",
134
- #interactive=True,
135
- show_label=True,
136
- )
137
- with gr.Row():
138
- submit = gr.Button("Generate Music")
139
- output_audio = gr.Audio("listen to the generated music", type="filepath")
140
- with gr.Row():
141
- transcribe_button = gr.Button("Transcribe")
142
- d = gr.DownloadButton("Download the file", visible=False)
143
- transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
144
-
145
- submit.click(
146
- fn=predict,
147
- inputs=[model_path, prompt, melody, duration, topk, topp, temperature,
148
- sample_rate],
149
- outputs=output_audio
150
- )
151
-
152
- gr.Examples(
153
- examples=[
154
- [
155
- os.path.join(
156
- os.path.dirname(__file__),
157
- "./data/audio/twinkle_twinkle_little_stars_mozart_20sec.mp3"
158
- ),
159
- "Easy",
160
- 32000,
161
- 20
162
- ],
163
- [
164
- os.path.join(
165
- os.path.dirname(__file__),
166
- "./data/audio/golden_hour_20sec.mp3"
167
- ),
168
- "Easy",
169
- 32000,
170
- 20
171
- ],
172
- [
173
- os.path.join(
174
- os.path.dirname(__file__),
175
- "./data/audio/turkish_march_mozart_20sec.mp3"
176
- ),
177
- "Easy",
178
- 32000,
179
- 20
180
- ],
181
- [
182
- os.path.join(
183
- os.path.dirname(__file__),
184
- "./data/audio/golden_hour_20sec.mp3"
185
- ),
186
- "Hard",
187
- 32000,
188
- 20
189
- ],
190
- [
191
- os.path.join(
192
- os.path.dirname(__file__),
193
- "./data/audio/golden_hour_20sec.mp3"
194
- ),
195
- "Hard",
196
- 32000,
197
- 40
198
- ],
199
- [
200
- os.path.join(
201
- os.path.dirname(__file__),
202
- "./data/audio/golden_hour_20sec.mp3"
203
- ),
204
- "Hard",
205
- 16000,
206
- 20
207
- ],
208
- ],
209
- inputs=[melody, difficulty, sample_rate, duration],
210
- label="Audio Examples",
211
- outputs=[output_audio],
212
- # cache_examples=True,
213
- )
214
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
 
217
  if __name__ == "__main__":
218
- UI()
 
2
 
3
  import gradio as gr
4
 
5
+ from gradio_components.image import generate_caption
6
  from gradio_components.prediction import predict, transcribe
7
 
8
  theme = gr.themes.Glass(
9
+ primary_hue="fuchsia",
10
+ secondary_hue="indigo",
11
+ neutral_hue="slate",
12
+ font=[
13
+ gr.themes.GoogleFont("Source Sans Pro"),
14
+ "ui-sans-serif",
15
+ "system-ui",
16
+ "sans-serif",
17
+ ],
18
+ ).set(
19
+ body_background_fill_dark="*background_fill_primary",
20
+ embed_radius="*table_radius",
21
+ background_fill_primary="*neutral_50",
22
+ background_fill_primary_dark="*neutral_950",
23
+ background_fill_secondary_dark="*neutral_900",
24
+ border_color_accent="*neutral_600",
25
+ border_color_accent_subdued="*color_accent",
26
+ border_color_primary_dark="*neutral_700",
27
+ block_background_fill="*background_fill_primary",
28
+ block_background_fill_dark="*neutral_800",
29
+ block_border_width="1px",
30
+ block_label_background_fill="*background_fill_primary",
31
+ block_label_background_fill_dark="*background_fill_secondary",
32
+ block_label_text_color="*neutral_500",
33
+ block_label_text_size="*text_sm",
34
+ block_label_text_weight="400",
35
+ block_shadow="none",
36
+ block_shadow_dark="none",
37
+ block_title_text_color="*neutral_500",
38
+ block_title_text_weight="400",
39
+ panel_border_width="0",
40
+ panel_border_width_dark="0",
41
+ checkbox_background_color_dark="*neutral_800",
42
+ checkbox_border_width="*input_border_width",
43
+ checkbox_label_border_width="*input_border_width",
44
+ input_background_fill="*neutral_100",
45
+ input_background_fill_dark="*neutral_700",
46
+ input_border_color_focus_dark="*neutral_700",
47
+ input_border_width="0px",
48
+ input_border_width_dark="0px",
49
+ slider_color="#2563eb",
50
+ slider_color_dark="#2563eb",
51
+ table_even_background_fill_dark="*neutral_950",
52
+ table_odd_background_fill_dark="*neutral_900",
53
+ button_border_width="*input_border_width",
54
+ button_shadow_active="none",
55
+ button_primary_background_fill="*primary_200",
56
+ button_primary_background_fill_dark="*primary_700",
57
+ button_primary_background_fill_hover="*button_primary_background_fill",
58
+ button_primary_background_fill_hover_dark="*button_primary_background_fill",
59
+ button_secondary_background_fill="*neutral_200",
60
+ button_secondary_background_fill_dark="*neutral_600",
61
+ button_secondary_background_fill_hover="*button_secondary_background_fill",
62
+ button_secondary_background_fill_hover_dark="*button_secondary_background_fill",
63
+ button_cancel_background_fill="*button_secondary_background_fill",
64
+ button_cancel_background_fill_dark="*button_secondary_background_fill",
65
+ button_cancel_background_fill_hover="*button_cancel_background_fill",
66
+ button_cancel_background_fill_hover_dark="*button_cancel_background_fill",
67
+ )
68
+
69
+
70
+ _AUDIOCRAFT_MODELS = [
71
+ "facebook/musicgen-melody",
72
+ "facebook/musicgen-medium",
73
+ "facebook/musicgen-small",
74
+ "facebook/musicgen-large",
75
+ "facebook/musicgen-melody-large",
76
+ "facebook/audiogen-medium",
77
+ ]
78
 
79
 
80
  def generate_prompt(difficulty, style):
81
+ _DIFFICULTY_MAPPIN = {
82
+ "Easy": "beginner player",
83
+ "Medum": "player who has 2-3 years experience",
84
+ "Hard": "player who has more than 4 years experiences",
85
+ }
86
+ prompt = "piano only music for a {} to pratice with the touch of {}".format(
87
+ _DIFFICULTY_MAPPIN[difficulty], style
88
+ )
89
+ return prompt
90
+
91
+
92
+ def toggle_melody_condition(melody_condition):
93
+ if melody_condition:
94
+ return gr.Audio(
95
+ sources=["microphone", "upload"],
96
+ label="Record or upload your audio",
97
+ show_label=True,
98
+ visible=True,
99
+ )
100
+ else:
101
+ return gr.Audio(
102
+ sources=["microphone", "upload"],
103
+ label="Record or upload your audio",
104
+ show_label=True,
105
+ visible=False,
106
+ )
107
+
108
+
109
+ def show_caption(show_caption_condition, description, prompt):
110
+ if show_caption_condition:
111
+ return (
112
+ gr.Textbox(
113
+ label="Image Caption",
114
+ value=description,
115
+ interactive=False,
116
+ show_label=True,
117
+ visible=True,
118
+ ),
119
+ gr.Textbox(
120
+ label="Generated Prompt",
121
+ value=prompt,
122
+ interactive=True,
123
+ show_label=True,
124
+ visible=True,
125
+ ),
126
+ gr.Button("Generate Music", interactive=True, visible=True),
127
+ )
128
+ else:
129
+ return (
130
+ gr.Textbox(
131
+ label="Image Caption",
132
+ value=description,
133
+ interactive=False,
134
+ show_label=True,
135
+ visible=False,
136
+ ),
137
+ gr.Textbox(
138
+ label="Generated Prompt",
139
+ value=prompt,
140
+ interactive=True,
141
+ show_label=True,
142
+ visible=False,
143
+ ),
144
+ gr.Button(label="Generate Music", interactive=True, visible=True),
145
+ )
146
+
147
+
148
+ def post_submit(show_caption, image_input):
149
+ _, description, prompt = generate_caption(image_input)
150
+ return (
151
+ gr.Textbox(
152
+ label="Image Caption",
153
+ value=description,
154
+ interactive=False,
155
+ show_label=True,
156
+ visible=show_caption,
157
+ ),
158
+ gr.Textbox(
159
+ label="Generated Prompt",
160
+ value=prompt,
161
+ interactive=True,
162
+ show_label=True,
163
+ visible=show_caption,
164
+ ),
165
+ gr.Button("Generate Music", interactive=True, visible=True),
166
+ )
167
 
168
 
169
  def UI():
170
+ with gr.Blocks() as demo:
171
+ with gr.Tab("Generate Music by melody"):
172
+ with gr.Row():
173
+ with gr.Column():
174
+ with gr.Row():
175
+ model_path = gr.Dropdown(
176
+ choices=_AUDIOCRAFT_MODELS,
177
+ label="Select the model",
178
+ value="facebook/musicgen-melody-large",
179
+ )
180
+ with gr.Row():
181
+ duration = gr.Slider(
182
+ minimum=10,
183
+ maximum=60,
184
+ value=10,
185
+ label="Duration",
186
+ interactive=True,
187
+ )
188
+ with gr.Row():
189
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
190
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
191
+ temperature = gr.Number(
192
+ label="Temperature", value=1.0, interactive=True
193
+ )
194
+ sample_rate = gr.Number(
195
+ label="output music sample rate",
196
+ value=32000,
197
+ interactive=True,
198
+ )
199
+ difficulty = gr.Radio(
200
+ ["Easy", "Medium", "Hard"],
201
+ label="Difficulty",
202
+ value="Easy",
203
+ interactive=True,
204
+ )
205
+ style = gr.Radio(
206
+ ["Jazz", "Classical Music", "Hip Hop", "Others"],
207
+ value="Classical Music",
208
+ label="music genre",
209
+ interactive=True,
210
+ )
211
+ if style == "Others":
212
+ style = gr.Textbox(label="Type your music genre")
213
+ prompt = generate_prompt(difficulty.value, style.value)
214
+ customize = gr.Checkbox(
215
+ label="Customize the prompt", interactive=True
216
+ )
217
+ if customize:
218
+ prompt = gr.Textbox(label="Type your prompt")
219
+ with gr.Column():
220
+ with gr.Row():
221
+ melody = gr.Audio(
222
+ sources=["microphone", "upload"],
223
+ label="Record or upload your audio",
224
+ # interactive=True,
225
+ show_label=True,
226
+ )
227
+ with gr.Row():
228
+ submit = gr.Button("Generate Music")
229
+ output_audio = gr.Audio(
230
+ "listen to the generated music", type="filepath"
231
+ )
232
+ with gr.Row():
233
+ transcribe_button = gr.Button("Transcribe")
234
+ d = gr.DownloadButton("Download the file", visible=False)
235
+ transcribe_button.click(
236
+ transcribe, inputs=[output_audio], outputs=d
237
+ )
238
+
239
+ submit.click(
240
+ fn=predict,
241
+ inputs=[
242
+ model_path,
243
+ prompt,
244
+ melody,
245
+ duration,
246
+ topk,
247
+ topp,
248
+ temperature,
249
+ sample_rate,
250
+ ],
251
+ outputs=output_audio,
252
+ )
253
+ gr.Examples(
254
+ examples=[
255
+ [
256
+ os.path.join(
257
+ os.path.dirname(__file__),
258
+ "./data/audio/twinkle_twinkle_little_stars_mozart_20sec"
259
+ ".mp3",
260
+ ),
261
+ "Easy",
262
+ 32000,
263
+ 20,
264
+ ],
265
+ [
266
+ os.path.join(
267
+ os.path.dirname(__file__),
268
+ "./data/audio/golden_hour_20sec.mp3",
269
+ ),
270
+ "Easy",
271
+ 32000,
272
+ 20,
273
+ ],
274
+ [
275
+ os.path.join(
276
+ os.path.dirname(__file__),
277
+ "./data/audio/turkish_march_mozart_20sec.mp3",
278
+ ),
279
+ "Easy",
280
+ 32000,
281
+ 20,
282
+ ],
283
+ [
284
+ os.path.join(
285
+ os.path.dirname(__file__),
286
+ "./data/audio/golden_hour_20sec.mp3",
287
+ ),
288
+ "Hard",
289
+ 32000,
290
+ 20,
291
+ ],
292
+ [
293
+ os.path.join(
294
+ os.path.dirname(__file__),
295
+ "./data/audio/golden_hour_20sec.mp3",
296
+ ),
297
+ "Hard",
298
+ 32000,
299
+ 40,
300
+ ],
301
+ [
302
+ os.path.join(
303
+ os.path.dirname(__file__),
304
+ "./data/audio/golden_hour_20sec.mp3",
305
+ ),
306
+ "Hard",
307
+ 16000,
308
+ 20,
309
+ ],
310
+ ],
311
+ inputs=[melody, difficulty, sample_rate, duration],
312
+ label="Audio Examples",
313
+ outputs=[output_audio],
314
+ # cache_examples=True,
315
+ )
316
+
317
+ with gr.Tab("Generate Music by image"):
318
+ with gr.Row():
319
+ with gr.Column():
320
+ image_input = gr.Image("Upload an image", type="filepath")
321
+ melody_condition = gr.Checkbox(
322
+ label="Generate music by melody", interactive=True, value=False
323
+ )
324
+ melody = gr.Audio(
325
+ sources=["microphone", "upload"],
326
+ label="Record or upload your audio",
327
+ show_label=True,
328
+ visible=False,
329
+ )
330
+ melody_condition.change(
331
+ fn=toggle_melody_condition,
332
+ inputs=[melody_condition],
333
+ outputs=melody,
334
+ )
335
+ description = gr.Textbox(
336
+ label="Image Captioning",
337
+ show_label=True,
338
+ interactive=False,
339
+ visible=False,
340
+ )
341
+ prompt = gr.Textbox(
342
+ label="Generated Prompt",
343
+ show_label=True,
344
+ interactive=True,
345
+ visible=False,
346
+ )
347
+ show_prompt = gr.Checkbox(label="Show the prompt", interactive=True)
348
+ submit = gr.Button("submit", interactive=True, visible=True)
349
+ generate = gr.Button(
350
+ "Generate Music", interactive=True, visible=False
351
+ )
352
+ submit.click(
353
+ fn=post_submit,
354
+ inputs=[show_prompt, image_input],
355
+ outputs=[description, prompt, generate],
356
+ )
357
+ show_prompt.change(
358
+ fn=show_caption,
359
+ inputs=[show_prompt, description, prompt],
360
+ outputs=[description, prompt, generate],
361
+ )
362
+
363
+ with gr.Column():
364
+ with gr.Row():
365
+ model_path = gr.Dropdown(
366
+ choices=_AUDIOCRAFT_MODELS,
367
+ label="Select the model",
368
+ value="facebook/musicgen-large",
369
+ )
370
+ with gr.Row():
371
+ duration = gr.Slider(
372
+ minimum=10,
373
+ maximum=60,
374
+ value=10,
375
+ label="Duration",
376
+ interactive=True,
377
+ )
378
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
379
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
380
+ temperature = gr.Number(
381
+ label="Temperature", value=1.0, interactive=True
382
+ )
383
+ sample_rate = gr.Number(
384
+ label="output music sample rate", value=32000, interactive=True
385
+ )
386
+ with gr.Column():
387
+ output_audio = gr.Audio(
388
+ "listen to the generated music",
389
+ type="filepath",
390
+ show_label=True,
391
+ )
392
+ transcribe_button = gr.Button("Transcribe")
393
+ d = gr.DownloadButton("Download the file", visible=False)
394
+ transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
395
+ generate.click(
396
+ fn=predict,
397
+ inputs=[
398
+ model_path,
399
+ prompt,
400
+ melody,
401
+ duration,
402
+ topk,
403
+ topp,
404
+ temperature,
405
+ sample_rate,
406
+ ],
407
+ outputs=output_audio,
408
+ )
409
+
410
+ gr.Examples(
411
+ examples=[
412
+ [
413
+ os.path.join(
414
+ os.path.dirname(__file__),
415
+ "./data/image/kids_drawing.jpeg",
416
+ ),
417
+ False,
418
+ None,
419
+ "facebook/musicgen-large",
420
+ ],
421
+ [
422
+ os.path.join(
423
+ os.path.dirname(__file__),
424
+ "./data/image/cat.jpeg",
425
+ ),
426
+ False,
427
+ None,
428
+ "facebook/musicgen-large",
429
+ ],
430
+ [
431
+ os.path.join(
432
+ os.path.dirname(__file__),
433
+ "./data/image/cat.jpeg",
434
+ ),
435
+ True,
436
+ "./data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3",
437
+ "facebook/musicgen-melody-large",
438
+ ],
439
+ [
440
+ os.path.join(
441
+ os.path.dirname(__file__),
442
+ "./data/image/beach.jpeg",
443
+ ),
444
+ False,
445
+ None,
446
+ "facebook/audiogen-medium",
447
+ ],
448
+ ],
449
+ inputs=[image_input, melody_condition, melody, model_path],
450
+ label="Audio Examples",
451
+ outputs=[output_audio],
452
+ # cache_examples=True,
453
+ )
454
+
455
+ demo.queue().launch()
456
 
457
 
458
  if __name__ == "__main__":
459
+ UI()
data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa933b78b2380d325e6436d91de191c7abcb82e4c62ef2ed52a04868233a5012
3
+ size 3577581
data/image/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/image/beach.jpeg ADDED

Git LFS Details

  • SHA256: 1b742f5752fcab31147a6c213e3b60f56c45b35344faa0a7266ddb95944bcfa3
  • Pointer size: 130 Bytes
  • Size of remote file: 39.7 kB
data/image/cat.jpeg ADDED

Git LFS Details

  • SHA256: 5f63e517121b2e3e8b21d1cbba5e1fac9e5317da7bbc9980dbaf622cf2439518
  • Pointer size: 132 Bytes
  • Size of remote file: 2.4 MB
data/image/kids_drawing.jpeg ADDED

Git LFS Details

  • SHA256: d8802f50a0b4353fb9f76f9291dfff758a0109aa437c9b32282c80b74e471d84
  • Pointer size: 131 Bytes
  • Size of remote file: 639 kB
gradio_components/image.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+
5
+ import anthropic
6
+ import gradio as gr
7
+
8
+ # Remember to put your API Key here
9
+ client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
10
+
11
+ # image1_url = "https://i.abcnewsfe.com/a/7d849ccc-e0fe-4416-959d-85889e338add/dune-1-ht-bb-231212_1702405287482_hpMain_16x9.jpeg"
12
+ image1_media_type = "image/jpeg"
13
+ # image1_data = base64.b64encode(httpx.get(image1_url).content).decode("utf-8")
14
+ #
15
+
16
+ SYSTEM_PROMPT = """You are an expert llm prompt engineer, you understand the structure of llms and facebook musicgen text to audio model. You will be provided with an image, and require to output a prompt for the musicgen model to capture the essense of the image. Try to do it step by step, evaluate and analyze the image thoroughly. After that, develop a prompt that contains music genera, style, instrument, and all the other details needed. This prompt will be provided to musicgen model to generate a 15s audio clip.
17
+
18
+ Here are some descriptions from musicgen model:
19
+ The model was trained with descriptions from a stock music catalog, descriptions that will work best should include some level of detail on the instruments present, along with some intended use case (e.g. adding “perfect for a commercial” can somehow help).
20
+
21
+ Try to make the prompt simple and concise with only 1-2 sentences
22
+
23
+ Make sure the ouput is in JSON fomat, with two items `description` and `prompt`"""
24
+
25
+
26
+ def generate_caption(image_file, progress=gr.Progress()):
27
+ with open(image_file, "rb") as f:
28
+ image_encoded = base64.b64encode(f.read()).decode("utf-8")
29
+ progress(0, desc="Starting image captioning...")
30
+ message = client.messages.create(
31
+ model="claude-3-opus-20240229",
32
+ max_tokens=1024,
33
+ system=SYSTEM_PROMPT,
34
+ messages=[
35
+ {
36
+ "role": "user",
37
+ "content": [
38
+ {
39
+ "type": "image",
40
+ "source": {
41
+ "type": "base64",
42
+ "media_type": image1_media_type,
43
+ "data": image_encoded,
44
+ },
45
+ },
46
+ {"type": "text", "text": "develop the prompt based on this image"},
47
+ ],
48
+ }
49
+ ],
50
+ )
51
+ progress(100, desc="image captioning...Done!")
52
+ # Parse the content string into a Python object
53
+ message_object = json.loads(message.content[0].text)
54
+ # Access the description and prompt from the message object
55
+ description = message_object["description"]
56
+ prompt = message_object["prompt"]
57
+ print(description)
58
+ print(prompt)
59
+ return message_object, description, prompt
gradio_components/prediction.py CHANGED
@@ -1,36 +1,53 @@
1
  import time
2
- import torch
3
-
4
- from audiocraft.data.audio_utils import convert_audio
5
- from audiocraft.data.audio import audio_write
6
- import gradio as gr
7
- from audiocraft.models import MusicGen
8
-
9
- from tempfile import NamedTemporaryFile
10
  from pathlib import Path
11
- from transformers import AutoModelForSeq2SeqLM
 
12
  import basic_pitch
13
  import basic_pitch.inference
 
 
 
 
 
14
  from basic_pitch import ICASSP_2022_MODEL_PATH
 
15
 
16
 
17
- def load_model(version='facebook/musicgen-melody'):
18
  return MusicGen.get_pretrained(version)
19
 
20
 
21
- def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs):
22
- print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  be = time.time()
24
  processed_melodies = []
25
  for melody in melodies:
26
  if melody is None:
27
  processed_melodies.append(None)
28
  else:
29
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
 
 
 
30
  print(f"Input audio sample rate is {sr}")
31
  if melody.dim() == 1:
32
  melody = melody[None]
33
- melody = melody[..., :int(sr * duration)]
34
  melody = convert_audio(melody, sr, target_sr, target_ac)
35
  processed_melodies.append(melody)
36
 
@@ -42,7 +59,7 @@ def _do_predictions(model, texts, melodies, duration, progress=False, gradio_pro
42
  melody_wavs=processed_melodies,
43
  melody_sample_rate=target_sr,
44
  progress=progress,
45
- return_tokens=False
46
  )
47
  else:
48
  # text only
@@ -55,14 +72,30 @@ def _do_predictions(model, texts, melodies, duration, progress=False, gradio_pro
55
  for output in outputs:
56
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
57
  audio_write(
58
- file.name, output, model.sample_rate, strategy="loudness",
59
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
 
 
 
 
 
 
60
  out_wavs.append(file.name)
61
  print("generation finished", len(texts), time.time() - be)
62
  return out_wavs
63
 
64
 
65
- def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
66
  global INTERRUPTING
67
  global USE_DIFFUSION
68
  INTERRUPTING = False
@@ -92,6 +125,7 @@ def predict(model_path, text, melody, duration, topk, topp, temperature, target_
92
  progress((min(max_generated, to_generate), to_generate))
93
  if INTERRUPTING:
94
  raise gr.Error("Interrupted.")
 
95
  model.set_custom_progress_callback(_progress)
96
 
97
  wavs = _do_predictions(
@@ -105,7 +139,8 @@ def predict(model_path, text, melody, duration, topk, topp, temperature, target_
105
  top_k=topk,
106
  top_p=topp,
107
  temperature=temperature,
108
- gradio_progress=progress)
 
109
  return wavs[0]
110
 
111
 
@@ -114,7 +149,7 @@ def transcribe(audio_path):
114
  model_output, midi_data, note_events = basic_pitch.inference.predict(
115
  audio_path=audio_path,
116
  model_or_model_path=ICASSP_2022_MODEL_PATH,
117
- )
118
 
119
  with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
120
  try:
@@ -125,6 +160,5 @@ def transcribe(audio_path):
125
  raise e
126
 
127
  return gr.DownloadButton(
128
- value=file.name,
129
- label=f"Download MIDI file {file.name}",
130
- visible=True)
 
1
  import time
 
 
 
 
 
 
 
 
2
  from pathlib import Path
3
+ from tempfile import NamedTemporaryFile
4
+
5
  import basic_pitch
6
  import basic_pitch.inference
7
+ import gradio as gr
8
+ import torch
9
+ from audiocraft.data.audio import audio_write
10
+ from audiocraft.data.audio_utils import convert_audio
11
+ from audiocraft.models import MusicGen
12
  from basic_pitch import ICASSP_2022_MODEL_PATH
13
+ from transformers import AutoModelForSeq2SeqLM
14
 
15
 
16
+ def load_model(version="facebook/musicgen-melody"):
17
  return MusicGen.get_pretrained(version)
18
 
19
 
20
+ def _do_predictions(
21
+ model,
22
+ texts,
23
+ melodies,
24
+ duration,
25
+ progress=False,
26
+ gradio_progress=None,
27
+ target_sr=32000,
28
+ target_ac=1,
29
+ **gen_kwargs,
30
+ ):
31
+ print(
32
+ "new batch",
33
+ len(texts),
34
+ texts,
35
+ [None if m is None else (m[0], m[1].shape) for m in melodies],
36
+ )
37
  be = time.time()
38
  processed_melodies = []
39
  for melody in melodies:
40
  if melody is None:
41
  processed_melodies.append(None)
42
  else:
43
+ sr, melody = (
44
+ melody[0],
45
+ torch.from_numpy(melody[1]).to(model.device).float().t(),
46
+ )
47
  print(f"Input audio sample rate is {sr}")
48
  if melody.dim() == 1:
49
  melody = melody[None]
50
+ melody = melody[..., : int(sr * duration)]
51
  melody = convert_audio(melody, sr, target_sr, target_ac)
52
  processed_melodies.append(melody)
53
 
 
59
  melody_wavs=processed_melodies,
60
  melody_sample_rate=target_sr,
61
  progress=progress,
62
+ return_tokens=False,
63
  )
64
  else:
65
  # text only
 
72
  for output in outputs:
73
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
74
  audio_write(
75
+ file.name,
76
+ output,
77
+ model.sample_rate,
78
+ strategy="loudness",
79
+ loudness_headroom_db=16,
80
+ loudness_compressor=True,
81
+ add_suffix=False,
82
+ )
83
  out_wavs.append(file.name)
84
  print("generation finished", len(texts), time.time() - be)
85
  return out_wavs
86
 
87
 
88
+ def predict(
89
+ model_path,
90
+ text,
91
+ melody,
92
+ duration,
93
+ topk,
94
+ topp,
95
+ temperature,
96
+ target_sr,
97
+ progress=gr.Progress(),
98
+ ):
99
  global INTERRUPTING
100
  global USE_DIFFUSION
101
  INTERRUPTING = False
 
125
  progress((min(max_generated, to_generate), to_generate))
126
  if INTERRUPTING:
127
  raise gr.Error("Interrupted.")
128
+
129
  model.set_custom_progress_callback(_progress)
130
 
131
  wavs = _do_predictions(
 
139
  top_k=topk,
140
  top_p=topp,
141
  temperature=temperature,
142
+ gradio_progress=progress,
143
+ )
144
  return wavs[0]
145
 
146
 
 
149
  model_output, midi_data, note_events = basic_pitch.inference.predict(
150
  audio_path=audio_path,
151
  model_or_model_path=ICASSP_2022_MODEL_PATH,
152
+ )
153
 
154
  with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
155
  try:
 
160
  raise e
161
 
162
  return gr.DownloadButton(
163
+ value=file.name, label=f"Download MIDI file {file.name}", visible=True
164
+ )
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch==2.1.0
2
  audiocraft
3
  basic-pitch
4
  gradio
5
- tensorflow==2.15.0
 
 
2
  audiocraft
3
  basic-pitch
4
  gradio
5
+ tensorflow==2.15.0
6
+ anthropic