SeyedAli commited on
Commit
f4fa8f9
·
1 Parent(s): 621f018

Upload 3 files

Browse files
audiodiffusion/__init__.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import acos, sin
2
+ from typing import Iterable, Tuple, Union, List
3
+
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ from tqdm.auto import tqdm
8
+ from librosa.beat import beat_track
9
+ from diffusers import (DiffusionPipeline, UNet2DConditionModel, DDIMScheduler,
10
+ DDPMScheduler, AutoencoderKL)
11
+
12
+ from .mel import Mel
13
+
14
+ VERSION = "1.2.5"
15
+
16
+
17
+ class AudioDiffusion:
18
+
19
+ def __init__(self,
20
+ model_id: str = "teticio/audio-diffusion-256",
21
+ sample_rate: int = 22050,
22
+ n_fft: int = 2048,
23
+ hop_length: int = 512,
24
+ top_db: int = 80,
25
+ cuda: bool = torch.cuda.is_available(),
26
+ progress_bar: Iterable = tqdm):
27
+ """Class for generating audio using De-noising Diffusion Probabilistic Models.
28
+
29
+ Args:
30
+ model_id (String): name of model (local directory or Hugging Face Hub)
31
+ sample_rate (int): sample rate of audio
32
+ n_fft (int): number of Fast Fourier Transforms
33
+ hop_length (int): hop length (a higher number is recommended for lower than 256 y_res)
34
+ top_db (int): loudest in decibels
35
+ cuda (bool): use CUDA?
36
+ progress_bar (iterable): iterable callback for progress updates or None
37
+ """
38
+ self.model_id = model_id
39
+ pipeline = {
40
+ 'LatentAudioDiffusionPipeline': LatentAudioDiffusionPipeline,
41
+ 'AudioDiffusionPipeline': AudioDiffusionPipeline
42
+ }.get(
43
+ DiffusionPipeline.get_config_dict(self.model_id)['_class_name'],
44
+ AudioDiffusionPipeline)
45
+ self.pipe = pipeline.from_pretrained(self.model_id)
46
+ if cuda:
47
+ self.pipe.to("cuda")
48
+ self.progress_bar = progress_bar or (lambda _: _)
49
+
50
+ # For backwards compatibility
51
+ sample_size = (self.pipe.unet.sample_size,
52
+ self.pipe.unet.sample_size) if type(
53
+ self.pipe.unet.sample_size
54
+ ) == int else self.pipe.unet.sample_size
55
+ self.mel = Mel(x_res=sample_size[1],
56
+ y_res=sample_size[0],
57
+ sample_rate=sample_rate,
58
+ n_fft=n_fft,
59
+ hop_length=hop_length,
60
+ top_db=top_db)
61
+
62
+ def generate_spectrogram_and_audio(
63
+ self,
64
+ steps: int = None,
65
+ generator: torch.Generator = None,
66
+ step_generator: torch.Generator = None,
67
+ eta: float = 0,
68
+ noise: torch.Tensor = None
69
+ ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
70
+ """Generate random mel spectrogram and convert to audio.
71
+
72
+ Args:
73
+ steps (int): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
74
+ generator (torch.Generator): random number generator or None
75
+ step_generator (torch.Generator): random number generator used to de-noise or None
76
+ eta (float): parameter between 0 and 1 used with DDIM scheduler
77
+ noise (torch.Tensor): noisy image or None
78
+
79
+ Returns:
80
+ PIL Image: mel spectrogram
81
+ (float, np.ndarray): sample rate and raw audio
82
+ """
83
+ images, (sample_rate,
84
+ audios) = self.pipe(mel=self.mel,
85
+ batch_size=1,
86
+ steps=steps,
87
+ generator=generator,
88
+ step_generator=step_generator,
89
+ eta=eta,
90
+ noise=noise)
91
+ return images[0], (sample_rate, audios[0])
92
+
93
+ def generate_spectrogram_and_audio_from_audio(
94
+ self,
95
+ audio_file: str = None,
96
+ raw_audio: np.ndarray = None,
97
+ slice: int = 0,
98
+ start_step: int = 0,
99
+ steps: int = None,
100
+ generator: torch.Generator = None,
101
+ mask_start_secs: float = 0,
102
+ mask_end_secs: float = 0,
103
+ step_generator: torch.Generator = None,
104
+ eta: float = 0,
105
+ noise: torch.Tensor = None
106
+ ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
107
+ """Generate random mel spectrogram from audio input and convert to audio.
108
+
109
+ Args:
110
+ audio_file (str): must be a file on disk due to Librosa limitation or
111
+ raw_audio (np.ndarray): audio as numpy array
112
+ slice (int): slice number of audio to convert
113
+ start_step (int): step to start from
114
+ steps (int): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
115
+ generator (torch.Generator): random number generator or None
116
+ mask_start_secs (float): number of seconds of audio to mask (not generate) at start
117
+ mask_end_secs (float): number of seconds of audio to mask (not generate) at end
118
+ step_generator (torch.Generator): random number generator used to de-noise or None
119
+ eta (float): parameter between 0 and 1 used with DDIM scheduler
120
+ noise (torch.Tensor): noisy image or None
121
+
122
+ Returns:
123
+ PIL Image: mel spectrogram
124
+ (float, np.ndarray): sample rate and raw audio
125
+ """
126
+
127
+ images, (sample_rate,
128
+ audios) = self.pipe(mel=self.mel,
129
+ batch_size=1,
130
+ audio_file=audio_file,
131
+ raw_audio=raw_audio,
132
+ slice=slice,
133
+ start_step=start_step,
134
+ steps=steps,
135
+ generator=generator,
136
+ mask_start_secs=mask_start_secs,
137
+ mask_end_secs=mask_end_secs,
138
+ step_generator=step_generator,
139
+ eta=eta,
140
+ noise=noise)
141
+ return images[0], (sample_rate, audios[0])
142
+
143
+ @staticmethod
144
+ def loop_it(audio: np.ndarray,
145
+ sample_rate: int,
146
+ loops: int = 12) -> np.ndarray:
147
+ """Loop audio
148
+
149
+ Args:
150
+ audio (np.ndarray): audio as numpy array
151
+ sample_rate (int): sample rate of audio
152
+ loops (int): number of times to loop
153
+
154
+ Returns:
155
+ (float, np.ndarray): sample rate and raw audio or None
156
+ """
157
+ _, beats = beat_track(y=audio, sr=sample_rate, units='samples')
158
+ for beats_in_bar in [16, 12, 8, 4]:
159
+ if len(beats) > beats_in_bar:
160
+ return np.tile(audio[beats[0]:beats[beats_in_bar]], loops)
161
+ return None
162
+
163
+
164
+ class AudioDiffusionPipeline(DiffusionPipeline):
165
+
166
+ def __init__(self, unet: UNet2DConditionModel,
167
+ scheduler: Union[DDIMScheduler, DDPMScheduler]):
168
+ super().__init__()
169
+ self.register_modules(unet=unet, scheduler=scheduler)
170
+
171
+ @torch.no_grad()
172
+ def __call__(
173
+ self,
174
+ mel: Mel,
175
+ batch_size: int = 1,
176
+ audio_file: str = None,
177
+ raw_audio: np.ndarray = None,
178
+ slice: int = 0,
179
+ start_step: int = 0,
180
+ steps: int = None,
181
+ generator: torch.Generator = None,
182
+ mask_start_secs: float = 0,
183
+ mask_end_secs: float = 0,
184
+ step_generator: torch.Generator = None,
185
+ eta: float = 0,
186
+ noise: torch.Tensor = None
187
+ ) -> Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]:
188
+ """Generate random mel spectrogram from audio input and convert to audio.
189
+
190
+ Args:
191
+ mel (Mel): instance of Mel class to perform image <-> audio
192
+ batch_size (int): number of samples to generate
193
+ audio_file (str): must be a file on disk due to Librosa limitation or
194
+ raw_audio (np.ndarray): audio as numpy array
195
+ slice (int): slice number of audio to convert
196
+ start_step (int): step to start from
197
+ steps (int): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
198
+ generator (torch.Generator): random number generator or None
199
+ mask_start_secs (float): number of seconds of audio to mask (not generate) at start
200
+ mask_end_secs (float): number of seconds of audio to mask (not generate) at end
201
+ step_generator (torch.Generator): random number generator used to de-noise or None
202
+ eta (float): parameter between 0 and 1 used with DDIM scheduler
203
+ noise (torch.Tensor): noise tensor of shape (batch_size, 1, height, width) or None
204
+
205
+ Returns:
206
+ List[PIL Image]: mel spectrograms
207
+ (float, List[np.ndarray]): sample rate and raw audios
208
+ """
209
+
210
+ steps = steps or 50 if isinstance(self.scheduler,
211
+ DDIMScheduler) else 1000
212
+ self.scheduler.set_timesteps(steps)
213
+ step_generator = step_generator or generator
214
+ # For backwards compatibility
215
+ if type(self.unet.sample_size) == int:
216
+ self.unet.sample_size = (self.unet.sample_size,
217
+ self.unet.sample_size)
218
+ if noise is None:
219
+ noise = torch.randn(
220
+ (batch_size, self.unet.in_channels, self.unet.sample_size[0],
221
+ self.unet.sample_size[1]),
222
+ generator=generator)
223
+ images = noise
224
+ mask = None
225
+
226
+ if audio_file is not None or raw_audio is not None:
227
+ mel.load_audio(audio_file, raw_audio)
228
+ input_image = mel.audio_slice_to_image(slice)
229
+ input_image = np.frombuffer(input_image.tobytes(),
230
+ dtype="uint8").reshape(
231
+ (input_image.height,
232
+ input_image.width))
233
+ input_image = ((input_image / 255) * 2 - 1)
234
+ input_images = np.tile(input_image, (batch_size, 1, 1, 1))
235
+
236
+ if hasattr(self, 'vqvae'):
237
+ input_images = self.vqvae.encode(
238
+ input_images).latent_dist.sample(generator=generator)
239
+ input_images = 0.18215 * input_images
240
+
241
+ if start_step > 0:
242
+ images[0, 0] = self.scheduler.add_noise(
243
+ torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
244
+ noise, torch.tensor(steps - start_step))
245
+
246
+ pixels_per_second = (self.unet.sample_size[1] *
247
+ mel.get_sample_rate() / mel.x_res /
248
+ mel.hop_length)
249
+ mask_start = int(mask_start_secs * pixels_per_second)
250
+ mask_end = int(mask_end_secs * pixels_per_second)
251
+ mask = self.scheduler.add_noise(
252
+ torch.tensor(input_images[:, np.newaxis, :]), noise,
253
+ torch.tensor(self.scheduler.timesteps[start_step:]))
254
+
255
+ images = images.to(self.device)
256
+ for step, t in enumerate(
257
+ self.progress_bar(self.scheduler.timesteps[start_step:])):
258
+ model_output = self.unet(images, t)['sample']
259
+
260
+ if isinstance(self.scheduler, DDIMScheduler):
261
+ images = self.scheduler.step(
262
+ model_output=model_output,
263
+ timestep=t,
264
+ sample=images,
265
+ eta=eta,
266
+ generator=step_generator)['prev_sample']
267
+ else:
268
+ images = self.scheduler.step(
269
+ model_output=model_output,
270
+ timestep=t,
271
+ sample=images,
272
+ generator=step_generator)['prev_sample']
273
+
274
+ if mask is not None:
275
+ if mask_start > 0:
276
+ images[:, :, :, :mask_start] = mask[
277
+ step, :, :, :, :mask_start]
278
+ if mask_end > 0:
279
+ images[:, :, :, -mask_end:] = mask[step, :, :, :,
280
+ -mask_end:]
281
+
282
+ if hasattr(self, 'vqvae'):
283
+ # 0.18215 was scaling factor used in training to ensure unit variance
284
+ images = 1 / 0.18215 * images
285
+ images = self.vqvae.decode(images)['sample']
286
+
287
+ images = (images / 2 + 0.5).clamp(0, 1)
288
+ images = images.cpu().permute(0, 2, 3, 1).numpy()
289
+ images = (images * 255).round().astype("uint8")
290
+ images = list(
291
+ map(lambda _: Image.fromarray(_[:, :, 0]), images) if images.
292
+ shape[3] == 1 else map(
293
+ lambda _: Image.fromarray(_, mode='RGB').convert('L'), images))
294
+
295
+ audios = list(map(lambda _: mel.image_to_audio(_), images))
296
+ return images, (mel.get_sample_rate(), audios)
297
+
298
+ @torch.no_grad()
299
+ def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
300
+ """Reverse step process: recover noisy image from generated image.
301
+
302
+ Args:
303
+ images (List[PIL Image]): list of images to encode
304
+ steps (int): number of encoding steps to perform (defaults to 50)
305
+
306
+ Returns:
307
+ np.ndarray: noise tensor of shape (batch_size, 1, height, width)
308
+ """
309
+
310
+ # Only works with DDIM as this method is deterministic
311
+ assert isinstance(self.scheduler, DDIMScheduler)
312
+ self.scheduler.set_timesteps(steps)
313
+ sample = np.array([
314
+ np.frombuffer(image.tobytes(), dtype="uint8").reshape(
315
+ (1, image.height, image.width)) for image in images
316
+ ])
317
+ sample = ((sample / 255) * 2 - 1)
318
+ sample = torch.Tensor(sample).to(self.device)
319
+
320
+ for t in self.progress_bar(torch.flip(self.scheduler.timesteps,
321
+ (0, ))):
322
+ prev_timestep = (t - self.scheduler.num_train_timesteps //
323
+ self.scheduler.num_inference_steps)
324
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
325
+ alpha_prod_t_prev = (self.scheduler.alphas_cumprod[prev_timestep]
326
+ if prev_timestep >= 0 else
327
+ self.scheduler.final_alpha_cumprod)
328
+ beta_prod_t = 1 - alpha_prod_t
329
+ model_output = self.unet(sample, t)['sample']
330
+ pred_sample_direction = (1 -
331
+ alpha_prod_t_prev)**(0.5) * model_output
332
+ sample = (sample -
333
+ pred_sample_direction) * alpha_prod_t_prev**(-0.5)
334
+ sample = sample * alpha_prod_t**(0.5) + beta_prod_t**(
335
+ 0.5) * model_output
336
+
337
+ return sample
338
+
339
+ @staticmethod
340
+ def slerp(x0: torch.Tensor, x1: torch.Tensor,
341
+ alpha: float) -> torch.Tensor:
342
+ """Spherical Linear intERPolation
343
+
344
+ Args:
345
+ x0 (torch.Tensor): first tensor to interpolate between
346
+ x1 (torch.Tensor): seconds tensor to interpolate between
347
+ alpha (float): interpolation between 0 and 1
348
+
349
+ Returns:
350
+ torch.Tensor: interpolated tensor
351
+ """
352
+
353
+ theta = acos(
354
+ torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) /
355
+ torch.norm(x1))
356
+ return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(
357
+ alpha * theta) * x1 / sin(theta)
358
+
359
+
360
+ class LatentAudioDiffusionPipeline(AudioDiffusionPipeline):
361
+
362
+ def __init__(self, unet: UNet2DConditionModel,
363
+ scheduler: Union[DDIMScheduler,
364
+ DDPMScheduler], vqvae: AutoencoderKL):
365
+ super().__init__(unet=unet, scheduler=scheduler)
366
+ self.register_modules(vqvae=vqvae)
367
+
368
+ def __call__(self, *args, **kwargs):
369
+ return super().__call__(*args, **kwargs)
audiodiffusion/mel.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings('ignore')
4
+
5
+ import librosa
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ class Mel:
11
+
12
+ def __init__(
13
+ self,
14
+ x_res: int = 256,
15
+ y_res: int = 256,
16
+ sample_rate: int = 22050,
17
+ n_fft: int = 2048,
18
+ hop_length: int = 512,
19
+ top_db: int = 80,
20
+ ):
21
+ """Class to convert audio to mel spectrograms and vice versa.
22
+
23
+ Args:
24
+ x_res (int): x resolution of spectrogram (time)
25
+ y_res (int): y resolution of spectrogram (frequency bins)
26
+ sample_rate (int): sample rate of audio
27
+ n_fft (int): number of Fast Fourier Transforms
28
+ hop_length (int): hop length (a higher number is recommended for lower than 256 y_res)
29
+ top_db (int): loudest in decibels
30
+ """
31
+ self.x_res = x_res
32
+ self.y_res = y_res
33
+ self.sr = sample_rate
34
+ self.n_fft = n_fft
35
+ self.hop_length = hop_length
36
+ self.n_mels = self.y_res
37
+ self.slice_size = self.x_res * self.hop_length - 1
38
+ self.fmax = self.sr / 2
39
+ self.top_db = top_db
40
+ self.audio = None
41
+
42
+ def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
43
+ """Load audio.
44
+
45
+ Args:
46
+ audio_file (str): must be a file on disk due to Librosa limitation or
47
+ raw_audio (np.ndarray): audio as numpy array
48
+ """
49
+ if audio_file is not None:
50
+ self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
51
+ else:
52
+ self.audio = raw_audio
53
+
54
+ # Pad with silence if necessary.
55
+ if len(self.audio) < self.x_res * self.hop_length:
56
+ self.audio = np.concatenate([
57
+ self.audio,
58
+ np.zeros((self.x_res * self.hop_length - len(self.audio), ))
59
+ ])
60
+
61
+ def get_number_of_slices(self) -> int:
62
+ """Get number of slices in audio.
63
+
64
+ Returns:
65
+ int: number of spectograms audio can be sliced into
66
+ """
67
+ return len(self.audio) // self.slice_size
68
+
69
+ def get_audio_slice(self, slice: int = 0) -> np.ndarray:
70
+ """Get slice of audio.
71
+
72
+ Args:
73
+ slice (int): slice number of audio (out of get_number_of_slices())
74
+
75
+ Returns:
76
+ np.ndarray: audio as numpy array
77
+ """
78
+ return self.audio[self.slice_size * slice:self.slice_size *
79
+ (slice + 1)]
80
+
81
+ def get_sample_rate(self) -> int:
82
+ """Get sample rate:
83
+
84
+ Returns:
85
+ int: sample rate of audio
86
+ """
87
+ return self.sr
88
+
89
+ def audio_slice_to_image(self, slice: int) -> Image.Image:
90
+ """Convert slice of audio to spectrogram.
91
+
92
+ Args:
93
+ slice (int): slice number of audio to convert (out of get_number_of_slices())
94
+
95
+ Returns:
96
+ PIL Image: grayscale image of x_res x y_res
97
+ """
98
+ S = librosa.feature.melspectrogram(
99
+ y=self.get_audio_slice(slice),
100
+ sr=self.sr,
101
+ n_fft=self.n_fft,
102
+ hop_length=self.hop_length,
103
+ n_mels=self.n_mels,
104
+ fmax=self.fmax,
105
+ )
106
+ log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
107
+ bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) +
108
+ 0.5).astype(np.uint8)
109
+ image = Image.fromarray(bytedata)
110
+ return image
111
+
112
+ def image_to_audio(self, image: Image.Image) -> np.ndarray:
113
+ """Converts spectrogram to audio.
114
+
115
+ Args:
116
+ image (PIL Image): x_res x y_res grayscale image
117
+
118
+ Returns:
119
+ audio (np.ndarray): raw audio
120
+ """
121
+ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
122
+ (image.height, image.width))
123
+ log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
124
+ S = librosa.db_to_power(log_S)
125
+ audio = librosa.feature.inverse.mel_to_audio(
126
+ S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length)
127
+ return audio
audiodiffusion/utils.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adpated from https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKL
5
+
6
+
7
+ def shave_segments(path, n_shave_prefix_segments=1):
8
+ """
9
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
10
+ """
11
+ if n_shave_prefix_segments >= 0:
12
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
13
+ else:
14
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
15
+
16
+
17
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
18
+ """
19
+ Updates paths inside resnets to the new naming scheme (local renaming)
20
+ """
21
+ mapping = []
22
+ for old_item in old_list:
23
+ new_item = old_item
24
+
25
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
26
+ new_item = shave_segments(
27
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments)
28
+
29
+ mapping.append({"old": old_item, "new": new_item})
30
+
31
+ return mapping
32
+
33
+
34
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
35
+ """
36
+ Updates paths inside attentions to the new naming scheme (local renaming)
37
+ """
38
+ mapping = []
39
+ for old_item in old_list:
40
+ new_item = old_item
41
+
42
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
43
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
44
+
45
+ new_item = new_item.replace("q.weight", "query.weight")
46
+ new_item = new_item.replace("q.bias", "query.bias")
47
+
48
+ new_item = new_item.replace("k.weight", "key.weight")
49
+ new_item = new_item.replace("k.bias", "key.bias")
50
+
51
+ new_item = new_item.replace("v.weight", "value.weight")
52
+ new_item = new_item.replace("v.bias", "value.bias")
53
+
54
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
55
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
56
+
57
+ new_item = shave_segments(
58
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments)
59
+
60
+ mapping.append({"old": old_item, "new": new_item})
61
+
62
+ return mapping
63
+
64
+
65
+ def assign_to_checkpoint(paths,
66
+ checkpoint,
67
+ old_checkpoint,
68
+ attention_paths_to_split=None,
69
+ additional_replacements=None,
70
+ config=None):
71
+ """
72
+ This does the final conversion step: take locally converted weights and apply a global renaming
73
+ to them. It splits attention layers, and takes into account additional replacements
74
+ that may arise.
75
+
76
+ Assigns the weights to the new checkpoint.
77
+ """
78
+ assert isinstance(
79
+ paths, list
80
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
81
+
82
+ # Splits the attention layers into three variables.
83
+ if attention_paths_to_split is not None:
84
+ for path, path_map in attention_paths_to_split.items():
85
+ old_tensor = old_checkpoint[path]
86
+ channels = old_tensor.shape[0] // 3
87
+
88
+ target_shape = (-1,
89
+ channels) if len(old_tensor.shape) == 3 else (-1)
90
+
91
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
92
+
93
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels //
94
+ num_heads) + old_tensor.shape[1:])
95
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
96
+
97
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
98
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
99
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
100
+
101
+ for path in paths:
102
+ new_path = path["new"]
103
+
104
+ # These have already been assigned
105
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
106
+ continue
107
+
108
+ # Global renaming happens here
109
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
110
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
111
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
112
+
113
+ if additional_replacements is not None:
114
+ for replacement in additional_replacements:
115
+ new_path = new_path.replace(replacement["old"],
116
+ replacement["new"])
117
+
118
+ # proj_attn.weight has to be converted from conv 1D to linear
119
+ if "proj_attn.weight" in new_path:
120
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
121
+ else:
122
+ checkpoint[new_path] = old_checkpoint[path["old"]]
123
+
124
+
125
+ def conv_attn_to_linear(checkpoint):
126
+ keys = list(checkpoint.keys())
127
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
128
+ for key in keys:
129
+ if ".".join(key.split(".")[-2:]) in attn_keys:
130
+ if checkpoint[key].ndim > 2:
131
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
132
+ elif "proj_attn.weight" in key:
133
+ if checkpoint[key].ndim > 2:
134
+ checkpoint[key] = checkpoint[key][:, :, 0]
135
+
136
+
137
+ def create_vae_diffusers_config(original_config):
138
+ """
139
+ Creates a config for the diffusers based on the config of the LDM model.
140
+ """
141
+ vae_params = original_config.model.params.ddconfig
142
+ _ = original_config.model.params.embed_dim
143
+
144
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
145
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
146
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
147
+
148
+ config = dict(
149
+ sample_size=vae_params.resolution,
150
+ in_channels=vae_params.in_channels,
151
+ out_channels=vae_params.out_ch,
152
+ down_block_types=tuple(down_block_types),
153
+ up_block_types=tuple(up_block_types),
154
+ block_out_channels=tuple(block_out_channels),
155
+ latent_channels=vae_params.z_channels,
156
+ layers_per_block=vae_params.num_res_blocks,
157
+ )
158
+ return config
159
+
160
+
161
+ def convert_ldm_vae_checkpoint(checkpoint, config):
162
+ # extract state dict for VAE
163
+ vae_state_dict = checkpoint
164
+
165
+ new_checkpoint = {}
166
+
167
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict[
168
+ "encoder.conv_in.weight"]
169
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict[
170
+ "encoder.conv_in.bias"]
171
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
172
+ "encoder.conv_out.weight"]
173
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict[
174
+ "encoder.conv_out.bias"]
175
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
176
+ "encoder.norm_out.weight"]
177
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
178
+ "encoder.norm_out.bias"]
179
+
180
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict[
181
+ "decoder.conv_in.weight"]
182
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict[
183
+ "decoder.conv_in.bias"]
184
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
185
+ "decoder.conv_out.weight"]
186
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict[
187
+ "decoder.conv_out.bias"]
188
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
189
+ "decoder.norm_out.weight"]
190
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
191
+ "decoder.norm_out.bias"]
192
+
193
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
194
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
195
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict[
196
+ "post_quant_conv.weight"]
197
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict[
198
+ "post_quant_conv.bias"]
199
+
200
+ # Retrieves the keys for the encoder down blocks only
201
+ num_down_blocks = len({
202
+ ".".join(layer.split(".")[:3])
203
+ for layer in vae_state_dict if "encoder.down" in layer
204
+ })
205
+ down_blocks = {
206
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
207
+ for layer_id in range(num_down_blocks)
208
+ }
209
+
210
+ # Retrieves the keys for the decoder up blocks only
211
+ num_up_blocks = len({
212
+ ".".join(layer.split(".")[:3])
213
+ for layer in vae_state_dict if "decoder.up" in layer
214
+ })
215
+ up_blocks = {
216
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
217
+ for layer_id in range(num_up_blocks)
218
+ }
219
+
220
+ for i in range(num_down_blocks):
221
+ resnets = [
222
+ key for key in down_blocks[i]
223
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
224
+ ]
225
+
226
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
227
+ new_checkpoint[
228
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
229
+ f"encoder.down.{i}.downsample.conv.weight")
230
+ new_checkpoint[
231
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
232
+ f"encoder.down.{i}.downsample.conv.bias")
233
+
234
+ paths = renew_vae_resnet_paths(resnets)
235
+ meta_path = {
236
+ "old": f"down.{i}.block",
237
+ "new": f"down_blocks.{i}.resnets"
238
+ }
239
+ assign_to_checkpoint(paths,
240
+ new_checkpoint,
241
+ vae_state_dict,
242
+ additional_replacements=[meta_path],
243
+ config=config)
244
+
245
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
246
+ num_mid_res_blocks = 2
247
+ for i in range(1, num_mid_res_blocks + 1):
248
+ resnets = [
249
+ key for key in mid_resnets if f"encoder.mid.block_{i}" in key
250
+ ]
251
+
252
+ paths = renew_vae_resnet_paths(resnets)
253
+ meta_path = {
254
+ "old": f"mid.block_{i}",
255
+ "new": f"mid_block.resnets.{i - 1}"
256
+ }
257
+ assign_to_checkpoint(paths,
258
+ new_checkpoint,
259
+ vae_state_dict,
260
+ additional_replacements=[meta_path],
261
+ config=config)
262
+
263
+ mid_attentions = [
264
+ key for key in vae_state_dict if "encoder.mid.attn" in key
265
+ ]
266
+ paths = renew_vae_attention_paths(mid_attentions)
267
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
268
+ assign_to_checkpoint(paths,
269
+ new_checkpoint,
270
+ vae_state_dict,
271
+ additional_replacements=[meta_path],
272
+ config=config)
273
+ conv_attn_to_linear(new_checkpoint)
274
+
275
+ for i in range(num_up_blocks):
276
+ block_id = num_up_blocks - 1 - i
277
+ resnets = [
278
+ key for key in up_blocks[block_id]
279
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
280
+ ]
281
+
282
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
283
+ new_checkpoint[
284
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
285
+ f"decoder.up.{block_id}.upsample.conv.weight"]
286
+ new_checkpoint[
287
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
288
+ f"decoder.up.{block_id}.upsample.conv.bias"]
289
+
290
+ paths = renew_vae_resnet_paths(resnets)
291
+ meta_path = {
292
+ "old": f"up.{block_id}.block",
293
+ "new": f"up_blocks.{i}.resnets"
294
+ }
295
+ assign_to_checkpoint(paths,
296
+ new_checkpoint,
297
+ vae_state_dict,
298
+ additional_replacements=[meta_path],
299
+ config=config)
300
+
301
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
302
+ num_mid_res_blocks = 2
303
+ for i in range(1, num_mid_res_blocks + 1):
304
+ resnets = [
305
+ key for key in mid_resnets if f"decoder.mid.block_{i}" in key
306
+ ]
307
+
308
+ paths = renew_vae_resnet_paths(resnets)
309
+ meta_path = {
310
+ "old": f"mid.block_{i}",
311
+ "new": f"mid_block.resnets.{i - 1}"
312
+ }
313
+ assign_to_checkpoint(paths,
314
+ new_checkpoint,
315
+ vae_state_dict,
316
+ additional_replacements=[meta_path],
317
+ config=config)
318
+
319
+ mid_attentions = [
320
+ key for key in vae_state_dict if "decoder.mid.attn" in key
321
+ ]
322
+ paths = renew_vae_attention_paths(mid_attentions)
323
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
324
+ assign_to_checkpoint(paths,
325
+ new_checkpoint,
326
+ vae_state_dict,
327
+ additional_replacements=[meta_path],
328
+ config=config)
329
+ conv_attn_to_linear(new_checkpoint)
330
+ return new_checkpoint
331
+
332
+ def convert_ldm_to_hf_vae(ldm_checkpoint, ldm_config, hf_checkpoint):
333
+ checkpoint = torch.load(ldm_checkpoint)["state_dict"]
334
+
335
+ # Convert the VAE model.
336
+ vae_config = create_vae_diffusers_config(ldm_config)
337
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
338
+ checkpoint, vae_config)
339
+
340
+ vae = AutoencoderKL(**vae_config)
341
+ vae.load_state_dict(converted_vae_checkpoint)
342
+ vae.save_pretrained(hf_checkpoint)