skytnt commited on
Commit
fd012a7
1 Parent(s): 81592e1
Files changed (3) hide show
  1. app.py +58 -92
  2. midi_model.py +56 -16
  3. requirements.txt +3 -1
app.py CHANGED
@@ -1,79 +1,53 @@
1
  import argparse
2
  import glob
3
  import json
4
- import os.path
5
  import time
6
 
7
  import gradio as gr
8
  import numpy as np
9
- import onnxruntime as rt
 
10
  import tqdm
11
  from huggingface_hub import hf_hub_download
12
 
13
  import MIDI
 
14
  from midi_synthesizer import MidiSynthesizer
15
- from midi_tokenizer import MIDITokenizer
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  in_space = os.getenv("SYSTEM") == "spaces"
19
 
20
 
21
- def softmax(x, axis):
22
- x_max = np.amax(x, axis=axis, keepdims=True)
23
- exp_x_shifted = np.exp(x - x_max)
24
- return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
25
-
26
-
27
- def sample_top_p_k(probs, p, k, generator=None):
28
- if generator is None:
29
- generator = np.random
30
- probs_idx = np.argsort(-probs, axis=-1)
31
- probs_sort = np.take_along_axis(probs, probs_idx, -1)
32
- probs_sum = np.cumsum(probs_sort, axis=-1)
33
- mask = probs_sum - probs_sort > p
34
- probs_sort[mask] = 0.0
35
- mask = np.zeros(probs_sort.shape[-1])
36
- mask[:k] = 1
37
- probs_sort = probs_sort * mask
38
- probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
39
- shape = probs_sort.shape
40
- probs_sort_flat = probs_sort.reshape(-1, shape[-1])
41
- probs_idx_flat = probs_idx.reshape(-1, shape[-1])
42
- next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
43
- next_token = next_token.reshape(*shape[:-1])
44
- return next_token
45
-
46
-
47
- def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
48
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
49
- tokenizer = model[2]
50
  if disable_channels is not None:
51
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
52
  else:
53
  disable_channels = []
54
- if generator is None:
55
- generator = np.random
56
  max_token_seq = tokenizer.max_token_seq
57
  if prompt is None:
58
- input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
59
  input_tensor[0, 0] = tokenizer.bos_id # bos
60
  else:
61
  prompt = prompt[:, :max_token_seq]
62
  if prompt.shape[-1] < max_token_seq:
63
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
64
  mode="constant", constant_values=tokenizer.pad_id)
65
- input_tensor = prompt
66
- input_tensor = input_tensor[None, :, :]
67
  cur_len = input_tensor.shape[1]
68
- bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
69
  with bar:
70
  while cur_len < max_len:
71
  end = False
72
- hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
73
- next_token_seq = np.empty((1, 0), dtype=np.int64)
74
  event_name = ""
75
  for i in range(max_token_seq):
76
- mask = np.zeros(tokenizer.vocab_size, dtype=np.int64)
77
  if i == 0:
78
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
79
  if disable_patch_change:
@@ -87,9 +61,9 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
87
  if param_name == "channel":
88
  mask_ids = [i for i in mask_ids if i not in disable_channels]
89
  mask[mask_ids] = 1
90
- logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
91
- scores = softmax(logits / temp, -1) * mask
92
- sample = sample_top_p_k(scores, top_p, top_k, generator)
93
  if i == 0:
94
  next_token_seq = sample
95
  eid = sample.item()
@@ -98,17 +72,17 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
98
  break
99
  event_name = tokenizer.id_events[eid]
100
  else:
101
- next_token_seq = np.concatenate([next_token_seq, sample], axis=1)
102
  if len(tokenizer.events[event_name]) == i:
103
  break
104
  if next_token_seq.shape[1] < max_token_seq:
105
- next_token_seq = np.pad(next_token_seq, ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
106
- mode="constant", constant_values=tokenizer.pad_id)
107
- next_token_seq = next_token_seq[None, :, :]
108
- input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
109
  cur_len += 1
110
  bar.update(1)
111
- yield next_token_seq.reshape(-1)
112
  if end:
113
  break
114
 
@@ -125,7 +99,7 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
125
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
126
  gen_events, temp, top_p, top_k, allow_cc):
127
  model = models[model_name]
128
- tokenizer = model[2]
129
  bpm = int(bpm)
130
  if time_sig == "auto":
131
  time_sig = None
@@ -147,7 +121,7 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
147
  max_len = gen_events
148
  if seed_rand:
149
  seed = np.random.randint(0, MAX_SEED)
150
- generator = np.random.RandomState(seed)
151
  disable_patch_change = False
152
  disable_channels = None
153
  if tab == 0:
@@ -203,22 +177,24 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
203
  init_msgs += [create_msg("visualizer_clear", tokenizer.version),
204
  create_msg("visualizer_append", events)]
205
  yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
206
- midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
207
- disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
208
- disable_channels=disable_channels, generator=generator)
209
- events = []
210
- t = time.time() + 1
211
- for i, token_seq in enumerate(midi_generator):
212
- token_seq = token_seq.tolist()
213
- mid_seq.append(token_seq)
214
- events.append(tokenizer.tokens2event(token_seq))
215
- ct = time.time()
216
- if ct - t > 0.5:
217
- yield (mid_seq, continuation_state, None, None, seed,
218
- send_msgs([create_msg("visualizer_append", events),
219
- create_msg("progress", [i + 1, gen_events])]))
220
- t = ct
221
- events = []
 
 
222
 
223
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
224
  mid = tokenizer.detokenize(mid_seq)
@@ -235,7 +211,7 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
235
  def cancel_run(model_name, mid_seq):
236
  if mid_seq is None:
237
  return None, None, []
238
- tokenizer = models[model_name][2]
239
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
240
  mid = tokenizer.detokenize(mid_seq)
241
  audio = synthesizer.synthesis(MIDI.score2opus(mid))
@@ -248,11 +224,12 @@ def cancel_run(model_name, mid_seq):
248
  return "output.mid", (44100, audio), send_msgs(end_msgs)
249
 
250
 
251
- def undo_continuation(mid_seq, continuation_state):
252
  if mid_seq is None or len(continuation_state) < 2:
253
  return mid_seq, continuation_state, send_msgs([])
254
  mid_seq = mid_seq[:continuation_state[-1]]
255
  continuation_state = continuation_state[:-1]
 
256
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
257
  end_msgs = [create_msg("visualizer_clear", tokenizer.version),
258
  create_msg("visualizer_append", events),
@@ -293,21 +270,6 @@ def hf_hub_download_retry(repo_id, filename):
293
  raise err
294
 
295
 
296
- def get_tokenizer(config_name):
297
- tv, size = config_name.split("-")
298
- tv = tv[1:]
299
- if tv[-1] == "o":
300
- o = True
301
- tv = tv[:-1]
302
- else:
303
- o = False
304
- if tv not in ["v1", "v2"]:
305
- raise ValueError(f"Unknown tokenizer version {tv}")
306
- tokenizer = MIDITokenizer(tv)
307
- tokenizer.set_optimise_midi(o)
308
- return tokenizer
309
-
310
-
311
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
312
  40: "Blush", 48: "Orchestra"}
313
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
@@ -319,6 +281,7 @@ if __name__ == "__main__":
319
  parser = argparse.ArgumentParser()
320
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
321
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
 
322
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
323
  opt = parser.parse_args()
324
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
@@ -331,14 +294,17 @@ if __name__ == "__main__":
331
  "touhou finetune model (tv1-medium) by skytnt": ["skytnt/midi-model-ft", "touhou/", "tv1-medium"],
332
  }
333
  models = {}
334
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
 
335
  for name, (repo_id, path, config) in models_info.items():
336
- model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
337
- model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
338
- model_base = rt.InferenceSession(model_base_path, providers=providers)
339
- model_token = rt.InferenceSession(model_token_path, providers=providers)
340
- tokenizer = get_tokenizer(config)
341
- models[name] = [model_base, model_token, tokenizer]
 
342
 
343
  load_javascript()
344
  app = gr.Blocks()
@@ -447,6 +413,6 @@ if __name__ == "__main__":
447
  stop_btn.click(cancel_run, [input_model, output_midi_seq],
448
  [output_midi, output_audio, js_msg],
449
  cancels=run_event, queue=False)
450
- undo_btn.click(undo_continuation, [output_midi_seq, output_continuation_state],
451
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
452
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
1
  import argparse
2
  import glob
3
  import json
4
+ import os
5
  import time
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
  import tqdm
12
  from huggingface_hub import hf_hub_download
13
 
14
  import MIDI
15
+ from midi_model import MIDIModel, MIDIModelConfig
16
  from midi_synthesizer import MidiSynthesizer
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  in_space = os.getenv("SYSTEM") == "spaces"
20
 
21
 
22
+ @torch.inference_mode()
23
+ def generate(model: MIDIModel, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
25
+ tokenizer = model.tokenizer
26
  if disable_channels is not None:
27
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
28
  else:
29
  disable_channels = []
 
 
30
  max_token_seq = tokenizer.max_token_seq
31
  if prompt is None:
32
+ input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
33
  input_tensor[0, 0] = tokenizer.bos_id # bos
34
  else:
35
  prompt = prompt[:, :max_token_seq]
36
  if prompt.shape[-1] < max_token_seq:
37
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
38
  mode="constant", constant_values=tokenizer.pad_id)
39
+ input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
40
+ input_tensor = input_tensor.unsqueeze(0)
41
  cur_len = input_tensor.shape[1]
42
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
43
  with bar:
44
  while cur_len < max_len:
45
  end = False
46
+ hidden = model.forward(input_tensor)[0, -1].unsqueeze(0)
47
+ next_token_seq = None
48
  event_name = ""
49
  for i in range(max_token_seq):
50
+ mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=model.device)
51
  if i == 0:
52
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
53
  if disable_patch_change:
 
61
  if param_name == "channel":
62
  mask_ids = [i for i in mask_ids if i not in disable_channels]
63
  mask[mask_ids] = 1
64
+ logits = model.forward_token(hidden, next_token_seq)[:, -1:]
65
+ scores = torch.softmax(logits / temp, dim=-1) * mask
66
+ sample = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
67
  if i == 0:
68
  next_token_seq = sample
69
  eid = sample.item()
 
72
  break
73
  event_name = tokenizer.id_events[eid]
74
  else:
75
+ next_token_seq = torch.cat([next_token_seq, sample], dim=1)
76
  if len(tokenizer.events[event_name]) == i:
77
  break
78
  if next_token_seq.shape[1] < max_token_seq:
79
+ next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
80
+ "constant", value=tokenizer.pad_id)
81
+ next_token_seq = next_token_seq.unsqueeze(1)
82
+ input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
83
  cur_len += 1
84
  bar.update(1)
85
+ yield next_token_seq.reshape(-1).cpu().numpy()
86
  if end:
87
  break
88
 
 
99
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
100
  gen_events, temp, top_p, top_k, allow_cc):
101
  model = models[model_name]
102
+ tokenizer = model.tokenizer
103
  bpm = int(bpm)
104
  if time_sig == "auto":
105
  time_sig = None
 
121
  max_len = gen_events
122
  if seed_rand:
123
  seed = np.random.randint(0, MAX_SEED)
124
+ generator = torch.Generator(opt.device).manual_seed(seed)
125
  disable_patch_change = False
126
  disable_channels = None
127
  if tab == 0:
 
177
  init_msgs += [create_msg("visualizer_clear", tokenizer.version),
178
  create_msg("visualizer_append", events)]
179
  yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
180
+ ctx = torch.amp.autocast(device_type=opt.device, dtype=torch.bfloat16, enabled=opt.device != "cpu")
181
+ with ctx:
182
+ midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
183
+ disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
184
+ disable_channels=disable_channels, generator=generator)
185
+ events = []
186
+ t = time.time() + 1
187
+ for i, token_seq in enumerate(midi_generator):
188
+ token_seq = token_seq.tolist()
189
+ mid_seq.append(token_seq)
190
+ events.append(tokenizer.tokens2event(token_seq))
191
+ ct = time.time()
192
+ if ct - t > 0.5:
193
+ yield (mid_seq, continuation_state, None, None, seed,
194
+ send_msgs([create_msg("visualizer_append", events),
195
+ create_msg("progress", [i + 1, gen_events])]))
196
+ t = ct
197
+ events = []
198
 
199
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
200
  mid = tokenizer.detokenize(mid_seq)
 
211
  def cancel_run(model_name, mid_seq):
212
  if mid_seq is None:
213
  return None, None, []
214
+ tokenizer = models[model_name].tokenizer
215
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
216
  mid = tokenizer.detokenize(mid_seq)
217
  audio = synthesizer.synthesis(MIDI.score2opus(mid))
 
224
  return "output.mid", (44100, audio), send_msgs(end_msgs)
225
 
226
 
227
+ def undo_continuation(model_name, mid_seq, continuation_state):
228
  if mid_seq is None or len(continuation_state) < 2:
229
  return mid_seq, continuation_state, send_msgs([])
230
  mid_seq = mid_seq[:continuation_state[-1]]
231
  continuation_state = continuation_state[:-1]
232
+ tokenizer = models[model_name].tokenizer
233
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
234
  end_msgs = [create_msg("visualizer_clear", tokenizer.version),
235
  create_msg("visualizer_append", events),
 
270
  raise err
271
 
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
274
  40: "Blush", 48: "Orchestra"}
275
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
 
281
  parser = argparse.ArgumentParser()
282
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
283
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
284
+ parser.add_argument("--device", type=str, default="cuda", help="device to run model")
285
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
286
  opt = parser.parse_args()
287
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
 
294
  "touhou finetune model (tv1-medium) by skytnt": ["skytnt/midi-model-ft", "touhou/", "tv1-medium"],
295
  }
296
  models = {}
297
+ if opt.device == "cuda":
298
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
299
+ torch.backends.cuda.enable_flash_sdp(True)
300
  for name, (repo_id, path, config) in models_info.items():
301
+ model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
302
+ model = MIDIModel(config=MIDIModelConfig.from_name(config))
303
+ ckpt = torch.load(model_path, map_location="cpu")
304
+ state_dict = ckpt.get("state_dict", ckpt)
305
+ model.load_state_dict(state_dict, strict=False)
306
+ model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
307
+ models[name] = model
308
 
309
  load_javascript()
310
  app = gr.Blocks()
 
413
  stop_btn.click(cancel_run, [input_model, output_midi_seq],
414
  [output_midi, output_audio, js_msg],
415
  cancels=run_event, queue=False)
416
+ undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
417
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
418
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
midi_model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
@@ -5,23 +7,61 @@ import torch.nn.functional as F
5
  import tqdm
6
  from transformers import LlamaModel, LlamaConfig
7
 
8
- from midi_tokenizer import MIDITokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class MIDIModel(nn.Module):
12
- def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096,
13
- *args, **kwargs):
14
  super(MIDIModel, self).__init__()
15
- self.tokenizer = tokenizer
16
- self.net = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
17
- hidden_size=n_embd, num_attention_heads=n_head,
18
- num_hidden_layers=n_layer, intermediate_size=n_inner,
19
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
20
- self.net_token = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
21
- hidden_size=n_embd, num_attention_heads=n_head // 4,
22
- num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
23
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
24
- self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
25
  self.device = "cpu"
26
 
27
  def to(self, *args, **kwargs):
@@ -71,7 +111,7 @@ class MIDIModel(nn.Module):
71
  return next_token
72
 
73
  @torch.inference_mode()
74
- def generate(self, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, amp=True, generator=None):
75
  tokenizer = self.tokenizer
76
  max_token_seq = tokenizer.max_token_seq
77
  if prompt is None:
@@ -86,7 +126,7 @@ class MIDIModel(nn.Module):
86
  input_tensor = input_tensor.unsqueeze(0)
87
  cur_len = input_tensor.shape[1]
88
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
89
- with bar, torch.cuda.amp.autocast(enabled=amp):
90
  while cur_len < max_len:
91
  end = False
92
  hidden = self.forward(input_tensor)[0, -1].unsqueeze(0)
@@ -123,4 +163,4 @@ class MIDIModel(nn.Module):
123
  bar.update(1)
124
  if end:
125
  break
126
- return input_tensor[0].cpu().numpy()
 
1
+ from typing import Union
2
+
3
  import numpy as np
4
  import torch
5
  import torch.nn as nn
 
7
  import tqdm
8
  from transformers import LlamaModel, LlamaConfig
9
 
10
+ from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
11
+
12
+ config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
13
+
14
+
15
+ class MIDIModelConfig:
16
+ def __init__(self, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2],
17
+ net_config: LlamaConfig, net_token_config: LlamaConfig):
18
+ self.tokenizer = tokenizer
19
+ self.net_config = net_config
20
+ self.net_token_config = net_token_config
21
+ self.n_embd = net_token_config.hidden_size
22
+
23
+ @staticmethod
24
+ def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
25
+ tokenizer = MIDITokenizer(tokenizer_ver)
26
+ tokenizer.set_optimise_midi(optimise_midi)
27
+ net_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
28
+ hidden_size=n_embd, num_attention_heads=n_head,
29
+ num_hidden_layers=n_layer, intermediate_size=n_inner,
30
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096)
31
+ net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
32
+ hidden_size=n_embd, num_attention_heads=n_head // 4,
33
+ num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
34
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096)
35
+ return MIDIModelConfig(tokenizer, net_config, net_token_config)
36
+
37
+ @staticmethod
38
+ def from_name(name="tv2o-medium"):
39
+ tv, size = name.split("-")
40
+ tv = tv[1:]
41
+ if tv[-1] == "o":
42
+ o = True
43
+ tv = tv[:-1]
44
+ else:
45
+ o = False
46
+ if tv not in ["v1", "v2"]:
47
+ raise ValueError(f"Unknown tokenizer version {tv}")
48
+ if size == "medium":
49
+ return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
50
+ n_layer=12, n_head=16, n_embd=1024, n_inner=4096)
51
+ elif size == "large":
52
+ return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
53
+ n_layer=24, n_head=16, n_embd=1024, n_inner=4096)
54
+ else:
55
+ raise ValueError(f"Unknown model size {size}")
56
 
57
 
58
  class MIDIModel(nn.Module):
59
+ def __init__(self, config: MIDIModelConfig, *args, **kwargs):
 
60
  super(MIDIModel, self).__init__()
61
+ self.tokenizer = config.tokenizer
62
+ self.net = LlamaModel(config.net_config)
63
+ self.net_token = LlamaModel(config.net_token_config)
64
+ self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
 
 
 
 
 
 
65
  self.device = "cpu"
66
 
67
  def to(self, *args, **kwargs):
 
111
  return next_token
112
 
113
  @torch.inference_mode()
114
+ def generate(self, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
115
  tokenizer = self.tokenizer
116
  max_token_seq = tokenizer.max_token_seq
117
  if prompt is None:
 
126
  input_tensor = input_tensor.unsqueeze(0)
127
  cur_len = input_tensor.shape[1]
128
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
129
+ with bar:
130
  while cur_len < max_len:
131
  end = False
132
  hidden = self.forward(input_tensor)[0, -1].unsqueeze(0)
 
163
  bar.update(1)
164
  if end:
165
  break
166
+ return input_tensor[0].cpu().numpy()
requirements.txt CHANGED
@@ -1,6 +1,8 @@
 
1
  Pillow
2
  numpy
3
- onnxruntime-gpu
 
4
  gradio==4.43.0
5
  pyfluidsynth
6
  tqdm
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
  Pillow
3
  numpy
4
+ torch
5
+ transformers>=4.36
6
  gradio==4.43.0
7
  pyfluidsynth
8
  tqdm