skytnt commited on
Commit
1b97a00
1 Parent(s): 3e02d8a

fix midi visualizer

Browse files
Files changed (2) hide show
  1. app.py +25 -9
  2. javascript/app.js +21 -31
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import glob
3
  import os.path
 
4
 
5
  import gradio as gr
6
  import numpy as np
@@ -107,7 +108,7 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
107
 
108
 
109
  def create_msg(name, data):
110
- return {"name": name, "data": data}
111
 
112
 
113
  def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
@@ -164,7 +165,7 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
164
 
165
  def cancel_run(mid_seq):
166
  if mid_seq is None:
167
- return None, None
168
  mid = tokenizer.detokenize(mid_seq)
169
  with open(f"output.mid", 'wb') as f:
170
  f.write(MIDI.score2midi(mid))
@@ -189,17 +190,25 @@ def load_javascript(dir="javascript"):
189
 
190
  gr.routes.templates.TemplateResponse = template_response
191
 
 
192
  # JSMsgReceiver
193
- HTML_postprocess_ori = gr.HTML.postprocess
 
 
194
 
195
 
 
196
  def JSMsgReceiver_postprocess(self, y):
 
197
  if self.elem_id == "msg_receiver" and y:
198
- y = f"<p>{json.dumps(y)}</p>"
199
- return HTML_postprocess_ori(self, y)
 
 
 
200
 
201
 
202
- gr.HTML.postprocess = JSMsgReceiver_postprocess
203
 
204
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
205
  40: "Blush", 48: "Orchestra"}
@@ -214,8 +223,8 @@ if __name__ == "__main__":
214
  opt = parser.parse_args()
215
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
216
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
217
- "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
218
- "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
219
  }
220
  models = {}
221
  tokenizer = MIDITokenizer()
@@ -238,7 +247,14 @@ if __name__ == "__main__":
238
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
239
  " for faster running and longer generation"
240
  )
241
- js_msg = gr.HTML(elem_id="msg_receiver", visible=False)
 
 
 
 
 
 
 
242
  input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
243
  type="value", value=list(models.keys())[0])
244
  tab_select = gr.State(value=0)
 
1
  import argparse
2
  import glob
3
  import os.path
4
+ import uuid
5
 
6
  import gradio as gr
7
  import numpy as np
 
108
 
109
 
110
  def create_msg(name, data):
111
+ return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
112
 
113
 
114
  def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
 
165
 
166
  def cancel_run(mid_seq):
167
  if mid_seq is None:
168
+ return None, None, []
169
  mid = tokenizer.detokenize(mid_seq)
170
  with open(f"output.mid", 'wb') as f:
171
  f.write(MIDI.score2midi(mid))
 
190
 
191
  gr.routes.templates.TemplateResponse = template_response
192
 
193
+
194
  # JSMsgReceiver
195
+ Textbox_postprocess_ori = gr.Textbox.postprocess
196
+
197
+ msg_history = []
198
 
199
 
200
+ # the change event may not trigger every time, so send msg history to avoid msg missing.
201
  def JSMsgReceiver_postprocess(self, y):
202
+ global msg_history
203
  if self.elem_id == "msg_receiver" and y:
204
+ msg_history.append(y)
205
+ if len(msg_history) > 50:
206
+ msg_history = msg_history[1:]
207
+ y = json.dumps(msg_history)
208
+ return Textbox_postprocess_ori(self, y)
209
 
210
 
211
+ gr.Textbox.postprocess = JSMsgReceiver_postprocess
212
 
213
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
214
  40: "Blush", 48: "Orchestra"}
 
223
  opt = parser.parse_args()
224
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
225
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
226
+ # "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
227
+ # "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
228
  }
229
  models = {}
230
  tokenizer = MIDITokenizer()
 
247
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
248
  " for faster running and longer generation"
249
  )
250
+ js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
251
+ js_msg.change(None, [js_msg], [], js="""
252
+ (msg_json) =>{
253
+ let msgs = JSON.parse(msg_json);
254
+ executeCallbacks(msgReceiveCallbacks, msgs);
255
+ return [];
256
+ }
257
+ """)
258
  input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
259
  type="value", value=list(models.keys())[0])
260
  tab_select = gr.State(value=0)
javascript/app.js CHANGED
@@ -76,33 +76,6 @@ document.addEventListener("DOMContentLoaded", function() {
76
  mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
77
  });
78
 
79
- (()=>{
80
- let mse_receiver_inited = null
81
- onUiUpdate(()=>{
82
- let app = gradioApp()
83
- let msg_receiver = app.querySelector("#msg_receiver");
84
- if(!!msg_receiver && mse_receiver_inited !== msg_receiver){
85
- let mutationObserver = new MutationObserver(function(ms){
86
- ms.forEach((m)=>{
87
- m.addedNodes.forEach((node)=>{
88
- if(node.nodeName === "P"){
89
- let obj = JSON.parse(node.innerText);
90
- if(obj instanceof Array){
91
- obj.forEach((o)=>{executeCallbacks(msgReceiveCallbacks, o);});
92
- }else{
93
- executeCallbacks(msgReceiveCallbacks, obj);
94
- }
95
- }
96
- })
97
- })
98
- });
99
- mutationObserver.observe( msg_receiver, {childList:true, subtree:true, characterData:true})
100
- console.log("receiver init");
101
- mse_receiver_inited = msg_receiver;
102
- }
103
- })
104
- })()
105
-
106
  function HSVtoRGB(h, s, v) {
107
  let r, g, b, i, f, p, q, t;
108
  i = Math.floor(h * 6);
@@ -261,9 +234,11 @@ class MidiVisualizer extends HTMLElement{
261
  this.midiTimes.push({ms:ms, t: t, tempo: tempo})
262
  }
263
  if(midiEvent[0]==="note"){
264
- this.totalTimeMs = ms + (midiEvent[3]/ this.timePreBeat)*tempo
 
 
265
  }
266
- lastT = t
267
  })
268
  }
269
 
@@ -431,7 +406,22 @@ customElements.define('midi-visualizer', MidiVisualizer);
431
  divInner.textContent = `${progress}/${total}`;
432
  }
433
 
434
- onMsgReceive((msg)=>{
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  switch (msg.name) {
436
  case "visualizer_clear":
437
  midi_visualizer.clearMidiEvents();
@@ -452,5 +442,5 @@ customElements.define('midi-visualizer', MidiVisualizer);
452
  break;
453
  default:
454
  }
455
- })
456
  })();
 
76
  mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
77
  });
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  function HSVtoRGB(h, s, v) {
80
  let r, g, b, i, f, p, q, t;
81
  i = Math.floor(h * 6);
 
234
  this.midiTimes.push({ms:ms, t: t, tempo: tempo})
235
  }
236
  if(midiEvent[0]==="note"){
237
+ this.totalTimeMs = Math.max(this.totalTimeMs, ms + (midiEvent[3]/ this.timePreBeat)*tempo)
238
+ }else{
239
+ this.totalTimeMs = Math.max(this.totalTimeMs, ms);
240
  }
241
+ lastT = t;
242
  })
243
  }
244
 
 
406
  divInner.textContent = `${progress}/${total}`;
407
  }
408
 
409
+ onMsgReceive((msgs)=>{
410
+ for(let msg of msgs){
411
+ if(msg instanceof Array){
412
+ msg.forEach((o)=>{handleMsg(o)});
413
+ }else{
414
+ handleMsg(msg);
415
+ }
416
+ }
417
+ })
418
+ let handled_msgs = [];
419
+ function handleMsg(msg){
420
+ if(handled_msgs.indexOf(msg.uuid)!== -1)
421
+ return;
422
+ handled_msgs.push(msg.uuid);
423
+ if(handled_msgs.length > 200)
424
+ handled_msgs = handled_msgs.slice(1);
425
  switch (msg.name) {
426
  case "visualizer_clear":
427
  midi_visualizer.clearMidiEvents();
 
442
  break;
443
  default:
444
  }
445
+ }
446
  })();