teticio commited on
Commit
59e49d0
1 Parent(s): b981ece

fixes for latent diffusion

Browse files
Files changed (1) hide show
  1. audiodiffusion/__init__.py +36 -17
audiodiffusion/__init__.py CHANGED
@@ -11,7 +11,7 @@ from diffusers import (DiffusionPipeline, UNet2DConditionModel, DDIMScheduler,
11
 
12
  from .mel import Mel
13
 
14
- VERSION = "1.2.5"
15
 
16
 
17
  class AudioDiffusion:
@@ -47,11 +47,7 @@ class AudioDiffusion:
47
  self.pipe.to("cuda")
48
  self.progress_bar = progress_bar or (lambda _: _)
49
 
50
- # For backwards compatibility
51
- sample_size = (self.pipe.unet.sample_size,
52
- self.pipe.unet.sample_size) if type(
53
- self.pipe.unet.sample_size
54
- ) == int else self.pipe.unet.sample_size
55
  self.mel = Mel(x_res=sample_size[1],
56
  y_res=sample_size[0],
57
  sample_rate=sample_rate,
@@ -168,6 +164,27 @@ class AudioDiffusionPipeline(DiffusionPipeline):
168
  super().__init__()
169
  self.register_modules(unet=unet, scheduler=scheduler)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  @torch.no_grad()
172
  def __call__(
173
  self,
@@ -207,8 +224,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
207
  (float, List[np.ndarray]): sample rate and raw audios
208
  """
209
 
210
- steps = steps or 50 if isinstance(self.scheduler,
211
- DDIMScheduler) else 1000
212
  self.scheduler.set_timesteps(steps)
213
  step_generator = step_generator or generator
214
  # For backwards compatibility
@@ -231,17 +247,21 @@ class AudioDiffusionPipeline(DiffusionPipeline):
231
  (input_image.height,
232
  input_image.width))
233
  input_image = ((input_image / 255) * 2 - 1)
234
- input_images = np.tile(input_image, (batch_size, 1, 1, 1))
 
235
 
236
  if hasattr(self, 'vqvae'):
237
  input_images = self.vqvae.encode(
238
- input_images).latent_dist.sample(generator=generator)
 
 
239
  input_images = 0.18215 * input_images
240
 
241
  if start_step > 0:
242
  images[0, 0] = self.scheduler.add_noise(
243
- torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
244
- noise, torch.tensor(steps - start_step))
 
245
 
246
  pixels_per_second = (self.unet.sample_size[1] *
247
  mel.get_sample_rate() / mel.x_res /
@@ -249,7 +269,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
249
  mask_start = int(mask_start_secs * pixels_per_second)
250
  mask_end = int(mask_end_secs * pixels_per_second)
251
  mask = self.scheduler.add_noise(
252
- torch.tensor(input_images[:, np.newaxis, :]), noise,
253
  torch.tensor(self.scheduler.timesteps[start_step:]))
254
 
255
  images = images.to(self.device)
@@ -273,11 +293,10 @@ class AudioDiffusionPipeline(DiffusionPipeline):
273
 
274
  if mask is not None:
275
  if mask_start > 0:
276
- images[:, :, :, :mask_start] = mask[
277
- step, :, :, :, :mask_start]
278
  if mask_end > 0:
279
- images[:, :, :, -mask_end:] = mask[step, :, :, :,
280
- -mask_end:]
281
 
282
  if hasattr(self, 'vqvae'):
283
  # 0.18215 was scaling factor used in training to ensure unit variance
 
11
 
12
  from .mel import Mel
13
 
14
+ VERSION = "1.2.6"
15
 
16
 
17
  class AudioDiffusion:
 
47
  self.pipe.to("cuda")
48
  self.progress_bar = progress_bar or (lambda _: _)
49
 
50
+ sample_size = self.pipe.get_input_dims()
 
 
 
 
51
  self.mel = Mel(x_res=sample_size[1],
52
  y_res=sample_size[0],
53
  sample_rate=sample_rate,
 
164
  super().__init__()
165
  self.register_modules(unet=unet, scheduler=scheduler)
166
 
167
+ def get_input_dims(self) -> Tuple:
168
+ """Returns dimension of input image
169
+
170
+ Returns:
171
+ Tuple: (height, width)
172
+ """
173
+ input_module = self.vqvae if hasattr(self, 'vqvae') else self.unet
174
+ # For backwards compatibility
175
+ sample_size = (
176
+ input_module.sample_size, input_module.sample_size) if type(
177
+ input_module.sample_size) == int else input_module.sample_size
178
+ return sample_size
179
+
180
+ def get_default_steps(self) -> int:
181
+ """Returns default number of steps recommended for inference
182
+
183
+ Returns:
184
+ int: number of steps
185
+ """
186
+ return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
187
+
188
  @torch.no_grad()
189
  def __call__(
190
  self,
 
224
  (float, List[np.ndarray]): sample rate and raw audios
225
  """
226
 
227
+ steps = steps or self.get_default_steps()
 
228
  self.scheduler.set_timesteps(steps)
229
  step_generator = step_generator or generator
230
  # For backwards compatibility
 
247
  (input_image.height,
248
  input_image.width))
249
  input_image = ((input_image / 255) * 2 - 1)
250
+ input_images = torch.tensor(input_image[np.newaxis, :, :],
251
+ dtype=torch.float)
252
 
253
  if hasattr(self, 'vqvae'):
254
  input_images = self.vqvae.encode(
255
+ torch.unsqueeze(input_images,
256
+ 0).to(self.device)).latent_dist.sample(
257
+ generator=generator).cpu()[0]
258
  input_images = 0.18215 * input_images
259
 
260
  if start_step > 0:
261
  images[0, 0] = self.scheduler.add_noise(
262
+ input_images, noise,
263
+ self.scheduler.timesteps[start_step - 1])
264
+ print(self.scheduler.timesteps[start_step - 1])
265
 
266
  pixels_per_second = (self.unet.sample_size[1] *
267
  mel.get_sample_rate() / mel.x_res /
 
269
  mask_start = int(mask_start_secs * pixels_per_second)
270
  mask_end = int(mask_end_secs * pixels_per_second)
271
  mask = self.scheduler.add_noise(
272
+ input_images, noise,
273
  torch.tensor(self.scheduler.timesteps[start_step:]))
274
 
275
  images = images.to(self.device)
 
293
 
294
  if mask is not None:
295
  if mask_start > 0:
296
+ images[:, :, :, :mask_start] = mask[:,
297
+ step, :, :mask_start]
298
  if mask_end > 0:
299
+ images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
 
300
 
301
  if hasattr(self, 'vqvae'):
302
  # 0.18215 was scaling factor used in training to ensure unit variance