skytnt commited on
Commit
cfa70b3
1 Parent(s): 23810d5

add asigalov61/Music-Llama tv2o-large model

Browse files
Files changed (3) hide show
  1. app.py +57 -23
  2. javascript/app.js +29 -7
  3. midi_tokenizer.py +695 -33
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import argparse
2
  import glob
 
3
  import os.path
4
  import time
5
- import uuid
6
 
7
  import gradio as gr
8
  import numpy as np
9
  import onnxruntime as rt
10
  import tqdm
11
- import json
12
  from huggingface_hub import hf_hub_download
13
 
14
  import MIDI
@@ -47,6 +46,7 @@ def sample_top_p_k(probs, p, k, generator=None):
47
 
48
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
49
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
 
50
  if disable_channels is not None:
51
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
52
  else:
@@ -121,10 +121,11 @@ def send_msgs(msgs):
121
  return json.dumps(msgs)
122
 
123
 
124
- def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
125
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
- mid_seq = []
 
128
  bpm = int(bpm)
129
  gen_events = int(gen_events)
130
  max_len = gen_events
@@ -137,7 +138,7 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
137
  i = 0
138
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
139
  if bpm != 0:
140
- mid.append(tokenizer.event2tokens(["set_tempo",0,0,0, bpm]))
141
  patches = {}
142
  if instruments is None:
143
  instruments = []
@@ -153,7 +154,7 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
153
  if len(instruments) > 0:
154
  disable_patch_change = True
155
  disable_channels = [i for i in range(16) if i not in patches]
156
- elif mid is not None:
157
  eps = 4 if reduce_cc_st else 0
158
  mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
159
  remap_track_channel=remap_track_channel,
@@ -161,14 +162,26 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
161
  remove_empty_channels=remove_empty_channels)
162
  mid = np.asarray(mid, dtype=np.int64)
163
  mid = mid[:int(midi_events)]
 
164
  for token_seq in mid:
165
  mid_seq.append(token_seq.tolist())
166
- max_len += len(mid)
 
 
 
 
 
 
 
 
167
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
168
- init_msgs = [create_msg("visualizer_clear", None), create_msg("visualizer_append", events)]
169
- t = time.time() + 1
 
 
 
170
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
171
- model = models[model_name]
172
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
173
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
174
  disable_channels=disable_channels, generator=generator)
@@ -191,9 +204,10 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
191
  yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
192
 
193
 
194
- def cancel_run(mid_seq):
195
  if mid_seq is None:
196
  return None, None, []
 
197
  mid = tokenizer.detokenize(mid_seq)
198
  with open(f"output.mid", 'wb') as f:
199
  f.write(MIDI.score2midi(mid))
@@ -233,6 +247,20 @@ def hf_hub_download_retry(repo_id, filename):
233
  if err:
234
  raise err
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
237
  40: "Blush", 48: "Orchestra"}
238
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
@@ -245,19 +273,20 @@ if __name__ == "__main__":
245
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
246
  opt = parser.parse_args()
247
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
248
- models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
249
- "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
250
- "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
 
251
  }
252
  models = {}
253
- tokenizer = MIDITokenizer()
254
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
255
- for name, (repo_id, path) in models_info.items():
256
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
257
  model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
258
  model_base = rt.InferenceSession(model_base_path, providers=providers)
259
  model_token = rt.InferenceSession(model_token_path, providers=providers)
260
- models[name] = [model_base, model_token]
 
261
 
262
  load_javascript()
263
  app = gr.Blocks()
@@ -316,9 +345,12 @@ if __name__ == "__main__":
316
  input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
317
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
318
  [input_midi, input_midi_events])
 
 
319
 
320
  tab1.select(lambda: 0, None, tab_select, queue=False)
321
  tab2.select(lambda: 1, None, tab_select, queue=False)
 
322
  input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
323
  step=1, value=0)
324
  input_seed_rand = gr.Checkbox(label="random seed", value=True)
@@ -336,12 +368,14 @@ if __name__ == "__main__":
336
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
337
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
338
  output_midi = gr.File(label="output midi", file_types=[".mid"])
339
- run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
340
- input_midi, input_midi_events, input_reduce_cc_st, input_remap_track_channel,
341
- input_add_default_instr, input_remove_empty_channels, input_seed,
342
- input_seed_rand, input_gen_events, input_temp, input_top_p, input_top_k,
343
- input_allow_cc],
344
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
345
  concurrency_limit=3)
346
- stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
 
 
347
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
1
  import argparse
2
  import glob
3
+ import json
4
  import os.path
5
  import time
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
  import onnxruntime as rt
10
  import tqdm
 
11
  from huggingface_hub import hf_hub_download
12
 
13
  import MIDI
 
46
 
47
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
48
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
49
+ tokenizer = model[2]
50
  if disable_channels is not None:
51
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
52
  else:
 
121
  return json.dumps(msgs)
122
 
123
 
124
+ def run(model_name, tab, mid_seq, instruments, drum_kit, bpm, mid, midi_events,
125
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
+ model = models[model_name]
128
+ tokenizer = model[2]
129
  bpm = int(bpm)
130
  gen_events = int(gen_events)
131
  max_len = gen_events
 
138
  i = 0
139
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
140
  if bpm != 0:
141
+ mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
142
  patches = {}
143
  if instruments is None:
144
  instruments = []
 
154
  if len(instruments) > 0:
155
  disable_patch_change = True
156
  disable_channels = [i for i in range(16) if i not in patches]
157
+ elif tab == 1 and mid is not None:
158
  eps = 4 if reduce_cc_st else 0
159
  mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
160
  remap_track_channel=remap_track_channel,
 
162
  remove_empty_channels=remove_empty_channels)
163
  mid = np.asarray(mid, dtype=np.int64)
164
  mid = mid[:int(midi_events)]
165
+ mid_seq = []
166
  for token_seq in mid:
167
  mid_seq.append(token_seq.tolist())
168
+ elif tab == 2 and mid_seq is not None:
169
+ mid = np.asarray(mid_seq, dtype=np.int64)
170
+ else:
171
+ mid_seq = []
172
+ mid = None
173
+
174
+ if mid is not None:
175
+ max_len += len(mid)
176
+
177
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
178
+ if tab == 2:
179
+ init_msgs = [create_msg("visualizer_continue", tokenizer.version)]
180
+ else:
181
+ init_msgs = [create_msg("visualizer_clear", tokenizer.version),
182
+ create_msg("visualizer_append", events)]
183
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
184
+ t = time.time() + 1
185
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
186
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
187
  disable_channels=disable_channels, generator=generator)
 
204
  yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
205
 
206
 
207
+ def cancel_run(model_name, mid_seq):
208
  if mid_seq is None:
209
  return None, None, []
210
+ tokenizer = models[model_name][2]
211
  mid = tokenizer.detokenize(mid_seq)
212
  with open(f"output.mid", 'wb') as f:
213
  f.write(MIDI.score2midi(mid))
 
247
  if err:
248
  raise err
249
 
250
+ def get_tokenizer(config_name):
251
+ tv, size = config_name.split("-")
252
+ tv = tv[1:]
253
+ if tv[-1] == "o":
254
+ o = True
255
+ tv = tv[:-1]
256
+ else:
257
+ o = False
258
+ if tv not in ["v1", "v2"]:
259
+ raise ValueError(f"Unknown tokenizer version {tv}")
260
+ tokenizer = MIDITokenizer(tv)
261
+ tokenizer.set_optimise_midi(o)
262
+ return tokenizer
263
+
264
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
265
  40: "Blush", 48: "Orchestra"}
266
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
 
273
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
274
  opt = parser.parse_args()
275
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
276
+ models_info = {"generic pretrain model (tv2o-large)": ["asigalov61/Music-Llama", "", "tv2o-large"],
277
+ "generic pretrain model (tv1-medium)": ["skytnt/midi-model", "", "tv1-medium"],
278
+ "j-pop finetune model (tv1-medium)": ["skytnt/midi-model-ft", "jpop/", "tv1-medium"],
279
+ "touhou finetune model (tv1-medium)": ["skytnt/midi-model-ft", "touhou/", "tv1-medium"],
280
  }
281
  models = {}
 
282
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
283
+ for name, (repo_id, path, config) in models_info.items():
284
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
285
  model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
286
  model_base = rt.InferenceSession(model_base_path, providers=providers)
287
  model_token = rt.InferenceSession(model_token_path, providers=providers)
288
+ tokenizer = get_tokenizer(config)
289
+ models[name] = [model_base, model_token, tokenizer]
290
 
291
  load_javascript()
292
  app = gr.Blocks()
 
345
  input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
346
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
347
  [input_midi, input_midi_events])
348
+ with gr.TabItem("last output prompt") as tab3:
349
+ gr.Markdown("Continue generating on the last output. Just click the generate button")
350
 
351
  tab1.select(lambda: 0, None, tab_select, queue=False)
352
  tab2.select(lambda: 1, None, tab_select, queue=False)
353
+ tab3.select(lambda: 2, None, tab_select, queue=False)
354
  input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
355
  step=1, value=0)
356
  input_seed_rand = gr.Checkbox(label="random seed", value=True)
 
368
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
369
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
370
  output_midi = gr.File(label="output midi", file_types=[".mid"])
371
+ run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, input_instruments,
372
+ input_drum_kit, input_bpm, input_midi, input_midi_events, input_reduce_cc_st,
373
+ input_remap_track_channel, input_add_default_instr, input_remove_empty_channels,
374
+ input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
375
+ input_top_k, input_allow_cc],
376
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
377
  concurrency_limit=3)
378
+ stop_btn.click(cancel_run, [input_model,output_midi_seq],
379
+ [output_midi, output_audio, js_msg],
380
+ cancels=run_event, queue=False)
381
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
javascript/app.js CHANGED
@@ -117,9 +117,11 @@ class MidiVisualizer extends HTMLElement{
117
  this.totalTimeMs = 0
118
  this.playTime = 0
119
  this.playTimeMs = 0
 
120
  this.colorMap = new Map();
121
  this.playing = false;
122
  this.timer = null;
 
123
  this.init();
124
  }
125
 
@@ -157,6 +159,7 @@ class MidiVisualizer extends HTMLElement{
157
  this.setPlayTime(0);
158
  this.totalTimeMs = 0;
159
  this.playTimeMs = 0
 
160
  this.svgWidth = 0
161
  this.svg.innerHTML = ''
162
  this.svg.style.width = `${this.svgWidth}px`;
@@ -171,10 +174,22 @@ class MidiVisualizer extends HTMLElement{
171
  midiEvent = [midiEvent[0], t].concat(midiEvent.slice(3))
172
  if(midiEvent[0] === "note"){
173
  let track = midiEvent[2]
174
- let duration = midiEvent[3]
175
- let channel = midiEvent[4]
176
- let pitch = midiEvent[5]
177
- let velocity = midiEvent[6]
 
 
 
 
 
 
 
 
 
 
 
 
178
  let x = (t/this.timePreBeat)*this.config.beatWidth
179
  let y = (127 - pitch)*this.config.noteHeight
180
  let w = (duration/this.timePreBeat)*this.config.beatWidth
@@ -252,14 +267,14 @@ class MidiVisualizer extends HTMLElement{
252
  this.timeLine.setAttribute('y2', `${this.config.noteHeight*128}`);
253
 
254
  this.wrapper.scrollTo(Math.max(0, x - this.wrapper.offsetWidth/2), 0)
255
-
256
- if(this.playing){
257
  let activeNotes = []
258
  this.removeActiveNotes(this.activeNotes)
259
  this.midiEvents.forEach((midiEvent)=>{
260
  if(midiEvent[0] === "note"){
261
  let time = midiEvent[1]
262
- let duration = midiEvent[3]
263
  let note = midiEvent[midiEvent.length - 1]
264
  if(time <=this.playTime && time+duration>= this.playTime){
265
  activeNotes.push(note)
@@ -267,7 +282,9 @@ class MidiVisualizer extends HTMLElement{
267
  }
268
  })
269
  this.addActiveNotes(activeNotes)
 
270
  }
 
271
  }
272
 
273
  setPlayTimeMs(ms){
@@ -424,6 +441,11 @@ customElements.define('midi-visualizer', MidiVisualizer);
424
  switch (msg.name) {
425
  case "visualizer_clear":
426
  midi_visualizer.clearMidiEvents(false);
 
 
 
 
 
427
  createProgressBar(midi_visualizer_container_inited)
428
  break;
429
  case "visualizer_append":
 
117
  this.totalTimeMs = 0
118
  this.playTime = 0
119
  this.playTimeMs = 0
120
+ this.lastUpdateTime = 0
121
  this.colorMap = new Map();
122
  this.playing = false;
123
  this.timer = null;
124
+ this.version = "v2"
125
  this.init();
126
  }
127
 
 
159
  this.setPlayTime(0);
160
  this.totalTimeMs = 0;
161
  this.playTimeMs = 0
162
+ this.lastUpdateTime = 0
163
  this.svgWidth = 0
164
  this.svg.innerHTML = ''
165
  this.svg.style.width = `${this.svgWidth}px`;
 
174
  midiEvent = [midiEvent[0], t].concat(midiEvent.slice(3))
175
  if(midiEvent[0] === "note"){
176
  let track = midiEvent[2]
177
+ let duration = 0
178
+ let channel = 0
179
+ let pitch = 0
180
+ let velocity = 0
181
+ if(this.version === "v1"){
182
+ duration = midiEvent[3]
183
+ channel = midiEvent[4]
184
+ pitch = midiEvent[5]
185
+ velocity = midiEvent[6]
186
+ }else if (this.version === "v2"){
187
+ channel = midiEvent[3]
188
+ pitch = midiEvent[4]
189
+ velocity = midiEvent[5]
190
+ duration = midiEvent[6]
191
+ }
192
+
193
  let x = (t/this.timePreBeat)*this.config.beatWidth
194
  let y = (127 - pitch)*this.config.noteHeight
195
  let w = (duration/this.timePreBeat)*this.config.beatWidth
 
267
  this.timeLine.setAttribute('y2', `${this.config.noteHeight*128}`);
268
 
269
  this.wrapper.scrollTo(Math.max(0, x - this.wrapper.offsetWidth/2), 0)
270
+ let dt = Date.now() - this.lastUpdateTime; // limit the update rate of ActiveNotes
271
+ if(this.playing && dt > 50){
272
  let activeNotes = []
273
  this.removeActiveNotes(this.activeNotes)
274
  this.midiEvents.forEach((midiEvent)=>{
275
  if(midiEvent[0] === "note"){
276
  let time = midiEvent[1]
277
+ let duration = this.version==="v1"? midiEvent[3]:midiEvent[6]
278
  let note = midiEvent[midiEvent.length - 1]
279
  if(time <=this.playTime && time+duration>= this.playTime){
280
  activeNotes.push(note)
 
282
  }
283
  })
284
  this.addActiveNotes(activeNotes)
285
+ this.lastUpdateTime = Date.now();
286
  }
287
+
288
  }
289
 
290
  setPlayTimeMs(ms){
 
441
  switch (msg.name) {
442
  case "visualizer_clear":
443
  midi_visualizer.clearMidiEvents(false);
444
+ midi_visualizer.version = msg.data
445
+ createProgressBar(midi_visualizer_container_inited)
446
+ break;
447
+ case "visualizer_continue":
448
+ midi_visualizer.version = msg.data
449
  createProgressBar(midi_visualizer_container_inited)
450
  break;
451
  case "visualizer_append":
midi_tokenizer.py CHANGED
@@ -1,11 +1,13 @@
1
  import random
2
 
3
- import PIL
4
  import numpy as np
5
 
6
 
7
- class MIDITokenizer:
8
  def __init__(self):
 
 
9
  self.vocab_size = 0
10
 
11
  def allocate_ids(size):
@@ -31,26 +33,38 @@ class MIDITokenizer:
31
  self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
32
  self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
33
 
34
- def tempo2bpm(self, tempo):
 
 
 
 
35
  tempo = tempo / 10 ** 6 # us to s
36
  bpm = 60 / tempo
37
  return bpm
38
 
39
- def bpm2tempo(self, bpm):
 
40
  if bpm == 0:
41
  bpm = 1
42
  tempo = int((60 / bpm) * 10 ** 6)
43
  return tempo
44
 
45
  def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
46
- remap_track_channel=False, add_default_instr=False, remove_empty_channels=False):
 
 
 
 
 
 
 
47
  ticks_per_beat = midi_score[0]
48
  event_list = {}
49
  track_idx_map = {i: dict() for i in range(16)}
50
  track_idx_dict = {}
51
  channels = []
52
  patch_channels = []
53
- empty_channels = [True]*16
54
  channel_note_tracks = {i: list() for i in range(16)}
55
  for track_idx, track in enumerate(midi_score[1:129]):
56
  last_notes = {}
@@ -74,7 +88,7 @@ class MIDITokenizer:
74
  note_tracks.append(track_idx)
75
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
76
  elif event[0] == "set_tempo":
77
- if new_event[4] == 0: # invalid tempo
78
  continue
79
  bpm = int(self.tempo2bpm(new_event[4]))
80
  new_event[4] = min(bpm, 255)
@@ -143,8 +157,8 @@ class MIDITokenizer:
143
  channels = list(channels_map.values())
144
 
145
  track_count = 0
146
- track_idx_map_order = [k for k,v in sorted(list(channels_map.items()), key=lambda x: x[1])]
147
- for c in track_idx_map_order: # tracks not to remove
148
  if remove_empty_channels and c in empty_channels:
149
  continue
150
  tr_map = track_idx_map[c]
@@ -154,7 +168,7 @@ class MIDITokenizer:
154
  continue
155
  track_count += 1
156
  tr_map[track_idx] = track_count
157
- for c in track_idx_map_order: # tracks to remove
158
  if not (remove_empty_channels and c in empty_channels):
159
  continue
160
  tr_map = track_idx_map[c]
@@ -166,7 +180,7 @@ class MIDITokenizer:
166
  tr_map[track_idx] = track_count
167
 
168
  empty_channels = [channels_map[c] for c in empty_channels]
169
-
170
  for event in event_list:
171
  name = event[0]
172
  track_idx = event[3]
@@ -174,7 +188,8 @@ class MIDITokenizer:
174
  c = event[5]
175
  event[5] = channels_map[c]
176
  event[3] = track_idx_map[c][track_idx]
177
- track_idx_dict[event[5]] = event[3]
 
178
  elif name == "set_tempo":
179
  event[3] = 0
180
  elif name == "control_change" or name == "patch_change":
@@ -192,10 +207,10 @@ class MIDITokenizer:
192
 
193
  if add_default_instr:
194
  for c in channels:
195
- if c not in patch_channels:
196
- event_list.append(["patch_change", 0,0, track_idx_dict[c], c, 0])
197
 
198
- events_name_order = {"set_tempo":0, "patch_change":1, "control_change":2, "note":3}
199
  events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
200
  event_list = sorted(event_list, key=events_order)
201
 
@@ -214,7 +229,7 @@ class MIDITokenizer:
214
  if notes_in_setup and i > 0:
215
  pre_event = event_list[i - 1]
216
  has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
217
- if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
218
  event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
219
  break
220
  else:
@@ -253,17 +268,17 @@ class MIDITokenizer:
253
  return tokens
254
 
255
  def tokens2event(self, tokens):
256
- if tokens[0] in self.id_events:
257
- name = self.id_events[tokens[0]]
258
- if len(tokens) <= len(self.events[name]):
259
- return []
260
- params = tokens[1:]
261
- params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
262
- if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
263
- return []
264
- event = [name] + params
265
- return event
266
- return []
267
 
268
  def detokenize(self, midi_seq):
269
  ticks_per_beat = 480
@@ -386,7 +401,9 @@ class MIDITokenizer:
386
  midi_seq_new.append(tokens_new)
387
  return midi_seq_new
388
 
389
- def check_quality(self, midi_seq, alignment_min=0.4, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3, notes_density_max=30, notes_density_min=2.5, total_notes_max=10000, total_notes_min=500, note_window_size=16):
 
 
390
  total_notes = 0
391
  channels = []
392
  time_hist = [0] * 16
@@ -450,13 +467,648 @@ class MIDITokenizer:
450
  tonality_list.append(sum(key_hist[:7]) / len(notes))
451
  notes_density_list.append(len(notes) / note_window_size)
452
  tonality_list = sorted(tonality_list)
453
- tonality = sum(tonality_list)/len(tonality_list)
454
- notes_bandwidth = sum(notes_bandwidth_list)/len(notes_bandwidth_list) if notes_bandwidth_list else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  notes_density = max(notes_density_list) if notes_density_list else 0
456
  piano_ratio = len(piano_channels) / len(channels)
457
- if len(channels) <=3: # ignore piano threshold if it is a piano solo midi
458
  piano_max = 1
459
- if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
460
  reasons.append("alignment")
461
  if tonality < tonality_min: # check whether the music is tonal
462
  reasons.append("tonality")
@@ -464,6 +1116,16 @@ class MIDITokenizer:
464
  reasons.append("bandwidth")
465
  if not notes_density_min < notes_density < notes_density_max:
466
  reasons.append("density")
467
- if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
468
  reasons.append("piano")
469
  return not reasons, reasons
 
 
 
 
 
 
 
 
 
 
 
1
  import random
2
 
3
+ import PIL.Image
4
  import numpy as np
5
 
6
 
7
+ class MIDITokenizerV1:
8
  def __init__(self):
9
+ self.version = "v1"
10
+ self.optimise_midi = False
11
  self.vocab_size = 0
12
 
13
  def allocate_ids(size):
 
33
  self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
34
  self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
35
 
36
+ def set_optimise_midi(self, optimise_midi=True):
37
+ self.optimise_midi = optimise_midi
38
+
39
+ @staticmethod
40
+ def tempo2bpm(tempo):
41
  tempo = tempo / 10 ** 6 # us to s
42
  bpm = 60 / tempo
43
  return bpm
44
 
45
+ @staticmethod
46
+ def bpm2tempo(bpm):
47
  if bpm == 0:
48
  bpm = 1
49
  tempo = int((60 / bpm) * 10 ** 6)
50
  return tempo
51
 
52
  def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
53
+ remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
54
+ if remap_track_channel is None: # set default value
55
+ remap_track_channel = self.optimise_midi
56
+ if add_default_instr is None:
57
+ add_default_instr = self.optimise_midi
58
+ if remove_empty_channels is None:
59
+ remove_empty_channels = self.optimise_midi
60
+
61
  ticks_per_beat = midi_score[0]
62
  event_list = {}
63
  track_idx_map = {i: dict() for i in range(16)}
64
  track_idx_dict = {}
65
  channels = []
66
  patch_channels = []
67
+ empty_channels = [True] * 16
68
  channel_note_tracks = {i: list() for i in range(16)}
69
  for track_idx, track in enumerate(midi_score[1:129]):
70
  last_notes = {}
 
88
  note_tracks.append(track_idx)
89
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
90
  elif event[0] == "set_tempo":
91
+ if new_event[4] == 0: # invalid tempo
92
  continue
93
  bpm = int(self.tempo2bpm(new_event[4]))
94
  new_event[4] = min(bpm, 255)
 
157
  channels = list(channels_map.values())
158
 
159
  track_count = 0
160
+ track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
161
+ for c in track_idx_map_order: # tracks not to remove
162
  if remove_empty_channels and c in empty_channels:
163
  continue
164
  tr_map = track_idx_map[c]
 
168
  continue
169
  track_count += 1
170
  tr_map[track_idx] = track_count
171
+ for c in track_idx_map_order: # tracks to remove
172
  if not (remove_empty_channels and c in empty_channels):
173
  continue
174
  tr_map = track_idx_map[c]
 
180
  tr_map[track_idx] = track_count
181
 
182
  empty_channels = [channels_map[c] for c in empty_channels]
183
+ track_idx_dict = {}
184
  for event in event_list:
185
  name = event[0]
186
  track_idx = event[3]
 
188
  c = event[5]
189
  event[5] = channels_map[c]
190
  event[3] = track_idx_map[c][track_idx]
191
+ track_idx_dict.setdefault(event[5], event[3])
192
+ # setdefault, so the track_idx is first of the channel
193
  elif name == "set_tempo":
194
  event[3] = 0
195
  elif name == "control_change" or name == "patch_change":
 
207
 
208
  if add_default_instr:
209
  for c in channels:
210
+ if c not in patch_channels and c in track_idx_dict:
211
+ event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
212
 
213
+ events_name_order = {"set_tempo": 0, "patch_change": 1, "control_change": 2, "note": 3}
214
  events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
215
  event_list = sorted(event_list, key=events_order)
216
 
 
229
  if notes_in_setup and i > 0:
230
  pre_event = event_list[i - 1]
231
  has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
232
+ if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
233
  event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
234
  break
235
  else:
 
268
  return tokens
269
 
270
  def tokens2event(self, tokens):
271
+ if tokens[0] not in self.id_events:
272
+ return []
273
+ name = self.id_events[tokens[0]]
274
+ if len(tokens) <= len(self.events[name]):
275
+ return []
276
+ params = tokens[1:]
277
+ params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
278
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
279
+ return []
280
+ event = [name] + params
281
+ return event
282
 
283
  def detokenize(self, midi_seq):
284
  ticks_per_beat = 480
 
401
  midi_seq_new.append(tokens_new)
402
  return midi_seq_new
403
 
404
+ def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
405
+ notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
406
+ note_window_size=16):
407
  total_notes = 0
408
  channels = []
409
  time_hist = [0] * 16
 
467
  tonality_list.append(sum(key_hist[:7]) / len(notes))
468
  notes_density_list.append(len(notes) / note_window_size)
469
  tonality_list = sorted(tonality_list)
470
+ tonality = sum(tonality_list) / len(tonality_list)
471
+ notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
472
+ notes_density = max(notes_density_list) if notes_density_list else 0
473
+ piano_ratio = len(piano_channels) / len(channels)
474
+ if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
475
+ piano_max = 1
476
+ if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
477
+ reasons.append("alignment")
478
+ if tonality < tonality_min: # check whether the music is tonal
479
+ reasons.append("tonality")
480
+ if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
481
+ reasons.append("bandwidth")
482
+ if not notes_density_min < notes_density < notes_density_max:
483
+ reasons.append("density")
484
+ if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
485
+ reasons.append("piano")
486
+ return not reasons, reasons
487
+
488
+
489
+ class MIDITokenizerV2:
490
+ def __init__(self):
491
+ self.version = "v2"
492
+ self.optimise_midi = False
493
+ self.vocab_size = 0
494
+
495
+ def allocate_ids(size):
496
+ ids = [self.vocab_size + i for i in range(size)]
497
+ self.vocab_size += size
498
+ return ids
499
+
500
+ self.pad_id = allocate_ids(1)[0]
501
+ self.bos_id = allocate_ids(1)[0]
502
+ self.eos_id = allocate_ids(1)[0]
503
+ self.events = {
504
+ "note": ["time1", "time2", "track", "channel", "pitch", "velocity", "duration"],
505
+ "patch_change": ["time1", "time2", "track", "channel", "patch"],
506
+ "control_change": ["time1", "time2", "track", "channel", "controller", "value"],
507
+ "set_tempo": ["time1", "time2", "track", "bpm"],
508
+ "time_signature": ["time1", "time2", "track", "nn", "dd"],
509
+ "key_signature": ["time1", "time2", "track", "sf", "mi"],
510
+ }
511
+ self.event_parameters = {
512
+ "time1": 128, "time2": 16, "duration": 2048, "track": 128, "channel": 16, "pitch": 128, "velocity": 128,
513
+ "patch": 128, "controller": 128, "value": 128, "bpm": 384, "nn": 16, "dd": 4, "sf": 15, "mi": 2
514
+ }
515
+ self.event_ids = {e: allocate_ids(1)[0] for e in self.events.keys()}
516
+ self.id_events = {i: e for e, i in self.event_ids.items()}
517
+ self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
518
+ self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
519
+
520
+ def set_optimise_midi(self, optimise_midi=True):
521
+ self.optimise_midi = optimise_midi
522
+
523
+ @staticmethod
524
+ def tempo2bpm(tempo):
525
+ tempo = tempo / 10 ** 6 # us to s
526
+ bpm = 60 / tempo
527
+ return bpm
528
+
529
+ @staticmethod
530
+ def bpm2tempo(bpm):
531
+ if bpm == 0:
532
+ bpm = 1
533
+ tempo = int((60 / bpm) * 10 ** 6)
534
+ return tempo
535
+
536
+ @staticmethod
537
+ def sf2key(sf):
538
+ # sf in key_signature to key.
539
+ # key represents the sequence from C note to B note (12 in total)
540
+ return (sf * 7) % 12
541
+
542
+ @staticmethod
543
+ def key2sf(k, mi):
544
+ # key to sf
545
+ sf = (k * 7) % 12
546
+ if sf > 6 or (mi == 1 and sf >= 5):
547
+ sf -= 12
548
+ return sf
549
+
550
+ @staticmethod
551
+ def detect_key_signature(key_hist, threshold=0.7):
552
+ if len(key_hist) != 12:
553
+ return None
554
+ p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
555
+ if p < threshold:
556
+ return None
557
+ keys = [x[1] for x in sorted(zip(key_hist, range(len(key_hist))), reverse=True, key=lambda x: x[0])[:7]]
558
+ keys = sorted(keys)
559
+ semitones = []
560
+ for i in range(len(keys)):
561
+ dis = keys[i] - keys[i - 1]
562
+ if dis == 1 or dis == -11:
563
+ semitones.append(keys[i])
564
+ if len(semitones) != 2:
565
+ return None
566
+ semitones_dis = semitones[1] - semitones[0]
567
+ if semitones_dis == 5:
568
+ root_key = semitones[0]
569
+ elif semitones_dis == 7:
570
+ root_key = semitones[1]
571
+ else:
572
+ return None
573
+ return root_key
574
+
575
+ def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
576
+ remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
577
+ if remap_track_channel is None: # set default value
578
+ remap_track_channel = self.optimise_midi
579
+ if add_default_instr is None:
580
+ add_default_instr = self.optimise_midi
581
+ if remove_empty_channels is None:
582
+ remove_empty_channels = self.optimise_midi
583
+
584
+ ticks_per_beat = midi_score[0]
585
+ event_list = {}
586
+ track_idx_map = {i: dict() for i in range(16)}
587
+ track_idx_dict = {}
588
+ channels = []
589
+ patch_channels = []
590
+ empty_channels = [True] * 16
591
+ channel_note_tracks = {i: list() for i in range(16)}
592
+ note_key_hist = [0]*12
593
+ key_sig_num = 0
594
+ track_to_channels = {}
595
+ for track_idx, track in enumerate(midi_score[1:129]):
596
+ last_notes = {}
597
+ patch_dict = {}
598
+ control_dict = {}
599
+ last_bpm = 0
600
+ track_channels = []
601
+ track_to_channels.setdefault(track_idx, track_channels)
602
+ for event in track:
603
+ if event[0] not in self.events:
604
+ continue
605
+ name = event[0]
606
+ c = -1
607
+ t = round(16 * event[1] / ticks_per_beat) # quantization
608
+ new_event = [name, t // 16, t % 16, track_idx]
609
+ if name == "note":
610
+ d, c, p, v = event[2:]
611
+ if not (0 <= c <= 15):
612
+ continue
613
+ d = max(1, round(16 * d / ticks_per_beat))
614
+ new_event += [c, p, v, d]
615
+ empty_channels[c] = False
616
+ track_idx_dict.setdefault(c, track_idx)
617
+ note_tracks = channel_note_tracks[c]
618
+ if track_idx not in note_tracks:
619
+ note_tracks.append(track_idx)
620
+ if c != 9:
621
+ note_key_hist[p%12] += 1
622
+ if c not in track_channels:
623
+ track_channels.append(c)
624
+ elif name == "patch_change":
625
+ c, p = event[2:]
626
+ if not (0 <= c <= 15):
627
+ continue
628
+ new_event += [c, p]
629
+ last_p = patch_dict.setdefault(c, None)
630
+ if last_p == p:
631
+ continue
632
+ patch_dict[c] = p
633
+ if c not in patch_channels:
634
+ patch_channels.append(c)
635
+ elif name == "control_change":
636
+ c, cc, v = event[2:]
637
+ if not (0 <= c <= 15):
638
+ continue
639
+ new_event += [c, cc, v]
640
+ last_v = control_dict.setdefault((c, cc), 0)
641
+ if abs(last_v - v) < cc_eps:
642
+ continue
643
+ control_dict[(c, cc)] = v
644
+ elif name == "set_tempo":
645
+ tempo = event[2]
646
+ if tempo == 0: # invalid tempo
647
+ continue
648
+ bpm = min(int(self.tempo2bpm(tempo)), 383)
649
+ new_event += [bpm]
650
+ if abs(last_bpm - bpm) < tempo_eps:
651
+ continue
652
+ last_bpm = bpm
653
+ elif name == "time_signature":
654
+ nn, dd = event[2:4]
655
+ if not (1 <= nn <= 16 and 1 <= dd <= 4): # invalid
656
+ continue
657
+ nn -= 1 # make it start from 0
658
+ dd -= 1
659
+ new_event += [nn, dd]
660
+ elif name == "key_signature":
661
+ sf, mi = event[2:]
662
+ if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
663
+ continue
664
+ key_sig_num += 1
665
+ sf += 7
666
+ new_event += [sf, mi]
667
+
668
+ if name == "note":
669
+ key = tuple(new_event[:-2])
670
+ else:
671
+ key = tuple(new_event[:-1])
672
+
673
+ if c != -1:
674
+ if c not in channels:
675
+ channels.append(c)
676
+ tr_map = track_idx_map[c]
677
+ if track_idx not in tr_map:
678
+ tr_map[track_idx] = 0
679
+
680
+ if event[0] == "note": # to eliminate note overlap due to quantization
681
+ cp = tuple(new_event[4:6]) # channel pitch
682
+ if cp in last_notes:
683
+ last_note_key, last_note = last_notes[cp]
684
+ last_t = last_note[1] * 16 + last_note[2]
685
+ last_note[-1] = max(0, min(last_note[-1], t - last_t)) # modify duration
686
+ if last_note[-1] == 0:
687
+ event_list.pop(last_note_key)
688
+ last_notes[cp] = (key, new_event)
689
+ event_list[key] = new_event
690
+ event_list = list(event_list.values())
691
+
692
+ empty_channels = [c for c in channels if empty_channels[c]]
693
+
694
+ if remap_track_channel:
695
+ patch_channels = []
696
+ channels_count = 0
697
+ channels_map = {9: 9} if 9 in channels else {}
698
+ if remove_empty_channels:
699
+ channels = sorted(channels, key=lambda x: 1 if x in empty_channels else 0)
700
+ for c in channels:
701
+ if c == 9:
702
+ continue
703
+ channels_map[c] = channels_count
704
+ channels_count += 1
705
+ if channels_count == 9:
706
+ channels_count = 10
707
+ channels = list(channels_map.values())
708
+
709
+ track_count = 0
710
+ track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
711
+ for c in track_idx_map_order: # tracks not to remove
712
+ if remove_empty_channels and c in empty_channels:
713
+ continue
714
+ tr_map = track_idx_map[c]
715
+ for track_idx in tr_map:
716
+ note_tracks = channel_note_tracks[c]
717
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
718
+ continue
719
+ track_count += 1
720
+ tr_map[track_idx] = track_count
721
+ for c in track_idx_map_order: # tracks to remove
722
+ if not (remove_empty_channels and c in empty_channels):
723
+ continue
724
+ tr_map = track_idx_map[c]
725
+ for track_idx in tr_map:
726
+ note_tracks = channel_note_tracks[c]
727
+ if not (len(note_tracks) != 0 and track_idx not in note_tracks):
728
+ continue
729
+ track_count += 1
730
+ tr_map[track_idx] = track_count
731
+
732
+ empty_channels = [channels_map[c] for c in empty_channels]
733
+ track_idx_dict = {}
734
+ key_signature_to_add = []
735
+ for event in event_list:
736
+ name = event[0]
737
+ track_idx = event[3]
738
+ if name == "note":
739
+ c = event[4]
740
+ event[4] = channels_map[c] # channel
741
+ event[3] = track_idx_map[c][track_idx] # track
742
+ track_idx_dict.setdefault(event[4], event[3])
743
+ # setdefault, so the track_idx is first of the channel
744
+ elif name in ["set_tempo", "time_signature"]:
745
+ event[3] = 0 # set track 0 for meta events
746
+ elif name == "key_signature":
747
+ new_channel_track_idxs = []
748
+ for c, tr_map in track_idx_map.items():
749
+ if track_idx in tr_map:
750
+ new_track_idx = tr_map[track_idx]
751
+ new_channel_track_idx = (c, new_track_idx)
752
+ if new_channel_track_idx not in new_channel_track_idxs:
753
+ new_channel_track_idxs.append(new_channel_track_idx)
754
+ if len(new_channel_track_idxs) == 0:
755
+ event[3] = 0
756
+ continue
757
+ c, nt = new_channel_track_idxs[0]
758
+ event[3] = nt
759
+ if c == 9:
760
+ event[4] = 7 # sf=0
761
+ for c, nt in new_channel_track_idxs[1:]:
762
+ new_event = [*event]
763
+ new_event[3] = nt
764
+ if c == 9:
765
+ new_event[4] = 7 # sf=0
766
+ key_signature_to_add.append(new_event)
767
+ elif name == "control_change" or name == "patch_change":
768
+ c = event[4]
769
+ event[4] = channels_map[c] # channel
770
+ tr_map = track_idx_map[c]
771
+ # move the event to first track of the channel if it's original track is empty
772
+ note_tracks = channel_note_tracks[c]
773
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
774
+ track_idx = channel_note_tracks[c][0]
775
+ new_track_idx = tr_map.setdefault(track_idx, next(iter(tr_map.values())))
776
+ event[3] = new_track_idx
777
+ if name == "patch_change" and event[4] not in patch_channels:
778
+ patch_channels.append(event[4])
779
+ event_list += key_signature_to_add
780
+ track_to_channels ={}
781
+ for c, tr_map in track_idx_map.items():
782
+ if c not in channels_map:
783
+ continue
784
+ c = channels_map[c]
785
+ for _, track_idx in tr_map.items():
786
+ track_to_channels.setdefault(track_idx, [])
787
+ cs = track_to_channels[track_idx]
788
+ if c not in cs:
789
+ cs.append(c)
790
+
791
+ if add_default_instr:
792
+ for c in channels:
793
+ if c not in patch_channels and c in track_idx_dict:
794
+ event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
795
+
796
+ if key_sig_num == 0:
797
+ # detect key signature.
798
+ root_key = self.detect_key_signature(note_key_hist)
799
+ if root_key is not None:
800
+ sf = self.key2sf(root_key, 0)
801
+ # print("detect_key_signature",sf)
802
+ for tr, cs in track_to_channels.items():
803
+ if remap_track_channel and tr == 0:
804
+ continue
805
+ event_list.append(["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0])
806
+
807
+ events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
808
+ events_name_order = {name: i for i, name in enumerate(events_name_order)}
809
+ events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
810
+ event_list = sorted(event_list, key=events_order)
811
+
812
+ setup_events = {}
813
+ notes_in_setup = False
814
+ for i, event in enumerate(event_list): # optimise setup
815
+ new_event = [*event] # make copy of event
816
+ if event[0] not in ["note", "time_signature"]:
817
+ new_event[1] = 0
818
+ new_event[2] = 0
819
+ has_next = False
820
+ has_pre = False
821
+ if i < len(event_list) - 1:
822
+ next_event = event_list[i + 1]
823
+ has_next = event[1] + event[2] == next_event[1] + next_event[2]
824
+ if notes_in_setup and i > 0:
825
+ pre_event = event_list[i - 1]
826
+ has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
827
+ if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
828
+ event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
829
+ break
830
+ else:
831
+ if event[0] == "note":
832
+ notes_in_setup = True
833
+ key = tuple(event[3:-1])
834
+ setup_events[key] = new_event
835
+
836
+ last_t1 = 0
837
+ midi_seq = []
838
+ for event in event_list:
839
+ if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
840
+ continue
841
+ cur_t1 = event[1]
842
+ event[1] = event[1] - last_t1
843
+ tokens = self.event2tokens(event)
844
+ if not tokens:
845
+ continue
846
+ midi_seq.append(tokens)
847
+ last_t1 = cur_t1
848
+
849
+ if add_bos_eos:
850
+ bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
851
+ eos = [self.eos_id] + [self.pad_id] * (self.max_token_seq - 1)
852
+ midi_seq = [bos] + midi_seq + [eos]
853
+ return midi_seq
854
+
855
+ def event2tokens(self, event):
856
+ name = event[0]
857
+ params = event[1:]
858
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
859
+ return []
860
+ tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
861
+ for i, p in enumerate(self.events[name])]
862
+ tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
863
+ return tokens
864
+
865
+ def tokens2event(self, tokens):
866
+ if tokens[0] not in self.id_events:
867
+ return []
868
+ name = self.id_events[tokens[0]]
869
+ if len(tokens) <= len(self.events[name]):
870
+ return []
871
+ params = tokens[1:]
872
+ params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
873
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
874
+ return []
875
+ event = [name] + params
876
+ return event
877
+
878
+ def detokenize(self, midi_seq):
879
+ ticks_per_beat = 480
880
+ tracks_dict = {}
881
+ t1 = 0
882
+ for tokens in midi_seq:
883
+ if tokens[0] in self.id_events:
884
+ event = self.tokens2event(tokens)
885
+ if not event:
886
+ continue
887
+ name = event[0]
888
+ t1 += event[1]
889
+ t = t1 * 16 + event[2]
890
+ t = int(t * ticks_per_beat / 16)
891
+ track_idx = event[3]
892
+ event_new = [name, t]
893
+ if name == "note":
894
+ c, p, v, d = event[4:]
895
+ d = int(d * ticks_per_beat / 16)
896
+ event_new += [d, c, p, v]
897
+ elif name == "control_change" or name == "patch_change":
898
+ event_new += event[4:]
899
+ elif name == "set_tempo":
900
+ event_new += [self.bpm2tempo(event[4])]
901
+ elif name == "time_signature":
902
+ nn, dd = event[4:]
903
+ nn += 1
904
+ dd += 1
905
+ event_new += [nn, dd, 24, 8] # usually cc, bb = 24, 8
906
+ elif name == "key_signature":
907
+ sf, mi = event[4:]
908
+ sf -= 7
909
+ event_new += [sf, mi]
910
+ else: # should not go here
911
+ continue
912
+ if track_idx not in tracks_dict:
913
+ tracks_dict[track_idx] = []
914
+ tracks_dict[track_idx].append(event_new)
915
+ tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
916
+
917
+ for i in range(len(tracks)): # to eliminate note overlap
918
+ track = tracks[i]
919
+ track = sorted(track, key=lambda e: e[1])
920
+ last_note_t = {}
921
+ zero_len_notes = []
922
+ for e in reversed(track):
923
+ if e[0] == "note":
924
+ t, d, c, p = e[1:5]
925
+ key = (c, p)
926
+ if key in last_note_t:
927
+ d = min(d, max(last_note_t[key] - t, 0))
928
+ last_note_t[key] = t
929
+ e[2] = d
930
+ if d == 0:
931
+ zero_len_notes.append(e)
932
+ for e in zero_len_notes:
933
+ track.remove(e)
934
+ tracks[i] = track
935
+ return [ticks_per_beat, *tracks]
936
+
937
+ def midi2img(self, midi_score):
938
+ ticks_per_beat = midi_score[0]
939
+ notes = []
940
+ max_time = 1
941
+ track_num = len(midi_score[1:])
942
+ for track_idx, track in enumerate(midi_score[1:]):
943
+ for event in track:
944
+ t = round(16 * event[1] / ticks_per_beat)
945
+ if event[0] == "note":
946
+ d = max(1, round(16 * event[2] / ticks_per_beat))
947
+ c, p = event[3:5]
948
+ max_time = max(max_time, t + d + 1)
949
+ notes.append((track_idx, c, p, t, d))
950
+ img = np.zeros((128, max_time, 3), dtype=np.uint8)
951
+ colors = {(i, j): np.random.randint(50, 256, 3) for i in range(track_num) for j in range(16)}
952
+ for note in notes:
953
+ tr, c, p, t, d = note
954
+ img[p, t: t + d] = colors[(tr, c)]
955
+ img = PIL.Image.fromarray(np.flip(img, 0))
956
+ return img
957
+
958
+ def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
959
+ max_track_shift=0, max_channel_shift=16):
960
+ pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
961
+ vel_shift = random.randint(-max_vel_shift, max_vel_shift)
962
+ cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
963
+ bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
964
+ track_shift = random.randint(0, max_track_shift)
965
+ channel_shift = random.randint(0, max_channel_shift)
966
+ midi_seq_new = []
967
+ key_signature_tokens = []
968
+ track_to_channels = {}
969
+ for tokens in midi_seq:
970
+ tokens_new = [*tokens]
971
+ if tokens[0] in self.id_events:
972
+ name = self.id_events[tokens[0]]
973
+ for i, pn in enumerate(self.events[name]):
974
+ if pn == "track":
975
+ tr = tokens[1 + i] - self.parameter_ids[pn][0]
976
+ tr += track_shift
977
+ tr = tr % self.event_parameters[pn]
978
+ tokens_new[1 + i] = self.parameter_ids[pn][tr]
979
+ elif pn == "channel":
980
+ c = tokens[1 + i] - self.parameter_ids[pn][0]
981
+ c0 = c
982
+ c += channel_shift
983
+ c = c % self.event_parameters[pn]
984
+ if c0 == 9:
985
+ c = 9
986
+ elif c == 9:
987
+ c = (9 + channel_shift) % self.event_parameters[pn]
988
+ tokens_new[1 + i] = self.parameter_ids[pn][c]
989
+
990
+ if name == "note":
991
+ tr = tokens[3] - self.parameter_ids["track"][0]
992
+ c = tokens[4] - self.parameter_ids["channel"][0]
993
+ p = tokens[5] - self.parameter_ids["pitch"][0]
994
+ v = tokens[6] - self.parameter_ids["velocity"][0]
995
+ if c != 9: # no shift for drums
996
+ p += pitch_shift
997
+ if not 0 <= p < 128:
998
+ return midi_seq
999
+ v += vel_shift
1000
+ v = max(1, min(127, v))
1001
+ tokens_new[5] = self.parameter_ids["pitch"][p]
1002
+ tokens_new[6] = self.parameter_ids["velocity"][v]
1003
+ track_to_channels.setdefault(tr, [])
1004
+ cs = track_to_channels[tr]
1005
+ if c not in cs:
1006
+ cs.append(c)
1007
+ elif name == "control_change":
1008
+ cc = tokens[5] - self.parameter_ids["controller"][0]
1009
+ val = tokens[6] - self.parameter_ids["value"][0]
1010
+ if cc in [1, 2, 7, 11]:
1011
+ val += cc_val_shift
1012
+ val = max(1, min(127, val))
1013
+ tokens_new[6] = self.parameter_ids["value"][val]
1014
+ elif name == "set_tempo":
1015
+ bpm = tokens[4] - self.parameter_ids["bpm"][0]
1016
+ bpm += bpm_shift
1017
+ bpm = max(1, min(383, bpm))
1018
+ tokens_new[4] = self.parameter_ids["bpm"][bpm]
1019
+ elif name == "key_signature":
1020
+ sf = tokens[4] - self.parameter_ids["sf"][0]
1021
+ mi = tokens[5] - self.parameter_ids["mi"][0]
1022
+ sf -= 7
1023
+ k = self.sf2key(sf)
1024
+ k = (k + pitch_shift) % 12
1025
+ sf = self.key2sf(k, mi)
1026
+ sf += 7
1027
+ tokens_new[4] = self.parameter_ids["sf"][sf]
1028
+ tokens_new[5] = self.parameter_ids["mi"][mi]
1029
+ key_signature_tokens.append(tokens_new)
1030
+ midi_seq_new.append(tokens_new)
1031
+ for tokens in key_signature_tokens:
1032
+ tr = tokens[3] - self.parameter_ids["track"][0]
1033
+ if tr in track_to_channels:
1034
+ cs = track_to_channels[tr]
1035
+ if len(cs) == 1 and cs[0] == 9:
1036
+ tokens[4] = self.parameter_ids["sf"][7] # sf=0
1037
+ return midi_seq_new
1038
+
1039
+ def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
1040
+ notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
1041
+ note_window_size=16):
1042
+ total_notes = 0
1043
+ channels = []
1044
+ time_hist = [0] * 16
1045
+ note_windows = {}
1046
+ notes_sametime = []
1047
+ notes_density_list = []
1048
+ tonality_list = []
1049
+ notes_bandwidth_list = []
1050
+ instruments = {}
1051
+ piano_channels = []
1052
+ abs_t1 = 0
1053
+ last_t = 0
1054
+ for tsi, tokens in enumerate(midi_seq):
1055
+ event = self.tokens2event(tokens)
1056
+ if not event:
1057
+ continue
1058
+ t1, t2, tr = event[1:4]
1059
+ abs_t1 += t1
1060
+ t = abs_t1 * 16 + t2
1061
+ c = None
1062
+ if event[0] == "note":
1063
+ c, p, v, d = event[4:]
1064
+ total_notes += 1
1065
+ time_hist[t2] += 1
1066
+ if c != 9: # ignore drum channel
1067
+ if c not in instruments:
1068
+ instruments[c] = 0
1069
+ if c not in piano_channels:
1070
+ piano_channels.append(c)
1071
+ note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
1072
+ if last_t != t:
1073
+ notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
1074
+ notes_sametime_p = [p_ for _, p_ in notes_sametime]
1075
+ if len(notes_sametime) > 0:
1076
+ notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
1077
+ notes_sametime.append((t + d - 1, p))
1078
+ elif event[0] == "patch_change":
1079
+ c, p = event[4:]
1080
+ instruments[c] = p
1081
+ if p == 0 and c not in piano_channels:
1082
+ piano_channels.append(c)
1083
+ if c is not None and c not in channels:
1084
+ channels.append(c)
1085
+ last_t = t
1086
+ reasons = []
1087
+ if total_notes < total_notes_min:
1088
+ reasons.append("total_min")
1089
+ if total_notes > total_notes_max:
1090
+ reasons.append("total_max")
1091
+ if len(note_windows) == 0 and total_notes > 0:
1092
+ reasons.append("drum_only")
1093
+ if reasons:
1094
+ return False, reasons
1095
+ time_hist = sorted(time_hist, reverse=True)
1096
+ alignment = sum(time_hist[:2]) / total_notes
1097
+ for notes in note_windows.values():
1098
+ key_hist = [0] * 12
1099
+ for p in notes:
1100
+ key_hist[p % 12] += 1
1101
+ key_hist = sorted(key_hist, reverse=True)
1102
+ tonality_list.append(sum(key_hist[:7]) / len(notes))
1103
+ notes_density_list.append(len(notes) / note_window_size)
1104
+ tonality_list = sorted(tonality_list)
1105
+ tonality = sum(tonality_list) / len(tonality_list)
1106
+ notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
1107
  notes_density = max(notes_density_list) if notes_density_list else 0
1108
  piano_ratio = len(piano_channels) / len(channels)
1109
+ if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
1110
  piano_max = 1
1111
+ if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
1112
  reasons.append("alignment")
1113
  if tonality < tonality_min: # check whether the music is tonal
1114
  reasons.append("tonality")
 
1116
  reasons.append("bandwidth")
1117
  if not notes_density_min < notes_density < notes_density_max:
1118
  reasons.append("density")
1119
+ if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
1120
  reasons.append("piano")
1121
  return not reasons, reasons
1122
+
1123
+
1124
+ class MIDITokenizer:
1125
+ def __new__(cls, version="v2"):
1126
+ if version == "v1":
1127
+ return MIDITokenizerV1()
1128
+ elif version == "v2":
1129
+ return MIDITokenizerV2()
1130
+ else:
1131
+ raise ValueError(f"Unsupported version: {version}")