Spaces:
Sleeping
Sleeping
init app
Browse files- README.md +4 -1
- app.py +205 -0
- data/audio/golden_hour.mp3 +3 -0
- data/audio/turkish_march_mozart.mp3 +3 -0
- data/audio/twinkle_twinkle_little_stars_mozart.mp3 +3 -0
- gradio_components/prediction.py +103 -0
- requirements.txt +4 -0
README.md
CHANGED
@@ -2,11 +2,14 @@
|
|
2 |
title: MMM MagicMusicMachine
|
3 |
emoji: 🐨
|
4 |
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
|
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
|
|
2 |
title: MMM MagicMusicMachine
|
3 |
emoji: 🐨
|
4 |
colorFrom: purple
|
5 |
+
colorTo: magenta
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
+
python_version: python 3.9
|
9 |
+
suggested_hardware: "a10g-large"
|
10 |
app_file: app.py
|
11 |
pinned: false
|
12 |
+
tags: MusicAI, MultiModal, Audio, Text, Image
|
13 |
license: mit
|
14 |
---
|
15 |
|
app.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from gradio_components.prediction import load_model, predict
|
5 |
+
|
6 |
+
theme = gr.themes.Glass(
|
7 |
+
primary_hue="fuchsia",
|
8 |
+
secondary_hue="indigo",
|
9 |
+
neutral_hue="slate",
|
10 |
+
font=[gr.themes.GoogleFont('Source Sans Pro'), 'ui-sans-serif', 'system-ui',
|
11 |
+
'sans-serif'],
|
12 |
+
).set(
|
13 |
+
body_background_fill_dark='*background_fill_primary',
|
14 |
+
embed_radius='*table_radius',
|
15 |
+
background_fill_primary='*neutral_50',
|
16 |
+
background_fill_primary_dark='*neutral_950',
|
17 |
+
background_fill_secondary_dark='*neutral_900',
|
18 |
+
border_color_accent='*neutral_600',
|
19 |
+
border_color_accent_subdued='*color_accent',
|
20 |
+
border_color_primary_dark='*neutral_700',
|
21 |
+
block_background_fill='*background_fill_primary',
|
22 |
+
block_background_fill_dark='*neutral_800',
|
23 |
+
block_border_width='1px',
|
24 |
+
block_label_background_fill='*background_fill_primary',
|
25 |
+
block_label_background_fill_dark='*background_fill_secondary',
|
26 |
+
block_label_text_color='*neutral_500',
|
27 |
+
block_label_text_size='*text_sm',
|
28 |
+
block_label_text_weight='400',
|
29 |
+
block_shadow='none',
|
30 |
+
block_shadow_dark='none',
|
31 |
+
block_title_text_color='*neutral_500',
|
32 |
+
block_title_text_weight='400',
|
33 |
+
panel_border_width='0',
|
34 |
+
panel_border_width_dark='0',
|
35 |
+
checkbox_background_color_dark='*neutral_800',
|
36 |
+
checkbox_border_width='*input_border_width',
|
37 |
+
checkbox_label_border_width='*input_border_width',
|
38 |
+
input_background_fill='*neutral_100',
|
39 |
+
input_background_fill_dark='*neutral_700',
|
40 |
+
input_border_color_focus_dark='*neutral_700',
|
41 |
+
input_border_width='0px',
|
42 |
+
input_border_width_dark='0px',
|
43 |
+
slider_color='#2563eb',
|
44 |
+
slider_color_dark='#2563eb',
|
45 |
+
table_even_background_fill_dark='*neutral_950',
|
46 |
+
table_odd_background_fill_dark='*neutral_900',
|
47 |
+
button_border_width='*input_border_width',
|
48 |
+
button_shadow_active='none',
|
49 |
+
button_primary_background_fill='*primary_200',
|
50 |
+
button_primary_background_fill_dark='*primary_700',
|
51 |
+
button_primary_background_fill_hover='*button_primary_background_fill',
|
52 |
+
button_primary_background_fill_hover_dark='*button_primary_background_fill',
|
53 |
+
button_secondary_background_fill='*neutral_200',
|
54 |
+
button_secondary_background_fill_dark='*neutral_600',
|
55 |
+
button_secondary_background_fill_hover='*button_secondary_background_fill',
|
56 |
+
button_secondary_background_fill_hover_dark='*button_secondary_background_fill',
|
57 |
+
button_cancel_background_fill='*button_secondary_background_fill',
|
58 |
+
button_cancel_background_fill_dark='*button_secondary_background_fill',
|
59 |
+
button_cancel_background_fill_hover='*button_cancel_background_fill',
|
60 |
+
button_cancel_background_fill_hover_dark='*button_cancel_background_fill'
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
_AUDIOCRAFT_MODELS = ["facebook/musicgen-melody",
|
65 |
+
"facebook/musicgen-medium",
|
66 |
+
"facebook/musicgen-small",
|
67 |
+
"facebook/musicgen-large",
|
68 |
+
"facebook/musicgen-melody-large"]
|
69 |
+
|
70 |
+
|
71 |
+
def generate_prompt(difficulty, style):
|
72 |
+
_DIFFICULTY_MAPPIN = {
|
73 |
+
"Easy": "beginner player",
|
74 |
+
"Medum": "player who has 2-3 years experience",
|
75 |
+
"Hard": "player who has more than 4 years experiences"
|
76 |
+
}
|
77 |
+
prompt = 'piano only music for a {} to pratice with the touch of {}'.format(
|
78 |
+
_DIFFICULTY_MAPPIN[difficulty], style
|
79 |
+
)
|
80 |
+
return prompt
|
81 |
+
|
82 |
+
def UI():
|
83 |
+
with gr.Blocks() as demo:
|
84 |
+
with gr.Tab("Generate Music by melody"):
|
85 |
+
with gr.Row():
|
86 |
+
with gr.Column():
|
87 |
+
with gr.Row():
|
88 |
+
model_path = gr.Dropdown(
|
89 |
+
choices=_AUDIOCRAFT_MODELS,
|
90 |
+
label="Select the model",
|
91 |
+
value="facebook/musicgen-melody-large"
|
92 |
+
)
|
93 |
+
with gr.Row():
|
94 |
+
duration = gr.Slider(
|
95 |
+
minimum=10,
|
96 |
+
maximum=60,
|
97 |
+
value=10,
|
98 |
+
label="Duration",
|
99 |
+
interactive=True
|
100 |
+
)
|
101 |
+
with gr.Row():
|
102 |
+
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
103 |
+
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
104 |
+
temperature = gr.Number(
|
105 |
+
label="Temperature", value=1.0, interactive=True
|
106 |
+
)
|
107 |
+
sample_rate = gr.Number(label="output music sample rate", value=32000, interactive=True)
|
108 |
+
difficulty = gr.Radio(["Easy", "Medium", "Hard"], label="Difficulty", value="Easy", interactive=True)
|
109 |
+
style = gr.Radio(["Jazz", "Classical Music", "Hip Hop", "Others"], value="Classical Music", label="music genre", interactive=True)
|
110 |
+
if style == "Others":
|
111 |
+
style = gr.Textbox(label="Type your music genre")
|
112 |
+
prompt = generate_prompt(difficulty.value, style.value)
|
113 |
+
customize = gr.Checkbox(
|
114 |
+
label="Customize the prompt", interactive=True
|
115 |
+
)
|
116 |
+
if customize:
|
117 |
+
prompt = gr.Textbox(label="Type your prompt")
|
118 |
+
with gr.Column():
|
119 |
+
with gr.Row():
|
120 |
+
melody = gr.Audio(
|
121 |
+
sources=["microphone", "upload"],
|
122 |
+
streaming=True,
|
123 |
+
label="Record or upload your audio",
|
124 |
+
interactive=True,
|
125 |
+
type="numpy",
|
126 |
+
show_label=True,
|
127 |
+
)
|
128 |
+
with gr.Row():
|
129 |
+
submit = gr.Button("Generate Music")
|
130 |
+
output = gr.Audio("listen to the generated music")
|
131 |
+
submit.click(fn=predict, inputs=melody, outputs=output)
|
132 |
+
|
133 |
+
generate_music = gr.Button("Generate Music")
|
134 |
+
output = gr.Audio("listen to the music")
|
135 |
+
generate_music.click(fn=predict,
|
136 |
+
inputs=[model_path, prompt, melody, duration, topk, topp, temperature, sample_rate],
|
137 |
+
outputs=output)
|
138 |
+
|
139 |
+
gr.Examples(
|
140 |
+
examples=[
|
141 |
+
[
|
142 |
+
os.path.join(
|
143 |
+
os.path.dirname(__file__),
|
144 |
+
"./data/audio/twinkle_twinkle_little_stars_mozart.mp3"
|
145 |
+
),
|
146 |
+
"Easy",
|
147 |
+
32000,
|
148 |
+
20
|
149 |
+
],
|
150 |
+
[
|
151 |
+
os.path.join(
|
152 |
+
os.path.dirname(__file__),
|
153 |
+
"./data/audio/golden_hour.mp3"
|
154 |
+
),
|
155 |
+
"Easy",
|
156 |
+
32000,
|
157 |
+
20
|
158 |
+
],
|
159 |
+
[
|
160 |
+
os.path.join(
|
161 |
+
os.path.dirname(__file__),
|
162 |
+
"./data/audio/turkish_march_mozart.mp3"
|
163 |
+
),
|
164 |
+
"Easy",
|
165 |
+
32000,
|
166 |
+
20
|
167 |
+
],
|
168 |
+
[
|
169 |
+
os.path.join(
|
170 |
+
os.path.dirname(__file__),
|
171 |
+
"./data/audio/golden_hour.mp3"
|
172 |
+
),
|
173 |
+
"Hard",
|
174 |
+
32000,
|
175 |
+
20
|
176 |
+
],
|
177 |
+
[
|
178 |
+
os.path.join(
|
179 |
+
os.path.dirname(__file__),
|
180 |
+
"./data/audio/golden_hour.mp3"
|
181 |
+
),
|
182 |
+
"Hard",
|
183 |
+
32000,
|
184 |
+
40
|
185 |
+
],
|
186 |
+
[
|
187 |
+
os.path.join(
|
188 |
+
os.path.dirname(__file__),
|
189 |
+
"./data/audio/golden_hour.mp3"
|
190 |
+
),
|
191 |
+
"Hard",
|
192 |
+
16000,
|
193 |
+
20
|
194 |
+
],
|
195 |
+
],
|
196 |
+
inputs=[melody, difficulty, sample_rate, duration],
|
197 |
+
label="Audio Examples",
|
198 |
+
outputs=[output],
|
199 |
+
# cache_examples=True,
|
200 |
+
)
|
201 |
+
demo.queue().launch()
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
UI()
|
205 |
+
|
data/audio/golden_hour.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90d863d50f73f0d84a2e199349909a2ef8e82cdec12c64e9ee29043e3f3a7730
|
3 |
+
size 5468297
|
data/audio/turkish_march_mozart.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6367d283829147ff0db47b4bbf0fdd9d159ef95d40e5a9279c81e8be93f9f2cd
|
3 |
+
size 5085237
|
data/audio/twinkle_twinkle_little_stars_mozart.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed46f25fb0031b270dafc14981e67121ecf094e15c6c6c138f7998672de8ce7a
|
3 |
+
size 20276397
|
gradio_components/prediction.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from audiocraft.data.audio_utils import convert_audio
|
5 |
+
from audiocraft.data.audio import audio_write
|
6 |
+
import gradio as gr
|
7 |
+
from audiocraft.models import MusicGen
|
8 |
+
|
9 |
+
from tempfile import NamedTemporaryFile
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
|
13 |
+
def load_model(version='facebook/musicgen-melody'):
|
14 |
+
return MusicGen.get_pretrained(version)
|
15 |
+
|
16 |
+
|
17 |
+
def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs):
|
18 |
+
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
19 |
+
be = time.time()
|
20 |
+
processed_melodies = []
|
21 |
+
for melody in melodies:
|
22 |
+
if melody is None:
|
23 |
+
processed_melodies.append(None)
|
24 |
+
else:
|
25 |
+
sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
|
26 |
+
if melody.dim() == 1:
|
27 |
+
melody = melody[None]
|
28 |
+
melody = melody[..., :int(sr * duration)]
|
29 |
+
melody = convert_audio(melody, sr, target_sr, target_ac)
|
30 |
+
processed_melodies.append(melody)
|
31 |
+
|
32 |
+
try:
|
33 |
+
if any(m is not None for m in processed_melodies):
|
34 |
+
# melody condition
|
35 |
+
outputs = model.generate_with_chroma(
|
36 |
+
descriptions=texts,
|
37 |
+
melody_wavs=processed_melodies,
|
38 |
+
melody_sample_rate=target_sr,
|
39 |
+
progress=progress,
|
40 |
+
return_tokens=False
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
# text only
|
44 |
+
outputs = model.generate(texts, progress=progress, return_tokens=False)
|
45 |
+
except RuntimeError as e:
|
46 |
+
raise gr.Error("Error while generating " + e.args[0])
|
47 |
+
outputs = outputs.detach().cpu().float()
|
48 |
+
pending_videos = []
|
49 |
+
out_wavs = []
|
50 |
+
for output in outputs:
|
51 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
52 |
+
audio_write(
|
53 |
+
file.name, output, model.sample_rate, strategy="loudness",
|
54 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
55 |
+
out_wavs.append(file.name)
|
56 |
+
print("generation finished", len(texts), time.time() - be)
|
57 |
+
return out_wavs
|
58 |
+
|
59 |
+
|
60 |
+
def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()):
|
61 |
+
global INTERRUPTING
|
62 |
+
global USE_DIFFUSION
|
63 |
+
INTERRUPTING = False
|
64 |
+
progress(0, desc="Loading model...")
|
65 |
+
model_path = model_path.strip()
|
66 |
+
if model_path:
|
67 |
+
if not Path(model_path).exists():
|
68 |
+
raise gr.Error(f"Model path {model_path} doesn't exist.")
|
69 |
+
if not Path(model_path).is_dir():
|
70 |
+
raise gr.Error(f"Model path {model_path} must be a folder containing "
|
71 |
+
"state_dict.bin and compression_state_dict_.bin.")
|
72 |
+
if temperature < 0:
|
73 |
+
raise gr.Error("Temperature must be >= 0.")
|
74 |
+
if topk < 0:
|
75 |
+
raise gr.Error("Topk must be non-negative.")
|
76 |
+
if topp < 0:
|
77 |
+
raise gr.Error("Topp must be non-negative.")
|
78 |
+
|
79 |
+
topk = int(topk)
|
80 |
+
model = load_model(model_path)
|
81 |
+
|
82 |
+
max_generated = 0
|
83 |
+
|
84 |
+
def _progress(generated, to_generate):
|
85 |
+
nonlocal max_generated
|
86 |
+
max_generated = max(generated, max_generated)
|
87 |
+
progress((min(max_generated, to_generate), to_generate))
|
88 |
+
if INTERRUPTING:
|
89 |
+
raise gr.Error("Interrupted.")
|
90 |
+
model.set_custom_progress_callback(_progress)
|
91 |
+
|
92 |
+
wavs = _do_predictions(
|
93 |
+
[text],
|
94 |
+
[melody],
|
95 |
+
duration,
|
96 |
+
progress=True,
|
97 |
+
target_ac=1,
|
98 |
+
target_sr=target_sr,
|
99 |
+
top_k=topk,
|
100 |
+
top_p=topp,
|
101 |
+
temperature=temperature,
|
102 |
+
gradio_progress=progress)
|
103 |
+
return wavs
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.0
|
2 |
+
audiocraft
|
3 |
+
basic-pitch
|
4 |
+
gradio
|