skytnt commited on
Commit
2f27e32
1 Parent(s): b37778c
Files changed (4) hide show
  1. app.py +9 -14
  2. app_onnx.py +12 -19
  3. midi_model.py +53 -23
  4. midi_tokenizer.py +29 -0
app.py CHANGED
@@ -365,19 +365,19 @@ if __name__ == "__main__":
365
  synthesizer = MidiSynthesizer(soundfont_path)
366
  models_info = {
367
  "generic pretrain model (tv2o-medium) by skytnt": [
368
- "skytnt/midi-model-tv2o-medium", "", "tv2o-medium", {
369
  "jpop": "skytnt/midi-model-tv2om-jpop-lora",
370
  "touhou": "skytnt/midi-model-tv2om-touhou-lora"
371
  }
372
  ],
373
  "generic pretrain model (tv2o-large) by asigalov61": [
374
- "asigalov61/Music-Llama", "", "tv2o-large", {}
375
  ],
376
  "generic pretrain model (tv2o-medium) by asigalov61": [
377
- "asigalov61/Music-Llama-Medium", "", "tv2o-medium", {}
378
  ],
379
  "generic pretrain model (tv1-medium) by skytnt": [
380
- "skytnt/midi-model", "", "tv1-medium", {}
381
  ]
382
  }
383
  models = {}
@@ -388,20 +388,15 @@ if __name__ == "__main__":
388
  torch.backends.cudnn.allow_tf32 = True
389
  torch.backends.cuda.enable_mem_efficient_sdp(True)
390
  torch.backends.cuda.enable_flash_sdp(True)
391
- for name, (repo_id, path, config, loras) in models_info.items():
392
- model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
393
- model = MIDIModel(config=MIDIModelConfig.from_name(config))
394
- ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
395
- state_dict = ckpt.get("state_dict", ckpt)
396
- model.load_state_dict(state_dict, strict=False)
397
- model.to(device="cpu", dtype=torch.float32).eval()
398
  models[name] = model
399
  for lora_name, lora_repo in loras.items():
400
- model = MIDIModel(config=MIDIModelConfig.from_name(config))
401
- model.load_state_dict(state_dict, strict=False)
402
  print(f"loading lora {lora_repo} for {name}")
403
  model = model.load_merge_lora(lora_repo)
404
- model.to(device="cpu", dtype=torch.float32).eval()
405
  models[f"{name} with {lora_name} lora"] = model
406
 
407
  load_javascript()
 
365
  synthesizer = MidiSynthesizer(soundfont_path)
366
  models_info = {
367
  "generic pretrain model (tv2o-medium) by skytnt": [
368
+ "skytnt/midi-model-tv2o-medium", {
369
  "jpop": "skytnt/midi-model-tv2om-jpop-lora",
370
  "touhou": "skytnt/midi-model-tv2om-touhou-lora"
371
  }
372
  ],
373
  "generic pretrain model (tv2o-large) by asigalov61": [
374
+ "asigalov61/Music-Llama", {}
375
  ],
376
  "generic pretrain model (tv2o-medium) by asigalov61": [
377
+ "asigalov61/Music-Llama-Medium", {}
378
  ],
379
  "generic pretrain model (tv1-medium) by skytnt": [
380
+ "skytnt/midi-model", {}
381
  ]
382
  }
383
  models = {}
 
388
  torch.backends.cudnn.allow_tf32 = True
389
  torch.backends.cuda.enable_mem_efficient_sdp(True)
390
  torch.backends.cuda.enable_flash_sdp(True)
391
+ for name, (repo_id, loras) in models_info.items():
392
+ model = MIDIModel.from_pretrained(repo_id)
393
+ model.to(device="cpu", dtype=torch.float32)
 
 
 
 
394
  models[name] = model
395
  for lora_name, lora_repo in loras.items():
396
+ model = MIDIModel.from_pretrained(repo_id)
 
397
  print(f"loading lora {lora_repo} for {name}")
398
  model = model.load_merge_lora(lora_repo)
399
+ model.to(device="cpu", dtype=torch.float32)
400
  models[f"{name} with {lora_name} lora"] = model
401
 
402
  load_javascript()
app_onnx.py CHANGED
@@ -432,18 +432,12 @@ def hf_hub_download_retry(repo_id, filename):
432
  raise err
433
 
434
 
435
- def get_tokenizer(config_name):
436
- tv, size = config_name.split("-")
437
- tv = tv[1:]
438
- if tv[-1] == "o":
439
- o = True
440
- tv = tv[:-1]
441
- else:
442
- o = False
443
- if tv not in ["v1", "v2"]:
444
- raise ValueError(f"Unknown tokenizer version {tv}")
445
- tokenizer = MIDITokenizer(tv)
446
- tokenizer.set_optimise_midi(o)
447
  return tokenizer
448
 
449
 
@@ -468,34 +462,33 @@ if __name__ == "__main__":
468
  synthesizer = MidiSynthesizer(soundfont_path)
469
  models_info = {
470
  "generic pretrain model (tv2o-medium) by skytnt": [
471
- "skytnt/midi-model-tv2o-medium", "", "tv2o-medium", {
472
  "jpop": "skytnt/midi-model-tv2om-jpop-lora",
473
  "touhou": "skytnt/midi-model-tv2om-touhou-lora"
474
  }
475
  ],
476
  "generic pretrain model (tv2o-large) by asigalov61": [
477
- "asigalov61/Music-Llama", "", "tv2o-large", {}
478
  ],
479
  "generic pretrain model (tv2o-medium) by asigalov61": [
480
- "asigalov61/Music-Llama-Medium", "", "tv2o-medium", {}
481
  ],
482
  "generic pretrain model (tv1-medium) by skytnt": [
483
- "skytnt/midi-model", "", "tv1-medium", {}
484
  ]
485
  }
486
  models = {}
487
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
488
  device = "cuda"
489
 
490
- for name, (repo_id, path, config, loras) in models_info.items():
491
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
492
  model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
493
- tokenizer = get_tokenizer(config)
494
  models[name] = [model_base_path, model_token_path, tokenizer]
495
  for lora_name, lora_repo in loras.items():
496
  model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
497
  model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
498
- tokenizer = get_tokenizer(config)
499
  models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
500
 
501
  load_javascript()
 
432
  raise err
433
 
434
 
435
+ def get_tokenizer(repo_id):
436
+ config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
437
+ with open(config_path, "r") as f:
438
+ config = json.load(f)
439
+ tokenizer = MIDITokenizer(config["tokenizer"]["version"])
440
+ tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
 
 
 
 
 
 
441
  return tokenizer
442
 
443
 
 
462
  synthesizer = MidiSynthesizer(soundfont_path)
463
  models_info = {
464
  "generic pretrain model (tv2o-medium) by skytnt": [
465
+ "skytnt/midi-model-tv2o-medium", "", {
466
  "jpop": "skytnt/midi-model-tv2om-jpop-lora",
467
  "touhou": "skytnt/midi-model-tv2om-touhou-lora"
468
  }
469
  ],
470
  "generic pretrain model (tv2o-large) by asigalov61": [
471
+ "asigalov61/Music-Llama", "", {}
472
  ],
473
  "generic pretrain model (tv2o-medium) by asigalov61": [
474
+ "asigalov61/Music-Llama-Medium", "", {}
475
  ],
476
  "generic pretrain model (tv1-medium) by skytnt": [
477
+ "skytnt/midi-model", "", {}
478
  ]
479
  }
480
  models = {}
481
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
482
  device = "cuda"
483
 
484
+ for name, (repo_id, path, loras) in models_info.items():
485
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
486
  model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
487
+ tokenizer = get_tokenizer(repo_id)
488
  models[name] = [model_base_path, model_token_path, tokenizer]
489
  for lora_name, lora_repo in loras.items():
490
  model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
491
  model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
 
492
  models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
493
 
494
  load_javascript()
midi_model.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Union
 
2
 
3
  import numpy as np
4
  import torch
@@ -6,21 +7,57 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  import tqdm
8
  from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
9
- from transformers import LlamaModel, LlamaConfig, DynamicCache
10
- from transformers.integrations import PeftAdapterMixin
11
 
12
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
13
 
14
  config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
15
 
16
 
17
- class MIDIModelConfig:
18
- def __init__(self, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2],
19
- net_config: LlamaConfig, net_token_config: LlamaConfig):
20
- self.tokenizer = tokenizer
21
- self.net_config = net_config
22
- self.net_token_config = net_token_config
23
- self.n_embd = net_token_config.hidden_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @staticmethod
26
  def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
@@ -59,27 +96,20 @@ class MIDIModelConfig:
59
  raise ValueError(f"Unknown model size {size}")
60
 
61
 
62
- class MIDIModel(nn.Module, PeftAdapterMixin):
 
 
63
  def __init__(self, config: MIDIModelConfig, *args, **kwargs):
64
- super(MIDIModel, self).__init__()
65
  self.tokenizer = config.tokenizer
66
  self.net = LlamaModel(config.net_config)
67
  self.net_token = LlamaModel(config.net_token_config)
68
  self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
69
- self.device = "cpu"
70
-
71
- def to(self, *args, **kwargs):
72
- if "device" in kwargs:
73
- self.device = kwargs["device"]
74
- return super(MIDIModel, self).to(*args, **kwargs)
75
-
76
- def peft_loaded(self):
77
- return self._hf_peft_config_loaded
78
 
79
  def load_merge_lora(self, model_id):
80
  peft_config = PeftConfig.from_pretrained(model_id)
81
  model = LoraModel(self, peft_config, adapter_name="default")
82
- adapter_state_dict = load_peft_weights(model_id, device=self.device)
83
  set_peft_model_state_dict(self, adapter_state_dict, "default")
84
  return model.merge_and_unload()
85
 
@@ -164,7 +194,7 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
164
  with bar:
165
  while cur_len < max_len:
166
  end = [False] * batch_size
167
- hidden = self.forward(input_tensor[:,past_len:], cache=cache1)[:, -1]
168
  next_token_seq = None
169
  event_names = [""] * batch_size
170
  cache2 = DynamicCache()
 
1
+ import json
2
+ from typing import Union, Dict, Any
3
 
4
  import numpy as np
5
  import torch
 
7
  import torch.nn.functional as F
8
  import tqdm
9
  from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
10
+ from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel
 
11
 
12
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
13
 
14
  config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
15
 
16
 
17
+ class MIDIModelConfig(PretrainedConfig):
18
+ model_type = "midi_model"
19
+
20
+ def __init__(self,
21
+ tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
22
+ net_config: Union[LlamaConfig, Dict]=None,
23
+ net_token_config: Union[LlamaConfig, Dict]=None,
24
+ **kwargs):
25
+ super().__init__(**kwargs)
26
+ if tokenizer:
27
+ if isinstance(tokenizer, dict):
28
+ self.tokenizer = MIDITokenizer(tokenizer["version"])
29
+ self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"])
30
+ else:
31
+ self.tokenizer = tokenizer
32
+ else:
33
+ self.tokenizer = MIDITokenizer()
34
+ if net_config:
35
+ if isinstance(net_config, dict):
36
+ self.net_config = LlamaConfig(**net_config)
37
+ else:
38
+ self.net_config = net_config
39
+ else:
40
+ self.net_config = LlamaConfig()
41
+ if net_token_config:
42
+ if isinstance(net_token_config, dict):
43
+ self.net_token_config = LlamaConfig(**net_token_config)
44
+ else:
45
+ self.net_token_config = net_token_config
46
+ else:
47
+ self.net_token_config = LlamaConfig()
48
+ self.n_embd = self.net_token_config.hidden_size
49
+
50
+ def to_dict(self) -> Dict[str, Any]:
51
+ d = super().to_dict()
52
+ d["tokenizer"] = self.tokenizer.to_dict()
53
+ return d
54
+
55
+ def __str__(self):
56
+ d = {
57
+ "net": self.net_config.to_json_string(use_diff=False),
58
+ "net_token": self.net_token_config.to_json_string(use_diff=False)
59
+ }
60
+ return json.dumps(d, indent=4)
61
 
62
  @staticmethod
63
  def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
 
96
  raise ValueError(f"Unknown model size {size}")
97
 
98
 
99
+ class MIDIModel(PreTrainedModel):
100
+ config_class = MIDIModelConfig
101
+
102
  def __init__(self, config: MIDIModelConfig, *args, **kwargs):
103
+ super(MIDIModel, self).__init__(config, *args, **kwargs)
104
  self.tokenizer = config.tokenizer
105
  self.net = LlamaModel(config.net_config)
106
  self.net_token = LlamaModel(config.net_token_config)
107
  self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
 
 
 
 
 
 
 
 
 
108
 
109
  def load_merge_lora(self, model_id):
110
  peft_config = PeftConfig.from_pretrained(model_id)
111
  model = LoraModel(self, peft_config, adapter_name="default")
112
+ adapter_state_dict = load_peft_weights(model_id, device=str(self.device))
113
  set_peft_model_state_dict(self, adapter_state_dict, "default")
114
  return model.merge_and_unload()
115
 
 
194
  with bar:
195
  while cur_len < max_len:
196
  end = [False] * batch_size
197
+ hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
198
  next_token_seq = None
199
  event_names = [""] * batch_size
200
  cache2 = DynamicCache()
midi_tokenizer.py CHANGED
@@ -1,4 +1,5 @@
1
  import random
 
2
 
3
  import PIL.Image
4
  import numpy as np
@@ -33,6 +34,20 @@ class MIDITokenizerV1:
33
  self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
34
  self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def set_optimise_midi(self, optimise_midi=True):
37
  self.optimise_midi = optimise_midi
38
 
@@ -519,6 +534,20 @@ class MIDITokenizerV2:
519
  self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
520
  self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  def set_optimise_midi(self, optimise_midi=True):
523
  self.optimise_midi = optimise_midi
524
 
 
1
  import random
2
+ from typing import Dict, Any
3
 
4
  import PIL.Image
5
  import numpy as np
 
34
  self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
35
  self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
36
 
37
+ def to_dict(self) -> Dict[str, Any]:
38
+ d = {
39
+ "version":self.version,
40
+ "optimise_midi":self.optimise_midi,
41
+ "vocab_size": self.vocab_size,
42
+ "events": self.events,
43
+ "event_parameters": self.event_parameters,
44
+ "max_token_seq": self.max_token_seq,
45
+ "pad_id": self.pad_id,
46
+ "bos_id": self.bos_id,
47
+ "eos_id": self.eos_id,
48
+ }
49
+ return d
50
+
51
  def set_optimise_midi(self, optimise_midi=True):
52
  self.optimise_midi = optimise_midi
53
 
 
534
  self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
535
  self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
536
 
537
+ def to_dict(self) -> Dict[str, Any]:
538
+ d = {
539
+ "version":self.version,
540
+ "optimise_midi":self.optimise_midi,
541
+ "vocab_size": self.vocab_size,
542
+ "events": self.events,
543
+ "event_parameters": self.event_parameters,
544
+ "max_token_seq": self.max_token_seq,
545
+ "pad_id": self.pad_id,
546
+ "bos_id": self.bos_id,
547
+ "eos_id": self.eos_id,
548
+ }
549
+ return d
550
+
551
  def set_optimise_midi(self, optimise_midi=True):
552
  self.optimise_midi = optimise_midi
553