lauraibnz commited on
Commit
cdc91fe
1 Parent(s): 9d07bb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -26,7 +26,10 @@ generator = torch.Generator(device)
26
  def predict(midi_file=None, prompt="", negative_prompt="", audio_length_in_s=10, random_seed=0, controlnet_conditioning_scale=1, num_inference_steps=20, guidance_scale=2.5, guess_mode=False):
27
  if isinstance(midi_file, _TemporaryFileWrapper):
28
  midi_file = midi_file.name
29
- midi = PrettyMIDI(midi_file)
 
 
 
30
  audio = pipe(
31
  prompt,
32
  negative_prompt=negative_prompt,
@@ -46,7 +49,7 @@ def synthesize(midi_file=None):
46
  midi = PrettyMIDI(midi_file)
47
  midi_synth = midi.synthesize(fs=SAMPLE_RATE)
48
  midi_synth = midi_synth.reshape(midi_synth.shape[0], 1)
49
- return (SAMPLE_RATE, midi_synth)
50
 
51
  with gr.Blocks(title="🎹 MIDI-AudioLDM", theme=gr.themes.Base(text_size=gr.themes.sizes.text_md, font=[gr.themes.GoogleFont("Nunito Sans")])) as demo:
52
  gr.HTML(
@@ -60,8 +63,9 @@ with gr.Blocks(title="🎹 MIDI-AudioLDM", theme=gr.themes.Base(text_size=gr.the
60
  with gr.Row():
61
  with gr.Column(variant='panel'):
62
  midi_synth = gr.Audio(label="synthesized midi")
63
- midi = gr.UploadButton("Upload a MIDI File", file_types=[".mid"], value="S00.mid")
64
- midi.upload(synthesize, midi, midi_synth)
 
65
  prompt = gr.Textbox(label="prompt", info="Enter a descriptive text prompt to guide the audio generation.")
66
  with gr.Column(variant='panel'):
67
  audio = gr.Audio(label="generated audio")
 
26
  def predict(midi_file=None, prompt="", negative_prompt="", audio_length_in_s=10, random_seed=0, controlnet_conditioning_scale=1, num_inference_steps=20, guidance_scale=2.5, guess_mode=False):
27
  if isinstance(midi_file, _TemporaryFileWrapper):
28
  midi_file = midi_file.name
29
+ if not isinstance(midi_file, PrettyMIDI):
30
+ midi = PrettyMIDI(midi_file)
31
+ else:
32
+ midi = midi_file
33
  audio = pipe(
34
  prompt,
35
  negative_prompt=negative_prompt,
 
49
  midi = PrettyMIDI(midi_file)
50
  midi_synth = midi.synthesize(fs=SAMPLE_RATE)
51
  midi_synth = midi_synth.reshape(midi_synth.shape[0], 1)
52
+ return midi, (SAMPLE_RATE, midi_synth)
53
 
54
  with gr.Blocks(title="🎹 MIDI-AudioLDM", theme=gr.themes.Base(text_size=gr.themes.sizes.text_md, font=[gr.themes.GoogleFont("Nunito Sans")])) as demo:
55
  gr.HTML(
 
63
  with gr.Row():
64
  with gr.Column(variant='panel'):
65
  midi_synth = gr.Audio(label="synthesized midi")
66
+ midi_file = gr.UploadButton("Upload a MIDI File", file_types=[".mid"], value="S00.mid")
67
+ midi = PrettyMIDI()
68
+ midi_file.upload(synthesize, midi_file, [midi, midi_synth])
69
  prompt = gr.Textbox(label="prompt", info="Enter a descriptive text prompt to guide the audio generation.")
70
  with gr.Column(variant='panel'):
71
  audio = gr.Audio(label="generated audio")