TristanBehrens commited on
Commit
87ae0b7
1 Parent(s): 2d7a385

Initial commit

Browse files
Files changed (5) hide show
  1. app.py +245 -0
  2. assets/asciilogo.txt +11 -0
  3. requirements.txt +11 -0
  4. source/languagemodel.py +288 -0
  5. source/utilities.py +331 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from source.languagemodel import LanguageModel
3
+ from source.utilities import (
4
+ convert_tokens_to_songdata,
5
+ convert_songdata_to_notesequence,
6
+ convert_songdata_to_pianoroll,
7
+ convert_notesequence_to_wave,
8
+ convert_notesequence_to_midi
9
+ )
10
+
11
+ # Define the MIDI instruments.
12
+ midi_instruments = {
13
+ "Harpsichord": 6,
14
+ "Church Organ": 19,
15
+ "Piano": 0,
16
+ }
17
+
18
+ # Load the model once and cache it.
19
+ @st.cache_resource
20
+ def load_model():
21
+ model = LanguageModel("TristanBehrens/bach-garland-mambaplus")
22
+ return model
23
+ model = load_model()
24
+
25
+
26
+ # Initialize token_sequence in session state if it doesn't exist
27
+ if "token_sequence" not in st.session_state:
28
+ st.session_state.token_sequence = "GARLAND_START"
29
+ st.session_state.song_data = None
30
+ st.session_state.piano_roll = None
31
+ st.session_state.wave = None
32
+ st.session_state.note_sequence = None
33
+ st.session_state.midi_file_content = None
34
+ st.session_state.temperature = 0.1
35
+ st.session_state.bpm = 100
36
+ st.session_state.instrument = "Piano"
37
+
38
+
39
+ # Define the main function.
40
+ def main():
41
+
42
+ columns = st.columns([0.7, 0.3])
43
+
44
+ # Set up the Streamlit application
45
+ column = columns.pop(0)
46
+ with column:
47
+
48
+ # Change the colors of the a-tag to (255, 75, 75).
49
+ st.markdown("<style>a:link { color: #FF4B4B; } a:visited { color: #FF4B4B; }</style>", unsafe_allow_html=True)
50
+
51
+ # Add a title.
52
+ st.title("Garland Composer")
53
+ linkedin_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
54
+ x_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
55
+ st.write(f"By Dr. Tristan Behrens. Find me on [LinkedIn]({linkedin_url}) and [X]({x_url}).")
56
+ hf_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
57
+ st.write(f"Model available on [Hugging Face]({hf_url}).")
58
+
59
+ # Add a picture.
60
+ column = columns.pop(0)
61
+ with column:
62
+ st.write(" ")
63
+ st.write(" ")
64
+ st.write(" ")
65
+ st.image("garland.jpg", use_column_width=True)
66
+
67
+ # Add a horizontal line.
68
+ st.markdown("---")
69
+
70
+ # Create two columns.
71
+ columns = st.columns(3)
72
+
73
+ # Add a slider to control the temperature.
74
+ state_temperature = st.session_state.temperature
75
+ with columns.pop(0):
76
+ temperature = st.slider("Temperature", 0.0, 1.0, state_temperature)
77
+ st.session_state.temperature = temperature
78
+
79
+ # Add a slider to control the bpm.
80
+ state_bpm = st.session_state.bpm
81
+ with columns.pop(0):
82
+ bpm = st.slider("BPM", 80, 120, state_bpm, 5)
83
+ st.session_state.bpm = bpm
84
+
85
+ # Dropdown for the instrument.
86
+ state_instrument = st.session_state.instrument
87
+ with columns.pop(0):
88
+ instrument = st.selectbox("Instrument", list(midi_instruments.keys()), index=list(midi_instruments.keys()).index(state_instrument))
89
+ st.session_state.instrument = instrument
90
+
91
+ # Get the token sequence from the session state.
92
+ token_sequence = st.session_state.token_sequence
93
+
94
+ # Columns for the buttons.
95
+ columns = st.columns(5)
96
+
97
+ # Add a button to generate the next bar.
98
+ column = columns.pop(0)
99
+ with column:
100
+ if st.button("Add a bar", use_container_width=True):
101
+ token_sequence = extend_sequence(model, token_sequence, temperature)
102
+ refresh(token_sequence, bpm, instrument)
103
+
104
+ # Add a button to compose long.
105
+ column = columns.pop(0)
106
+ with column:
107
+ if st.button("Auto compose", use_container_width=True):
108
+ token_sequence = auto_compose(model, token_sequence, temperature)
109
+ refresh(token_sequence, bpm, instrument)
110
+
111
+ # Add a button to remove the last bar.
112
+ column = columns.pop(0)
113
+ with column:
114
+ if st.button("Remove last", use_container_width=True):
115
+ token_sequence = shortened_sequence(token_sequence)
116
+ refresh(token_sequence, bpm, instrument)
117
+
118
+ # Add a button to reset the sequence.
119
+ column = columns.pop(0)
120
+ if token_sequence != "GARLAND_START":
121
+ with column:
122
+ if st.button("Reset", use_container_width=True):
123
+ with columns.pop(0):
124
+ token_sequence = "GARLAND_START"
125
+ refresh(token_sequence, bpm, instrument)
126
+
127
+ # Provide a download button for the MIDI file.
128
+ column = columns.pop(0)
129
+ if "midi_file_content" in st.session_state and st.session_state.midi_file_content is not None:
130
+ with column:
131
+ midi_file_content = st.session_state.midi_file_content
132
+ if st.download_button(
133
+ label="Download MIDI",
134
+ data=midi_file_content,
135
+ file_name="music.mid",
136
+ mime="audio/midi",
137
+ use_container_width=True
138
+ ):
139
+ pass
140
+
141
+ # Add a horizontal line.
142
+ st.markdown("---")
143
+
144
+ # Display the piano roll.
145
+ if "piano_roll" in st.session_state and st.session_state.piano_roll is not None:
146
+ st.image(st.session_state.piano_roll)
147
+
148
+ # Display an audio player.
149
+ if "wave" in st.session_state and st.session_state.wave is not None:
150
+ st.audio(st.session_state.wave, format="audio/wav", sample_rate=44100, autoplay=True)
151
+
152
+ # Add a horizontal line.
153
+ st.markdown("---")
154
+
155
+ # Set the text color to (255, 31, 75).
156
+ if token_sequence.endswith("GARLAND_END"):
157
+ st.write("The AI believes that the music is finished.")
158
+ else:
159
+ st.write("The AI believes that the music is not finished.")
160
+
161
+
162
+ def auto_compose(model, token_sequence, temperature):
163
+
164
+ max_iterations = 100
165
+ for _ in range(max_iterations):
166
+ token_sequence = extend_sequence(model, token_sequence, temperature)
167
+ if token_sequence.endswith("GARLAND_END"):
168
+ break
169
+ return token_sequence
170
+
171
+
172
+ def extend_sequence(model, token_sequence, temperature):
173
+
174
+ # Replace the last GARLAND_END token with NEXT.
175
+ if token_sequence.endswith("GARLAND_END"):
176
+ token_sequence = token_sequence.replace("GARLAND_END", "NEXT")
177
+
178
+ # The maximum length of the generated music.
179
+ max_length = 16_384
180
+
181
+ # When to stop the generation.
182
+ end_tokens = ["NEXT", "GARLAND_END"]
183
+
184
+ # Compose the music iterativelybar by bar.
185
+ output_dict = model.generate(
186
+ prompt=token_sequence,
187
+ temperature=temperature,
188
+ max_length=max_length,
189
+ end_tokens=end_tokens,
190
+ forbidden_tokens=["[PAD]", "[EOS]"],
191
+ return_structured_output=True
192
+ )
193
+ output = output_dict["output"]
194
+ return output
195
+
196
+
197
+ def shortened_sequence(token_sequence):
198
+
199
+ # Find the position of the next to last NEXT token.
200
+ next_tokens = token_sequence.split()
201
+ next_positions = [i for i, x in enumerate(next_tokens) if x == "NEXT" or x == "GARLAND_END"]
202
+ if len(next_positions) <= 1:
203
+ token_sequence = "GARLAND_START"
204
+ else:
205
+ next_position = next_positions[-2]
206
+ token_sequence = " ".join(next_tokens[:next_position + 1])
207
+ return token_sequence
208
+
209
+
210
+ def refresh(token_sequence="GARLAND_START", bpm=120, instrument="Piano"):
211
+
212
+ # Get the token sequence into the session state.
213
+ st.session_state.token_sequence = token_sequence
214
+
215
+ # Convert to song data.
216
+ song_data = convert_tokens_to_songdata(token_sequence)
217
+ song_data["bpm"] = bpm
218
+ st.session_state.song_data = song_data
219
+
220
+ # Set the instrument.
221
+ for track in song_data["tracks"]:
222
+ track["instrument"] = midi_instruments[instrument]
223
+
224
+ # Convert to piano roll.
225
+ piano_roll = convert_songdata_to_pianoroll(song_data)
226
+ st.session_state.piano_roll = piano_roll
227
+
228
+ # Convert to note sequence.
229
+ note_sequence = convert_songdata_to_notesequence(song_data)
230
+ st.session_state.note_sequence = note_sequence
231
+
232
+ # Play the note sequence.
233
+ wave = convert_notesequence_to_wave(note_sequence)
234
+ st.session_state.wave = wave
235
+
236
+ # Get the MIDI file content.
237
+ midi_file_content = convert_notesequence_to_midi(note_sequence)
238
+ st.session_state.midi_file_content = midi_file_content
239
+
240
+ # Rerun the app.
241
+ st.rerun()
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
assets/asciilogo.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ▄█ █▄ ▄████████ ▄█ ▄█ ▀█████████▄ ▄████████ ███ █▄ ███▄▄▄▄ ███▄▄▄▄ ▄████████
2
+ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███▀▀▀██▄ ███▀▀▀██▄ ███ ███
3
+ ███ ███ ███ █▀ ███ ███▌ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███
4
+ ▄███▄▄▄▄███▄▄ ▄███▄▄▄ ███ ███▌ ▄███▄▄▄██▀ ▄███▄▄▄▄██▀ ███ ███ ███ ███ ███ ███ ███ ███
5
+ ▀▀███▀▀▀▀███▀ ▀▀███▀▀▀ ███ ███▌ ▀▀███▀▀▀██▄ ▀▀███▀▀▀▀▀ ███ ███ ███ ███ ███ ███ ▀███████████
6
+ ███ ███ ███ █▄ ███ ███ ███ ██▄ ▀███████████ ███ ███ ███ ███ ███ ███ ███ ███
7
+ ███ ███ ███ ███ ███▌ ▄ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███
8
+ ███ █▀ ██████████ █████▄▄██ █▀ ▄█████████▀ ███ ███ ████████▀ ▀█ █▀ ▀█ █▀ ███ █▀
9
+ ▀ ███ ███
10
+
11
+ By Dr. Tristan Behrens
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dacite==1.8.1
2
+ colorama==0.4.6
3
+ omegaconf==2.3.0
4
+ streamlit==1.38.0
5
+ note_seq==0.0.5
6
+ pyfluidsynth==1.3.2
7
+ torch==2.2.0
8
+ transformers==4.44.0
9
+ mamba-ssm==2.2.2
10
+ einops==0.8.0
11
+ mambapy==1.2.0
source/languagemodel.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Helibrunna - A HuggingFace compatible xLSTM trainer.
2
+ # Copyright (c) 2024 Dr. Tristan Behrens
3
+ #
4
+ # This program is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU Affero General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # This program is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU Affero General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU Affero General Public License
15
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
16
+
17
+ import os
18
+ import glob
19
+ from omegaconf import OmegaConf
20
+ from transformers import PreTrainedTokenizerFast
21
+ import torch
22
+ from safetensors.torch import load_file
23
+ import time
24
+ from .utilities import display_logo, model_from_config
25
+
26
+
27
+ class LanguageModel:
28
+
29
+ def __init__(self, model_path_or_repo, config_overrides={}, mask_special_tokens=True, device="auto"):
30
+ """
31
+ Initializes the LanguageModel object.
32
+ Args:
33
+ model_path_or_repo (str): The path to the model or the repository ID.
34
+ Raises:
35
+ ValueError: If the model checkpoint, tokenizer, config, or weights are not found.
36
+ Exception: If failed to download the model.
37
+ Returns:
38
+ None
39
+ """
40
+
41
+ # Set the maskt_special_tokens flag.
42
+ self.mask_special_tokens = mask_special_tokens
43
+
44
+ # Set the device. CPU is default.
45
+ if device != "auto":
46
+
47
+ # Check if CUDA is available.
48
+ if not torch.cuda.is_available() and device == "cuda":
49
+ raise ValueError("CUDA is not available on this system.")
50
+
51
+ # Check if MPS is available.
52
+ if not torch.backends.mps.is_available() and device == "mps":
53
+ raise ValueError("MPS is not available on this system.")
54
+
55
+ # Set the device.
56
+ self.device = device
57
+
58
+ # Set the device to auto.
59
+ else:
60
+
61
+ # Set the device to CPU if auto is selected.
62
+ self.device = "cpu" if device == "auto" else device
63
+
64
+ # Check if CUDA is available.
65
+ if torch.cuda.is_available() and device == "auto":
66
+ self.device = "cuda"
67
+
68
+ # See if MPS is available.
69
+ # Note: This is disabled for now. It's not working as expected. It is very slow.
70
+ #if torch.backends.mps.is_available():
71
+ # self.device = "mps"
72
+
73
+ # Display the logo.
74
+ display_logo()
75
+
76
+ # Download the model if it doesn't exist. Or at least try to.
77
+ if not os.path.exists(model_path_or_repo):
78
+ from huggingface_hub import snapshot_download
79
+ try:
80
+ model_path=snapshot_download(repo_id=model_path_or_repo)
81
+ tokenizer_path=model_path
82
+ except Exception as e:
83
+ raise f"Failed to download the model: {e}"
84
+
85
+ # Use a local model.
86
+ else:
87
+ # Set the model path and tokenizer path.
88
+ model_path = None
89
+ tokenizer_path = model_path_or_repo
90
+
91
+ # Find all the checkpoint folders, folders that start with "checkpoint-". Then find the last one.
92
+ checkpoint_folders = glob.glob(os.path.join(model_path_or_repo, "checkpoint-*"))
93
+ for checkpoint_folder in checkpoint_folders:
94
+ if checkpoint_folder.endswith("-last"):
95
+ model_path = checkpoint_folder
96
+ break
97
+ if model_path is None:
98
+ raise ValueError("No model checkpoint found.")
99
+
100
+ # Find the tokenizer folder.
101
+ if os.path.exists(os.path.join(model_path_or_repo, "tokenizer.json")):
102
+ tokenizer_path = model_path_or_repo
103
+ if not os.path.exists(tokenizer_path):
104
+ raise ValueError("Tokenizer not found.")
105
+
106
+ # Load the config.
107
+ config_path = os.path.join(model_path, "config.yaml")
108
+ if not os.path.exists(config_path):
109
+ raise ValueError(f"Config not found at {config_path}")
110
+ model_config = OmegaConf.load(config_path)
111
+
112
+ # Override the config.
113
+ if config_overrides != {} and config_overrides is not None:
114
+ model_config = OmegaConf.merge(model_config, config_overrides)
115
+ import json
116
+ print(json.dumps(OmegaConf.to_container(model_config), indent=4))
117
+
118
+ # Create the model from the config.
119
+ model = model_from_config(model_config, device=self.device)
120
+ model.to(self.device)
121
+ self.config = model_config
122
+
123
+ # Load the weights from the checkpoint.
124
+ weights_path = os.path.join(model_path, "model.safetensors")
125
+ if not os.path.exists(weights_path):
126
+ raise ValueError(f"Weights not found at {weights_path}")
127
+ state_dict = load_file(weights_path)
128
+
129
+ # TODO: Permute the last two dimensions of these parameters: xlstm_block_stack.blocks.2.xlstm.slstm_cell._recurrent_kernel_:
130
+ # Check if we have an xLSTM model and if CUDA is not available.
131
+ if not torch.cuda.is_available() and model_config.get("type", "xLSTMLMModel") == "xLSTMLMModel":
132
+ print(state_dict.keys())
133
+ endings = ["xlstm.slstm_cell._recurrent_kernel_"]
134
+ for key, values in state_dict.items():
135
+ for ending in endings:
136
+ if key.endswith(ending):
137
+ print(key)
138
+ print(values.shape)
139
+
140
+ # Option: Permute the last two dimensions.
141
+ values = values.permute(0, 2, 1)
142
+
143
+ # Option: View the tensor.
144
+ #new_shape = (values.shape[0], values.shape[2], values.shape[1])
145
+ #values = values.view(new_shape)
146
+
147
+ print(values.shape)
148
+ state_dict[key] = values
149
+ break
150
+
151
+ # Load the weights into the model.
152
+ model.load_state_dict(state_dict)
153
+ self.model = model
154
+
155
+ # Load the tokenizer.
156
+ tokenizer_path = os.path.join(tokenizer_path, "tokenizer.json")
157
+ if not os.path.exists(tokenizer_path):
158
+ raise ValueError(f"Tokenizer not found at {tokenizer_path}")
159
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
160
+ self.tokenizer = tokenizer
161
+
162
+
163
+ def generate(
164
+ self,
165
+ prompt: str,
166
+ temperature: float = 1.0,
167
+ max_length: int = 100,
168
+ end_tokens: list[str] = [],
169
+ forbidden_tokens: list[str] = [],
170
+ return_structured_output: bool = False
171
+ ):
172
+ """
173
+ Generates a continuation for a given prompt using the language model.
174
+ Args:
175
+ prompt (str): The prompt to generate a continuation for.
176
+ temperature (float, optional): The temperature value for controlling the randomness of the generated output.
177
+ Higher values (e.g., 1.0) make the output more random, while lower values (e.g., 0.5) make it more deterministic.
178
+ Defaults to 1.0.
179
+ max_length (int, optional): The maximum length of the generated output. Defaults to 100.
180
+ end_tokens (list[str], optional): A list of end tokens that, if encountered, will stop the generation process.
181
+ Defaults to an empty list.
182
+ return_structured_output (bool, optional): If True, returns a dictionary with the generated output, elapsed time,
183
+ and tokens per second. If False, returns only the generated output as a string. Defaults to False.
184
+ Returns:
185
+ str or dict: The generated output as a string if return_structured_output is False.
186
+ A dictionary with the generated output, elapsed time, and tokens per second if return_structured_output is True.
187
+ """
188
+
189
+ # Tokenize the prompt.
190
+ inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
191
+ assert inputs.shape[0] == 1
192
+
193
+ # Determine the end tokens ids.
194
+ end_token_ids = []
195
+ for end_token in end_tokens:
196
+ assert end_token in self.tokenizer.vocab
197
+ end_token_ids.append(self.tokenizer(end_token).input_ids[0])
198
+
199
+ # Initialize the ids to mask.
200
+ ids_to_mask = []
201
+
202
+ # Mask the forbidden tokens.
203
+ for forbidden_token in forbidden_tokens:
204
+ assert forbidden_token in self.tokenizer.vocab
205
+ ids_to_mask.extend(self.tokenizer(forbidden_token).input_ids)
206
+
207
+ # Generate the continuation.
208
+ start_time = time.time()
209
+ tokens_count = 0
210
+ while inputs.shape[1] < max_length:
211
+
212
+ # Stop if the maximum context length is reached.
213
+ if inputs.shape[1] >= self.config.context_length:
214
+ print("Warning: The maximum context length has been reached.")
215
+ break
216
+
217
+ # Generate the continuation.
218
+ outputs = self.model(inputs.to(device=self.device))
219
+ assert outputs.shape[0] == 1
220
+
221
+ # Mask the tokens.
222
+ outputs[:, :, self.tokenizer.all_special_ids] = float("-inf")
223
+
224
+ # Use the temperature to sample from the distribution.
225
+ outputs = outputs / temperature
226
+ outputs = torch.nn.functional.softmax(outputs, dim=-1)
227
+ outputs = torch.multinomial(outputs[0, -1], num_samples=1)
228
+
229
+ # Add to the inputs.
230
+ inputs = torch.cat([inputs, outputs.unsqueeze(0)], dim=1)
231
+
232
+ # Increment the tokens count.
233
+ tokens_count += 1
234
+
235
+ # Check if the end token is reached.
236
+ if outputs[0] in end_token_ids:
237
+ break
238
+
239
+ # Print the elapsed time and tokens per second.
240
+ elapsed_time = time.time() - start_time
241
+ tokens_per_second = tokens_count / elapsed_time
242
+
243
+ # Decode the output.
244
+ output = self.tokenizer.decode(inputs[0].tolist())
245
+
246
+ # Return the output.
247
+ if not return_structured_output:
248
+ return output
249
+
250
+ # Return the structured output.
251
+ else:
252
+ return {
253
+ "output": output,
254
+ "elapsed_time": elapsed_time,
255
+ "tokens_per_second": tokens_per_second
256
+ }
257
+
258
+ def summary(self):
259
+ """
260
+ Prints a summary of the model. Makes the model architecture readable. Includes the number of parameters.
261
+ """
262
+
263
+ # Print the model.
264
+ print(self.model)
265
+
266
+ # Get the number of parameters.
267
+ number_of_parameters = sum(p.numel() for p in self.model.parameters())
268
+ print(f"Number of parameters: {number_of_parameters:_}")
269
+ sizes = ["", "K", "M", "B", "T"]
270
+ size_index = 0
271
+ while number_of_parameters > 1000:
272
+ number_of_parameters /= 1000
273
+ size_index += 1
274
+ print(f"Number of parameters: {number_of_parameters:.2f}{sizes[size_index]}")
275
+
276
+ # Size of the model.
277
+ # Get the total size of all the markdown files. And make it human readable.
278
+ number_of_parameters = sum(p.numel() for p in self.model.parameters())
279
+ total_size = number_of_parameters * 4
280
+ sizes = ["B", "KB", "MB", "GB", "TB"]
281
+ size_index = 0
282
+ while total_size > 1024:
283
+ total_size /= 1024
284
+ size_index += 1
285
+ print(f"Total size of the model: {total_size:.2f}{sizes[size_index]} for precision 32-bit floats.")
286
+
287
+ # Print on which device the model is running.
288
+ print(f"Device: {self.device}")
source/utilities.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import note_seq
3
+ from PIL import Image
4
+ import tempfile
5
+ import os
6
+ import colorama
7
+ from omegaconf import DictConfig, OmegaConf
8
+ import torch
9
+ from typing import List, Tuple, Dict
10
+ from dacite import from_dict
11
+ from collections.abc import MutableMapping
12
+ import sys
13
+
14
+
15
+ # NOTE: Imported from helibrunna.
16
+ def display_logo():
17
+ """
18
+ Display the logo by printing it line by line with a cyberpunk color scheme.
19
+
20
+ Raises:
21
+ FileNotFoundError: If the logo file is missing.
22
+ """
23
+
24
+ # Get the path of this script and use it to find the logo.
25
+ script_path = os.path.dirname(os.path.realpath(__file__))
26
+ search_path = os.path.dirname(script_path)
27
+
28
+ # Load the logo.
29
+ logo_path = os.path.join(search_path, "assets", "asciilogo.txt")
30
+ if not os.path.exists(logo_path):
31
+ raise FileNotFoundError("The logo file is missing.")
32
+ with open(logo_path, "r") as f:
33
+ logo = f.read()
34
+
35
+ # Print the logo line by line. Use colorama to colorize the output. Use a cyberpunk color scheme.
36
+ for line_index, line in enumerate(logo.split("\n")):
37
+ color = colorama.Fore.GREEN
38
+ style = colorama.Style.BRIGHT if line_index % 2 == 0 else colorama.Style.NORMAL
39
+ print(color + style + line)
40
+ print(colorama.Style.RESET_ALL)
41
+
42
+
43
+ # NOTE: Imported from helibrunna.
44
+ def model_from_config(model_config: DictConfig, device:str) -> torch.nn.Module:
45
+ """
46
+ Create a model based on the provided model configuration.
47
+
48
+ Args:
49
+ model_config (DictConfig): The configuration for the model.
50
+
51
+ Returns:
52
+ The created model.
53
+
54
+ Raises:
55
+ ValueError: If the model type is unknown.
56
+ """
57
+
58
+ # Get the model type from the configuration.
59
+ model_type = model_config.get("type", "xLSTMLMModel")
60
+
61
+ # Create the xLSTMLMModel.
62
+ if model_type == "xLSTMLMModel":
63
+ print("Creating xLSTMLMModel...")
64
+ from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
65
+
66
+ # If there is no GPU, use the vanilla backend.
67
+ if not torch.cuda.is_available():
68
+ #model_config.backend = "vanilla"
69
+ model_config.slstm_block.slstm.backend = "vanilla"
70
+ model_config.mlstm_block.mlstm.backend = "vanilla"
71
+ model_config_object = from_dict(xLSTMLMModelConfig, OmegaConf.to_container(model_config))
72
+
73
+ # Create the model.
74
+ model = xLSTMLMModel(model_config_object)
75
+ model.reset_parameters()
76
+
77
+ # Create the GPT2LMModel.
78
+ elif model_type == "gpt2":
79
+ print("Creating GPT2LMModel...")
80
+ from .models.gpttwo import GPT2LMModel, GPT2LMModelConfig
81
+ model_config_object = from_dict(GPT2LMModelConfig, OmegaConf.to_container(model_config))
82
+ model = GPT2LMModel(model_config_object)
83
+
84
+ # Create the MambaLM.
85
+ elif model_type == "mamba":
86
+ print("Creating Mamba LM...")
87
+ from mambapy.lm import LM, MambaConfig
88
+ model_config_object = from_dict(MambaConfig, OmegaConf.to_container(model_config))
89
+ model = LM(model_config_object, model_config.vocab_size)
90
+
91
+ # Create the Transformer.
92
+ elif model_type == "transformer":
93
+ from .models.transformer import TransformerConfig, Transformer
94
+ model_config_object = from_dict(TransformerConfig, OmegaConf.to_container(model_config))
95
+ model = Transformer(model_config_object)
96
+
97
+ # Create a Pharia instance.
98
+ elif model_type == "pharia":
99
+ from .models.pharia import PhariaConfig, PhariaModel
100
+ model_config_object = from_dict(PhariaConfig, OmegaConf.to_container(model_config))
101
+ model = PhariaModel(model_config_object)
102
+
103
+ # Create a TransformerXL instance.
104
+ else:
105
+ raise ValueError(f"Unknown model type: {model_type}")
106
+
107
+ # Move the model to the device.
108
+ model.to(device)
109
+ return model
110
+
111
+
112
+ def convert_tokens_to_songdata(tokens):
113
+
114
+ if isinstance(tokens, str):
115
+ tokens = tokens.split()
116
+
117
+ song_data = {}
118
+
119
+ song_data["tracks"] = []
120
+
121
+ current_track_index = 0
122
+ current_timestep = 0
123
+ for token in tokens:
124
+ if token == "GARLAND_START":
125
+ pass
126
+ elif token == "BAR_START":
127
+ if current_track_index == len(song_data["tracks"]):
128
+ song_data["tracks"] += [{"bars": [], "instrument": "0"}]
129
+ bar_data = {"notes": []}
130
+ song_data["tracks"][current_track_index]["bars"] += [bar_data]
131
+ current_timestep = 0
132
+ elif token.startswith("INST="):
133
+ instrument = token.split("=")[1]
134
+ song_data["tracks"][current_track_index]["instrument"] = instrument
135
+ elif token.startswith("DENSITY="):
136
+ pass
137
+ elif token.startswith("NOTE_ON="):
138
+ note_pitch = int(token.split("=")[1])
139
+ note_data = {
140
+ "note": note_pitch,
141
+ "start": current_timestep,
142
+ "end": current_timestep,
143
+ "veloctiy": 80
144
+ }
145
+ song_data["tracks"][current_track_index]["bars"][-1]["notes"] += [note_data]
146
+ pass
147
+ elif token.startswith("TIME_DELTA="):
148
+ current_timestep += int(token.split("=")[1])
149
+ elif token.startswith("NOTE_OFF="):
150
+ note_pitch = int(token.split("=")[1])
151
+ for note_data in song_data["tracks"][current_track_index]["bars"][-1]["notes"]:
152
+ if note_data["note"] == note_pitch and note_data["start"] == note_data["end"]:
153
+ note_data["end"] = current_timestep
154
+ break
155
+ pass
156
+ elif token == "BAR_END":
157
+ current_track_index += 1
158
+ elif token == "NEXT":
159
+ current_track_index = 0
160
+ elif token == "GARLAND_END":
161
+ pass
162
+ elif token == "[PAD]":
163
+ pass
164
+ elif token == "[EOS]":
165
+ pass
166
+ else:
167
+ raise Exception(f"Unknown token: {token}")
168
+
169
+ assert isinstance(song_data, dict)
170
+ return song_data
171
+
172
+
173
+ def convert_songdata_to_notesequence(song_data:dict, quantize_steps_per_quarter=8, remove_disabled_tracks=True):
174
+
175
+ assert isinstance(song_data, dict), f"Invalid song data type: {type(song_data)}"
176
+
177
+ # Clone the song data.
178
+ song_data = copy.deepcopy(song_data)
179
+
180
+ # Sort the tracks by instrument.
181
+ assert "tracks" in song_data, f"Invalid song data: {song_data.keys()}"
182
+ tracks = sorted(song_data["tracks"], key=lambda t: t["instrument"])
183
+ song_data["tracks"] = tracks
184
+
185
+ # Remove tracks that are not enabled.
186
+ if remove_disabled_tracks:
187
+ song_data["tracks"] = [t for t in song_data["tracks"] if t.get("enabled", True)]
188
+
189
+ # Create an empy note sequence.
190
+ note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
191
+
192
+ # Add the tempo.
193
+ bpm = song_data["bpm"] if "bpm" in song_data else 120
194
+ note_sequence.tempos.add().qpm = bpm
195
+
196
+ # Compute some lengths.
197
+ step_length_seconds = 60.0 / bpm / quantize_steps_per_quarter
198
+ bar_length_seconds = 4 * step_length_seconds * quantize_steps_per_quarter
199
+
200
+ # Get the instruments.
201
+ instruments = list(set([t["instrument"] for t in song_data["tracks"]]))
202
+
203
+ # Add the tracks.
204
+ for track_index, track_data in enumerate(song_data["tracks"]):
205
+ instrument = track_data["instrument"]
206
+ for bar_index, bar_data in enumerate(track_data["bars"]):
207
+ bar_start_time = bar_index * bar_length_seconds
208
+ for note_data in bar_data["notes"]:
209
+ assert "note" in note_data
210
+ assert "start" in note_data
211
+ assert "end" in note_data
212
+ note = note_sequence.notes.add()
213
+ #note.instrument = instrument TODO
214
+ note.pitch = note_data["note"]
215
+ note.start_time = note_data["start"] * step_length_seconds + bar_start_time
216
+ note.end_time = note_data["end"] * step_length_seconds + bar_start_time
217
+ if "velocity" in note_data:
218
+ note.velocity = note_data["velocity"]
219
+ else:
220
+ note.velocity = 80
221
+ note.instrument = track_index
222
+ if instrument == "drums":
223
+ note.is_drum = True
224
+ else:
225
+ note.is_drum = False
226
+ note.program = int(instrument)
227
+
228
+ return note_sequence
229
+
230
+
231
+ def convert_songdata_to_pianoroll(song_data):
232
+
233
+ # The bars are 4/4 and the quantization is 8 steps per quarter, aka 32 steps per bar.
234
+ # We will render a grid. The height is 64 pixels. The width is 32 pixels per bar
235
+
236
+ # Create a new image.
237
+ lengths = [len(track["bars"]) for track in song_data["tracks"]]
238
+ if lengths == []:
239
+ return None
240
+ assert len(set(lengths)) == 1, f"Unequal number of bars: {lengths}"
241
+ num_bars = lengths[0]
242
+
243
+ # Get the note extremes.
244
+ min_note = 128
245
+ max_note = 0
246
+ for track_data in song_data["tracks"]:
247
+ for bar_data in track_data["bars"]:
248
+ for note_data in bar_data["notes"]:
249
+ min_note = min(min_note, note_data["note"])
250
+ max_note = max(max_note, note_data["note"])
251
+
252
+ # The width depends on the bars.
253
+ width = 32 * num_bars
254
+
255
+ # The width depends on the notes.
256
+ height = 1 + max_note - min_note
257
+
258
+ # Create the image.
259
+ image = Image.new("RGB", (width, height), (14, 17, 23))
260
+
261
+ # Define some colors.
262
+ base_color = (255, 75, 75)
263
+ adjustments = [1.2, 1.0, 0.8, 0.6]
264
+ colors = []
265
+ for adjustment in adjustments:
266
+ import colorsys
267
+ rgb = base_color
268
+ rgb = [float(c) / 255.0 for c in rgb]
269
+ hsv = colorsys.rgb_to_hsv(*rgb)
270
+ # Rotate the hue.
271
+ offset = (adjustment - 1.0) * 0.1
272
+ hsv = (hsv[0] + offset, hsv[1], hsv[2])
273
+ rgb = colorsys.hsv_to_rgb(*hsv)
274
+ rgb = tuple([int(255.0 * c) for c in rgb])
275
+ colors += [rgb]
276
+ print("")
277
+
278
+ for color in colors:
279
+ print(color)
280
+
281
+
282
+
283
+ # Draw the grid.
284
+ for track_index, track_data in enumerate(song_data["tracks"]):
285
+ color = colors[track_index % len(colors)]
286
+ for bar_index, bar_data in enumerate(track_data["bars"]):
287
+ x = bar_index * 32
288
+
289
+ for note_data in bar_data["notes"]:
290
+ y = max_note - note_data["note"]
291
+ assert y >= 0 and y < height, f"Invalid y: {y}, note {note_data['note']} min_note: {min_note}, max_note: {max_note}, difference: {max_note - min_note}, height: {height}"
292
+ for i in range(note_data["start"], note_data["end"]):
293
+ image.putpixel((x + i, y), color)
294
+
295
+ # Resize the image. Use nearest neighbor for pixel art.
296
+ factor = 4
297
+ image = image.resize((width * factor, height * factor), Image.NEAREST)
298
+
299
+ return image
300
+
301
+
302
+ def convert_notesequence_to_wave(note_sequence):
303
+
304
+ if len(note_sequence.notes) == 0:
305
+ return None
306
+
307
+ try:
308
+ synthesizer = note_seq.fluidsynth
309
+ wave = synthesizer(note_sequence, sample_rate=44100)
310
+ return wave
311
+ except Exception as e:
312
+ synthesizer = note_seq.synthesize
313
+ wave = synthesizer(note_sequence)
314
+ return wave
315
+
316
+
317
+ def convert_notesequence_to_midi(note_sequence, filename="output.mid"):
318
+
319
+ if len(note_sequence.notes) == 0:
320
+ return None
321
+
322
+ # Returns the file content of the midi file.
323
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
324
+ filename = temp_file.name
325
+ note_seq.sequence_proto_to_midi_file(note_sequence, filename)
326
+ with open(filename, "rb") as file:
327
+ content = file.read()
328
+ return content
329
+
330
+
331
+