Spaces:
Runtime error
Runtime error
File size: 5,542 Bytes
9118de8 175ee87 9118de8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from constants import INSTRUMENT_CLASSES
from playback import get_music, show_piano_roll
# matplotlib settings
matplotlib.use("Agg") # for server
matplotlib.rcParams["xtick.major.size"] = 0
matplotlib.rcParams["ytick.major.size"] = 0
matplotlib.rcParams["axes.facecolor"] = "none"
matplotlib.rcParams["axes.edgecolor"] = "grey"
def define_generation_dir(model_repo_path):
#### to remove later ####
if model_repo_path == "models/model_2048_fake_wholedataset":
model_repo_path = "misnaej/the-jam-machine-wdtef6l"
#### to remove later ####
generated_sequence_files_path = f"midi/generated/{model_repo_path}"
if not os.path.exists(generated_sequence_files_path):
os.makedirs(generated_sequence_files_path)
return generated_sequence_files_path
def bar_count_check(sequence, n_bars):
"""check if the sequence contains the right number of bars"""
sequence = sequence.split(" ")
# find occurences of "BAR_END" in a "sequence"
# I don't check for "BAR_START" because it is not always included in "sequence"
# e.g. BAR_START is included the prompt when generating one more bar
bar_count = 0
for seq in sequence:
if seq == "BAR_END":
bar_count += 1
bar_count_matches = bar_count == n_bars
if not bar_count_matches:
print(f"Bar count is {bar_count} - but should be {n_bars}")
return bar_count_matches, bar_count
def print_inst_classes(INSTRUMENT_CLASSES):
"""Print the instrument classes"""
for classe in INSTRUMENT_CLASSES:
print(f"{classe}")
def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
"""Check if the prompt instrument are in the tokenizer vocab"""
for inst in inst_prompt_list:
if f"INST={inst}" not in tokenizer.vocab:
instruments_in_dataset = np.sort(
[tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok]
)
print_inst_classes(INSTRUMENT_CLASSES)
raise ValueError(
f"""The instrument {inst} is not in the tokenizer vocabulary.
Available Instruments: {instruments_in_dataset}"""
)
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
"""Forcing the generated sequence to have the expected length
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
if bar_count - expected_length > 0: # Cut the sequence if too long
full_piece = ""
splited = generated.split("BAR_END ")
for count, spl in enumerate(splited):
if count < expected_length:
full_piece += spl + "BAR_END "
full_piece += "TRACK_END "
full_piece = input_prompt + full_piece
print(f"Generated sequence trunkated at {expected_length} bars")
bar_count_checks = True
elif bar_count - expected_length < 0: # Do nothing it the sequence if too short
full_piece = input_prompt + generated
bar_count_checks = False
print(f"--- Generated sequence is too short - Force Regeration ---")
return full_piece, bar_count_checks
def get_max_time(inst_midi):
max_time = 0
for inst in inst_midi.instruments:
max_time = max(max_time, inst.get_end_time())
return max_time
def plot_piano_roll(inst_midi):
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
piano_roll_fig.tight_layout()
piano_roll_fig.patch.set_alpha(0)
inst_count = 0
beats_per_bar = 4
sec_per_beat = 0.5
next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0]
bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype(
int
)
for inst in inst_midi.instruments:
# hardcoded for now
if inst.name == "Drums":
color = "purple"
elif inst.name == "Synth Bass 1":
color = "orange"
else:
color = "green"
inst_count += 1
plt.subplot(len(inst_midi.instruments), 1, inst_count)
for bar in bars_time:
plt.axvline(bar, color="grey", linewidth=0.5)
octaves = np.arange(0, 128, 12)
for octave in octaves:
plt.axhline(octave, color="grey", linewidth=0.5)
plt.yticks(octaves, visible=False)
p_midi_note_list = inst.notes
note_time = []
note_pitch = []
for note in p_midi_note_list:
note_time.append([note.start, note.end])
note_pitch.append([note.pitch, note.pitch])
note_pitch = np.array(note_pitch)
note_time = np.array(note_time)
plt.plot(
note_time.T,
note_pitch.T,
color=color,
linewidth=4,
solid_capstyle="butt",
)
plt.ylim(0, 128)
xticks = np.array(bars_time)[:-1]
plt.tight_layout()
plt.xlim(min(bars_time), max(bars_time))
plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5)
plt.xticks(
xticks + 0.5 * beats_per_bar * sec_per_beat,
labels=xticks.argsort() + 1,
visible=False,
)
plt.text(
0.2,
note_pitch.max() + 4,
inst.name,
fontsize=20,
color=color,
horizontalalignment="left",
verticalalignment="top",
)
return piano_roll_fig
|