Sylvain Filoni commited on
Commit
d6953d3
·
1 Parent(s): 4e7833d

try addind duration control

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. spectro.py +2 -2
app.py CHANGED
@@ -9,10 +9,10 @@ model_id = "riffusion/riffusion-model-v1"
9
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
10
  pipe = pipe.to("cuda")
11
 
12
- def predict(prompt):
13
  spec = pipe(prompt).images[0]
14
  print(spec)
15
- wav = wav_bytes_from_spectrogram_image(spec)
16
  with open("output.wav", "wb") as f:
17
  f.write(wav[0].getbuffer())
18
  return spec, 'output.wav'
@@ -102,7 +102,7 @@ with gr.Blocks(css=css) as demo:
102
  gr.HTML(title)
103
 
104
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club")
105
-
106
  send_btn = gr.Button("Get a new spectrogram ! ")
107
 
108
  with gr.Column(elem_id="col-container-2"):
@@ -111,6 +111,6 @@ with gr.Blocks(css=css) as demo:
111
 
112
  gr.HTML(article)
113
 
114
- send_btn.click(predict, inputs=[prompt_input], outputs=[spectrogram_output, sound_output])
115
 
116
  demo.queue(max_size=250).launch(debug=True)
 
9
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
10
  pipe = pipe.to("cuda")
11
 
12
+ def predict(prompt, sample_duration):
13
  spec = pipe(prompt).images[0]
14
  print(spec)
15
+ wav = wav_bytes_from_spectrogram_image(spec, sample_duration)
16
  with open("output.wav", "wb") as f:
17
  f.write(wav[0].getbuffer())
18
  return spec, 'output.wav'
 
102
  gr.HTML(title)
103
 
104
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club")
105
+ sample_duration_input = gr.Slider(minimum=5, maximum=20, value=5, step=5)
106
  send_btn = gr.Button("Get a new spectrogram ! ")
107
 
108
  with gr.Column(elem_id="col-container-2"):
 
111
 
112
  gr.HTML(article)
113
 
114
+ send_btn.click(predict, inputs=[prompt_input, sample_duration_input], outputs=[spectrogram_output, sound_output])
115
 
116
  demo.queue(max_size=250).launch(debug=True)
spectro.py CHANGED
@@ -12,7 +12,7 @@ import torch
12
  import torchaudio
13
 
14
 
15
- def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
16
  """
17
  Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
18
  """
@@ -22,7 +22,7 @@ def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO,
22
  Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)
23
 
24
  sample_rate = 44100 # [Hz]
25
- clip_duration_ms = 5000 # [ms]
26
 
27
  bins_per_image = 512
28
  n_mels = 512
 
12
  import torchaudio
13
 
14
 
15
+ def wav_bytes_from_spectrogram_image(image: Image.Image, sample_duration) -> T.Tuple[io.BytesIO, float]:
16
  """
17
  Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
18
  """
 
22
  Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)
23
 
24
  sample_rate = 44100 # [Hz]
25
+ clip_duration_ms = sample_duration*1000 # [ms]
26
 
27
  bins_per_image = 512
28
  n_mels = 512