teticio commited on
Commit
f30235e
1 Parent(s): 68d16a1

get ready for migration to diffusers

Browse files
README.md CHANGED
@@ -23,6 +23,8 @@ Go to https://soundcloud.com/teticio2/sets/audio-diffusion-loops for more exampl
23
  ---
24
  #### Updates
25
 
 
 
26
  **7/11/2022**. Added pre-trained latent audio diffusion models [teticio/latent-audio-diffusion-256](https://huggingface.co/teticio/latent-audio-diffusion-256) and [teticio/latent-audio-diffusion-ddim-256](https://huggingface.co/teticio/latent-audio-diffusion-ddim-256). You can use the pre-trained VAE to train your own latent diffusion models on a different set of audio files.
27
 
28
  **22/10/2022**. Added DDIM encoder and ability to interpolate between audios in latent "noise" space. Mel spectrograms no longer have to be square (thanks to Tristan for this one), so you can set the vertical (frequency) and horizontal (time) resolutions independently.
 
23
  ---
24
  #### Updates
25
 
26
+ **2/12/2022**. Added Mel to pipeline and updated the pretrained models to save Mel config (they are now no longer compatible with previous versions of this repo). It is relatively straightforward to migrate previously trained models to the new format (see https://huggingface.co/teticio/audio-diffusion-256).
27
+
28
  **7/11/2022**. Added pre-trained latent audio diffusion models [teticio/latent-audio-diffusion-256](https://huggingface.co/teticio/latent-audio-diffusion-256) and [teticio/latent-audio-diffusion-ddim-256](https://huggingface.co/teticio/latent-audio-diffusion-ddim-256). You can use the pre-trained VAE to train your own latent diffusion models on a different set of audio files.
29
 
30
  **22/10/2022**. Added DDIM encoder and ability to interpolate between audios in latent "noise" space. Mel spectrograms no longer have to be square (thanks to Tristan for this one), so you can set the vertical (frequency) and horizontal (time) resolutions independently.
audiodiffusion/__init__.py CHANGED
@@ -1,62 +1,34 @@
1
- from math import acos, sin
2
- from typing import Iterable, Tuple, Union, List
3
 
4
  import torch
5
  import numpy as np
6
  from PIL import Image
7
  from tqdm.auto import tqdm
8
  from librosa.beat import beat_track
9
- from diffusers import (DiffusionPipeline, UNet2DConditionModel, DDIMScheduler,
10
- DDPMScheduler, AutoencoderKL)
11
- from diffusers.pipeline_utils import (AudioPipelineOutput, BaseOutput,
12
- ImagePipelineOutput)
13
-
14
- from .mel import Mel
15
 
16
- VERSION = "1.2.7"
17
 
18
 
19
  class AudioDiffusion:
20
 
21
  def __init__(self,
22
  model_id: str = "teticio/audio-diffusion-256",
23
- sample_rate: int = 22050,
24
- n_fft: int = 2048,
25
- hop_length: int = 512,
26
- top_db: int = 80,
27
  cuda: bool = torch.cuda.is_available(),
28
  progress_bar: Iterable = tqdm):
29
  """Class for generating audio using De-noising Diffusion Probabilistic Models.
30
 
31
  Args:
32
  model_id (String): name of model (local directory or Hugging Face Hub)
33
- sample_rate (int): sample rate of audio
34
- n_fft (int): number of Fast Fourier Transforms
35
- hop_length (int): hop length (a higher number is recommended for lower than 256 y_res)
36
- top_db (int): loudest in decibels
37
  cuda (bool): use CUDA?
38
  progress_bar (iterable): iterable callback for progress updates or None
39
  """
40
  self.model_id = model_id
41
- pipeline = {
42
- 'LatentAudioDiffusionPipeline': LatentAudioDiffusionPipeline,
43
- 'AudioDiffusionPipeline': AudioDiffusionPipeline
44
- }.get(
45
- DiffusionPipeline.get_config_dict(self.model_id)['_class_name'],
46
- AudioDiffusionPipeline)
47
- self.pipe = pipeline.from_pretrained(self.model_id)
48
  if cuda:
49
  self.pipe.to("cuda")
50
  self.progress_bar = progress_bar or (lambda _: _)
51
 
52
- sample_size = self.pipe.get_input_dims()
53
- self.mel = Mel(x_res=sample_size[1],
54
- y_res=sample_size[0],
55
- sample_rate=sample_rate,
56
- n_fft=n_fft,
57
- hop_length=hop_length,
58
- top_db=top_db)
59
-
60
  def generate_spectrogram_and_audio(
61
  self,
62
  steps: int = None,
@@ -79,8 +51,7 @@ class AudioDiffusion:
79
  (float, np.ndarray): sample rate and raw audio
80
  """
81
  images, (sample_rate,
82
- audios) = self.pipe(mel=self.mel,
83
- batch_size=1,
84
  steps=steps,
85
  generator=generator,
86
  step_generator=step_generator,
@@ -124,8 +95,7 @@ class AudioDiffusion:
124
  """
125
 
126
  images, (sample_rate,
127
- audios) = self.pipe(mel=self.mel,
128
- batch_size=1,
129
  audio_file=audio_file,
130
  raw_audio=raw_audio,
131
  slice=slice,
@@ -161,18 +131,274 @@ class AudioDiffusion:
161
  return None
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  class AudioDiffusionPipeline(DiffusionPipeline):
165
- def __init__(self, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  super().__init__()
167
- self.register_modules(unet=unet, scheduler=scheduler)
168
 
169
  def get_input_dims(self) -> Tuple:
170
  """Returns dimension of input image
171
 
172
  Returns:
173
- Tuple: (height, width)
174
  """
175
- input_module = self.vqvae if hasattr(self, "vqvae") else self.unet
176
  # For backwards compatibility
177
  sample_size = (
178
  (input_module.sample_size, input_module.sample_size)
@@ -185,14 +411,13 @@ class AudioDiffusionPipeline(DiffusionPipeline):
185
  """Returns default number of steps recommended for inference
186
 
187
  Returns:
188
- int: number of steps
189
  """
190
  return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
191
 
192
  @torch.no_grad()
193
  def __call__(
194
  self,
195
- mel: Mel,
196
  batch_size: int = 1,
197
  audio_file: str = None,
198
  raw_audio: np.ndarray = None,
@@ -212,23 +437,22 @@ class AudioDiffusionPipeline(DiffusionPipeline):
212
  """Generate random mel spectrogram from audio input and convert to audio.
213
 
214
  Args:
215
- mel (Mel): instance of Mel class to perform image <-> audio
216
- batch_size (int): number of samples to generate
217
- audio_file (str): must be a file on disk due to Librosa limitation or
218
- raw_audio (np.ndarray): audio as numpy array
219
- slice (int): slice number of audio to convert
220
  start_step (int): step to start from
221
- steps (int): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
222
- generator (torch.Generator): random number generator or None
223
- mask_start_secs (float): number of seconds of audio to mask (not generate) at start
224
- mask_end_secs (float): number of seconds of audio to mask (not generate) at end
225
- step_generator (torch.Generator): random number generator used to de-noise or None
226
- eta (float): parameter between 0 and 1 used with DDIM scheduler
227
- noise (torch.Tensor): noise tensor of shape (batch_size, 1, height, width) or None
228
- return_dict (bool): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
229
 
230
  Returns:
231
- List[PIL Image]: mel spectrograms (float, List[np.ndarray]): sample rate and raw audios
232
  """
233
 
234
  steps = steps or self.get_default_steps()
@@ -238,7 +462,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
238
  if type(self.unet.sample_size) == int:
239
  self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
240
  input_dims = self.get_input_dims()
241
- mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
242
  if noise is None:
243
  noise = torch.randn(
244
  (batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]),
@@ -249,15 +473,15 @@ class AudioDiffusionPipeline(DiffusionPipeline):
249
  mask = None
250
 
251
  if audio_file is not None or raw_audio is not None:
252
- mel.load_audio(audio_file, raw_audio)
253
- input_image = mel.audio_slice_to_image(slice)
254
  input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape(
255
  (input_image.height, input_image.width)
256
  )
257
  input_image = (input_image / 255) * 2 - 1
258
  input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device)
259
 
260
- if hasattr(self, "vqvae"):
261
  input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
262
  generator=generator
263
  )[0]
@@ -266,7 +490,9 @@ class AudioDiffusionPipeline(DiffusionPipeline):
266
  if start_step > 0:
267
  images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
268
 
269
- pixels_per_second = self.unet.sample_size[1] * mel.get_sample_rate() / mel.x_res / mel.hop_length
 
 
270
  mask_start = int(mask_start_secs * pixels_per_second)
271
  mask_end = int(mask_end_secs * pixels_per_second)
272
  mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
@@ -289,7 +515,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
289
  if mask_end > 0:
290
  images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
291
 
292
- if hasattr(self, "vqvae"):
293
  # 0.18215 was scaling factor used in training to ensure unit variance
294
  images = 1 / 0.18215 * images
295
  images = self.vqvae.decode(images)["sample"]
@@ -303,9 +529,9 @@ class AudioDiffusionPipeline(DiffusionPipeline):
303
  else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images)
304
  )
305
 
306
- audios = list(map(lambda _: mel.image_to_audio(_), images))
307
  if not return_dict:
308
- return images, (mel.get_sample_rate(), audios)
309
 
310
  return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images))
311
 
@@ -314,11 +540,11 @@ class AudioDiffusionPipeline(DiffusionPipeline):
314
  """Reverse step process: recover noisy image from generated image.
315
 
316
  Args:
317
- images (List[PIL Image]): list of images to encode
318
- steps (int): number of encoding steps to perform (defaults to 50)
319
 
320
  Returns:
321
- np.ndarray: noise tensor of shape (batch_size, 1, height, width)
322
  """
323
 
324
  # Only works with DDIM as this method is deterministic
@@ -351,24 +577,22 @@ class AudioDiffusionPipeline(DiffusionPipeline):
351
  """Spherical Linear intERPolation
352
 
353
  Args:
354
- x0 (torch.Tensor): first tensor to interpolate between
355
- x1 (torch.Tensor): seconds tensor to interpolate between
356
- alpha (float): interpolation between 0 and 1
357
 
358
  Returns:
359
- torch.Tensor: interpolated tensor
360
  """
361
 
362
  theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1))
363
  return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta)
364
 
365
 
366
- class LatentAudioDiffusionPipeline(AudioDiffusionPipeline):
367
- def __init__(
368
- self, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler], vqvae: AutoencoderKL
369
- ):
370
- super().__init__(unet=unet, scheduler=scheduler)
371
- self.register_modules(vqvae=vqvae)
372
 
373
- def __call__(self, *args, **kwargs):
374
- return super().__call__(*args, **kwargs)
 
 
 
 
1
+ from typing import Iterable, Tuple, Union
 
2
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
  from tqdm.auto import tqdm
7
  from librosa.beat import beat_track
8
+ #from diffusers import DiffusionPipeline
 
 
 
 
 
9
 
10
+ VERSION = "1.3.0"
11
 
12
 
13
  class AudioDiffusion:
14
 
15
  def __init__(self,
16
  model_id: str = "teticio/audio-diffusion-256",
 
 
 
 
17
  cuda: bool = torch.cuda.is_available(),
18
  progress_bar: Iterable = tqdm):
19
  """Class for generating audio using De-noising Diffusion Probabilistic Models.
20
 
21
  Args:
22
  model_id (String): name of model (local directory or Hugging Face Hub)
 
 
 
 
23
  cuda (bool): use CUDA?
24
  progress_bar (iterable): iterable callback for progress updates or None
25
  """
26
  self.model_id = model_id
27
+ self.pipe = AudioDiffusionPipeline.from_pretrained(self.model_id)
 
 
 
 
 
 
28
  if cuda:
29
  self.pipe.to("cuda")
30
  self.progress_bar = progress_bar or (lambda _: _)
31
 
 
 
 
 
 
 
 
 
32
  def generate_spectrogram_and_audio(
33
  self,
34
  steps: int = None,
 
51
  (float, np.ndarray): sample rate and raw audio
52
  """
53
  images, (sample_rate,
54
+ audios) = self.pipe(batch_size=1,
 
55
  steps=steps,
56
  generator=generator,
57
  step_generator=step_generator,
 
95
  """
96
 
97
  images, (sample_rate,
98
+ audios) = self.pipe(batch_size=1,
 
99
  audio_file=audio_file,
100
  raw_audio=raw_audio,
101
  slice=slice,
 
131
  return None
132
 
133
 
134
+ # This code will be migrated to diffusers shortly
135
+
136
+ #-----------------------------------------------------------------------------#
137
+
138
+ import os
139
+ import warnings
140
+ from typing import Any, Dict, Optional, Union
141
+
142
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
143
+
144
+
145
+ warnings.filterwarnings("ignore")
146
+
147
+ import numpy as np # noqa: E402
148
+
149
+ import librosa # noqa: E402
150
+ from PIL import Image # noqa: E402
151
+
152
+
153
+ class Mel(ConfigMixin):
154
+ """
155
+ Parameters:
156
+ x_res (`int`): x resolution of spectrogram (time)
157
+ y_res (`int`): y resolution of spectrogram (frequency bins)
158
+ sample_rate (`int`): sample rate of audio
159
+ n_fft (`int`): number of Fast Fourier Transforms
160
+ hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
161
+ top_db (`int`): loudest in decibels
162
+ n_iter (`int`): number of iterations for Griffin Linn mel inversion
163
+ """
164
+
165
+ config_name = "mel_config.json"
166
+
167
+ @register_to_config
168
+ def __init__(
169
+ self,
170
+ x_res: int = 256,
171
+ y_res: int = 256,
172
+ sample_rate: int = 22050,
173
+ n_fft: int = 2048,
174
+ hop_length: int = 512,
175
+ top_db: int = 80,
176
+ n_iter: int = 32,
177
+ ):
178
+ self.hop_length = hop_length
179
+ self.sr = sample_rate
180
+ self.n_fft = n_fft
181
+ self.top_db = top_db
182
+ self.n_iter = n_iter
183
+ self.set_resolution(x_res, y_res)
184
+ self.audio = None
185
+
186
+ def set_resolution(self, x_res: int, y_res: int):
187
+ """Set resolution.
188
+
189
+ Args:
190
+ x_res (`int`): x resolution of spectrogram (time)
191
+ y_res (`int`): y resolution of spectrogram (frequency bins)
192
+ """
193
+ self.x_res = x_res
194
+ self.y_res = y_res
195
+ self.n_mels = self.y_res
196
+ self.slice_size = self.x_res * self.hop_length - 1
197
+
198
+ def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
199
+ """Load audio.
200
+
201
+ Args:
202
+ audio_file (`str`): must be a file on disk due to Librosa limitation or
203
+ raw_audio (`np.ndarray`): audio as numpy array
204
+ """
205
+ if audio_file is not None:
206
+ self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
207
+ else:
208
+ self.audio = raw_audio
209
+
210
+ # Pad with silence if necessary.
211
+ if len(self.audio) < self.x_res * self.hop_length:
212
+ self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))])
213
+
214
+ def get_number_of_slices(self) -> int:
215
+ """Get number of slices in audio.
216
+
217
+ Returns:
218
+ `int`: number of spectograms audio can be sliced into
219
+ """
220
+ return len(self.audio) // self.slice_size
221
+
222
+ def get_audio_slice(self, slice: int = 0) -> np.ndarray:
223
+ """Get slice of audio.
224
+
225
+ Args:
226
+ slice (`int`): slice number of audio (out of get_number_of_slices())
227
+
228
+ Returns:
229
+ `np.ndarray`: audio as numpy array
230
+ """
231
+ return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
232
+
233
+ def get_sample_rate(self) -> int:
234
+ """Get sample rate:
235
+
236
+ Returns:
237
+ `int`: sample rate of audio
238
+ """
239
+ return self.sr
240
+
241
+ def audio_slice_to_image(self, slice: int) -> Image.Image:
242
+ """Convert slice of audio to spectrogram.
243
+
244
+ Args:
245
+ slice (`int`): slice number of audio to convert (out of get_number_of_slices())
246
+
247
+ Returns:
248
+ `PIL Image`: grayscale image of x_res x y_res
249
+ """
250
+ S = librosa.feature.melspectrogram(
251
+ y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
252
+ )
253
+ log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
254
+ bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
255
+ image = Image.fromarray(bytedata)
256
+ return image
257
+
258
+ def image_to_audio(self, image: Image.Image) -> np.ndarray:
259
+ """Converts spectrogram to audio.
260
+
261
+ Args:
262
+ image (`PIL Image`): x_res x y_res grayscale image
263
+
264
+ Returns:
265
+ audio (`np.ndarray`): raw audio
266
+ """
267
+ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
268
+ log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
269
+ S = librosa.db_to_power(log_S)
270
+ audio = librosa.feature.inverse.mel_to_audio(
271
+ S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
272
+ )
273
+ return audio
274
+
275
+ @classmethod
276
+ def from_pretrained(
277
+ cls,
278
+ pretrained_model_name_or_path: Dict[str, Any] = None,
279
+ subfolder: Optional[str] = None,
280
+ return_unused_kwargs=False,
281
+ **kwargs,
282
+ ):
283
+ r"""
284
+ Instantiate a Mel class from a pre-defined JSON configuration file inside a directory or Hub repo.
285
+
286
+ Parameters:
287
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
288
+ Can be either:
289
+
290
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
291
+ organization name, like `google/ddpm-celebahq-256`.
292
+ - A path to a *directory* containing the mel configurations saved using [`~Mel.save_pretrained`],
293
+ e.g., `./my_model_directory/`.
294
+ subfolder (`str`, *optional*):
295
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
296
+ huggingface.co or downloaded locally), you can specify the folder name here.
297
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
298
+ Whether kwargs that are not consumed by the Python class should be returned or not.
299
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
300
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
301
+ standard cache should not be used.
302
+ force_download (`bool`, *optional*, defaults to `False`):
303
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
304
+ cached versions if they exist.
305
+ resume_download (`bool`, *optional*, defaults to `False`):
306
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
307
+ file exists.
308
+ proxies (`Dict[str, str]`, *optional*):
309
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
310
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
311
+ output_loading_info(`bool`, *optional*, defaults to `False`):
312
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
313
+ local_files_only(`bool`, *optional*, defaults to `False`):
314
+ Whether or not to only look at local files (i.e., do not try to download the model).
315
+ use_auth_token (`str` or *bool*, *optional*):
316
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
317
+ when running `transformers-cli login` (stored in `~/.huggingface`).
318
+ revision (`str`, *optional*, defaults to `"main"`):
319
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
320
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
321
+ identifier allowed by git.
322
+
323
+ <Tip>
324
+
325
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
326
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
327
+
328
+ </Tip>
329
+
330
+ <Tip>
331
+
332
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
333
+ use this method in a firewalled environment.
334
+
335
+ </Tip>
336
+
337
+ """
338
+ config, kwargs = cls.load_config(
339
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
340
+ subfolder=subfolder,
341
+ return_unused_kwargs=True,
342
+ **kwargs,
343
+ )
344
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
345
+
346
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
347
+ """
348
+ Save a mel configuration object to the directory `save_directory`, so that it can be re-loaded using the
349
+ [`~Mel.from_pretrained`] class method.
350
+
351
+ Args:
352
+ save_directory (`str` or `os.PathLike`):
353
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
354
+ """
355
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
356
+
357
+ #-----------------------------------------------------------------------------#
358
+
359
+ from math import acos, sin
360
+ from typing import List, Tuple, Union
361
+
362
+ import numpy as np
363
+ import torch
364
+
365
+ from PIL import Image
366
+
367
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, DDIMScheduler, DDPMScheduler
368
+ from diffusers.pipeline_utils import AudioPipelineOutput, BaseOutput, ImagePipelineOutput
369
+
370
+
371
  class AudioDiffusionPipeline(DiffusionPipeline):
372
+ """
373
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
374
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
375
+
376
+ Parameters:
377
+ vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None
378
+ unet ([`UNet2DConditionModel`]): UNET model
379
+ mel ([`Mel`]): transform audio <-> spectrogram
380
+ scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler
381
+ """
382
+
383
+ _optional_components = ["vqvae"]
384
+
385
+ def __init__(
386
+ self,
387
+ vqvae: AutoencoderKL,
388
+ unet: UNet2DConditionModel,
389
+ mel: Mel,
390
+ scheduler: Union[DDIMScheduler, DDPMScheduler],
391
+ ):
392
  super().__init__()
393
+ self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
394
 
395
  def get_input_dims(self) -> Tuple:
396
  """Returns dimension of input image
397
 
398
  Returns:
399
+ `Tuple`: (height, width)
400
  """
401
+ input_module = self.vqvae if self.vqvae is not None else self.unet
402
  # For backwards compatibility
403
  sample_size = (
404
  (input_module.sample_size, input_module.sample_size)
 
411
  """Returns default number of steps recommended for inference
412
 
413
  Returns:
414
+ `int`: number of steps
415
  """
416
  return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
417
 
418
  @torch.no_grad()
419
  def __call__(
420
  self,
 
421
  batch_size: int = 1,
422
  audio_file: str = None,
423
  raw_audio: np.ndarray = None,
 
437
  """Generate random mel spectrogram from audio input and convert to audio.
438
 
439
  Args:
440
+ batch_size (`int`): number of samples to generate
441
+ audio_file (`str`): must be a file on disk due to Librosa limitation or
442
+ raw_audio (`np.ndarray`): audio as numpy array
443
+ slice (`int`): slice number of audio to convert
 
444
  start_step (int): step to start from
445
+ steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
446
+ generator (`torch.Generator`): random number generator or None
447
+ mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start
448
+ mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end
449
+ step_generator (`torch.Generator`): random number generator used to de-noise or None
450
+ eta (`float`): parameter between 0 and 1 used with DDIM scheduler
451
+ noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
452
+ return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
453
 
454
  Returns:
455
+ `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios
456
  """
457
 
458
  steps = steps or self.get_default_steps()
 
462
  if type(self.unet.sample_size) == int:
463
  self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
464
  input_dims = self.get_input_dims()
465
+ self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
466
  if noise is None:
467
  noise = torch.randn(
468
  (batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]),
 
473
  mask = None
474
 
475
  if audio_file is not None or raw_audio is not None:
476
+ self.mel.load_audio(audio_file, raw_audio)
477
+ input_image = self.mel.audio_slice_to_image(slice)
478
  input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape(
479
  (input_image.height, input_image.width)
480
  )
481
  input_image = (input_image / 255) * 2 - 1
482
  input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device)
483
 
484
+ if self.vqvae is not None:
485
  input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
486
  generator=generator
487
  )[0]
 
490
  if start_step > 0:
491
  images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
492
 
493
+ pixels_per_second = (
494
+ self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
495
+ )
496
  mask_start = int(mask_start_secs * pixels_per_second)
497
  mask_end = int(mask_end_secs * pixels_per_second)
498
  mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
 
515
  if mask_end > 0:
516
  images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
517
 
518
+ if self.vqvae is not None:
519
  # 0.18215 was scaling factor used in training to ensure unit variance
520
  images = 1 / 0.18215 * images
521
  images = self.vqvae.decode(images)["sample"]
 
529
  else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images)
530
  )
531
 
532
+ audios = list(map(lambda _: self.mel.image_to_audio(_), images))
533
  if not return_dict:
534
+ return images, (self.mel.get_sample_rate(), audios)
535
 
536
  return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images))
537
 
 
540
  """Reverse step process: recover noisy image from generated image.
541
 
542
  Args:
543
+ images (`List[PIL Image]`): list of images to encode
544
+ steps (`int`): number of encoding steps to perform (defaults to 50)
545
 
546
  Returns:
547
+ `np.ndarray`: noise tensor of shape (batch_size, 1, height, width)
548
  """
549
 
550
  # Only works with DDIM as this method is deterministic
 
577
  """Spherical Linear intERPolation
578
 
579
  Args:
580
+ x0 (`torch.Tensor`): first tensor to interpolate between
581
+ x1 (`torch.Tensor`): seconds tensor to interpolate between
582
+ alpha (`float`): interpolation between 0 and 1
583
 
584
  Returns:
585
+ `torch.Tensor`: interpolated tensor
586
  """
587
 
588
  theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1))
589
  return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta)
590
 
591
 
592
+ import diffusers
 
 
 
 
 
593
 
594
+ diffusers.Mel = Mel
595
+ setattr(diffusers, Mel.__name__, Mel)
596
+ diffusers.AudioDiffusionPipeline = AudioDiffusionPipeline
597
+ setattr(diffusers, AudioDiffusionPipeline.__name__, AudioDiffusionPipeline)
598
+ diffusers.pipeline_utils.LOADABLE_CLASSES['diffusers']['Mel'] = ["save_pretrained", "from_pretrained"]
audiodiffusion/mel.py DELETED
@@ -1,129 +0,0 @@
1
- import warnings
2
-
3
-
4
- warnings.filterwarnings("ignore")
5
-
6
- import numpy as np # noqa: E402
7
-
8
- import librosa # noqa: E402
9
- from PIL import Image # noqa: E402
10
-
11
-
12
- class Mel:
13
- def __init__(
14
- self,
15
- x_res: int = 256,
16
- y_res: int = 256,
17
- sample_rate: int = 22050,
18
- n_fft: int = 2048,
19
- hop_length: int = 512,
20
- top_db: int = 80,
21
- n_iter: int = 32,
22
- ):
23
- """Class to convert audio to mel spectrograms and vice versa.
24
-
25
- Args:
26
- x_res (int): x resolution of spectrogram (time)
27
- y_res (int): y resolution of spectrogram (frequency bins)
28
- sample_rate (int): sample rate of audio
29
- n_fft (int): number of Fast Fourier Transforms
30
- hop_length (int): hop length (a higher number is recommended for lower than 256 y_res)
31
- top_db (int): loudest in decibels
32
- n_iter (int): number of iterations for Griffin Linn mel inversion
33
- """
34
- self.hop_length = hop_length
35
- self.sr = sample_rate
36
- self.n_fft = n_fft
37
- self.top_db = top_db
38
- self.n_iter = n_iter
39
- self.set_resolution(x_res, y_res)
40
- self.audio = None
41
-
42
- def set_resolution(self, x_res: int, y_res: int):
43
- """Set resolution.
44
-
45
- Args:
46
- x_res (int): x resolution of spectrogram (time)
47
- y_res (int): y resolution of spectrogram (frequency bins)
48
- """
49
- self.x_res = x_res
50
- self.y_res = y_res
51
- self.n_mels = self.y_res
52
- self.slice_size = self.x_res * self.hop_length - 1
53
-
54
- def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
55
- """Load audio.
56
-
57
- Args:
58
- audio_file (str): must be a file on disk due to Librosa limitation or
59
- raw_audio (np.ndarray): audio as numpy array
60
- """
61
- if audio_file is not None:
62
- self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
63
- else:
64
- self.audio = raw_audio
65
-
66
- # Pad with silence if necessary.
67
- if len(self.audio) < self.x_res * self.hop_length:
68
- self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))])
69
-
70
- def get_number_of_slices(self) -> int:
71
- """Get number of slices in audio.
72
-
73
- Returns:
74
- int: number of spectograms audio can be sliced into
75
- """
76
- return len(self.audio) // self.slice_size
77
-
78
- def get_audio_slice(self, slice: int = 0) -> np.ndarray:
79
- """Get slice of audio.
80
-
81
- Args:
82
- slice (int): slice number of audio (out of get_number_of_slices())
83
-
84
- Returns:
85
- np.ndarray: audio as numpy array
86
- """
87
- return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
88
-
89
- def get_sample_rate(self) -> int:
90
- """Get sample rate:
91
-
92
- Returns:
93
- int: sample rate of audio
94
- """
95
- return self.sr
96
-
97
- def audio_slice_to_image(self, slice: int) -> Image.Image:
98
- """Convert slice of audio to spectrogram.
99
-
100
- Args:
101
- slice (int): slice number of audio to convert (out of get_number_of_slices())
102
-
103
- Returns:
104
- PIL Image: grayscale image of x_res x y_res
105
- """
106
- S = librosa.feature.melspectrogram(
107
- y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
108
- )
109
- log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
110
- bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
111
- image = Image.fromarray(bytedata)
112
- return image
113
-
114
- def image_to_audio(self, image: Image.Image) -> np.ndarray:
115
- """Converts spectrogram to audio.
116
-
117
- Args:
118
- image (PIL Image): x_res x y_res grayscale image
119
-
120
- Returns:
121
- audio (np.ndarray): raw audio
122
- """
123
- bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
124
- log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
125
- S = librosa.db_to_power(log_S)
126
- audio = librosa.feature.inverse.mel_to_audio(
127
- S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
128
- )
129
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/audio_diffusion_pipeline.ipynb CHANGED
@@ -89,7 +89,7 @@
89
  "\n",
90
  "#@markdown teticio/audio-diffusion-instrumental-hiphop-256 - trained on instrumental hiphop\n",
91
  "\n",
92
- "model_id = \"teticio/audio-diffusion-256-new\" #@param [\"teticio/audio-diffusion-256\", \"teticio/audio-diffusion-breaks-256\", \"audio-diffusion-instrumenal-hiphop-256\", \"teticio/audio-diffusion-ddim-256\"]"
93
  ]
94
  },
95
  {
@@ -356,7 +356,7 @@
356
  "metadata": {},
357
  "outputs": [],
358
  "source": [
359
- "audio_diffusion = DiffusionPipeline.from_pretrained('teticio/audio-diffusion-ddim-256-new').to(device)\n",
360
  "mel = audio_diffusion.mel\n",
361
  "sample_rate = mel.get_sample_rate()"
362
  ]
@@ -532,7 +532,7 @@
532
  "metadata": {},
533
  "outputs": [],
534
  "source": [
535
- "model_id = \"teticio/latent-audio-diffusion-ddim-256-new\" #@param [\"teticio/latent-audio-diffusion-256\", \"teticio/latent-audio-diffusion-ddim-256\"]"
536
  ]
537
  },
538
  {
 
89
  "\n",
90
  "#@markdown teticio/audio-diffusion-instrumental-hiphop-256 - trained on instrumental hiphop\n",
91
  "\n",
92
+ "model_id = \"teticio/audio-diffusion-256\" #@param [\"teticio/audio-diffusion-256\", \"teticio/audio-diffusion-breaks-256\", \"audio-diffusion-instrumenal-hiphop-256\", \"teticio/audio-diffusion-ddim-256\"]"
93
  ]
94
  },
95
  {
 
356
  "metadata": {},
357
  "outputs": [],
358
  "source": [
359
+ "audio_diffusion = DiffusionPipeline.from_pretrained('teticio/audio-diffusion-ddim-256').to(device)\n",
360
  "mel = audio_diffusion.mel\n",
361
  "sample_rate = mel.get_sample_rate()"
362
  ]
 
532
  "metadata": {},
533
  "outputs": [],
534
  "source": [
535
+ "model_id = \"teticio/latent-audio-diffusion-ddim-256\" #@param [\"teticio/latent-audio-diffusion-256\", \"teticio/latent-audio-diffusion-ddim-256\"]"
536
  ]
537
  },
538
  {
notebooks/test_mel.ipynb CHANGED
@@ -25,18 +25,6 @@
25
  " pass"
26
  ]
27
  },
28
- {
29
- "cell_type": "code",
30
- "execution_count": null,
31
- "id": "21f27189",
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "import os\n",
36
- "import sys\n",
37
- "sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))"
38
- ]
39
- },
40
  {
41
  "cell_type": "code",
42
  "execution_count": null,
@@ -46,7 +34,7 @@
46
  "source": [
47
  "from datasets import load_dataset\n",
48
  "from IPython.display import Audio\n",
49
- "from audiodiffusion.mel import Mel"
50
  ]
51
  },
52
  {
 
25
  " pass"
26
  ]
27
  },
 
 
 
 
 
 
 
 
 
 
 
 
28
  {
29
  "cell_type": "code",
30
  "execution_count": null,
 
34
  "source": [
35
  "from datasets import load_dataset\n",
36
  "from IPython.display import Audio\n",
37
+ "from audiodiffusion import Mel"
38
  ]
39
  },
40
  {
notebooks/test_model.ipynb CHANGED
@@ -49,7 +49,6 @@
49
  "import numpy as np\n",
50
  "from datasets import load_dataset\n",
51
  "from IPython.display import Audio\n",
52
- "from audiodiffusion.mel import Mel\n",
53
  "from audiodiffusion import AudioDiffusion"
54
  ]
55
  },
@@ -60,7 +59,6 @@
60
  "metadata": {},
61
  "outputs": [],
62
  "source": [
63
- "mel = Mel()\n",
64
  "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
65
  "generator = torch.Generator(device=device)"
66
  ]
@@ -104,7 +102,8 @@
104
  "metadata": {},
105
  "outputs": [],
106
  "source": [
107
- "audio_diffusion = AudioDiffusion(model_id=model_id)"
 
108
  ]
109
  },
110
  {
@@ -336,7 +335,8 @@
336
  "metadata": {},
337
  "outputs": [],
338
  "source": [
339
- "audio_diffusion = AudioDiffusion(model_id='teticio/audio-diffusion-ddim-256')"
 
340
  ]
341
  },
342
  {
@@ -507,7 +507,7 @@
507
  "metadata": {},
508
  "outputs": [],
509
  "source": [
510
- "model_id = \"teticio/latent-audio-diffusion-ddim-256\" #@param [\"teticio/latent-audio-diffusion-256\", \"teticio/latent-audio-diffusion-ddim-256\"]"
511
  ]
512
  },
513
  {
@@ -517,7 +517,8 @@
517
  "metadata": {},
518
  "outputs": [],
519
  "source": [
520
- "audio_diffusion = AudioDiffusion(model_id=model_id)"
 
521
  ]
522
  },
523
  {
@@ -568,9 +569,10 @@
568
  "source": [
569
  "generator.manual_seed(seed)\n",
570
  "latents = torch.randn((1, audio_diffusion.pipe.unet.in_channels,\n",
571
- " audio_diffusion.pipe.unet.sample_size[0],\n",
572
- " audio_diffusion.pipe.unet.sample_size[1]),\n",
573
- " generator=generator)\n",
 
574
  "latents.shape"
575
  ]
576
  },
@@ -583,9 +585,10 @@
583
  "source": [
584
  "generator.manual_seed(seed2)\n",
585
  "latents2 = torch.randn((1, audio_diffusion.pipe.unet.in_channels,\n",
586
- " audio_diffusion.pipe.unet.sample_size[0],\n",
587
- " audio_diffusion.pipe.unet.sample_size[1]),\n",
588
- " generator=generator)\n",
 
589
  "latents2.shape"
590
  ]
591
  },
 
49
  "import numpy as np\n",
50
  "from datasets import load_dataset\n",
51
  "from IPython.display import Audio\n",
 
52
  "from audiodiffusion import AudioDiffusion"
53
  ]
54
  },
 
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
 
62
  "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
63
  "generator = torch.Generator(device=device)"
64
  ]
 
102
  "metadata": {},
103
  "outputs": [],
104
  "source": [
105
+ "audio_diffusion = AudioDiffusion(model_id=model_id)\n",
106
+ "mel = audio_diffusion.pipe.mel"
107
  ]
108
  },
109
  {
 
335
  "metadata": {},
336
  "outputs": [],
337
  "source": [
338
+ "audio_diffusion = AudioDiffusion(model_id='teticio/audio-diffusion-ddim-256')\n",
339
+ "mel = audio_diffusion.pipe.mel"
340
  ]
341
  },
342
  {
 
507
  "metadata": {},
508
  "outputs": [],
509
  "source": [
510
+ "model_id = \"teticio/latent-audio-diffusion-ddim-256-new\" #@param [\"teticio/latent-audio-diffusion-256\", \"teticio/latent-audio-diffusion-ddim-256\"]"
511
  ]
512
  },
513
  {
 
517
  "metadata": {},
518
  "outputs": [],
519
  "source": [
520
+ "audio_diffusion = AudioDiffusion(model_id=model_id)\n",
521
+ "mel = audio_diffusion.pipe.mel"
522
  ]
523
  },
524
  {
 
569
  "source": [
570
  "generator.manual_seed(seed)\n",
571
  "latents = torch.randn((1, audio_diffusion.pipe.unet.in_channels,\n",
572
+ " audio_diffusion.pipe.unet.sample_size[0],\n",
573
+ " audio_diffusion.pipe.unet.sample_size[1]),\n",
574
+ " device=device,\n",
575
+ " generator=generator)\n",
576
  "latents.shape"
577
  ]
578
  },
 
585
  "source": [
586
  "generator.manual_seed(seed2)\n",
587
  "latents2 = torch.randn((1, audio_diffusion.pipe.unet.in_channels,\n",
588
+ " audio_diffusion.pipe.unet.sample_size[0],\n",
589
+ " audio_diffusion.pipe.unet.sample_size[1]),\n",
590
+ " device=device,\n",
591
+ " generator=generator)\n",
592
  "latents2.shape"
593
  ]
594
  },
notebooks/test_vae.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  torch
2
  numpy
3
  Pillow
4
- diffusers>=0.4.1
5
  librosa
6
  datasets
7
  gradio
 
1
  torch
2
  numpy
3
  Pillow
4
+ diffusers>=0.9.0
5
  librosa
6
  datasets
7
  gradio
scripts/audio_to_images.py CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
9
  from tqdm.auto import tqdm
10
  from datasets import Dataset, DatasetDict, Features, Image, Value
11
 
12
- from audiodiffusion.mel import Mel
13
 
14
  logging.basicConfig(level=logging.WARN)
15
  logger = logging.getLogger('audio_to_images')
 
9
  from tqdm.auto import tqdm
10
  from datasets import Dataset, DatasetDict, Features, Image, Value
11
 
12
+ from audiodiffusion import Mel
13
 
14
  logging.basicConfig(level=logging.WARN)
15
  logger = logging.getLogger('audio_to_images')
scripts/train_unconditional.py CHANGED
@@ -1,5 +1,9 @@
1
  # based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
2
 
 
 
 
 
3
  import argparse
4
  import os
5
 
@@ -9,8 +13,13 @@ import torch.nn.functional as F
9
  from accelerate import Accelerator
10
  from accelerate.logging import get_logger
11
  from datasets import load_from_disk, load_dataset
12
- from diffusers import (DiffusionPipeline, DDPMScheduler, UNet2DModel,
13
- DDIMScheduler, AutoencoderKL)
 
 
 
 
 
14
  from diffusers.hub_utils import init_git_repo, push_to_hub
15
  from diffusers.optimization import get_scheduler
16
  from diffusers.training_utils import EMAModel
@@ -23,8 +32,8 @@ import numpy as np
23
  from tqdm.auto import tqdm
24
  from librosa.util import normalize
25
 
26
- from audiodiffusion.mel import Mel
27
- from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
28
 
29
  logger = get_logger(__name__)
30
 
@@ -59,7 +68,7 @@ def main(args):
59
  split="train",
60
  )
61
  # Determine image resolution
62
- resolution = dataset[0]['image'].height, dataset[0]['image'].width
63
 
64
  augmentations = Compose([
65
  ToTensor(),
@@ -67,9 +76,9 @@ def main(args):
67
  ])
68
 
69
  def transforms(examples):
70
- if args.vae is not None and vqvae.config['in_channels'] == 3:
71
  images = [
72
- augmentations(image.convert('RGB'))
73
  for image in examples["image"]
74
  ]
75
  else:
@@ -85,32 +94,27 @@ def main(args):
85
  try:
86
  vqvae = AutoencoderKL.from_pretrained(args.vae)
87
  except EnvironmentError:
88
- vqvae = LatentAudioDiffusionPipeline.from_pretrained(
89
  args.vae).vqvae
90
  # Determine latent resolution
91
  with torch.no_grad():
92
- latent_resolution = vqvae.encode(
93
  torch.zeros((1, 1) +
94
- resolution)).latent_dist.sample().shape[2:]
95
 
96
  if args.from_pretrained is not None:
97
- pipeline = {
98
- 'LatentAudioDiffusionPipeline': LatentAudioDiffusionPipeline,
99
- 'AudioDiffusionPipeline': AudioDiffusionPipeline
100
- }.get(
101
- DiffusionPipeline.get_config_dict(
102
- args.from_pretrained)['_class_name'], AudioDiffusionPipeline)
103
- pipeline = pipeline.from_pretrained(args.from_pretrained)
104
  model = pipeline.unet
105
- if hasattr(pipeline, 'vqvae'):
106
  vqvae = pipeline.vqvae
107
  else:
108
  model = UNet2DModel(
109
  sample_size=resolution if vqvae is None else latent_resolution,
110
  in_channels=1
111
- if vqvae is None else vqvae.config['latent_channels'],
112
  out_channels=1
113
- if vqvae is None else vqvae.config['latent_channels'],
114
  layers_per_block=2,
115
  block_out_channels=(128, 128, 256, 256, 512, 512),
116
  down_block_types=(
@@ -171,11 +175,13 @@ def main(args):
171
  run = os.path.split(__file__)[-1].split(".")[0]
172
  accelerator.init_trackers(run)
173
 
174
- mel = Mel(x_res=resolution[1],
175
- y_res=resolution[0],
176
- hop_length=args.hop_length,
177
- sample_rate=args.sample_rate,
178
- n_fft=args.n_fft)
 
 
179
 
180
  global_step = 0
181
  for epoch in range(args.num_epochs):
@@ -256,20 +262,14 @@ def main(args):
256
  if (
257
  epoch + 1
258
  ) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
259
- if vqvae is not None:
260
- pipeline = LatentAudioDiffusionPipeline(
261
- unet=accelerator.unwrap_model(
262
- ema_model.averaged_model if args.use_ema else model
263
- ),
264
- vqvae=vqvae,
265
- scheduler=noise_scheduler)
266
- else:
267
- pipeline = AudioDiffusionPipeline(
268
- unet=accelerator.unwrap_model(
269
- ema_model.averaged_model if args.use_ema else model
270
- ),
271
- scheduler=noise_scheduler,
272
- )
273
 
274
  # save the model
275
  if args.push_to_hub:
@@ -287,12 +287,13 @@ def main(args):
287
  pipeline.save_pretrained(output_dir)
288
 
289
  if (epoch + 1) % args.save_images_epochs == 0:
290
- generator = torch.manual_seed(42)
 
291
  # run pipeline in inference (sample random noise and denoise)
292
  images, (sample_rate, audios) = pipeline(
293
- mel=mel,
294
  generator=generator,
295
  batch_size=args.eval_batch_size,
 
296
  )
297
 
298
  # denormalize the images and save to tensorboard
@@ -373,10 +374,12 @@ if __name__ == "__main__":
373
  type=str,
374
  default="ddpm",
375
  help="ddpm or ddim")
376
- parser.add_argument("--vae",
377
- type=str,
378
- default=None,
379
- help="pretrained VAE model for latent diffusion")
 
 
380
 
381
  args = parser.parse_args()
382
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 
1
  # based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
2
 
3
+ # TODO
4
+ # Migrate to diffusers
5
+ # from diffusers.hub_utils import Repository
6
+
7
  import argparse
8
  import os
9
 
 
13
  from accelerate import Accelerator
14
  from accelerate.logging import get_logger
15
  from datasets import load_from_disk, load_dataset
16
+ from diffusers import (
17
+ #AudioDiffusionPipeline,
18
+ DDPMScheduler,
19
+ UNet2DModel,
20
+ DDIMScheduler,
21
+ AutoencoderKL,
22
+ )
23
  from diffusers.hub_utils import init_git_repo, push_to_hub
24
  from diffusers.optimization import get_scheduler
25
  from diffusers.training_utils import EMAModel
 
32
  from tqdm.auto import tqdm
33
  from librosa.util import normalize
34
 
35
+ #from diffusers import Mel, AudioDiffusionPipeline
36
+ from audiodiffusion import Mel, AudioDiffusionPipeline
37
 
38
  logger = get_logger(__name__)
39
 
 
68
  split="train",
69
  )
70
  # Determine image resolution
71
+ resolution = dataset[0]["image"].height, dataset[0]["image"].width
72
 
73
  augmentations = Compose([
74
  ToTensor(),
 
76
  ])
77
 
78
  def transforms(examples):
79
+ if args.vae is not None and vqvae.config["in_channels"] == 3:
80
  images = [
81
+ augmentations(image.convert("RGB"))
82
  for image in examples["image"]
83
  ]
84
  else:
 
94
  try:
95
  vqvae = AutoencoderKL.from_pretrained(args.vae)
96
  except EnvironmentError:
97
+ vqvae = AudioDiffusionPipeline.from_pretrained(
98
  args.vae).vqvae
99
  # Determine latent resolution
100
  with torch.no_grad():
101
+ latent_resolution = (vqvae.encode(
102
  torch.zeros((1, 1) +
103
+ resolution)).latent_dist.sample().shape[2:])
104
 
105
  if args.from_pretrained is not None:
106
+ pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained)
107
+ mel = pipeline.mel
 
 
 
 
 
108
  model = pipeline.unet
109
+ if hasattr(pipeline, "vqvae"):
110
  vqvae = pipeline.vqvae
111
  else:
112
  model = UNet2DModel(
113
  sample_size=resolution if vqvae is None else latent_resolution,
114
  in_channels=1
115
+ if vqvae is None else vqvae.config["latent_channels"],
116
  out_channels=1
117
+ if vqvae is None else vqvae.config["latent_channels"],
118
  layers_per_block=2,
119
  block_out_channels=(128, 128, 256, 256, 512, 512),
120
  down_block_types=(
 
175
  run = os.path.split(__file__)[-1].split(".")[0]
176
  accelerator.init_trackers(run)
177
 
178
+ mel = Mel(
179
+ x_res=resolution[1],
180
+ y_res=resolution[0],
181
+ hop_length=args.hop_length,
182
+ sample_rate=args.sample_rate,
183
+ n_fft=args.n_fft,
184
+ )
185
 
186
  global_step = 0
187
  for epoch in range(args.num_epochs):
 
262
  if (
263
  epoch + 1
264
  ) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
265
+ pipeline = AudioDiffusionPipeline(
266
+ vqvae=vqvae,
267
+ unet=accelerator.unwrap_model(
268
+ ema_model.averaged_model if args.use_ema else model
269
+ ),
270
+ mel=mel,
271
+ scheduler=noise_scheduler,
272
+ )
 
 
 
 
 
 
273
 
274
  # save the model
275
  if args.push_to_hub:
 
287
  pipeline.save_pretrained(output_dir)
288
 
289
  if (epoch + 1) % args.save_images_epochs == 0:
290
+ generator = torch.Generator(
291
+ device=clean_images.device).manual_seed(42)
292
  # run pipeline in inference (sample random noise and denoise)
293
  images, (sample_rate, audios) = pipeline(
 
294
  generator=generator,
295
  batch_size=args.eval_batch_size,
296
+ return_dict=False
297
  )
298
 
299
  # denormalize the images and save to tensorboard
 
374
  type=str,
375
  default="ddpm",
376
  help="ddpm or ddim")
377
+ parser.add_argument(
378
+ "--vae",
379
+ type=str,
380
+ default=None,
381
+ help="pretrained VAE model for latent diffusion",
382
+ )
383
 
384
  args = parser.parse_args()
385
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
scripts/train_vae.py CHANGED
@@ -17,7 +17,8 @@ from datasets import load_from_disk, load_dataset
17
  from pytorch_lightning.callbacks import Callback, ModelCheckpoint
18
  from pytorch_lightning.utilities.distributed import rank_zero_only
19
 
20
- from audiodiffusion.mel import Mel
 
21
  from audiodiffusion.utils import convert_ldm_to_hf_vae
22
 
23
 
 
17
  from pytorch_lightning.callbacks import Callback, ModelCheckpoint
18
  from pytorch_lightning.utilities.distributed import rank_zero_only
19
 
20
+ #from diffusers import Mel
21
+ from audiodiffusion import Mel
22
  from audiodiffusion.utils import convert_ldm_to_hf_vae
23
 
24
 
setup.cfg CHANGED
@@ -15,6 +15,6 @@ install_requires =
15
  torch
16
  numpy
17
  Pillow
18
- diffusers>=0.4.1
19
  librosa
20
  datasets
 
15
  torch
16
  numpy
17
  Pillow
18
+ diffusers>=0.9.0
19
  librosa
20
  datasets