tspsram commited on
Commit
35410df
1 Parent(s): 08ab31c

Upload 19 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mid filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105
+ __pypackages__/
106
+
107
+ # Celery stuff
108
+ celerybeat-schedule
109
+ celerybeat.pid
110
+
111
+ # SageMath parsed files
112
+ *.sage.py
113
+
114
+ # Environments
115
+ .env
116
+ .venv
117
+ env/
118
+ venv/
119
+ ENV/
120
+ env.bak/
121
+ venv.bak/
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ # Pyre type checker
139
+ .pyre/
140
+
141
+ # pytype static type analyzer
142
+ .pytype/
143
+
144
+ # Cython debug symbols
145
+ cython_debug/
146
+
147
+ # PyCharm
148
+ # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
149
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
150
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
151
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
152
+ .idea/
153
+ output.mid
154
+ /outputs/
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ python3.9 \
10
+ python3-pip \
11
+ git \
12
+ ffmpeg \
13
+ fluidsynth \
14
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
15
+
16
+ WORKDIR /code
17
+
18
+ COPY ./requirements.txt /code/requirements.txt
19
+
20
+ # Set up a new user named "user" with user ID 1000
21
+ RUN useradd -m -u 1000 user
22
+ # Switch to the "user" user
23
+ USER user
24
+ # Set home to the user's home directory
25
+ ENV HOME=/home/user \
26
+ PATH=/home/user/.local/bin:$PATH \
27
+ PYTHONPATH=$HOME/app \
28
+ PYTHONUNBUFFERED=1 \
29
+ GRADIO_ALLOW_FLAGGING=never \
30
+ GRADIO_NUM_PORTS=1 \
31
+ GRADIO_SERVER_NAME=0.0.0.0 \
32
+ GRADIO_THEME=huggingface \
33
+ SYSTEM=spaces
34
+
35
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
36
+
37
+ # Set the working directory to the user's home directory
38
+ WORKDIR $HOME/app
39
+
40
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
41
+ COPY --chown=user . $HOME/app
42
+
43
+ CMD ["python3", "app.py"]
MIDI.py ADDED
@@ -0,0 +1,1735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python3
2
+ # unsupported 20091104 ...
3
+ # ['set_sequence_number', dtime, sequence]
4
+ # ['raw_data', dtime, raw]
5
+
6
+ # 20150914 jimbo1qaz MIDI.py str/bytes bug report
7
+ # I found a MIDI file which had Shift-JIS titles. When midi.py decodes it as
8
+ # latin-1, it produces a string which cannot even be accessed without raising
9
+ # a UnicodeDecodeError. Maybe, when converting raw byte strings from MIDI,
10
+ # you should keep them as bytes, not improperly decode them. However, this
11
+ # would change the API. (ie: text = a "string" ? of 0 or more bytes). It
12
+ # could break compatiblity, but there's not much else you can do to fix the bug
13
+ # https://en.wikipedia.org/wiki/Shift_JIS
14
+
15
+ r'''
16
+ This module offers functions: concatenate_scores(), grep(),
17
+ merge_scores(), mix_scores(), midi2opus(), midi2score(), opus2midi(),
18
+ opus2score(), play_score(), score2midi(), score2opus(), score2stats(),
19
+ score_type(), segment(), timeshift() and to_millisecs(),
20
+ where "midi" means the MIDI-file bytes (as can be put in a .mid file,
21
+ or piped into aplaymidi), and "opus" and "score" are list-structures
22
+ as inspired by Sean Burke's MIDI-Perl CPAN module.
23
+
24
+ Warning: Version 6.4 is not necessarily backward-compatible with
25
+ previous versions, in that text-data is now bytes, not strings.
26
+ This reflects the fact that many MIDI files have text data in
27
+ encodings other that ISO-8859-1, for example in Shift-JIS.
28
+
29
+ Download MIDI.py from http://www.pjb.com.au/midi/free/MIDI.py
30
+ and put it in your PYTHONPATH. MIDI.py depends on Python3.
31
+
32
+ There is also a call-compatible translation into Lua of this
33
+ module: see http://www.pjb.com.au/comp/lua/MIDI.html
34
+
35
+ The "opus" is a direct translation of the midi-file-events, where
36
+ the times are delta-times, in ticks, since the previous event.
37
+
38
+ The "score" is more human-centric; it uses absolute times, and
39
+ combines the separate note_on and note_off events into one "note"
40
+ event, with a duration:
41
+ ['note', start_time, duration, channel, note, velocity] # in a "score"
42
+
43
+ EVENTS (in an "opus" structure)
44
+ ['note_off', dtime, channel, note, velocity] # in an "opus"
45
+ ['note_on', dtime, channel, note, velocity] # in an "opus"
46
+ ['key_after_touch', dtime, channel, note, velocity]
47
+ ['control_change', dtime, channel, controller(0-127), value(0-127)]
48
+ ['patch_change', dtime, channel, patch]
49
+ ['channel_after_touch', dtime, channel, velocity]
50
+ ['pitch_wheel_change', dtime, channel, pitch_wheel]
51
+ ['text_event', dtime, text]
52
+ ['copyright_text_event', dtime, text]
53
+ ['track_name', dtime, text]
54
+ ['instrument_name', dtime, text]
55
+ ['lyric', dtime, text]
56
+ ['marker', dtime, text]
57
+ ['cue_point', dtime, text]
58
+ ['text_event_08', dtime, text]
59
+ ['text_event_09', dtime, text]
60
+ ['text_event_0a', dtime, text]
61
+ ['text_event_0b', dtime, text]
62
+ ['text_event_0c', dtime, text]
63
+ ['text_event_0d', dtime, text]
64
+ ['text_event_0e', dtime, text]
65
+ ['text_event_0f', dtime, text]
66
+ ['end_track', dtime]
67
+ ['set_tempo', dtime, tempo]
68
+ ['smpte_offset', dtime, hr, mn, se, fr, ff]
69
+ ['time_signature', dtime, nn, dd, cc, bb]
70
+ ['key_signature', dtime, sf, mi]
71
+ ['sequencer_specific', dtime, raw]
72
+ ['raw_meta_event', dtime, command(0-255), raw]
73
+ ['sysex_f0', dtime, raw]
74
+ ['sysex_f7', dtime, raw]
75
+ ['song_position', dtime, song_pos]
76
+ ['song_select', dtime, song_number]
77
+ ['tune_request', dtime]
78
+
79
+ DATA TYPES
80
+ channel = a value 0 to 15
81
+ controller = 0 to 127 (see http://www.pjb.com.au/muscript/gm.html#cc )
82
+ dtime = time measured in "ticks", 0 to 268435455
83
+ velocity = a value 0 (soft) to 127 (loud)
84
+ note = a value 0 to 127 (middle-C is 60)
85
+ patch = 0 to 127 (see http://www.pjb.com.au/muscript/gm.html )
86
+ pitch_wheel = a value -8192 to 8191 (0x1FFF)
87
+ raw = bytes, of length 0 or more (for sysex events see below)
88
+ sequence_number = a value 0 to 65,535 (0xFFFF)
89
+ song_pos = a value 0 to 16,383 (0x3FFF)
90
+ song_number = a value 0 to 127
91
+ tempo = microseconds per crochet (quarter-note), 0 to 16777215
92
+ text = bytes, of length 0 or more
93
+ ticks = the number of ticks per crochet (quarter-note)
94
+
95
+ In sysex_f0 events, the raw data must not start with a \xF0 byte,
96
+ since this gets added automatically;
97
+ but it must end with an explicit \xF7 byte!
98
+ In the very unlikely case that you ever need to split sysex data
99
+ into one sysex_f0 followed by one or more sysex_f7s, then only the
100
+ last of those sysex_f7 events must end with the explicit \xF7 byte
101
+ (again, the raw data of individual sysex_f7 events must not start
102
+ with any \xF7 byte, since this gets added automatically).
103
+
104
+ Since version 6.4, text data is in bytes, not in a ISO-8859-1 string.
105
+
106
+
107
+ GOING THROUGH A SCORE WITHIN A PYTHON PROGRAM
108
+ channels = {2,3,5,8,13}
109
+ itrack = 1 # skip 1st element which is ticks
110
+ while itrack < len(score):
111
+ for event in score[itrack]:
112
+ if event[0] == 'note': # for example,
113
+ pass # do something to all notes
114
+ # or, to work on events in only particular channels...
115
+ channel_index = MIDI.Event2channelindex.get(event[0], False)
116
+ if channel_index and (event[channel_index] in channels):
117
+ pass # do something to channels 2,3,5,8 and 13
118
+ itrack += 1
119
+
120
+ '''
121
+
122
+ import sys, struct, copy
123
+ # sys.stdout = os.fdopen(sys.stdout.fileno(), 'wb')
124
+ Version = '6.7'
125
+ VersionDate = '20201120'
126
+ # 20201120 6.7 call to bytest() removed, and protect _unshift_ber_int
127
+ # 20160702 6.6 to_millisecs() now handles set_tempo across multiple Tracks
128
+ # 20150921 6.5 segment restores controllers as well as patch and tempo
129
+ # 20150914 6.4 text data is bytes or bytearray, not ISO-8859-1 strings
130
+ # 20150628 6.3 absent any set_tempo, default is 120bpm (see MIDI file spec 1.1)
131
+ # 20150101 6.2 all text events can be 8-bit; let user get the right encoding
132
+ # 20141231 6.1 fix _some_text_event; sequencer_specific data can be 8-bit
133
+ # 20141230 6.0 synth_specific data can be 8-bit
134
+ # 20120504 5.9 add the contents of mid_opus_tracks()
135
+ # 20120208 5.8 fix num_notes_by_channel() ; should be a dict
136
+ # 20120129 5.7 _encode handles empty tracks; score2stats num_notes_by_channel
137
+ # 20111111 5.6 fix patch 45 and 46 in Number2patch, should be Harp
138
+ # 20110129 5.5 add mix_opus_tracks() and event2alsaseq()
139
+ # 20110126 5.4 "previous message repeated N times" to save space on stderr
140
+ # 20110125 5.2 opus2score terminates unended notes at the end of the track
141
+ # 20110124 5.1 the warnings in midi2opus display track_num
142
+ # 21110122 5.0 if garbage, midi2opus returns the opus so far
143
+ # 21110119 4.9 non-ascii chars stripped out of the text_events
144
+ # 21110110 4.8 note_on with velocity=0 treated as a note-off
145
+ # 21110108 4.6 unknown F-series event correctly eats just one byte
146
+ # 21011010 4.2 segment() uses start_time, end_time named params
147
+ # 21011005 4.1 timeshift() must not pad the set_tempo command
148
+ # 21011003 4.0 pitch2note_event must be chapitch2note_event
149
+ # 21010918 3.9 set_sequence_number supported, FWIW
150
+ # 20100913 3.7 many small bugfixes; passes all tests
151
+ # 20100910 3.6 concatenate_scores enforce ticks=1000, just like merge_scores
152
+ # 20100908 3.5 minor bugs fixed in score2stats
153
+ # 20091104 3.4 tune_request now supported
154
+ # 20091104 3.3 fixed bug in decoding song_position and song_select
155
+ # 20091104 3.2 unsupported: set_sequence_number tune_request raw_data
156
+ # 20091101 3.1 document how to traverse a score within Python
157
+ # 20091021 3.0 fixed bug in score2stats detecting GM-mode = 0
158
+ # 20091020 2.9 score2stats reports GM-mode and bank msb,lsb events
159
+ # 20091019 2.8 in merge_scores, channel 9 must remain channel 9 (in GM)
160
+ # 20091018 2.7 handles empty tracks gracefully
161
+ # 20091015 2.6 grep() selects channels
162
+ # 20091010 2.5 merge_scores reassigns channels to avoid conflicts
163
+ # 20091010 2.4 fixed bug in to_millisecs which now only does opusses
164
+ # 20091010 2.3 score2stats returns channels & patch_changes, by_track & total
165
+ # 20091010 2.2 score2stats() returns also pitches and percussion dicts
166
+ # 20091010 2.1 bugs: >= not > in segment, to notice patch_change at time 0
167
+ # 20091010 2.0 bugs: spurious pop(0) ( in _decode sysex
168
+ # 20091008 1.9 bugs: ISO decoding in sysex; str( not int( in note-off warning
169
+ # 20091008 1.8 add concatenate_scores()
170
+ # 20091006 1.7 score2stats() measures nticks and ticks_per_quarter
171
+ # 20091004 1.6 first mix_scores() and merge_scores()
172
+ # 20090424 1.5 timeshift() bugfix: earliest only sees events after from_time
173
+ # 20090330 1.4 timeshift() has also a from_time argument
174
+ # 20090322 1.3 timeshift() has also a start_time argument
175
+ # 20090319 1.2 add segment() and timeshift()
176
+ # 20090301 1.1 add to_millisecs()
177
+
178
+ _previous_warning = '' # 5.4
179
+ _previous_times = 0 # 5.4
180
+ _no_warning = True
181
+ #------------------------------- Encoding stuff --------------------------
182
+
183
+ def opus2midi(opus=[]):
184
+ r'''The argument is a list: the first item in the list is the "ticks"
185
+ parameter, the others are the tracks. Each track is a list
186
+ of midi-events, and each event is itself a list; see above.
187
+ opus2midi() returns a bytestring of the MIDI, which can then be
188
+ written either to a file opened in binary mode (mode='wb'),
189
+ or to stdout by means of: sys.stdout.buffer.write()
190
+
191
+ my_opus = [
192
+ 96,
193
+ [ # track 0:
194
+ ['patch_change', 0, 1, 8], # and these are the events...
195
+ ['note_on', 5, 1, 25, 96],
196
+ ['note_off', 96, 1, 25, 0],
197
+ ['note_on', 0, 1, 29, 96],
198
+ ['note_off', 96, 1, 29, 0],
199
+ ], # end of track 0
200
+ ]
201
+ my_midi = opus2midi(my_opus)
202
+ sys.stdout.buffer.write(my_midi)
203
+ '''
204
+ if len(opus) < 2:
205
+ opus=[1000, [],]
206
+ tracks = copy.deepcopy(opus)
207
+ ticks = int(tracks.pop(0))
208
+ ntracks = len(tracks)
209
+ if ntracks == 1:
210
+ format = 0
211
+ else:
212
+ format = 1
213
+
214
+ my_midi = b"MThd\x00\x00\x00\x06"+struct.pack('>HHH',format,ntracks,ticks)
215
+ for track in tracks:
216
+ events = _encode(track)
217
+ my_midi += b'MTrk' + struct.pack('>I',len(events)) + events
218
+ _clean_up_warnings()
219
+ return my_midi
220
+
221
+
222
+ def score2opus(score=None):
223
+ r'''
224
+ The argument is a list: the first item in the list is the "ticks"
225
+ parameter, the others are the tracks. Each track is a list
226
+ of score-events, and each event is itself a list. A score-event
227
+ is similar to an opus-event (see above), except that in a score:
228
+ 1) the times are expressed as an absolute number of ticks
229
+ from the track's start time
230
+ 2) the pairs of 'note_on' and 'note_off' events in an "opus"
231
+ are abstracted into a single 'note' event in a "score":
232
+ ['note', start_time, duration, channel, pitch, velocity]
233
+ score2opus() returns a list specifying the equivalent "opus".
234
+
235
+ my_score = [
236
+ 96,
237
+ [ # track 0:
238
+ ['patch_change', 0, 1, 8],
239
+ ['note', 5, 96, 1, 25, 96],
240
+ ['note', 101, 96, 1, 29, 96]
241
+ ], # end of track 0
242
+ ]
243
+ my_opus = score2opus(my_score)
244
+ '''
245
+ if len(score) < 2:
246
+ score=[1000, [],]
247
+ tracks = copy.deepcopy(score)
248
+ ticks = int(tracks.pop(0))
249
+ opus_tracks = []
250
+ for scoretrack in tracks:
251
+ time2events = dict([])
252
+ for scoreevent in scoretrack:
253
+ if scoreevent[0] == 'note':
254
+ note_on_event = ['note_on',scoreevent[1],
255
+ scoreevent[3],scoreevent[4],scoreevent[5]]
256
+ note_off_event = ['note_off',scoreevent[1]+scoreevent[2],
257
+ scoreevent[3],scoreevent[4],scoreevent[5]]
258
+ if time2events.get(note_on_event[1]):
259
+ time2events[note_on_event[1]].append(note_on_event)
260
+ else:
261
+ time2events[note_on_event[1]] = [note_on_event,]
262
+ if time2events.get(note_off_event[1]):
263
+ time2events[note_off_event[1]].append(note_off_event)
264
+ else:
265
+ time2events[note_off_event[1]] = [note_off_event,]
266
+ continue
267
+ if time2events.get(scoreevent[1]):
268
+ time2events[scoreevent[1]].append(scoreevent)
269
+ else:
270
+ time2events[scoreevent[1]] = [scoreevent,]
271
+
272
+ sorted_times = [] # list of keys
273
+ for k in time2events.keys():
274
+ sorted_times.append(k)
275
+ sorted_times.sort()
276
+
277
+ sorted_events = [] # once-flattened list of values sorted by key
278
+ for time in sorted_times:
279
+ sorted_events.extend(time2events[time])
280
+
281
+ abs_time = 0
282
+ for event in sorted_events: # convert abs times => delta times
283
+ delta_time = event[1] - abs_time
284
+ abs_time = event[1]
285
+ event[1] = delta_time
286
+ opus_tracks.append(sorted_events)
287
+ opus_tracks.insert(0,ticks)
288
+ _clean_up_warnings()
289
+ return opus_tracks
290
+
291
+ def score2midi(score=None):
292
+ r'''
293
+ Translates a "score" into MIDI, using score2opus() then opus2midi()
294
+ '''
295
+ return opus2midi(score2opus(score))
296
+
297
+ #--------------------------- Decoding stuff ------------------------
298
+
299
+ def midi2opus(midi=b''):
300
+ r'''Translates MIDI into a "opus". For a description of the
301
+ "opus" format, see opus2midi()
302
+ '''
303
+ my_midi=bytearray(midi)
304
+ if len(my_midi) < 4:
305
+ _clean_up_warnings()
306
+ return [1000,[],]
307
+ id = bytes(my_midi[0:4])
308
+ if id != b'MThd':
309
+ _warn("midi2opus: midi starts with "+str(id)+" instead of 'MThd'")
310
+ _clean_up_warnings()
311
+ return [1000,[],]
312
+ [length, format, tracks_expected, ticks] = struct.unpack(
313
+ '>IHHH', bytes(my_midi[4:14]))
314
+ if length != 6:
315
+ _warn("midi2opus: midi header length was "+str(length)+" instead of 6")
316
+ _clean_up_warnings()
317
+ return [1000,[],]
318
+ my_opus = [ticks,]
319
+ my_midi = my_midi[14:]
320
+ track_num = 1 # 5.1
321
+ while len(my_midi) >= 8:
322
+ track_type = bytes(my_midi[0:4])
323
+ if track_type != b'MTrk':
324
+ _warn('midi2opus: Warning: track #'+str(track_num)+' type is '+str(track_type)+" instead of b'MTrk'")
325
+ [track_length] = struct.unpack('>I', my_midi[4:8])
326
+ my_midi = my_midi[8:]
327
+ if track_length > len(my_midi):
328
+ _warn('midi2opus: track #'+str(track_num)+' length '+str(track_length)+' is too large')
329
+ _clean_up_warnings()
330
+ return my_opus # 5.0
331
+ my_midi_track = my_midi[0:track_length]
332
+ my_track = _decode(my_midi_track)
333
+ my_opus.append(my_track)
334
+ my_midi = my_midi[track_length:]
335
+ track_num += 1 # 5.1
336
+ _clean_up_warnings()
337
+ return my_opus
338
+
339
+ def opus2score(opus=[]):
340
+ r'''For a description of the "opus" and "score" formats,
341
+ see opus2midi() and score2opus().
342
+ '''
343
+ if len(opus) < 2:
344
+ _clean_up_warnings()
345
+ return [1000,[],]
346
+ tracks = copy.deepcopy(opus) # couple of slices probably quicker...
347
+ ticks = int(tracks.pop(0))
348
+ score = [ticks,]
349
+ for opus_track in tracks:
350
+ ticks_so_far = 0
351
+ score_track = []
352
+ chapitch2note_on_events = dict([]) # 4.0
353
+ for opus_event in opus_track:
354
+ ticks_so_far += opus_event[1]
355
+ if opus_event[0] == 'note_off' or (opus_event[0] == 'note_on' and opus_event[4] == 0): # 4.8
356
+ cha = opus_event[2]
357
+ pitch = opus_event[3]
358
+ key = cha*128 + pitch
359
+ if chapitch2note_on_events.get(key):
360
+ new_event = chapitch2note_on_events[key].pop(0)
361
+ new_event[2] = ticks_so_far - new_event[1]
362
+ score_track.append(new_event)
363
+ elif pitch > 127:
364
+ pass #_warn('opus2score: note_off with no note_on, bad pitch='+str(pitch))
365
+ else:
366
+ pass #_warn('opus2score: note_off with no note_on cha='+str(cha)+' pitch='+str(pitch))
367
+ elif opus_event[0] == 'note_on':
368
+ cha = opus_event[2]
369
+ pitch = opus_event[3]
370
+ key = cha*128 + pitch
371
+ new_event = ['note',ticks_so_far,0,cha,pitch, opus_event[4]]
372
+ if chapitch2note_on_events.get(key):
373
+ chapitch2note_on_events[key].append(new_event)
374
+ else:
375
+ chapitch2note_on_events[key] = [new_event,]
376
+ else:
377
+ opus_event[1] = ticks_so_far
378
+ score_track.append(opus_event)
379
+ # check for unterminated notes (Oisín) -- 5.2
380
+ for chapitch in chapitch2note_on_events:
381
+ note_on_events = chapitch2note_on_events[chapitch]
382
+ for new_e in note_on_events:
383
+ new_e[2] = ticks_so_far - new_e[1]
384
+ score_track.append(new_e)
385
+ pass #_warn("opus2score: note_on with no note_off cha="+str(new_e[3])+' pitch='+str(new_e[4])+'; adding note_off at end')
386
+ score.append(score_track)
387
+ _clean_up_warnings()
388
+ return score
389
+
390
+ def midi2score(midi=b''):
391
+ r'''
392
+ Translates MIDI into a "score", using midi2opus() then opus2score()
393
+ '''
394
+ return opus2score(midi2opus(midi))
395
+
396
+ def midi2ms_score(midi=b''):
397
+ r'''
398
+ Translates MIDI into a "score" with one beat per second and one
399
+ tick per millisecond, using midi2opus() then to_millisecs()
400
+ then opus2score()
401
+ '''
402
+ return opus2score(to_millisecs(midi2opus(midi)))
403
+
404
+ #------------------------ Other Transformations ---------------------
405
+
406
+ def to_millisecs(old_opus=None):
407
+ r'''Recallibrates all the times in an "opus" to use one beat
408
+ per second and one tick per millisecond. This makes it
409
+ hard to retrieve any information about beats or barlines,
410
+ but it does make it easy to mix different scores together.
411
+ '''
412
+ if old_opus == None:
413
+ return [1000,[],]
414
+ try:
415
+ old_tpq = int(old_opus[0])
416
+ except IndexError: # 5.0
417
+ _warn('to_millisecs: the opus '+str(type(old_opus))+' has no elements')
418
+ return [1000,[],]
419
+ new_opus = [1000,]
420
+ # 6.7 first go through building a table of set_tempos by absolute-tick
421
+ ticks2tempo = {}
422
+ itrack = 1
423
+ while itrack < len(old_opus):
424
+ ticks_so_far = 0
425
+ for old_event in old_opus[itrack]:
426
+ if old_event[0] == 'note':
427
+ raise TypeError('to_millisecs needs an opus, not a score')
428
+ ticks_so_far += old_event[1]
429
+ if old_event[0] == 'set_tempo':
430
+ ticks2tempo[ticks_so_far] = old_event[2]
431
+ itrack += 1
432
+ # then get the sorted-array of their keys
433
+ tempo_ticks = [] # list of keys
434
+ for k in ticks2tempo.keys():
435
+ tempo_ticks.append(k)
436
+ tempo_ticks.sort()
437
+ # then go through converting to millisec, testing if the next
438
+ # set_tempo lies before the next track-event, and using it if so.
439
+ itrack = 1
440
+ while itrack < len(old_opus):
441
+ ms_per_old_tick = 500.0 / old_tpq # float: will round later 6.3
442
+ i_tempo_ticks = 0
443
+ ticks_so_far = 0
444
+ ms_so_far = 0.0
445
+ previous_ms_so_far = 0.0
446
+ new_track = [['set_tempo',0,1000000],] # new "crochet" is 1 sec
447
+ for old_event in old_opus[itrack]:
448
+ # detect if ticks2tempo has something before this event
449
+ # 20160702 if ticks2tempo is at the same time, leave it
450
+ event_delta_ticks = old_event[1]
451
+ if (i_tempo_ticks < len(tempo_ticks) and
452
+ tempo_ticks[i_tempo_ticks] < (ticks_so_far + old_event[1])):
453
+ delta_ticks = tempo_ticks[i_tempo_ticks] - ticks_so_far
454
+ ms_so_far += (ms_per_old_tick * delta_ticks)
455
+ ticks_so_far = tempo_ticks[i_tempo_ticks]
456
+ ms_per_old_tick = ticks2tempo[ticks_so_far] / (1000.0*old_tpq)
457
+ i_tempo_ticks += 1
458
+ event_delta_ticks -= delta_ticks
459
+ new_event = copy.deepcopy(old_event) # now handle the new event
460
+ ms_so_far += (ms_per_old_tick * old_event[1])
461
+ new_event[1] = round(ms_so_far - previous_ms_so_far)
462
+ if old_event[0] != 'set_tempo':
463
+ previous_ms_so_far = ms_so_far
464
+ new_track.append(new_event)
465
+ ticks_so_far += event_delta_ticks
466
+ new_opus.append(new_track)
467
+ itrack += 1
468
+ _clean_up_warnings()
469
+ return new_opus
470
+
471
+ def event2alsaseq(event=None): # 5.5
472
+ r'''Converts an event into the format needed by the alsaseq module,
473
+ http://pp.com.mx/python/alsaseq
474
+ The type of track (opus or score) is autodetected.
475
+ '''
476
+ pass
477
+
478
+ def grep(score=None, channels=None):
479
+ r'''Returns a "score" containing only the channels specified
480
+ '''
481
+ if score == None:
482
+ return [1000,[],]
483
+ ticks = score[0]
484
+ new_score = [ticks,]
485
+ if channels == None:
486
+ return new_score
487
+ channels = set(channels)
488
+ global Event2channelindex
489
+ itrack = 1
490
+ while itrack < len(score):
491
+ new_score.append([])
492
+ for event in score[itrack]:
493
+ channel_index = Event2channelindex.get(event[0], False)
494
+ if channel_index:
495
+ if event[channel_index] in channels:
496
+ new_score[itrack].append(event)
497
+ else:
498
+ new_score[itrack].append(event)
499
+ itrack += 1
500
+ return new_score
501
+
502
+ def play_score(score=None):
503
+ r'''Converts the "score" to midi, and feeds it into 'aplaymidi -'
504
+ '''
505
+ if score == None:
506
+ return
507
+ import subprocess
508
+ pipe = subprocess.Popen(['aplaymidi','-'], stdin=subprocess.PIPE)
509
+ if score_type(score) == 'opus':
510
+ pipe.stdin.write(opus2midi(score))
511
+ else:
512
+ pipe.stdin.write(score2midi(score))
513
+ pipe.stdin.close()
514
+
515
+ def timeshift(score=None, shift=None, start_time=None, from_time=0, tracks={0,1,2,3,4,5,6,7,8,10,12,13,14,15}):
516
+ r'''Returns a "score" shifted in time by "shift" ticks, or shifted
517
+ so that the first event starts at "start_time" ticks.
518
+
519
+ If "from_time" is specified, only those events in the score
520
+ that begin after it are shifted. If "start_time" is less than
521
+ "from_time" (or "shift" is negative), then the intermediate
522
+ notes are deleted, though patch-change events are preserved.
523
+
524
+ If "tracks" are specified, then only those tracks get shifted.
525
+ "tracks" can be a list, tuple or set; it gets converted to set
526
+ internally.
527
+
528
+ It is deprecated to specify both "shift" and "start_time".
529
+ If this does happen, timeshift() will print a warning to
530
+ stderr and ignore the "shift" argument.
531
+
532
+ If "shift" is negative and sufficiently large that it would
533
+ leave some event with a negative tick-value, then the score
534
+ is shifted so that the first event occurs at time 0. This
535
+ also occurs if "start_time" is negative, and is also the
536
+ default if neither "shift" nor "start_time" are specified.
537
+ '''
538
+ #_warn('tracks='+str(tracks))
539
+ if score == None or len(score) < 2:
540
+ return [1000, [],]
541
+ new_score = [score[0],]
542
+ my_type = score_type(score)
543
+ if my_type == '':
544
+ return new_score
545
+ if my_type == 'opus':
546
+ _warn("timeshift: opus format is not supported\n")
547
+ # _clean_up_scores() 6.2; doesn't exist! what was it supposed to do?
548
+ return new_score
549
+ if not (shift == None) and not (start_time == None):
550
+ _warn("timeshift: shift and start_time specified: ignoring shift\n")
551
+ shift = None
552
+ if shift == None:
553
+ if (start_time == None) or (start_time < 0):
554
+ start_time = 0
555
+ # shift = start_time - from_time
556
+
557
+ i = 1 # ignore first element (ticks)
558
+ tracks = set(tracks) # defend against tuples and lists
559
+ earliest = 1000000000
560
+ if not (start_time == None) or shift < 0: # first find the earliest event
561
+ while i < len(score):
562
+ if len(tracks) and not ((i-1) in tracks):
563
+ i += 1
564
+ continue
565
+ for event in score[i]:
566
+ if event[1] < from_time:
567
+ continue # just inspect the to_be_shifted events
568
+ if event[1] < earliest:
569
+ earliest = event[1]
570
+ i += 1
571
+ if earliest > 999999999:
572
+ earliest = 0
573
+ if shift == None:
574
+ shift = start_time - earliest
575
+ elif (earliest + shift) < 0:
576
+ start_time = 0
577
+ shift = 0 - earliest
578
+
579
+ i = 1 # ignore first element (ticks)
580
+ while i < len(score):
581
+ if len(tracks) == 0 or not ((i-1) in tracks): # 3.8
582
+ new_score.append(score[i])
583
+ i += 1
584
+ continue
585
+ new_track = []
586
+ for event in score[i]:
587
+ new_event = list(event)
588
+ #if new_event[1] == 0 and shift > 0 and new_event[0] != 'note':
589
+ # pass
590
+ #elif new_event[1] >= from_time:
591
+ if new_event[1] >= from_time:
592
+ # 4.1 must not rightshift set_tempo
593
+ if new_event[0] != 'set_tempo' or shift<0:
594
+ new_event[1] += shift
595
+ elif (shift < 0) and (new_event[1] >= (from_time+shift)):
596
+ continue
597
+ new_track.append(new_event)
598
+ if len(new_track) > 0:
599
+ new_score.append(new_track)
600
+ i += 1
601
+ _clean_up_warnings()
602
+ return new_score
603
+
604
+ def segment(score=None, start_time=None, end_time=None, start=0, end=100000000,
605
+ tracks={0,1,2,3,4,5,6,7,8,10,11,12,13,14,15}):
606
+ r'''Returns a "score" which is a segment of the one supplied
607
+ as the argument, beginning at "start_time" ticks and ending
608
+ at "end_time" ticks (or at the end if "end_time" is not supplied).
609
+ If the set "tracks" is specified, only those tracks will
610
+ be returned.
611
+ '''
612
+ if score == None or len(score) < 2:
613
+ return [1000, [],]
614
+ if start_time == None: # as of 4.2 start_time is recommended
615
+ start_time = start # start is legacy usage
616
+ if end_time == None: # likewise
617
+ end_time = end
618
+ new_score = [score[0],]
619
+ my_type = score_type(score)
620
+ if my_type == '':
621
+ return new_score
622
+ if my_type == 'opus':
623
+ # more difficult (disconnecting note_on's from their note_off's)...
624
+ _warn("segment: opus format is not supported\n")
625
+ _clean_up_warnings()
626
+ return new_score
627
+ i = 1 # ignore first element (ticks); we count in ticks anyway
628
+ tracks = set(tracks) # defend against tuples and lists
629
+ while i < len(score):
630
+ if len(tracks) and not ((i-1) in tracks):
631
+ i += 1
632
+ continue
633
+ new_track = []
634
+ channel2cc_num = {} # most recent controller change before start
635
+ channel2cc_val = {}
636
+ channel2cc_time = {}
637
+ channel2patch_num = {} # keep most recent patch change before start
638
+ channel2patch_time = {}
639
+ set_tempo_num = 500000 # most recent tempo change before start 6.3
640
+ set_tempo_time = 0
641
+ earliest_note_time = end_time
642
+ for event in score[i]:
643
+ if event[0] == 'control_change': # 6.5
644
+ cc_time = channel2cc_time.get(event[2]) or 0
645
+ if (event[1] <= start_time) and (event[1] >= cc_time):
646
+ channel2cc_num[event[2]] = event[3]
647
+ channel2cc_val[event[2]] = event[4]
648
+ channel2cc_time[event[2]] = event[1]
649
+ elif event[0] == 'patch_change':
650
+ patch_time = channel2patch_time.get(event[2]) or 0
651
+ if (event[1]<=start_time) and (event[1] >= patch_time): # 2.0
652
+ channel2patch_num[event[2]] = event[3]
653
+ channel2patch_time[event[2]] = event[1]
654
+ elif event[0] == 'set_tempo':
655
+ if (event[1]<=start_time) and (event[1]>=set_tempo_time): #6.4
656
+ set_tempo_num = event[2]
657
+ set_tempo_time = event[1]
658
+ if (event[1] >= start_time) and (event[1] <= end_time):
659
+ new_track.append(event)
660
+ if (event[0] == 'note') and (event[1] < earliest_note_time):
661
+ earliest_note_time = event[1]
662
+ if len(new_track) > 0:
663
+ new_track.append(['set_tempo', start_time, set_tempo_num])
664
+ for c in channel2patch_num:
665
+ new_track.append(['patch_change',start_time,c,channel2patch_num[c]],)
666
+ for c in channel2cc_num: # 6.5
667
+ new_track.append(['control_change',start_time,c,channel2cc_num[c],channel2cc_val[c]])
668
+ new_score.append(new_track)
669
+ i += 1
670
+ _clean_up_warnings()
671
+ return new_score
672
+
673
+ def score_type(opus_or_score=None):
674
+ r'''Returns a string, either 'opus' or 'score' or ''
675
+ '''
676
+ if opus_or_score == None or str(type(opus_or_score)).find('list')<0 or len(opus_or_score) < 2:
677
+ return ''
678
+ i = 1 # ignore first element
679
+ while i < len(opus_or_score):
680
+ for event in opus_or_score[i]:
681
+ if event[0] == 'note':
682
+ return 'score'
683
+ elif event[0] == 'note_on':
684
+ return 'opus'
685
+ i += 1
686
+ return ''
687
+
688
+ def concatenate_scores(scores):
689
+ r'''Concatenates a list of scores into one score.
690
+ If the scores differ in their "ticks" parameter,
691
+ they will all get converted to millisecond-tick format.
692
+ '''
693
+ # the deepcopys are needed if the input_score's are refs to the same obj
694
+ # e.g. if invoked by midisox's repeat()
695
+ input_scores = _consistentise_ticks(scores) # 3.7
696
+ output_score = copy.deepcopy(input_scores[0])
697
+ for input_score in input_scores[1:]:
698
+ output_stats = score2stats(output_score)
699
+ delta_ticks = output_stats['nticks']
700
+ itrack = 1
701
+ while itrack < len(input_score):
702
+ if itrack >= len(output_score): # new output track if doesn't exist
703
+ output_score.append([])
704
+ for event in input_score[itrack]:
705
+ output_score[itrack].append(copy.deepcopy(event))
706
+ output_score[itrack][-1][1] += delta_ticks
707
+ itrack += 1
708
+ return output_score
709
+
710
+ def merge_scores(scores):
711
+ r'''Merges a list of scores into one score. A merged score comprises
712
+ all of the tracks from all of the input scores; un-merging is possible
713
+ by selecting just some of the tracks. If the scores differ in their
714
+ "ticks" parameter, they will all get converted to millisecond-tick
715
+ format. merge_scores attempts to resolve channel-conflicts,
716
+ but there are of course only 15 available channels...
717
+ '''
718
+ input_scores = _consistentise_ticks(scores) # 3.6
719
+ output_score = [1000]
720
+ channels_so_far = set()
721
+ all_channels = {0,1,2,3,4,5,6,7,8,10,11,12,13,14,15}
722
+ global Event2channelindex
723
+ for input_score in input_scores:
724
+ new_channels = set(score2stats(input_score).get('channels_total', []))
725
+ new_channels.discard(9) # 2.8 cha9 must remain cha9 (in GM)
726
+ for channel in channels_so_far & new_channels:
727
+ # consistently choose lowest avaiable, to ease testing
728
+ free_channels = list(all_channels - (channels_so_far|new_channels))
729
+ if len(free_channels) > 0:
730
+ free_channels.sort()
731
+ free_channel = free_channels[0]
732
+ else:
733
+ free_channel = None
734
+ break
735
+ itrack = 1
736
+ while itrack < len(input_score):
737
+ for input_event in input_score[itrack]:
738
+ channel_index=Event2channelindex.get(input_event[0],False)
739
+ if channel_index and input_event[channel_index]==channel:
740
+ input_event[channel_index] = free_channel
741
+ itrack += 1
742
+ channels_so_far.add(free_channel)
743
+
744
+ channels_so_far |= new_channels
745
+ output_score.extend(input_score[1:])
746
+ return output_score
747
+
748
+ def _ticks(event):
749
+ return event[1]
750
+ def mix_opus_tracks(input_tracks): # 5.5
751
+ r'''Mixes an array of tracks into one track. A mixed track
752
+ cannot be un-mixed. It is assumed that the tracks share the same
753
+ ticks parameter and the same tempo.
754
+ Mixing score-tracks is trivial (just insert all events into one array).
755
+ Mixing opus-tracks is only slightly harder, but it's common enough
756
+ that a dedicated function is useful.
757
+ '''
758
+ output_score = [1000, []]
759
+ for input_track in input_tracks: # 5.8
760
+ input_score = opus2score([1000, input_track])
761
+ for event in input_score[1]:
762
+ output_score[1].append(event)
763
+ output_score[1].sort(key=_ticks)
764
+ output_opus = score2opus(output_score)
765
+ return output_opus[1]
766
+
767
+ def mix_scores(scores):
768
+ r'''Mixes a list of scores into one one-track score.
769
+ A mixed score cannot be un-mixed. Hopefully the scores
770
+ have no undesirable channel-conflicts between them.
771
+ If the scores differ in their "ticks" parameter,
772
+ they will all get converted to millisecond-tick format.
773
+ '''
774
+ input_scores = _consistentise_ticks(scores) # 3.6
775
+ output_score = [1000, []]
776
+ for input_score in input_scores:
777
+ for input_track in input_score[1:]:
778
+ output_score[1].extend(input_track)
779
+ return output_score
780
+
781
+ def score2stats(opus_or_score=None):
782
+ r'''Returns a dict of some basic stats about the score, like
783
+ bank_select (list of tuples (msb,lsb)),
784
+ channels_by_track (list of lists), channels_total (set),
785
+ general_midi_mode (list),
786
+ ntracks, nticks, patch_changes_by_track (list of dicts),
787
+ num_notes_by_channel (list of numbers),
788
+ patch_changes_total (set),
789
+ percussion (dict histogram of channel 9 events),
790
+ pitches (dict histogram of pitches on channels other than 9),
791
+ pitch_range_by_track (list, by track, of two-member-tuples),
792
+ pitch_range_sum (sum over tracks of the pitch_ranges),
793
+ '''
794
+ bank_select_msb = -1
795
+ bank_select_lsb = -1
796
+ bank_select = []
797
+ channels_by_track = []
798
+ channels_total = set([])
799
+ general_midi_mode = []
800
+ num_notes_by_channel = dict([])
801
+ patches_used_by_track = []
802
+ patches_used_total = set([])
803
+ patch_changes_by_track = []
804
+ patch_changes_total = set([])
805
+ percussion = dict([]) # histogram of channel 9 "pitches"
806
+ pitches = dict([]) # histogram of pitch-occurrences channels 0-8,10-15
807
+ pitch_range_sum = 0 # u pitch-ranges of each track
808
+ pitch_range_by_track = []
809
+ is_a_score = True
810
+ if opus_or_score == None:
811
+ return {'bank_select':[], 'channels_by_track':[], 'channels_total':[],
812
+ 'general_midi_mode':[], 'ntracks':0, 'nticks':0,
813
+ 'num_notes_by_channel':dict([]),
814
+ 'patch_changes_by_track':[], 'patch_changes_total':[],
815
+ 'percussion':{}, 'pitches':{}, 'pitch_range_by_track':[],
816
+ 'ticks_per_quarter':0, 'pitch_range_sum':0}
817
+ ticks_per_quarter = opus_or_score[0]
818
+ i = 1 # ignore first element, which is ticks
819
+ nticks = 0
820
+ while i < len(opus_or_score):
821
+ highest_pitch = 0
822
+ lowest_pitch = 128
823
+ channels_this_track = set([])
824
+ patch_changes_this_track = dict({})
825
+ for event in opus_or_score[i]:
826
+ if event[0] == 'note':
827
+ num_notes_by_channel[event[3]] = num_notes_by_channel.get(event[3],0) + 1
828
+ if event[3] == 9:
829
+ percussion[event[4]] = percussion.get(event[4],0) + 1
830
+ else:
831
+ pitches[event[4]] = pitches.get(event[4],0) + 1
832
+ if event[4] > highest_pitch:
833
+ highest_pitch = event[4]
834
+ if event[4] < lowest_pitch:
835
+ lowest_pitch = event[4]
836
+ channels_this_track.add(event[3])
837
+ channels_total.add(event[3])
838
+ finish_time = event[1] + event[2]
839
+ if finish_time > nticks:
840
+ nticks = finish_time
841
+ elif event[0] == 'note_off' or (event[0] == 'note_on' and event[4] == 0): # 4.8
842
+ finish_time = event[1]
843
+ if finish_time > nticks:
844
+ nticks = finish_time
845
+ elif event[0] == 'note_on':
846
+ is_a_score = False
847
+ num_notes_by_channel[event[2]] = num_notes_by_channel.get(event[2],0) + 1
848
+ if event[2] == 9:
849
+ percussion[event[3]] = percussion.get(event[3],0) + 1
850
+ else:
851
+ pitches[event[3]] = pitches.get(event[3],0) + 1
852
+ if event[3] > highest_pitch:
853
+ highest_pitch = event[3]
854
+ if event[3] < lowest_pitch:
855
+ lowest_pitch = event[3]
856
+ channels_this_track.add(event[2])
857
+ channels_total.add(event[2])
858
+ elif event[0] == 'patch_change':
859
+ patch_changes_this_track[event[2]] = event[3]
860
+ patch_changes_total.add(event[3])
861
+ elif event[0] == 'control_change':
862
+ if event[3] == 0: # bank select MSB
863
+ bank_select_msb = event[4]
864
+ elif event[3] == 32: # bank select LSB
865
+ bank_select_lsb = event[4]
866
+ if bank_select_msb >= 0 and bank_select_lsb >= 0:
867
+ bank_select.append((bank_select_msb,bank_select_lsb))
868
+ bank_select_msb = -1
869
+ bank_select_lsb = -1
870
+ elif event[0] == 'sysex_f0':
871
+ if _sysex2midimode.get(event[2], -1) >= 0:
872
+ general_midi_mode.append(_sysex2midimode.get(event[2]))
873
+ if is_a_score:
874
+ if event[1] > nticks:
875
+ nticks = event[1]
876
+ else:
877
+ nticks += event[1]
878
+ if lowest_pitch == 128:
879
+ lowest_pitch = 0
880
+ channels_by_track.append(channels_this_track)
881
+ patch_changes_by_track.append(patch_changes_this_track)
882
+ pitch_range_by_track.append((lowest_pitch,highest_pitch))
883
+ pitch_range_sum += (highest_pitch-lowest_pitch)
884
+ i += 1
885
+
886
+ return {'bank_select':bank_select,
887
+ 'channels_by_track':channels_by_track,
888
+ 'channels_total':channels_total,
889
+ 'general_midi_mode':general_midi_mode,
890
+ 'ntracks':len(opus_or_score)-1,
891
+ 'nticks':nticks,
892
+ 'num_notes_by_channel':num_notes_by_channel,
893
+ 'patch_changes_by_track':patch_changes_by_track,
894
+ 'patch_changes_total':patch_changes_total,
895
+ 'percussion':percussion,
896
+ 'pitches':pitches,
897
+ 'pitch_range_by_track':pitch_range_by_track,
898
+ 'pitch_range_sum':pitch_range_sum,
899
+ 'ticks_per_quarter':ticks_per_quarter}
900
+
901
+ #----------------------------- Event stuff --------------------------
902
+
903
+ _sysex2midimode = {
904
+ "\x7E\x7F\x09\x01\xF7": 1,
905
+ "\x7E\x7F\x09\x02\xF7": 0,
906
+ "\x7E\x7F\x09\x03\xF7": 2,
907
+ }
908
+
909
+ # Some public-access tuples:
910
+ MIDI_events = tuple('''note_off note_on key_after_touch
911
+ control_change patch_change channel_after_touch
912
+ pitch_wheel_change'''.split())
913
+
914
+ Text_events = tuple('''text_event copyright_text_event
915
+ track_name instrument_name lyric marker cue_point text_event_08
916
+ text_event_09 text_event_0a text_event_0b text_event_0c
917
+ text_event_0d text_event_0e text_event_0f'''.split())
918
+
919
+ Nontext_meta_events = tuple('''end_track set_tempo
920
+ smpte_offset time_signature key_signature sequencer_specific
921
+ raw_meta_event sysex_f0 sysex_f7 song_position song_select
922
+ tune_request'''.split())
923
+ # unsupported: raw_data
924
+
925
+ # Actually, 'tune_request' is is F-series event, not strictly a meta-event...
926
+ Meta_events = Text_events + Nontext_meta_events
927
+ All_events = MIDI_events + Meta_events
928
+
929
+ # And three dictionaries:
930
+ Number2patch = { # General MIDI patch numbers:
931
+ 0:'Acoustic Grand',
932
+ 1:'Bright Acoustic',
933
+ 2:'Electric Grand',
934
+ 3:'Honky-Tonk',
935
+ 4:'Electric Piano 1',
936
+ 5:'Electric Piano 2',
937
+ 6:'Harpsichord',
938
+ 7:'Clav',
939
+ 8:'Celesta',
940
+ 9:'Glockenspiel',
941
+ 10:'Music Box',
942
+ 11:'Vibraphone',
943
+ 12:'Marimba',
944
+ 13:'Xylophone',
945
+ 14:'Tubular Bells',
946
+ 15:'Dulcimer',
947
+ 16:'Drawbar Organ',
948
+ 17:'Percussive Organ',
949
+ 18:'Rock Organ',
950
+ 19:'Church Organ',
951
+ 20:'Reed Organ',
952
+ 21:'Accordion',
953
+ 22:'Harmonica',
954
+ 23:'Tango Accordion',
955
+ 24:'Acoustic Guitar(nylon)',
956
+ 25:'Acoustic Guitar(steel)',
957
+ 26:'Electric Guitar(jazz)',
958
+ 27:'Electric Guitar(clean)',
959
+ 28:'Electric Guitar(muted)',
960
+ 29:'Overdriven Guitar',
961
+ 30:'Distortion Guitar',
962
+ 31:'Guitar Harmonics',
963
+ 32:'Acoustic Bass',
964
+ 33:'Electric Bass(finger)',
965
+ 34:'Electric Bass(pick)',
966
+ 35:'Fretless Bass',
967
+ 36:'Slap Bass 1',
968
+ 37:'Slap Bass 2',
969
+ 38:'Synth Bass 1',
970
+ 39:'Synth Bass 2',
971
+ 40:'Violin',
972
+ 41:'Viola',
973
+ 42:'Cello',
974
+ 43:'Contrabass',
975
+ 44:'Tremolo Strings',
976
+ 45:'Pizzicato Strings',
977
+ 46:'Orchestral Harp',
978
+ 47:'Timpani',
979
+ 48:'String Ensemble 1',
980
+ 49:'String Ensemble 2',
981
+ 50:'SynthStrings 1',
982
+ 51:'SynthStrings 2',
983
+ 52:'Choir Aahs',
984
+ 53:'Voice Oohs',
985
+ 54:'Synth Voice',
986
+ 55:'Orchestra Hit',
987
+ 56:'Trumpet',
988
+ 57:'Trombone',
989
+ 58:'Tuba',
990
+ 59:'Muted Trumpet',
991
+ 60:'French Horn',
992
+ 61:'Brass Section',
993
+ 62:'SynthBrass 1',
994
+ 63:'SynthBrass 2',
995
+ 64:'Soprano Sax',
996
+ 65:'Alto Sax',
997
+ 66:'Tenor Sax',
998
+ 67:'Baritone Sax',
999
+ 68:'Oboe',
1000
+ 69:'English Horn',
1001
+ 70:'Bassoon',
1002
+ 71:'Clarinet',
1003
+ 72:'Piccolo',
1004
+ 73:'Flute',
1005
+ 74:'Recorder',
1006
+ 75:'Pan Flute',
1007
+ 76:'Blown Bottle',
1008
+ 77:'Skakuhachi',
1009
+ 78:'Whistle',
1010
+ 79:'Ocarina',
1011
+ 80:'Lead 1 (square)',
1012
+ 81:'Lead 2 (sawtooth)',
1013
+ 82:'Lead 3 (calliope)',
1014
+ 83:'Lead 4 (chiff)',
1015
+ 84:'Lead 5 (charang)',
1016
+ 85:'Lead 6 (voice)',
1017
+ 86:'Lead 7 (fifths)',
1018
+ 87:'Lead 8 (bass+lead)',
1019
+ 88:'Pad 1 (new age)',
1020
+ 89:'Pad 2 (warm)',
1021
+ 90:'Pad 3 (polysynth)',
1022
+ 91:'Pad 4 (choir)',
1023
+ 92:'Pad 5 (bowed)',
1024
+ 93:'Pad 6 (metallic)',
1025
+ 94:'Pad 7 (halo)',
1026
+ 95:'Pad 8 (sweep)',
1027
+ 96:'FX 1 (rain)',
1028
+ 97:'FX 2 (soundtrack)',
1029
+ 98:'FX 3 (crystal)',
1030
+ 99:'FX 4 (atmosphere)',
1031
+ 100:'FX 5 (brightness)',
1032
+ 101:'FX 6 (goblins)',
1033
+ 102:'FX 7 (echoes)',
1034
+ 103:'FX 8 (sci-fi)',
1035
+ 104:'Sitar',
1036
+ 105:'Banjo',
1037
+ 106:'Shamisen',
1038
+ 107:'Koto',
1039
+ 108:'Kalimba',
1040
+ 109:'Bagpipe',
1041
+ 110:'Fiddle',
1042
+ 111:'Shanai',
1043
+ 112:'Tinkle Bell',
1044
+ 113:'Agogo',
1045
+ 114:'Steel Drums',
1046
+ 115:'Woodblock',
1047
+ 116:'Taiko Drum',
1048
+ 117:'Melodic Tom',
1049
+ 118:'Synth Drum',
1050
+ 119:'Reverse Cymbal',
1051
+ 120:'Guitar Fret Noise',
1052
+ 121:'Breath Noise',
1053
+ 122:'Seashore',
1054
+ 123:'Bird Tweet',
1055
+ 124:'Telephone Ring',
1056
+ 125:'Helicopter',
1057
+ 126:'Applause',
1058
+ 127:'Gunshot',
1059
+ }
1060
+ Notenum2percussion = { # General MIDI Percussion (on Channel 9):
1061
+ 35:'Acoustic Bass Drum',
1062
+ 36:'Bass Drum 1',
1063
+ 37:'Side Stick',
1064
+ 38:'Acoustic Snare',
1065
+ 39:'Hand Clap',
1066
+ 40:'Electric Snare',
1067
+ 41:'Low Floor Tom',
1068
+ 42:'Closed Hi-Hat',
1069
+ 43:'High Floor Tom',
1070
+ 44:'Pedal Hi-Hat',
1071
+ 45:'Low Tom',
1072
+ 46:'Open Hi-Hat',
1073
+ 47:'Low-Mid Tom',
1074
+ 48:'Hi-Mid Tom',
1075
+ 49:'Crash Cymbal 1',
1076
+ 50:'High Tom',
1077
+ 51:'Ride Cymbal 1',
1078
+ 52:'Chinese Cymbal',
1079
+ 53:'Ride Bell',
1080
+ 54:'Tambourine',
1081
+ 55:'Splash Cymbal',
1082
+ 56:'Cowbell',
1083
+ 57:'Crash Cymbal 2',
1084
+ 58:'Vibraslap',
1085
+ 59:'Ride Cymbal 2',
1086
+ 60:'Hi Bongo',
1087
+ 61:'Low Bongo',
1088
+ 62:'Mute Hi Conga',
1089
+ 63:'Open Hi Conga',
1090
+ 64:'Low Conga',
1091
+ 65:'High Timbale',
1092
+ 66:'Low Timbale',
1093
+ 67:'High Agogo',
1094
+ 68:'Low Agogo',
1095
+ 69:'Cabasa',
1096
+ 70:'Maracas',
1097
+ 71:'Short Whistle',
1098
+ 72:'Long Whistle',
1099
+ 73:'Short Guiro',
1100
+ 74:'Long Guiro',
1101
+ 75:'Claves',
1102
+ 76:'Hi Wood Block',
1103
+ 77:'Low Wood Block',
1104
+ 78:'Mute Cuica',
1105
+ 79:'Open Cuica',
1106
+ 80:'Mute Triangle',
1107
+ 81:'Open Triangle',
1108
+ }
1109
+
1110
+ Event2channelindex = { 'note':3, 'note_off':2, 'note_on':2,
1111
+ 'key_after_touch':2, 'control_change':2, 'patch_change':2,
1112
+ 'channel_after_touch':2, 'pitch_wheel_change':2
1113
+ }
1114
+
1115
+ ################################################################
1116
+ # The code below this line is full of frightening things, all to
1117
+ # do with the actual encoding and decoding of binary MIDI data.
1118
+
1119
+ def _twobytes2int(byte_a):
1120
+ r'''decode a 16 bit quantity from two bytes,'''
1121
+ return (byte_a[1] | (byte_a[0] << 8))
1122
+
1123
+ def _int2twobytes(int_16bit):
1124
+ r'''encode a 16 bit quantity into two bytes,'''
1125
+ return bytes([(int_16bit>>8) & 0xFF, int_16bit & 0xFF])
1126
+
1127
+ def _read_14_bit(byte_a):
1128
+ r'''decode a 14 bit quantity from two bytes,'''
1129
+ return (byte_a[0] | (byte_a[1] << 7))
1130
+
1131
+ def _write_14_bit(int_14bit):
1132
+ r'''encode a 14 bit quantity into two bytes,'''
1133
+ return bytes([int_14bit & 0x7F, (int_14bit>>7) & 0x7F])
1134
+
1135
+ def _ber_compressed_int(integer):
1136
+ r'''BER compressed integer (not an ASN.1 BER, see perlpacktut for
1137
+ details). Its bytes represent an unsigned integer in base 128,
1138
+ most significant digit first, with as few digits as possible.
1139
+ Bit eight (the high bit) is set on each byte except the last.
1140
+ '''
1141
+ ber = bytearray(b'')
1142
+ seven_bits = 0x7F & integer
1143
+ ber.insert(0, seven_bits) # XXX surely should convert to a char ?
1144
+ integer >>= 7
1145
+ while integer > 0:
1146
+ seven_bits = 0x7F & integer
1147
+ ber.insert(0, 0x80|seven_bits) # XXX surely should convert to a char ?
1148
+ integer >>= 7
1149
+ return ber
1150
+
1151
+ def _unshift_ber_int(ba):
1152
+ r'''Given a bytearray, returns a tuple of (the ber-integer at the
1153
+ start, and the remainder of the bytearray).
1154
+ '''
1155
+ if not len(ba): # 6.7
1156
+ _warn('_unshift_ber_int: no integer found')
1157
+ return ((0, b""))
1158
+ byte = ba.pop(0)
1159
+ integer = 0
1160
+ while True:
1161
+ integer += (byte & 0x7F)
1162
+ if not (byte & 0x80):
1163
+ return ((integer, ba))
1164
+ if not len(ba):
1165
+ _warn('_unshift_ber_int: no end-of-integer found')
1166
+ return ((0, ba))
1167
+ byte = ba.pop(0)
1168
+ integer <<= 7
1169
+
1170
+ def _clean_up_warnings(): # 5.4
1171
+ # Call this before returning from any publicly callable function
1172
+ # whenever there's a possibility that a warning might have been printed
1173
+ # by the function, or by any private functions it might have called.
1174
+ if _no_warning:
1175
+ return
1176
+ global _previous_times
1177
+ global _previous_warning
1178
+ if _previous_times > 1:
1179
+ # E:1176, 0: invalid syntax (<string>, line 1176) (syntax-error) ???
1180
+ # print(' previous message repeated '+str(_previous_times)+' times', file=sys.stderr)
1181
+ # 6.7
1182
+ sys.stderr.write(' previous message repeated {0} times\n'.format(_previous_times))
1183
+ elif _previous_times > 0:
1184
+ sys.stderr.write(' previous message repeated\n')
1185
+ _previous_times = 0
1186
+ _previous_warning = ''
1187
+
1188
+ def _warn(s=''):
1189
+ if _no_warning:
1190
+ return
1191
+ global _previous_times
1192
+ global _previous_warning
1193
+ if s == _previous_warning: # 5.4
1194
+ _previous_times = _previous_times + 1
1195
+ else:
1196
+ _clean_up_warnings()
1197
+ sys.stderr.write(str(s)+"\n")
1198
+ _previous_warning = s
1199
+
1200
+ def _some_text_event(which_kind=0x01, text=b'some_text'):
1201
+ if str(type(text)).find("'str'") >= 0: # 6.4 test for back-compatibility
1202
+ data = bytes(text, encoding='ISO-8859-1')
1203
+ else:
1204
+ data = bytes(text)
1205
+ return b'\xFF'+bytes((which_kind,))+_ber_compressed_int(len(data))+data
1206
+
1207
+ def _consistentise_ticks(scores): # 3.6
1208
+ # used by mix_scores, merge_scores, concatenate_scores
1209
+ if len(scores) == 1:
1210
+ return copy.deepcopy(scores)
1211
+ are_consistent = True
1212
+ ticks = scores[0][0]
1213
+ iscore = 1
1214
+ while iscore < len(scores):
1215
+ if scores[iscore][0] != ticks:
1216
+ are_consistent = False
1217
+ break
1218
+ iscore += 1
1219
+ if are_consistent:
1220
+ return copy.deepcopy(scores)
1221
+ new_scores = []
1222
+ iscore = 0
1223
+ while iscore < len(scores):
1224
+ score = scores[iscore]
1225
+ new_scores.append(opus2score(to_millisecs(score2opus(score))))
1226
+ iscore += 1
1227
+ return new_scores
1228
+
1229
+
1230
+ ###########################################################################
1231
+
1232
+ def _decode(trackdata=b'', exclude=None, include=None,
1233
+ event_callback=None, exclusive_event_callback=None, no_eot_magic=False):
1234
+ r'''Decodes MIDI track data into an opus-style list of events.
1235
+ The options:
1236
+ 'exclude' is a list of event types which will be ignored SHOULD BE A SET
1237
+ 'include' (and no exclude), makes exclude a list
1238
+ of all possible events, /minus/ what include specifies
1239
+ 'event_callback' is a coderef
1240
+ 'exclusive_event_callback' is a coderef
1241
+ '''
1242
+ trackdata = bytearray(trackdata)
1243
+ if exclude == None:
1244
+ exclude = []
1245
+ if include == None:
1246
+ include = []
1247
+ if include and not exclude:
1248
+ exclude = All_events
1249
+ include = set(include)
1250
+ exclude = set(exclude)
1251
+
1252
+ # Pointer = 0; not used here; we eat through the bytearray instead.
1253
+ event_code = -1; # used for running status
1254
+ event_count = 0;
1255
+ events = []
1256
+
1257
+ while(len(trackdata)):
1258
+ # loop while there's anything to analyze ...
1259
+ eot = False # When True, the event registrar aborts this loop
1260
+ event_count += 1
1261
+
1262
+ E = []
1263
+ # E for events - we'll feed it to the event registrar at the end.
1264
+
1265
+ # Slice off the delta time code, and analyze it
1266
+ [time, remainder] = _unshift_ber_int(trackdata)
1267
+
1268
+ # Now let's see what we can make of the command
1269
+ first_byte = trackdata.pop(0) & 0xFF
1270
+
1271
+ if (first_byte < 0xF0): # It's a MIDI event
1272
+ if (first_byte & 0x80):
1273
+ event_code = first_byte
1274
+ else:
1275
+ # It wants running status; use last event_code value
1276
+ trackdata.insert(0, first_byte)
1277
+ if (event_code == -1):
1278
+ _warn("Running status not set; Aborting track.")
1279
+ return []
1280
+
1281
+ command = event_code & 0xF0
1282
+ channel = event_code & 0x0F
1283
+
1284
+ if (command == 0xF6): # 0-byte argument
1285
+ pass
1286
+ elif (command == 0xC0 or command == 0xD0): # 1-byte argument
1287
+ parameter = trackdata.pop(0) # could be B
1288
+ else: # 2-byte argument could be BB or 14-bit
1289
+ parameter = (trackdata.pop(0), trackdata.pop(0))
1290
+
1291
+ #################################################################
1292
+ # MIDI events
1293
+
1294
+ if (command == 0x80):
1295
+ if 'note_off' in exclude:
1296
+ continue
1297
+ E = ['note_off', time, channel, parameter[0], parameter[1]]
1298
+ elif (command == 0x90):
1299
+ if 'note_on' in exclude:
1300
+ continue
1301
+ E = ['note_on', time, channel, parameter[0], parameter[1]]
1302
+ elif (command == 0xA0):
1303
+ if 'key_after_touch' in exclude:
1304
+ continue
1305
+ E = ['key_after_touch',time,channel,parameter[0],parameter[1]]
1306
+ elif (command == 0xB0):
1307
+ if 'control_change' in exclude:
1308
+ continue
1309
+ E = ['control_change',time,channel,parameter[0],parameter[1]]
1310
+ elif (command == 0xC0):
1311
+ if 'patch_change' in exclude:
1312
+ continue
1313
+ E = ['patch_change', time, channel, parameter]
1314
+ elif (command == 0xD0):
1315
+ if 'channel_after_touch' in exclude:
1316
+ continue
1317
+ E = ['channel_after_touch', time, channel, parameter]
1318
+ elif (command == 0xE0):
1319
+ if 'pitch_wheel_change' in exclude:
1320
+ continue
1321
+ E = ['pitch_wheel_change', time, channel,
1322
+ _read_14_bit(parameter)-0x2000]
1323
+ else:
1324
+ _warn("Shouldn't get here; command="+hex(command))
1325
+
1326
+ elif (first_byte == 0xFF): # It's a Meta-Event! ##################
1327
+ #[command, length, remainder] =
1328
+ # unpack("xCwa*", substr(trackdata, $Pointer, 6));
1329
+ #Pointer += 6 - len(remainder);
1330
+ # # Move past JUST the length-encoded.
1331
+ command = trackdata.pop(0) & 0xFF
1332
+ [length, trackdata] = _unshift_ber_int(trackdata)
1333
+ if (command == 0x00):
1334
+ if (length == 2):
1335
+ E = ['set_sequence_number',time,_twobytes2int(trackdata)]
1336
+ else:
1337
+ _warn('set_sequence_number: length must be 2, not '+str(length))
1338
+ E = ['set_sequence_number', time, 0]
1339
+
1340
+ elif command >= 0x01 and command <= 0x0f: # Text events
1341
+ # 6.2 take it in bytes; let the user get the right encoding.
1342
+ # text_str = trackdata[0:length].decode('ascii','ignore')
1343
+ # text_str = trackdata[0:length].decode('ISO-8859-1')
1344
+ # 6.4 take it in bytes; let the user get the right encoding.
1345
+ text_data = bytes(trackdata[0:length]) # 6.4
1346
+ # Defined text events
1347
+ if (command == 0x01):
1348
+ E = ['text_event', time, text_data]
1349
+ elif (command == 0x02):
1350
+ E = ['copyright_text_event', time, text_data]
1351
+ elif (command == 0x03):
1352
+ E = ['track_name', time, text_data]
1353
+ elif (command == 0x04):
1354
+ E = ['instrument_name', time, text_data]
1355
+ elif (command == 0x05):
1356
+ E = ['lyric', time, text_data]
1357
+ elif (command == 0x06):
1358
+ E = ['marker', time, text_data]
1359
+ elif (command == 0x07):
1360
+ E = ['cue_point', time, text_data]
1361
+ # Reserved but apparently unassigned text events
1362
+ elif (command == 0x08):
1363
+ E = ['text_event_08', time, text_data]
1364
+ elif (command == 0x09):
1365
+ E = ['text_event_09', time, text_data]
1366
+ elif (command == 0x0a):
1367
+ E = ['text_event_0a', time, text_data]
1368
+ elif (command == 0x0b):
1369
+ E = ['text_event_0b', time, text_data]
1370
+ elif (command == 0x0c):
1371
+ E = ['text_event_0c', time, text_data]
1372
+ elif (command == 0x0d):
1373
+ E = ['text_event_0d', time, text_data]
1374
+ elif (command == 0x0e):
1375
+ E = ['text_event_0e', time, text_data]
1376
+ elif (command == 0x0f):
1377
+ E = ['text_event_0f', time, text_data]
1378
+
1379
+ # Now the sticky events -------------------------------------
1380
+ elif (command == 0x2F):
1381
+ E = ['end_track', time]
1382
+ # The code for handling this, oddly, comes LATER,
1383
+ # in the event registrar.
1384
+ elif (command == 0x51): # DTime, Microseconds/Crochet
1385
+ if length != 3:
1386
+ _warn('set_tempo event, but length='+str(length))
1387
+ E = ['set_tempo', time,
1388
+ struct.unpack(">I", b'\x00'+trackdata[0:3])[0]]
1389
+ elif (command == 0x54):
1390
+ if length != 5: # DTime, HR, MN, SE, FR, FF
1391
+ _warn('smpte_offset event, but length='+str(length))
1392
+ E = ['smpte_offset',time] + list(struct.unpack(">BBBBB",trackdata[0:5]))
1393
+ elif (command == 0x58):
1394
+ if length != 4: # DTime, NN, DD, CC, BB
1395
+ _warn('time_signature event, but length='+str(length))
1396
+ E = ['time_signature', time]+list(trackdata[0:4])
1397
+ elif (command == 0x59):
1398
+ if length != 2: # DTime, SF(signed), MI
1399
+ _warn('key_signature event, but length='+str(length))
1400
+ E = ['key_signature',time] + list(struct.unpack(">bB",trackdata[0:2]))
1401
+ elif (command == 0x7F): # 6.4
1402
+ E = ['sequencer_specific',time, bytes(trackdata[0:length])]
1403
+ else:
1404
+ E = ['raw_meta_event', time, command,
1405
+ bytes(trackdata[0:length])] # 6.0
1406
+ #"[uninterpretable meta-event command of length length]"
1407
+ # DTime, Command, Binary Data
1408
+ # It's uninterpretable; record it as raw_data.
1409
+
1410
+ # Pointer += length; # Now move Pointer
1411
+ trackdata = trackdata[length:]
1412
+
1413
+ ######################################################################
1414
+ elif (first_byte == 0xF0 or first_byte == 0xF7):
1415
+ # Note that sysexes in MIDI /files/ are different than sysexes
1416
+ # in MIDI transmissions!! The vast majority of system exclusive
1417
+ # messages will just use the F0 format. For instance, the
1418
+ # transmitted message F0 43 12 00 07 F7 would be stored in a
1419
+ # MIDI file as F0 05 43 12 00 07 F7. As mentioned above, it is
1420
+ # required to include the F7 at the end so that the reader of the
1421
+ # MIDI file knows that it has read the entire message. (But the F7
1422
+ # is omitted if this is a non-final block in a multiblock sysex;
1423
+ # but the F7 (if there) is counted in the message's declared
1424
+ # length, so we don't have to think about it anyway.)
1425
+ #command = trackdata.pop(0)
1426
+ [length, trackdata] = _unshift_ber_int(trackdata)
1427
+ if first_byte == 0xF0:
1428
+ # 20091008 added ISO-8859-1 to get an 8-bit str
1429
+ # 6.4 return bytes instead
1430
+ E = ['sysex_f0', time, bytes(trackdata[0:length])]
1431
+ else:
1432
+ E = ['sysex_f7', time, bytes(trackdata[0:length])]
1433
+ trackdata = trackdata[length:]
1434
+
1435
+ ######################################################################
1436
+ # Now, the MIDI file spec says:
1437
+ # <track data> = <MTrk event>+
1438
+ # <MTrk event> = <delta-time> <event>
1439
+ # <event> = <MIDI event> | <sysex event> | <meta-event>
1440
+ # I know that, on the wire, <MIDI event> can include note_on,
1441
+ # note_off, and all the other 8x to Ex events, AND Fx events
1442
+ # other than F0, F7, and FF -- namely, <song position msg>,
1443
+ # <song select msg>, and <tune request>.
1444
+ #
1445
+ # Whether these can occur in MIDI files is not clear specified
1446
+ # from the MIDI file spec. So, I'm going to assume that
1447
+ # they CAN, in practice, occur. I don't know whether it's
1448
+ # proper for you to actually emit these into a MIDI file.
1449
+
1450
+ elif (first_byte == 0xF2): # DTime, Beats
1451
+ # <song position msg> ::= F2 <data pair>
1452
+ E = ['song_position', time, _read_14_bit(trackdata[:2])]
1453
+ trackdata = trackdata[2:]
1454
+
1455
+ elif (first_byte == 0xF3): # <song select msg> ::= F3 <data singlet>
1456
+ # E = ['song_select', time, struct.unpack('>B',trackdata.pop(0))[0]]
1457
+ E = ['song_select', time, trackdata[0]]
1458
+ trackdata = trackdata[1:]
1459
+ # DTime, Thing (what?! song number? whatever ...)
1460
+
1461
+ elif (first_byte == 0xF6): # DTime
1462
+ E = ['tune_request', time]
1463
+ # What would a tune request be doing in a MIDI /file/?
1464
+
1465
+ #########################################################
1466
+ # ADD MORE META-EVENTS HERE. TODO:
1467
+ # f1 -- MTC Quarter Frame Message. One data byte follows
1468
+ # the Status; it's the time code value, from 0 to 127.
1469
+ # f8 -- MIDI clock. no data.
1470
+ # fa -- MIDI start. no data.
1471
+ # fb -- MIDI continue. no data.
1472
+ # fc -- MIDI stop. no data.
1473
+ # fe -- Active sense. no data.
1474
+ # f4 f5 f9 fd -- unallocated
1475
+
1476
+ r'''
1477
+ elif (first_byte > 0xF0) { # Some unknown kinda F-series event ####
1478
+ # Here we only produce a one-byte piece of raw data.
1479
+ # But the encoder for 'raw_data' accepts any length of it.
1480
+ E = [ 'raw_data',
1481
+ time, substr(trackdata,Pointer,1) ]
1482
+ # DTime and the Data (in this case, the one Event-byte)
1483
+ ++Pointer; # itself
1484
+
1485
+ '''
1486
+ elif first_byte > 0xF0: # Some unknown F-series event
1487
+ # Here we only produce a one-byte piece of raw data.
1488
+ # E = ['raw_data', time, bytest(trackdata[0])] # 6.4
1489
+ E = ['raw_data', time, trackdata[0]] # 6.4 6.7
1490
+ trackdata = trackdata[1:]
1491
+ else: # Fallthru.
1492
+ _warn("Aborting track. Command-byte first_byte="+hex(first_byte))
1493
+ break
1494
+ # End of the big if-group
1495
+
1496
+
1497
+ ######################################################################
1498
+ # THE EVENT REGISTRAR...
1499
+ if E and (E[0] == 'end_track'):
1500
+ # This is the code for exceptional handling of the EOT event.
1501
+ eot = True
1502
+ if not no_eot_magic:
1503
+ if E[1] > 0: # a null text-event to carry the delta-time
1504
+ E = ['text_event', E[1], '']
1505
+ else:
1506
+ E = [] # EOT with a delta-time of 0; ignore it.
1507
+
1508
+ if E and not (E[0] in exclude):
1509
+ #if ( $exclusive_event_callback ):
1510
+ # &{ $exclusive_event_callback }( @E );
1511
+ #else:
1512
+ # &{ $event_callback }( @E ) if $event_callback;
1513
+ events.append(E)
1514
+ if eot:
1515
+ break
1516
+
1517
+ # End of the big "Event" while-block
1518
+
1519
+ return events
1520
+
1521
+
1522
+ ###########################################################################
1523
+ def _encode(events_lol, unknown_callback=None, never_add_eot=False,
1524
+ no_eot_magic=False, no_running_status=False):
1525
+ # encode an event structure, presumably for writing to a file
1526
+ # Calling format:
1527
+ # $data_r = MIDI::Event::encode( \@event_lol, { options } );
1528
+ # Takes a REFERENCE to an event structure (a LoL)
1529
+ # Returns an (unblessed) REFERENCE to track data.
1530
+
1531
+ # If you want to use this to encode a /single/ event,
1532
+ # you still have to do it as a reference to an event structure (a LoL)
1533
+ # that just happens to have just one event. I.e.,
1534
+ # encode( [ $event ] ) or encode( [ [ 'note_on', 100, 5, 42, 64] ] )
1535
+ # If you're doing this, consider the never_add_eot track option, as in
1536
+ # print MIDI ${ encode( [ $event], { 'never_add_eot' => 1} ) };
1537
+
1538
+ data = [] # what I'll store the chunks of byte-data in
1539
+
1540
+ # This is so my end_track magic won't corrupt the original
1541
+ events = copy.deepcopy(events_lol)
1542
+
1543
+ if not never_add_eot:
1544
+ # One way or another, tack on an 'end_track'
1545
+ if events:
1546
+ last = events[-1]
1547
+ if not (last[0] == 'end_track'): # no end_track already
1548
+ if (last[0] == 'text_event' and len(last[2]) == 0):
1549
+ # 0-length text event at track-end.
1550
+ if no_eot_magic:
1551
+ # Exceptional case: don't mess with track-final
1552
+ # 0-length text_events; just peg on an end_track
1553
+ events.append(['end_track', 0])
1554
+ else:
1555
+ # NORMAL CASE: replace with an end_track, leaving DTime
1556
+ last[0] = 'end_track'
1557
+ else:
1558
+ # last event was neither 0-length text_event nor end_track
1559
+ events.append(['end_track', 0])
1560
+ else: # an eventless track!
1561
+ events = [['end_track', 0],]
1562
+
1563
+ # maybe_running_status = not no_running_status # unused? 4.7
1564
+ last_status = -1
1565
+
1566
+ for event_r in (events):
1567
+ E = copy.deepcopy(event_r)
1568
+ # otherwise the shifting'd corrupt the original
1569
+ if not E:
1570
+ continue
1571
+
1572
+ event = E.pop(0)
1573
+ if not len(event):
1574
+ continue
1575
+
1576
+ dtime = int(E.pop(0))
1577
+ # print('event='+str(event)+' dtime='+str(dtime))
1578
+
1579
+ event_data = ''
1580
+
1581
+ if ( # MIDI events -- eligible for running status
1582
+ event == 'note_on'
1583
+ or event == 'note_off'
1584
+ or event == 'control_change'
1585
+ or event == 'key_after_touch'
1586
+ or event == 'patch_change'
1587
+ or event == 'channel_after_touch'
1588
+ or event == 'pitch_wheel_change' ):
1589
+
1590
+ # This block is where we spend most of the time. Gotta be tight.
1591
+ if (event == 'note_off'):
1592
+ status = 0x80 | (int(E[0]) & 0x0F)
1593
+ parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
1594
+ elif (event == 'note_on'):
1595
+ status = 0x90 | (int(E[0]) & 0x0F)
1596
+ parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
1597
+ elif (event == 'key_after_touch'):
1598
+ status = 0xA0 | (int(E[0]) & 0x0F)
1599
+ parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
1600
+ elif (event == 'control_change'):
1601
+ status = 0xB0 | (int(E[0]) & 0x0F)
1602
+ parameters = struct.pack('>BB', int(E[1])&0xFF, int(E[2])&0xFF)
1603
+ elif (event == 'patch_change'):
1604
+ status = 0xC0 | (int(E[0]) & 0x0F)
1605
+ parameters = struct.pack('>B', int(E[1]) & 0xFF)
1606
+ elif (event == 'channel_after_touch'):
1607
+ status = 0xD0 | (int(E[0]) & 0x0F)
1608
+ parameters = struct.pack('>B', int(E[1]) & 0xFF)
1609
+ elif (event == 'pitch_wheel_change'):
1610
+ status = 0xE0 | (int(E[0]) & 0x0F)
1611
+ parameters = _write_14_bit(int(E[1]) + 0x2000)
1612
+ else:
1613
+ _warn("BADASS FREAKOUT ERROR 31415!")
1614
+
1615
+ # And now the encoding
1616
+ # w = BER compressed integer (not ASN.1 BER, see perlpacktut for
1617
+ # details). Its bytes represent an unsigned integer in base 128,
1618
+ # most significant digit first, with as few digits as possible.
1619
+ # Bit eight (the high bit) is set on each byte except the last.
1620
+
1621
+ data.append(_ber_compressed_int(dtime))
1622
+ if (status != last_status) or no_running_status:
1623
+ data.append(struct.pack('>B', status))
1624
+ data.append(parameters)
1625
+
1626
+ last_status = status
1627
+ continue
1628
+ else:
1629
+ # Not a MIDI event.
1630
+ # All the code in this block could be more efficient,
1631
+ # but this is not where the code needs to be tight.
1632
+ # print "zaz $event\n";
1633
+ last_status = -1
1634
+
1635
+ if event == 'raw_meta_event':
1636
+ event_data = _some_text_event(int(E[0]), E[1])
1637
+ elif (event == 'set_sequence_number'): # 3.9
1638
+ event_data = b'\xFF\x00\x02'+_int2twobytes(E[0])
1639
+
1640
+ # Text meta-events...
1641
+ # a case for a dict, I think (pjb) ...
1642
+ elif (event == 'text_event'):
1643
+ event_data = _some_text_event(0x01, E[0])
1644
+ elif (event == 'copyright_text_event'):
1645
+ event_data = _some_text_event(0x02, E[0])
1646
+ elif (event == 'track_name'):
1647
+ event_data = _some_text_event(0x03, E[0])
1648
+ elif (event == 'instrument_name'):
1649
+ event_data = _some_text_event(0x04, E[0])
1650
+ elif (event == 'lyric'):
1651
+ event_data = _some_text_event(0x05, E[0])
1652
+ elif (event == 'marker'):
1653
+ event_data = _some_text_event(0x06, E[0])
1654
+ elif (event == 'cue_point'):
1655
+ event_data = _some_text_event(0x07, E[0])
1656
+ elif (event == 'text_event_08'):
1657
+ event_data = _some_text_event(0x08, E[0])
1658
+ elif (event == 'text_event_09'):
1659
+ event_data = _some_text_event(0x09, E[0])
1660
+ elif (event == 'text_event_0a'):
1661
+ event_data = _some_text_event(0x0A, E[0])
1662
+ elif (event == 'text_event_0b'):
1663
+ event_data = _some_text_event(0x0B, E[0])
1664
+ elif (event == 'text_event_0c'):
1665
+ event_data = _some_text_event(0x0C, E[0])
1666
+ elif (event == 'text_event_0d'):
1667
+ event_data = _some_text_event(0x0D, E[0])
1668
+ elif (event == 'text_event_0e'):
1669
+ event_data = _some_text_event(0x0E, E[0])
1670
+ elif (event == 'text_event_0f'):
1671
+ event_data = _some_text_event(0x0F, E[0])
1672
+ # End of text meta-events
1673
+
1674
+ elif (event == 'end_track'):
1675
+ event_data = b"\xFF\x2F\x00"
1676
+
1677
+ elif (event == 'set_tempo'):
1678
+ #event_data = struct.pack(">BBwa*", 0xFF, 0x51, 3,
1679
+ # substr( struct.pack('>I', E[0]), 1, 3))
1680
+ event_data = b'\xFF\x51\x03'+struct.pack('>I',E[0])[1:]
1681
+ elif (event == 'smpte_offset'):
1682
+ # event_data = struct.pack(">BBwBBBBB", 0xFF, 0x54, 5, E[0:5] )
1683
+ event_data = struct.pack(">BBBbBBBB", 0xFF,0x54,0x05,E[0],E[1],E[2],E[3],E[4])
1684
+ elif (event == 'time_signature'):
1685
+ # event_data = struct.pack(">BBwBBBB", 0xFF, 0x58, 4, E[0:4] )
1686
+ event_data = struct.pack(">BBBbBBB", 0xFF, 0x58, 0x04, E[0],E[1],E[2],E[3])
1687
+ elif (event == 'key_signature'):
1688
+ event_data = struct.pack(">BBBbB", 0xFF, 0x59, 0x02, E[0],E[1])
1689
+ elif (event == 'sequencer_specific'):
1690
+ # event_data = struct.pack(">BBwa*", 0xFF,0x7F, len(E[0]), E[0])
1691
+ event_data = _some_text_event(0x7F, E[0])
1692
+ # End of Meta-events
1693
+
1694
+ # Other Things...
1695
+ elif (event == 'sysex_f0'):
1696
+ #event_data = struct.pack(">Bwa*", 0xF0, len(E[0]), E[0])
1697
+ #B=bitstring w=BER-compressed-integer a=null-padded-ascii-str
1698
+ event_data = bytearray(b'\xF0')+_ber_compressed_int(len(E[0]))+bytearray(E[0])
1699
+ elif (event == 'sysex_f7'):
1700
+ #event_data = struct.pack(">Bwa*", 0xF7, len(E[0]), E[0])
1701
+ event_data = bytearray(b'\xF7')+_ber_compressed_int(len(E[0]))+bytearray(E[0])
1702
+
1703
+ elif (event == 'song_position'):
1704
+ event_data = b"\xF2" + _write_14_bit( E[0] )
1705
+ elif (event == 'song_select'):
1706
+ event_data = struct.pack('>BB', 0xF3, E[0] )
1707
+ elif (event == 'tune_request'):
1708
+ event_data = b"\xF6"
1709
+ elif (event == 'raw_data'):
1710
+ _warn("_encode: raw_data event not supported")
1711
+ # event_data = E[0]
1712
+ continue
1713
+ # End of Other Stuff
1714
+
1715
+ else:
1716
+ # The Big Fallthru
1717
+ if unknown_callback:
1718
+ # push(@data, &{ $unknown_callback }( @$event_r ))
1719
+ pass
1720
+ else:
1721
+ _warn("Unknown event: "+str(event))
1722
+ # To surpress complaint here, just set
1723
+ # 'unknown_callback' => sub { return () }
1724
+ continue
1725
+
1726
+ #print "Event $event encoded part 2\n"
1727
+ if str(type(event_data)).find("'str'") >= 0:
1728
+ event_data = bytearray(event_data.encode('Latin1', 'ignore'))
1729
+ if len(event_data): # how could $event_data be empty
1730
+ # data.append(struct.pack('>wa*', dtime, event_data))
1731
+ # print(' event_data='+str(event_data))
1732
+ data.append(_ber_compressed_int(dtime)+event_data)
1733
+
1734
+ return b''.join(data)
1735
+
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Music
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.0.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: Midi Music Generator
3
+ emoji: 🎼🎶
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.0.1
8
+ app_file: app_onnx.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import random
3
+ import argparse
4
+ import glob
5
+ import json
6
+ import os
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import tqdm
15
+ from huggingface_hub import hf_hub_download
16
+ from transformers import DynamicCache
17
+
18
+ import MIDI
19
+ from midi_model import MIDIModel, MIDIModelConfig
20
+ from midi_synthesizer import MidiSynthesizer
21
+
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ in_space = os.getenv("SYSTEM") == "spaces"
24
+
25
+
26
+ @torch.inference_mode()
27
+ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
28
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
29
+ tokenizer = model.tokenizer
30
+ if disable_channels is not None:
31
+ disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
32
+ else:
33
+ disable_channels = []
34
+ max_token_seq = tokenizer.max_token_seq
35
+ if prompt is None:
36
+ input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
37
+ input_tensor[0, 0] = tokenizer.bos_id # bos
38
+ input_tensor = input_tensor.unsqueeze(0)
39
+ input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
40
+ else:
41
+ if len(prompt.shape) == 2:
42
+ prompt = prompt[None, :]
43
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
44
+ elif prompt.shape[0] == 1:
45
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
46
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
47
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
48
+ prompt = prompt[..., :max_token_seq]
49
+ if prompt.shape[-1] < max_token_seq:
50
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
51
+ mode="constant", constant_values=tokenizer.pad_id)
52
+ input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
53
+ cur_len = input_tensor.shape[1]
54
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
55
+ cache1 = DynamicCache()
56
+ past_len = 0
57
+ with bar:
58
+ while cur_len < max_len:
59
+ end = [False] * batch_size
60
+ hidden = model.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
61
+ next_token_seq = None
62
+ event_names = [""] * batch_size
63
+ cache2 = DynamicCache()
64
+ for i in range(max_token_seq):
65
+ mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
66
+ for b in range(batch_size):
67
+ if end[b]:
68
+ mask[b, tokenizer.pad_id] = 1
69
+ continue
70
+ if i == 0:
71
+ mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
72
+ if disable_patch_change:
73
+ mask_ids.remove(tokenizer.event_ids["patch_change"])
74
+ if disable_control_change:
75
+ mask_ids.remove(tokenizer.event_ids["control_change"])
76
+ mask[b, mask_ids] = 1
77
+ else:
78
+ param_names = tokenizer.events[event_names[b]]
79
+ if i > len(param_names):
80
+ mask[b, tokenizer.pad_id] = 1
81
+ continue
82
+ param_name = param_names[i - 1]
83
+ mask_ids = tokenizer.parameter_ids[param_name]
84
+ if param_name == "channel":
85
+ mask_ids = [i for i in mask_ids if i not in disable_channels]
86
+ mask[b, mask_ids] = 1
87
+ mask = mask.unsqueeze(1)
88
+ x = next_token_seq
89
+ if i != 0:
90
+ hidden = None
91
+ x = x[:, -1:]
92
+ logits = model.forward_token(hidden, x, cache=cache2)[:, -1:]
93
+ scores = torch.softmax(logits / temp, dim=-1) * mask
94
+ samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
95
+ if i == 0:
96
+ next_token_seq = samples
97
+ for b in range(batch_size):
98
+ if end[b]:
99
+ continue
100
+ eid = samples[b].item()
101
+ if eid == tokenizer.eos_id:
102
+ end[b] = True
103
+ else:
104
+ event_names[b] = tokenizer.id_events[eid]
105
+ else:
106
+ next_token_seq = torch.cat([next_token_seq, samples], dim=1)
107
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
108
+ break
109
+ if next_token_seq.shape[1] < max_token_seq:
110
+ next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
111
+ "constant", value=tokenizer.pad_id)
112
+ next_token_seq = next_token_seq.unsqueeze(1)
113
+ input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
114
+ past_len = cur_len
115
+ cur_len += 1
116
+ bar.update(1)
117
+ yield next_token_seq[:, 0].cpu().numpy()
118
+ if all(end):
119
+ break
120
+
121
+
122
+ def create_msg(name, data):
123
+ return {"name": name, "data": data}
124
+
125
+
126
+ def send_msgs(msgs):
127
+ return json.dumps(msgs)
128
+
129
+
130
+ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
131
+ time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
132
+ remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
133
+ t = gen_events // 23
134
+ if "large" in model_name:
135
+ t = gen_events // 14
136
+ return t + 5
137
+
138
+
139
+ @spaces.GPU(duration=get_duration)
140
+ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
141
+ key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
142
+ seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
143
+ model = models[model_name]
144
+ model.to(device=opt.device)
145
+ tokenizer = model.tokenizer
146
+ bpm = int(bpm)
147
+ if time_sig == "auto":
148
+ time_sig = None
149
+ time_sig_nn = 4
150
+ time_sig_dd = 2
151
+ else:
152
+ time_sig_nn, time_sig_dd = time_sig.split('/')
153
+ time_sig_nn = int(time_sig_nn)
154
+ time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
155
+ if key_sig == 0:
156
+ key_sig = None
157
+ key_sig_sf = 0
158
+ key_sig_mi = 0
159
+ else:
160
+ key_sig = (key_sig - 1)
161
+ key_sig_sf = key_sig // 2 - 7
162
+ key_sig_mi = key_sig % 2
163
+ gen_events = int(gen_events)
164
+ max_len = gen_events
165
+ if seed_rand:
166
+ seed = random.randint(0, MAX_SEED)
167
+ generator = torch.Generator(opt.device).manual_seed(seed)
168
+ disable_patch_change = False
169
+ disable_channels = None
170
+ if tab == 0:
171
+ i = 0
172
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
173
+ if tokenizer.version == "v2":
174
+ if time_sig is not None:
175
+ mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
176
+ if key_sig is not None:
177
+ mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
178
+ if bpm != 0:
179
+ mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
180
+ patches = {}
181
+ if instruments is None:
182
+ instruments = []
183
+ for instr in instruments:
184
+ patches[i] = patch2number[instr]
185
+ i = (i + 1) if i != 8 else 10
186
+ if drum_kit != "None":
187
+ patches[9] = drum_kits2number[drum_kit]
188
+ for i, (c, p) in enumerate(patches.items()):
189
+ mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
190
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
191
+ mid_seq = mid.tolist()
192
+ if len(instruments) > 0:
193
+ disable_patch_change = True
194
+ disable_channels = [i for i in range(16) if i not in patches]
195
+ elif tab == 1 and mid is not None:
196
+ eps = 4 if reduce_cc_st else 0
197
+ mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
198
+ remap_track_channel=remap_track_channel,
199
+ add_default_instr=add_default_instr,
200
+ remove_empty_channels=remove_empty_channels)
201
+ mid = mid[:int(midi_events)]
202
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
203
+ mid_seq = mid.tolist()
204
+ elif tab == 2 and mid_seq is not None:
205
+ mid = np.asarray(mid_seq, dtype=np.int64)
206
+ if continuation_select > 0:
207
+ continuation_state.append(mid_seq)
208
+ mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
209
+ mid_seq = mid.tolist()
210
+ else:
211
+ continuation_state.append(mid.shape[1])
212
+ else:
213
+ continuation_state = [0]
214
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
215
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
216
+ mid_seq = mid.tolist()
217
+
218
+ if mid is not None:
219
+ max_len += mid.shape[1]
220
+
221
+ init_msgs = [create_msg("progress", [0, gen_events])]
222
+ if not (tab == 2 and continuation_select == 0):
223
+ for i in range(OUTPUT_BATCH_SIZE):
224
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
225
+ init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
226
+ create_msg("visualizer_append", [i, events])]
227
+ yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
228
+ midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
229
+ top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
230
+ disable_control_change=not allow_cc, disable_channels=disable_channels,
231
+ generator=generator)
232
+ events = [list() for i in range(OUTPUT_BATCH_SIZE)]
233
+ t = time.time() + 1
234
+ for i, token_seqs in enumerate(midi_generator):
235
+ token_seqs = token_seqs.tolist()
236
+ for j in range(OUTPUT_BATCH_SIZE):
237
+ token_seq = token_seqs[j]
238
+ mid_seq[j].append(token_seq)
239
+ events[j].append(tokenizer.tokens2event(token_seq))
240
+ if time.time() - t > 0.5:
241
+ msgs = [create_msg("progress", [i + 1, gen_events])]
242
+ for j in range(OUTPUT_BATCH_SIZE):
243
+ msgs += [create_msg("visualizer_append", [j, events[j]])]
244
+ events[j] = list()
245
+ yield mid_seq, continuation_state, seed, send_msgs(msgs)
246
+ t = time.time()
247
+ yield mid_seq, continuation_state, seed, send_msgs([])
248
+
249
+
250
+ def finish_run(model_name, mid_seq):
251
+ if mid_seq is None:
252
+ outputs = [None] * OUTPUT_BATCH_SIZE
253
+ return *outputs, []
254
+ tokenizer = models[model_name].tokenizer
255
+ outputs = []
256
+ end_msgs = [create_msg("progress", [0, 0])]
257
+ if not os.path.exists("outputs"):
258
+ os.mkdir("outputs")
259
+ for i in range(OUTPUT_BATCH_SIZE):
260
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
261
+ mid = tokenizer.detokenize(mid_seq[i])
262
+ with open(f"outputs/output{i + 1}.mid", 'wb') as f:
263
+ f.write(MIDI.score2midi(mid))
264
+ outputs.append(f"outputs/output{i + 1}.mid")
265
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
266
+ create_msg("visualizer_append", [i, events]),
267
+ create_msg("visualizer_end", i)]
268
+ return *outputs, send_msgs(end_msgs)
269
+
270
+
271
+ def synthesis_task(mid):
272
+ return synthesizer.synthesis(MIDI.score2opus(mid))
273
+
274
+ def render_audio(model_name, mid_seq, should_render_audio):
275
+ if (not should_render_audio) or mid_seq is None:
276
+ outputs = [None] * OUTPUT_BATCH_SIZE
277
+ return tuple(outputs)
278
+ tokenizer = models[model_name].tokenizer
279
+ outputs = []
280
+ if not os.path.exists("outputs"):
281
+ os.mkdir("outputs")
282
+ audio_futures = []
283
+ for i in range(OUTPUT_BATCH_SIZE):
284
+ mid = tokenizer.detokenize(mid_seq[i])
285
+ audio_future = thread_pool.submit(synthesis_task, mid)
286
+ audio_futures.append(audio_future)
287
+ for future in audio_futures:
288
+ outputs.append((44100, future.result()))
289
+ if OUTPUT_BATCH_SIZE == 1:
290
+ return outputs[0]
291
+ return tuple(outputs)
292
+
293
+
294
+ def undo_continuation(model_name, mid_seq, continuation_state):
295
+ if mid_seq is None or len(continuation_state) < 2:
296
+ return mid_seq, continuation_state, send_msgs([])
297
+ tokenizer = models[model_name].tokenizer
298
+ if isinstance(continuation_state[-1], list):
299
+ mid_seq = continuation_state[-1]
300
+ else:
301
+ mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
302
+ continuation_state = continuation_state[:-1]
303
+ end_msgs = [create_msg("progress", [0, 0])]
304
+ for i in range(OUTPUT_BATCH_SIZE):
305
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
306
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
307
+ create_msg("visualizer_append", [i, events]),
308
+ create_msg("visualizer_end", i)]
309
+ return mid_seq, continuation_state, send_msgs(end_msgs)
310
+
311
+
312
+ def load_javascript(dir="javascript"):
313
+ scripts_list = glob.glob(f"{dir}/*.js")
314
+ javascript = ""
315
+ for path in scripts_list:
316
+ with open(path, "r", encoding="utf8") as jsfile:
317
+ js_content = jsfile.read()
318
+ js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
319
+ f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
320
+ javascript += f"\n<!-- {path} --><script>{js_content}</script>"
321
+ template_response_ori = gr.routes.templates.TemplateResponse
322
+
323
+ def template_response(*args, **kwargs):
324
+ res = template_response_ori(*args, **kwargs)
325
+ res.body = res.body.replace(
326
+ b'</head>', f'{javascript}</head>'.encode("utf8"))
327
+ res.init_headers()
328
+ return res
329
+
330
+ gr.routes.templates.TemplateResponse = template_response
331
+
332
+
333
+ def hf_hub_download_retry(repo_id, filename):
334
+ print(f"downloading {repo_id} {filename}")
335
+ retry = 0
336
+ err = None
337
+ while retry < 30:
338
+ try:
339
+ return hf_hub_download(repo_id=repo_id, filename=filename)
340
+ except Exception as e:
341
+ err = e
342
+ retry += 1
343
+ if err:
344
+ raise err
345
+
346
+
347
+ number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
348
+ 40: "Blush", 48: "Orchestra"}
349
+ patch2number = {v: k for k, v in MIDI.Number2patch.items()}
350
+ drum_kits2number = {v: k for k, v in number2drum_kits.items()}
351
+ key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
352
+ 'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
353
+
354
+ if __name__ == "__main__":
355
+ parser = argparse.ArgumentParser()
356
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
357
+ parser.add_argument("--port", type=int, default=7860, help="gradio server port")
358
+ parser.add_argument("--device", type=str, default="cuda", help="device to run model")
359
+ parser.add_argument("--batch", type=int, default=8, help="batch size")
360
+ parser.add_argument("--max-gen", type=int, default=1024, help="max")
361
+ opt = parser.parse_args()
362
+ OUTPUT_BATCH_SIZE = opt.batch
363
+ soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
364
+ thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
365
+ synthesizer = MidiSynthesizer(soundfont_path)
366
+ models_info = {
367
+ "generic pretrain model (tv2o-medium) by skytnt": [
368
+ "skytnt/midi-model-tv2o-medium", {
369
+ "jpop": "skytnt/midi-model-tv2om-jpop-lora",
370
+ "touhou": "skytnt/midi-model-tv2om-touhou-lora"
371
+ }
372
+ ],
373
+ "generic pretrain model (tv2o-large) by asigalov61": [
374
+ "asigalov61/Music-Llama", {}
375
+ ],
376
+ "generic pretrain model (tv2o-medium) by asigalov61": [
377
+ "asigalov61/Music-Llama-Medium", {}
378
+ ],
379
+ "generic pretrain model (tv1-medium) by skytnt": [
380
+ "skytnt/midi-model", {}
381
+ ]
382
+ }
383
+ models = {}
384
+ if opt.device == "cuda":
385
+ torch.backends.cudnn.deterministic = True
386
+ torch.backends.cudnn.benchmark = False
387
+ torch.backends.cuda.matmul.allow_tf32 = True
388
+ torch.backends.cudnn.allow_tf32 = True
389
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
390
+ torch.backends.cuda.enable_flash_sdp(True)
391
+ for name, (repo_id, loras) in models_info.items():
392
+ model = MIDIModel.from_pretrained(repo_id)
393
+ model.to(device="cpu", dtype=torch.float32)
394
+ models[name] = model
395
+ for lora_name, lora_repo in loras.items():
396
+ model = MIDIModel.from_pretrained(repo_id)
397
+ print(f"loading lora {lora_repo} for {name}")
398
+ model = model.load_merge_lora(lora_repo)
399
+ model.to(device="cpu", dtype=torch.float32)
400
+ models[f"{name} with {lora_name} lora"] = model
401
+
402
+ load_javascript()
403
+ app = gr.Blocks(theme=gr.themes.Soft())
404
+ with app:
405
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
406
+ gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
407
+ "Midi event transformer for symbolic music generation\n\n"
408
+ "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
409
+ "[Open In Colab]"
410
+ "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
411
+ " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
412
+ " for unlimited generation\n\n"
413
+ "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
414
+ "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
415
+ )
416
+ js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
417
+ js_msg.change(None, [js_msg], [], js="""
418
+ (msg_json) =>{
419
+ let msgs = JSON.parse(msg_json);
420
+ executeCallbacks(msgReceiveCallbacks, msgs);
421
+ return [];
422
+ }
423
+ """)
424
+ input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
425
+ type="value", value=list(models.keys())[0])
426
+ tab_select = gr.State(value=0)
427
+ with gr.Tabs():
428
+ with gr.TabItem("custom prompt") as tab1:
429
+ input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
430
+ multiselect=True, max_choices=15, type="value")
431
+ input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
432
+ value="None")
433
+ input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
434
+ step=1,
435
+ value=0)
436
+ input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
437
+ value="auto",
438
+ choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
439
+ "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
440
+ )
441
+ input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
442
+ value="auto",
443
+ choices=["auto"] + key_signatures,
444
+ type="index"
445
+ )
446
+ example1 = gr.Examples([
447
+ [[], "None"],
448
+ [["Acoustic Grand"], "None"],
449
+ [['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
450
+ 'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
451
+ [['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
452
+ 'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
453
+ [['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
454
+ 'Oboe', 'Pizzicato Strings'], "Orchestra"],
455
+ [['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
456
+ 'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
457
+ [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
458
+ "Electric Bass(finger)"], "Standard"]
459
+ ], [input_instruments, input_drum_kit])
460
+ with gr.TabItem("midi prompt") as tab2:
461
+ input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
462
+ input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
463
+ step=1,
464
+ value=128)
465
+ input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
466
+ input_remap_track_channel = gr.Checkbox(
467
+ label="remap tracks and channels so each track has only one channel and in order", value=True)
468
+ input_add_default_instr = gr.Checkbox(
469
+ label="add a default instrument to channels that don't have an instrument", value=True)
470
+ input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
471
+ example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
472
+ [input_midi, input_midi_events])
473
+ with gr.TabItem("last output prompt") as tab3:
474
+ gr.Markdown("Continue generating on the last output.")
475
+ input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
476
+ choices=["all"] + [f"output{i + 1}" for i in
477
+ range(OUTPUT_BATCH_SIZE)],
478
+ type="index"
479
+ )
480
+ undo_btn = gr.Button("undo the last continuation")
481
+
482
+ tab1.select(lambda: 0, None, tab_select, queue=False)
483
+ tab2.select(lambda: 1, None, tab_select, queue=False)
484
+ tab3.select(lambda: 2, None, tab_select, queue=False)
485
+ input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
486
+ step=1, value=0)
487
+ input_seed_rand = gr.Checkbox(label="random seed", value=True)
488
+ input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
489
+ step=1, value=opt.max_gen // 2)
490
+ with gr.Accordion("options", open=False):
491
+ input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
492
+ input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
493
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
494
+ input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
495
+ input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
496
+ example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
497
+ [input_temp, input_top_p, input_top_k])
498
+ run_btn = gr.Button("generate", variant="primary")
499
+ # stop_btn = gr.Button("stop and output")
500
+ output_midi_seq = gr.State()
501
+ output_continuation_state = gr.State([0])
502
+ midi_outputs = []
503
+ audio_outputs = []
504
+ with gr.Tabs(elem_id="output_tabs"):
505
+ for i in range(OUTPUT_BATCH_SIZE):
506
+ with gr.TabItem(f"output {i + 1}") as tab1:
507
+ output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
508
+ output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
509
+ output_midi = gr.File(label="output midi", file_types=[".mid"])
510
+ midi_outputs.append(output_midi)
511
+ audio_outputs.append(output_audio)
512
+ run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
513
+ input_continuation_select, input_instruments, input_drum_kit, input_bpm,
514
+ input_time_sig, input_key_sig, input_midi, input_midi_events,
515
+ input_reduce_cc_st, input_remap_track_channel,
516
+ input_add_default_instr, input_remove_empty_channels,
517
+ input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
518
+ input_top_k, input_allow_cc],
519
+ [output_midi_seq, output_continuation_state, input_seed, js_msg],
520
+ concurrency_limit=10, queue=True)
521
+ finish_run_event = run_event.then(fn=finish_run,
522
+ inputs=[input_model, output_midi_seq],
523
+ outputs=midi_outputs + [js_msg],
524
+ queue=False)
525
+ finish_run_event.then(fn=render_audio,
526
+ inputs=[input_model, output_midi_seq, input_render_audio],
527
+ outputs=audio_outputs,
528
+ queue=False)
529
+ # stop_btn.click(None, [], [], cancels=run_event,
530
+ # queue=False)
531
+ undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
532
+ [output_midi_seq, output_continuation_state, js_msg], queue=False)
533
+ app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
534
+ thread_pool.shutdown()
app_onnx.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import random
3
+ import argparse
4
+ import glob
5
+ import json
6
+ import os
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import onnxruntime as rt
13
+ import tqdm
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ import MIDI
17
+ from midi_synthesizer import MidiSynthesizer
18
+ from midi_tokenizer import MIDITokenizer
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ in_space = os.getenv("SYSTEM") == "spaces"
22
+
23
+
24
+ def softmax(x, axis):
25
+ x_max = np.amax(x, axis=axis, keepdims=True)
26
+ exp_x_shifted = np.exp(x - x_max)
27
+ return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
28
+
29
+
30
+ def sample_top_p_k(probs, p, k, generator=None):
31
+ if generator is None:
32
+ generator = np.random
33
+ probs_idx = np.argsort(-probs, axis=-1)
34
+ probs_sort = np.take_along_axis(probs, probs_idx, -1)
35
+ probs_sum = np.cumsum(probs_sort, axis=-1)
36
+ mask = probs_sum - probs_sort > p
37
+ probs_sort[mask] = 0.0
38
+ mask = np.zeros(probs_sort.shape[-1])
39
+ mask[:k] = 1
40
+ probs_sort = probs_sort * mask
41
+ probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
42
+ shape = probs_sort.shape
43
+ probs_sort_flat = probs_sort.reshape(-1, shape[-1])
44
+ probs_idx_flat = probs_idx.reshape(-1, shape[-1])
45
+ next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
46
+ next_token = next_token.reshape(*shape[:-1])
47
+ return next_token
48
+
49
+
50
+ def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
51
+ io_binding = model.io_binding()
52
+ for input_ in model.get_inputs():
53
+ name = input_.name
54
+ if name.startswith("past_key_values"):
55
+ present_name = name.replace("past_key_values", "present")
56
+ if present_name in outputs:
57
+ v = outputs[present_name]
58
+ else:
59
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
60
+ (batch_size, input_.shape[1], past_len, input_.shape[3]),
61
+ element_type=np.float32,
62
+ device_type=device)
63
+ inputs[name] = v
64
+ else:
65
+ v = inputs[name]
66
+ io_binding.bind_ortvalue_input(name, v)
67
+
68
+ for output in model.get_outputs():
69
+ name = output.name
70
+ if name.startswith("present"):
71
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
72
+ (batch_size, output.shape[1], cur_len, output.shape[3]),
73
+ element_type=np.float32,
74
+ device_type=device)
75
+ outputs[name] = v
76
+ else:
77
+ v = outputs[name]
78
+ io_binding.bind_ortvalue_output(name, v)
79
+ return io_binding
80
+
81
+ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
82
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
83
+ tokenizer = model[2]
84
+ if disable_channels is not None:
85
+ disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
86
+ else:
87
+ disable_channels = []
88
+ if generator is None:
89
+ generator = np.random
90
+ max_token_seq = tokenizer.max_token_seq
91
+ if prompt is None:
92
+ input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
93
+ input_tensor[0, 0] = tokenizer.bos_id # bos
94
+ input_tensor = input_tensor[None, :, :]
95
+ input_tensor = np.repeat(input_tensor, repeats=batch_size, axis=0)
96
+ else:
97
+ if len(prompt.shape) == 2:
98
+ prompt = prompt[None, :]
99
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
100
+ elif prompt.shape[0] == 1:
101
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
102
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
103
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
104
+ prompt = prompt[..., :max_token_seq]
105
+ if prompt.shape[-1] < max_token_seq:
106
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
107
+ mode="constant", constant_values=tokenizer.pad_id)
108
+ input_tensor = prompt
109
+ cur_len = input_tensor.shape[1]
110
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
111
+ model0_inputs = {}
112
+ model0_outputs = {}
113
+ emb_size = 1024
114
+ for output in model[0].get_outputs():
115
+ if output.name == "hidden":
116
+ emb_size = output.shape[2]
117
+ past_len = 0
118
+ with bar:
119
+ while cur_len < max_len:
120
+ end = [False] * batch_size
121
+ model0_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(input_tensor[:, past_len:], device_type=device)
122
+ model0_outputs["hidden"] = rt.OrtValue.ortvalue_from_shape_and_type(
123
+ (batch_size, cur_len - past_len, emb_size),
124
+ element_type=np.float32,
125
+ device_type=device)
126
+ io_binding = apply_io_binding(model[0], model0_inputs, model0_outputs, batch_size, past_len, cur_len)
127
+ io_binding.synchronize_inputs()
128
+ model[0].run_with_iobinding(io_binding)
129
+ io_binding.synchronize_outputs()
130
+
131
+ hidden = model0_outputs["hidden"].numpy()[:, -1:]
132
+ next_token_seq = np.zeros((batch_size, 0), dtype=np.int64)
133
+ event_names = [""] * batch_size
134
+ model1_inputs = {"hidden": rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)}
135
+ model1_outputs = {}
136
+ for i in range(max_token_seq):
137
+ mask = np.zeros((batch_size, tokenizer.vocab_size), dtype=np.int64)
138
+ for b in range(batch_size):
139
+ if end[b]:
140
+ mask[b, tokenizer.pad_id] = 1
141
+ continue
142
+ if i == 0:
143
+ mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
144
+ if disable_patch_change:
145
+ mask_ids.remove(tokenizer.event_ids["patch_change"])
146
+ if disable_control_change:
147
+ mask_ids.remove(tokenizer.event_ids["control_change"])
148
+ mask[b, mask_ids] = 1
149
+ else:
150
+ param_names = tokenizer.events[event_names[b]]
151
+ if i > len(param_names):
152
+ mask[b, tokenizer.pad_id] = 1
153
+ continue
154
+ param_name = param_names[i - 1]
155
+ mask_ids = tokenizer.parameter_ids[param_name]
156
+ if param_name == "channel":
157
+ mask_ids = [i for i in mask_ids if i not in disable_channels]
158
+ mask[b, mask_ids] = 1
159
+ mask = mask[:, None, :]
160
+ x = next_token_seq
161
+ if i != 0:
162
+ # cached
163
+ if i == 1:
164
+ hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
165
+ model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
166
+ x = x[:, -1:]
167
+ model1_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(x, device_type=device)
168
+ model1_outputs["y"] = rt.OrtValue.ortvalue_from_shape_and_type(
169
+ (batch_size, 1, tokenizer.vocab_size),
170
+ element_type=np.float32,
171
+ device_type=device
172
+ )
173
+ io_binding = apply_io_binding(model[1], model1_inputs, model1_outputs, batch_size, i, i+1)
174
+ io_binding.synchronize_inputs()
175
+ model[1].run_with_iobinding(io_binding)
176
+ io_binding.synchronize_outputs()
177
+ logits = model1_outputs["y"].numpy()
178
+ scores = softmax(logits / temp, -1) * mask
179
+ samples = sample_top_p_k(scores, top_p, top_k, generator)
180
+ if i == 0:
181
+ next_token_seq = samples
182
+ for b in range(batch_size):
183
+ if end[b]:
184
+ continue
185
+ eid = samples[b].item()
186
+ if eid == tokenizer.eos_id:
187
+ end[b] = True
188
+ else:
189
+ event_names[b] = tokenizer.id_events[eid]
190
+ else:
191
+ next_token_seq = np.concatenate([next_token_seq, samples], axis=1)
192
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
193
+ break
194
+ if next_token_seq.shape[1] < max_token_seq:
195
+ next_token_seq = np.pad(next_token_seq,
196
+ ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
197
+ mode="constant", constant_values=tokenizer.pad_id)
198
+ next_token_seq = next_token_seq[:, None, :]
199
+ input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
200
+ past_len = cur_len
201
+ cur_len += 1
202
+ bar.update(1)
203
+ yield next_token_seq[:, 0]
204
+ if all(end):
205
+ break
206
+
207
+
208
+ def create_msg(name, data):
209
+ return {"name": name, "data": data}
210
+
211
+
212
+ def send_msgs(msgs):
213
+ return json.dumps(msgs)
214
+
215
+
216
+ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
217
+ time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
218
+ remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
219
+ t = gen_events // 30
220
+ if "large" in model_name:
221
+ t = gen_events // 23
222
+ return t + 5
223
+
224
+
225
+ @spaces.GPU(duration=get_duration)
226
+ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
227
+ key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
228
+ seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
229
+ model = models[model_name]
230
+ model_base = rt.InferenceSession(model[0], providers=providers)
231
+ model_token = rt.InferenceSession(model[1], providers=providers)
232
+ tokenizer = model[2]
233
+ model = [model_base, model_token, tokenizer]
234
+ bpm = int(bpm)
235
+ if time_sig == "auto":
236
+ time_sig = None
237
+ time_sig_nn = 4
238
+ time_sig_dd = 2
239
+ else:
240
+ time_sig_nn, time_sig_dd = time_sig.split('/')
241
+ time_sig_nn = int(time_sig_nn)
242
+ time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
243
+ if key_sig == 0:
244
+ key_sig = None
245
+ key_sig_sf = 0
246
+ key_sig_mi = 0
247
+ else:
248
+ key_sig = (key_sig - 1)
249
+ key_sig_sf = key_sig // 2 - 7
250
+ key_sig_mi = key_sig % 2
251
+ gen_events = int(gen_events)
252
+ max_len = gen_events
253
+ if seed_rand:
254
+ seed = random.randint(0, MAX_SEED)
255
+ generator = np.random.RandomState(seed)
256
+ disable_patch_change = False
257
+ disable_channels = None
258
+ if tab == 0:
259
+ i = 0
260
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
261
+ if tokenizer.version == "v2":
262
+ if time_sig is not None:
263
+ mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
264
+ if key_sig is not None:
265
+ mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
266
+ if bpm != 0:
267
+ mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
268
+ patches = {}
269
+ if instruments is None:
270
+ instruments = []
271
+ for instr in instruments:
272
+ patches[i] = patch2number[instr]
273
+ i = (i + 1) if i != 8 else 10
274
+ if drum_kit != "None":
275
+ patches[9] = drum_kits2number[drum_kit]
276
+ for i, (c, p) in enumerate(patches.items()):
277
+ mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
278
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
279
+ mid_seq = mid.tolist()
280
+ if len(instruments) > 0:
281
+ disable_patch_change = True
282
+ disable_channels = [i for i in range(16) if i not in patches]
283
+ elif tab == 1 and mid is not None:
284
+ eps = 4 if reduce_cc_st else 0
285
+ mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
286
+ remap_track_channel=remap_track_channel,
287
+ add_default_instr=add_default_instr,
288
+ remove_empty_channels=remove_empty_channels)
289
+ mid = mid[:int(midi_events)]
290
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
291
+ mid_seq = mid.tolist()
292
+ elif tab == 2 and mid_seq is not None:
293
+ mid = np.asarray(mid_seq, dtype=np.int64)
294
+ if continuation_select > 0:
295
+ continuation_state.append(mid_seq)
296
+ mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
297
+ mid_seq = mid.tolist()
298
+ else:
299
+ continuation_state.append(mid.shape[1])
300
+ else:
301
+ continuation_state = [0]
302
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
303
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
304
+ mid_seq = mid.tolist()
305
+
306
+ if mid is not None:
307
+ max_len += mid.shape[1]
308
+
309
+ init_msgs = [create_msg("progress", [0, gen_events])]
310
+ if not (tab == 2 and continuation_select == 0):
311
+ for i in range(OUTPUT_BATCH_SIZE):
312
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
313
+ init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
314
+ create_msg("visualizer_append", [i, events])]
315
+ yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
316
+ midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
317
+ top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
318
+ disable_control_change=not allow_cc, disable_channels=disable_channels,
319
+ generator=generator)
320
+ events = [list() for i in range(OUTPUT_BATCH_SIZE)]
321
+ t = time.time() + 1
322
+ for i, token_seqs in enumerate(midi_generator):
323
+ token_seqs = token_seqs.tolist()
324
+ for j in range(OUTPUT_BATCH_SIZE):
325
+ token_seq = token_seqs[j]
326
+ mid_seq[j].append(token_seq)
327
+ events[j].append(tokenizer.tokens2event(token_seq))
328
+ if time.time() - t > 0.5:
329
+ msgs = [create_msg("progress", [i + 1, gen_events])]
330
+ for j in range(OUTPUT_BATCH_SIZE):
331
+ msgs += [create_msg("visualizer_append", [j, events[j]])]
332
+ events[j] = list()
333
+ yield mid_seq, continuation_state, seed, send_msgs(msgs)
334
+ t = time.time()
335
+ yield mid_seq, continuation_state, seed, send_msgs([])
336
+
337
+
338
+ def finish_run(model_name, mid_seq):
339
+ if mid_seq is None:
340
+ outputs = [None] * OUTPUT_BATCH_SIZE
341
+ return *outputs, []
342
+ tokenizer = models[model_name][2]
343
+ outputs = []
344
+ end_msgs = [create_msg("progress", [0, 0])]
345
+ if not os.path.exists("outputs"):
346
+ os.mkdir("outputs")
347
+ for i in range(OUTPUT_BATCH_SIZE):
348
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
349
+ mid = tokenizer.detokenize(mid_seq[i])
350
+ with open(f"outputs/output{i + 1}.mid", 'wb') as f:
351
+ f.write(MIDI.score2midi(mid))
352
+ outputs.append(f"outputs/output{i + 1}.mid")
353
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
354
+ create_msg("visualizer_append", [i, events]),
355
+ create_msg("visualizer_end", i)]
356
+ return *outputs, send_msgs(end_msgs)
357
+
358
+
359
+ def synthesis_task(mid):
360
+ return synthesizer.synthesis(MIDI.score2opus(mid))
361
+
362
+ def render_audio(model_name, mid_seq, should_render_audio):
363
+ if (not should_render_audio) or mid_seq is None:
364
+ outputs = [None] * OUTPUT_BATCH_SIZE
365
+ return tuple(outputs)
366
+ tokenizer = models[model_name][2]
367
+ outputs = []
368
+ if not os.path.exists("outputs"):
369
+ os.mkdir("outputs")
370
+ audio_futures = []
371
+ for i in range(OUTPUT_BATCH_SIZE):
372
+ mid = tokenizer.detokenize(mid_seq[i])
373
+ audio_future = thread_pool.submit(synthesis_task, mid)
374
+ audio_futures.append(audio_future)
375
+ for future in audio_futures:
376
+ outputs.append((44100, future.result()))
377
+ if OUTPUT_BATCH_SIZE == 1:
378
+ return outputs[0]
379
+ return tuple(outputs)
380
+
381
+
382
+ def undo_continuation(model_name, mid_seq, continuation_state):
383
+ if mid_seq is None or len(continuation_state) < 2:
384
+ return mid_seq, continuation_state, send_msgs([])
385
+ tokenizer = models[model_name][2]
386
+ if isinstance(continuation_state[-1], list):
387
+ mid_seq = continuation_state[-1]
388
+ else:
389
+ mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
390
+ continuation_state = continuation_state[:-1]
391
+ end_msgs = [create_msg("progress", [0, 0])]
392
+ for i in range(OUTPUT_BATCH_SIZE):
393
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
394
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
395
+ create_msg("visualizer_append", [i, events]),
396
+ create_msg("visualizer_end", i)]
397
+ return mid_seq, continuation_state, send_msgs(end_msgs)
398
+
399
+
400
+ def load_javascript(dir="javascript"):
401
+ scripts_list = glob.glob(f"{dir}/*.js")
402
+ javascript = ""
403
+ for path in scripts_list:
404
+ with open(path, "r", encoding="utf8") as jsfile:
405
+ js_content = jsfile.read()
406
+ js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
407
+ f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
408
+ javascript += f"\n<!-- {path} --><script>{js_content}</script>"
409
+ template_response_ori = gr.routes.templates.TemplateResponse
410
+
411
+ def template_response(*args, **kwargs):
412
+ res = template_response_ori(*args, **kwargs)
413
+ res.body = res.body.replace(
414
+ b'</head>', f'{javascript}</head>'.encode("utf8"))
415
+ res.init_headers()
416
+ return res
417
+
418
+ gr.routes.templates.TemplateResponse = template_response
419
+
420
+
421
+ def hf_hub_download_retry(repo_id, filename):
422
+ print(f"downloading {repo_id} {filename}")
423
+ retry = 0
424
+ err = None
425
+ while retry < 30:
426
+ try:
427
+ return hf_hub_download(repo_id=repo_id, filename=filename)
428
+ except Exception as e:
429
+ err = e
430
+ retry += 1
431
+ if err:
432
+ raise err
433
+
434
+
435
+ def get_tokenizer(repo_id):
436
+ config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
437
+ with open(config_path, "r") as f:
438
+ config = json.load(f)
439
+ tokenizer = MIDITokenizer(config["tokenizer"]["version"])
440
+ tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
441
+ return tokenizer
442
+
443
+
444
+ number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
445
+ 40: "Blush", 48: "Orchestra"}
446
+ patch2number = {v: k for k, v in MIDI.Number2patch.items()}
447
+ drum_kits2number = {v: k for k, v in number2drum_kits.items()}
448
+ key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
449
+ 'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
450
+
451
+ if __name__ == "__main__":
452
+ parser = argparse.ArgumentParser()
453
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
454
+ parser.add_argument("--port", type=int, default=7860, help="gradio server port")
455
+ parser.add_argument("--device", type=str, default="cuda", help="device to run model")
456
+ parser.add_argument("--batch", type=int, default=8, help="batch size")
457
+ parser.add_argument("--max-gen", type=int, default=1024, help="max")
458
+ opt = parser.parse_args()
459
+ OUTPUT_BATCH_SIZE = opt.batch
460
+ soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
461
+ thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
462
+ synthesizer = MidiSynthesizer(soundfont_path)
463
+ models_info = {
464
+ "generic pretrain model (tv2o-medium) by skytnt": [
465
+ "skytnt/midi-model-tv2o-medium", "", {
466
+ "jpop": "skytnt/midi-model-tv2om-jpop-lora",
467
+ "touhou": "skytnt/midi-model-tv2om-touhou-lora"
468
+ }
469
+ ],
470
+ "generic pretrain model (tv2o-large) by asigalov61": [
471
+ "asigalov61/Music-Llama", "", {}
472
+ ],
473
+ "generic pretrain model (tv2o-medium) by asigalov61": [
474
+ "asigalov61/Music-Llama-Medium", "", {}
475
+ ],
476
+ "generic pretrain model (tv1-medium) by skytnt": [
477
+ "skytnt/midi-model", "", {}
478
+ ]
479
+ }
480
+ models = {}
481
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
482
+ device = "cuda"
483
+
484
+ for name, (repo_id, path, loras) in models_info.items():
485
+ model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
486
+ model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
487
+ tokenizer = get_tokenizer(repo_id)
488
+ models[name] = [model_base_path, model_token_path, tokenizer]
489
+ for lora_name, lora_repo in loras.items():
490
+ model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
491
+ model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
492
+ models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
493
+
494
+ load_javascript()
495
+ app = gr.Blocks(theme=gr.themes.Soft())
496
+ with app:
497
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
498
+ gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
499
+ "Midi event transformer for symbolic music generation\n\n"
500
+ "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
501
+ "[Open In Colab]"
502
+ "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
503
+ " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
504
+ " for unlimited generation\n\n"
505
+ "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
506
+ "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
507
+ )
508
+ js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
509
+ js_msg.change(None, [js_msg], [], js="""
510
+ (msg_json) =>{
511
+ let msgs = JSON.parse(msg_json);
512
+ executeCallbacks(msgReceiveCallbacks, msgs);
513
+ return [];
514
+ }
515
+ """)
516
+ input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
517
+ type="value", value=list(models.keys())[0])
518
+ tab_select = gr.State(value=0)
519
+ with gr.Tabs():
520
+ with gr.TabItem("custom prompt") as tab1:
521
+ input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
522
+ multiselect=True, max_choices=15, type="value")
523
+ input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
524
+ value="None")
525
+ input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
526
+ step=1,
527
+ value=0)
528
+ input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
529
+ value="auto",
530
+ choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
531
+ "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
532
+ )
533
+ input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
534
+ value="auto",
535
+ choices=["auto"] + key_signatures,
536
+ type="index"
537
+ )
538
+ example1 = gr.Examples([
539
+ [[], "None"],
540
+ [["Acoustic Grand"], "None"],
541
+ [['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
542
+ 'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
543
+ [['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
544
+ 'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
545
+ [['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
546
+ 'Oboe', 'Pizzicato Strings'], "Orchestra"],
547
+ [['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
548
+ 'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
549
+ [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
550
+ "Electric Bass(finger)"], "Standard"]
551
+ ], [input_instruments, input_drum_kit])
552
+ with gr.TabItem("midi prompt") as tab2:
553
+ input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
554
+ input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
555
+ step=1,
556
+ value=128)
557
+ input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
558
+ input_remap_track_channel = gr.Checkbox(
559
+ label="remap tracks and channels so each track has only one channel and in order", value=True)
560
+ input_add_default_instr = gr.Checkbox(
561
+ label="add a default instrument to channels that don't have an instrument", value=True)
562
+ input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
563
+ example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
564
+ [input_midi, input_midi_events])
565
+ with gr.TabItem("last output prompt") as tab3:
566
+ gr.Markdown("Continue generating on the last output.")
567
+ input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
568
+ choices=["all"] + [f"output{i + 1}" for i in
569
+ range(OUTPUT_BATCH_SIZE)],
570
+ type="index"
571
+ )
572
+ undo_btn = gr.Button("undo the last continuation")
573
+
574
+ tab1.select(lambda: 0, None, tab_select, queue=False)
575
+ tab2.select(lambda: 1, None, tab_select, queue=False)
576
+ tab3.select(lambda: 2, None, tab_select, queue=False)
577
+ input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
578
+ step=1, value=0)
579
+ input_seed_rand = gr.Checkbox(label="random seed", value=True)
580
+ input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
581
+ step=1, value=opt.max_gen // 2)
582
+ with gr.Accordion("options", open=False):
583
+ input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
584
+ input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
585
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
586
+ input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
587
+ input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
588
+ example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
589
+ [input_temp, input_top_p, input_top_k])
590
+ run_btn = gr.Button("generate", variant="primary")
591
+ # stop_btn = gr.Button("stop and output")
592
+ output_midi_seq = gr.State()
593
+ output_continuation_state = gr.State([0])
594
+ midi_outputs = []
595
+ audio_outputs = []
596
+ with gr.Tabs(elem_id="output_tabs"):
597
+ for i in range(OUTPUT_BATCH_SIZE):
598
+ with gr.TabItem(f"output {i + 1}") as tab1:
599
+ output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
600
+ output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
601
+ output_midi = gr.File(label="output midi", file_types=[".mid"])
602
+ midi_outputs.append(output_midi)
603
+ audio_outputs.append(output_audio)
604
+ run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
605
+ input_continuation_select, input_instruments, input_drum_kit, input_bpm,
606
+ input_time_sig, input_key_sig, input_midi, input_midi_events,
607
+ input_reduce_cc_st, input_remap_track_channel,
608
+ input_add_default_instr, input_remove_empty_channels,
609
+ input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
610
+ input_top_k, input_allow_cc],
611
+ [output_midi_seq, output_continuation_state, input_seed, js_msg],
612
+ concurrency_limit=10, queue=True)
613
+ finish_run_event = run_event.then(fn=finish_run,
614
+ inputs=[input_model, output_midi_seq],
615
+ outputs=midi_outputs + [js_msg],
616
+ queue=False)
617
+ finish_run_event.then(fn=render_audio,
618
+ inputs=[input_model, output_midi_seq, input_render_audio],
619
+ outputs=audio_outputs,
620
+ queue=False)
621
+ # stop_btn.click(None, [], [], cancels=run_event,
622
+ # queue=False)
623
+ undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
624
+ [output_midi_seq, output_continuation_state, js_msg], queue=False)
625
+ app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
626
+ thread_pool.shutdown()
example/Bach--Fugue-in-D-Minor.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1398121eb86a33e73f90ec84be71dac6abc0ddf11372ea7cdd9e01586938a56b
3
+ size 7720
example/Beethoven--Symphony-No5-in-C-Minor-Fate-Opus-67.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ff6fdcd644e781d36411bf40ab7a1f4849adddbcd1040eaec22751c5ca99d2
3
+ size 87090
example/Chopin--Nocturne No. 9 in B Major, Opus 32 No.1, Andante Sostenuto.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a236e647ad9f5d0af680d3ca19d3b60f334c4bde6b4f86310f63405245c476e
3
+ size 13484
example/Mozart--Requiem, No.1..mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa49bf4633401e16777fe47f6f53a494c2166f5101af6dafc60114932a59b9bd
3
+ size 14695
example/castle_in_the_sky.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa14aec6f1be15c4fddd0decc6d9152204f160d4e07e05d8d1dc9f209c309ff7
3
+ size 7957
example/eva-残酷な天使のテーゼ.mid ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e513487543d7e27ec5dc30f027302d2a3b5a3aaf9af554def1e5cd6a7a8d355a
3
+ size 17671
javascript/app.js ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const MIDI_OUTPUT_BATCH_SIZE=4;
2
+ //Do not change MIDI_OUTPUT_BATCH_SIZE. It will be automatically replaced.
3
+
4
+ /**
5
+ * 自动绕过 shadowRoot 的 querySelector
6
+ * @param {string} selector - 要查询的 CSS 选择器
7
+ * @returns {Element|null} - 匹配的元素或 null 如果未找到
8
+ */
9
+ function deepQuerySelector(selector) {
10
+ /**
11
+ * 在指定的根元素或文档对象下深度查询元素
12
+ * @param {Element|Document} root - 要开始搜索的根元素或文档对象
13
+ * @param {string} selector - 要查询的 CSS 选择器
14
+ * @returns {Element|null} - 匹配的元素或 null 如果未找到
15
+ */
16
+ function deepSearch(root, selector) {
17
+ // 在当前根元素下查找
18
+ let element = root.querySelector(selector);
19
+ if (element) {
20
+ return element;
21
+ }
22
+
23
+ // 如果未找到,递归检查 shadow DOM
24
+ const shadowHosts = root.querySelectorAll('*');
25
+
26
+ for (let i = 0; i < shadowHosts.length; i++) {
27
+ const host = shadowHosts[i];
28
+
29
+ // 检查当前元素是否有 shadowRoot
30
+ if (host.shadowRoot) {
31
+ element = deepSearch(host.shadowRoot, selector);
32
+ if (element) {
33
+ return element;
34
+ }
35
+ }
36
+ }
37
+ // 未找到元素
38
+ return null;
39
+ }
40
+
41
+ return deepSearch(this, selector);
42
+ }
43
+
44
+ Element.prototype.deepQuerySelector = deepQuerySelector;
45
+ Document.prototype.deepQuerySelector = deepQuerySelector;
46
+
47
+ function gradioApp() {
48
+ const elems = document.getElementsByTagName('gradio-app')
49
+ const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
50
+ return !!gradioShadowRoot ? gradioShadowRoot : document;
51
+ }
52
+
53
+ uiUpdateCallbacks = []
54
+ msgReceiveCallbacks = []
55
+
56
+ function onUiUpdate(callback){
57
+ uiUpdateCallbacks.push(callback)
58
+ }
59
+
60
+ function onMsgReceive(callback){
61
+ msgReceiveCallbacks.push(callback)
62
+ }
63
+
64
+ function runCallback(x, m){
65
+ try {
66
+ x(m)
67
+ } catch (e) {
68
+ (console.error || console.log).call(console, e.message, e);
69
+ }
70
+ }
71
+ function executeCallbacks(queue, m) {
72
+ queue.forEach(function(x){runCallback(x, m)})
73
+ }
74
+
75
+ document.addEventListener("DOMContentLoaded", function() {
76
+ var mutationObserver = new MutationObserver(function(m){
77
+ executeCallbacks(uiUpdateCallbacks, m);
78
+ });
79
+ mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
80
+ });
81
+
82
+ function HSVtoRGB(h, s, v) {
83
+ let r, g, b, i, f, p, q, t;
84
+ i = Math.floor(h * 6);
85
+ f = h * 6 - i;
86
+ p = v * (1 - s);
87
+ q = v * (1 - f * s);
88
+ t = v * (1 - (1 - f) * s);
89
+ switch (i % 6) {
90
+ case 0: r = v; g = t; b = p; break;
91
+ case 1: r = q; g = v; b = p; break;
92
+ case 2: r = p; g = v; b = t; break;
93
+ case 3: r = p; g = q; b = v; break;
94
+ case 4: r = t; g = p; b = v; break;
95
+ case 5: r = v; g = p; b = q; break;
96
+ }
97
+ return {
98
+ r: Math.round(r * 255),
99
+ g: Math.round(g * 255),
100
+ b: Math.round(b * 255)
101
+ };
102
+ }
103
+
104
+ function isMobile(){
105
+ return /(iPhone|iPad|iPod|iOS|Android|Windows Phone)/i.test(navigator.userAgent);
106
+ }
107
+
108
+ const number2patch = ['Acoustic Grand', 'Bright Acoustic', 'Electric Grand', 'Honky-Tonk', 'Electric Piano 1', 'Electric Piano 2', 'Harpsichord', 'Clav', 'Celesta', 'Glockenspiel', 'Music Box', 'Vibraphone', 'Marimba', 'Xylophone', 'Tubular Bells', 'Dulcimer', 'Drawbar Organ', 'Percussive Organ', 'Rock Organ', 'Church Organ', 'Reed Organ', 'Accordion', 'Harmonica', 'Tango Accordion', 'Acoustic Guitar(nylon)', 'Acoustic Guitar(steel)', 'Electric Guitar(jazz)', 'Electric Guitar(clean)', 'Electric Guitar(muted)', 'Overdriven Guitar', 'Distortion Guitar', 'Guitar Harmonics', 'Acoustic Bass', 'Electric Bass(finger)', 'Electric Bass(pick)', 'Fretless Bass', 'Slap Bass 1', 'Slap Bass 2', 'Synth Bass 1', 'Synth Bass 2', 'Violin', 'Viola', 'Cello', 'Contrabass', 'Tremolo Strings', 'Pizzicato Strings', 'Orchestral Harp', 'Timpani', 'String Ensemble 1', 'String Ensemble 2', 'SynthStrings 1', 'SynthStrings 2', 'Choir Aahs', 'Voice Oohs', 'Synth Voice', 'Orchestra Hit', 'Trumpet', 'Trombone', 'Tuba', 'Muted Trumpet', 'French Horn', 'Brass Section', 'SynthBrass 1', 'SynthBrass 2', 'Soprano Sax', 'Alto Sax', 'Tenor Sax', 'Baritone Sax', 'Oboe', 'English Horn', 'Bassoon', 'Clarinet', 'Piccolo', 'Flute', 'Recorder', 'Pan Flute', 'Blown Bottle', 'Skakuhachi', 'Whistle', 'Ocarina', 'Lead 1 (square)', 'Lead 2 (sawtooth)', 'Lead 3 (calliope)', 'Lead 4 (chiff)', 'Lead 5 (charang)', 'Lead 6 (voice)', 'Lead 7 (fifths)', 'Lead 8 (bass+lead)', 'Pad 1 (new age)', 'Pad 2 (warm)', 'Pad 3 (polysynth)', 'Pad 4 (choir)', 'Pad 5 (bowed)', 'Pad 6 (metallic)', 'Pad 7 (halo)', 'Pad 8 (sweep)', 'FX 1 (rain)', 'FX 2 (soundtrack)', 'FX 3 (crystal)', 'FX 4 (atmosphere)', 'FX 5 (brightness)', 'FX 6 (goblins)', 'FX 7 (echoes)', 'FX 8 (sci-fi)', 'Sitar', 'Banjo', 'Shamisen', 'Koto', 'Kalimba', 'Bagpipe', 'Fiddle', 'Shanai', 'Tinkle Bell', 'Agogo', 'Steel Drums', 'Woodblock', 'Taiko Drum', 'Melodic Tom', 'Synth Drum', 'Reverse Cymbal', 'Guitar Fret Noise', 'Breath Noise', 'Seashore', 'Bird Tweet', 'Telephone Ring', 'Helicopter', 'Applause', 'Gunshot']
109
+ const number2drum_kits = {0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz", 40: "Blush", 48: "Orchestra"}
110
+
111
+ class MidiVisualizer extends HTMLElement{
112
+ constructor() {
113
+ super();
114
+ this.midiEvents = [];
115
+ this.activeNotes = [];
116
+ this.midiTimes = [];
117
+ this.trackMap = new Map()
118
+ this.patches = [];
119
+ for (let i=0;i<16;i++){
120
+ this.patches.push([[0,0]])
121
+ }
122
+ this.container = null;
123
+ this.trackList = null
124
+ this.pianoRoll = null;
125
+ this.svg = null;
126
+ this.timeLine = null;
127
+ this.config = {
128
+ noteHeight : 4,
129
+ beatWidth: 32
130
+ }
131
+ if (isMobile()){
132
+ this.config.noteHeight = 1;
133
+ this.config.beatWidth = 16;
134
+ }
135
+ this.timePreBeat = 16
136
+ this.svgWidth = 0;
137
+ this.t1 = 0;
138
+ this.totalTimeMs = 0
139
+ this.playTime = 0
140
+ this.playTimeMs = 0
141
+ this.lastUpdateTime = 0
142
+ this.colorMap = new Map();
143
+ this.playing = false;
144
+ this.timer = null;
145
+ this.version = "v2"
146
+ this.init();
147
+ }
148
+
149
+ init(){
150
+ this.innerHTML=''
151
+ const shadow = this.attachShadow({mode: 'open'});
152
+ const style = document.createElement("style");
153
+ style.textContent = ".note.active {stroke: black;stroke-width: 0.75;stroke-opacity: 0.75;}";
154
+ const container = document.createElement('div');
155
+ container.style.display="flex";
156
+ container.style.height=`${this.config.noteHeight*128 + 25}px`;
157
+ const trackListContainer = document.createElement('div');
158
+ trackListContainer.style.width = "260px";
159
+ trackListContainer.style.minWidth = "260px";
160
+ trackListContainer.style.height = "100%";
161
+ trackListContainer.style.display="flex";
162
+ trackListContainer.style.flexDirection="column";
163
+ const trackList = document.createElement('div');
164
+ trackList.style.width = "100%";
165
+ trackList.style.height = "100%";
166
+ trackList.style.overflowY= "scroll";
167
+ trackList.style.display="flex";
168
+ trackList.style.flexDirection="column";
169
+ trackList.style.flexGrow="1";
170
+ const trackControls = document.createElement('div');
171
+ trackControls.style.display="flex";
172
+ trackControls.style.flexDirection="row";
173
+ trackControls.style.width = "100%";
174
+ trackControls.style.height = "50px";
175
+ trackControls.style.minHeight = "50px";
176
+ const allTrackBtn = document.createElement('button');
177
+ allTrackBtn.textContent = "All";
178
+ allTrackBtn.style.width = "50%";
179
+ allTrackBtn.style.height = "100%";
180
+ allTrackBtn.style.backgroundColor = "rgba(200, 200, 200, 0.3)";
181
+ allTrackBtn.style.color = 'inherit';
182
+ allTrackBtn.style.border = "none";
183
+ allTrackBtn.style.cursor = 'pointer';
184
+ let self = this;
185
+ allTrackBtn.onclick = function (){
186
+ self.trackMap.forEach((track, id) => {
187
+ track.setChecked(true);
188
+ })
189
+ };
190
+ const noneTrackBtn = document.createElement('button');
191
+ noneTrackBtn.textContent = "None";
192
+ noneTrackBtn.style.width = "50%";
193
+ noneTrackBtn.style.height = "100%";
194
+ noneTrackBtn.style.backgroundColor = "rgba(200, 200, 200, 0.3)";
195
+ noneTrackBtn.style.color = 'inherit';
196
+ noneTrackBtn.style.border = "none";
197
+ noneTrackBtn.style.cursor = 'pointer';
198
+ noneTrackBtn.onclick = function (){
199
+ self.trackMap.forEach((track, id) => {
200
+ track.setChecked(false);
201
+ });
202
+ };
203
+ const pianoRoll = document.createElement('div');
204
+ pianoRoll.style.overflowX= "scroll";
205
+ pianoRoll.style.flexGrow="1";
206
+ const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg');
207
+ svg.style.height = `${this.config.noteHeight*128}px`;
208
+ svg.style.width = `${this.svgWidth}px`;
209
+ const timeLine = document.createElementNS('http://www.w3.org/2000/svg', 'line');
210
+ timeLine.style.stroke = "green"
211
+ timeLine.style.strokeWidth = "2";
212
+
213
+ if (isMobile()){
214
+ trackListContainer.style.display = "none";
215
+ timeLine.style.strokeWidth = "1";
216
+ }
217
+ shadow.appendChild(style)
218
+ shadow.appendChild(container);
219
+ container.appendChild(trackListContainer);
220
+ trackListContainer.appendChild(trackList);
221
+ trackListContainer.appendChild(trackControls);
222
+ trackControls.appendChild(allTrackBtn);
223
+ trackControls.appendChild(noneTrackBtn);
224
+ container.appendChild(pianoRoll);
225
+ pianoRoll.appendChild(svg);
226
+ svg.appendChild(timeLine)
227
+ this.container = container;
228
+ this.trackList = trackList;
229
+ this.pianoRoll = pianoRoll;
230
+ this.svg = svg;
231
+ this.timeLine= timeLine;
232
+ for(let i = 0; i < 128 ; i++){
233
+ this.colorMap.set(i, HSVtoRGB(i / 128, 1, 1))
234
+ }
235
+ this.setPlayTime(0);
236
+ }
237
+
238
+ addTrack(id, tr, cl, name, color){
239
+ const track = {id, tr, cl, name, color, empty: true,
240
+ lastCC: new Map(),
241
+ instrument: cl===9?"Standard Drum":"Acoustic Grand",
242
+ svg: document.createElementNS('http://www.w3.org/2000/svg', 'g'),
243
+ ccPaths: new Map()
244
+ }
245
+ this.svg.appendChild(track.svg)
246
+ const trackItem = this.createTrackItem(track);
247
+ this.trackList.appendChild(trackItem);
248
+ this.trackMap.set(id, track);
249
+ return track;
250
+ }
251
+
252
+ getTrack(tr, cl){
253
+ const id = tr * 16 + cl
254
+ let track = this.trackMap.get(id)
255
+ if (!!track){
256
+ return track
257
+ }
258
+ let color = this.colorMap.get((this.trackMap.size*53)%128)
259
+ return this.addTrack(id, tr, cl, `Track ${tr}, Channel ${cl}`, color)
260
+ }
261
+
262
+ createTrackItem(track) {
263
+ const trackItem = document.createElement('div');
264
+ trackItem.style.display = 'flex';
265
+ trackItem.style.alignItems = 'center';
266
+ trackItem.style.width = '100%';
267
+ trackItem.style.position = 'relative';
268
+
269
+ const colorBar = document.createElement('div');
270
+ colorBar.style.width = '5%';
271
+ colorBar.style.height = '100%';
272
+ colorBar.style.position = 'absolute';
273
+ colorBar.style.left = '0';
274
+ colorBar.style.top = '0';
275
+ let color = track.color;
276
+ colorBar.style.backgroundColor = `rgb(${color.r}, ${color.g}, ${color.b})`;
277
+ trackItem.appendChild(colorBar);
278
+
279
+ const content = document.createElement('div');
280
+ content.style.paddingLeft = '30px';
281
+ content.style.flexGrow = '1';
282
+ content.style.color = "grey"
283
+ content.innerHTML = `<p>${track.name}<br>${track.instrument}</p>`;
284
+ trackItem.appendChild(content);
285
+ track.updateInstrument = function (instrument){
286
+ track.instrument = instrument;
287
+ content.innerHTML = `<p>${track.name}<br>${track.instrument}</p>`;
288
+ }
289
+ track.setEmpty = function (empty){
290
+ if (empty!==track.empty){
291
+ content.style.color = empty?"grey":"inherit";
292
+ }
293
+ }
294
+
295
+ const toggleSwitch = document.createElement('input');
296
+ toggleSwitch.type = 'checkbox';
297
+ toggleSwitch.checked = true;
298
+ toggleSwitch.style.marginLeft = 'auto';
299
+ toggleSwitch.style.marginRight = '10px';
300
+ toggleSwitch.style.width = '20px';
301
+ toggleSwitch.style.height = '20px';
302
+ toggleSwitch.style.cursor = 'pointer';
303
+
304
+ toggleSwitch.onchange = function () {
305
+ track.svg.setAttribute('visibility',toggleSwitch.checked? "visible" : "hidden")
306
+ };
307
+ track.setChecked = function (checked){
308
+ toggleSwitch.checked = checked;
309
+ track.svg.setAttribute('visibility',toggleSwitch.checked? "visible" : "hidden")
310
+ }
311
+ trackItem.appendChild(toggleSwitch);
312
+ return trackItem;
313
+ }
314
+
315
+ clearMidiEvents(){
316
+ this.pause()
317
+ this.midiEvents = [];
318
+ this.activeNotes = [];
319
+ this.midiTimes = [];
320
+ this.trackMap = new Map()
321
+ this.patches = [];
322
+ for (let i=0;i<16;i++){
323
+ this.patches.push([[0,0]])
324
+ }
325
+ this.t1 = 0
326
+ this.setPlayTime(0);
327
+ this.totalTimeMs = 0;
328
+ this.playTimeMs = 0
329
+ this.lastUpdateTime = 0
330
+ this.trackList.innerHTML = ''
331
+ this.svgWidth = 0
332
+ this.svg.innerHTML = ''
333
+ this.svg.style.width = `${this.svgWidth}px`;
334
+ this.svg.appendChild(this.timeLine)
335
+ }
336
+
337
+ appendMidiEvent(midiEvent){
338
+ if(midiEvent instanceof Array && midiEvent.length > 0){
339
+
340
+ this.t1 += midiEvent[1]
341
+ let t = this.t1*this.timePreBeat + midiEvent[2]
342
+ midiEvent = [midiEvent[0], t].concat(midiEvent.slice(3))
343
+ if(midiEvent[0] === "note"){
344
+ let track = midiEvent[2]
345
+ let duration = 0
346
+ let channel = 0
347
+ let pitch = 0
348
+ let velocity = 0
349
+ if(this.version === "v1"){
350
+ duration = midiEvent[3]
351
+ channel = midiEvent[4]
352
+ pitch = midiEvent[5]
353
+ velocity = midiEvent[6]
354
+ }else if (this.version === "v2"){
355
+ channel = midiEvent[3]
356
+ pitch = midiEvent[4]
357
+ velocity = midiEvent[5]
358
+ duration = midiEvent[6]
359
+ }
360
+ let vis_track = this.getTrack(track, channel);
361
+ vis_track.setEmpty(false);
362
+ let x = (t/this.timePreBeat)*this.config.beatWidth
363
+ let y = (127 - pitch)*this.config.noteHeight
364
+ let w = (duration/this.timePreBeat)*this.config.beatWidth
365
+ let h = this.config.noteHeight
366
+ this.svgWidth = Math.ceil(Math.max(x + w, this.svgWidth))
367
+ let opacity = Math.min(1, velocity/127 + 0.1).toFixed(2)
368
+ let rect = this.drawNote(vis_track, x,y,w,h, opacity)
369
+ midiEvent.push(rect);
370
+ this.setPlayTime(t);
371
+ this.pianoRoll.scrollTo(this.svgWidth - this.pianoRoll.offsetWidth, this.pianoRoll.scrollTop)
372
+ }else if(midiEvent[0] === "patch_change"){
373
+ let track = midiEvent[2];
374
+ let channel = midiEvent[3];
375
+ this.patches[channel].push([t, midiEvent[4]]);
376
+ this.patches[channel].sort((a, b) => a[0] - b[0]);
377
+ this.getTrack(track, channel);
378
+ }else if(midiEvent[0] === "control_change"){
379
+ let track = midiEvent[2];
380
+ let channel = midiEvent[3];
381
+ let controller = midiEvent[4];
382
+ let value = midiEvent[5];
383
+ let vis_track = this.getTrack(track, channel);
384
+ this.drawCC(vis_track, t, controller, value);
385
+ this.setPlayTime(t);
386
+ }
387
+ this.midiEvents.push(midiEvent);
388
+ this.svg.style.width = `${this.svgWidth}px`;
389
+ }
390
+
391
+ }
392
+
393
+ drawNote(track, x, y, w, h, opacity) {
394
+ if (!track.svg) {
395
+ return null;
396
+ }
397
+ const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect');
398
+ rect.classList.add('note');
399
+ const color = track.color;
400
+ rect.setAttribute('fill', `rgba(${color.r}, ${color.g}, ${color.b}, ${opacity})`);
401
+ // Round values to the nearest integer to avoid partially filled pixels.
402
+ rect.setAttribute('x', `${Math.round(x)}`);
403
+ rect.setAttribute('y', `${Math.round(y)}`);
404
+ rect.setAttribute('width', `${Math.round(w)}`);
405
+ rect.setAttribute('height', `${Math.round(h)}`);
406
+ track.svg.appendChild(rect);
407
+ return rect
408
+ }
409
+
410
+ drawCC(track, t, controller, value){
411
+ if (!track.svg) {
412
+ return null;
413
+ }
414
+ let path = track.ccPaths.get(controller);
415
+ let x = (t/this.timePreBeat)*this.config.beatWidth
416
+ let y = (127 - value)*this.config.noteHeight
417
+ if (!path){
418
+ path = document.createElementNS('http://www.w3.org/2000/svg', 'path');
419
+ path.setAttribute('visibility',"hidden");
420
+ path.setAttribute('fill', "transparent");
421
+ const color = track.color;
422
+ path.setAttribute('stroke', `rgba(${color.r}, ${color.g}, ${color.b}, 0.6)`);
423
+ path.setAttribute('stroke-width', "1");
424
+ path.setAttribute('d',
425
+ t===0?`M ${x} ${y}`:`M 0 ${127*this.config.noteHeight} H ${x} V ${y}`);
426
+ track.svg.appendChild(path);
427
+ track.ccPaths.set(controller, path);
428
+ track.lastCC.set(controller, value);
429
+ return path;
430
+ }
431
+ let lastVal = track.lastCC.get(controller);
432
+ if(lastVal !== value){
433
+ path.removeAttribute('visibility');
434
+ }
435
+ let d = path.getAttribute("d");
436
+ d += `H ${x} V ${y}`
437
+ path.setAttribute('d', d);
438
+ return path
439
+ }
440
+
441
+ finishAppendMidiEvent(){
442
+ this.pause()
443
+ let midiEvents = this.midiEvents.sort((a, b)=>a[1]-b[1])
444
+ let tempo = (60 / 120) * 10 ** 3
445
+ let ms = 0
446
+ let lastT = 0
447
+ this.midiTimes.push({ms:ms, t: 0, tempo: tempo})
448
+ midiEvents.forEach((midiEvent)=>{
449
+ let t = midiEvent[1]
450
+ ms += ((t- lastT) / this.timePreBeat) * tempo
451
+ if(midiEvent[0]==="set_tempo"){
452
+ tempo = (60 / midiEvent[3]) * 10 ** 3
453
+ this.midiTimes.push({ms:ms, t: t, tempo: tempo})
454
+ }
455
+ if(midiEvent[0]==="note"){
456
+ this.totalTimeMs = Math.max(this.totalTimeMs, ms + (midiEvent[3]/ this.timePreBeat)*tempo)
457
+ }else{
458
+ this.totalTimeMs = Math.max(this.totalTimeMs, ms);
459
+ }
460
+ lastT = t;
461
+ })
462
+ let x = (lastT/this.timePreBeat)*this.config.beatWidth;
463
+ this.trackMap.forEach((track, id)=>{
464
+ track.ccPaths.forEach((path, controller)=>{
465
+ let d = path.getAttribute("d");
466
+ d += `H ${x}`
467
+ path.setAttribute('d', d);
468
+ })
469
+ })
470
+ }
471
+
472
+ setPlayTime(t){
473
+ this.playTime = t
474
+ let x = Math.round((t/this.timePreBeat)*this.config.beatWidth)
475
+ this.timeLine.setAttribute('x1', `${x}`);
476
+ this.timeLine.setAttribute('y1', '0');
477
+ this.timeLine.setAttribute('x2', `${x}`);
478
+ this.timeLine.setAttribute('y2', `${this.config.noteHeight*128}`);
479
+
480
+ this.pianoRoll.scrollTo(Math.max(0, x - this.pianoRoll.offsetWidth/2), this.pianoRoll.scrollTop)
481
+
482
+ this.trackMap.forEach((track, id)=>{
483
+ let instrument = track.instrument
484
+ let cl = track.cl;
485
+ let patches = this.patches[cl]
486
+ let p = 0
487
+ for (let i = 0; i < patches.length ; i++){
488
+ let tp = patches[i]
489
+ if (t < tp[0])
490
+ break
491
+ p = tp[1]
492
+ }
493
+ if (cl === 9){
494
+ let drumKit = number2drum_kits[`${p}`];
495
+ if (!!drumKit)
496
+ instrument = drumKit + " Drum";
497
+ }else{
498
+ instrument = number2patch[p]
499
+ }
500
+ if (instrument !== track.instrument)
501
+ track.updateInstrument(instrument)
502
+ });
503
+
504
+ let dt = Date.now() - this.lastUpdateTime; // limit the update rate of ActiveNotes
505
+ if(this.playing && dt > 50){
506
+ let activeNotes = []
507
+ this.removeActiveNotes(this.activeNotes)
508
+ this.midiEvents.forEach((midiEvent)=>{
509
+ if(midiEvent[0] === "note"){
510
+ let time = midiEvent[1]
511
+ let duration = this.version==="v1"? midiEvent[3]:midiEvent[6]
512
+ let note = midiEvent[midiEvent.length - 1]
513
+ if(time <=this.playTime && time+duration>= this.playTime){
514
+ activeNotes.push(note)
515
+ }
516
+ }
517
+ });
518
+ this.addActiveNotes(activeNotes)
519
+ this.lastUpdateTime = Date.now();
520
+ }
521
+
522
+ }
523
+
524
+ setPlayTimeMs(ms){
525
+ this.playTimeMs = ms
526
+ let playTime = 0
527
+ for(let i =0;i<this.midiTimes.length;i++){
528
+ let midiTime = this.midiTimes[i]
529
+ if(midiTime.ms>=ms){
530
+ break;
531
+ }
532
+ playTime = midiTime.t + (ms-midiTime.ms) * this.timePreBeat / midiTime.tempo
533
+ }
534
+ this.setPlayTime(playTime)
535
+ }
536
+
537
+ addActiveNotes(notes){
538
+ notes.forEach((note)=>{
539
+ this.activeNotes.push(note)
540
+ note.classList.add('active');
541
+ });
542
+ }
543
+
544
+ removeActiveNotes(notes){
545
+ notes.forEach((note)=>{
546
+ let idx = this.activeNotes.indexOf(note)
547
+ if(idx>-1)
548
+ this.activeNotes.splice(idx, 1);
549
+ note.classList.remove('active');
550
+ });
551
+ }
552
+
553
+ play(){
554
+ this.playing = true;
555
+ }
556
+
557
+ pause(){
558
+ this.removeActiveNotes(this.activeNotes)
559
+ this.playing = false;
560
+ }
561
+
562
+
563
+ bindAudioPlayer(audio){
564
+ this.pause()
565
+ audio.addEventListener("play", (event)=>{
566
+ this.play()
567
+ })
568
+ audio.addEventListener("pause", (event)=>{
569
+ this.pause()
570
+ })
571
+ audio.addEventListener("loadedmetadata", (event)=>{
572
+ //I don't know why the calculated totalTimeMs is different from audio.duration*10**3
573
+ this.totalTimeMs = audio.duration*10**3;
574
+ })
575
+ }
576
+
577
+ bindWaveformCursor(cursor){
578
+ let self = this;
579
+ const callback = function(mutationsList, observer) {
580
+ for(let mutation of mutationsList) {
581
+ if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
582
+ let progress = parseFloat(mutation.target.style.left.slice(0,-1))*0.01;
583
+ if(!isNaN(progress)){
584
+ self.setPlayTimeMs(progress*self.totalTimeMs);
585
+ }
586
+ }
587
+ }
588
+ };
589
+ const observer = new MutationObserver(callback);
590
+ observer.observe(cursor, {
591
+ attributes: true,
592
+ attributeFilter: ['style']
593
+ });
594
+ }
595
+ }
596
+
597
+ customElements.define('midi-visualizer', MidiVisualizer);
598
+
599
+ (()=>{
600
+ function midi_visualizer_setup(idx, midi_visualizer){
601
+ let midi_visualizer_container_inited = null
602
+ let midi_audio_audio_inited = null;
603
+ let midi_audio_cursor_inited = null;
604
+ onUiUpdate((m)=>{
605
+ let app = gradioApp()
606
+ let midi_visualizer_container = app.querySelector(`#midi_visualizer_container_${idx}`);
607
+ if(!!midi_visualizer_container && midi_visualizer_container_inited!== midi_visualizer_container){
608
+ midi_visualizer_container.appendChild(midi_visualizer)
609
+ midi_visualizer_container_inited = midi_visualizer_container;
610
+ }
611
+ let midi_audio = app.querySelector(`#midi_audio_${idx}`);
612
+ if (!!midi_audio){
613
+ let midi_audio_cursor = midi_audio.deepQuerySelector(".cursor");
614
+ if(!!midi_audio_cursor && midi_audio_cursor_inited!==midi_audio_cursor){
615
+ midi_visualizer.bindWaveformCursor(midi_audio_cursor)
616
+ midi_audio_cursor_inited = midi_audio_cursor
617
+ }
618
+ let midi_audio_waveform = midi_audio.deepQuerySelector("#waveform");
619
+ if(!!midi_audio_waveform){
620
+ let midi_audio_audio = midi_audio_waveform.deepQuerySelector("audio");
621
+ if(!!midi_audio_audio && midi_audio_audio_inited!==midi_audio_audio){
622
+ midi_visualizer.bindAudioPlayer(midi_audio_audio)
623
+ midi_audio_audio_inited = midi_audio_audio
624
+ }
625
+ }
626
+ }
627
+ });
628
+ }
629
+
630
+ let midi_visualizers = []
631
+ for (let i = 0; i < MIDI_OUTPUT_BATCH_SIZE ; i++){
632
+ let midi_visualizer = document.createElement('midi-visualizer');
633
+ midi_visualizers.push(midi_visualizer);
634
+ midi_visualizer_setup(i, midi_visualizer)
635
+ }
636
+
637
+ let hasProgressBar = false;
638
+ let output_tabs_inited = null;
639
+ onUiUpdate((m)=>{
640
+ let app = gradioApp()
641
+ let output_tabs = app.querySelector("#output_tabs");
642
+ if(!!output_tabs && output_tabs_inited!== output_tabs){
643
+ output_tabs_inited = output_tabs;
644
+ }
645
+ });
646
+
647
+ function createProgressBar(progressbarContainer){
648
+ let parentProgressbar = progressbarContainer.parentNode;
649
+ let divProgress = document.createElement('div');
650
+ divProgress.className='progressDiv';
651
+ let rect = progressbarContainer.getBoundingClientRect();
652
+ divProgress.style.width = rect.width + "px";
653
+ divProgress.style.background = "#b4c0cc";
654
+ divProgress.style.borderRadius = "8px";
655
+ let divInner = document.createElement('div');
656
+ divInner.className='progress';
657
+ divInner.style.color = "white";
658
+ divInner.style.background = "#0060df";
659
+ divInner.style.textAlign = "right";
660
+ divInner.style.fontWeight = "bold";
661
+ divInner.style.borderRadius = "8px";
662
+ divInner.style.height = "20px";
663
+ divInner.style.lineHeight = "20px";
664
+ divInner.style.paddingRight = "8px"
665
+ divInner.style.width = "0%";
666
+ divProgress.appendChild(divInner);
667
+ parentProgressbar.insertBefore(divProgress, progressbarContainer);
668
+ hasProgressBar = true;
669
+ }
670
+
671
+ function removeProgressBar(progressbarContainer){
672
+ let parentProgressbar = progressbarContainer.parentNode;
673
+ let divProgress = parentProgressbar.querySelector(".progressDiv");
674
+ parentProgressbar.removeChild(divProgress);
675
+ hasProgressBar = false;
676
+ }
677
+
678
+ function setProgressBar(progress, total){
679
+ if (!hasProgressBar)
680
+ createProgressBar(output_tabs_inited)
681
+ if (hasProgressBar && total === 0){
682
+ removeProgressBar(output_tabs_inited)
683
+ return
684
+ }
685
+ let parentProgressbar = output_tabs_inited.parentNode;
686
+ // let divProgress = parentProgressbar.querySelector(".progressDiv");
687
+ let divInner = parentProgressbar.querySelector(".progress");
688
+ if(total===0)
689
+ total = 1;
690
+ divInner.style.width = `${(progress/total)*100}%`;
691
+ divInner.textContent = `${progress}/${total}`;
692
+ }
693
+
694
+ onMsgReceive((msgs)=>{
695
+ for(let msg of msgs){
696
+ if(msg instanceof Array){
697
+ msg.forEach((o)=>{handleMsg(o)});
698
+ }else{
699
+ handleMsg(msg);
700
+ }
701
+ }
702
+ })
703
+ function handleMsg(msg){
704
+ let idx;
705
+ switch (msg.name) {
706
+ case "visualizer_clear":
707
+ idx = msg.data[0];
708
+ let ver = msg.data[1];
709
+ midi_visualizers[idx].clearMidiEvents(false);
710
+ midi_visualizers[idx].version = ver;
711
+ break;
712
+ case "visualizer_append":
713
+ idx = msg.data[0];
714
+ let events = msg.data[1];
715
+ events.forEach( value => {
716
+ midi_visualizers[idx].appendMidiEvent(value);
717
+ })
718
+ break;
719
+ case "visualizer_end":
720
+ idx = msg.data;
721
+ midi_visualizers[idx].finishAppendMidiEvent()
722
+ midi_visualizers[idx].setPlayTime(0);
723
+ break;
724
+ case "progress":
725
+ let progress = msg.data[0]
726
+ let total = msg.data[1]
727
+ setProgressBar(progress, total)
728
+ break;
729
+ default:
730
+ }
731
+ }
732
+ })();
midi_model.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Union, Dict, Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import tqdm
9
+ from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
10
+ from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel
11
+
12
+ from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
13
+
14
+ config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
15
+
16
+
17
+ class MIDIModelConfig(PretrainedConfig):
18
+ model_type = "midi_model"
19
+
20
+ def __init__(self,
21
+ tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
22
+ net_config: Union[LlamaConfig, Dict]=None,
23
+ net_token_config: Union[LlamaConfig, Dict]=None,
24
+ **kwargs):
25
+ super().__init__(**kwargs)
26
+ if tokenizer:
27
+ if isinstance(tokenizer, dict):
28
+ self.tokenizer = MIDITokenizer(tokenizer["version"])
29
+ self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"])
30
+ else:
31
+ self.tokenizer = tokenizer
32
+ else:
33
+ self.tokenizer = MIDITokenizer()
34
+ if net_config:
35
+ if isinstance(net_config, dict):
36
+ self.net_config = LlamaConfig(**net_config)
37
+ else:
38
+ self.net_config = net_config
39
+ else:
40
+ self.net_config = LlamaConfig()
41
+ if net_token_config:
42
+ if isinstance(net_token_config, dict):
43
+ self.net_token_config = LlamaConfig(**net_token_config)
44
+ else:
45
+ self.net_token_config = net_token_config
46
+ else:
47
+ self.net_token_config = LlamaConfig()
48
+ self.n_embd = self.net_token_config.hidden_size
49
+
50
+ def to_dict(self) -> Dict[str, Any]:
51
+ d = super().to_dict()
52
+ d["tokenizer"] = self.tokenizer.to_dict()
53
+ return d
54
+
55
+ def __str__(self):
56
+ d = {
57
+ "net": self.net_config.to_json_string(use_diff=False),
58
+ "net_token": self.net_token_config.to_json_string(use_diff=False)
59
+ }
60
+ return json.dumps(d, indent=4)
61
+
62
+ @staticmethod
63
+ def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
64
+ tokenizer = MIDITokenizer(tokenizer_ver)
65
+ tokenizer.set_optimise_midi(optimise_midi)
66
+ net_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
67
+ hidden_size=n_embd, num_attention_heads=n_head,
68
+ num_hidden_layers=n_layer, intermediate_size=n_inner,
69
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
70
+ use_cache=False)
71
+ net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
72
+ hidden_size=n_embd, num_attention_heads=n_head // 4,
73
+ num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
74
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
75
+ use_cache=False)
76
+ return MIDIModelConfig(tokenizer, net_config, net_token_config)
77
+
78
+ @staticmethod
79
+ def from_name(name="tv2o-medium"):
80
+ tv, size = name.split("-")
81
+ tv = tv[1:]
82
+ if tv[-1] == "o":
83
+ o = True
84
+ tv = tv[:-1]
85
+ else:
86
+ o = False
87
+ if tv not in ["v1", "v2"]:
88
+ raise ValueError(f"Unknown tokenizer version {tv}")
89
+ if size == "medium":
90
+ return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
91
+ n_layer=12, n_head=16, n_embd=1024, n_inner=4096)
92
+ elif size == "large":
93
+ return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
94
+ n_layer=24, n_head=16, n_embd=1024, n_inner=4096)
95
+ else:
96
+ raise ValueError(f"Unknown model size {size}")
97
+
98
+
99
+ class MIDIModel(PreTrainedModel):
100
+ config_class = MIDIModelConfig
101
+
102
+ def __init__(self, config: MIDIModelConfig, *args, **kwargs):
103
+ super(MIDIModel, self).__init__(config, *args, **kwargs)
104
+ self.tokenizer = config.tokenizer
105
+ self.net = LlamaModel(config.net_config)
106
+ self.net_token = LlamaModel(config.net_token_config)
107
+ self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
108
+
109
+ def load_merge_lora(self, model_id):
110
+ peft_config = PeftConfig.from_pretrained(model_id)
111
+ model = LoraModel(self, peft_config, adapter_name="default")
112
+ adapter_state_dict = load_peft_weights(model_id, device=str(self.device))
113
+ set_peft_model_state_dict(self, adapter_state_dict, "default")
114
+ return model.merge_and_unload()
115
+
116
+ def forward_token(self, hidden_state=None, x=None, cache=None):
117
+ """
118
+
119
+ :param hidden_state: (batch_size, n_embd)
120
+ :param x: (batch_size, token_sequence_length)
121
+ :param cache: Cache
122
+ :return: (batch_size, 1 + token_sequence_length, vocab_size)
123
+ """
124
+ if hidden_state is not None:
125
+ #if you use cache, you don't need to pass in hidden_state
126
+ hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
127
+ if x is not None:
128
+ x = self.net_token.embed_tokens(x)
129
+ if hidden_state is not None:
130
+ x = torch.cat([hidden_state, x], dim=1)
131
+ hidden_state = x
132
+ hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
133
+ past_key_values=cache,
134
+ use_cache=cache is not None).last_hidden_state
135
+ return self.lm_head(hidden_state)
136
+
137
+ def forward(self, x, cache = None):
138
+ """
139
+ :param x: (batch_size, midi_sequence_length, token_sequence_length)
140
+ :param cache: Cache
141
+ :return: hidden (batch_size, midi_sequence_length, n_embd)
142
+ """
143
+
144
+ # merge token sequence
145
+ x = self.net.embed_tokens(x)
146
+ x = x.sum(dim=-2)
147
+ x = self.net.forward(inputs_embeds=x,
148
+ past_key_values=cache,
149
+ use_cache=cache is not None)
150
+ return x.last_hidden_state
151
+
152
+ def sample_top_p_k(self, probs, p, k, generator=None):
153
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
154
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
155
+ mask = probs_sum - probs_sort > p
156
+ probs_sort[mask] = 0.0
157
+ mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
158
+ mask[:k] = 1
159
+ probs_sort = probs_sort * mask
160
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
161
+ shape = probs_sort.shape
162
+ next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]),
163
+ num_samples=1, generator=generator).reshape(*shape[:-1], 1)
164
+ next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
165
+ return next_token
166
+
167
+ @torch.inference_mode()
168
+ def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
169
+ tokenizer = self.tokenizer
170
+ max_token_seq = tokenizer.max_token_seq
171
+ if prompt is None:
172
+ input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
173
+ input_tensor[0, 0] = tokenizer.bos_id # bos
174
+ input_tensor = input_tensor.unsqueeze(0)
175
+ input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
176
+ else:
177
+ if len(prompt.shape) == 2:
178
+ prompt = prompt[None, :]
179
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
180
+ elif prompt.shape[0] == 1:
181
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
182
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
183
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
184
+ prompt = prompt[..., :max_token_seq]
185
+ if prompt.shape[-1] < max_token_seq:
186
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
187
+ mode="constant", constant_values=tokenizer.pad_id)
188
+ input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
189
+
190
+ cur_len = input_tensor.shape[1]
191
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
192
+ cache1 = DynamicCache()
193
+ past_len = 0
194
+ with bar:
195
+ while cur_len < max_len:
196
+ end = [False] * batch_size
197
+ hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
198
+ next_token_seq = None
199
+ event_names = [""] * batch_size
200
+ cache2 = DynamicCache()
201
+ for i in range(max_token_seq):
202
+ mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
203
+ for b in range(batch_size):
204
+ if end[b]:
205
+ mask[b, tokenizer.pad_id] = 1
206
+ continue
207
+ if i == 0:
208
+ mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
209
+ else:
210
+ param_names = tokenizer.events[event_names[b]]
211
+ if i > len(param_names):
212
+ mask[b, tokenizer.pad_id] = 1
213
+ continue
214
+ mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
215
+ mask = mask.unsqueeze(1)
216
+ x = next_token_seq
217
+ if i != 0:
218
+ # cached
219
+ hidden = None
220
+ x = x[:, -1:]
221
+ logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
222
+ scores = torch.softmax(logits / temp, dim=-1) * mask
223
+ samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
224
+ if i == 0:
225
+ next_token_seq = samples
226
+ for b in range(batch_size):
227
+ if end[b]:
228
+ continue
229
+ eid = samples[b].item()
230
+ if eid == tokenizer.eos_id:
231
+ end[b] = True
232
+ else:
233
+ event_names[b] = tokenizer.id_events[eid]
234
+ else:
235
+ next_token_seq = torch.cat([next_token_seq, samples], dim=1)
236
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
237
+ break
238
+
239
+ if next_token_seq.shape[1] < max_token_seq:
240
+ next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
241
+ "constant", value=tokenizer.pad_id)
242
+ next_token_seq = next_token_seq.unsqueeze(1)
243
+ input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
244
+ past_len = cur_len
245
+ cur_len += 1
246
+ bar.update(1)
247
+
248
+ if all(end):
249
+ break
250
+ return input_tensor.cpu().numpy()
midi_synthesizer.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Lock
2
+
3
+ import fluidsynth
4
+ import numpy as np
5
+
6
+
7
+ class MidiSynthesizer:
8
+ def __init__(self, soundfont_path, sample_rate=44100):
9
+ self.soundfont_path = soundfont_path
10
+ self.sample_rate = sample_rate
11
+ fl = fluidsynth.Synth(samplerate=float(sample_rate))
12
+ sfid = fl.sfload(soundfont_path)
13
+ self.devices = [[fl, sfid, False]]
14
+ self.file_lock = Lock()
15
+
16
+ def get_fluidsynth(self):
17
+ for device in self.devices:
18
+ if not device[2]:
19
+ device[2] = True
20
+ return device
21
+ with self.file_lock:
22
+ fl = fluidsynth.Synth(samplerate=float(self.sample_rate))
23
+ sfid = fl.sfload(self.soundfont_path)
24
+ device = [fl, sfid, True]
25
+ self.devices.append(device)
26
+ return device
27
+
28
+ def release_fluidsynth(self, device):
29
+ device[0].system_reset()
30
+ device[0].get_samples(self.sample_rate*5) # wait for silence
31
+ device[2] = False
32
+
33
+ def synthesis(self, midi_opus):
34
+ ticks_per_beat = midi_opus[0]
35
+ event_list = []
36
+ for track_idx, track in enumerate(midi_opus[1:]):
37
+ abs_t = 0
38
+ for event in track:
39
+ abs_t += event[1]
40
+ event_new = [*event]
41
+ event_new[1] = abs_t
42
+ event_list.append(event_new)
43
+ event_list = sorted(event_list, key=lambda e: e[1])
44
+
45
+ tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
46
+ ss = np.empty((0, 2), dtype=np.int16)
47
+ device = self.get_fluidsynth()
48
+ fl, sfid = device[:-1]
49
+ last_t = 0
50
+ for c in range(16):
51
+ fl.program_select(c, sfid, 128 if c == 9 else 0, 0)
52
+ for event in event_list:
53
+ name = event[0]
54
+ sample_len = int(((event[1] / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
55
+ sample_len -= int(((last_t / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
56
+ last_t = event[1]
57
+ if sample_len > 0:
58
+ sample = fl.get_samples(sample_len).reshape(sample_len, 2)
59
+ ss = np.concatenate([ss, sample])
60
+ if name == "set_tempo":
61
+ tempo = event[2]
62
+ elif name == "patch_change":
63
+ c, p = event[2:4]
64
+ fl.program_select(c, sfid, 128 if c == 9 else 0, p)
65
+ elif name == "control_change":
66
+ c, cc, v = event[2:5]
67
+ fl.cc(c, cc, v)
68
+ elif name == "note_on" and event[3] > 0:
69
+ c, p, v = event[2:5]
70
+ fl.noteon(c, p, v)
71
+ elif name == "note_off" or (name == "note_on" and event[3] == 0):
72
+ c, p = event[2:4]
73
+ fl.noteoff(c, p)
74
+
75
+ self.release_fluidsynth(device)
76
+ if ss.shape[0] > 0:
77
+ max_val = np.abs(ss).max()
78
+ if max_val != 0:
79
+ ss = (ss / max_val) * np.iinfo(np.int16).max
80
+ ss = ss.astype(np.int16)
81
+ return ss
midi_tokenizer.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Dict, Any
3
+
4
+ import PIL.Image
5
+ import numpy as np
6
+
7
+
8
+ class MIDITokenizerV1:
9
+ def __init__(self):
10
+ self.version = "v1"
11
+ self.optimise_midi = False
12
+ self.vocab_size = 0
13
+
14
+ def allocate_ids(size):
15
+ ids = [self.vocab_size + i for i in range(size)]
16
+ self.vocab_size += size
17
+ return ids
18
+
19
+ self.pad_id = allocate_ids(1)[0]
20
+ self.bos_id = allocate_ids(1)[0]
21
+ self.eos_id = allocate_ids(1)[0]
22
+ self.events = {
23
+ "note": ["time1", "time2", "track", "duration", "channel", "pitch", "velocity"],
24
+ "patch_change": ["time1", "time2", "track", "channel", "patch"],
25
+ "control_change": ["time1", "time2", "track", "channel", "controller", "value"],
26
+ "set_tempo": ["time1", "time2", "track", "bpm"],
27
+ }
28
+ self.event_parameters = {
29
+ "time1": 128, "time2": 16, "duration": 2048, "track": 128, "channel": 16, "pitch": 128, "velocity": 128,
30
+ "patch": 128, "controller": 128, "value": 128, "bpm": 256
31
+ }
32
+ self.event_ids = {e: allocate_ids(1)[0] for e in self.events.keys()}
33
+ self.id_events = {i: e for e, i in self.event_ids.items()}
34
+ self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
35
+ self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
36
+
37
+ def to_dict(self) -> Dict[str, Any]:
38
+ d = {
39
+ "version":self.version,
40
+ "optimise_midi":self.optimise_midi,
41
+ "vocab_size": self.vocab_size,
42
+ "events": self.events,
43
+ "event_parameters": self.event_parameters,
44
+ "max_token_seq": self.max_token_seq,
45
+ "pad_id": self.pad_id,
46
+ "bos_id": self.bos_id,
47
+ "eos_id": self.eos_id,
48
+ }
49
+ return d
50
+
51
+ def set_optimise_midi(self, optimise_midi=True):
52
+ self.optimise_midi = optimise_midi
53
+
54
+ @staticmethod
55
+ def tempo2bpm(tempo):
56
+ tempo = tempo / 10 ** 6 # us to s
57
+ bpm = 60 / tempo
58
+ return bpm
59
+
60
+ @staticmethod
61
+ def bpm2tempo(bpm):
62
+ if bpm == 0:
63
+ bpm = 1
64
+ tempo = int((60 / bpm) * 10 ** 6)
65
+ return tempo
66
+
67
+ def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
68
+ remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
69
+ if remap_track_channel is None: # set default value
70
+ remap_track_channel = self.optimise_midi
71
+ if add_default_instr is None:
72
+ add_default_instr = self.optimise_midi
73
+ if remove_empty_channels is None:
74
+ remove_empty_channels = self.optimise_midi
75
+
76
+ ticks_per_beat = midi_score[0]
77
+ event_list = {}
78
+ track_idx_map = {i: dict() for i in range(16)}
79
+ track_idx_dict = {}
80
+ channels = []
81
+ patch_channels = []
82
+ empty_channels = [True] * 16
83
+ channel_note_tracks = {i: list() for i in range(16)}
84
+ for track_idx, track in enumerate(midi_score[1:129]):
85
+ last_notes = {}
86
+ patch_dict = {}
87
+ control_dict = {}
88
+ last_tempo = 0
89
+ for event in track:
90
+ if event[0] not in self.events:
91
+ continue
92
+ c = -1
93
+ t = round(16 * event[1] / ticks_per_beat) # quantization
94
+ new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
95
+ if event[0] == "note":
96
+ c = event[3]
97
+ if c > 15 or c < 0:
98
+ continue
99
+ empty_channels[c] = False
100
+ track_idx_dict.setdefault(c, track_idx)
101
+ note_tracks = channel_note_tracks[c]
102
+ if track_idx not in note_tracks:
103
+ note_tracks.append(track_idx)
104
+ new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
105
+ elif event[0] == "set_tempo":
106
+ if new_event[4] == 0: # invalid tempo
107
+ continue
108
+ bpm = int(self.tempo2bpm(new_event[4]))
109
+ new_event[4] = min(bpm, 255)
110
+ if event[0] == "note":
111
+ key = tuple(new_event[:4] + new_event[5:-1])
112
+ else:
113
+ key = tuple(new_event[:-1])
114
+ if event[0] == "patch_change":
115
+ c, p = event[2:]
116
+ if c > 15 or c < 0:
117
+ continue
118
+ last_p = patch_dict.setdefault(c, None)
119
+ if last_p == p:
120
+ continue
121
+ patch_dict[c] = p
122
+ if c not in patch_channels:
123
+ patch_channels.append(c)
124
+ elif event[0] == "control_change":
125
+ c, cc, v = event[2:]
126
+ if c > 15 or c < 0:
127
+ continue
128
+ last_v = control_dict.setdefault((c, cc), 0)
129
+ if abs(last_v - v) < cc_eps:
130
+ continue
131
+ control_dict[(c, cc)] = v
132
+ elif event[0] == "set_tempo":
133
+ tempo = new_event[-1]
134
+ if abs(last_tempo - tempo) < tempo_eps:
135
+ continue
136
+ last_tempo = tempo
137
+
138
+ if c != -1:
139
+ if c not in channels:
140
+ channels.append(c)
141
+ tr_map = track_idx_map[c]
142
+ if track_idx not in tr_map:
143
+ tr_map[track_idx] = 0
144
+
145
+ if event[0] == "note": # to eliminate note overlap due to quantization
146
+ cp = tuple(new_event[5:7])
147
+ if cp in last_notes:
148
+ last_note_key, last_note = last_notes[cp]
149
+ last_t = last_note[1] * 16 + last_note[2]
150
+ last_note[4] = max(0, min(last_note[4], t - last_t))
151
+ if last_note[4] == 0:
152
+ event_list.pop(last_note_key)
153
+ last_notes[cp] = (key, new_event)
154
+ event_list[key] = new_event
155
+ event_list = list(event_list.values())
156
+
157
+ empty_channels = [c for c in channels if empty_channels[c]]
158
+
159
+ if remap_track_channel:
160
+ patch_channels = []
161
+ channels_count = 0
162
+ channels_map = {9: 9} if 9 in channels else {}
163
+ if remove_empty_channels:
164
+ channels = sorted(channels, key=lambda x: 1 if x in empty_channels else 0)
165
+ for c in channels:
166
+ if c == 9:
167
+ continue
168
+ channels_map[c] = channels_count
169
+ channels_count += 1
170
+ if channels_count == 9:
171
+ channels_count = 10
172
+ channels = list(channels_map.values())
173
+
174
+ track_count = 0
175
+ track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
176
+ for c in track_idx_map_order: # tracks not to remove
177
+ if remove_empty_channels and c in empty_channels:
178
+ continue
179
+ tr_map = track_idx_map[c]
180
+ for track_idx in tr_map:
181
+ note_tracks = channel_note_tracks[c]
182
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
183
+ continue
184
+ track_count += 1
185
+ tr_map[track_idx] = track_count
186
+ for c in track_idx_map_order: # tracks to remove
187
+ if not (remove_empty_channels and c in empty_channels):
188
+ continue
189
+ tr_map = track_idx_map[c]
190
+ for track_idx in tr_map:
191
+ note_tracks = channel_note_tracks[c]
192
+ if not (len(note_tracks) != 0 and track_idx not in note_tracks):
193
+ continue
194
+ track_count += 1
195
+ tr_map[track_idx] = track_count
196
+
197
+ empty_channels = [channels_map[c] for c in empty_channels]
198
+ track_idx_dict = {}
199
+ for event in event_list:
200
+ name = event[0]
201
+ track_idx = event[3]
202
+ if name == "note":
203
+ c = event[5]
204
+ event[5] = channels_map[c]
205
+ event[3] = track_idx_map[c][track_idx]
206
+ track_idx_dict.setdefault(event[5], event[3])
207
+ # setdefault, so the track_idx is first of the channel
208
+ elif name == "set_tempo":
209
+ event[3] = 0
210
+ elif name == "control_change" or name == "patch_change":
211
+ c = event[4]
212
+ event[4] = channels_map[c]
213
+ tr_map = track_idx_map[c]
214
+ # move the event to first track of the channel if it's original track is empty
215
+ note_tracks = channel_note_tracks[c]
216
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
217
+ track_idx = channel_note_tracks[c][0]
218
+ new_track_idx = tr_map[track_idx]
219
+ event[3] = new_track_idx
220
+ if name == "patch_change" and event[4] not in patch_channels:
221
+ patch_channels.append(event[4])
222
+
223
+ if add_default_instr:
224
+ for c in channels:
225
+ if c not in patch_channels and c in track_idx_dict:
226
+ event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
227
+
228
+ events_name_order = {"set_tempo": 0, "patch_change": 1, "control_change": 2, "note": 3}
229
+ events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
230
+ event_list = sorted(event_list, key=events_order)
231
+
232
+ setup_events = {}
233
+ notes_in_setup = False
234
+ for i, event in enumerate(event_list): # optimise setup
235
+ new_event = [*event]
236
+ if event[0] != "note":
237
+ new_event[1] = 0
238
+ new_event[2] = 0
239
+ has_next = False
240
+ has_pre = False
241
+ if i < len(event_list) - 1:
242
+ next_event = event_list[i + 1]
243
+ has_next = event[1] + event[2] == next_event[1] + next_event[2]
244
+ if notes_in_setup and i > 0:
245
+ pre_event = event_list[i - 1]
246
+ has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
247
+ if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
248
+ event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
249
+ break
250
+ else:
251
+ if event[0] == "note":
252
+ notes_in_setup = True
253
+ key = tuple([event[0]] + event[3:-2])
254
+ else:
255
+ key = tuple([event[0]] + event[3:-1])
256
+ setup_events[key] = new_event
257
+
258
+ last_t1 = 0
259
+ midi_seq = []
260
+ for event in event_list:
261
+ if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
262
+ continue
263
+ cur_t1 = event[1]
264
+ event[1] = event[1] - last_t1
265
+ tokens = self.event2tokens(event)
266
+ if not tokens:
267
+ continue
268
+ midi_seq.append(tokens)
269
+ last_t1 = cur_t1
270
+
271
+ if add_bos_eos:
272
+ bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
273
+ eos = [self.eos_id] + [self.pad_id] * (self.max_token_seq - 1)
274
+ midi_seq = [bos] + midi_seq + [eos]
275
+ return midi_seq
276
+
277
+ def event2tokens(self, event):
278
+ name = event[0]
279
+ params = event[1:]
280
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
281
+ return []
282
+ tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
283
+ for i, p in enumerate(self.events[name])]
284
+ tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
285
+ return tokens
286
+
287
+ def tokens2event(self, tokens):
288
+ if tokens[0] not in self.id_events:
289
+ return []
290
+ name = self.id_events[tokens[0]]
291
+ if len(tokens) <= len(self.events[name]):
292
+ return []
293
+ params = tokens[1:]
294
+ params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
295
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
296
+ return []
297
+ event = [name] + params
298
+ return event
299
+
300
+ def detokenize(self, midi_seq):
301
+ ticks_per_beat = 480
302
+ tracks_dict = {}
303
+ t1 = 0
304
+ for tokens in midi_seq:
305
+ if tokens[0] in self.id_events:
306
+ event = self.tokens2event(tokens)
307
+ if not event:
308
+ continue
309
+ name = event[0]
310
+ if name == "set_tempo":
311
+ event[4] = self.bpm2tempo(event[4])
312
+ if event[0] == "note":
313
+ event[4] = int(event[4] * ticks_per_beat / 16)
314
+ t1 += event[1]
315
+ t = t1 * 16 + event[2]
316
+ t = int(t * ticks_per_beat / 16)
317
+ track_idx = event[3]
318
+ if track_idx not in tracks_dict:
319
+ tracks_dict[track_idx] = []
320
+ tracks_dict[track_idx].append([event[0], t] + event[4:])
321
+ tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
322
+
323
+ for i in range(len(tracks)): # to eliminate note overlap
324
+ track = tracks[i]
325
+ track = sorted(track, key=lambda e: e[1])
326
+ last_note_t = {}
327
+ zero_len_notes = []
328
+ for e in reversed(track):
329
+ if e[0] == "note":
330
+ t, d, c, p = e[1:5]
331
+ key = (c, p)
332
+ if key in last_note_t:
333
+ d = min(d, max(last_note_t[key] - t, 0))
334
+ last_note_t[key] = t
335
+ e[2] = d
336
+ if d == 0:
337
+ zero_len_notes.append(e)
338
+ for e in zero_len_notes:
339
+ track.remove(e)
340
+ tracks[i] = track
341
+ return [ticks_per_beat, *tracks]
342
+
343
+ def midi2img(self, midi_score):
344
+ ticks_per_beat = midi_score[0]
345
+ notes = []
346
+ max_time = 1
347
+ track_num = len(midi_score[1:])
348
+ for track_idx, track in enumerate(midi_score[1:]):
349
+ for event in track:
350
+ t = round(16 * event[1] / ticks_per_beat)
351
+ if event[0] == "note":
352
+ d = max(1, round(16 * event[2] / ticks_per_beat))
353
+ c, p = event[3:5]
354
+ max_time = max(max_time, t + d + 1)
355
+ notes.append((track_idx, c, p, t, d))
356
+ img = np.zeros((128, max_time, 3), dtype=np.uint8)
357
+ colors = {(i, j): np.random.randint(50, 256, 3) for i in range(track_num) for j in range(16)}
358
+ for note in notes:
359
+ tr, c, p, t, d = note
360
+ img[p, t: t + d] = colors[(tr, c)]
361
+ img = PIL.Image.fromarray(np.flip(img, 0))
362
+ return img
363
+
364
+ def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
365
+ max_track_shift=0, max_channel_shift=16):
366
+ pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
367
+ vel_shift = random.randint(-max_vel_shift, max_vel_shift)
368
+ cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
369
+ bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
370
+ track_shift = random.randint(0, max_track_shift)
371
+ channel_shift = random.randint(0, max_channel_shift)
372
+ midi_seq_new = []
373
+ for tokens in midi_seq:
374
+ tokens_new = [*tokens]
375
+ if tokens[0] in self.id_events:
376
+ name = self.id_events[tokens[0]]
377
+ for i, pn in enumerate(self.events[name]):
378
+ if pn == "track":
379
+ tr = tokens[1 + i] - self.parameter_ids[pn][0]
380
+ tr += track_shift
381
+ tr = tr % self.event_parameters[pn]
382
+ tokens_new[1 + i] = self.parameter_ids[pn][tr]
383
+ elif pn == "channel":
384
+ c = tokens[1 + i] - self.parameter_ids[pn][0]
385
+ c0 = c
386
+ c += channel_shift
387
+ c = c % self.event_parameters[pn]
388
+ if c0 == 9:
389
+ c = 9
390
+ elif c == 9:
391
+ c = (9 + channel_shift) % self.event_parameters[pn]
392
+ tokens_new[1 + i] = self.parameter_ids[pn][c]
393
+
394
+ if name == "note":
395
+ c = tokens[5] - self.parameter_ids["channel"][0]
396
+ p = tokens[6] - self.parameter_ids["pitch"][0]
397
+ v = tokens[7] - self.parameter_ids["velocity"][0]
398
+ if c != 9: # no shift for drums
399
+ p += pitch_shift
400
+ if not 0 <= p < 128:
401
+ return midi_seq
402
+ v += vel_shift
403
+ v = max(1, min(127, v))
404
+ tokens_new[6] = self.parameter_ids["pitch"][p]
405
+ tokens_new[7] = self.parameter_ids["velocity"][v]
406
+ elif name == "control_change":
407
+ cc = tokens[5] - self.parameter_ids["controller"][0]
408
+ val = tokens[6] - self.parameter_ids["value"][0]
409
+ if cc in [1, 2, 7, 11]:
410
+ val += cc_val_shift
411
+ val = max(1, min(127, val))
412
+ tokens_new[6] = self.parameter_ids["value"][val]
413
+ elif name == "set_tempo":
414
+ bpm = tokens[4] - self.parameter_ids["bpm"][0]
415
+ bpm += bpm_shift
416
+ bpm = max(1, min(255, bpm))
417
+ tokens_new[4] = self.parameter_ids["bpm"][bpm]
418
+ midi_seq_new.append(tokens_new)
419
+ return midi_seq_new
420
+
421
+ def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
422
+ notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
423
+ note_window_size=16):
424
+ total_notes = 0
425
+ channels = []
426
+ time_hist = [0] * 16
427
+ note_windows = {}
428
+ notes_sametime = []
429
+ notes_density_list = []
430
+ tonality_list = []
431
+ notes_bandwidth_list = []
432
+ instruments = {}
433
+ piano_channels = []
434
+ abs_t1 = 0
435
+ last_t = 0
436
+ for tsi, tokens in enumerate(midi_seq):
437
+ event = self.tokens2event(tokens)
438
+ if not event:
439
+ continue
440
+ t1, t2, tr = event[1:4]
441
+ abs_t1 += t1
442
+ t = abs_t1 * 16 + t2
443
+ c = None
444
+ if event[0] == "note":
445
+ d, c, p, v = event[4:]
446
+ total_notes += 1
447
+ time_hist[t2] += 1
448
+ if c != 9: # ignore drum channel
449
+ if c not in instruments:
450
+ instruments[c] = 0
451
+ if c not in piano_channels:
452
+ piano_channels.append(c)
453
+ note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
454
+ if last_t != t:
455
+ notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
456
+ notes_sametime_p = [p_ for _, p_ in notes_sametime]
457
+ if len(notes_sametime) > 0:
458
+ notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
459
+ notes_sametime.append((t + d - 1, p))
460
+ elif event[0] == "patch_change":
461
+ c, p = event[4:]
462
+ instruments[c] = p
463
+ if p == 0 and c not in piano_channels:
464
+ piano_channels.append(c)
465
+ if c is not None and c not in channels:
466
+ channels.append(c)
467
+ last_t = t
468
+ reasons = []
469
+ if total_notes < total_notes_min:
470
+ reasons.append("total_min")
471
+ if total_notes > total_notes_max:
472
+ reasons.append("total_max")
473
+ if len(note_windows) == 0 and total_notes > 0:
474
+ reasons.append("drum_only")
475
+ if reasons:
476
+ return False, reasons
477
+ time_hist = sorted(time_hist, reverse=True)
478
+ alignment = sum(time_hist[:2]) / total_notes
479
+ for notes in note_windows.values():
480
+ key_hist = [0] * 12
481
+ for p in notes:
482
+ key_hist[p % 12] += 1
483
+ key_hist = sorted(key_hist, reverse=True)
484
+ tonality_list.append(sum(key_hist[:7]) / len(notes))
485
+ notes_density_list.append(len(notes) / note_window_size)
486
+ tonality_list = sorted(tonality_list)
487
+ tonality = sum(tonality_list) / len(tonality_list)
488
+ notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
489
+ notes_density = max(notes_density_list) if notes_density_list else 0
490
+ piano_ratio = len(piano_channels) / len(channels)
491
+ if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
492
+ piano_max = 1
493
+ if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
494
+ reasons.append("alignment")
495
+ if tonality < tonality_min: # check whether the music is tonal
496
+ reasons.append("tonality")
497
+ if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
498
+ reasons.append("bandwidth")
499
+ if not notes_density_min < notes_density < notes_density_max:
500
+ reasons.append("density")
501
+ if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
502
+ reasons.append("piano")
503
+ return not reasons, reasons
504
+
505
+
506
+ class MIDITokenizerV2:
507
+ def __init__(self):
508
+ self.version = "v2"
509
+ self.optimise_midi = False
510
+ self.vocab_size = 0
511
+
512
+ def allocate_ids(size):
513
+ ids = [self.vocab_size + i for i in range(size)]
514
+ self.vocab_size += size
515
+ return ids
516
+
517
+ self.pad_id = allocate_ids(1)[0]
518
+ self.bos_id = allocate_ids(1)[0]
519
+ self.eos_id = allocate_ids(1)[0]
520
+ self.events = {
521
+ "note": ["time1", "time2", "track", "channel", "pitch", "velocity", "duration"],
522
+ "patch_change": ["time1", "time2", "track", "channel", "patch"],
523
+ "control_change": ["time1", "time2", "track", "channel", "controller", "value"],
524
+ "set_tempo": ["time1", "time2", "track", "bpm"],
525
+ "time_signature": ["time1", "time2", "track", "nn", "dd"],
526
+ "key_signature": ["time1", "time2", "track", "sf", "mi"],
527
+ }
528
+ self.event_parameters = {
529
+ "time1": 128, "time2": 16, "duration": 2048, "track": 128, "channel": 16, "pitch": 128, "velocity": 128,
530
+ "patch": 128, "controller": 128, "value": 128, "bpm": 384, "nn": 16, "dd": 4, "sf": 15, "mi": 2
531
+ }
532
+ self.event_ids = {e: allocate_ids(1)[0] for e in self.events.keys()}
533
+ self.id_events = {i: e for e, i in self.event_ids.items()}
534
+ self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
535
+ self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
536
+
537
+ def to_dict(self) -> Dict[str, Any]:
538
+ d = {
539
+ "version":self.version,
540
+ "optimise_midi":self.optimise_midi,
541
+ "vocab_size": self.vocab_size,
542
+ "events": self.events,
543
+ "event_parameters": self.event_parameters,
544
+ "max_token_seq": self.max_token_seq,
545
+ "pad_id": self.pad_id,
546
+ "bos_id": self.bos_id,
547
+ "eos_id": self.eos_id,
548
+ }
549
+ return d
550
+
551
+ def set_optimise_midi(self, optimise_midi=True):
552
+ self.optimise_midi = optimise_midi
553
+
554
+ @staticmethod
555
+ def tempo2bpm(tempo):
556
+ tempo = tempo / 10 ** 6 # us to s
557
+ bpm = 60 / tempo
558
+ return bpm
559
+
560
+ @staticmethod
561
+ def bpm2tempo(bpm):
562
+ if bpm == 0:
563
+ bpm = 1
564
+ tempo = int((60 / bpm) * 10 ** 6)
565
+ return tempo
566
+
567
+ @staticmethod
568
+ def sf2key(sf):
569
+ # sf in key_signature to key.
570
+ # key represents the sequence from C note to B note (12 in total)
571
+ return (sf * 7) % 12
572
+
573
+ @staticmethod
574
+ def key2sf(k, mi):
575
+ # key to sf
576
+ sf = (k * 7) % 12
577
+ if sf > 6 or (mi == 1 and sf >= 5):
578
+ sf -= 12
579
+ return sf
580
+
581
+ @staticmethod
582
+ def detect_key_signature(key_hist, threshold=0.7):
583
+ if len(key_hist) != 12:
584
+ return None
585
+ if sum(key_hist) == 0:
586
+ return None
587
+ p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
588
+ if p < threshold:
589
+ return None
590
+ keys = [x[1] for x in sorted(zip(key_hist, range(len(key_hist))), reverse=True, key=lambda x: x[0])[:7]]
591
+ keys = sorted(keys)
592
+ semitones = []
593
+ for i in range(len(keys)):
594
+ dis = keys[i] - keys[i - 1]
595
+ if dis == 1 or dis == -11:
596
+ semitones.append(keys[i])
597
+ if len(semitones) != 2:
598
+ return None
599
+ semitones_dis = semitones[1] - semitones[0]
600
+ if semitones_dis == 5:
601
+ root_key = semitones[0]
602
+ elif semitones_dis == 7:
603
+ root_key = semitones[1]
604
+ else:
605
+ return None
606
+ return root_key
607
+
608
+ def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
609
+ remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
610
+ if remap_track_channel is None: # set default value
611
+ remap_track_channel = self.optimise_midi
612
+ if add_default_instr is None:
613
+ add_default_instr = self.optimise_midi
614
+ if remove_empty_channels is None:
615
+ remove_empty_channels = self.optimise_midi
616
+
617
+ ticks_per_beat = midi_score[0]
618
+ event_list = {}
619
+ track_idx_map = {i: dict() for i in range(16)}
620
+ track_idx_dict = {}
621
+ channels = []
622
+ patch_channels = []
623
+ empty_channels = [True] * 16
624
+ channel_note_tracks = {i: list() for i in range(16)}
625
+ note_key_hist = [0]*12
626
+ key_sigs = []
627
+ track_to_channels = {}
628
+ for track_idx, track in enumerate(midi_score[1:129]):
629
+ last_notes = {}
630
+ patch_dict = {}
631
+ control_dict = {}
632
+ last_bpm = 0
633
+ track_channels = []
634
+ track_to_channels.setdefault(track_idx, track_channels)
635
+ for event in track:
636
+ if event[0] not in self.events:
637
+ continue
638
+ name = event[0]
639
+ c = -1
640
+ t = round(16 * event[1] / ticks_per_beat) # quantization
641
+ new_event = [name, t // 16, t % 16, track_idx]
642
+ if name == "note":
643
+ d, c, p, v = event[2:]
644
+ if not (0 <= c <= 15):
645
+ continue
646
+ d = max(1, round(16 * d / ticks_per_beat))
647
+ new_event += [c, p, v, d]
648
+ empty_channels[c] = False
649
+ track_idx_dict.setdefault(c, track_idx)
650
+ note_tracks = channel_note_tracks[c]
651
+ if track_idx not in note_tracks:
652
+ note_tracks.append(track_idx)
653
+ if c != 9:
654
+ note_key_hist[p%12] += 1
655
+ if c not in track_channels:
656
+ track_channels.append(c)
657
+ elif name == "patch_change":
658
+ c, p = event[2:]
659
+ if not (0 <= c <= 15):
660
+ continue
661
+ new_event += [c, p]
662
+ last_p = patch_dict.setdefault(c, None)
663
+ if last_p == p:
664
+ continue
665
+ patch_dict[c] = p
666
+ if c not in patch_channels:
667
+ patch_channels.append(c)
668
+ elif name == "control_change":
669
+ c, cc, v = event[2:]
670
+ if not (0 <= c <= 15):
671
+ continue
672
+ new_event += [c, cc, v]
673
+ last_v = control_dict.setdefault((c, cc), 0)
674
+ if abs(last_v - v) < cc_eps:
675
+ continue
676
+ control_dict[(c, cc)] = v
677
+ elif name == "set_tempo":
678
+ tempo = event[2]
679
+ if tempo == 0: # invalid tempo
680
+ continue
681
+ bpm = min(int(self.tempo2bpm(tempo)), 383)
682
+ new_event += [bpm]
683
+ if abs(last_bpm - bpm) < tempo_eps:
684
+ continue
685
+ last_bpm = bpm
686
+ elif name == "time_signature":
687
+ nn, dd = event[2:4]
688
+ if not (1 <= nn <= 16 and 1 <= dd <= 4): # invalid
689
+ continue
690
+ nn -= 1 # make it start from 0
691
+ dd -= 1
692
+ new_event += [nn, dd]
693
+ elif name == "key_signature":
694
+ sf, mi = event[2:]
695
+ if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
696
+ continue
697
+ sf += 7
698
+ new_event += [sf, mi]
699
+ key_sigs.append(new_event)
700
+
701
+ if name in ["note", "time_signature", "key_signature"]:
702
+ key = tuple(new_event[:-2])
703
+ else:
704
+ key = tuple(new_event[:-1])
705
+
706
+ if c != -1:
707
+ if c not in channels:
708
+ channels.append(c)
709
+ tr_map = track_idx_map[c]
710
+ if track_idx not in tr_map:
711
+ tr_map[track_idx] = 0
712
+
713
+ if event[0] == "note": # to eliminate note overlap due to quantization
714
+ cp = tuple(new_event[4:6]) # channel pitch
715
+ if cp in last_notes:
716
+ last_note_key, last_note = last_notes[cp]
717
+ last_t = last_note[1] * 16 + last_note[2]
718
+ last_note[-1] = max(0, min(last_note[-1], t - last_t)) # modify duration
719
+ if last_note[-1] == 0:
720
+ event_list.pop(last_note_key)
721
+ last_notes[cp] = (key, new_event)
722
+ event_list[key] = new_event
723
+ event_list = list(event_list.values())
724
+
725
+ empty_channels = [c for c in channels if empty_channels[c]]
726
+
727
+ if remap_track_channel:
728
+ patch_channels = []
729
+ channels_count = 0
730
+ channels_map = {9: 9} if 9 in channels else {}
731
+ if remove_empty_channels:
732
+ channels = sorted(channels, key=lambda x: 1 if x in empty_channels else 0)
733
+ for c in channels:
734
+ if c == 9:
735
+ continue
736
+ channels_map[c] = channels_count
737
+ channels_count += 1
738
+ if channels_count == 9:
739
+ channels_count = 10
740
+ channels = list(channels_map.values())
741
+
742
+ track_count = 0
743
+ track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
744
+ for c in track_idx_map_order: # tracks not to remove
745
+ if remove_empty_channels and c in empty_channels:
746
+ continue
747
+ tr_map = track_idx_map[c]
748
+ for track_idx in tr_map:
749
+ note_tracks = channel_note_tracks[c]
750
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
751
+ continue
752
+ track_count += 1
753
+ tr_map[track_idx] = track_count
754
+ for c in track_idx_map_order: # tracks to remove
755
+ if not (remove_empty_channels and c in empty_channels):
756
+ continue
757
+ tr_map = track_idx_map[c]
758
+ for track_idx in tr_map:
759
+ note_tracks = channel_note_tracks[c]
760
+ if not (len(note_tracks) != 0 and track_idx not in note_tracks):
761
+ continue
762
+ track_count += 1
763
+ tr_map[track_idx] = track_count
764
+
765
+ empty_channels = [channels_map[c] for c in empty_channels]
766
+ track_idx_dict = {}
767
+ key_sigs = []
768
+ key_signature_to_add = []
769
+ key_signature_to_remove = []
770
+ for event in event_list:
771
+ name = event[0]
772
+ track_idx = event[3]
773
+ if name == "note":
774
+ c = event[4]
775
+ event[4] = channels_map[c] # channel
776
+ event[3] = track_idx_map[c][track_idx] # track
777
+ track_idx_dict.setdefault(event[4], event[3])
778
+ # setdefault, so the track_idx is first of the channel
779
+ elif name in ["set_tempo", "time_signature"]:
780
+ event[3] = 0 # set track 0 for meta events
781
+ elif name == "key_signature":
782
+ new_channel_track_idxs = []
783
+ for c, tr_map in track_idx_map.items():
784
+ if track_idx in tr_map:
785
+ new_track_idx = tr_map[track_idx]
786
+ c = channels_map[c]
787
+ new_channel_track_idx = (c, new_track_idx)
788
+ if new_track_idx == 0:
789
+ continue
790
+ if new_channel_track_idx not in new_channel_track_idxs:
791
+ new_channel_track_idxs.append(new_channel_track_idx)
792
+
793
+ if len(new_channel_track_idxs) == 0:
794
+ if event[3] == 0: # keep key_signature on track 0 (meta)
795
+ key_sigs.append(event)
796
+ continue
797
+ event[3] = -1 # avoid remove same event
798
+ key_signature_to_remove.append(event) # empty track
799
+ continue
800
+ c, nt = new_channel_track_idxs[0]
801
+ event[3] = nt
802
+ key_sigs.append(event)
803
+ if c == 9:
804
+ event[4] = 7 # sf=0
805
+ for c, nt in new_channel_track_idxs[1:]:
806
+ new_event = [*event]
807
+ new_event[3] = nt
808
+ if c == 9:
809
+ new_event[4] = 7 # sf=0
810
+ key_sigs.append(new_event)
811
+ key_signature_to_add.append(new_event)
812
+ elif name == "control_change" or name == "patch_change":
813
+ c = event[4]
814
+ event[4] = channels_map[c] # channel
815
+ tr_map = track_idx_map[c]
816
+ # move the event to first track of the channel if it's original track is empty
817
+ note_tracks = channel_note_tracks[c]
818
+ if len(note_tracks) != 0 and track_idx not in note_tracks:
819
+ track_idx = channel_note_tracks[c][0]
820
+ new_track_idx = tr_map[track_idx]
821
+ event[3] = new_track_idx
822
+ if name == "patch_change" and event[4] not in patch_channels:
823
+ patch_channels.append(event[4])
824
+ for key_sig in key_signature_to_remove:
825
+ event_list.remove(key_sig)
826
+ event_list += key_signature_to_add
827
+ track_to_channels ={}
828
+ for c, tr_map in track_idx_map.items():
829
+ if c not in channels_map:
830
+ continue
831
+ c = channels_map[c]
832
+ for _, track_idx in tr_map.items():
833
+ track_to_channels.setdefault(track_idx, [])
834
+ cs = track_to_channels[track_idx]
835
+ if c not in cs:
836
+ cs.append(c)
837
+
838
+ if add_default_instr:
839
+ for c in channels:
840
+ if c not in patch_channels and c in track_idx_dict:
841
+ event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
842
+
843
+ if len(key_sigs) == 0 or all([key_sig[4]==7 for key_sig in key_sigs]):
844
+ # detect key signature or fix the default key signature
845
+ root_key = self.detect_key_signature(note_key_hist)
846
+ if root_key is not None:
847
+ sf = self.key2sf(root_key, 0)
848
+ # print("detect_key_signature",sf)
849
+ if len(key_sigs) == 0:
850
+ for tr, cs in track_to_channels.items():
851
+ if remap_track_channel and tr == 0:
852
+ continue
853
+ new_event = ["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0]
854
+ event_list.append(new_event)
855
+ else:
856
+ for key_sig in key_sigs:
857
+ tr = key_sig[3]
858
+ if tr in track_to_channels:
859
+ cs = track_to_channels[tr]
860
+ if len(cs) == 1 and cs[0] == 9:
861
+ continue
862
+ key_sig[4] = sf + 7
863
+ key_sig[5] = 0
864
+ else:
865
+ # remove default key signature
866
+ for key_sig in key_sigs:
867
+ event_list.remove(key_sig)
868
+
869
+ events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
870
+ events_name_order = {name: i for i, name in enumerate(events_name_order)}
871
+ events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
872
+ event_list = sorted(event_list, key=events_order)
873
+
874
+ setup_events = {}
875
+ notes_in_setup = False
876
+ for i, event in enumerate(event_list): # optimise setup
877
+ new_event = [*event] # make copy of event
878
+ if event[0] not in ["note", "time_signature"]:
879
+ new_event[1] = 0
880
+ new_event[2] = 0
881
+ has_next = False
882
+ has_pre = False
883
+ if i < len(event_list) - 1:
884
+ next_event = event_list[i + 1]
885
+ has_next = event[1] + event[2] == next_event[1] + next_event[2]
886
+ if notes_in_setup and i > 0:
887
+ pre_event = event_list[i - 1]
888
+ has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
889
+ if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
890
+ event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
891
+ break
892
+ else:
893
+ if event[0] == "note":
894
+ notes_in_setup = True
895
+ if event[0] in ["note", "time_signature", "key_signature"]:
896
+ key = tuple([event[0]]+event[3:-2])
897
+ else:
898
+ key = tuple([event[0]]+event[3:-1])
899
+ setup_events[key] = new_event
900
+
901
+ last_t1 = 0
902
+ midi_seq = []
903
+ for event in event_list:
904
+ if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
905
+ continue
906
+ cur_t1 = event[1]
907
+ event[1] = event[1] - last_t1
908
+ tokens = self.event2tokens(event)
909
+ if not tokens:
910
+ continue
911
+ midi_seq.append(tokens)
912
+ last_t1 = cur_t1
913
+
914
+ if add_bos_eos:
915
+ bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
916
+ eos = [self.eos_id] + [self.pad_id] * (self.max_token_seq - 1)
917
+ midi_seq = [bos] + midi_seq + [eos]
918
+ return midi_seq
919
+
920
+ def event2tokens(self, event):
921
+ name = event[0]
922
+ params = event[1:]
923
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
924
+ return []
925
+ tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
926
+ for i, p in enumerate(self.events[name])]
927
+ tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
928
+ return tokens
929
+
930
+ def tokens2event(self, tokens):
931
+ if tokens[0] not in self.id_events:
932
+ return []
933
+ name = self.id_events[tokens[0]]
934
+ if len(tokens) <= len(self.events[name]):
935
+ return []
936
+ params = tokens[1:]
937
+ params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
938
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
939
+ return []
940
+ event = [name] + params
941
+ return event
942
+
943
+ def detokenize(self, midi_seq):
944
+ ticks_per_beat = 480
945
+ tracks_dict = {}
946
+ t1 = 0
947
+ for tokens in midi_seq:
948
+ if tokens[0] in self.id_events:
949
+ event = self.tokens2event(tokens)
950
+ if not event:
951
+ continue
952
+ name = event[0]
953
+ t1 += event[1]
954
+ t = t1 * 16 + event[2]
955
+ t = int(t * ticks_per_beat / 16)
956
+ track_idx = event[3]
957
+ event_new = [name, t]
958
+ if name == "note":
959
+ c, p, v, d = event[4:]
960
+ d = int(d * ticks_per_beat / 16)
961
+ event_new += [d, c, p, v]
962
+ elif name == "control_change" or name == "patch_change":
963
+ event_new += event[4:]
964
+ elif name == "set_tempo":
965
+ event_new += [self.bpm2tempo(event[4])]
966
+ elif name == "time_signature":
967
+ nn, dd = event[4:]
968
+ nn += 1
969
+ dd += 1
970
+ event_new += [nn, dd, 24, 8] # usually cc, bb = 24, 8
971
+ elif name == "key_signature":
972
+ sf, mi = event[4:]
973
+ sf -= 7
974
+ event_new += [sf, mi]
975
+ else: # should not go here
976
+ continue
977
+ if track_idx not in tracks_dict:
978
+ tracks_dict[track_idx] = []
979
+ tracks_dict[track_idx].append(event_new)
980
+ tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
981
+
982
+ for i in range(len(tracks)): # to eliminate note overlap
983
+ track = tracks[i]
984
+ track = sorted(track, key=lambda e: e[1])
985
+ last_note_t = {}
986
+ zero_len_notes = []
987
+ for e in reversed(track):
988
+ if e[0] == "note":
989
+ t, d, c, p = e[1:5]
990
+ key = (c, p)
991
+ if key in last_note_t:
992
+ d = min(d, max(last_note_t[key] - t, 0))
993
+ last_note_t[key] = t
994
+ e[2] = d
995
+ if d == 0:
996
+ zero_len_notes.append(e)
997
+ for e in zero_len_notes:
998
+ track.remove(e)
999
+ tracks[i] = track
1000
+ return [ticks_per_beat, *tracks]
1001
+
1002
+ def midi2img(self, midi_score):
1003
+ ticks_per_beat = midi_score[0]
1004
+ notes = []
1005
+ max_time = 1
1006
+ track_num = len(midi_score[1:])
1007
+ for track_idx, track in enumerate(midi_score[1:]):
1008
+ for event in track:
1009
+ t = round(16 * event[1] / ticks_per_beat)
1010
+ if event[0] == "note":
1011
+ d = max(1, round(16 * event[2] / ticks_per_beat))
1012
+ c, p = event[3:5]
1013
+ max_time = max(max_time, t + d + 1)
1014
+ notes.append((track_idx, c, p, t, d))
1015
+ img = np.zeros((128, max_time, 3), dtype=np.uint8)
1016
+ colors = {(i, j): np.random.randint(50, 256, 3) for i in range(track_num) for j in range(16)}
1017
+ for note in notes:
1018
+ tr, c, p, t, d = note
1019
+ img[p, t: t + d] = colors[(tr, c)]
1020
+ img = PIL.Image.fromarray(np.flip(img, 0))
1021
+ return img
1022
+
1023
+ def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
1024
+ max_track_shift=0, max_channel_shift=16):
1025
+ pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
1026
+ vel_shift = random.randint(-max_vel_shift, max_vel_shift)
1027
+ cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
1028
+ bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
1029
+ track_shift = random.randint(0, max_track_shift)
1030
+ channel_shift = random.randint(0, max_channel_shift)
1031
+ midi_seq_new = []
1032
+ key_signature_tokens = []
1033
+ track_to_channels = {}
1034
+ for tokens in midi_seq:
1035
+ tokens_new = [*tokens]
1036
+ if tokens[0] in self.id_events:
1037
+ name = self.id_events[tokens[0]]
1038
+ for i, pn in enumerate(self.events[name]):
1039
+ if pn == "track":
1040
+ tr = tokens[1 + i] - self.parameter_ids[pn][0]
1041
+ tr += track_shift
1042
+ tr = tr % self.event_parameters[pn]
1043
+ tokens_new[1 + i] = self.parameter_ids[pn][tr]
1044
+ elif pn == "channel":
1045
+ c = tokens[1 + i] - self.parameter_ids[pn][0]
1046
+ c0 = c
1047
+ c += channel_shift
1048
+ c = c % self.event_parameters[pn]
1049
+ if c0 == 9:
1050
+ c = 9
1051
+ elif c == 9:
1052
+ c = (9 + channel_shift) % self.event_parameters[pn]
1053
+ tokens_new[1 + i] = self.parameter_ids[pn][c]
1054
+
1055
+ if name == "note":
1056
+ tr = tokens[3] - self.parameter_ids["track"][0]
1057
+ c = tokens[4] - self.parameter_ids["channel"][0]
1058
+ p = tokens[5] - self.parameter_ids["pitch"][0]
1059
+ v = tokens[6] - self.parameter_ids["velocity"][0]
1060
+ if c != 9: # no shift for drums
1061
+ p += pitch_shift
1062
+ if not 0 <= p < 128:
1063
+ return midi_seq
1064
+ v += vel_shift
1065
+ v = max(1, min(127, v))
1066
+ tokens_new[5] = self.parameter_ids["pitch"][p]
1067
+ tokens_new[6] = self.parameter_ids["velocity"][v]
1068
+ track_to_channels.setdefault(tr, [])
1069
+ cs = track_to_channels[tr]
1070
+ if c not in cs:
1071
+ cs.append(c)
1072
+ elif name == "control_change":
1073
+ cc = tokens[5] - self.parameter_ids["controller"][0]
1074
+ val = tokens[6] - self.parameter_ids["value"][0]
1075
+ if cc in [1, 2, 7, 11]:
1076
+ val += cc_val_shift
1077
+ val = max(1, min(127, val))
1078
+ tokens_new[6] = self.parameter_ids["value"][val]
1079
+ elif name == "set_tempo":
1080
+ bpm = tokens[4] - self.parameter_ids["bpm"][0]
1081
+ bpm += bpm_shift
1082
+ bpm = max(1, min(383, bpm))
1083
+ tokens_new[4] = self.parameter_ids["bpm"][bpm]
1084
+ elif name == "key_signature":
1085
+ sf = tokens[4] - self.parameter_ids["sf"][0]
1086
+ mi = tokens[5] - self.parameter_ids["mi"][0]
1087
+ sf -= 7
1088
+ k = self.sf2key(sf)
1089
+ k = (k + pitch_shift) % 12
1090
+ sf = self.key2sf(k, mi)
1091
+ sf += 7
1092
+ tokens_new[4] = self.parameter_ids["sf"][sf]
1093
+ tokens_new[5] = self.parameter_ids["mi"][mi]
1094
+ key_signature_tokens.append(tokens_new)
1095
+ midi_seq_new.append(tokens_new)
1096
+ for tokens in key_signature_tokens:
1097
+ tr = tokens[3] - self.parameter_ids["track"][0]
1098
+ if tr in track_to_channels:
1099
+ cs = track_to_channels[tr]
1100
+ if len(cs) == 1 and cs[0] == 9:
1101
+ tokens[4] = self.parameter_ids["sf"][7] # sf=0
1102
+ return midi_seq_new
1103
+
1104
+ def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
1105
+ notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
1106
+ note_window_size=16):
1107
+ total_notes = 0
1108
+ channels = []
1109
+ time_hist = [0] * 16
1110
+ note_windows = {}
1111
+ notes_sametime = []
1112
+ notes_density_list = []
1113
+ tonality_list = []
1114
+ notes_bandwidth_list = []
1115
+ instruments = {}
1116
+ piano_channels = []
1117
+ abs_t1 = 0
1118
+ last_t = 0
1119
+ for tsi, tokens in enumerate(midi_seq):
1120
+ event = self.tokens2event(tokens)
1121
+ if not event:
1122
+ continue
1123
+ t1, t2, tr = event[1:4]
1124
+ abs_t1 += t1
1125
+ t = abs_t1 * 16 + t2
1126
+ c = None
1127
+ if event[0] == "note":
1128
+ c, p, v, d = event[4:]
1129
+ total_notes += 1
1130
+ time_hist[t2] += 1
1131
+ if c != 9: # ignore drum channel
1132
+ if c not in instruments:
1133
+ instruments[c] = 0
1134
+ if c not in piano_channels:
1135
+ piano_channels.append(c)
1136
+ note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
1137
+ if last_t != t:
1138
+ notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
1139
+ notes_sametime_p = [p_ for _, p_ in notes_sametime]
1140
+ if len(notes_sametime) > 0:
1141
+ notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
1142
+ notes_sametime.append((t + d - 1, p))
1143
+ elif event[0] == "patch_change":
1144
+ c, p = event[4:]
1145
+ instruments[c] = p
1146
+ if p == 0 and c not in piano_channels:
1147
+ piano_channels.append(c)
1148
+ if c is not None and c not in channels:
1149
+ channels.append(c)
1150
+ last_t = t
1151
+ reasons = []
1152
+ if total_notes < total_notes_min:
1153
+ reasons.append("total_min")
1154
+ if total_notes > total_notes_max:
1155
+ reasons.append("total_max")
1156
+ if len(note_windows) == 0 and total_notes > 0:
1157
+ reasons.append("drum_only")
1158
+ if reasons:
1159
+ return False, reasons
1160
+ time_hist = sorted(time_hist, reverse=True)
1161
+ alignment = sum(time_hist[:2]) / total_notes
1162
+ for notes in note_windows.values():
1163
+ key_hist = [0] * 12
1164
+ for p in notes:
1165
+ key_hist[p % 12] += 1
1166
+ key_hist = sorted(key_hist, reverse=True)
1167
+ tonality_list.append(sum(key_hist[:7]) / len(notes))
1168
+ notes_density_list.append(len(notes) / note_window_size)
1169
+ tonality_list = sorted(tonality_list)
1170
+ tonality = sum(tonality_list) / len(tonality_list)
1171
+ notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
1172
+ notes_density = max(notes_density_list) if notes_density_list else 0
1173
+ piano_ratio = len(piano_channels) / len(channels)
1174
+ if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
1175
+ piano_max = 1
1176
+ if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
1177
+ reasons.append("alignment")
1178
+ if tonality < tonality_min: # check whether the music is tonal
1179
+ reasons.append("tonality")
1180
+ if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
1181
+ reasons.append("bandwidth")
1182
+ if not notes_density_min < notes_density < notes_density_max:
1183
+ reasons.append("density")
1184
+ if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
1185
+ reasons.append("piano")
1186
+ return not reasons, reasons
1187
+
1188
+
1189
+ class MIDITokenizer:
1190
+ def __new__(cls, version="v2"):
1191
+ if version == "v1":
1192
+ return MIDITokenizerV1()
1193
+ elif version == "v2":
1194
+ return MIDITokenizerV2()
1195
+ else:
1196
+ raise ValueError(f"Unsupported version: {version}")
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ Pillow
3
+ numpy
4
+ torch
5
+ onnxruntime-gpu
6
+ peft>=0.13.0
7
+ transformers>=4.36
8
+ gradio==5.0.1
9
+ pyfluidsynth
10
+ tqdm
11
+ huggingface_hub