File size: 6,279 Bytes
9118de8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import matplotlib.pyplot as plt
import gradio as gr
from load import LoadModel
from generate import GenerateMidiText
from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
from decoder import TextDecoder
from utils import get_miditok, index_has_substring
from playback import get_music
from matplotlib import pylab
import sys
import matplotlib
from generation_utils import plot_piano_roll
import numpy as np

matplotlib.use("Agg")

sys.modules["pylab"] = pylab

model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
n_bar_generated = 8
# model_repo = "JammyMachina/improved_4bars-mdl"
# n_bar_generated = 4

model, tokenizer = LoadModel(
    model_repo,
    from_huggingface=True,
).load_model_and_tokenizer()

miditok = get_miditok()
decoder = TextDecoder(miditok)


def define_prompt(state, genesis):
    if len(state) == 0:
        input_prompt = "PIECE_START "
    else:
        input_prompt = genesis.get_whole_piece_from_bar_dict()
    return input_prompt


def generator(
    label,
    regenerate,
    temp,
    density,
    instrument,
    state,
    piece_by_track,
    add_bars=False,
    add_bar_count=1,
):

    genesis = GenerateMidiText(model, tokenizer, piece_by_track)
    track = {"label": label}
    inst = next(
        (
            inst
            for inst in INSTRUMENT_TRANSFER_CLASSES
            if inst["transfer_to"] == instrument
        ),
        {"family_number": "DRUMS"},
    )["family_number"]

    inst_index = -1  # default to last generated
    if state != []:
        for index, instrum in enumerate(state):
            if instrum["label"] == track["label"]:
                inst_index = index  # changing if exists

    # Generate
    if not add_bars:
        # Regenerate
        if regenerate:
            state.pop(inst_index)
            genesis.delete_one_track(inst_index)

            generated_text = (
                genesis.get_whole_piece_from_bar_dict()
            )  # maybe not useful here
            inst_index = -1  # reset to last generated

        # NEW TRACK
        input_prompt = define_prompt(state, genesis)
        generated_text = genesis.generate_one_new_track(
            inst, density, temp, input_prompt=input_prompt
        )

        regenerate = True  # set generate to true
    else:
        # NEW BARS
        genesis.generate_n_more_bars(add_bar_count)  # for all instruments
        generated_text = genesis.get_whole_piece_from_bar_dict()

    decoder.get_midi(generated_text, "mixed.mid")
    mixed_inst_midi, mixed_audio = get_music("mixed.mid")

    inst_text = genesis.get_selected_track_as_text(inst_index)
    inst_midi_name = f"{instrument}.mid"
    decoder.get_midi(inst_text, inst_midi_name)
    _, inst_audio = get_music(inst_midi_name)
    piano_roll = plot_piano_roll(mixed_inst_midi)
    track["text"] = inst_text
    state.append(track)

    return (
        inst_text,
        (44100, inst_audio),
        piano_roll,
        state,
        (44100, mixed_audio),
        regenerate,
        genesis.piece_by_track,
    )


def instrument_row(default_inst, row_id):
    with gr.Row():
        row = gr.Variable(row_id)
        with gr.Column(scale=1, min_width=100):
            inst = gr.Dropdown(
                sorted([inst["transfer_to"] for inst in INSTRUMENT_TRANSFER_CLASSES])
                + ["Drums"],
                value=default_inst,
                label="Instrument",
            )
            temp = gr.Dropdown(
                [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1],
                value=0.7,
                label="Creativity",
            )
            density = gr.Dropdown([1, 2, 3], value=3, label="Note Density")

        with gr.Column(scale=3):
            output_txt = gr.Textbox(
                label="output", lines=10, max_lines=10, show_label=False
            )
        with gr.Column(scale=1, min_width=100):
            inst_audio = gr.Audio(label="TRACK Audio", show_label=True)
            regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False)
            # add_bars = gr.Checkbox(value=False, label="Add Bars")
            # add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
            gen_btn = gr.Button("Generate")
            gen_btn.click(
                fn=generator,
                inputs=[row, regenerate, temp, density, inst, state, piece_by_track],
                outputs=[
                    output_txt,
                    inst_audio,
                    piano_roll,
                    state,
                    mixed_audio,
                    regenerate,
                    piece_by_track,
                ],
            )


with gr.Blocks() as demo:
    piece_by_track = gr.State([])
    state = gr.State([])
    title = gr.Markdown(
        """ # Demo-App of The-Jam-Machine
    A Generative AI trained on text transcription of MIDI music """
    )
    track1_md = gr.Markdown(""" ## Mixed Audio and Piano Roll """)
    mixed_audio = gr.Audio(label="Mixed Audio")
    piano_roll = gr.Plot(label="Piano Roll", show_label=False)
    description = gr.Markdown(
        """
        For each **TRACK**, choose your **instrument** along with **creativity** (temperature) and **note density**. Then, hit the **Generate** Button!
        You can have a look at the generated text; but most importantly, check the **piano roll** and listen to the TRACK audio!
        If you don't like the track, hit the generate button to regenerate it! Generate more tracks and listen to the **mixed audio**!
        """
    )
    track1_md = gr.Markdown(""" ## TRACK 1 """)
    instrument_row("Drums", 0)
    track1_md = gr.Markdown(""" ## TRACK 2 """)
    instrument_row("Synth Bass 1", 1)
    track1_md = gr.Markdown(""" ## TRACK 2 """)
    instrument_row("Synth Lead Square", 2)
    # instrument_row("Piano")

demo.launch(debug=True)

"""
TODO: reset button
TODO: add a button to save the generated midi
TODO: add improvise button
TODO: set values for temperature as it is done for density
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
TODO: row height to fix

TODO: reset state of tick boxes after used maybe (regenerate, add bars) ; 
TODO: block regenerate if add bar on
"""