matthew mitton m4lw4r3exe commited on
Commit
9118de8
0 Parent(s):

Duplicate from JammyMachina/the-jam-machine-app

Browse files

Co-authored-by: Halid Bayram <[email protected]>

Files changed (16) hide show
  1. .gitattributes +34 -0
  2. .gitignore +1 -0
  3. .vscode/launch.json +16 -0
  4. .vscode/settings.json +3 -0
  5. README.md +26 -0
  6. constants.py +121 -0
  7. decoder.py +197 -0
  8. familizer.py +137 -0
  9. generate.py +486 -0
  10. generation_utils.py +161 -0
  11. load.py +63 -0
  12. packages.txt +1 -0
  13. playback.py +35 -0
  14. playground.py +195 -0
  15. requirements.txt +17 -0
  16. utils.py +246 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
.vscode/launch.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "playground.py",
9
+ "type": "python",
10
+ "request": "launch",
11
+ "program": "playground.py",
12
+ "console": "integratedTerminal",
13
+ "justMyCode": false
14
+ }
15
+ ]
16
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.formatting.provider": "black"
3
+ }
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: The Jam Machine
3
+ emoji: 🎶
4
+ colorFrom: darkblue
5
+ colorTo: black
6
+ sdk: gradio
7
+ sdk_version: 3.13.1
8
+ python_version: 3.10.6
9
+ app_file: playground.py
10
+ pinned: true
11
+ duplicated_from: JammyMachina/the-jam-machine-app
12
+ ---
13
+
14
+ [Presentation](pitch.com/public/417162a8-88b0-4472-a651-c66bb89428be)
15
+ ## Contributors:
16
+ ### Jean Simonnet:
17
+ - [Github](https://github.com/misnaej)
18
+ - [Linkedin](https://www.linkedin.com/in/jeansimonnet/)
19
+
20
+ ### Louis Demetz:
21
+ - [Github](https://github.com/louis-demetz)
22
+ - [Linkedin](https://www.linkedin.com/in/ldemetz/)
23
+
24
+ ### Halid Bayram:
25
+ - [Github](https://github.com/m41w4r3exe)
26
+ - [Linkedin](https://www.linkedin.com/in/halid-bayram-6b9ba861/)
constants.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fmt: off
2
+ # Instrument mapping and mapping functions
3
+ INSTRUMENT_CLASSES = [
4
+ {"name": "Piano", "program_range": range(0, 8), "family_number": 0},
5
+ {"name": "Chromatic Percussion", "program_range": range(8, 16), "family_number": 1},
6
+ {"name": "Organ", "program_range": range(16, 24), "family_number": 2},
7
+ {"name": "Guitar", "program_range": range(24, 32), "family_number": 3},
8
+ {"name": "Bass", "program_range": range(32, 40), "family_number": 4},
9
+ {"name": "Strings", "program_range": range(40, 48), "family_number": 5},
10
+ {"name": "Ensemble", "program_range": range(48, 56), "family_number": 6},
11
+ {"name": "Brass", "program_range": range(56, 64), "family_number": 7},
12
+ {"name": "Reed", "program_range": range(64, 72), "family_number": 8},
13
+ {"name": "Pipe", "program_range": range(72, 80), "family_number": 9},
14
+ {"name": "Synth Lead", "program_range": range(80, 88), "family_number": 10},
15
+ {"name": "Synth Pad", "program_range": range(88, 96), "family_number": 11},
16
+ {"name": "Synth Effects", "program_range": range(96, 104), "family_number": 12},
17
+ {"name": "Ethnic", "program_range": range(104, 112), "family_number": 13},
18
+ {"name": "Percussive", "program_range": range(112, 120), "family_number": 14},
19
+ {"name": "Sound Effects", "program_range": range(120, 128), "family_number": 15,},
20
+ ]
21
+ # fmt: on
22
+
23
+ # Instrument mapping for decodiing our midi sequence into midi instruments of our choice
24
+ INSTRUMENT_TRANSFER_CLASSES = [
25
+ {
26
+ "name": "Piano",
27
+ "program_range": [4],
28
+ "family_number": 0,
29
+ "transfer_to": "Electric Piano 1",
30
+ },
31
+ {
32
+ "name": "Chromatic Percussion",
33
+ "program_range": [11],
34
+ "family_number": 1,
35
+ "transfer_to": "Vibraphone",
36
+ },
37
+ {
38
+ "name": "Organ",
39
+ "program_range": [17],
40
+ "family_number": 2,
41
+ "transfer_to": "Percussive Organ",
42
+ },
43
+ {
44
+ "name": "Guitar",
45
+ "program_range": [80],
46
+ "family_number": 3,
47
+ "transfer_to": "Synth Lead Square",
48
+ },
49
+ {
50
+ "name": "Bass",
51
+ "program_range": [38],
52
+ "family_number": 4,
53
+ "transfer_to": "Synth Bass 1",
54
+ },
55
+ {
56
+ "name": "Strings",
57
+ "program_range": [50],
58
+ "family_number": 5,
59
+ "transfer_to": "Synth Strings 1",
60
+ },
61
+ {
62
+ "name": "Ensemble",
63
+ "program_range": [51],
64
+ "family_number": 6,
65
+ "transfer_to": "Synth Strings 2",
66
+ },
67
+ {
68
+ "name": "Brass",
69
+ "program_range": [63],
70
+ "family_number": 7,
71
+ "transfer_to": "Synth Brass 1",
72
+ },
73
+ {
74
+ "name": "Reed",
75
+ "program_range": [64],
76
+ "family_number": 8,
77
+ "transfer_to": "Synth Brass 2",
78
+ },
79
+ {
80
+ "name": "Pipe",
81
+ "program_range": [82],
82
+ "family_number": 9,
83
+ "transfer_to": "Synth Lead Calliope",
84
+ },
85
+ {
86
+ "name": "Synth Lead",
87
+ "program_range": [81], # Synth Lead Sawtooth
88
+ "family_number": 10,
89
+ "transfer_to": "Synth Lead Sawtooth",
90
+ },
91
+ {
92
+ "name": "Synth Pad",
93
+ "program_range": range(88, 96),
94
+ "family_number": 11,
95
+ "transfer_to": "Synth Pad",
96
+ },
97
+ {
98
+ "name": "Synth Effects",
99
+ "program_range": range(96, 104),
100
+ "family_number": 12,
101
+ "transfer_to": "Synth Effects",
102
+ },
103
+ {
104
+ "name": "Ethnic",
105
+ "program_range": range(104, 112),
106
+ "family_number": 13,
107
+ "transfer_to": "Ethnic",
108
+ },
109
+ {
110
+ "name": "Percussive",
111
+ "program_range": range(112, 120),
112
+ "family_number": 14,
113
+ "transfer_to": "Percussive",
114
+ },
115
+ {
116
+ "name": "Sound Effects",
117
+ "program_range": range(120, 128),
118
+ "family_number": 15,
119
+ "transfer_to": "Sound Effects",
120
+ },
121
+ ]
decoder.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from familizer import Familizer
3
+ from miditok import Event
4
+
5
+
6
+ class TextDecoder:
7
+ """Decodes text into:
8
+ 1- List of events
9
+ 2- Then converts these events to midi file via MidiTok and miditoolkit
10
+
11
+ :param tokenizer: from MidiTok
12
+
13
+ Usage with write_to_midi method:
14
+ args: text(String) example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
15
+ returns: midi file from miditoolkit
16
+ """
17
+
18
+ def __init__(self, tokenizer, familized=True):
19
+ self.tokenizer = tokenizer
20
+ self.familized = familized
21
+
22
+ def decode(self, text):
23
+ r"""converts from text to instrument events
24
+ Args:
25
+ text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
26
+
27
+ Returns:
28
+ Dict{inst_id: List[Events]}: List of events of Notes with velocities, aggregated Timeshifts, for each instrument
29
+ """
30
+ piece_events = self.text_to_events(text)
31
+ inst_events = self.piece_to_inst_events(piece_events)
32
+ events = self.add_timeshifts_for_empty_bars(inst_events)
33
+ events = self.aggregate_timeshifts(events)
34
+ events = self.add_velocity(events)
35
+ return events
36
+
37
+ def tokenize(self, events):
38
+ r"""converts from events to MidiTok tokens
39
+ Args:
40
+ events (Dict{inst_id: List[Events]}): List of events for each instrument
41
+
42
+ Returns:
43
+ List[List[Events]]: List of tokens for each instrument
44
+ """
45
+ tokens = []
46
+ for inst in events.keys():
47
+ tokens.append(self.tokenizer.events_to_tokens(events[inst]))
48
+ return tokens
49
+
50
+ def get_midi(self, text, filename=None):
51
+ r"""converts from text to midi
52
+ Args:
53
+ text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
54
+
55
+ Returns:
56
+ miditoolkit midi: Returns and writes to midi
57
+ """
58
+ events = self.decode(text)
59
+ tokens = self.tokenize(events)
60
+ instruments = self.get_instruments_tuple(events)
61
+ midi = self.tokenizer.tokens_to_midi(tokens, instruments)
62
+
63
+ if filename is not None:
64
+ midi.dump(f"{filename}")
65
+ print(f"midi file written: {filename}")
66
+
67
+ return midi
68
+
69
+ @staticmethod
70
+ def text_to_events(text):
71
+ events = []
72
+ for word in text.split(" "):
73
+ # TODO: Handle bar and track values with a counter
74
+ _event = word.split("=")
75
+ value = _event[1] if len(_event) > 1 else None
76
+ event = get_event(_event[0], value)
77
+ if event:
78
+ events.append(event)
79
+ return events
80
+
81
+ @staticmethod
82
+ def piece_to_inst_events(piece_events):
83
+ """Converts piece events of 8 bars to instrument events for entire song
84
+
85
+ Args:
86
+ piece_events (List[Events]): List of events of Notes, Timeshifts, Bars, Tracks
87
+
88
+ Returns:
89
+ Dict{inst_id: List[Events]}: List of events for each instrument
90
+
91
+ """
92
+ inst_events = {}
93
+ current_instrument = -1
94
+ for event in piece_events:
95
+ if event.type == "Instrument":
96
+ current_instrument = event.value
97
+ if current_instrument not in inst_events:
98
+ inst_events[current_instrument] = []
99
+ elif current_instrument != -1:
100
+ inst_events[current_instrument].append(event)
101
+ return inst_events
102
+
103
+ @staticmethod
104
+ def add_timeshifts_for_empty_bars(inst_events):
105
+ """Adds time shift events instead of consecutive [BAR_START BAR_END] events"""
106
+ new_inst_events = {}
107
+ for inst, events in inst_events.items():
108
+ new_inst_events[inst] = []
109
+ for index, event in enumerate(events):
110
+ if event.type == "Bar-End" or event.type == "Bar-Start":
111
+ if events[index - 1].type == "Bar-Start":
112
+ new_inst_events[inst].append(Event("Time-Shift", "4.0.8"))
113
+ else:
114
+ new_inst_events[inst].append(event)
115
+ return new_inst_events
116
+
117
+ @staticmethod
118
+ def add_timeshifts(beat_values1, beat_values2):
119
+ """Adds two beat values
120
+
121
+ Args:
122
+ beat_values1 (String): like 0.3.8
123
+ beat_values2 (String): like 1.7.8
124
+
125
+ Returns:
126
+ beat_str (String): added beats like 2.2.8 for example values
127
+ """
128
+ value1 = to_base10(beat_values1)
129
+ value2 = to_base10(beat_values2)
130
+ return to_beat_str(value1 + value2)
131
+
132
+ def aggregate_timeshifts(self, events):
133
+ """Aggregates consecutive time shift events bigger than a bar
134
+ -> like Timeshift 4.0.8
135
+
136
+ Args:
137
+ events (_type_): _description_
138
+
139
+ Returns:
140
+ _type_: _description_
141
+ """
142
+ new_events = {}
143
+ for inst, events in events.items():
144
+ inst_events = []
145
+ for i, event in enumerate(events):
146
+ if (
147
+ event.type == "Time-Shift"
148
+ and len(inst_events) > 0
149
+ and inst_events[-1].type == "Time-Shift"
150
+ ):
151
+ inst_events[-1].value = self.add_timeshifts(
152
+ inst_events[-1].value, event.value
153
+ )
154
+ else:
155
+ inst_events.append(event)
156
+ new_events[inst] = inst_events
157
+ return new_events
158
+
159
+ @staticmethod
160
+ def add_velocity(events):
161
+ """Adds default velocity 99 to note events since they are removed from text, needed to generate midi"""
162
+ new_events = {}
163
+ for inst, events in events.items():
164
+ inst_events = []
165
+ for event in events:
166
+ inst_events.append(event)
167
+ if event.type == "Note-On":
168
+ inst_events.append(Event("Velocity", 99))
169
+ new_events[inst] = inst_events
170
+ return new_events
171
+
172
+ def get_instruments_tuple(self, events):
173
+ """Returns instruments tuple for midi generation"""
174
+ instruments = []
175
+ for inst in events.keys():
176
+ is_drum = 0
177
+ if inst == "DRUMS":
178
+ inst = 0
179
+ is_drum = 1
180
+ if self.familized:
181
+ inst = Familizer(arbitrary=True).get_program_number(int(inst))
182
+ instruments.append((int(inst), is_drum))
183
+ return tuple(instruments)
184
+
185
+
186
+ if __name__ == "__main__":
187
+
188
+ filename = "midi/generated/misnaej/the-jam-machine-elec-famil/20221209_175750"
189
+ encoded_json = readFromFile(
190
+ f"{filename}.json",
191
+ True,
192
+ )
193
+ encoded_text = encoded_json["sequence"]
194
+ # encoded_text = "PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=69 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=69 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=57 TIME_DELTA=1 NOTE_OFF=57 NOTE_ON=56 TIME_DELTA=1 NOTE_OFF=56 NOTE_ON=64 NOTE_ON=60 NOTE_ON=55 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=55 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=59 NOTE_ON=55 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=59 NOTE_OFF=50 NOTE_OFF=55 NOTE_OFF=50 BAR_END BAR_START BAR_END TRACK_END"
195
+
196
+ miditok = get_miditok()
197
+ TextDecoder(miditok).get_midi(encoded_text, filename=filename)
familizer.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from joblib import Parallel, delayed
3
+ from pathlib import Path
4
+ from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
5
+ from utils import get_files, timeit, FileCompressor
6
+
7
+
8
+ class Familizer:
9
+ def __init__(self, n_jobs=-1, arbitrary=False):
10
+ self.n_jobs = n_jobs
11
+ self.reverse_family(arbitrary)
12
+
13
+ def get_family_number(self, program_number):
14
+ """
15
+ Given a MIDI instrument number, return its associated instrument family number.
16
+ """
17
+ for instrument_class in INSTRUMENT_CLASSES:
18
+ if program_number in instrument_class["program_range"]:
19
+ return instrument_class["family_number"]
20
+
21
+ def reverse_family(self, arbitrary):
22
+ """
23
+ Create a dictionary of family numbers to randomly assigned program numbers.
24
+ This is used to reverse the family number tokens back to program number tokens.
25
+ """
26
+
27
+ if arbitrary is True:
28
+ int_class = INSTRUMENT_TRANSFER_CLASSES
29
+ else:
30
+ int_class = INSTRUMENT_CLASSES
31
+
32
+ self.reference_programs = {}
33
+ for family in int_class:
34
+ self.reference_programs[family["family_number"]] = random.choice(
35
+ family["program_range"]
36
+ )
37
+
38
+ def get_program_number(self, family_number):
39
+ """
40
+ Given given a family number return a random program number in the respective program_range.
41
+ This is the reverse operation of get_family_number.
42
+ """
43
+ assert family_number in self.reference_programs
44
+ return self.reference_programs[family_number]
45
+
46
+ # Replace instruments in text files
47
+ def replace_instrument_token(self, token):
48
+ """
49
+ Given a MIDI program number in a word token, replace it with the family or program
50
+ number token depending on the operation.
51
+ e.g. INST=86 -> INST=10
52
+ """
53
+ inst_number = int(token.split("=")[1])
54
+ if self.operation == "family":
55
+ return "INST=" + str(self.get_family_number(inst_number))
56
+ elif self.operation == "program":
57
+ return "INST=" + str(self.get_program_number(inst_number))
58
+
59
+ def replace_instrument_in_text(self, text):
60
+ """Given a text piece, replace all instrument tokens with family number tokens."""
61
+ return " ".join(
62
+ [
63
+ self.replace_instrument_token(token)
64
+ if token.startswith("INST=") and not token == "INST=DRUMS"
65
+ else token
66
+ for token in text.split(" ")
67
+ ]
68
+ )
69
+
70
+ def replace_instruments_in_file(self, file):
71
+ """Given a text file, replace all instrument tokens with family number tokens."""
72
+ text = file.read_text()
73
+ file.write_text(self.replace_instrument_in_text(text))
74
+
75
+ @timeit
76
+ def replace_instruments(self):
77
+ """
78
+ Given a directory of text files:
79
+ Replace all instrument tokens with family number tokens.
80
+ """
81
+ files = get_files(self.output_directory, extension="txt")
82
+ Parallel(n_jobs=self.n_jobs)(
83
+ delayed(self.replace_instruments_in_file)(file) for file in files
84
+ )
85
+
86
+ def replace_tokens(self, input_directory, output_directory, operation):
87
+ """
88
+ Given a directory and an operation, perform the operation on all text files in the directory.
89
+ operation can be either 'family' or 'program'.
90
+ """
91
+ self.input_directory = input_directory
92
+ self.output_directory = output_directory
93
+ self.operation = operation
94
+
95
+ # Uncompress files, replace tokens, compress files
96
+ fc = FileCompressor(self.input_directory, self.output_directory, self.n_jobs)
97
+ fc.unzip()
98
+ self.replace_instruments()
99
+ fc.zip()
100
+ print(self.operation + " complete.")
101
+
102
+ def to_family(self, input_directory, output_directory):
103
+ """
104
+ Given a directory containing zip files, replace all instrument tokens with
105
+ family number tokens. The output is a directory of zip files.
106
+ """
107
+ self.replace_tokens(input_directory, output_directory, "family")
108
+
109
+ def to_program(self, input_directory, output_directory):
110
+ """
111
+ Given a directory containing zip files, replace all instrument tokens with
112
+ program number tokens. The output is a directory of zip files.
113
+ """
114
+ self.replace_tokens(input_directory, output_directory, "program")
115
+
116
+
117
+ if __name__ == "__main__":
118
+
119
+ # Choose number of jobs for parallel processing
120
+ n_jobs = -1
121
+
122
+ # Instantiate Familizer
123
+ familizer = Familizer(n_jobs)
124
+
125
+ # Choose directory to process for program
126
+ input_directory = Path("midi/dataset/first_selection/validate").resolve() # fmt: skip
127
+ output_directory = input_directory / "family"
128
+
129
+ # familize files
130
+ familizer.to_family(input_directory, output_directory)
131
+
132
+ # Choose directory to process for family
133
+ # input_directory = Path("../data/music_picks/encoded_samples/validate/family").resolve() # fmt: skip
134
+ # output_directory = input_directory.parent / "program"
135
+
136
+ # # programize files
137
+ # familizer.to_program(input_directory, output_directory)
generate.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from generation_utils import *
2
+ from utils import WriteTextMidiToFile, get_miditok
3
+ from load import LoadModel
4
+ from decoder import TextDecoder
5
+ from playback import get_music
6
+
7
+
8
+ class GenerateMidiText:
9
+ """Generating music with Class
10
+
11
+ LOGIC:
12
+
13
+ FOR GENERATING FROM SCRATCH:
14
+ - self.generate_one_new_track()
15
+ it calls
16
+ - self.generate_until_track_end()
17
+
18
+ FOR GENERATING NEW BARS:
19
+ - self.generate_one_more_bar()
20
+ it calls
21
+ - self.process_prompt_for_next_bar()
22
+ - self.generate_until_track_end()"""
23
+
24
+ def __init__(self, model, tokenizer, piece_by_track=[]):
25
+ self.model = model
26
+ self.tokenizer = tokenizer
27
+ # default initialization
28
+ self.initialize_default_parameters()
29
+ self.initialize_dictionaries(piece_by_track)
30
+
31
+ """Setters"""
32
+
33
+ def initialize_default_parameters(self):
34
+ self.set_device()
35
+ self.set_attention_length()
36
+ self.generate_until = "TRACK_END"
37
+ self.set_force_sequence_lenth()
38
+ self.set_nb_bars_generated()
39
+ self.set_improvisation_level(0)
40
+
41
+ def initialize_dictionaries(self, piece_by_track):
42
+ self.piece_by_track = piece_by_track
43
+
44
+ def set_device(self, device="cpu"):
45
+ self.device = ("cpu",)
46
+
47
+ def set_attention_length(self):
48
+ self.max_length = self.model.config.n_positions
49
+ print(
50
+ f"Attention length set to {self.max_length} -> 'model.config.n_positions'"
51
+ )
52
+
53
+ def set_force_sequence_lenth(self, force_sequence_length=True):
54
+ self.force_sequence_length = force_sequence_length
55
+
56
+ def set_improvisation_level(self, improvisation_value):
57
+ self.no_repeat_ngram_size = improvisation_value
58
+ print("--------------------")
59
+ print(f"no_repeat_ngram_size set to {improvisation_value}")
60
+ print("--------------------")
61
+
62
+ def reset_temperatures(self, track_id, temperature):
63
+ self.piece_by_track[track_id]["temperature"] = temperature
64
+
65
+ def set_nb_bars_generated(self, n_bars=8): # default is a 8 bar model
66
+ self.model_n_bar = n_bars
67
+
68
+ """ Generation Tools - Dictionnaries """
69
+
70
+ def initiate_track_dict(self, instr, density, temperature):
71
+ label = len(self.piece_by_track)
72
+ self.piece_by_track.append(
73
+ {
74
+ "label": f"track_{label}",
75
+ "instrument": instr,
76
+ "density": density,
77
+ "temperature": temperature,
78
+ "bars": [],
79
+ }
80
+ )
81
+
82
+ def update_track_dict__add_bars(self, bars, track_id):
83
+ """Add bars to the track dictionnary"""
84
+ for bar in self.striping_track_ends(bars).split("BAR_START "):
85
+ if bar == "": # happens is there is one bar only
86
+ continue
87
+ else:
88
+ if "TRACK_START" in bar:
89
+ self.piece_by_track[track_id]["bars"].append(bar)
90
+ else:
91
+ self.piece_by_track[track_id]["bars"].append("BAR_START " + bar)
92
+
93
+ def get_all_instr_bars(self, track_id):
94
+ return self.piece_by_track[track_id]["bars"]
95
+
96
+ def striping_track_ends(self, text):
97
+ if "TRACK_END" in text:
98
+ # first get rid of extra space if any
99
+ # then gets rid of "TRACK_END"
100
+ text = text.rstrip(" ").rstrip("TRACK_END")
101
+ return text
102
+
103
+ def get_last_generated_track(self, full_piece):
104
+ track = (
105
+ "TRACK_START "
106
+ + self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
107
+ + "TRACK_END "
108
+ ) # forcing the space after track and
109
+ return track
110
+
111
+ def get_selected_track_as_text(self, track_id):
112
+ text = ""
113
+ for bar in self.piece_by_track[track_id]["bars"]:
114
+ text += bar
115
+ text += "TRACK_END "
116
+ return text
117
+
118
+ @staticmethod
119
+ def get_newly_generated_text(input_prompt, full_piece):
120
+ return full_piece[len(input_prompt) :]
121
+
122
+ def get_whole_piece_from_bar_dict(self):
123
+ text = "PIECE_START "
124
+ for track_id, _ in enumerate(self.piece_by_track):
125
+ text += self.get_selected_track_as_text(track_id)
126
+ return text
127
+
128
+ def delete_one_track(self, track): # TO BE TESTED
129
+ self.piece_by_track.pop(track)
130
+
131
+ # def update_piece_dict__add_track(self, track_id, track):
132
+ # self.piece_dict[track_id] = track
133
+
134
+ # def update_all_dictionnaries__add_track(self, track):
135
+ # self.update_piece_dict__add_track(track_id, track)
136
+
137
+ """Basic generation tools"""
138
+
139
+ def tokenize_input_prompt(self, input_prompt, verbose=True):
140
+ """Tokenizing prompt
141
+
142
+ Args:
143
+ - input_prompt (str): prompt to tokenize
144
+
145
+ Returns:
146
+ - input_prompt_ids (torch.tensor): tokenized prompt
147
+ """
148
+ if verbose:
149
+ print("Tokenizing input_prompt...")
150
+
151
+ return self.tokenizer.encode(input_prompt, return_tensors="pt")
152
+
153
+ def generate_sequence_of_token_ids(
154
+ self,
155
+ input_prompt_ids,
156
+ temperature,
157
+ verbose=True,
158
+ ):
159
+ """
160
+ generate a sequence of token ids based on input_prompt_ids
161
+ The sequence length depends on the trained model (self.model_n_bar)
162
+ """
163
+ generated_ids = self.model.generate(
164
+ input_prompt_ids,
165
+ max_length=self.max_length,
166
+ do_sample=True,
167
+ temperature=temperature,
168
+ no_repeat_ngram_size=self.no_repeat_ngram_size, # default = 0
169
+ eos_token_id=self.tokenizer.encode(self.generate_until)[0], # good
170
+ )
171
+
172
+ if verbose:
173
+ print("Generating a token_id sequence...")
174
+
175
+ return generated_ids
176
+
177
+ def convert_ids_to_text(self, generated_ids, verbose=True):
178
+ """converts the token_ids to text"""
179
+ generated_text = self.tokenizer.decode(generated_ids[0])
180
+ if verbose:
181
+ print("Converting token sequence to MidiText...")
182
+ return generated_text
183
+
184
+ def generate_until_track_end(
185
+ self,
186
+ input_prompt="PIECE_START ",
187
+ instrument=None,
188
+ density=None,
189
+ temperature=None,
190
+ verbose=True,
191
+ expected_length=None,
192
+ ):
193
+
194
+ """generate until the TRACK_END token is reached
195
+ full_piece = input_prompt + generated"""
196
+ if expected_length is None:
197
+ expected_length = self.model_n_bar
198
+
199
+ if instrument is not None:
200
+ input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} "
201
+ if density is not None:
202
+ input_prompt = f"{input_prompt}DENSITY={str(density)} "
203
+
204
+ if instrument is None and density is not None:
205
+ print("Density cannot be defined without an input_prompt instrument #TOFIX")
206
+
207
+ if temperature is None:
208
+ ValueError("Temperature must be defined")
209
+
210
+ if verbose:
211
+ print("--------------------")
212
+ print(
213
+ f"Generating {instrument} - Density {density} - temperature {temperature}"
214
+ )
215
+ bar_count_checks = False
216
+ failed = 0
217
+ while not bar_count_checks: # regenerate until right length
218
+ input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose)
219
+ generated_tokens = self.generate_sequence_of_token_ids(
220
+ input_prompt_ids, temperature, verbose=verbose
221
+ )
222
+ full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose)
223
+ generated = self.get_newly_generated_text(input_prompt, full_piece)
224
+ # bar_count_checks
225
+ bar_count_checks, bar_count = bar_count_check(generated, expected_length)
226
+
227
+ if not self.force_sequence_length:
228
+ # set bar_count_checks to true to exist the while loop
229
+ bar_count_checks = True
230
+
231
+ if not bar_count_checks and self.force_sequence_length:
232
+ # if the generated sequence is not the expected length
233
+ if failed > -1: # deactivated for speed
234
+ full_piece, bar_count_checks = forcing_bar_count(
235
+ input_prompt,
236
+ generated,
237
+ bar_count,
238
+ expected_length,
239
+ )
240
+ else:
241
+ print('"--- Wrong length - Regenerating ---')
242
+ if not bar_count_checks:
243
+ failed += 1
244
+ if failed > 2:
245
+ bar_count_checks = True # TOFIX exit the while loop
246
+
247
+ return full_piece
248
+
249
+ def generate_one_new_track(
250
+ self,
251
+ instrument,
252
+ density,
253
+ temperature,
254
+ input_prompt="PIECE_START ",
255
+ ):
256
+ self.initiate_track_dict(instrument, density, temperature)
257
+ full_piece = self.generate_until_track_end(
258
+ input_prompt=input_prompt,
259
+ instrument=instrument,
260
+ density=density,
261
+ temperature=temperature,
262
+ )
263
+
264
+ track = self.get_last_generated_track(full_piece)
265
+ self.update_track_dict__add_bars(track, -1)
266
+ full_piece = self.get_whole_piece_from_bar_dict()
267
+ return full_piece
268
+
269
+ """ Piece generation - Basics """
270
+
271
+ def generate_piece(self, instrument_list, density_list, temperature_list):
272
+ """generate a sequence with mutiple tracks
273
+ - inst_list sets the list of instruments of the order of generation
274
+ - density is paired with inst_list
275
+ Each track/intrument is generated on a prompt which contains the previously generated track/instrument
276
+ This means that the first instrument is generated with less bias than the next one, and so on.
277
+
278
+ 'generated_piece' keeps track of the entire piece
279
+ 'generated_piece' is returned by self.generate_until_track_end
280
+ # it is returned by self.generate_until_track_end"""
281
+
282
+ generated_piece = "PIECE_START "
283
+ for instrument, density, temperature in zip(
284
+ instrument_list, density_list, temperature_list
285
+ ):
286
+ generated_piece = self.generate_one_new_track(
287
+ instrument,
288
+ density,
289
+ temperature,
290
+ input_prompt=generated_piece,
291
+ )
292
+
293
+ # generated_piece = self.get_whole_piece_from_bar_dict()
294
+ self.check_the_piece_for_errors()
295
+ return generated_piece
296
+
297
+ """ Piece generation - Extra Bars """
298
+
299
+ @staticmethod
300
+ def process_prompt_for_next_bar(self, track_idx):
301
+ """Processing the prompt for the model to generate one more bar only.
302
+ The prompt containts:
303
+ if not the first bar: the previous, already processed, bars of the track
304
+ the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ")
305
+ the last (self.model_n_bar)-1 bars of the track
306
+ Args:
307
+ track_idx (int): the index of the track to be processed
308
+
309
+ Returns:
310
+ the processed prompt for generating the next bar
311
+ """
312
+ track = self.piece_by_track[track_idx]
313
+ # for bars which are not the bar to prolong
314
+ pre_promt = "PIECE_START "
315
+ for i, othertrack in enumerate(self.piece_by_track):
316
+ if i != track_idx:
317
+ len_diff = len(othertrack["bars"]) - len(track["bars"])
318
+ if len_diff > 0:
319
+ # if other bars are longer, it mean that this one should catch up
320
+ pre_promt += othertrack["bars"][0]
321
+ for bar in track["bars"][-self.model_n_bar :]:
322
+ pre_promt += bar
323
+ pre_promt += "TRACK_END "
324
+ elif False: # len_diff <= 0: # THIS GENERATES EMPTINESS
325
+ # adding an empty bars at the end of the other tracks if they have not been processed yet
326
+ pre_promt += othertracks["bars"][0]
327
+ for bar in track["bars"][-(self.model_n_bar - 1) :]:
328
+ pre_promt += bar
329
+ for _ in range(abs(len_diff) + 1):
330
+ pre_promt += "BAR_START BAR_END "
331
+ pre_promt += "TRACK_END "
332
+
333
+ # for the bar to prolong
334
+ # initialization e.g TRACK_START INST=DRUMS DENSITY=2
335
+ processed_prompt = track["bars"][0]
336
+ for bar in track["bars"][-(self.model_n_bar - 1) :]:
337
+ # adding the "last" bars of the track
338
+ processed_prompt += bar
339
+
340
+ processed_prompt += "BAR_START "
341
+ print(
342
+ f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
343
+ )
344
+ return pre_promt + processed_prompt
345
+
346
+ def generate_one_more_bar(self, i):
347
+ """Generate one more bar from the input_prompt"""
348
+ processed_prompt = self.process_prompt_for_next_bar(self, i)
349
+ prompt_plus_bar = self.generate_until_track_end(
350
+ input_prompt=processed_prompt,
351
+ temperature=self.piece_by_track[i]["temperature"],
352
+ expected_length=1,
353
+ verbose=False,
354
+ )
355
+ added_bar = self.get_newly_generated_bar(prompt_plus_bar)
356
+ self.update_track_dict__add_bars(added_bar, i)
357
+
358
+ def get_newly_generated_bar(self, prompt_plus_bar):
359
+ return "BAR_START " + self.striping_track_ends(
360
+ prompt_plus_bar.split("BAR_START ")[-1]
361
+ )
362
+
363
+ def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True):
364
+ """Generate n more bars from the input_prompt"""
365
+ if only_this_track is None:
366
+ only_this_track
367
+
368
+ print(f"================== ")
369
+ print(f"Adding {n_bars} more bars to the piece ")
370
+ for bar_id in range(n_bars):
371
+ print(f"----- added bar #{bar_id+1} --")
372
+ for i, track in enumerate(self.piece_by_track):
373
+ if only_this_track is None or i == only_this_track:
374
+ print(f"--------- {track['label']}")
375
+ self.generate_one_more_bar(i)
376
+ self.check_the_piece_for_errors()
377
+
378
+ def check_the_piece_for_errors(self, piece: str = None):
379
+
380
+ if piece is None:
381
+ piece = generate_midi.get_whole_piece_from_bar_dict()
382
+ errors = []
383
+ errors.append(
384
+ [
385
+ (token, id)
386
+ for id, token in enumerate(piece.split(" "))
387
+ if token not in self.tokenizer.vocab or token == "UNK"
388
+ ]
389
+ )
390
+ if len(errors) > 0:
391
+ # print(piece)
392
+ for er in errors:
393
+ er
394
+ print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}")
395
+ print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5])
396
+
397
+
398
+ if __name__ == "__main__":
399
+
400
+ # worker
401
+ DEVICE = "cpu"
402
+
403
+ # define generation parameters
404
+ N_FILES_TO_GENERATE = 2
405
+ Temperatures_to_try = [0.7]
406
+
407
+ USE_FAMILIZED_MODEL = True
408
+ force_sequence_length = True
409
+
410
+ if USE_FAMILIZED_MODEL:
411
+ # model_repo = "misnaej/the-jam-machine-elec-famil"
412
+ # model_repo = "misnaej/the-jam-machine-elec-famil-ft32"
413
+
414
+ # model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
415
+ # n_bar_generated = 8
416
+
417
+ model_repo = "JammyMachina/improved_4bars-mdl"
418
+ n_bar_generated = 4
419
+ instrument_promt_list = ["4", "DRUMS", "3"]
420
+ # DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
421
+ density_list = [3, 2, 2]
422
+ # temperature_list = [0.7, 0.7, 0.75]
423
+ else:
424
+ model_repo = "misnaej/the-jam-machine"
425
+ instrument_promt_list = ["30"] # , "DRUMS", "0"]
426
+ density_list = [3] # , 2, 3]
427
+ # temperature_list = [0.7, 0.5, 0.75]
428
+ pass
429
+
430
+ # define generation directory
431
+ generated_sequence_files_path = define_generation_dir(model_repo)
432
+
433
+ # load model and tokenizer
434
+ model, tokenizer = LoadModel(
435
+ model_repo, from_huggingface=True
436
+ ).load_model_and_tokenizer()
437
+
438
+ # does the prompt make sense
439
+ check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)
440
+
441
+ for temperature in Temperatures_to_try:
442
+ print(f"================= TEMPERATURE {temperature} =======================")
443
+ for _ in range(N_FILES_TO_GENERATE):
444
+ print(f"========================================")
445
+ # 1 - instantiate
446
+ generate_midi = GenerateMidiText(model, tokenizer)
447
+ # 0 - set the n_bar for this model
448
+ generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
449
+ # 1 - defines the instruments, densities and temperatures
450
+ # 2- generate the first 8 bars for each instrument
451
+ generate_midi.set_improvisation_level(30)
452
+ generate_midi.generate_piece(
453
+ instrument_promt_list,
454
+ density_list,
455
+ [temperature for _ in density_list],
456
+ )
457
+ # 3 - force the model to improvise
458
+ # generate_midi.set_improvisation_level(20)
459
+ # 4 - generate the next 4 bars for each instrument
460
+ # generate_midi.generate_n_more_bars(n_bar_generated)
461
+ # 5 - lower the improvisation level
462
+ generate_midi.generated_piece = (
463
+ generate_midi.get_whole_piece_from_bar_dict()
464
+ )
465
+
466
+ # print the generated sequence in terminal
467
+ print("=========================================")
468
+ print(generate_midi.generated_piece)
469
+ print("=========================================")
470
+
471
+ # write to JSON file
472
+ filename = WriteTextMidiToFile(
473
+ generate_midi,
474
+ generated_sequence_files_path,
475
+ ).text_midi_to_file()
476
+
477
+ # decode the sequence to MIDI """
478
+ decode_tokenizer = get_miditok()
479
+ TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
480
+ generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
481
+ )
482
+ inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
483
+ max_time = get_max_time(inst_midi)
484
+ plot_piano_roll(inst_midi)
485
+
486
+ print("Et voilà! Your MIDI file is ready! GO JAM!")
generation_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+
6
+ from constants import INSTRUMENT_CLASSES
7
+ from playback import get_music, show_piano_roll
8
+
9
+ # matplotlib settings
10
+ matplotlib.use("Agg") # for server
11
+ matplotlib.rcParams["xtick.major.size"] = 0
12
+ matplotlib.rcParams["ytick.major.size"] = 0
13
+ matplotlib.rcParams["axes.facecolor"] = "none"
14
+ matplotlib.rcParams["axes.edgecolor"] = "grey"
15
+
16
+
17
+ def define_generation_dir(model_repo_path):
18
+ #### to remove later ####
19
+ if model_repo_path == "models/model_2048_fake_wholedataset":
20
+ model_repo_path = "misnaej/the-jam-machine"
21
+ #### to remove later ####
22
+ generated_sequence_files_path = f"midi/generated/{model_repo_path}"
23
+ if not os.path.exists(generated_sequence_files_path):
24
+ os.makedirs(generated_sequence_files_path)
25
+ return generated_sequence_files_path
26
+
27
+
28
+ def bar_count_check(sequence, n_bars):
29
+ """check if the sequence contains the right number of bars"""
30
+ sequence = sequence.split(" ")
31
+ # find occurences of "BAR_END" in a "sequence"
32
+ # I don't check for "BAR_START" because it is not always included in "sequence"
33
+ # e.g. BAR_START is included the prompt when generating one more bar
34
+ bar_count = 0
35
+ for seq in sequence:
36
+ if seq == "BAR_END":
37
+ bar_count += 1
38
+ bar_count_matches = bar_count == n_bars
39
+ if not bar_count_matches:
40
+ print(f"Bar count is {bar_count} - but should be {n_bars}")
41
+ return bar_count_matches, bar_count
42
+
43
+
44
+ def print_inst_classes(INSTRUMENT_CLASSES):
45
+ """Print the instrument classes"""
46
+ for classe in INSTRUMENT_CLASSES:
47
+ print(f"{classe}")
48
+
49
+
50
+ def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
51
+ """Check if the prompt instrument are in the tokenizer vocab"""
52
+ for inst in inst_prompt_list:
53
+ if f"INST={inst}" not in tokenizer.vocab:
54
+ instruments_in_dataset = np.sort(
55
+ [tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok]
56
+ )
57
+ print_inst_classes(INSTRUMENT_CLASSES)
58
+ raise ValueError(
59
+ f"""The instrument {inst} is not in the tokenizer vocabulary.
60
+ Available Instruments: {instruments_in_dataset}"""
61
+ )
62
+
63
+
64
+ def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
65
+ """Forcing the generated sequence to have the expected length
66
+ expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
67
+
68
+ if bar_count - expected_length > 0: # Cut the sequence if too long
69
+ full_piece = ""
70
+ splited = generated.split("BAR_END ")
71
+ for count, spl in enumerate(splited):
72
+ if count < expected_length:
73
+ full_piece += spl + "BAR_END "
74
+
75
+ full_piece += "TRACK_END "
76
+ full_piece = input_prompt + full_piece
77
+ print(f"Generated sequence trunkated at {expected_length} bars")
78
+ bar_count_checks = True
79
+
80
+ elif bar_count - expected_length < 0: # Do nothing it the sequence if too short
81
+ full_piece = input_prompt + generated
82
+ bar_count_checks = False
83
+ print(f"--- Generated sequence is too short - Force Regeration ---")
84
+
85
+ return full_piece, bar_count_checks
86
+
87
+
88
+ def get_max_time(inst_midi):
89
+ max_time = 0
90
+ for inst in inst_midi.instruments:
91
+ max_time = max(max_time, inst.get_end_time())
92
+ return max_time
93
+
94
+
95
+ def plot_piano_roll(inst_midi):
96
+ piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
97
+ piano_roll_fig.tight_layout()
98
+ piano_roll_fig.patch.set_alpha(0)
99
+ inst_count = 0
100
+ beats_per_bar = 4
101
+ sec_per_beat = 0.5
102
+ next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0]
103
+ bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype(
104
+ int
105
+ )
106
+ for inst in inst_midi.instruments:
107
+ # hardcoded for now
108
+ if inst.name == "Drums":
109
+ color = "purple"
110
+ elif inst.name == "Synth Bass 1":
111
+ color = "orange"
112
+ else:
113
+ color = "green"
114
+
115
+ inst_count += 1
116
+ plt.subplot(len(inst_midi.instruments), 1, inst_count)
117
+
118
+ for bar in bars_time:
119
+ plt.axvline(bar, color="grey", linewidth=0.5)
120
+ octaves = np.arange(0, 128, 12)
121
+ for octave in octaves:
122
+ plt.axhline(octave, color="grey", linewidth=0.5)
123
+ plt.yticks(octaves, visible=False)
124
+
125
+ p_midi_note_list = inst.notes
126
+ note_time = []
127
+ note_pitch = []
128
+ for note in p_midi_note_list:
129
+ note_time.append([note.start, note.end])
130
+ note_pitch.append([note.pitch, note.pitch])
131
+ note_pitch = np.array(note_pitch)
132
+ note_time = np.array(note_time)
133
+
134
+ plt.plot(
135
+ note_time.T,
136
+ note_pitch.T,
137
+ color=color,
138
+ linewidth=4,
139
+ solid_capstyle="butt",
140
+ )
141
+ plt.ylim(0, 128)
142
+ xticks = np.array(bars_time)[:-1]
143
+ plt.tight_layout()
144
+ plt.xlim(min(bars_time), max(bars_time))
145
+ plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5)
146
+ plt.xticks(
147
+ xticks + 0.5 * beats_per_bar * sec_per_beat,
148
+ labels=xticks.argsort() + 1,
149
+ visible=False,
150
+ )
151
+ plt.text(
152
+ 0.2,
153
+ note_pitch.max() + 4,
154
+ inst.name,
155
+ fontsize=20,
156
+ color=color,
157
+ horizontalalignment="left",
158
+ verticalalignment="top",
159
+ )
160
+
161
+ return piano_roll_fig
load.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel
2
+ from transformers import PreTrainedTokenizerFast
3
+ import os
4
+ import torch
5
+
6
+
7
+ class LoadModel:
8
+ """
9
+ Example usage:
10
+
11
+ # if loading model and tokenizer from Huggingface
12
+ model_repo = "misnaej/the-jam-machine"
13
+ model, tokenizer = LoadModel(
14
+ model_repo, from_huggingface=True
15
+ ).load_model_and_tokenizer()
16
+
17
+ # if loading model and tokenizer from a local folder
18
+ model_path = "models/model_2048_wholedataset"
19
+ model, tokenizer = LoadModel(
20
+ model_path, from_huggingface=False
21
+ ).load_model_and_tokenizer()
22
+
23
+ """
24
+
25
+ def __init__(self, path, from_huggingface=True, device="cpu", revision=None):
26
+ # path is either a relative path on a local/remote machine or a model repo on HuggingFace
27
+ if not from_huggingface:
28
+ if not os.path.exists(path):
29
+ print(path)
30
+ raise Exception("Model path does not exist")
31
+ self.from_huggingface = from_huggingface
32
+ self.path = path
33
+ self.device = device
34
+ self.revision = revision
35
+ if torch.cuda.is_available():
36
+ self.device = "cuda"
37
+
38
+ def load_model_and_tokenizer(self):
39
+ model = self.load_model()
40
+ tokenizer = self.load_tokenizer()
41
+
42
+ return model, tokenizer
43
+
44
+ def load_model(self):
45
+ if self.revision is None:
46
+ model = GPT2LMHeadModel.from_pretrained(self.path) # .to(self.device)
47
+ else:
48
+ model = GPT2LMHeadModel.from_pretrained(
49
+ self.path, revision=self.revision
50
+ ) # .to(self.device)
51
+
52
+ return model
53
+
54
+ def load_tokenizer(self):
55
+ if self.from_huggingface:
56
+ pass
57
+ else:
58
+ if not os.path.exists(f"{self.path}/tokenizer.json"):
59
+ raise Exception(
60
+ f"There is no 'tokenizer.json'file in the defined {self.path}"
61
+ )
62
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(self.path)
63
+ return tokenizer
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
playback.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import librosa.display
3
+ from pretty_midi import PrettyMIDI
4
+
5
+
6
+ # Note: these functions are meant to be played within an interactive Python shell
7
+ # Please refer to the synth.ipynb for an example of how to use them
8
+
9
+
10
+ def get_music(midi_file):
11
+ """
12
+ Load a midi file and return the PrettyMIDI object and the audio signal
13
+ """
14
+ music = PrettyMIDI(midi_file=midi_file)
15
+ waveform = music.fluidsynth()
16
+ return music, waveform
17
+
18
+
19
+ def show_piano_roll(music_notes, fs=100):
20
+ """
21
+ Show the piano roll of a music piece, with all instruments squashed onto a single 128xN matrix
22
+ :param music_notes: PrettyMIDI object
23
+ :param fs: sampling frequency
24
+ """
25
+ # get the piano roll
26
+ piano_roll = music_notes.get_piano_roll(fs)
27
+ print("Piano roll shape: {}".format(piano_roll.shape))
28
+
29
+ # plot the piano roll
30
+ plt.figure(figsize=(12, 4))
31
+ librosa.display.specshow(piano_roll, sr=100, x_axis="time", y_axis="cqt_note")
32
+ plt.colorbar()
33
+ plt.title("Piano roll")
34
+ plt.tight_layout()
35
+ plt.show()
playground.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import gradio as gr
3
+ from load import LoadModel
4
+ from generate import GenerateMidiText
5
+ from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
6
+ from decoder import TextDecoder
7
+ from utils import get_miditok, index_has_substring
8
+ from playback import get_music
9
+ from matplotlib import pylab
10
+ import sys
11
+ import matplotlib
12
+ from generation_utils import plot_piano_roll
13
+ import numpy as np
14
+
15
+ matplotlib.use("Agg")
16
+
17
+ sys.modules["pylab"] = pylab
18
+
19
+ model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
20
+ n_bar_generated = 8
21
+ # model_repo = "JammyMachina/improved_4bars-mdl"
22
+ # n_bar_generated = 4
23
+
24
+ model, tokenizer = LoadModel(
25
+ model_repo,
26
+ from_huggingface=True,
27
+ ).load_model_and_tokenizer()
28
+
29
+ miditok = get_miditok()
30
+ decoder = TextDecoder(miditok)
31
+
32
+
33
+ def define_prompt(state, genesis):
34
+ if len(state) == 0:
35
+ input_prompt = "PIECE_START "
36
+ else:
37
+ input_prompt = genesis.get_whole_piece_from_bar_dict()
38
+ return input_prompt
39
+
40
+
41
+ def generator(
42
+ label,
43
+ regenerate,
44
+ temp,
45
+ density,
46
+ instrument,
47
+ state,
48
+ piece_by_track,
49
+ add_bars=False,
50
+ add_bar_count=1,
51
+ ):
52
+
53
+ genesis = GenerateMidiText(model, tokenizer, piece_by_track)
54
+ track = {"label": label}
55
+ inst = next(
56
+ (
57
+ inst
58
+ for inst in INSTRUMENT_TRANSFER_CLASSES
59
+ if inst["transfer_to"] == instrument
60
+ ),
61
+ {"family_number": "DRUMS"},
62
+ )["family_number"]
63
+
64
+ inst_index = -1 # default to last generated
65
+ if state != []:
66
+ for index, instrum in enumerate(state):
67
+ if instrum["label"] == track["label"]:
68
+ inst_index = index # changing if exists
69
+
70
+ # Generate
71
+ if not add_bars:
72
+ # Regenerate
73
+ if regenerate:
74
+ state.pop(inst_index)
75
+ genesis.delete_one_track(inst_index)
76
+
77
+ generated_text = (
78
+ genesis.get_whole_piece_from_bar_dict()
79
+ ) # maybe not useful here
80
+ inst_index = -1 # reset to last generated
81
+
82
+ # NEW TRACK
83
+ input_prompt = define_prompt(state, genesis)
84
+ generated_text = genesis.generate_one_new_track(
85
+ inst, density, temp, input_prompt=input_prompt
86
+ )
87
+
88
+ regenerate = True # set generate to true
89
+ else:
90
+ # NEW BARS
91
+ genesis.generate_n_more_bars(add_bar_count) # for all instruments
92
+ generated_text = genesis.get_whole_piece_from_bar_dict()
93
+
94
+ decoder.get_midi(generated_text, "mixed.mid")
95
+ mixed_inst_midi, mixed_audio = get_music("mixed.mid")
96
+
97
+ inst_text = genesis.get_selected_track_as_text(inst_index)
98
+ inst_midi_name = f"{instrument}.mid"
99
+ decoder.get_midi(inst_text, inst_midi_name)
100
+ _, inst_audio = get_music(inst_midi_name)
101
+ piano_roll = plot_piano_roll(mixed_inst_midi)
102
+ track["text"] = inst_text
103
+ state.append(track)
104
+
105
+ return (
106
+ inst_text,
107
+ (44100, inst_audio),
108
+ piano_roll,
109
+ state,
110
+ (44100, mixed_audio),
111
+ regenerate,
112
+ genesis.piece_by_track,
113
+ )
114
+
115
+
116
+ def instrument_row(default_inst, row_id):
117
+ with gr.Row():
118
+ row = gr.Variable(row_id)
119
+ with gr.Column(scale=1, min_width=100):
120
+ inst = gr.Dropdown(
121
+ sorted([inst["transfer_to"] for inst in INSTRUMENT_TRANSFER_CLASSES])
122
+ + ["Drums"],
123
+ value=default_inst,
124
+ label="Instrument",
125
+ )
126
+ temp = gr.Dropdown(
127
+ [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1],
128
+ value=0.7,
129
+ label="Creativity",
130
+ )
131
+ density = gr.Dropdown([1, 2, 3], value=3, label="Note Density")
132
+
133
+ with gr.Column(scale=3):
134
+ output_txt = gr.Textbox(
135
+ label="output", lines=10, max_lines=10, show_label=False
136
+ )
137
+ with gr.Column(scale=1, min_width=100):
138
+ inst_audio = gr.Audio(label="TRACK Audio", show_label=True)
139
+ regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False)
140
+ # add_bars = gr.Checkbox(value=False, label="Add Bars")
141
+ # add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
142
+ gen_btn = gr.Button("Generate")
143
+ gen_btn.click(
144
+ fn=generator,
145
+ inputs=[row, regenerate, temp, density, inst, state, piece_by_track],
146
+ outputs=[
147
+ output_txt,
148
+ inst_audio,
149
+ piano_roll,
150
+ state,
151
+ mixed_audio,
152
+ regenerate,
153
+ piece_by_track,
154
+ ],
155
+ )
156
+
157
+
158
+ with gr.Blocks() as demo:
159
+ piece_by_track = gr.State([])
160
+ state = gr.State([])
161
+ title = gr.Markdown(
162
+ """ # Demo-App of The-Jam-Machine
163
+ A Generative AI trained on text transcription of MIDI music """
164
+ )
165
+ track1_md = gr.Markdown(""" ## Mixed Audio and Piano Roll """)
166
+ mixed_audio = gr.Audio(label="Mixed Audio")
167
+ piano_roll = gr.Plot(label="Piano Roll", show_label=False)
168
+ description = gr.Markdown(
169
+ """
170
+ For each **TRACK**, choose your **instrument** along with **creativity** (temperature) and **note density**. Then, hit the **Generate** Button!
171
+ You can have a look at the generated text; but most importantly, check the **piano roll** and listen to the TRACK audio!
172
+ If you don't like the track, hit the generate button to regenerate it! Generate more tracks and listen to the **mixed audio**!
173
+ """
174
+ )
175
+ track1_md = gr.Markdown(""" ## TRACK 1 """)
176
+ instrument_row("Drums", 0)
177
+ track1_md = gr.Markdown(""" ## TRACK 2 """)
178
+ instrument_row("Synth Bass 1", 1)
179
+ track1_md = gr.Markdown(""" ## TRACK 2 """)
180
+ instrument_row("Synth Lead Square", 2)
181
+ # instrument_row("Piano")
182
+
183
+ demo.launch(debug=True)
184
+
185
+ """
186
+ TODO: reset button
187
+ TODO: add a button to save the generated midi
188
+ TODO: add improvise button
189
+ TODO: set values for temperature as it is done for density
190
+ TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
191
+ TODO: row height to fix
192
+
193
+ TODO: reset state of tick boxes after used maybe (regenerate, add bars) ;
194
+ TODO: block regenerate if add bar on
195
+ """
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ matplotlib
3
+ matplotlib
4
+ numpy
5
+ joblib
6
+ pathlib
7
+ transformers
8
+ miditok == 1.3.2
9
+ librosa
10
+ pretty_midi
11
+ pydub
12
+ scipy
13
+ datetime
14
+ torch
15
+ torchvision
16
+ pyFluidSynth
17
+ accelerate
utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from miditok import Event, MIDILike
3
+ import os
4
+ import json
5
+ from time import perf_counter
6
+ from joblib import Parallel, delayed
7
+ from zipfile import ZipFile, ZIP_DEFLATED
8
+ from scipy.io.wavfile import write
9
+ import numpy as np
10
+ from pydub import AudioSegment
11
+ import shutil
12
+
13
+
14
+ def writeToFile(path, content):
15
+ if type(content) is dict:
16
+ with open(f"{path}", "w") as json_file:
17
+ json.dump(content, json_file)
18
+ else:
19
+ if type(content) is not str:
20
+ content = str(content)
21
+ os.makedirs(os.path.dirname(path), exist_ok=True)
22
+ with open(path, "w") as f:
23
+ f.write(content)
24
+
25
+
26
+ # Function to read from text from txt file:
27
+ def readFromFile(path, isJSON=False):
28
+ with open(path, "r") as f:
29
+ if isJSON:
30
+ return json.load(f)
31
+ else:
32
+ return f.read()
33
+
34
+
35
+ def chain(input, funcs, *params):
36
+ res = input
37
+ for func in funcs:
38
+ try:
39
+ res = func(res, *params)
40
+ except TypeError:
41
+ res = func(res)
42
+ return res
43
+
44
+
45
+ def to_beat_str(value, beat_res=8):
46
+ values = [
47
+ int(int(value * beat_res) / beat_res),
48
+ int(int(value * beat_res) % beat_res),
49
+ beat_res,
50
+ ]
51
+ return ".".join(map(str, values))
52
+
53
+
54
+ def to_base10(beat_str):
55
+ integer, decimal, base = split_dots(beat_str)
56
+ return integer + decimal / base
57
+
58
+
59
+ def split_dots(value):
60
+ return list(map(int, value.split(".")))
61
+
62
+
63
+ def compute_list_average(l):
64
+ return sum(l) / len(l)
65
+
66
+
67
+ def get_datetime():
68
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
69
+
70
+
71
+ def get_text(event):
72
+ match event.type:
73
+ case "Piece-Start":
74
+ return "PIECE_START "
75
+ case "Track-Start":
76
+ return "TRACK_START "
77
+ case "Track-End":
78
+ return "TRACK_END "
79
+ case "Instrument":
80
+ return f"INST={event.value} "
81
+ case "Bar-Start":
82
+ return "BAR_START "
83
+ case "Bar-End":
84
+ return "BAR_END "
85
+ case "Time-Shift":
86
+ return f"TIME_SHIFT={event.value} "
87
+ case "Note-On":
88
+ return f"NOTE_ON={event.value} "
89
+ case "Note-Off":
90
+ return f"NOTE_OFF={event.value} "
91
+ case _:
92
+ return ""
93
+
94
+
95
+ def get_event(text, value=None):
96
+ match text:
97
+ case "PIECE_START":
98
+ return Event("Piece-Start", value)
99
+ case "TRACK_START":
100
+ return None
101
+ case "TRACK_END":
102
+ return None
103
+ case "INST":
104
+ return Event("Instrument", value)
105
+ case "BAR_START":
106
+ return Event("Bar-Start", value)
107
+ case "BAR_END":
108
+ return Event("Bar-End", value)
109
+ case "TIME_SHIFT":
110
+ return Event("Time-Shift", value)
111
+ case "TIME_DELTA":
112
+ return Event("Time-Shift", to_beat_str(int(value) / 4))
113
+ case "NOTE_ON":
114
+ return Event("Note-On", value)
115
+ case "NOTE_OFF":
116
+ return Event("Note-Off", value)
117
+ case _:
118
+ return None
119
+
120
+
121
+ # TODO: Make this singleton
122
+ def get_miditok():
123
+ pitch_range = range(0, 140) # was (21, 109)
124
+ beat_res = {(0, 400): 8}
125
+ return MIDILike(pitch_range, beat_res)
126
+
127
+
128
+ class WriteTextMidiToFile: # utils saving to file
129
+ def __init__(self, generate_midi, output_path):
130
+ self.generated_midi = generate_midi.generated_piece
131
+ self.output_path = output_path
132
+ self.hyperparameter_and_bars = generate_midi.piece_by_track
133
+
134
+ def hashing_seq(self):
135
+ self.current_time = get_datetime()
136
+ self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
137
+
138
+ def wrapping_seq_hyperparameters_in_dict(self):
139
+ # assert type(self.generated_midi) is str, "error: generate_midi must be a string"
140
+ # assert (
141
+ # type(self.hyperparameter_dict) is dict
142
+ # ), "error: feature_dict must be a dictionnary"
143
+ return {
144
+ "generate_midi": self.generated_midi,
145
+ "hyperparameters_and_bars": self.hyperparameter_and_bars,
146
+ }
147
+
148
+ def text_midi_to_file(self):
149
+ self.hashing_seq()
150
+ output_dict = self.wrapping_seq_hyperparameters_in_dict()
151
+ print(f"Token generate_midi written: {self.output_path_filename}")
152
+ writeToFile(self.output_path_filename, output_dict)
153
+ return self.output_path_filename
154
+
155
+
156
+ def get_files(directory, extension, recursive=False):
157
+ """
158
+ Given a directory, get a list of the file paths of all files matching the
159
+ specified file extension.
160
+ directory: the directory to search as a Path object
161
+ extension: the file extension to match as a string
162
+ recursive: whether to search recursively in the directory or not
163
+ """
164
+ if recursive:
165
+ return list(directory.rglob(f"*.{extension}"))
166
+ else:
167
+ return list(directory.glob(f"*.{extension}"))
168
+
169
+
170
+ def timeit(func):
171
+ def wrapper(*args, **kwargs):
172
+ start = perf_counter()
173
+ result = func(*args, **kwargs)
174
+ end = perf_counter()
175
+ print(f"{func.__name__} took {end - start:.2f} seconds to run.")
176
+ return result
177
+
178
+ return wrapper
179
+
180
+
181
+ class FileCompressor:
182
+ def __init__(self, input_directory, output_directory, n_jobs=-1):
183
+ self.input_directory = input_directory
184
+ self.output_directory = output_directory
185
+ self.n_jobs = n_jobs
186
+
187
+ # File compression and decompression
188
+ def unzip_file(self, file):
189
+ """uncompress single zip file"""
190
+ with ZipFile(file, "r") as zip_ref:
191
+ zip_ref.extractall(self.output_directory)
192
+
193
+ def zip_file(self, file):
194
+ """compress a single text file to a new zip file and delete the original"""
195
+ output_file = self.output_directory / (file.stem + ".zip")
196
+ with ZipFile(output_file, "w") as zip_ref:
197
+ zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED)
198
+ file.unlink()
199
+
200
+ @timeit
201
+ def unzip(self):
202
+ """uncompress all zip files in folder"""
203
+ files = get_files(self.input_directory, extension="zip")
204
+ Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files)
205
+
206
+ @timeit
207
+ def zip(self):
208
+ """compress all text files in folder to new zip files and remove the text files"""
209
+ files = get_files(self.output_directory, extension="txt")
210
+ Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)
211
+
212
+
213
+ def load_jsonl(filepath):
214
+ """Load a jsonl file"""
215
+ with open(filepath, "r") as f:
216
+ data = [json.loads(line) for line in f]
217
+ return data
218
+
219
+
220
+ def write_mp3(waveform, output_path, bitrate="92k"):
221
+ """
222
+ Write a waveform to an mp3 file.
223
+ output_path: Path object for the output mp3 file
224
+ waveform: numpy array of the waveform
225
+ bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k)
226
+ """
227
+ # write the wav file
228
+ wav_path = output_path.with_suffix(".wav")
229
+ write(wav_path, 44100, waveform.astype(np.float32))
230
+ # compress the wav file as mp3
231
+ AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate)
232
+ # remove the wav file
233
+ wav_path.unlink()
234
+
235
+
236
+ def copy_file(input_file, output_dir):
237
+ """Copy an input file to the output_dir"""
238
+ output_file = output_dir / input_file.name
239
+ shutil.copy(input_file, output_file)
240
+
241
+
242
+ def index_has_substring(list, substring):
243
+ for i, s in enumerate(list):
244
+ if substring in s:
245
+ return i
246
+ return -1