import json from typing import Union, Dict, Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import tqdm from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"] class MIDIModelConfig(PretrainedConfig): model_type = "midi_model" def __init__(self, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None, net_config: Union[LlamaConfig, Dict]=None, net_token_config: Union[LlamaConfig, Dict]=None, **kwargs): super().__init__(**kwargs) if tokenizer: if isinstance(tokenizer, dict): self.tokenizer = MIDITokenizer(tokenizer["version"]) self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"]) else: self.tokenizer = tokenizer else: self.tokenizer = MIDITokenizer() if net_config: if isinstance(net_config, dict): self.net_config = LlamaConfig(**net_config) else: self.net_config = net_config else: self.net_config = LlamaConfig() if net_token_config: if isinstance(net_token_config, dict): self.net_token_config = LlamaConfig(**net_token_config) else: self.net_token_config = net_token_config else: self.net_token_config = LlamaConfig() self.n_embd = self.net_token_config.hidden_size def to_dict(self) -> Dict[str, Any]: d = super().to_dict() d["tokenizer"] = self.tokenizer.to_dict() return d def __str__(self): d = { "net": self.net_config.to_json_string(use_diff=False), "net_token": self.net_token_config.to_json_string(use_diff=False) } return json.dumps(d, indent=4) @staticmethod def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096): tokenizer = MIDITokenizer(tokenizer_ver) tokenizer.set_optimise_midi(optimise_midi) net_config = LlamaConfig(vocab_size=tokenizer.vocab_size, hidden_size=n_embd, num_attention_heads=n_head, num_hidden_layers=n_layer, intermediate_size=n_inner, pad_token_id=tokenizer.pad_id, max_position_embeddings=4096, use_cache=False) net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size, hidden_size=n_embd, num_attention_heads=n_head // 4, num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4, pad_token_id=tokenizer.pad_id, max_position_embeddings=4096, use_cache=False) return MIDIModelConfig(tokenizer, net_config, net_token_config) @staticmethod def from_name(name="tv2o-medium"): tv, size = name.split("-") tv = tv[1:] if tv[-1] == "o": o = True tv = tv[:-1] else: o = False if tv not in ["v1", "v2"]: raise ValueError(f"Unknown tokenizer version {tv}") if size == "medium": return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o, n_layer=12, n_head=16, n_embd=1024, n_inner=4096) elif size == "large": return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o, n_layer=24, n_head=16, n_embd=1024, n_inner=4096) else: raise ValueError(f"Unknown model size {size}") class MIDIModel(PreTrainedModel): config_class = MIDIModelConfig def __init__(self, config: MIDIModelConfig, *args, **kwargs): super(MIDIModel, self).__init__(config, *args, **kwargs) self.tokenizer = config.tokenizer self.net = LlamaModel(config.net_config) self.net_token = LlamaModel(config.net_token_config) self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False) def load_merge_lora(self, model_id): peft_config = PeftConfig.from_pretrained(model_id) model = LoraModel(self, peft_config, adapter_name="default") adapter_state_dict = load_peft_weights(model_id, device=str(self.device)) set_peft_model_state_dict(self, adapter_state_dict, "default") return model.merge_and_unload() def forward_token(self, hidden_state=None, x=None, cache=None): """ :param hidden_state: (batch_size, n_embd) :param x: (batch_size, token_sequence_length) :param cache: Cache :return: (batch_size, 1 + token_sequence_length, vocab_size) """ if hidden_state is not None: #if you use cache, you don't need to pass in hidden_state hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd) if x is not None: x = self.net_token.embed_tokens(x) if hidden_state is not None: x = torch.cat([hidden_state, x], dim=1) hidden_state = x hidden_state = self.net_token.forward(inputs_embeds=hidden_state, past_key_values=cache, use_cache=cache is not None).last_hidden_state return self.lm_head(hidden_state) def forward(self, x, cache = None): """ :param x: (batch_size, midi_sequence_length, token_sequence_length) :param cache: Cache :return: hidden (batch_size, midi_sequence_length, n_embd) """ # merge token sequence x = self.net.embed_tokens(x) x = x.sum(dim=-2) x = self.net.forward(inputs_embeds=x, past_key_values=cache, use_cache=cache is not None) return x.last_hidden_state def sample_top_p_k(self, probs, p, k, generator=None): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device) mask[:k] = 1 probs_sort = probs_sort * mask probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) shape = probs_sort.shape next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]), num_samples=1, generator=generator).reshape(*shape[:-1], 1) next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1]) return next_token @torch.inference_mode() def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None): tokenizer = self.tokenizer max_token_seq = tokenizer.max_token_seq if prompt is None: input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device) input_tensor[0, 0] = tokenizer.bos_id # bos input_tensor = input_tensor.unsqueeze(0) input_tensor = torch.cat([input_tensor] * batch_size, dim=0) else: if len(prompt.shape) == 2: prompt = prompt[None, :] prompt = np.repeat(prompt, repeats=batch_size, axis=0) elif prompt.shape[0] == 1: prompt = np.repeat(prompt, repeats=batch_size, axis=0) elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size: raise ValueError(f"invalid shape for prompt, {prompt.shape}") prompt = prompt[..., :max_token_seq] if prompt.shape[-1] < max_token_seq: prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])), mode="constant", constant_values=tokenizer.pad_id) input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device) cur_len = input_tensor.shape[1] bar = tqdm.tqdm(desc="generating", total=max_len - cur_len) cache1 = DynamicCache() past_len = 0 with bar: while cur_len < max_len: end = [False] * batch_size hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1] next_token_seq = None event_names = [""] * batch_size cache2 = DynamicCache() for i in range(max_token_seq): mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device) for b in range(batch_size): if end[b]: mask[b, tokenizer.pad_id] = 1 continue if i == 0: mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1 else: param_names = tokenizer.events[event_names[b]] if i > len(param_names): mask[b, tokenizer.pad_id] = 1 continue mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1 mask = mask.unsqueeze(1) x = next_token_seq if i != 0: # cached hidden = None x = x[:, -1:] logits = self.forward_token(hidden, x, cache=cache2)[:, -1:] scores = torch.softmax(logits / temp, dim=-1) * mask samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator) if i == 0: next_token_seq = samples for b in range(batch_size): if end[b]: continue eid = samples[b].item() if eid == tokenizer.eos_id: end[b] = True else: event_names[b] = tokenizer.id_events[eid] else: next_token_seq = torch.cat([next_token_seq, samples], dim=1) if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]): break if next_token_seq.shape[1] < max_token_seq: next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]), "constant", value=tokenizer.pad_id) next_token_seq = next_token_seq.unsqueeze(1) input_tensor = torch.cat([input_tensor, next_token_seq], dim=1) past_len = cur_len cur_len += 1 bar.update(1) if all(end): break return input_tensor.cpu().numpy()