Spaces:
Runtime error
Runtime error
Commit
•
9118de8
0
Parent(s):
Duplicate from JammyMachina/the-jam-machine-app
Browse filesCo-authored-by: Halid Bayram <[email protected]>
- .gitattributes +34 -0
- .gitignore +1 -0
- .vscode/launch.json +16 -0
- .vscode/settings.json +3 -0
- README.md +26 -0
- constants.py +121 -0
- decoder.py +197 -0
- familizer.py +137 -0
- generate.py +486 -0
- generation_utils.py +161 -0
- load.py +63 -0
- packages.txt +1 -0
- playback.py +35 -0
- playground.py +195 -0
- requirements.txt +17 -0
- utils.py +246 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
.vscode/launch.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "playground.py",
|
9 |
+
"type": "python",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "playground.py",
|
12 |
+
"console": "integratedTerminal",
|
13 |
+
"justMyCode": false
|
14 |
+
}
|
15 |
+
]
|
16 |
+
}
|
.vscode/settings.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.formatting.provider": "black"
|
3 |
+
}
|
README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: The Jam Machine
|
3 |
+
emoji: 🎶
|
4 |
+
colorFrom: darkblue
|
5 |
+
colorTo: black
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.13.1
|
8 |
+
python_version: 3.10.6
|
9 |
+
app_file: playground.py
|
10 |
+
pinned: true
|
11 |
+
duplicated_from: JammyMachina/the-jam-machine-app
|
12 |
+
---
|
13 |
+
|
14 |
+
[Presentation](pitch.com/public/417162a8-88b0-4472-a651-c66bb89428be)
|
15 |
+
## Contributors:
|
16 |
+
### Jean Simonnet:
|
17 |
+
- [Github](https://github.com/misnaej)
|
18 |
+
- [Linkedin](https://www.linkedin.com/in/jeansimonnet/)
|
19 |
+
|
20 |
+
### Louis Demetz:
|
21 |
+
- [Github](https://github.com/louis-demetz)
|
22 |
+
- [Linkedin](https://www.linkedin.com/in/ldemetz/)
|
23 |
+
|
24 |
+
### Halid Bayram:
|
25 |
+
- [Github](https://github.com/m41w4r3exe)
|
26 |
+
- [Linkedin](https://www.linkedin.com/in/halid-bayram-6b9ba861/)
|
constants.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# fmt: off
|
2 |
+
# Instrument mapping and mapping functions
|
3 |
+
INSTRUMENT_CLASSES = [
|
4 |
+
{"name": "Piano", "program_range": range(0, 8), "family_number": 0},
|
5 |
+
{"name": "Chromatic Percussion", "program_range": range(8, 16), "family_number": 1},
|
6 |
+
{"name": "Organ", "program_range": range(16, 24), "family_number": 2},
|
7 |
+
{"name": "Guitar", "program_range": range(24, 32), "family_number": 3},
|
8 |
+
{"name": "Bass", "program_range": range(32, 40), "family_number": 4},
|
9 |
+
{"name": "Strings", "program_range": range(40, 48), "family_number": 5},
|
10 |
+
{"name": "Ensemble", "program_range": range(48, 56), "family_number": 6},
|
11 |
+
{"name": "Brass", "program_range": range(56, 64), "family_number": 7},
|
12 |
+
{"name": "Reed", "program_range": range(64, 72), "family_number": 8},
|
13 |
+
{"name": "Pipe", "program_range": range(72, 80), "family_number": 9},
|
14 |
+
{"name": "Synth Lead", "program_range": range(80, 88), "family_number": 10},
|
15 |
+
{"name": "Synth Pad", "program_range": range(88, 96), "family_number": 11},
|
16 |
+
{"name": "Synth Effects", "program_range": range(96, 104), "family_number": 12},
|
17 |
+
{"name": "Ethnic", "program_range": range(104, 112), "family_number": 13},
|
18 |
+
{"name": "Percussive", "program_range": range(112, 120), "family_number": 14},
|
19 |
+
{"name": "Sound Effects", "program_range": range(120, 128), "family_number": 15,},
|
20 |
+
]
|
21 |
+
# fmt: on
|
22 |
+
|
23 |
+
# Instrument mapping for decodiing our midi sequence into midi instruments of our choice
|
24 |
+
INSTRUMENT_TRANSFER_CLASSES = [
|
25 |
+
{
|
26 |
+
"name": "Piano",
|
27 |
+
"program_range": [4],
|
28 |
+
"family_number": 0,
|
29 |
+
"transfer_to": "Electric Piano 1",
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"name": "Chromatic Percussion",
|
33 |
+
"program_range": [11],
|
34 |
+
"family_number": 1,
|
35 |
+
"transfer_to": "Vibraphone",
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"name": "Organ",
|
39 |
+
"program_range": [17],
|
40 |
+
"family_number": 2,
|
41 |
+
"transfer_to": "Percussive Organ",
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "Guitar",
|
45 |
+
"program_range": [80],
|
46 |
+
"family_number": 3,
|
47 |
+
"transfer_to": "Synth Lead Square",
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"name": "Bass",
|
51 |
+
"program_range": [38],
|
52 |
+
"family_number": 4,
|
53 |
+
"transfer_to": "Synth Bass 1",
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"name": "Strings",
|
57 |
+
"program_range": [50],
|
58 |
+
"family_number": 5,
|
59 |
+
"transfer_to": "Synth Strings 1",
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"name": "Ensemble",
|
63 |
+
"program_range": [51],
|
64 |
+
"family_number": 6,
|
65 |
+
"transfer_to": "Synth Strings 2",
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "Brass",
|
69 |
+
"program_range": [63],
|
70 |
+
"family_number": 7,
|
71 |
+
"transfer_to": "Synth Brass 1",
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"name": "Reed",
|
75 |
+
"program_range": [64],
|
76 |
+
"family_number": 8,
|
77 |
+
"transfer_to": "Synth Brass 2",
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"name": "Pipe",
|
81 |
+
"program_range": [82],
|
82 |
+
"family_number": 9,
|
83 |
+
"transfer_to": "Synth Lead Calliope",
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"name": "Synth Lead",
|
87 |
+
"program_range": [81], # Synth Lead Sawtooth
|
88 |
+
"family_number": 10,
|
89 |
+
"transfer_to": "Synth Lead Sawtooth",
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"name": "Synth Pad",
|
93 |
+
"program_range": range(88, 96),
|
94 |
+
"family_number": 11,
|
95 |
+
"transfer_to": "Synth Pad",
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"name": "Synth Effects",
|
99 |
+
"program_range": range(96, 104),
|
100 |
+
"family_number": 12,
|
101 |
+
"transfer_to": "Synth Effects",
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"name": "Ethnic",
|
105 |
+
"program_range": range(104, 112),
|
106 |
+
"family_number": 13,
|
107 |
+
"transfer_to": "Ethnic",
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"name": "Percussive",
|
111 |
+
"program_range": range(112, 120),
|
112 |
+
"family_number": 14,
|
113 |
+
"transfer_to": "Percussive",
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"name": "Sound Effects",
|
117 |
+
"program_range": range(120, 128),
|
118 |
+
"family_number": 15,
|
119 |
+
"transfer_to": "Sound Effects",
|
120 |
+
},
|
121 |
+
]
|
decoder.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
from familizer import Familizer
|
3 |
+
from miditok import Event
|
4 |
+
|
5 |
+
|
6 |
+
class TextDecoder:
|
7 |
+
"""Decodes text into:
|
8 |
+
1- List of events
|
9 |
+
2- Then converts these events to midi file via MidiTok and miditoolkit
|
10 |
+
|
11 |
+
:param tokenizer: from MidiTok
|
12 |
+
|
13 |
+
Usage with write_to_midi method:
|
14 |
+
args: text(String) example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
|
15 |
+
returns: midi file from miditoolkit
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, tokenizer, familized=True):
|
19 |
+
self.tokenizer = tokenizer
|
20 |
+
self.familized = familized
|
21 |
+
|
22 |
+
def decode(self, text):
|
23 |
+
r"""converts from text to instrument events
|
24 |
+
Args:
|
25 |
+
text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Dict{inst_id: List[Events]}: List of events of Notes with velocities, aggregated Timeshifts, for each instrument
|
29 |
+
"""
|
30 |
+
piece_events = self.text_to_events(text)
|
31 |
+
inst_events = self.piece_to_inst_events(piece_events)
|
32 |
+
events = self.add_timeshifts_for_empty_bars(inst_events)
|
33 |
+
events = self.aggregate_timeshifts(events)
|
34 |
+
events = self.add_velocity(events)
|
35 |
+
return events
|
36 |
+
|
37 |
+
def tokenize(self, events):
|
38 |
+
r"""converts from events to MidiTok tokens
|
39 |
+
Args:
|
40 |
+
events (Dict{inst_id: List[Events]}): List of events for each instrument
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
List[List[Events]]: List of tokens for each instrument
|
44 |
+
"""
|
45 |
+
tokens = []
|
46 |
+
for inst in events.keys():
|
47 |
+
tokens.append(self.tokenizer.events_to_tokens(events[inst]))
|
48 |
+
return tokens
|
49 |
+
|
50 |
+
def get_midi(self, text, filename=None):
|
51 |
+
r"""converts from text to midi
|
52 |
+
Args:
|
53 |
+
text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
miditoolkit midi: Returns and writes to midi
|
57 |
+
"""
|
58 |
+
events = self.decode(text)
|
59 |
+
tokens = self.tokenize(events)
|
60 |
+
instruments = self.get_instruments_tuple(events)
|
61 |
+
midi = self.tokenizer.tokens_to_midi(tokens, instruments)
|
62 |
+
|
63 |
+
if filename is not None:
|
64 |
+
midi.dump(f"{filename}")
|
65 |
+
print(f"midi file written: {filename}")
|
66 |
+
|
67 |
+
return midi
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def text_to_events(text):
|
71 |
+
events = []
|
72 |
+
for word in text.split(" "):
|
73 |
+
# TODO: Handle bar and track values with a counter
|
74 |
+
_event = word.split("=")
|
75 |
+
value = _event[1] if len(_event) > 1 else None
|
76 |
+
event = get_event(_event[0], value)
|
77 |
+
if event:
|
78 |
+
events.append(event)
|
79 |
+
return events
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def piece_to_inst_events(piece_events):
|
83 |
+
"""Converts piece events of 8 bars to instrument events for entire song
|
84 |
+
|
85 |
+
Args:
|
86 |
+
piece_events (List[Events]): List of events of Notes, Timeshifts, Bars, Tracks
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Dict{inst_id: List[Events]}: List of events for each instrument
|
90 |
+
|
91 |
+
"""
|
92 |
+
inst_events = {}
|
93 |
+
current_instrument = -1
|
94 |
+
for event in piece_events:
|
95 |
+
if event.type == "Instrument":
|
96 |
+
current_instrument = event.value
|
97 |
+
if current_instrument not in inst_events:
|
98 |
+
inst_events[current_instrument] = []
|
99 |
+
elif current_instrument != -1:
|
100 |
+
inst_events[current_instrument].append(event)
|
101 |
+
return inst_events
|
102 |
+
|
103 |
+
@staticmethod
|
104 |
+
def add_timeshifts_for_empty_bars(inst_events):
|
105 |
+
"""Adds time shift events instead of consecutive [BAR_START BAR_END] events"""
|
106 |
+
new_inst_events = {}
|
107 |
+
for inst, events in inst_events.items():
|
108 |
+
new_inst_events[inst] = []
|
109 |
+
for index, event in enumerate(events):
|
110 |
+
if event.type == "Bar-End" or event.type == "Bar-Start":
|
111 |
+
if events[index - 1].type == "Bar-Start":
|
112 |
+
new_inst_events[inst].append(Event("Time-Shift", "4.0.8"))
|
113 |
+
else:
|
114 |
+
new_inst_events[inst].append(event)
|
115 |
+
return new_inst_events
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def add_timeshifts(beat_values1, beat_values2):
|
119 |
+
"""Adds two beat values
|
120 |
+
|
121 |
+
Args:
|
122 |
+
beat_values1 (String): like 0.3.8
|
123 |
+
beat_values2 (String): like 1.7.8
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
beat_str (String): added beats like 2.2.8 for example values
|
127 |
+
"""
|
128 |
+
value1 = to_base10(beat_values1)
|
129 |
+
value2 = to_base10(beat_values2)
|
130 |
+
return to_beat_str(value1 + value2)
|
131 |
+
|
132 |
+
def aggregate_timeshifts(self, events):
|
133 |
+
"""Aggregates consecutive time shift events bigger than a bar
|
134 |
+
-> like Timeshift 4.0.8
|
135 |
+
|
136 |
+
Args:
|
137 |
+
events (_type_): _description_
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
_type_: _description_
|
141 |
+
"""
|
142 |
+
new_events = {}
|
143 |
+
for inst, events in events.items():
|
144 |
+
inst_events = []
|
145 |
+
for i, event in enumerate(events):
|
146 |
+
if (
|
147 |
+
event.type == "Time-Shift"
|
148 |
+
and len(inst_events) > 0
|
149 |
+
and inst_events[-1].type == "Time-Shift"
|
150 |
+
):
|
151 |
+
inst_events[-1].value = self.add_timeshifts(
|
152 |
+
inst_events[-1].value, event.value
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
inst_events.append(event)
|
156 |
+
new_events[inst] = inst_events
|
157 |
+
return new_events
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def add_velocity(events):
|
161 |
+
"""Adds default velocity 99 to note events since they are removed from text, needed to generate midi"""
|
162 |
+
new_events = {}
|
163 |
+
for inst, events in events.items():
|
164 |
+
inst_events = []
|
165 |
+
for event in events:
|
166 |
+
inst_events.append(event)
|
167 |
+
if event.type == "Note-On":
|
168 |
+
inst_events.append(Event("Velocity", 99))
|
169 |
+
new_events[inst] = inst_events
|
170 |
+
return new_events
|
171 |
+
|
172 |
+
def get_instruments_tuple(self, events):
|
173 |
+
"""Returns instruments tuple for midi generation"""
|
174 |
+
instruments = []
|
175 |
+
for inst in events.keys():
|
176 |
+
is_drum = 0
|
177 |
+
if inst == "DRUMS":
|
178 |
+
inst = 0
|
179 |
+
is_drum = 1
|
180 |
+
if self.familized:
|
181 |
+
inst = Familizer(arbitrary=True).get_program_number(int(inst))
|
182 |
+
instruments.append((int(inst), is_drum))
|
183 |
+
return tuple(instruments)
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
|
188 |
+
filename = "midi/generated/misnaej/the-jam-machine-elec-famil/20221209_175750"
|
189 |
+
encoded_json = readFromFile(
|
190 |
+
f"{filename}.json",
|
191 |
+
True,
|
192 |
+
)
|
193 |
+
encoded_text = encoded_json["sequence"]
|
194 |
+
# encoded_text = "PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=69 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=69 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=57 TIME_DELTA=1 NOTE_OFF=57 NOTE_ON=56 TIME_DELTA=1 NOTE_OFF=56 NOTE_ON=64 NOTE_ON=60 NOTE_ON=55 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=55 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=59 NOTE_ON=55 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=59 NOTE_OFF=50 NOTE_OFF=55 NOTE_OFF=50 BAR_END BAR_START BAR_END TRACK_END"
|
195 |
+
|
196 |
+
miditok = get_miditok()
|
197 |
+
TextDecoder(miditok).get_midi(encoded_text, filename=filename)
|
familizer.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from joblib import Parallel, delayed
|
3 |
+
from pathlib import Path
|
4 |
+
from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
|
5 |
+
from utils import get_files, timeit, FileCompressor
|
6 |
+
|
7 |
+
|
8 |
+
class Familizer:
|
9 |
+
def __init__(self, n_jobs=-1, arbitrary=False):
|
10 |
+
self.n_jobs = n_jobs
|
11 |
+
self.reverse_family(arbitrary)
|
12 |
+
|
13 |
+
def get_family_number(self, program_number):
|
14 |
+
"""
|
15 |
+
Given a MIDI instrument number, return its associated instrument family number.
|
16 |
+
"""
|
17 |
+
for instrument_class in INSTRUMENT_CLASSES:
|
18 |
+
if program_number in instrument_class["program_range"]:
|
19 |
+
return instrument_class["family_number"]
|
20 |
+
|
21 |
+
def reverse_family(self, arbitrary):
|
22 |
+
"""
|
23 |
+
Create a dictionary of family numbers to randomly assigned program numbers.
|
24 |
+
This is used to reverse the family number tokens back to program number tokens.
|
25 |
+
"""
|
26 |
+
|
27 |
+
if arbitrary is True:
|
28 |
+
int_class = INSTRUMENT_TRANSFER_CLASSES
|
29 |
+
else:
|
30 |
+
int_class = INSTRUMENT_CLASSES
|
31 |
+
|
32 |
+
self.reference_programs = {}
|
33 |
+
for family in int_class:
|
34 |
+
self.reference_programs[family["family_number"]] = random.choice(
|
35 |
+
family["program_range"]
|
36 |
+
)
|
37 |
+
|
38 |
+
def get_program_number(self, family_number):
|
39 |
+
"""
|
40 |
+
Given given a family number return a random program number in the respective program_range.
|
41 |
+
This is the reverse operation of get_family_number.
|
42 |
+
"""
|
43 |
+
assert family_number in self.reference_programs
|
44 |
+
return self.reference_programs[family_number]
|
45 |
+
|
46 |
+
# Replace instruments in text files
|
47 |
+
def replace_instrument_token(self, token):
|
48 |
+
"""
|
49 |
+
Given a MIDI program number in a word token, replace it with the family or program
|
50 |
+
number token depending on the operation.
|
51 |
+
e.g. INST=86 -> INST=10
|
52 |
+
"""
|
53 |
+
inst_number = int(token.split("=")[1])
|
54 |
+
if self.operation == "family":
|
55 |
+
return "INST=" + str(self.get_family_number(inst_number))
|
56 |
+
elif self.operation == "program":
|
57 |
+
return "INST=" + str(self.get_program_number(inst_number))
|
58 |
+
|
59 |
+
def replace_instrument_in_text(self, text):
|
60 |
+
"""Given a text piece, replace all instrument tokens with family number tokens."""
|
61 |
+
return " ".join(
|
62 |
+
[
|
63 |
+
self.replace_instrument_token(token)
|
64 |
+
if token.startswith("INST=") and not token == "INST=DRUMS"
|
65 |
+
else token
|
66 |
+
for token in text.split(" ")
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
def replace_instruments_in_file(self, file):
|
71 |
+
"""Given a text file, replace all instrument tokens with family number tokens."""
|
72 |
+
text = file.read_text()
|
73 |
+
file.write_text(self.replace_instrument_in_text(text))
|
74 |
+
|
75 |
+
@timeit
|
76 |
+
def replace_instruments(self):
|
77 |
+
"""
|
78 |
+
Given a directory of text files:
|
79 |
+
Replace all instrument tokens with family number tokens.
|
80 |
+
"""
|
81 |
+
files = get_files(self.output_directory, extension="txt")
|
82 |
+
Parallel(n_jobs=self.n_jobs)(
|
83 |
+
delayed(self.replace_instruments_in_file)(file) for file in files
|
84 |
+
)
|
85 |
+
|
86 |
+
def replace_tokens(self, input_directory, output_directory, operation):
|
87 |
+
"""
|
88 |
+
Given a directory and an operation, perform the operation on all text files in the directory.
|
89 |
+
operation can be either 'family' or 'program'.
|
90 |
+
"""
|
91 |
+
self.input_directory = input_directory
|
92 |
+
self.output_directory = output_directory
|
93 |
+
self.operation = operation
|
94 |
+
|
95 |
+
# Uncompress files, replace tokens, compress files
|
96 |
+
fc = FileCompressor(self.input_directory, self.output_directory, self.n_jobs)
|
97 |
+
fc.unzip()
|
98 |
+
self.replace_instruments()
|
99 |
+
fc.zip()
|
100 |
+
print(self.operation + " complete.")
|
101 |
+
|
102 |
+
def to_family(self, input_directory, output_directory):
|
103 |
+
"""
|
104 |
+
Given a directory containing zip files, replace all instrument tokens with
|
105 |
+
family number tokens. The output is a directory of zip files.
|
106 |
+
"""
|
107 |
+
self.replace_tokens(input_directory, output_directory, "family")
|
108 |
+
|
109 |
+
def to_program(self, input_directory, output_directory):
|
110 |
+
"""
|
111 |
+
Given a directory containing zip files, replace all instrument tokens with
|
112 |
+
program number tokens. The output is a directory of zip files.
|
113 |
+
"""
|
114 |
+
self.replace_tokens(input_directory, output_directory, "program")
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
|
119 |
+
# Choose number of jobs for parallel processing
|
120 |
+
n_jobs = -1
|
121 |
+
|
122 |
+
# Instantiate Familizer
|
123 |
+
familizer = Familizer(n_jobs)
|
124 |
+
|
125 |
+
# Choose directory to process for program
|
126 |
+
input_directory = Path("midi/dataset/first_selection/validate").resolve() # fmt: skip
|
127 |
+
output_directory = input_directory / "family"
|
128 |
+
|
129 |
+
# familize files
|
130 |
+
familizer.to_family(input_directory, output_directory)
|
131 |
+
|
132 |
+
# Choose directory to process for family
|
133 |
+
# input_directory = Path("../data/music_picks/encoded_samples/validate/family").resolve() # fmt: skip
|
134 |
+
# output_directory = input_directory.parent / "program"
|
135 |
+
|
136 |
+
# # programize files
|
137 |
+
# familizer.to_program(input_directory, output_directory)
|
generate.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from generation_utils import *
|
2 |
+
from utils import WriteTextMidiToFile, get_miditok
|
3 |
+
from load import LoadModel
|
4 |
+
from decoder import TextDecoder
|
5 |
+
from playback import get_music
|
6 |
+
|
7 |
+
|
8 |
+
class GenerateMidiText:
|
9 |
+
"""Generating music with Class
|
10 |
+
|
11 |
+
LOGIC:
|
12 |
+
|
13 |
+
FOR GENERATING FROM SCRATCH:
|
14 |
+
- self.generate_one_new_track()
|
15 |
+
it calls
|
16 |
+
- self.generate_until_track_end()
|
17 |
+
|
18 |
+
FOR GENERATING NEW BARS:
|
19 |
+
- self.generate_one_more_bar()
|
20 |
+
it calls
|
21 |
+
- self.process_prompt_for_next_bar()
|
22 |
+
- self.generate_until_track_end()"""
|
23 |
+
|
24 |
+
def __init__(self, model, tokenizer, piece_by_track=[]):
|
25 |
+
self.model = model
|
26 |
+
self.tokenizer = tokenizer
|
27 |
+
# default initialization
|
28 |
+
self.initialize_default_parameters()
|
29 |
+
self.initialize_dictionaries(piece_by_track)
|
30 |
+
|
31 |
+
"""Setters"""
|
32 |
+
|
33 |
+
def initialize_default_parameters(self):
|
34 |
+
self.set_device()
|
35 |
+
self.set_attention_length()
|
36 |
+
self.generate_until = "TRACK_END"
|
37 |
+
self.set_force_sequence_lenth()
|
38 |
+
self.set_nb_bars_generated()
|
39 |
+
self.set_improvisation_level(0)
|
40 |
+
|
41 |
+
def initialize_dictionaries(self, piece_by_track):
|
42 |
+
self.piece_by_track = piece_by_track
|
43 |
+
|
44 |
+
def set_device(self, device="cpu"):
|
45 |
+
self.device = ("cpu",)
|
46 |
+
|
47 |
+
def set_attention_length(self):
|
48 |
+
self.max_length = self.model.config.n_positions
|
49 |
+
print(
|
50 |
+
f"Attention length set to {self.max_length} -> 'model.config.n_positions'"
|
51 |
+
)
|
52 |
+
|
53 |
+
def set_force_sequence_lenth(self, force_sequence_length=True):
|
54 |
+
self.force_sequence_length = force_sequence_length
|
55 |
+
|
56 |
+
def set_improvisation_level(self, improvisation_value):
|
57 |
+
self.no_repeat_ngram_size = improvisation_value
|
58 |
+
print("--------------------")
|
59 |
+
print(f"no_repeat_ngram_size set to {improvisation_value}")
|
60 |
+
print("--------------------")
|
61 |
+
|
62 |
+
def reset_temperatures(self, track_id, temperature):
|
63 |
+
self.piece_by_track[track_id]["temperature"] = temperature
|
64 |
+
|
65 |
+
def set_nb_bars_generated(self, n_bars=8): # default is a 8 bar model
|
66 |
+
self.model_n_bar = n_bars
|
67 |
+
|
68 |
+
""" Generation Tools - Dictionnaries """
|
69 |
+
|
70 |
+
def initiate_track_dict(self, instr, density, temperature):
|
71 |
+
label = len(self.piece_by_track)
|
72 |
+
self.piece_by_track.append(
|
73 |
+
{
|
74 |
+
"label": f"track_{label}",
|
75 |
+
"instrument": instr,
|
76 |
+
"density": density,
|
77 |
+
"temperature": temperature,
|
78 |
+
"bars": [],
|
79 |
+
}
|
80 |
+
)
|
81 |
+
|
82 |
+
def update_track_dict__add_bars(self, bars, track_id):
|
83 |
+
"""Add bars to the track dictionnary"""
|
84 |
+
for bar in self.striping_track_ends(bars).split("BAR_START "):
|
85 |
+
if bar == "": # happens is there is one bar only
|
86 |
+
continue
|
87 |
+
else:
|
88 |
+
if "TRACK_START" in bar:
|
89 |
+
self.piece_by_track[track_id]["bars"].append(bar)
|
90 |
+
else:
|
91 |
+
self.piece_by_track[track_id]["bars"].append("BAR_START " + bar)
|
92 |
+
|
93 |
+
def get_all_instr_bars(self, track_id):
|
94 |
+
return self.piece_by_track[track_id]["bars"]
|
95 |
+
|
96 |
+
def striping_track_ends(self, text):
|
97 |
+
if "TRACK_END" in text:
|
98 |
+
# first get rid of extra space if any
|
99 |
+
# then gets rid of "TRACK_END"
|
100 |
+
text = text.rstrip(" ").rstrip("TRACK_END")
|
101 |
+
return text
|
102 |
+
|
103 |
+
def get_last_generated_track(self, full_piece):
|
104 |
+
track = (
|
105 |
+
"TRACK_START "
|
106 |
+
+ self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
|
107 |
+
+ "TRACK_END "
|
108 |
+
) # forcing the space after track and
|
109 |
+
return track
|
110 |
+
|
111 |
+
def get_selected_track_as_text(self, track_id):
|
112 |
+
text = ""
|
113 |
+
for bar in self.piece_by_track[track_id]["bars"]:
|
114 |
+
text += bar
|
115 |
+
text += "TRACK_END "
|
116 |
+
return text
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def get_newly_generated_text(input_prompt, full_piece):
|
120 |
+
return full_piece[len(input_prompt) :]
|
121 |
+
|
122 |
+
def get_whole_piece_from_bar_dict(self):
|
123 |
+
text = "PIECE_START "
|
124 |
+
for track_id, _ in enumerate(self.piece_by_track):
|
125 |
+
text += self.get_selected_track_as_text(track_id)
|
126 |
+
return text
|
127 |
+
|
128 |
+
def delete_one_track(self, track): # TO BE TESTED
|
129 |
+
self.piece_by_track.pop(track)
|
130 |
+
|
131 |
+
# def update_piece_dict__add_track(self, track_id, track):
|
132 |
+
# self.piece_dict[track_id] = track
|
133 |
+
|
134 |
+
# def update_all_dictionnaries__add_track(self, track):
|
135 |
+
# self.update_piece_dict__add_track(track_id, track)
|
136 |
+
|
137 |
+
"""Basic generation tools"""
|
138 |
+
|
139 |
+
def tokenize_input_prompt(self, input_prompt, verbose=True):
|
140 |
+
"""Tokenizing prompt
|
141 |
+
|
142 |
+
Args:
|
143 |
+
- input_prompt (str): prompt to tokenize
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
- input_prompt_ids (torch.tensor): tokenized prompt
|
147 |
+
"""
|
148 |
+
if verbose:
|
149 |
+
print("Tokenizing input_prompt...")
|
150 |
+
|
151 |
+
return self.tokenizer.encode(input_prompt, return_tensors="pt")
|
152 |
+
|
153 |
+
def generate_sequence_of_token_ids(
|
154 |
+
self,
|
155 |
+
input_prompt_ids,
|
156 |
+
temperature,
|
157 |
+
verbose=True,
|
158 |
+
):
|
159 |
+
"""
|
160 |
+
generate a sequence of token ids based on input_prompt_ids
|
161 |
+
The sequence length depends on the trained model (self.model_n_bar)
|
162 |
+
"""
|
163 |
+
generated_ids = self.model.generate(
|
164 |
+
input_prompt_ids,
|
165 |
+
max_length=self.max_length,
|
166 |
+
do_sample=True,
|
167 |
+
temperature=temperature,
|
168 |
+
no_repeat_ngram_size=self.no_repeat_ngram_size, # default = 0
|
169 |
+
eos_token_id=self.tokenizer.encode(self.generate_until)[0], # good
|
170 |
+
)
|
171 |
+
|
172 |
+
if verbose:
|
173 |
+
print("Generating a token_id sequence...")
|
174 |
+
|
175 |
+
return generated_ids
|
176 |
+
|
177 |
+
def convert_ids_to_text(self, generated_ids, verbose=True):
|
178 |
+
"""converts the token_ids to text"""
|
179 |
+
generated_text = self.tokenizer.decode(generated_ids[0])
|
180 |
+
if verbose:
|
181 |
+
print("Converting token sequence to MidiText...")
|
182 |
+
return generated_text
|
183 |
+
|
184 |
+
def generate_until_track_end(
|
185 |
+
self,
|
186 |
+
input_prompt="PIECE_START ",
|
187 |
+
instrument=None,
|
188 |
+
density=None,
|
189 |
+
temperature=None,
|
190 |
+
verbose=True,
|
191 |
+
expected_length=None,
|
192 |
+
):
|
193 |
+
|
194 |
+
"""generate until the TRACK_END token is reached
|
195 |
+
full_piece = input_prompt + generated"""
|
196 |
+
if expected_length is None:
|
197 |
+
expected_length = self.model_n_bar
|
198 |
+
|
199 |
+
if instrument is not None:
|
200 |
+
input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} "
|
201 |
+
if density is not None:
|
202 |
+
input_prompt = f"{input_prompt}DENSITY={str(density)} "
|
203 |
+
|
204 |
+
if instrument is None and density is not None:
|
205 |
+
print("Density cannot be defined without an input_prompt instrument #TOFIX")
|
206 |
+
|
207 |
+
if temperature is None:
|
208 |
+
ValueError("Temperature must be defined")
|
209 |
+
|
210 |
+
if verbose:
|
211 |
+
print("--------------------")
|
212 |
+
print(
|
213 |
+
f"Generating {instrument} - Density {density} - temperature {temperature}"
|
214 |
+
)
|
215 |
+
bar_count_checks = False
|
216 |
+
failed = 0
|
217 |
+
while not bar_count_checks: # regenerate until right length
|
218 |
+
input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose)
|
219 |
+
generated_tokens = self.generate_sequence_of_token_ids(
|
220 |
+
input_prompt_ids, temperature, verbose=verbose
|
221 |
+
)
|
222 |
+
full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose)
|
223 |
+
generated = self.get_newly_generated_text(input_prompt, full_piece)
|
224 |
+
# bar_count_checks
|
225 |
+
bar_count_checks, bar_count = bar_count_check(generated, expected_length)
|
226 |
+
|
227 |
+
if not self.force_sequence_length:
|
228 |
+
# set bar_count_checks to true to exist the while loop
|
229 |
+
bar_count_checks = True
|
230 |
+
|
231 |
+
if not bar_count_checks and self.force_sequence_length:
|
232 |
+
# if the generated sequence is not the expected length
|
233 |
+
if failed > -1: # deactivated for speed
|
234 |
+
full_piece, bar_count_checks = forcing_bar_count(
|
235 |
+
input_prompt,
|
236 |
+
generated,
|
237 |
+
bar_count,
|
238 |
+
expected_length,
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
print('"--- Wrong length - Regenerating ---')
|
242 |
+
if not bar_count_checks:
|
243 |
+
failed += 1
|
244 |
+
if failed > 2:
|
245 |
+
bar_count_checks = True # TOFIX exit the while loop
|
246 |
+
|
247 |
+
return full_piece
|
248 |
+
|
249 |
+
def generate_one_new_track(
|
250 |
+
self,
|
251 |
+
instrument,
|
252 |
+
density,
|
253 |
+
temperature,
|
254 |
+
input_prompt="PIECE_START ",
|
255 |
+
):
|
256 |
+
self.initiate_track_dict(instrument, density, temperature)
|
257 |
+
full_piece = self.generate_until_track_end(
|
258 |
+
input_prompt=input_prompt,
|
259 |
+
instrument=instrument,
|
260 |
+
density=density,
|
261 |
+
temperature=temperature,
|
262 |
+
)
|
263 |
+
|
264 |
+
track = self.get_last_generated_track(full_piece)
|
265 |
+
self.update_track_dict__add_bars(track, -1)
|
266 |
+
full_piece = self.get_whole_piece_from_bar_dict()
|
267 |
+
return full_piece
|
268 |
+
|
269 |
+
""" Piece generation - Basics """
|
270 |
+
|
271 |
+
def generate_piece(self, instrument_list, density_list, temperature_list):
|
272 |
+
"""generate a sequence with mutiple tracks
|
273 |
+
- inst_list sets the list of instruments of the order of generation
|
274 |
+
- density is paired with inst_list
|
275 |
+
Each track/intrument is generated on a prompt which contains the previously generated track/instrument
|
276 |
+
This means that the first instrument is generated with less bias than the next one, and so on.
|
277 |
+
|
278 |
+
'generated_piece' keeps track of the entire piece
|
279 |
+
'generated_piece' is returned by self.generate_until_track_end
|
280 |
+
# it is returned by self.generate_until_track_end"""
|
281 |
+
|
282 |
+
generated_piece = "PIECE_START "
|
283 |
+
for instrument, density, temperature in zip(
|
284 |
+
instrument_list, density_list, temperature_list
|
285 |
+
):
|
286 |
+
generated_piece = self.generate_one_new_track(
|
287 |
+
instrument,
|
288 |
+
density,
|
289 |
+
temperature,
|
290 |
+
input_prompt=generated_piece,
|
291 |
+
)
|
292 |
+
|
293 |
+
# generated_piece = self.get_whole_piece_from_bar_dict()
|
294 |
+
self.check_the_piece_for_errors()
|
295 |
+
return generated_piece
|
296 |
+
|
297 |
+
""" Piece generation - Extra Bars """
|
298 |
+
|
299 |
+
@staticmethod
|
300 |
+
def process_prompt_for_next_bar(self, track_idx):
|
301 |
+
"""Processing the prompt for the model to generate one more bar only.
|
302 |
+
The prompt containts:
|
303 |
+
if not the first bar: the previous, already processed, bars of the track
|
304 |
+
the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ")
|
305 |
+
the last (self.model_n_bar)-1 bars of the track
|
306 |
+
Args:
|
307 |
+
track_idx (int): the index of the track to be processed
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
the processed prompt for generating the next bar
|
311 |
+
"""
|
312 |
+
track = self.piece_by_track[track_idx]
|
313 |
+
# for bars which are not the bar to prolong
|
314 |
+
pre_promt = "PIECE_START "
|
315 |
+
for i, othertrack in enumerate(self.piece_by_track):
|
316 |
+
if i != track_idx:
|
317 |
+
len_diff = len(othertrack["bars"]) - len(track["bars"])
|
318 |
+
if len_diff > 0:
|
319 |
+
# if other bars are longer, it mean that this one should catch up
|
320 |
+
pre_promt += othertrack["bars"][0]
|
321 |
+
for bar in track["bars"][-self.model_n_bar :]:
|
322 |
+
pre_promt += bar
|
323 |
+
pre_promt += "TRACK_END "
|
324 |
+
elif False: # len_diff <= 0: # THIS GENERATES EMPTINESS
|
325 |
+
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
326 |
+
pre_promt += othertracks["bars"][0]
|
327 |
+
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
328 |
+
pre_promt += bar
|
329 |
+
for _ in range(abs(len_diff) + 1):
|
330 |
+
pre_promt += "BAR_START BAR_END "
|
331 |
+
pre_promt += "TRACK_END "
|
332 |
+
|
333 |
+
# for the bar to prolong
|
334 |
+
# initialization e.g TRACK_START INST=DRUMS DENSITY=2
|
335 |
+
processed_prompt = track["bars"][0]
|
336 |
+
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
337 |
+
# adding the "last" bars of the track
|
338 |
+
processed_prompt += bar
|
339 |
+
|
340 |
+
processed_prompt += "BAR_START "
|
341 |
+
print(
|
342 |
+
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
|
343 |
+
)
|
344 |
+
return pre_promt + processed_prompt
|
345 |
+
|
346 |
+
def generate_one_more_bar(self, i):
|
347 |
+
"""Generate one more bar from the input_prompt"""
|
348 |
+
processed_prompt = self.process_prompt_for_next_bar(self, i)
|
349 |
+
prompt_plus_bar = self.generate_until_track_end(
|
350 |
+
input_prompt=processed_prompt,
|
351 |
+
temperature=self.piece_by_track[i]["temperature"],
|
352 |
+
expected_length=1,
|
353 |
+
verbose=False,
|
354 |
+
)
|
355 |
+
added_bar = self.get_newly_generated_bar(prompt_plus_bar)
|
356 |
+
self.update_track_dict__add_bars(added_bar, i)
|
357 |
+
|
358 |
+
def get_newly_generated_bar(self, prompt_plus_bar):
|
359 |
+
return "BAR_START " + self.striping_track_ends(
|
360 |
+
prompt_plus_bar.split("BAR_START ")[-1]
|
361 |
+
)
|
362 |
+
|
363 |
+
def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True):
|
364 |
+
"""Generate n more bars from the input_prompt"""
|
365 |
+
if only_this_track is None:
|
366 |
+
only_this_track
|
367 |
+
|
368 |
+
print(f"================== ")
|
369 |
+
print(f"Adding {n_bars} more bars to the piece ")
|
370 |
+
for bar_id in range(n_bars):
|
371 |
+
print(f"----- added bar #{bar_id+1} --")
|
372 |
+
for i, track in enumerate(self.piece_by_track):
|
373 |
+
if only_this_track is None or i == only_this_track:
|
374 |
+
print(f"--------- {track['label']}")
|
375 |
+
self.generate_one_more_bar(i)
|
376 |
+
self.check_the_piece_for_errors()
|
377 |
+
|
378 |
+
def check_the_piece_for_errors(self, piece: str = None):
|
379 |
+
|
380 |
+
if piece is None:
|
381 |
+
piece = generate_midi.get_whole_piece_from_bar_dict()
|
382 |
+
errors = []
|
383 |
+
errors.append(
|
384 |
+
[
|
385 |
+
(token, id)
|
386 |
+
for id, token in enumerate(piece.split(" "))
|
387 |
+
if token not in self.tokenizer.vocab or token == "UNK"
|
388 |
+
]
|
389 |
+
)
|
390 |
+
if len(errors) > 0:
|
391 |
+
# print(piece)
|
392 |
+
for er in errors:
|
393 |
+
er
|
394 |
+
print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}")
|
395 |
+
print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5])
|
396 |
+
|
397 |
+
|
398 |
+
if __name__ == "__main__":
|
399 |
+
|
400 |
+
# worker
|
401 |
+
DEVICE = "cpu"
|
402 |
+
|
403 |
+
# define generation parameters
|
404 |
+
N_FILES_TO_GENERATE = 2
|
405 |
+
Temperatures_to_try = [0.7]
|
406 |
+
|
407 |
+
USE_FAMILIZED_MODEL = True
|
408 |
+
force_sequence_length = True
|
409 |
+
|
410 |
+
if USE_FAMILIZED_MODEL:
|
411 |
+
# model_repo = "misnaej/the-jam-machine-elec-famil"
|
412 |
+
# model_repo = "misnaej/the-jam-machine-elec-famil-ft32"
|
413 |
+
|
414 |
+
# model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
|
415 |
+
# n_bar_generated = 8
|
416 |
+
|
417 |
+
model_repo = "JammyMachina/improved_4bars-mdl"
|
418 |
+
n_bar_generated = 4
|
419 |
+
instrument_promt_list = ["4", "DRUMS", "3"]
|
420 |
+
# DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
|
421 |
+
density_list = [3, 2, 2]
|
422 |
+
# temperature_list = [0.7, 0.7, 0.75]
|
423 |
+
else:
|
424 |
+
model_repo = "misnaej/the-jam-machine"
|
425 |
+
instrument_promt_list = ["30"] # , "DRUMS", "0"]
|
426 |
+
density_list = [3] # , 2, 3]
|
427 |
+
# temperature_list = [0.7, 0.5, 0.75]
|
428 |
+
pass
|
429 |
+
|
430 |
+
# define generation directory
|
431 |
+
generated_sequence_files_path = define_generation_dir(model_repo)
|
432 |
+
|
433 |
+
# load model and tokenizer
|
434 |
+
model, tokenizer = LoadModel(
|
435 |
+
model_repo, from_huggingface=True
|
436 |
+
).load_model_and_tokenizer()
|
437 |
+
|
438 |
+
# does the prompt make sense
|
439 |
+
check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)
|
440 |
+
|
441 |
+
for temperature in Temperatures_to_try:
|
442 |
+
print(f"================= TEMPERATURE {temperature} =======================")
|
443 |
+
for _ in range(N_FILES_TO_GENERATE):
|
444 |
+
print(f"========================================")
|
445 |
+
# 1 - instantiate
|
446 |
+
generate_midi = GenerateMidiText(model, tokenizer)
|
447 |
+
# 0 - set the n_bar for this model
|
448 |
+
generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
|
449 |
+
# 1 - defines the instruments, densities and temperatures
|
450 |
+
# 2- generate the first 8 bars for each instrument
|
451 |
+
generate_midi.set_improvisation_level(30)
|
452 |
+
generate_midi.generate_piece(
|
453 |
+
instrument_promt_list,
|
454 |
+
density_list,
|
455 |
+
[temperature for _ in density_list],
|
456 |
+
)
|
457 |
+
# 3 - force the model to improvise
|
458 |
+
# generate_midi.set_improvisation_level(20)
|
459 |
+
# 4 - generate the next 4 bars for each instrument
|
460 |
+
# generate_midi.generate_n_more_bars(n_bar_generated)
|
461 |
+
# 5 - lower the improvisation level
|
462 |
+
generate_midi.generated_piece = (
|
463 |
+
generate_midi.get_whole_piece_from_bar_dict()
|
464 |
+
)
|
465 |
+
|
466 |
+
# print the generated sequence in terminal
|
467 |
+
print("=========================================")
|
468 |
+
print(generate_midi.generated_piece)
|
469 |
+
print("=========================================")
|
470 |
+
|
471 |
+
# write to JSON file
|
472 |
+
filename = WriteTextMidiToFile(
|
473 |
+
generate_midi,
|
474 |
+
generated_sequence_files_path,
|
475 |
+
).text_midi_to_file()
|
476 |
+
|
477 |
+
# decode the sequence to MIDI """
|
478 |
+
decode_tokenizer = get_miditok()
|
479 |
+
TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
|
480 |
+
generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
|
481 |
+
)
|
482 |
+
inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
|
483 |
+
max_time = get_max_time(inst_midi)
|
484 |
+
plot_piano_roll(inst_midi)
|
485 |
+
|
486 |
+
print("Et voilà! Your MIDI file is ready! GO JAM!")
|
generation_utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib
|
5 |
+
|
6 |
+
from constants import INSTRUMENT_CLASSES
|
7 |
+
from playback import get_music, show_piano_roll
|
8 |
+
|
9 |
+
# matplotlib settings
|
10 |
+
matplotlib.use("Agg") # for server
|
11 |
+
matplotlib.rcParams["xtick.major.size"] = 0
|
12 |
+
matplotlib.rcParams["ytick.major.size"] = 0
|
13 |
+
matplotlib.rcParams["axes.facecolor"] = "none"
|
14 |
+
matplotlib.rcParams["axes.edgecolor"] = "grey"
|
15 |
+
|
16 |
+
|
17 |
+
def define_generation_dir(model_repo_path):
|
18 |
+
#### to remove later ####
|
19 |
+
if model_repo_path == "models/model_2048_fake_wholedataset":
|
20 |
+
model_repo_path = "misnaej/the-jam-machine"
|
21 |
+
#### to remove later ####
|
22 |
+
generated_sequence_files_path = f"midi/generated/{model_repo_path}"
|
23 |
+
if not os.path.exists(generated_sequence_files_path):
|
24 |
+
os.makedirs(generated_sequence_files_path)
|
25 |
+
return generated_sequence_files_path
|
26 |
+
|
27 |
+
|
28 |
+
def bar_count_check(sequence, n_bars):
|
29 |
+
"""check if the sequence contains the right number of bars"""
|
30 |
+
sequence = sequence.split(" ")
|
31 |
+
# find occurences of "BAR_END" in a "sequence"
|
32 |
+
# I don't check for "BAR_START" because it is not always included in "sequence"
|
33 |
+
# e.g. BAR_START is included the prompt when generating one more bar
|
34 |
+
bar_count = 0
|
35 |
+
for seq in sequence:
|
36 |
+
if seq == "BAR_END":
|
37 |
+
bar_count += 1
|
38 |
+
bar_count_matches = bar_count == n_bars
|
39 |
+
if not bar_count_matches:
|
40 |
+
print(f"Bar count is {bar_count} - but should be {n_bars}")
|
41 |
+
return bar_count_matches, bar_count
|
42 |
+
|
43 |
+
|
44 |
+
def print_inst_classes(INSTRUMENT_CLASSES):
|
45 |
+
"""Print the instrument classes"""
|
46 |
+
for classe in INSTRUMENT_CLASSES:
|
47 |
+
print(f"{classe}")
|
48 |
+
|
49 |
+
|
50 |
+
def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
|
51 |
+
"""Check if the prompt instrument are in the tokenizer vocab"""
|
52 |
+
for inst in inst_prompt_list:
|
53 |
+
if f"INST={inst}" not in tokenizer.vocab:
|
54 |
+
instruments_in_dataset = np.sort(
|
55 |
+
[tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok]
|
56 |
+
)
|
57 |
+
print_inst_classes(INSTRUMENT_CLASSES)
|
58 |
+
raise ValueError(
|
59 |
+
f"""The instrument {inst} is not in the tokenizer vocabulary.
|
60 |
+
Available Instruments: {instruments_in_dataset}"""
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
65 |
+
"""Forcing the generated sequence to have the expected length
|
66 |
+
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
|
67 |
+
|
68 |
+
if bar_count - expected_length > 0: # Cut the sequence if too long
|
69 |
+
full_piece = ""
|
70 |
+
splited = generated.split("BAR_END ")
|
71 |
+
for count, spl in enumerate(splited):
|
72 |
+
if count < expected_length:
|
73 |
+
full_piece += spl + "BAR_END "
|
74 |
+
|
75 |
+
full_piece += "TRACK_END "
|
76 |
+
full_piece = input_prompt + full_piece
|
77 |
+
print(f"Generated sequence trunkated at {expected_length} bars")
|
78 |
+
bar_count_checks = True
|
79 |
+
|
80 |
+
elif bar_count - expected_length < 0: # Do nothing it the sequence if too short
|
81 |
+
full_piece = input_prompt + generated
|
82 |
+
bar_count_checks = False
|
83 |
+
print(f"--- Generated sequence is too short - Force Regeration ---")
|
84 |
+
|
85 |
+
return full_piece, bar_count_checks
|
86 |
+
|
87 |
+
|
88 |
+
def get_max_time(inst_midi):
|
89 |
+
max_time = 0
|
90 |
+
for inst in inst_midi.instruments:
|
91 |
+
max_time = max(max_time, inst.get_end_time())
|
92 |
+
return max_time
|
93 |
+
|
94 |
+
|
95 |
+
def plot_piano_roll(inst_midi):
|
96 |
+
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
|
97 |
+
piano_roll_fig.tight_layout()
|
98 |
+
piano_roll_fig.patch.set_alpha(0)
|
99 |
+
inst_count = 0
|
100 |
+
beats_per_bar = 4
|
101 |
+
sec_per_beat = 0.5
|
102 |
+
next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0]
|
103 |
+
bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype(
|
104 |
+
int
|
105 |
+
)
|
106 |
+
for inst in inst_midi.instruments:
|
107 |
+
# hardcoded for now
|
108 |
+
if inst.name == "Drums":
|
109 |
+
color = "purple"
|
110 |
+
elif inst.name == "Synth Bass 1":
|
111 |
+
color = "orange"
|
112 |
+
else:
|
113 |
+
color = "green"
|
114 |
+
|
115 |
+
inst_count += 1
|
116 |
+
plt.subplot(len(inst_midi.instruments), 1, inst_count)
|
117 |
+
|
118 |
+
for bar in bars_time:
|
119 |
+
plt.axvline(bar, color="grey", linewidth=0.5)
|
120 |
+
octaves = np.arange(0, 128, 12)
|
121 |
+
for octave in octaves:
|
122 |
+
plt.axhline(octave, color="grey", linewidth=0.5)
|
123 |
+
plt.yticks(octaves, visible=False)
|
124 |
+
|
125 |
+
p_midi_note_list = inst.notes
|
126 |
+
note_time = []
|
127 |
+
note_pitch = []
|
128 |
+
for note in p_midi_note_list:
|
129 |
+
note_time.append([note.start, note.end])
|
130 |
+
note_pitch.append([note.pitch, note.pitch])
|
131 |
+
note_pitch = np.array(note_pitch)
|
132 |
+
note_time = np.array(note_time)
|
133 |
+
|
134 |
+
plt.plot(
|
135 |
+
note_time.T,
|
136 |
+
note_pitch.T,
|
137 |
+
color=color,
|
138 |
+
linewidth=4,
|
139 |
+
solid_capstyle="butt",
|
140 |
+
)
|
141 |
+
plt.ylim(0, 128)
|
142 |
+
xticks = np.array(bars_time)[:-1]
|
143 |
+
plt.tight_layout()
|
144 |
+
plt.xlim(min(bars_time), max(bars_time))
|
145 |
+
plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5)
|
146 |
+
plt.xticks(
|
147 |
+
xticks + 0.5 * beats_per_bar * sec_per_beat,
|
148 |
+
labels=xticks.argsort() + 1,
|
149 |
+
visible=False,
|
150 |
+
)
|
151 |
+
plt.text(
|
152 |
+
0.2,
|
153 |
+
note_pitch.max() + 4,
|
154 |
+
inst.name,
|
155 |
+
fontsize=20,
|
156 |
+
color=color,
|
157 |
+
horizontalalignment="left",
|
158 |
+
verticalalignment="top",
|
159 |
+
)
|
160 |
+
|
161 |
+
return piano_roll_fig
|
load.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2LMHeadModel
|
2 |
+
from transformers import PreTrainedTokenizerFast
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class LoadModel:
|
8 |
+
"""
|
9 |
+
Example usage:
|
10 |
+
|
11 |
+
# if loading model and tokenizer from Huggingface
|
12 |
+
model_repo = "misnaej/the-jam-machine"
|
13 |
+
model, tokenizer = LoadModel(
|
14 |
+
model_repo, from_huggingface=True
|
15 |
+
).load_model_and_tokenizer()
|
16 |
+
|
17 |
+
# if loading model and tokenizer from a local folder
|
18 |
+
model_path = "models/model_2048_wholedataset"
|
19 |
+
model, tokenizer = LoadModel(
|
20 |
+
model_path, from_huggingface=False
|
21 |
+
).load_model_and_tokenizer()
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, path, from_huggingface=True, device="cpu", revision=None):
|
26 |
+
# path is either a relative path on a local/remote machine or a model repo on HuggingFace
|
27 |
+
if not from_huggingface:
|
28 |
+
if not os.path.exists(path):
|
29 |
+
print(path)
|
30 |
+
raise Exception("Model path does not exist")
|
31 |
+
self.from_huggingface = from_huggingface
|
32 |
+
self.path = path
|
33 |
+
self.device = device
|
34 |
+
self.revision = revision
|
35 |
+
if torch.cuda.is_available():
|
36 |
+
self.device = "cuda"
|
37 |
+
|
38 |
+
def load_model_and_tokenizer(self):
|
39 |
+
model = self.load_model()
|
40 |
+
tokenizer = self.load_tokenizer()
|
41 |
+
|
42 |
+
return model, tokenizer
|
43 |
+
|
44 |
+
def load_model(self):
|
45 |
+
if self.revision is None:
|
46 |
+
model = GPT2LMHeadModel.from_pretrained(self.path) # .to(self.device)
|
47 |
+
else:
|
48 |
+
model = GPT2LMHeadModel.from_pretrained(
|
49 |
+
self.path, revision=self.revision
|
50 |
+
) # .to(self.device)
|
51 |
+
|
52 |
+
return model
|
53 |
+
|
54 |
+
def load_tokenizer(self):
|
55 |
+
if self.from_huggingface:
|
56 |
+
pass
|
57 |
+
else:
|
58 |
+
if not os.path.exists(f"{self.path}/tokenizer.json"):
|
59 |
+
raise Exception(
|
60 |
+
f"There is no 'tokenizer.json'file in the defined {self.path}"
|
61 |
+
)
|
62 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(self.path)
|
63 |
+
return tokenizer
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
fluidsynth
|
playback.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import librosa.display
|
3 |
+
from pretty_midi import PrettyMIDI
|
4 |
+
|
5 |
+
|
6 |
+
# Note: these functions are meant to be played within an interactive Python shell
|
7 |
+
# Please refer to the synth.ipynb for an example of how to use them
|
8 |
+
|
9 |
+
|
10 |
+
def get_music(midi_file):
|
11 |
+
"""
|
12 |
+
Load a midi file and return the PrettyMIDI object and the audio signal
|
13 |
+
"""
|
14 |
+
music = PrettyMIDI(midi_file=midi_file)
|
15 |
+
waveform = music.fluidsynth()
|
16 |
+
return music, waveform
|
17 |
+
|
18 |
+
|
19 |
+
def show_piano_roll(music_notes, fs=100):
|
20 |
+
"""
|
21 |
+
Show the piano roll of a music piece, with all instruments squashed onto a single 128xN matrix
|
22 |
+
:param music_notes: PrettyMIDI object
|
23 |
+
:param fs: sampling frequency
|
24 |
+
"""
|
25 |
+
# get the piano roll
|
26 |
+
piano_roll = music_notes.get_piano_roll(fs)
|
27 |
+
print("Piano roll shape: {}".format(piano_roll.shape))
|
28 |
+
|
29 |
+
# plot the piano roll
|
30 |
+
plt.figure(figsize=(12, 4))
|
31 |
+
librosa.display.specshow(piano_roll, sr=100, x_axis="time", y_axis="cqt_note")
|
32 |
+
plt.colorbar()
|
33 |
+
plt.title("Piano roll")
|
34 |
+
plt.tight_layout()
|
35 |
+
plt.show()
|
playground.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import gradio as gr
|
3 |
+
from load import LoadModel
|
4 |
+
from generate import GenerateMidiText
|
5 |
+
from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
|
6 |
+
from decoder import TextDecoder
|
7 |
+
from utils import get_miditok, index_has_substring
|
8 |
+
from playback import get_music
|
9 |
+
from matplotlib import pylab
|
10 |
+
import sys
|
11 |
+
import matplotlib
|
12 |
+
from generation_utils import plot_piano_roll
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
matplotlib.use("Agg")
|
16 |
+
|
17 |
+
sys.modules["pylab"] = pylab
|
18 |
+
|
19 |
+
model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
|
20 |
+
n_bar_generated = 8
|
21 |
+
# model_repo = "JammyMachina/improved_4bars-mdl"
|
22 |
+
# n_bar_generated = 4
|
23 |
+
|
24 |
+
model, tokenizer = LoadModel(
|
25 |
+
model_repo,
|
26 |
+
from_huggingface=True,
|
27 |
+
).load_model_and_tokenizer()
|
28 |
+
|
29 |
+
miditok = get_miditok()
|
30 |
+
decoder = TextDecoder(miditok)
|
31 |
+
|
32 |
+
|
33 |
+
def define_prompt(state, genesis):
|
34 |
+
if len(state) == 0:
|
35 |
+
input_prompt = "PIECE_START "
|
36 |
+
else:
|
37 |
+
input_prompt = genesis.get_whole_piece_from_bar_dict()
|
38 |
+
return input_prompt
|
39 |
+
|
40 |
+
|
41 |
+
def generator(
|
42 |
+
label,
|
43 |
+
regenerate,
|
44 |
+
temp,
|
45 |
+
density,
|
46 |
+
instrument,
|
47 |
+
state,
|
48 |
+
piece_by_track,
|
49 |
+
add_bars=False,
|
50 |
+
add_bar_count=1,
|
51 |
+
):
|
52 |
+
|
53 |
+
genesis = GenerateMidiText(model, tokenizer, piece_by_track)
|
54 |
+
track = {"label": label}
|
55 |
+
inst = next(
|
56 |
+
(
|
57 |
+
inst
|
58 |
+
for inst in INSTRUMENT_TRANSFER_CLASSES
|
59 |
+
if inst["transfer_to"] == instrument
|
60 |
+
),
|
61 |
+
{"family_number": "DRUMS"},
|
62 |
+
)["family_number"]
|
63 |
+
|
64 |
+
inst_index = -1 # default to last generated
|
65 |
+
if state != []:
|
66 |
+
for index, instrum in enumerate(state):
|
67 |
+
if instrum["label"] == track["label"]:
|
68 |
+
inst_index = index # changing if exists
|
69 |
+
|
70 |
+
# Generate
|
71 |
+
if not add_bars:
|
72 |
+
# Regenerate
|
73 |
+
if regenerate:
|
74 |
+
state.pop(inst_index)
|
75 |
+
genesis.delete_one_track(inst_index)
|
76 |
+
|
77 |
+
generated_text = (
|
78 |
+
genesis.get_whole_piece_from_bar_dict()
|
79 |
+
) # maybe not useful here
|
80 |
+
inst_index = -1 # reset to last generated
|
81 |
+
|
82 |
+
# NEW TRACK
|
83 |
+
input_prompt = define_prompt(state, genesis)
|
84 |
+
generated_text = genesis.generate_one_new_track(
|
85 |
+
inst, density, temp, input_prompt=input_prompt
|
86 |
+
)
|
87 |
+
|
88 |
+
regenerate = True # set generate to true
|
89 |
+
else:
|
90 |
+
# NEW BARS
|
91 |
+
genesis.generate_n_more_bars(add_bar_count) # for all instruments
|
92 |
+
generated_text = genesis.get_whole_piece_from_bar_dict()
|
93 |
+
|
94 |
+
decoder.get_midi(generated_text, "mixed.mid")
|
95 |
+
mixed_inst_midi, mixed_audio = get_music("mixed.mid")
|
96 |
+
|
97 |
+
inst_text = genesis.get_selected_track_as_text(inst_index)
|
98 |
+
inst_midi_name = f"{instrument}.mid"
|
99 |
+
decoder.get_midi(inst_text, inst_midi_name)
|
100 |
+
_, inst_audio = get_music(inst_midi_name)
|
101 |
+
piano_roll = plot_piano_roll(mixed_inst_midi)
|
102 |
+
track["text"] = inst_text
|
103 |
+
state.append(track)
|
104 |
+
|
105 |
+
return (
|
106 |
+
inst_text,
|
107 |
+
(44100, inst_audio),
|
108 |
+
piano_roll,
|
109 |
+
state,
|
110 |
+
(44100, mixed_audio),
|
111 |
+
regenerate,
|
112 |
+
genesis.piece_by_track,
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def instrument_row(default_inst, row_id):
|
117 |
+
with gr.Row():
|
118 |
+
row = gr.Variable(row_id)
|
119 |
+
with gr.Column(scale=1, min_width=100):
|
120 |
+
inst = gr.Dropdown(
|
121 |
+
sorted([inst["transfer_to"] for inst in INSTRUMENT_TRANSFER_CLASSES])
|
122 |
+
+ ["Drums"],
|
123 |
+
value=default_inst,
|
124 |
+
label="Instrument",
|
125 |
+
)
|
126 |
+
temp = gr.Dropdown(
|
127 |
+
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1],
|
128 |
+
value=0.7,
|
129 |
+
label="Creativity",
|
130 |
+
)
|
131 |
+
density = gr.Dropdown([1, 2, 3], value=3, label="Note Density")
|
132 |
+
|
133 |
+
with gr.Column(scale=3):
|
134 |
+
output_txt = gr.Textbox(
|
135 |
+
label="output", lines=10, max_lines=10, show_label=False
|
136 |
+
)
|
137 |
+
with gr.Column(scale=1, min_width=100):
|
138 |
+
inst_audio = gr.Audio(label="TRACK Audio", show_label=True)
|
139 |
+
regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False)
|
140 |
+
# add_bars = gr.Checkbox(value=False, label="Add Bars")
|
141 |
+
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
|
142 |
+
gen_btn = gr.Button("Generate")
|
143 |
+
gen_btn.click(
|
144 |
+
fn=generator,
|
145 |
+
inputs=[row, regenerate, temp, density, inst, state, piece_by_track],
|
146 |
+
outputs=[
|
147 |
+
output_txt,
|
148 |
+
inst_audio,
|
149 |
+
piano_roll,
|
150 |
+
state,
|
151 |
+
mixed_audio,
|
152 |
+
regenerate,
|
153 |
+
piece_by_track,
|
154 |
+
],
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
with gr.Blocks() as demo:
|
159 |
+
piece_by_track = gr.State([])
|
160 |
+
state = gr.State([])
|
161 |
+
title = gr.Markdown(
|
162 |
+
""" # Demo-App of The-Jam-Machine
|
163 |
+
A Generative AI trained on text transcription of MIDI music """
|
164 |
+
)
|
165 |
+
track1_md = gr.Markdown(""" ## Mixed Audio and Piano Roll """)
|
166 |
+
mixed_audio = gr.Audio(label="Mixed Audio")
|
167 |
+
piano_roll = gr.Plot(label="Piano Roll", show_label=False)
|
168 |
+
description = gr.Markdown(
|
169 |
+
"""
|
170 |
+
For each **TRACK**, choose your **instrument** along with **creativity** (temperature) and **note density**. Then, hit the **Generate** Button!
|
171 |
+
You can have a look at the generated text; but most importantly, check the **piano roll** and listen to the TRACK audio!
|
172 |
+
If you don't like the track, hit the generate button to regenerate it! Generate more tracks and listen to the **mixed audio**!
|
173 |
+
"""
|
174 |
+
)
|
175 |
+
track1_md = gr.Markdown(""" ## TRACK 1 """)
|
176 |
+
instrument_row("Drums", 0)
|
177 |
+
track1_md = gr.Markdown(""" ## TRACK 2 """)
|
178 |
+
instrument_row("Synth Bass 1", 1)
|
179 |
+
track1_md = gr.Markdown(""" ## TRACK 2 """)
|
180 |
+
instrument_row("Synth Lead Square", 2)
|
181 |
+
# instrument_row("Piano")
|
182 |
+
|
183 |
+
demo.launch(debug=True)
|
184 |
+
|
185 |
+
"""
|
186 |
+
TODO: reset button
|
187 |
+
TODO: add a button to save the generated midi
|
188 |
+
TODO: add improvise button
|
189 |
+
TODO: set values for temperature as it is done for density
|
190 |
+
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
|
191 |
+
TODO: row height to fix
|
192 |
+
|
193 |
+
TODO: reset state of tick boxes after used maybe (regenerate, add bars) ;
|
194 |
+
TODO: block regenerate if add bar on
|
195 |
+
"""
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
matplotlib
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
joblib
|
6 |
+
pathlib
|
7 |
+
transformers
|
8 |
+
miditok == 1.3.2
|
9 |
+
librosa
|
10 |
+
pretty_midi
|
11 |
+
pydub
|
12 |
+
scipy
|
13 |
+
datetime
|
14 |
+
torch
|
15 |
+
torchvision
|
16 |
+
pyFluidSynth
|
17 |
+
accelerate
|
utils.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from miditok import Event, MIDILike
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from time import perf_counter
|
6 |
+
from joblib import Parallel, delayed
|
7 |
+
from zipfile import ZipFile, ZIP_DEFLATED
|
8 |
+
from scipy.io.wavfile import write
|
9 |
+
import numpy as np
|
10 |
+
from pydub import AudioSegment
|
11 |
+
import shutil
|
12 |
+
|
13 |
+
|
14 |
+
def writeToFile(path, content):
|
15 |
+
if type(content) is dict:
|
16 |
+
with open(f"{path}", "w") as json_file:
|
17 |
+
json.dump(content, json_file)
|
18 |
+
else:
|
19 |
+
if type(content) is not str:
|
20 |
+
content = str(content)
|
21 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
22 |
+
with open(path, "w") as f:
|
23 |
+
f.write(content)
|
24 |
+
|
25 |
+
|
26 |
+
# Function to read from text from txt file:
|
27 |
+
def readFromFile(path, isJSON=False):
|
28 |
+
with open(path, "r") as f:
|
29 |
+
if isJSON:
|
30 |
+
return json.load(f)
|
31 |
+
else:
|
32 |
+
return f.read()
|
33 |
+
|
34 |
+
|
35 |
+
def chain(input, funcs, *params):
|
36 |
+
res = input
|
37 |
+
for func in funcs:
|
38 |
+
try:
|
39 |
+
res = func(res, *params)
|
40 |
+
except TypeError:
|
41 |
+
res = func(res)
|
42 |
+
return res
|
43 |
+
|
44 |
+
|
45 |
+
def to_beat_str(value, beat_res=8):
|
46 |
+
values = [
|
47 |
+
int(int(value * beat_res) / beat_res),
|
48 |
+
int(int(value * beat_res) % beat_res),
|
49 |
+
beat_res,
|
50 |
+
]
|
51 |
+
return ".".join(map(str, values))
|
52 |
+
|
53 |
+
|
54 |
+
def to_base10(beat_str):
|
55 |
+
integer, decimal, base = split_dots(beat_str)
|
56 |
+
return integer + decimal / base
|
57 |
+
|
58 |
+
|
59 |
+
def split_dots(value):
|
60 |
+
return list(map(int, value.split(".")))
|
61 |
+
|
62 |
+
|
63 |
+
def compute_list_average(l):
|
64 |
+
return sum(l) / len(l)
|
65 |
+
|
66 |
+
|
67 |
+
def get_datetime():
|
68 |
+
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
69 |
+
|
70 |
+
|
71 |
+
def get_text(event):
|
72 |
+
match event.type:
|
73 |
+
case "Piece-Start":
|
74 |
+
return "PIECE_START "
|
75 |
+
case "Track-Start":
|
76 |
+
return "TRACK_START "
|
77 |
+
case "Track-End":
|
78 |
+
return "TRACK_END "
|
79 |
+
case "Instrument":
|
80 |
+
return f"INST={event.value} "
|
81 |
+
case "Bar-Start":
|
82 |
+
return "BAR_START "
|
83 |
+
case "Bar-End":
|
84 |
+
return "BAR_END "
|
85 |
+
case "Time-Shift":
|
86 |
+
return f"TIME_SHIFT={event.value} "
|
87 |
+
case "Note-On":
|
88 |
+
return f"NOTE_ON={event.value} "
|
89 |
+
case "Note-Off":
|
90 |
+
return f"NOTE_OFF={event.value} "
|
91 |
+
case _:
|
92 |
+
return ""
|
93 |
+
|
94 |
+
|
95 |
+
def get_event(text, value=None):
|
96 |
+
match text:
|
97 |
+
case "PIECE_START":
|
98 |
+
return Event("Piece-Start", value)
|
99 |
+
case "TRACK_START":
|
100 |
+
return None
|
101 |
+
case "TRACK_END":
|
102 |
+
return None
|
103 |
+
case "INST":
|
104 |
+
return Event("Instrument", value)
|
105 |
+
case "BAR_START":
|
106 |
+
return Event("Bar-Start", value)
|
107 |
+
case "BAR_END":
|
108 |
+
return Event("Bar-End", value)
|
109 |
+
case "TIME_SHIFT":
|
110 |
+
return Event("Time-Shift", value)
|
111 |
+
case "TIME_DELTA":
|
112 |
+
return Event("Time-Shift", to_beat_str(int(value) / 4))
|
113 |
+
case "NOTE_ON":
|
114 |
+
return Event("Note-On", value)
|
115 |
+
case "NOTE_OFF":
|
116 |
+
return Event("Note-Off", value)
|
117 |
+
case _:
|
118 |
+
return None
|
119 |
+
|
120 |
+
|
121 |
+
# TODO: Make this singleton
|
122 |
+
def get_miditok():
|
123 |
+
pitch_range = range(0, 140) # was (21, 109)
|
124 |
+
beat_res = {(0, 400): 8}
|
125 |
+
return MIDILike(pitch_range, beat_res)
|
126 |
+
|
127 |
+
|
128 |
+
class WriteTextMidiToFile: # utils saving to file
|
129 |
+
def __init__(self, generate_midi, output_path):
|
130 |
+
self.generated_midi = generate_midi.generated_piece
|
131 |
+
self.output_path = output_path
|
132 |
+
self.hyperparameter_and_bars = generate_midi.piece_by_track
|
133 |
+
|
134 |
+
def hashing_seq(self):
|
135 |
+
self.current_time = get_datetime()
|
136 |
+
self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
|
137 |
+
|
138 |
+
def wrapping_seq_hyperparameters_in_dict(self):
|
139 |
+
# assert type(self.generated_midi) is str, "error: generate_midi must be a string"
|
140 |
+
# assert (
|
141 |
+
# type(self.hyperparameter_dict) is dict
|
142 |
+
# ), "error: feature_dict must be a dictionnary"
|
143 |
+
return {
|
144 |
+
"generate_midi": self.generated_midi,
|
145 |
+
"hyperparameters_and_bars": self.hyperparameter_and_bars,
|
146 |
+
}
|
147 |
+
|
148 |
+
def text_midi_to_file(self):
|
149 |
+
self.hashing_seq()
|
150 |
+
output_dict = self.wrapping_seq_hyperparameters_in_dict()
|
151 |
+
print(f"Token generate_midi written: {self.output_path_filename}")
|
152 |
+
writeToFile(self.output_path_filename, output_dict)
|
153 |
+
return self.output_path_filename
|
154 |
+
|
155 |
+
|
156 |
+
def get_files(directory, extension, recursive=False):
|
157 |
+
"""
|
158 |
+
Given a directory, get a list of the file paths of all files matching the
|
159 |
+
specified file extension.
|
160 |
+
directory: the directory to search as a Path object
|
161 |
+
extension: the file extension to match as a string
|
162 |
+
recursive: whether to search recursively in the directory or not
|
163 |
+
"""
|
164 |
+
if recursive:
|
165 |
+
return list(directory.rglob(f"*.{extension}"))
|
166 |
+
else:
|
167 |
+
return list(directory.glob(f"*.{extension}"))
|
168 |
+
|
169 |
+
|
170 |
+
def timeit(func):
|
171 |
+
def wrapper(*args, **kwargs):
|
172 |
+
start = perf_counter()
|
173 |
+
result = func(*args, **kwargs)
|
174 |
+
end = perf_counter()
|
175 |
+
print(f"{func.__name__} took {end - start:.2f} seconds to run.")
|
176 |
+
return result
|
177 |
+
|
178 |
+
return wrapper
|
179 |
+
|
180 |
+
|
181 |
+
class FileCompressor:
|
182 |
+
def __init__(self, input_directory, output_directory, n_jobs=-1):
|
183 |
+
self.input_directory = input_directory
|
184 |
+
self.output_directory = output_directory
|
185 |
+
self.n_jobs = n_jobs
|
186 |
+
|
187 |
+
# File compression and decompression
|
188 |
+
def unzip_file(self, file):
|
189 |
+
"""uncompress single zip file"""
|
190 |
+
with ZipFile(file, "r") as zip_ref:
|
191 |
+
zip_ref.extractall(self.output_directory)
|
192 |
+
|
193 |
+
def zip_file(self, file):
|
194 |
+
"""compress a single text file to a new zip file and delete the original"""
|
195 |
+
output_file = self.output_directory / (file.stem + ".zip")
|
196 |
+
with ZipFile(output_file, "w") as zip_ref:
|
197 |
+
zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED)
|
198 |
+
file.unlink()
|
199 |
+
|
200 |
+
@timeit
|
201 |
+
def unzip(self):
|
202 |
+
"""uncompress all zip files in folder"""
|
203 |
+
files = get_files(self.input_directory, extension="zip")
|
204 |
+
Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files)
|
205 |
+
|
206 |
+
@timeit
|
207 |
+
def zip(self):
|
208 |
+
"""compress all text files in folder to new zip files and remove the text files"""
|
209 |
+
files = get_files(self.output_directory, extension="txt")
|
210 |
+
Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)
|
211 |
+
|
212 |
+
|
213 |
+
def load_jsonl(filepath):
|
214 |
+
"""Load a jsonl file"""
|
215 |
+
with open(filepath, "r") as f:
|
216 |
+
data = [json.loads(line) for line in f]
|
217 |
+
return data
|
218 |
+
|
219 |
+
|
220 |
+
def write_mp3(waveform, output_path, bitrate="92k"):
|
221 |
+
"""
|
222 |
+
Write a waveform to an mp3 file.
|
223 |
+
output_path: Path object for the output mp3 file
|
224 |
+
waveform: numpy array of the waveform
|
225 |
+
bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k)
|
226 |
+
"""
|
227 |
+
# write the wav file
|
228 |
+
wav_path = output_path.with_suffix(".wav")
|
229 |
+
write(wav_path, 44100, waveform.astype(np.float32))
|
230 |
+
# compress the wav file as mp3
|
231 |
+
AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate)
|
232 |
+
# remove the wav file
|
233 |
+
wav_path.unlink()
|
234 |
+
|
235 |
+
|
236 |
+
def copy_file(input_file, output_dir):
|
237 |
+
"""Copy an input file to the output_dir"""
|
238 |
+
output_file = output_dir / input_file.name
|
239 |
+
shutil.copy(input_file, output_file)
|
240 |
+
|
241 |
+
|
242 |
+
def index_has_substring(list, substring):
|
243 |
+
for i, s in enumerate(list):
|
244 |
+
if substring in s:
|
245 |
+
return i
|
246 |
+
return -1
|