Audiofool
commited on
Commit
·
5f28c4a
1
Parent(s):
f800d5f
update app.py
Browse files
app.py
CHANGED
@@ -9,29 +9,17 @@ import base64
|
|
9 |
from pathlib import Path
|
10 |
from tempfile import NamedTemporaryFile
|
11 |
|
12 |
-
from einops import rearrange
|
13 |
-
import torch
|
14 |
import gradio as gr
|
15 |
import requests
|
16 |
|
17 |
-
from audiocraft.data.audio_utils import convert_audio
|
18 |
-
from audiocraft.data.audio import audio_write
|
19 |
-
from audiocraft.models.encodec import InterleaveStereoCompressionModel
|
20 |
-
from audiocraft.models import MusicGen, MultiBandDiffusion
|
21 |
-
|
22 |
from theme_wave import theme, css
|
23 |
|
24 |
# --- Configuration (Main App) ---
|
25 |
-
MLLM_API_URL =
|
26 |
-
|
27 |
-
)
|
28 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
|
30 |
# --- Global Variables (Main App) ---
|
31 |
-
MODEL = None
|
32 |
-
MBD = None
|
33 |
INTERRUPTING = False
|
34 |
-
USE_DIFFUSION = False # Keep this for now, even if unused, for easier switching
|
35 |
|
36 |
|
37 |
# --- Utility Functions (Main App) ---
|
@@ -72,29 +60,8 @@ def make_waveform(*args, **kwargs):
|
|
72 |
return gr.make_waveform(*args, **kwargs)
|
73 |
|
74 |
|
75 |
-
# --- Model Loading (Main App) ---
|
76 |
-
|
77 |
-
|
78 |
-
def load_musicgen_model(version="facebook/musicgen-stereo-melody-large"):
|
79 |
-
global MODEL
|
80 |
-
print(f"Loading MusicGen model: {version}")
|
81 |
-
if MODEL is None or MODEL.name != version:
|
82 |
-
if MODEL is not None:
|
83 |
-
del MODEL
|
84 |
-
torch.cuda.empty_cache()
|
85 |
-
MODEL = MusicGen.get_pretrained(version, device=DEVICE)
|
86 |
-
|
87 |
-
|
88 |
-
def load_diffusion_model():
|
89 |
-
global MBD
|
90 |
-
if MBD is None:
|
91 |
-
print("Loading diffusion model")
|
92 |
-
MBD = MultiBandDiffusion.get_mbd_musicgen(device=DEVICE)
|
93 |
-
|
94 |
-
|
95 |
# --- API Client Functions ---
|
96 |
|
97 |
-
|
98 |
def get_mllm_description(media_path: str, user_prompt: str) -> str:
|
99 |
"""Gets the music description from the MLLM API."""
|
100 |
|
@@ -122,7 +89,7 @@ def get_mllm_description(media_path: str, user_prompt: str) -> str:
|
|
122 |
f"{MLLM_API_URL}/describe_text/", json={"user_prompt": user_prompt}
|
123 |
)
|
124 |
|
125 |
-
response.raise_for_status()
|
126 |
return response.json()["description"]
|
127 |
|
128 |
except requests.exceptions.RequestException as e:
|
@@ -131,9 +98,73 @@ def get_mllm_description(media_path: str, user_prompt: str) -> str:
|
|
131 |
raise gr.Error(f"An unexpected error occurred: {e}")
|
132 |
|
133 |
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
|
|
|
|
137 |
def predict_full(
|
138 |
model_version,
|
139 |
media_type,
|
@@ -149,9 +180,9 @@ def predict_full(
|
|
149 |
decoder,
|
150 |
progress=gr.Progress(),
|
151 |
):
|
152 |
-
global INTERRUPTING
|
153 |
INTERRUPTING = False
|
154 |
-
|
155 |
|
156 |
if media_type == "Image":
|
157 |
media = image_input if image_input else None
|
@@ -160,124 +191,37 @@ def predict_full(
|
|
160 |
else:
|
161 |
media = None
|
162 |
|
163 |
-
# 1. Get Music Description (using the API
|
164 |
progress(progress=None, desc="Generating music description...")
|
165 |
if media:
|
166 |
try:
|
167 |
music_description = get_mllm_description(media, text_prompt)
|
168 |
except Exception as e:
|
169 |
-
raise gr.Error(str(e))
|
170 |
else:
|
171 |
music_description = text_prompt
|
172 |
|
173 |
-
# 2.
|
174 |
-
progress(progress=None, desc="
|
175 |
-
load_musicgen_model(model_version)
|
176 |
-
|
177 |
-
# 3. Set Generation Parameters (locally).
|
178 |
-
MODEL.set_generation_params(
|
179 |
-
duration=duration,
|
180 |
-
top_k=topk,
|
181 |
-
top_p=topp,
|
182 |
-
temperature=temperature,
|
183 |
-
cfg_coef=cfg_coef,
|
184 |
-
)
|
185 |
-
|
186 |
-
# 4. Melody Preprocessing (locally).
|
187 |
-
progress(progress=None, desc="Processing melody...")
|
188 |
-
melody_tensor = None # Use a different variable name
|
189 |
-
if melody:
|
190 |
-
try:
|
191 |
-
sr, melody_tensor = (
|
192 |
-
melody[0],
|
193 |
-
torch.from_numpy(melody[1]).to(MODEL.device).float().t(),
|
194 |
-
)
|
195 |
-
if melody_tensor.dim() == 1:
|
196 |
-
melody_tensor = melody_tensor[None]
|
197 |
-
melody_tensor = melody_tensor[..., : int(sr * duration)]
|
198 |
-
melody_tensor = convert_audio(
|
199 |
-
melody_tensor, sr, MODEL.sample_rate, MODEL.audio_channels
|
200 |
-
)
|
201 |
-
|
202 |
-
except Exception as e:
|
203 |
-
raise gr.Error(f"Error processing melody: {e}")
|
204 |
-
|
205 |
-
# 5. Music Generation (locally).
|
206 |
-
progress(progress=None, desc="Generating music...")
|
207 |
-
if USE_DIFFUSION:
|
208 |
-
load_diffusion_model()
|
209 |
-
|
210 |
try:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
descriptions=[music_description],
|
222 |
-
progress=True,
|
223 |
-
return_tokens=USE_DIFFUSION,
|
224 |
-
)
|
225 |
-
except RuntimeError as e:
|
226 |
-
raise gr.Error("Error while generating: " + str(e))
|
227 |
-
|
228 |
-
if USE_DIFFUSION:
|
229 |
-
progress(progress=None, desc="Running MultiBandDiffusion...")
|
230 |
-
tokens = output[1]
|
231 |
-
if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
|
232 |
-
left, right = MODEL.compression_model.get_left_right_codes(tokens)
|
233 |
-
tokens = torch.cat([left, right])
|
234 |
-
outputs_diffusion = MBD.tokens_to_wav(tokens)
|
235 |
-
if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
|
236 |
-
assert outputs_diffusion.shape[1] == 1 # output is mono
|
237 |
-
outputs_diffusion = rearrange(
|
238 |
-
outputs_diffusion, "(s b) c t -> b (s c) t", s=2
|
239 |
-
)
|
240 |
-
output_audio = torch.cat([output[0], outputs_diffusion], dim=0)
|
241 |
-
else:
|
242 |
-
output_audio = output[0]
|
243 |
-
|
244 |
-
output_audio = output_audio.detach().cpu().float()
|
245 |
-
|
246 |
-
# 6. Save and Return (locally).
|
247 |
-
progress(progress=None, desc="Saving and returning...")
|
248 |
-
output_audio_paths = []
|
249 |
-
|
250 |
-
for i, audio in enumerate(output_audio):
|
251 |
-
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
252 |
-
audio_write(
|
253 |
-
file.name,
|
254 |
-
audio,
|
255 |
-
MODEL.sample_rate,
|
256 |
-
strategy="loudness",
|
257 |
-
loudness_headroom_db=16,
|
258 |
-
loudness_compressor=True,
|
259 |
-
add_suffix=False,
|
260 |
-
)
|
261 |
-
output_audio_paths.append(file.name)
|
262 |
-
file_cleaner.add(file.name)
|
263 |
-
|
264 |
-
if USE_DIFFUSION:
|
265 |
-
# Return both audios, but make sure to return the correct one first
|
266 |
-
result = (
|
267 |
-
output_audio_paths[0], # Original
|
268 |
-
output_audio_paths[1], # MBD
|
269 |
)
|
270 |
-
|
271 |
-
|
272 |
-
output_audio_paths[0],
|
273 |
-
None,
|
274 |
-
) # Only original audio and description
|
275 |
|
276 |
-
|
277 |
-
|
278 |
-
torch.cuda.empty_cache()
|
279 |
|
280 |
-
return
|
281 |
|
282 |
|
283 |
Wave = theme()
|
@@ -349,9 +293,7 @@ def create_ui(launch_kwargs=None):
|
|
349 |
)
|
350 |
with gr.Row():
|
351 |
submit_button = gr.Button("Generate Music", variant="primary")
|
352 |
-
interrupt_button = gr.Button(
|
353 |
-
"Interrupt", variant="stop"
|
354 |
-
) # Keep as gr.Button
|
355 |
with gr.Row():
|
356 |
model_version = gr.Dropdown(
|
357 |
[
|
@@ -384,8 +326,6 @@ def create_ui(launch_kwargs=None):
|
|
384 |
interactive=True,
|
385 |
)
|
386 |
|
387 |
-
# with gr.Row():
|
388 |
-
# description_output = gr.Textbox(label="MLLM Generated Description")
|
389 |
with gr.Row():
|
390 |
output_audio = gr.Audio(label="Generated Music", type="filepath")
|
391 |
output_audio_mbd = gr.Audio(
|
@@ -408,12 +348,9 @@ def create_ui(launch_kwargs=None):
|
|
408 |
cfg_coef,
|
409 |
decoder,
|
410 |
],
|
411 |
-
# outputs=[output_audio, description_output, output_audio_mbd],
|
412 |
outputs=[output_audio, output_audio_mbd],
|
413 |
)
|
414 |
interrupt_button.click(interrupt_handler, [], [])
|
415 |
-
if INTERRUPTING:
|
416 |
-
raise gr.Error("Interrupted.")
|
417 |
|
418 |
gr.Examples(
|
419 |
examples=[
|
@@ -495,7 +432,7 @@ if __name__ == "__main__":
|
|
495 |
)
|
496 |
parser.add_argument(
|
497 |
"--server_port", type=int, default=0, help="Port to run the server on"
|
498 |
-
)
|
499 |
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
500 |
parser.add_argument("--share", action="store_true", help="Share the Gradio UI")
|
501 |
|
@@ -513,4 +450,4 @@ if __name__ == "__main__":
|
|
513 |
launch_kwargs["share"] = args.share
|
514 |
|
515 |
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
|
516 |
-
create_ui(launch_kwargs)
|
|
|
9 |
from pathlib import Path
|
10 |
from tempfile import NamedTemporaryFile
|
11 |
|
|
|
|
|
12 |
import gradio as gr
|
13 |
import requests
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
from theme_wave import theme, css
|
16 |
|
17 |
# --- Configuration (Main App) ---
|
18 |
+
MLLM_API_URL = "http://localhost:8000"
|
19 |
+
MUSICGEN_API_URL = "https://your-musicgen-api-endpoint.com" # Replace with actual MusicGen API endpoint
|
|
|
|
|
20 |
|
21 |
# --- Global Variables (Main App) ---
|
|
|
|
|
22 |
INTERRUPTING = False
|
|
|
23 |
|
24 |
|
25 |
# --- Utility Functions (Main App) ---
|
|
|
60 |
return gr.make_waveform(*args, **kwargs)
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# --- API Client Functions ---
|
64 |
|
|
|
65 |
def get_mllm_description(media_path: str, user_prompt: str) -> str:
|
66 |
"""Gets the music description from the MLLM API."""
|
67 |
|
|
|
89 |
f"{MLLM_API_URL}/describe_text/", json={"user_prompt": user_prompt}
|
90 |
)
|
91 |
|
92 |
+
response.raise_for_status()
|
93 |
return response.json()["description"]
|
94 |
|
95 |
except requests.exceptions.RequestException as e:
|
|
|
98 |
raise gr.Error(f"An unexpected error occurred: {e}")
|
99 |
|
100 |
|
101 |
+
def generate_music_from_api(
|
102 |
+
description: str,
|
103 |
+
melody=None,
|
104 |
+
duration: int = 10,
|
105 |
+
model_version: str = "facebook/musicgen-stereo-melody-large",
|
106 |
+
topk: int = 250,
|
107 |
+
topp: float = 0,
|
108 |
+
temperature: float = 1.0,
|
109 |
+
cfg_coef: float = 3.0,
|
110 |
+
use_diffusion: bool = False,
|
111 |
+
):
|
112 |
+
"""Generates music using the MusicGen API."""
|
113 |
+
|
114 |
+
# Prepare the API request payload
|
115 |
+
payload = {
|
116 |
+
"description": description,
|
117 |
+
"duration": duration,
|
118 |
+
"model_version": model_version,
|
119 |
+
"topk": topk,
|
120 |
+
"topp": topp,
|
121 |
+
"temperature": temperature,
|
122 |
+
"cfg_coef": cfg_coef,
|
123 |
+
"use_diffusion": use_diffusion
|
124 |
+
}
|
125 |
+
|
126 |
+
# Handle melody if provided
|
127 |
+
if melody is not None:
|
128 |
+
sr, melody_data = melody
|
129 |
+
# Convert melody to base64 for API transmission
|
130 |
+
melody_bytes = melody_data.tobytes() if hasattr(melody_data, 'tobytes') else melody_data.tostring()
|
131 |
+
encoded_melody = base64.b64encode(melody_bytes).decode("utf-8")
|
132 |
+
payload["melody"] = encoded_melody
|
133 |
+
payload["melody_sample_rate"] = sr
|
134 |
+
|
135 |
+
try:
|
136 |
+
response = requests.post(f"{MUSICGEN_API_URL}/generate", json=payload)
|
137 |
+
response.raise_for_status()
|
138 |
+
|
139 |
+
result = response.json()
|
140 |
+
|
141 |
+
# Assuming API returns base64 encoded audio files
|
142 |
+
audio_data = base64.b64decode(result["audio"])
|
143 |
+
diffusion_audio_data = base64.b64decode(result.get("diffusion_audio", "")) if use_diffusion else None
|
144 |
+
|
145 |
+
# Save to temporary files
|
146 |
+
output_paths = []
|
147 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
148 |
+
file.write(audio_data)
|
149 |
+
output_paths.append(file.name)
|
150 |
+
file_cleaner.add(file.name)
|
151 |
+
|
152 |
+
if diffusion_audio_data:
|
153 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
154 |
+
file.write(diffusion_audio_data)
|
155 |
+
output_paths.append(file.name)
|
156 |
+
file_cleaner.add(file.name)
|
157 |
+
|
158 |
+
return output_paths[0], output_paths[1] if len(output_paths) > 1 else None
|
159 |
+
|
160 |
+
except requests.exceptions.RequestException as e:
|
161 |
+
raise gr.Error(f"Error communicating with MusicGen API: {e}")
|
162 |
+
except Exception as e:
|
163 |
+
raise gr.Error(f"An unexpected error occurred: {e}")
|
164 |
|
165 |
|
166 |
+
# --- Music Generation ---
|
167 |
+
|
168 |
def predict_full(
|
169 |
model_version,
|
170 |
media_type,
|
|
|
180 |
decoder,
|
181 |
progress=gr.Progress(),
|
182 |
):
|
183 |
+
global INTERRUPTING
|
184 |
INTERRUPTING = False
|
185 |
+
use_diffusion = decoder == "MultiBand_Diffusion"
|
186 |
|
187 |
if media_type == "Image":
|
188 |
media = image_input if image_input else None
|
|
|
191 |
else:
|
192 |
media = None
|
193 |
|
194 |
+
# 1. Get Music Description (using the MLLM API)
|
195 |
progress(progress=None, desc="Generating music description...")
|
196 |
if media:
|
197 |
try:
|
198 |
music_description = get_mllm_description(media, text_prompt)
|
199 |
except Exception as e:
|
200 |
+
raise gr.Error(str(e))
|
201 |
else:
|
202 |
music_description = text_prompt
|
203 |
|
204 |
+
# 2. Generate music using MusicGen API
|
205 |
+
progress(progress=None, desc="Generating music via API...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
try:
|
207 |
+
output_audio_path, output_audio_mbd_path = generate_music_from_api(
|
208 |
+
description=music_description,
|
209 |
+
melody=melody,
|
210 |
+
duration=duration,
|
211 |
+
model_version=model_version,
|
212 |
+
topk=topk,
|
213 |
+
topp=topp,
|
214 |
+
temperature=temperature,
|
215 |
+
cfg_coef=cfg_coef,
|
216 |
+
use_diffusion=use_diffusion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
)
|
218 |
+
except Exception as e:
|
219 |
+
raise gr.Error(f"Error generating music: {e}")
|
|
|
|
|
|
|
220 |
|
221 |
+
if INTERRUPTING:
|
222 |
+
raise gr.Error("Generation interrupted.")
|
|
|
223 |
|
224 |
+
return output_audio_path, output_audio_mbd_path
|
225 |
|
226 |
|
227 |
Wave = theme()
|
|
|
293 |
)
|
294 |
with gr.Row():
|
295 |
submit_button = gr.Button("Generate Music", variant="primary")
|
296 |
+
interrupt_button = gr.Button("Interrupt", variant="stop")
|
|
|
|
|
297 |
with gr.Row():
|
298 |
model_version = gr.Dropdown(
|
299 |
[
|
|
|
326 |
interactive=True,
|
327 |
)
|
328 |
|
|
|
|
|
329 |
with gr.Row():
|
330 |
output_audio = gr.Audio(label="Generated Music", type="filepath")
|
331 |
output_audio_mbd = gr.Audio(
|
|
|
348 |
cfg_coef,
|
349 |
decoder,
|
350 |
],
|
|
|
351 |
outputs=[output_audio, output_audio_mbd],
|
352 |
)
|
353 |
interrupt_button.click(interrupt_handler, [], [])
|
|
|
|
|
354 |
|
355 |
gr.Examples(
|
356 |
examples=[
|
|
|
432 |
)
|
433 |
parser.add_argument(
|
434 |
"--server_port", type=int, default=0, help="Port to run the server on"
|
435 |
+
)
|
436 |
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
437 |
parser.add_argument("--share", action="store_true", help="Share the Gradio UI")
|
438 |
|
|
|
450 |
launch_kwargs["share"] = args.share
|
451 |
|
452 |
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
|
453 |
+
create_ui(launch_kwargs)
|