File size: 4,550 Bytes
c318a73
 
 
 
 
 
 
 
 
 
97a428f
4336e0a
 
 
c318a73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97a428f
c318a73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97a428f
 
 
 
 
 
c318a73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e1d74a
c318a73
 
 
 
 
 
 
 
 
 
4336e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import time
import torch

from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
import gradio as gr
from audiocraft.models import MusicGen

from tempfile import NamedTemporaryFile
from pathlib import Path
from transformers import AutoModelForSeq2SeqLM
import basic_pitch
import basic_pitch.inference
from basic_pitch import ICASSP_2022_MODEL_PATH


def load_model(version='facebook/musicgen-melody'):
    return MusicGen.get_pretrained(version)


def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs):
    print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
    be = time.time()
    processed_melodies = []
    for melody in melodies:
        if melody is None:
            processed_melodies.append(None)
        else:
            sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
            print(f"Input audio sample rate is {sr}")
            if melody.dim() == 1:
                melody = melody[None]
            melody = melody[..., :int(sr * duration)]
            melody = convert_audio(melody, sr, target_sr, target_ac)
            processed_melodies.append(melody)

    try:
        if any(m is not None for m in processed_melodies):
            # melody condition
            outputs = model.generate_with_chroma(
                descriptions=texts,
                melody_wavs=processed_melodies,
                melody_sample_rate=target_sr,
                progress=progress,
                return_tokens=False
            )
        else:
            # text only
            outputs = model.generate(texts, progress=progress, return_tokens=False)
    except RuntimeError as e:
        raise gr.Error("Error while generating " + e.args[0])
    outputs = outputs.detach().cpu().float()
    pending_videos = []
    out_wavs = []
    for output in outputs:
        with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
            audio_write(
                file.name, output, model.sample_rate, strategy="loudness",
                loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
            out_wavs.append(file.name)
    print("generation finished", len(texts), time.time() - be)
    return out_wavs


def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()):
    global INTERRUPTING
    global USE_DIFFUSION
    INTERRUPTING = False
    progress(0, desc="Loading model...")
    model_path = model_path.strip()
    # if model_path:
    #     if not Path(model_path).exists():
    #         raise gr.Error(f"Model path {model_path} doesn't exist.")
    #     if not Path(model_path).is_dir():
    #         raise gr.Error(f"Model path {model_path} must be a folder containing "
    #                        "state_dict.bin and compression_state_dict_.bin.")
    if temperature < 0:
        raise gr.Error("Temperature must be >= 0.")
    if topk < 0:
        raise gr.Error("Topk must be non-negative.")
    if topp < 0:
        raise gr.Error("Topp must be non-negative.")

    topk = int(topk)
    model = load_model(model_path)

    max_generated = 0

    def _progress(generated, to_generate):
        nonlocal max_generated
        max_generated = max(generated, max_generated)
        progress((min(max_generated, to_generate), to_generate))
        if INTERRUPTING:
            raise gr.Error("Interrupted.")
    model.set_custom_progress_callback(_progress)

    wavs = _do_predictions(
        model,
        [text],
        [melody],
        duration,
        progress=True,
        target_ac=1,
        target_sr=target_sr,
        top_k=topk,
        top_p=topp,
        temperature=temperature,
        gradio_progress=progress)
    return wavs[0]


def transcribe(audio_path):
    # model_output, midi_data, note_events = predict("generated_0.wav")
    model_output, midi_data, note_events = basic_pitch.inference.predict(
        audio_path=audio_path,
        model_or_model_path=ICASSP_2022_MODEL_PATH,
        )

    with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
        try:
            midi_data.write(file)
            print(f"midi file saved to {file.name}")
        except Exception as e:
            print(f"Error while writing midi file: {e}")
            raise e

    return gr.DownloadButton(
        value=file.name,
        label=f"Download MIDI file {file.name}",
        visible=True)