teticio commited on
Commit
c78ba1a
1 Parent(s): ffb86f0

add parameters to mel

Browse files
Files changed (1) hide show
  1. audiodiffusion/__init__.py +16 -4
audiodiffusion/__init__.py CHANGED
@@ -9,7 +9,7 @@ from diffusers import DDPMPipeline, DDPMScheduler
9
 
10
  from .mel import Mel
11
 
12
- VERSION = "1.1.4"
13
 
14
 
15
  class AudioDiffusion:
@@ -17,6 +17,10 @@ class AudioDiffusion:
17
  def __init__(self,
18
  model_id: str = "teticio/audio-diffusion-256",
19
  resolution: int = 256,
 
 
 
 
20
  cuda: bool = torch.cuda.is_available(),
21
  progress_bar: Iterable = tqdm):
22
  """Class for generating audio using Denoising Diffusion Probabilistic Models.
@@ -24,10 +28,19 @@ class AudioDiffusion:
24
  Args:
25
  model_id (String): name of model (local directory or Hugging Face Hub)
26
  resolution (int): size of square mel spectrogram in pixels
 
 
 
 
27
  cuda (bool): use CUDA?
28
  progress_bar (iterable): iterable callback for progress updates or None
29
  """
30
- self.mel = Mel(x_res=resolution, y_res=resolution)
 
 
 
 
 
31
  self.model_id = model_id
32
  self.ddpm = DDPMPipeline.from_pretrained(self.model_id)
33
  if cuda:
@@ -92,8 +105,7 @@ class AudioDiffusion:
92
  images = noise = torch.randn(
93
  (1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
94
  self.ddpm.unet.sample_size),
95
- generator=generator
96
- )
97
 
98
  if audio_file is not None or raw_audio is not None:
99
  self.mel.load_audio(audio_file, raw_audio)
 
9
 
10
  from .mel import Mel
11
 
12
+ VERSION = "1.1.5"
13
 
14
 
15
  class AudioDiffusion:
 
17
  def __init__(self,
18
  model_id: str = "teticio/audio-diffusion-256",
19
  resolution: int = 256,
20
+ sample_rate: int = 22050,
21
+ n_fft: int = 2048,
22
+ hop_length: int = 512,
23
+ top_db: int = 80,
24
  cuda: bool = torch.cuda.is_available(),
25
  progress_bar: Iterable = tqdm):
26
  """Class for generating audio using Denoising Diffusion Probabilistic Models.
 
28
  Args:
29
  model_id (String): name of model (local directory or Hugging Face Hub)
30
  resolution (int): size of square mel spectrogram in pixels
31
+ sample_rate (int): sample rate of audio
32
+ n_fft (int): number of Fast Fourier Transforms
33
+ hop_length (int): hop length (a higher number is recommended for lower than 256 y_res)
34
+ top_db (int): loudest in decibels
35
  cuda (bool): use CUDA?
36
  progress_bar (iterable): iterable callback for progress updates or None
37
  """
38
+ self.mel = Mel(x_res=resolution,
39
+ y_res=resolution,
40
+ sample_rate=sample_rate,
41
+ n_fft=n_fft,
42
+ hop_length=hop_length,
43
+ top_db=top_db)
44
  self.model_id = model_id
45
  self.ddpm = DDPMPipeline.from_pretrained(self.model_id)
46
  if cuda:
 
105
  images = noise = torch.randn(
106
  (1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
107
  self.ddpm.unet.sample_size),
108
+ generator=generator)
 
109
 
110
  if audio_file is not None or raw_audio is not None:
111
  self.mel.load_audio(audio_file, raw_audio)