Spaces:
Sleeping
Sleeping
update image-to-music tab
Browse files- .gitattributes +1 -0
- app.py +443 -202
- data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3 +3 -0
- data/image/.DS_Store +0 -0
- data/image/beach.jpeg +3 -0
- data/image/cat.jpeg +3 -0
- data/image/kids_drawing.jpeg +3 -0
- gradio_components/image.py +59 -0
- gradio_components/prediction.py +57 -23
- requirements.txt +2 -1
.gitattributes
CHANGED
@@ -38,3 +38,4 @@ data/audio/turkish_march_mozart.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
38 |
data/audio/twinkle_twinkle_little_stars_mozart.mp3 filter=lfs diff=lfs merge=lfs -text
|
39 |
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
40 |
*.mp3 !text !filter !merge !diff
|
|
|
|
38 |
data/audio/twinkle_twinkle_little_stars_mozart.mp3 filter=lfs diff=lfs merge=lfs -text
|
39 |
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
40 |
*.mp3 !text !filter !merge !diff
|
41 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -2,217 +2,458 @@ import os
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
|
|
5 |
from gradio_components.prediction import predict, transcribe
|
6 |
|
7 |
theme = gr.themes.Glass(
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
|
71 |
def generate_prompt(difficulty, style):
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
def UI():
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
|
217 |
if __name__ == "__main__":
|
218 |
-
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
5 |
+
from gradio_components.image import generate_caption
|
6 |
from gradio_components.prediction import predict, transcribe
|
7 |
|
8 |
theme = gr.themes.Glass(
|
9 |
+
primary_hue="fuchsia",
|
10 |
+
secondary_hue="indigo",
|
11 |
+
neutral_hue="slate",
|
12 |
+
font=[
|
13 |
+
gr.themes.GoogleFont("Source Sans Pro"),
|
14 |
+
"ui-sans-serif",
|
15 |
+
"system-ui",
|
16 |
+
"sans-serif",
|
17 |
+
],
|
18 |
+
).set(
|
19 |
+
body_background_fill_dark="*background_fill_primary",
|
20 |
+
embed_radius="*table_radius",
|
21 |
+
background_fill_primary="*neutral_50",
|
22 |
+
background_fill_primary_dark="*neutral_950",
|
23 |
+
background_fill_secondary_dark="*neutral_900",
|
24 |
+
border_color_accent="*neutral_600",
|
25 |
+
border_color_accent_subdued="*color_accent",
|
26 |
+
border_color_primary_dark="*neutral_700",
|
27 |
+
block_background_fill="*background_fill_primary",
|
28 |
+
block_background_fill_dark="*neutral_800",
|
29 |
+
block_border_width="1px",
|
30 |
+
block_label_background_fill="*background_fill_primary",
|
31 |
+
block_label_background_fill_dark="*background_fill_secondary",
|
32 |
+
block_label_text_color="*neutral_500",
|
33 |
+
block_label_text_size="*text_sm",
|
34 |
+
block_label_text_weight="400",
|
35 |
+
block_shadow="none",
|
36 |
+
block_shadow_dark="none",
|
37 |
+
block_title_text_color="*neutral_500",
|
38 |
+
block_title_text_weight="400",
|
39 |
+
panel_border_width="0",
|
40 |
+
panel_border_width_dark="0",
|
41 |
+
checkbox_background_color_dark="*neutral_800",
|
42 |
+
checkbox_border_width="*input_border_width",
|
43 |
+
checkbox_label_border_width="*input_border_width",
|
44 |
+
input_background_fill="*neutral_100",
|
45 |
+
input_background_fill_dark="*neutral_700",
|
46 |
+
input_border_color_focus_dark="*neutral_700",
|
47 |
+
input_border_width="0px",
|
48 |
+
input_border_width_dark="0px",
|
49 |
+
slider_color="#2563eb",
|
50 |
+
slider_color_dark="#2563eb",
|
51 |
+
table_even_background_fill_dark="*neutral_950",
|
52 |
+
table_odd_background_fill_dark="*neutral_900",
|
53 |
+
button_border_width="*input_border_width",
|
54 |
+
button_shadow_active="none",
|
55 |
+
button_primary_background_fill="*primary_200",
|
56 |
+
button_primary_background_fill_dark="*primary_700",
|
57 |
+
button_primary_background_fill_hover="*button_primary_background_fill",
|
58 |
+
button_primary_background_fill_hover_dark="*button_primary_background_fill",
|
59 |
+
button_secondary_background_fill="*neutral_200",
|
60 |
+
button_secondary_background_fill_dark="*neutral_600",
|
61 |
+
button_secondary_background_fill_hover="*button_secondary_background_fill",
|
62 |
+
button_secondary_background_fill_hover_dark="*button_secondary_background_fill",
|
63 |
+
button_cancel_background_fill="*button_secondary_background_fill",
|
64 |
+
button_cancel_background_fill_dark="*button_secondary_background_fill",
|
65 |
+
button_cancel_background_fill_hover="*button_cancel_background_fill",
|
66 |
+
button_cancel_background_fill_hover_dark="*button_cancel_background_fill",
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
_AUDIOCRAFT_MODELS = [
|
71 |
+
"facebook/musicgen-melody",
|
72 |
+
"facebook/musicgen-medium",
|
73 |
+
"facebook/musicgen-small",
|
74 |
+
"facebook/musicgen-large",
|
75 |
+
"facebook/musicgen-melody-large",
|
76 |
+
"facebook/audiogen-medium",
|
77 |
+
]
|
78 |
|
79 |
|
80 |
def generate_prompt(difficulty, style):
|
81 |
+
_DIFFICULTY_MAPPIN = {
|
82 |
+
"Easy": "beginner player",
|
83 |
+
"Medum": "player who has 2-3 years experience",
|
84 |
+
"Hard": "player who has more than 4 years experiences",
|
85 |
+
}
|
86 |
+
prompt = "piano only music for a {} to pratice with the touch of {}".format(
|
87 |
+
_DIFFICULTY_MAPPIN[difficulty], style
|
88 |
+
)
|
89 |
+
return prompt
|
90 |
+
|
91 |
+
|
92 |
+
def toggle_melody_condition(melody_condition):
|
93 |
+
if melody_condition:
|
94 |
+
return gr.Audio(
|
95 |
+
sources=["microphone", "upload"],
|
96 |
+
label="Record or upload your audio",
|
97 |
+
show_label=True,
|
98 |
+
visible=True,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
return gr.Audio(
|
102 |
+
sources=["microphone", "upload"],
|
103 |
+
label="Record or upload your audio",
|
104 |
+
show_label=True,
|
105 |
+
visible=False,
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def show_caption(show_caption_condition, description, prompt):
|
110 |
+
if show_caption_condition:
|
111 |
+
return (
|
112 |
+
gr.Textbox(
|
113 |
+
label="Image Caption",
|
114 |
+
value=description,
|
115 |
+
interactive=False,
|
116 |
+
show_label=True,
|
117 |
+
visible=True,
|
118 |
+
),
|
119 |
+
gr.Textbox(
|
120 |
+
label="Generated Prompt",
|
121 |
+
value=prompt,
|
122 |
+
interactive=True,
|
123 |
+
show_label=True,
|
124 |
+
visible=True,
|
125 |
+
),
|
126 |
+
gr.Button("Generate Music", interactive=True, visible=True),
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
return (
|
130 |
+
gr.Textbox(
|
131 |
+
label="Image Caption",
|
132 |
+
value=description,
|
133 |
+
interactive=False,
|
134 |
+
show_label=True,
|
135 |
+
visible=False,
|
136 |
+
),
|
137 |
+
gr.Textbox(
|
138 |
+
label="Generated Prompt",
|
139 |
+
value=prompt,
|
140 |
+
interactive=True,
|
141 |
+
show_label=True,
|
142 |
+
visible=False,
|
143 |
+
),
|
144 |
+
gr.Button(label="Generate Music", interactive=True, visible=True),
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
def post_submit(show_caption, image_input):
|
149 |
+
_, description, prompt = generate_caption(image_input)
|
150 |
+
return (
|
151 |
+
gr.Textbox(
|
152 |
+
label="Image Caption",
|
153 |
+
value=description,
|
154 |
+
interactive=False,
|
155 |
+
show_label=True,
|
156 |
+
visible=show_caption,
|
157 |
+
),
|
158 |
+
gr.Textbox(
|
159 |
+
label="Generated Prompt",
|
160 |
+
value=prompt,
|
161 |
+
interactive=True,
|
162 |
+
show_label=True,
|
163 |
+
visible=show_caption,
|
164 |
+
),
|
165 |
+
gr.Button("Generate Music", interactive=True, visible=True),
|
166 |
+
)
|
167 |
|
168 |
|
169 |
def UI():
|
170 |
+
with gr.Blocks() as demo:
|
171 |
+
with gr.Tab("Generate Music by melody"):
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column():
|
174 |
+
with gr.Row():
|
175 |
+
model_path = gr.Dropdown(
|
176 |
+
choices=_AUDIOCRAFT_MODELS,
|
177 |
+
label="Select the model",
|
178 |
+
value="facebook/musicgen-melody-large",
|
179 |
+
)
|
180 |
+
with gr.Row():
|
181 |
+
duration = gr.Slider(
|
182 |
+
minimum=10,
|
183 |
+
maximum=60,
|
184 |
+
value=10,
|
185 |
+
label="Duration",
|
186 |
+
interactive=True,
|
187 |
+
)
|
188 |
+
with gr.Row():
|
189 |
+
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
190 |
+
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
191 |
+
temperature = gr.Number(
|
192 |
+
label="Temperature", value=1.0, interactive=True
|
193 |
+
)
|
194 |
+
sample_rate = gr.Number(
|
195 |
+
label="output music sample rate",
|
196 |
+
value=32000,
|
197 |
+
interactive=True,
|
198 |
+
)
|
199 |
+
difficulty = gr.Radio(
|
200 |
+
["Easy", "Medium", "Hard"],
|
201 |
+
label="Difficulty",
|
202 |
+
value="Easy",
|
203 |
+
interactive=True,
|
204 |
+
)
|
205 |
+
style = gr.Radio(
|
206 |
+
["Jazz", "Classical Music", "Hip Hop", "Others"],
|
207 |
+
value="Classical Music",
|
208 |
+
label="music genre",
|
209 |
+
interactive=True,
|
210 |
+
)
|
211 |
+
if style == "Others":
|
212 |
+
style = gr.Textbox(label="Type your music genre")
|
213 |
+
prompt = generate_prompt(difficulty.value, style.value)
|
214 |
+
customize = gr.Checkbox(
|
215 |
+
label="Customize the prompt", interactive=True
|
216 |
+
)
|
217 |
+
if customize:
|
218 |
+
prompt = gr.Textbox(label="Type your prompt")
|
219 |
+
with gr.Column():
|
220 |
+
with gr.Row():
|
221 |
+
melody = gr.Audio(
|
222 |
+
sources=["microphone", "upload"],
|
223 |
+
label="Record or upload your audio",
|
224 |
+
# interactive=True,
|
225 |
+
show_label=True,
|
226 |
+
)
|
227 |
+
with gr.Row():
|
228 |
+
submit = gr.Button("Generate Music")
|
229 |
+
output_audio = gr.Audio(
|
230 |
+
"listen to the generated music", type="filepath"
|
231 |
+
)
|
232 |
+
with gr.Row():
|
233 |
+
transcribe_button = gr.Button("Transcribe")
|
234 |
+
d = gr.DownloadButton("Download the file", visible=False)
|
235 |
+
transcribe_button.click(
|
236 |
+
transcribe, inputs=[output_audio], outputs=d
|
237 |
+
)
|
238 |
+
|
239 |
+
submit.click(
|
240 |
+
fn=predict,
|
241 |
+
inputs=[
|
242 |
+
model_path,
|
243 |
+
prompt,
|
244 |
+
melody,
|
245 |
+
duration,
|
246 |
+
topk,
|
247 |
+
topp,
|
248 |
+
temperature,
|
249 |
+
sample_rate,
|
250 |
+
],
|
251 |
+
outputs=output_audio,
|
252 |
+
)
|
253 |
+
gr.Examples(
|
254 |
+
examples=[
|
255 |
+
[
|
256 |
+
os.path.join(
|
257 |
+
os.path.dirname(__file__),
|
258 |
+
"./data/audio/twinkle_twinkle_little_stars_mozart_20sec"
|
259 |
+
".mp3",
|
260 |
+
),
|
261 |
+
"Easy",
|
262 |
+
32000,
|
263 |
+
20,
|
264 |
+
],
|
265 |
+
[
|
266 |
+
os.path.join(
|
267 |
+
os.path.dirname(__file__),
|
268 |
+
"./data/audio/golden_hour_20sec.mp3",
|
269 |
+
),
|
270 |
+
"Easy",
|
271 |
+
32000,
|
272 |
+
20,
|
273 |
+
],
|
274 |
+
[
|
275 |
+
os.path.join(
|
276 |
+
os.path.dirname(__file__),
|
277 |
+
"./data/audio/turkish_march_mozart_20sec.mp3",
|
278 |
+
),
|
279 |
+
"Easy",
|
280 |
+
32000,
|
281 |
+
20,
|
282 |
+
],
|
283 |
+
[
|
284 |
+
os.path.join(
|
285 |
+
os.path.dirname(__file__),
|
286 |
+
"./data/audio/golden_hour_20sec.mp3",
|
287 |
+
),
|
288 |
+
"Hard",
|
289 |
+
32000,
|
290 |
+
20,
|
291 |
+
],
|
292 |
+
[
|
293 |
+
os.path.join(
|
294 |
+
os.path.dirname(__file__),
|
295 |
+
"./data/audio/golden_hour_20sec.mp3",
|
296 |
+
),
|
297 |
+
"Hard",
|
298 |
+
32000,
|
299 |
+
40,
|
300 |
+
],
|
301 |
+
[
|
302 |
+
os.path.join(
|
303 |
+
os.path.dirname(__file__),
|
304 |
+
"./data/audio/golden_hour_20sec.mp3",
|
305 |
+
),
|
306 |
+
"Hard",
|
307 |
+
16000,
|
308 |
+
20,
|
309 |
+
],
|
310 |
+
],
|
311 |
+
inputs=[melody, difficulty, sample_rate, duration],
|
312 |
+
label="Audio Examples",
|
313 |
+
outputs=[output_audio],
|
314 |
+
# cache_examples=True,
|
315 |
+
)
|
316 |
+
|
317 |
+
with gr.Tab("Generate Music by image"):
|
318 |
+
with gr.Row():
|
319 |
+
with gr.Column():
|
320 |
+
image_input = gr.Image("Upload an image", type="filepath")
|
321 |
+
melody_condition = gr.Checkbox(
|
322 |
+
label="Generate music by melody", interactive=True, value=False
|
323 |
+
)
|
324 |
+
melody = gr.Audio(
|
325 |
+
sources=["microphone", "upload"],
|
326 |
+
label="Record or upload your audio",
|
327 |
+
show_label=True,
|
328 |
+
visible=False,
|
329 |
+
)
|
330 |
+
melody_condition.change(
|
331 |
+
fn=toggle_melody_condition,
|
332 |
+
inputs=[melody_condition],
|
333 |
+
outputs=melody,
|
334 |
+
)
|
335 |
+
description = gr.Textbox(
|
336 |
+
label="Image Captioning",
|
337 |
+
show_label=True,
|
338 |
+
interactive=False,
|
339 |
+
visible=False,
|
340 |
+
)
|
341 |
+
prompt = gr.Textbox(
|
342 |
+
label="Generated Prompt",
|
343 |
+
show_label=True,
|
344 |
+
interactive=True,
|
345 |
+
visible=False,
|
346 |
+
)
|
347 |
+
show_prompt = gr.Checkbox(label="Show the prompt", interactive=True)
|
348 |
+
submit = gr.Button("submit", interactive=True, visible=True)
|
349 |
+
generate = gr.Button(
|
350 |
+
"Generate Music", interactive=True, visible=False
|
351 |
+
)
|
352 |
+
submit.click(
|
353 |
+
fn=post_submit,
|
354 |
+
inputs=[show_prompt, image_input],
|
355 |
+
outputs=[description, prompt, generate],
|
356 |
+
)
|
357 |
+
show_prompt.change(
|
358 |
+
fn=show_caption,
|
359 |
+
inputs=[show_prompt, description, prompt],
|
360 |
+
outputs=[description, prompt, generate],
|
361 |
+
)
|
362 |
+
|
363 |
+
with gr.Column():
|
364 |
+
with gr.Row():
|
365 |
+
model_path = gr.Dropdown(
|
366 |
+
choices=_AUDIOCRAFT_MODELS,
|
367 |
+
label="Select the model",
|
368 |
+
value="facebook/musicgen-large",
|
369 |
+
)
|
370 |
+
with gr.Row():
|
371 |
+
duration = gr.Slider(
|
372 |
+
minimum=10,
|
373 |
+
maximum=60,
|
374 |
+
value=10,
|
375 |
+
label="Duration",
|
376 |
+
interactive=True,
|
377 |
+
)
|
378 |
+
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
379 |
+
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
380 |
+
temperature = gr.Number(
|
381 |
+
label="Temperature", value=1.0, interactive=True
|
382 |
+
)
|
383 |
+
sample_rate = gr.Number(
|
384 |
+
label="output music sample rate", value=32000, interactive=True
|
385 |
+
)
|
386 |
+
with gr.Column():
|
387 |
+
output_audio = gr.Audio(
|
388 |
+
"listen to the generated music",
|
389 |
+
type="filepath",
|
390 |
+
show_label=True,
|
391 |
+
)
|
392 |
+
transcribe_button = gr.Button("Transcribe")
|
393 |
+
d = gr.DownloadButton("Download the file", visible=False)
|
394 |
+
transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
|
395 |
+
generate.click(
|
396 |
+
fn=predict,
|
397 |
+
inputs=[
|
398 |
+
model_path,
|
399 |
+
prompt,
|
400 |
+
melody,
|
401 |
+
duration,
|
402 |
+
topk,
|
403 |
+
topp,
|
404 |
+
temperature,
|
405 |
+
sample_rate,
|
406 |
+
],
|
407 |
+
outputs=output_audio,
|
408 |
+
)
|
409 |
+
|
410 |
+
gr.Examples(
|
411 |
+
examples=[
|
412 |
+
[
|
413 |
+
os.path.join(
|
414 |
+
os.path.dirname(__file__),
|
415 |
+
"./data/image/kids_drawing.jpeg",
|
416 |
+
),
|
417 |
+
False,
|
418 |
+
None,
|
419 |
+
"facebook/musicgen-large",
|
420 |
+
],
|
421 |
+
[
|
422 |
+
os.path.join(
|
423 |
+
os.path.dirname(__file__),
|
424 |
+
"./data/image/cat.jpeg",
|
425 |
+
),
|
426 |
+
False,
|
427 |
+
None,
|
428 |
+
"facebook/musicgen-large",
|
429 |
+
],
|
430 |
+
[
|
431 |
+
os.path.join(
|
432 |
+
os.path.dirname(__file__),
|
433 |
+
"./data/image/cat.jpeg",
|
434 |
+
),
|
435 |
+
True,
|
436 |
+
"./data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3",
|
437 |
+
"facebook/musicgen-melody-large",
|
438 |
+
],
|
439 |
+
[
|
440 |
+
os.path.join(
|
441 |
+
os.path.dirname(__file__),
|
442 |
+
"./data/image/beach.jpeg",
|
443 |
+
),
|
444 |
+
False,
|
445 |
+
None,
|
446 |
+
"facebook/audiogen-medium",
|
447 |
+
],
|
448 |
+
],
|
449 |
+
inputs=[image_input, melody_condition, melody, model_path],
|
450 |
+
label="Audio Examples",
|
451 |
+
outputs=[output_audio],
|
452 |
+
# cache_examples=True,
|
453 |
+
)
|
454 |
+
|
455 |
+
demo.queue().launch()
|
456 |
|
457 |
|
458 |
if __name__ == "__main__":
|
459 |
+
UI()
|
data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa933b78b2380d325e6436d91de191c7abcb82e4c62ef2ed52a04868233a5012
|
3 |
+
size 3577581
|
data/image/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/image/beach.jpeg
ADDED
Git LFS Details
|
data/image/cat.jpeg
ADDED
Git LFS Details
|
data/image/kids_drawing.jpeg
ADDED
Git LFS Details
|
gradio_components/image.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import anthropic
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# Remember to put your API Key here
|
9 |
+
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
10 |
+
|
11 |
+
# image1_url = "https://i.abcnewsfe.com/a/7d849ccc-e0fe-4416-959d-85889e338add/dune-1-ht-bb-231212_1702405287482_hpMain_16x9.jpeg"
|
12 |
+
image1_media_type = "image/jpeg"
|
13 |
+
# image1_data = base64.b64encode(httpx.get(image1_url).content).decode("utf-8")
|
14 |
+
#
|
15 |
+
|
16 |
+
SYSTEM_PROMPT = """You are an expert llm prompt engineer, you understand the structure of llms and facebook musicgen text to audio model. You will be provided with an image, and require to output a prompt for the musicgen model to capture the essense of the image. Try to do it step by step, evaluate and analyze the image thoroughly. After that, develop a prompt that contains music genera, style, instrument, and all the other details needed. This prompt will be provided to musicgen model to generate a 15s audio clip.
|
17 |
+
|
18 |
+
Here are some descriptions from musicgen model:
|
19 |
+
The model was trained with descriptions from a stock music catalog, descriptions that will work best should include some level of detail on the instruments present, along with some intended use case (e.g. adding “perfect for a commercial” can somehow help).
|
20 |
+
|
21 |
+
Try to make the prompt simple and concise with only 1-2 sentences
|
22 |
+
|
23 |
+
Make sure the ouput is in JSON fomat, with two items `description` and `prompt`"""
|
24 |
+
|
25 |
+
|
26 |
+
def generate_caption(image_file, progress=gr.Progress()):
|
27 |
+
with open(image_file, "rb") as f:
|
28 |
+
image_encoded = base64.b64encode(f.read()).decode("utf-8")
|
29 |
+
progress(0, desc="Starting image captioning...")
|
30 |
+
message = client.messages.create(
|
31 |
+
model="claude-3-opus-20240229",
|
32 |
+
max_tokens=1024,
|
33 |
+
system=SYSTEM_PROMPT,
|
34 |
+
messages=[
|
35 |
+
{
|
36 |
+
"role": "user",
|
37 |
+
"content": [
|
38 |
+
{
|
39 |
+
"type": "image",
|
40 |
+
"source": {
|
41 |
+
"type": "base64",
|
42 |
+
"media_type": image1_media_type,
|
43 |
+
"data": image_encoded,
|
44 |
+
},
|
45 |
+
},
|
46 |
+
{"type": "text", "text": "develop the prompt based on this image"},
|
47 |
+
],
|
48 |
+
}
|
49 |
+
],
|
50 |
+
)
|
51 |
+
progress(100, desc="image captioning...Done!")
|
52 |
+
# Parse the content string into a Python object
|
53 |
+
message_object = json.loads(message.content[0].text)
|
54 |
+
# Access the description and prompt from the message object
|
55 |
+
description = message_object["description"]
|
56 |
+
prompt = message_object["prompt"]
|
57 |
+
print(description)
|
58 |
+
print(prompt)
|
59 |
+
return message_object, description, prompt
|
gradio_components/prediction.py
CHANGED
@@ -1,36 +1,53 @@
|
|
1 |
import time
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from audiocraft.data.audio_utils import convert_audio
|
5 |
-
from audiocraft.data.audio import audio_write
|
6 |
-
import gradio as gr
|
7 |
-
from audiocraft.models import MusicGen
|
8 |
-
|
9 |
-
from tempfile import NamedTemporaryFile
|
10 |
from pathlib import Path
|
11 |
-
from
|
|
|
12 |
import basic_pitch
|
13 |
import basic_pitch.inference
|
|
|
|
|
|
|
|
|
|
|
14 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
|
|
15 |
|
16 |
|
17 |
-
def load_model(version=
|
18 |
return MusicGen.get_pretrained(version)
|
19 |
|
20 |
|
21 |
-
def _do_predictions(
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
be = time.time()
|
24 |
processed_melodies = []
|
25 |
for melody in melodies:
|
26 |
if melody is None:
|
27 |
processed_melodies.append(None)
|
28 |
else:
|
29 |
-
sr, melody =
|
|
|
|
|
|
|
30 |
print(f"Input audio sample rate is {sr}")
|
31 |
if melody.dim() == 1:
|
32 |
melody = melody[None]
|
33 |
-
melody = melody[..., :int(sr * duration)]
|
34 |
melody = convert_audio(melody, sr, target_sr, target_ac)
|
35 |
processed_melodies.append(melody)
|
36 |
|
@@ -42,7 +59,7 @@ def _do_predictions(model, texts, melodies, duration, progress=False, gradio_pro
|
|
42 |
melody_wavs=processed_melodies,
|
43 |
melody_sample_rate=target_sr,
|
44 |
progress=progress,
|
45 |
-
return_tokens=False
|
46 |
)
|
47 |
else:
|
48 |
# text only
|
@@ -55,14 +72,30 @@ def _do_predictions(model, texts, melodies, duration, progress=False, gradio_pro
|
|
55 |
for output in outputs:
|
56 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
57 |
audio_write(
|
58 |
-
file.name,
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
out_wavs.append(file.name)
|
61 |
print("generation finished", len(texts), time.time() - be)
|
62 |
return out_wavs
|
63 |
|
64 |
|
65 |
-
def predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
global INTERRUPTING
|
67 |
global USE_DIFFUSION
|
68 |
INTERRUPTING = False
|
@@ -92,6 +125,7 @@ def predict(model_path, text, melody, duration, topk, topp, temperature, target_
|
|
92 |
progress((min(max_generated, to_generate), to_generate))
|
93 |
if INTERRUPTING:
|
94 |
raise gr.Error("Interrupted.")
|
|
|
95 |
model.set_custom_progress_callback(_progress)
|
96 |
|
97 |
wavs = _do_predictions(
|
@@ -105,7 +139,8 @@ def predict(model_path, text, melody, duration, topk, topp, temperature, target_
|
|
105 |
top_k=topk,
|
106 |
top_p=topp,
|
107 |
temperature=temperature,
|
108 |
-
gradio_progress=progress
|
|
|
109 |
return wavs[0]
|
110 |
|
111 |
|
@@ -114,7 +149,7 @@ def transcribe(audio_path):
|
|
114 |
model_output, midi_data, note_events = basic_pitch.inference.predict(
|
115 |
audio_path=audio_path,
|
116 |
model_or_model_path=ICASSP_2022_MODEL_PATH,
|
117 |
-
|
118 |
|
119 |
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
|
120 |
try:
|
@@ -125,6 +160,5 @@ def transcribe(audio_path):
|
|
125 |
raise e
|
126 |
|
127 |
return gr.DownloadButton(
|
128 |
-
value=file.name,
|
129 |
-
|
130 |
-
visible=True)
|
|
|
1 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from pathlib import Path
|
3 |
+
from tempfile import NamedTemporaryFile
|
4 |
+
|
5 |
import basic_pitch
|
6 |
import basic_pitch.inference
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
from audiocraft.data.audio import audio_write
|
10 |
+
from audiocraft.data.audio_utils import convert_audio
|
11 |
+
from audiocraft.models import MusicGen
|
12 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
13 |
+
from transformers import AutoModelForSeq2SeqLM
|
14 |
|
15 |
|
16 |
+
def load_model(version="facebook/musicgen-melody"):
|
17 |
return MusicGen.get_pretrained(version)
|
18 |
|
19 |
|
20 |
+
def _do_predictions(
|
21 |
+
model,
|
22 |
+
texts,
|
23 |
+
melodies,
|
24 |
+
duration,
|
25 |
+
progress=False,
|
26 |
+
gradio_progress=None,
|
27 |
+
target_sr=32000,
|
28 |
+
target_ac=1,
|
29 |
+
**gen_kwargs,
|
30 |
+
):
|
31 |
+
print(
|
32 |
+
"new batch",
|
33 |
+
len(texts),
|
34 |
+
texts,
|
35 |
+
[None if m is None else (m[0], m[1].shape) for m in melodies],
|
36 |
+
)
|
37 |
be = time.time()
|
38 |
processed_melodies = []
|
39 |
for melody in melodies:
|
40 |
if melody is None:
|
41 |
processed_melodies.append(None)
|
42 |
else:
|
43 |
+
sr, melody = (
|
44 |
+
melody[0],
|
45 |
+
torch.from_numpy(melody[1]).to(model.device).float().t(),
|
46 |
+
)
|
47 |
print(f"Input audio sample rate is {sr}")
|
48 |
if melody.dim() == 1:
|
49 |
melody = melody[None]
|
50 |
+
melody = melody[..., : int(sr * duration)]
|
51 |
melody = convert_audio(melody, sr, target_sr, target_ac)
|
52 |
processed_melodies.append(melody)
|
53 |
|
|
|
59 |
melody_wavs=processed_melodies,
|
60 |
melody_sample_rate=target_sr,
|
61 |
progress=progress,
|
62 |
+
return_tokens=False,
|
63 |
)
|
64 |
else:
|
65 |
# text only
|
|
|
72 |
for output in outputs:
|
73 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
74 |
audio_write(
|
75 |
+
file.name,
|
76 |
+
output,
|
77 |
+
model.sample_rate,
|
78 |
+
strategy="loudness",
|
79 |
+
loudness_headroom_db=16,
|
80 |
+
loudness_compressor=True,
|
81 |
+
add_suffix=False,
|
82 |
+
)
|
83 |
out_wavs.append(file.name)
|
84 |
print("generation finished", len(texts), time.time() - be)
|
85 |
return out_wavs
|
86 |
|
87 |
|
88 |
+
def predict(
|
89 |
+
model_path,
|
90 |
+
text,
|
91 |
+
melody,
|
92 |
+
duration,
|
93 |
+
topk,
|
94 |
+
topp,
|
95 |
+
temperature,
|
96 |
+
target_sr,
|
97 |
+
progress=gr.Progress(),
|
98 |
+
):
|
99 |
global INTERRUPTING
|
100 |
global USE_DIFFUSION
|
101 |
INTERRUPTING = False
|
|
|
125 |
progress((min(max_generated, to_generate), to_generate))
|
126 |
if INTERRUPTING:
|
127 |
raise gr.Error("Interrupted.")
|
128 |
+
|
129 |
model.set_custom_progress_callback(_progress)
|
130 |
|
131 |
wavs = _do_predictions(
|
|
|
139 |
top_k=topk,
|
140 |
top_p=topp,
|
141 |
temperature=temperature,
|
142 |
+
gradio_progress=progress,
|
143 |
+
)
|
144 |
return wavs[0]
|
145 |
|
146 |
|
|
|
149 |
model_output, midi_data, note_events = basic_pitch.inference.predict(
|
150 |
audio_path=audio_path,
|
151 |
model_or_model_path=ICASSP_2022_MODEL_PATH,
|
152 |
+
)
|
153 |
|
154 |
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
|
155 |
try:
|
|
|
160 |
raise e
|
161 |
|
162 |
return gr.DownloadButton(
|
163 |
+
value=file.name, label=f"Download MIDI file {file.name}", visible=True
|
164 |
+
)
|
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ torch==2.1.0
|
|
2 |
audiocraft
|
3 |
basic-pitch
|
4 |
gradio
|
5 |
-
tensorflow==2.15.0
|
|
|
|
2 |
audiocraft
|
3 |
basic-pitch
|
4 |
gradio
|
5 |
+
tensorflow==2.15.0
|
6 |
+
anthropic
|