suric's picture
add transcribe button
4336e0a
raw
history blame
4.55 kB
import time
import torch
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
import gradio as gr
from audiocraft.models import MusicGen
from tempfile import NamedTemporaryFile
from pathlib import Path
from transformers import AutoModelForSeq2SeqLM
import basic_pitch
import basic_pitch.inference
from basic_pitch import ICASSP_2022_MODEL_PATH
def load_model(version='facebook/musicgen-melody'):
return MusicGen.get_pretrained(version)
def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs):
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
be = time.time()
processed_melodies = []
for melody in melodies:
if melody is None:
processed_melodies.append(None)
else:
sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
print(f"Input audio sample rate is {sr}")
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., :int(sr * duration)]
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
try:
if any(m is not None for m in processed_melodies):
# melody condition
outputs = model.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
return_tokens=False
)
else:
# text only
outputs = model.generate(texts, progress=progress, return_tokens=False)
except RuntimeError as e:
raise gr.Error("Error while generating " + e.args[0])
outputs = outputs.detach().cpu().float()
pending_videos = []
out_wavs = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name, output, model.sample_rate, strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
out_wavs.append(file.name)
print("generation finished", len(texts), time.time() - be)
return out_wavs
def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()):
global INTERRUPTING
global USE_DIFFUSION
INTERRUPTING = False
progress(0, desc="Loading model...")
model_path = model_path.strip()
# if model_path:
# if not Path(model_path).exists():
# raise gr.Error(f"Model path {model_path} doesn't exist.")
# if not Path(model_path).is_dir():
# raise gr.Error(f"Model path {model_path} must be a folder containing "
# "state_dict.bin and compression_state_dict_.bin.")
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
raise gr.Error("Topk must be non-negative.")
if topp < 0:
raise gr.Error("Topp must be non-negative.")
topk = int(topk)
model = load_model(model_path)
max_generated = 0
def _progress(generated, to_generate):
nonlocal max_generated
max_generated = max(generated, max_generated)
progress((min(max_generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
model.set_custom_progress_callback(_progress)
wavs = _do_predictions(
model,
[text],
[melody],
duration,
progress=True,
target_ac=1,
target_sr=target_sr,
top_k=topk,
top_p=topp,
temperature=temperature,
gradio_progress=progress)
return wavs[0]
def transcribe(audio_path):
# model_output, midi_data, note_events = predict("generated_0.wav")
model_output, midi_data, note_events = basic_pitch.inference.predict(
audio_path=audio_path,
model_or_model_path=ICASSP_2022_MODEL_PATH,
)
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
try:
midi_data.write(file)
print(f"midi file saved to {file.name}")
except Exception as e:
print(f"Error while writing midi file: {e}")
raise e
return gr.DownloadButton(
value=file.name,
label=f"Download MIDI file {file.name}",
visible=True)