import time import torch from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write import gradio as gr from audiocraft.models import MusicGen from tempfile import NamedTemporaryFile from pathlib import Path def load_model(version='facebook/musicgen-melody'): return MusicGen.get_pretrained(version) def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs): print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() processed_melodies = [] for melody in melodies: if melody is None: processed_melodies.append(None) else: sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody[..., :int(sr * duration)] melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) try: if any(m is not None for m in processed_melodies): # melody condition outputs = model.generate_with_chroma( descriptions=texts, melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress, return_tokens=False ) else: # text only outputs = model.generate(texts, progress=progress, return_tokens=False) except RuntimeError as e: raise gr.Error("Error while generating " + e.args[0]) outputs = outputs.detach().cpu().float() pending_videos = [] out_wavs = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, model.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) out_wavs.append(file.name) print("generation finished", len(texts), time.time() - be) return out_wavs def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()): global INTERRUPTING global USE_DIFFUSION INTERRUPTING = False progress(0, desc="Loading model...") model_path = model_path.strip() if model_path: if not Path(model_path).exists(): raise gr.Error(f"Model path {model_path} doesn't exist.") if not Path(model_path).is_dir(): raise gr.Error(f"Model path {model_path} must be a folder containing " "state_dict.bin and compression_state_dict_.bin.") if temperature < 0: raise gr.Error("Temperature must be >= 0.") if topk < 0: raise gr.Error("Topk must be non-negative.") if topp < 0: raise gr.Error("Topp must be non-negative.") topk = int(topk) model = load_model(model_path) max_generated = 0 def _progress(generated, to_generate): nonlocal max_generated max_generated = max(generated, max_generated) progress((min(max_generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") model.set_custom_progress_callback(_progress) wavs = _do_predictions( [text], [melody], duration, progress=True, target_ac=1, target_sr=target_sr, top_k=topk, top_p=topp, temperature=temperature, gradio_progress=progress) return wavs