Audiofool commited on
Commit
5f28c4a
·
1 Parent(s): f800d5f

update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -155
app.py CHANGED
@@ -9,29 +9,17 @@ import base64
9
  from pathlib import Path
10
  from tempfile import NamedTemporaryFile
11
 
12
- from einops import rearrange
13
- import torch
14
  import gradio as gr
15
  import requests
16
 
17
- from audiocraft.data.audio_utils import convert_audio
18
- from audiocraft.data.audio import audio_write
19
- from audiocraft.models.encodec import InterleaveStereoCompressionModel
20
- from audiocraft.models import MusicGen, MultiBandDiffusion
21
-
22
  from theme_wave import theme, css
23
 
24
  # --- Configuration (Main App) ---
25
- MLLM_API_URL = (
26
- "http://localhost:8000"
27
- )
28
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  # --- Global Variables (Main App) ---
31
- MODEL = None
32
- MBD = None
33
  INTERRUPTING = False
34
- USE_DIFFUSION = False # Keep this for now, even if unused, for easier switching
35
 
36
 
37
  # --- Utility Functions (Main App) ---
@@ -72,29 +60,8 @@ def make_waveform(*args, **kwargs):
72
  return gr.make_waveform(*args, **kwargs)
73
 
74
 
75
- # --- Model Loading (Main App) ---
76
-
77
-
78
- def load_musicgen_model(version="facebook/musicgen-stereo-melody-large"):
79
- global MODEL
80
- print(f"Loading MusicGen model: {version}")
81
- if MODEL is None or MODEL.name != version:
82
- if MODEL is not None:
83
- del MODEL
84
- torch.cuda.empty_cache()
85
- MODEL = MusicGen.get_pretrained(version, device=DEVICE)
86
-
87
-
88
- def load_diffusion_model():
89
- global MBD
90
- if MBD is None:
91
- print("Loading diffusion model")
92
- MBD = MultiBandDiffusion.get_mbd_musicgen(device=DEVICE)
93
-
94
-
95
  # --- API Client Functions ---
96
 
97
-
98
  def get_mllm_description(media_path: str, user_prompt: str) -> str:
99
  """Gets the music description from the MLLM API."""
100
 
@@ -122,7 +89,7 @@ def get_mllm_description(media_path: str, user_prompt: str) -> str:
122
  f"{MLLM_API_URL}/describe_text/", json={"user_prompt": user_prompt}
123
  )
124
 
125
- response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx).
126
  return response.json()["description"]
127
 
128
  except requests.exceptions.RequestException as e:
@@ -131,9 +98,73 @@ def get_mllm_description(media_path: str, user_prompt: str) -> str:
131
  raise gr.Error(f"An unexpected error occurred: {e}")
132
 
133
 
134
- # --- Music Generation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
 
 
137
  def predict_full(
138
  model_version,
139
  media_type,
@@ -149,9 +180,9 @@ def predict_full(
149
  decoder,
150
  progress=gr.Progress(),
151
  ):
152
- global INTERRUPTING, USE_DIFFUSION
153
  INTERRUPTING = False
154
- USE_DIFFUSION = decoder == "MultiBand_Diffusion"
155
 
156
  if media_type == "Image":
157
  media = image_input if image_input else None
@@ -160,124 +191,37 @@ def predict_full(
160
  else:
161
  media = None
162
 
163
- # 1. Get Music Description (using the API client).
164
  progress(progress=None, desc="Generating music description...")
165
  if media:
166
  try:
167
  music_description = get_mllm_description(media, text_prompt)
168
  except Exception as e:
169
- raise gr.Error(str(e)) # Re-raise for Gradio to handle.
170
  else:
171
  music_description = text_prompt
172
 
173
- # 2. Load MusicGen Model (locally).
174
- progress(progress=None, desc="Loading MusicGen model...")
175
- load_musicgen_model(model_version)
176
-
177
- # 3. Set Generation Parameters (locally).
178
- MODEL.set_generation_params(
179
- duration=duration,
180
- top_k=topk,
181
- top_p=topp,
182
- temperature=temperature,
183
- cfg_coef=cfg_coef,
184
- )
185
-
186
- # 4. Melody Preprocessing (locally).
187
- progress(progress=None, desc="Processing melody...")
188
- melody_tensor = None # Use a different variable name
189
- if melody:
190
- try:
191
- sr, melody_tensor = (
192
- melody[0],
193
- torch.from_numpy(melody[1]).to(MODEL.device).float().t(),
194
- )
195
- if melody_tensor.dim() == 1:
196
- melody_tensor = melody_tensor[None]
197
- melody_tensor = melody_tensor[..., : int(sr * duration)]
198
- melody_tensor = convert_audio(
199
- melody_tensor, sr, MODEL.sample_rate, MODEL.audio_channels
200
- )
201
-
202
- except Exception as e:
203
- raise gr.Error(f"Error processing melody: {e}")
204
-
205
- # 5. Music Generation (locally).
206
- progress(progress=None, desc="Generating music...")
207
- if USE_DIFFUSION:
208
- load_diffusion_model()
209
-
210
  try:
211
- if melody_tensor is not None: # Use the new variable
212
- output = MODEL.generate_with_chroma(
213
- descriptions=[music_description],
214
- melody_wavs=[melody_tensor],
215
- melody_sample_rate=MODEL.sample_rate,
216
- progress=True,
217
- return_tokens=USE_DIFFUSION,
218
- )
219
- else:
220
- output = MODEL.generate(
221
- descriptions=[music_description],
222
- progress=True,
223
- return_tokens=USE_DIFFUSION,
224
- )
225
- except RuntimeError as e:
226
- raise gr.Error("Error while generating: " + str(e))
227
-
228
- if USE_DIFFUSION:
229
- progress(progress=None, desc="Running MultiBandDiffusion...")
230
- tokens = output[1]
231
- if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
232
- left, right = MODEL.compression_model.get_left_right_codes(tokens)
233
- tokens = torch.cat([left, right])
234
- outputs_diffusion = MBD.tokens_to_wav(tokens)
235
- if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
236
- assert outputs_diffusion.shape[1] == 1 # output is mono
237
- outputs_diffusion = rearrange(
238
- outputs_diffusion, "(s b) c t -> b (s c) t", s=2
239
- )
240
- output_audio = torch.cat([output[0], outputs_diffusion], dim=0)
241
- else:
242
- output_audio = output[0]
243
-
244
- output_audio = output_audio.detach().cpu().float()
245
-
246
- # 6. Save and Return (locally).
247
- progress(progress=None, desc="Saving and returning...")
248
- output_audio_paths = []
249
-
250
- for i, audio in enumerate(output_audio):
251
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
252
- audio_write(
253
- file.name,
254
- audio,
255
- MODEL.sample_rate,
256
- strategy="loudness",
257
- loudness_headroom_db=16,
258
- loudness_compressor=True,
259
- add_suffix=False,
260
- )
261
- output_audio_paths.append(file.name)
262
- file_cleaner.add(file.name)
263
-
264
- if USE_DIFFUSION:
265
- # Return both audios, but make sure to return the correct one first
266
- result = (
267
- output_audio_paths[0], # Original
268
- output_audio_paths[1], # MBD
269
  )
270
- else:
271
- result = (
272
- output_audio_paths[0],
273
- None,
274
- ) # Only original audio and description
275
 
276
- del melody_tensor, output, output_audio
277
- if torch.cuda.is_available():
278
- torch.cuda.empty_cache()
279
 
280
- return result
281
 
282
 
283
  Wave = theme()
@@ -349,9 +293,7 @@ def create_ui(launch_kwargs=None):
349
  )
350
  with gr.Row():
351
  submit_button = gr.Button("Generate Music", variant="primary")
352
- interrupt_button = gr.Button(
353
- "Interrupt", variant="stop"
354
- ) # Keep as gr.Button
355
  with gr.Row():
356
  model_version = gr.Dropdown(
357
  [
@@ -384,8 +326,6 @@ def create_ui(launch_kwargs=None):
384
  interactive=True,
385
  )
386
 
387
- # with gr.Row():
388
- # description_output = gr.Textbox(label="MLLM Generated Description")
389
  with gr.Row():
390
  output_audio = gr.Audio(label="Generated Music", type="filepath")
391
  output_audio_mbd = gr.Audio(
@@ -408,12 +348,9 @@ def create_ui(launch_kwargs=None):
408
  cfg_coef,
409
  decoder,
410
  ],
411
- # outputs=[output_audio, description_output, output_audio_mbd],
412
  outputs=[output_audio, output_audio_mbd],
413
  )
414
  interrupt_button.click(interrupt_handler, [], [])
415
- if INTERRUPTING:
416
- raise gr.Error("Interrupted.")
417
 
418
  gr.Examples(
419
  examples=[
@@ -495,7 +432,7 @@ if __name__ == "__main__":
495
  )
496
  parser.add_argument(
497
  "--server_port", type=int, default=0, help="Port to run the server on"
498
- ) # Add server_port argument.
499
  parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
500
  parser.add_argument("--share", action="store_true", help="Share the Gradio UI")
501
 
@@ -513,4 +450,4 @@ if __name__ == "__main__":
513
  launch_kwargs["share"] = args.share
514
 
515
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
516
- create_ui(launch_kwargs)
 
9
  from pathlib import Path
10
  from tempfile import NamedTemporaryFile
11
 
 
 
12
  import gradio as gr
13
  import requests
14
 
 
 
 
 
 
15
  from theme_wave import theme, css
16
 
17
  # --- Configuration (Main App) ---
18
+ MLLM_API_URL = "http://localhost:8000"
19
+ MUSICGEN_API_URL = "https://your-musicgen-api-endpoint.com" # Replace with actual MusicGen API endpoint
 
 
20
 
21
  # --- Global Variables (Main App) ---
 
 
22
  INTERRUPTING = False
 
23
 
24
 
25
  # --- Utility Functions (Main App) ---
 
60
  return gr.make_waveform(*args, **kwargs)
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # --- API Client Functions ---
64
 
 
65
  def get_mllm_description(media_path: str, user_prompt: str) -> str:
66
  """Gets the music description from the MLLM API."""
67
 
 
89
  f"{MLLM_API_URL}/describe_text/", json={"user_prompt": user_prompt}
90
  )
91
 
92
+ response.raise_for_status()
93
  return response.json()["description"]
94
 
95
  except requests.exceptions.RequestException as e:
 
98
  raise gr.Error(f"An unexpected error occurred: {e}")
99
 
100
 
101
+ def generate_music_from_api(
102
+ description: str,
103
+ melody=None,
104
+ duration: int = 10,
105
+ model_version: str = "facebook/musicgen-stereo-melody-large",
106
+ topk: int = 250,
107
+ topp: float = 0,
108
+ temperature: float = 1.0,
109
+ cfg_coef: float = 3.0,
110
+ use_diffusion: bool = False,
111
+ ):
112
+ """Generates music using the MusicGen API."""
113
+
114
+ # Prepare the API request payload
115
+ payload = {
116
+ "description": description,
117
+ "duration": duration,
118
+ "model_version": model_version,
119
+ "topk": topk,
120
+ "topp": topp,
121
+ "temperature": temperature,
122
+ "cfg_coef": cfg_coef,
123
+ "use_diffusion": use_diffusion
124
+ }
125
+
126
+ # Handle melody if provided
127
+ if melody is not None:
128
+ sr, melody_data = melody
129
+ # Convert melody to base64 for API transmission
130
+ melody_bytes = melody_data.tobytes() if hasattr(melody_data, 'tobytes') else melody_data.tostring()
131
+ encoded_melody = base64.b64encode(melody_bytes).decode("utf-8")
132
+ payload["melody"] = encoded_melody
133
+ payload["melody_sample_rate"] = sr
134
+
135
+ try:
136
+ response = requests.post(f"{MUSICGEN_API_URL}/generate", json=payload)
137
+ response.raise_for_status()
138
+
139
+ result = response.json()
140
+
141
+ # Assuming API returns base64 encoded audio files
142
+ audio_data = base64.b64decode(result["audio"])
143
+ diffusion_audio_data = base64.b64decode(result.get("diffusion_audio", "")) if use_diffusion else None
144
+
145
+ # Save to temporary files
146
+ output_paths = []
147
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
148
+ file.write(audio_data)
149
+ output_paths.append(file.name)
150
+ file_cleaner.add(file.name)
151
+
152
+ if diffusion_audio_data:
153
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
154
+ file.write(diffusion_audio_data)
155
+ output_paths.append(file.name)
156
+ file_cleaner.add(file.name)
157
+
158
+ return output_paths[0], output_paths[1] if len(output_paths) > 1 else None
159
+
160
+ except requests.exceptions.RequestException as e:
161
+ raise gr.Error(f"Error communicating with MusicGen API: {e}")
162
+ except Exception as e:
163
+ raise gr.Error(f"An unexpected error occurred: {e}")
164
 
165
 
166
+ # --- Music Generation ---
167
+
168
  def predict_full(
169
  model_version,
170
  media_type,
 
180
  decoder,
181
  progress=gr.Progress(),
182
  ):
183
+ global INTERRUPTING
184
  INTERRUPTING = False
185
+ use_diffusion = decoder == "MultiBand_Diffusion"
186
 
187
  if media_type == "Image":
188
  media = image_input if image_input else None
 
191
  else:
192
  media = None
193
 
194
+ # 1. Get Music Description (using the MLLM API)
195
  progress(progress=None, desc="Generating music description...")
196
  if media:
197
  try:
198
  music_description = get_mllm_description(media, text_prompt)
199
  except Exception as e:
200
+ raise gr.Error(str(e))
201
  else:
202
  music_description = text_prompt
203
 
204
+ # 2. Generate music using MusicGen API
205
+ progress(progress=None, desc="Generating music via API...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  try:
207
+ output_audio_path, output_audio_mbd_path = generate_music_from_api(
208
+ description=music_description,
209
+ melody=melody,
210
+ duration=duration,
211
+ model_version=model_version,
212
+ topk=topk,
213
+ topp=topp,
214
+ temperature=temperature,
215
+ cfg_coef=cfg_coef,
216
+ use_diffusion=use_diffusion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  )
218
+ except Exception as e:
219
+ raise gr.Error(f"Error generating music: {e}")
 
 
 
220
 
221
+ if INTERRUPTING:
222
+ raise gr.Error("Generation interrupted.")
 
223
 
224
+ return output_audio_path, output_audio_mbd_path
225
 
226
 
227
  Wave = theme()
 
293
  )
294
  with gr.Row():
295
  submit_button = gr.Button("Generate Music", variant="primary")
296
+ interrupt_button = gr.Button("Interrupt", variant="stop")
 
 
297
  with gr.Row():
298
  model_version = gr.Dropdown(
299
  [
 
326
  interactive=True,
327
  )
328
 
 
 
329
  with gr.Row():
330
  output_audio = gr.Audio(label="Generated Music", type="filepath")
331
  output_audio_mbd = gr.Audio(
 
348
  cfg_coef,
349
  decoder,
350
  ],
 
351
  outputs=[output_audio, output_audio_mbd],
352
  )
353
  interrupt_button.click(interrupt_handler, [], [])
 
 
354
 
355
  gr.Examples(
356
  examples=[
 
432
  )
433
  parser.add_argument(
434
  "--server_port", type=int, default=0, help="Port to run the server on"
435
+ )
436
  parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
437
  parser.add_argument("--share", action="store_true", help="Share the Gradio UI")
438
 
 
450
  launch_kwargs["share"] = args.share
451
 
452
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
453
+ create_ui(launch_kwargs)