Spaces:
Sleeping
Sleeping
File size: 4,550 Bytes
c318a73 97a428f 4336e0a c318a73 97a428f c318a73 97a428f c318a73 3e1d74a c318a73 4336e0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import time
import torch
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
import gradio as gr
from audiocraft.models import MusicGen
from tempfile import NamedTemporaryFile
from pathlib import Path
from transformers import AutoModelForSeq2SeqLM
import basic_pitch
import basic_pitch.inference
from basic_pitch import ICASSP_2022_MODEL_PATH
def load_model(version='facebook/musicgen-melody'):
return MusicGen.get_pretrained(version)
def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs):
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
be = time.time()
processed_melodies = []
for melody in melodies:
if melody is None:
processed_melodies.append(None)
else:
sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
print(f"Input audio sample rate is {sr}")
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., :int(sr * duration)]
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
try:
if any(m is not None for m in processed_melodies):
# melody condition
outputs = model.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
return_tokens=False
)
else:
# text only
outputs = model.generate(texts, progress=progress, return_tokens=False)
except RuntimeError as e:
raise gr.Error("Error while generating " + e.args[0])
outputs = outputs.detach().cpu().float()
pending_videos = []
out_wavs = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name, output, model.sample_rate, strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
out_wavs.append(file.name)
print("generation finished", len(texts), time.time() - be)
return out_wavs
def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()):
global INTERRUPTING
global USE_DIFFUSION
INTERRUPTING = False
progress(0, desc="Loading model...")
model_path = model_path.strip()
# if model_path:
# if not Path(model_path).exists():
# raise gr.Error(f"Model path {model_path} doesn't exist.")
# if not Path(model_path).is_dir():
# raise gr.Error(f"Model path {model_path} must be a folder containing "
# "state_dict.bin and compression_state_dict_.bin.")
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
raise gr.Error("Topk must be non-negative.")
if topp < 0:
raise gr.Error("Topp must be non-negative.")
topk = int(topk)
model = load_model(model_path)
max_generated = 0
def _progress(generated, to_generate):
nonlocal max_generated
max_generated = max(generated, max_generated)
progress((min(max_generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
model.set_custom_progress_callback(_progress)
wavs = _do_predictions(
model,
[text],
[melody],
duration,
progress=True,
target_ac=1,
target_sr=target_sr,
top_k=topk,
top_p=topp,
temperature=temperature,
gradio_progress=progress)
return wavs[0]
def transcribe(audio_path):
# model_output, midi_data, note_events = predict("generated_0.wav")
model_output, midi_data, note_events = basic_pitch.inference.predict(
audio_path=audio_path,
model_or_model_path=ICASSP_2022_MODEL_PATH,
)
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
try:
midi_data.write(file)
print(f"midi file saved to {file.name}")
except Exception as e:
print(f"Error while writing midi file: {e}")
raise e
return gr.DownloadButton(
value=file.name,
label=f"Download MIDI file {file.name}",
visible=True) |