Spaces:
Sleeping
Sleeping
import time | |
import torch | |
from audiocraft.data.audio_utils import convert_audio | |
from audiocraft.data.audio import audio_write | |
import gradio as gr | |
from audiocraft.models import MusicGen | |
from tempfile import NamedTemporaryFile | |
from pathlib import Path | |
from transformers import AutoModelForSeq2SeqLM | |
import basic_pitch | |
import basic_pitch.inference | |
from basic_pitch import ICASSP_2022_MODEL_PATH | |
def load_model(version='facebook/musicgen-melody'): | |
return MusicGen.get_pretrained(version) | |
def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs): | |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) | |
be = time.time() | |
processed_melodies = [] | |
for melody in melodies: | |
if melody is None: | |
processed_melodies.append(None) | |
else: | |
sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t() | |
print(f"Input audio sample rate is {sr}") | |
if melody.dim() == 1: | |
melody = melody[None] | |
melody = melody[..., :int(sr * duration)] | |
melody = convert_audio(melody, sr, target_sr, target_ac) | |
processed_melodies.append(melody) | |
try: | |
if any(m is not None for m in processed_melodies): | |
# melody condition | |
outputs = model.generate_with_chroma( | |
descriptions=texts, | |
melody_wavs=processed_melodies, | |
melody_sample_rate=target_sr, | |
progress=progress, | |
return_tokens=False | |
) | |
else: | |
# text only | |
outputs = model.generate(texts, progress=progress, return_tokens=False) | |
except RuntimeError as e: | |
raise gr.Error("Error while generating " + e.args[0]) | |
outputs = outputs.detach().cpu().float() | |
pending_videos = [] | |
out_wavs = [] | |
for output in outputs: | |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: | |
audio_write( | |
file.name, output, model.sample_rate, strategy="loudness", | |
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) | |
out_wavs.append(file.name) | |
print("generation finished", len(texts), time.time() - be) | |
return out_wavs | |
def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()): | |
global INTERRUPTING | |
global USE_DIFFUSION | |
INTERRUPTING = False | |
progress(0, desc="Loading model...") | |
model_path = model_path.strip() | |
# if model_path: | |
# if not Path(model_path).exists(): | |
# raise gr.Error(f"Model path {model_path} doesn't exist.") | |
# if not Path(model_path).is_dir(): | |
# raise gr.Error(f"Model path {model_path} must be a folder containing " | |
# "state_dict.bin and compression_state_dict_.bin.") | |
if temperature < 0: | |
raise gr.Error("Temperature must be >= 0.") | |
if topk < 0: | |
raise gr.Error("Topk must be non-negative.") | |
if topp < 0: | |
raise gr.Error("Topp must be non-negative.") | |
topk = int(topk) | |
model = load_model(model_path) | |
max_generated = 0 | |
def _progress(generated, to_generate): | |
nonlocal max_generated | |
max_generated = max(generated, max_generated) | |
progress((min(max_generated, to_generate), to_generate)) | |
if INTERRUPTING: | |
raise gr.Error("Interrupted.") | |
model.set_custom_progress_callback(_progress) | |
wavs = _do_predictions( | |
model, | |
[text], | |
[melody], | |
duration, | |
progress=True, | |
target_ac=1, | |
target_sr=target_sr, | |
top_k=topk, | |
top_p=topp, | |
temperature=temperature, | |
gradio_progress=progress) | |
return wavs[0] | |
def transcribe(audio_path): | |
# model_output, midi_data, note_events = predict("generated_0.wav") | |
model_output, midi_data, note_events = basic_pitch.inference.predict( | |
audio_path=audio_path, | |
model_or_model_path=ICASSP_2022_MODEL_PATH, | |
) | |
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file: | |
try: | |
midi_data.write(file) | |
print(f"midi file saved to {file.name}") | |
except Exception as e: | |
print(f"Error while writing midi file: {e}") | |
raise e | |
return gr.DownloadButton( | |
value=file.name, | |
label=f"Download MIDI file {file.name}", | |
visible=True) |