toto10 commited on
Commit
8736faa
·
1 Parent(s): 8ebef76

60aad666c3153a196036b6ec1827ad52480727f2aa0721b4b0845396eb0febee

Browse files
Files changed (50) hide show
  1. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc +0 -0
  2. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc +0 -0
  3. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc +0 -0
  4. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/loss.cpython-310.pyc +0 -0
  5. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc +0 -0
  6. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc +0 -0
  7. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc +0 -0
  8. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc +0 -0
  9. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc +0 -0
  10. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc +0 -0
  11. repositories/generative-models/sgm/modules/diffusionmodules/denoiser.py +63 -0
  12. repositories/generative-models/sgm/modules/diffusionmodules/denoiser_scaling.py +31 -0
  13. repositories/generative-models/sgm/modules/diffusionmodules/denoiser_weighting.py +24 -0
  14. repositories/generative-models/sgm/modules/diffusionmodules/discretizer.py +68 -0
  15. repositories/generative-models/sgm/modules/diffusionmodules/guiders.py +53 -0
  16. repositories/generative-models/sgm/modules/diffusionmodules/loss.py +69 -0
  17. repositories/generative-models/sgm/modules/diffusionmodules/model.py +743 -0
  18. repositories/generative-models/sgm/modules/diffusionmodules/openaimodel.py +1262 -0
  19. repositories/generative-models/sgm/modules/diffusionmodules/sampling.py +365 -0
  20. repositories/generative-models/sgm/modules/diffusionmodules/sampling_utils.py +48 -0
  21. repositories/generative-models/sgm/modules/diffusionmodules/sigma_sampling.py +31 -0
  22. repositories/generative-models/sgm/modules/diffusionmodules/util.py +308 -0
  23. repositories/generative-models/sgm/modules/diffusionmodules/wrappers.py +34 -0
  24. repositories/generative-models/sgm/modules/distributions/__init__.py +0 -0
  25. repositories/generative-models/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc +0 -0
  26. repositories/generative-models/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc +0 -0
  27. repositories/generative-models/sgm/modules/distributions/distributions.py +102 -0
  28. repositories/generative-models/sgm/modules/ema.py +86 -0
  29. repositories/generative-models/sgm/modules/encoders/__init__.py +0 -0
  30. repositories/generative-models/sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc +0 -0
  31. repositories/generative-models/sgm/modules/encoders/__pycache__/modules.cpython-310.pyc +0 -0
  32. repositories/generative-models/sgm/modules/encoders/modules.py +960 -0
  33. repositories/generative-models/sgm/util.py +231 -0
  34. repositories/k-diffusion/.github/workflows/python-publish.yml +37 -0
  35. repositories/k-diffusion/.gitignore +10 -0
  36. repositories/k-diffusion/LICENSE +19 -0
  37. repositories/k-diffusion/README.md +61 -0
  38. repositories/k-diffusion/configs/config_32x32_small.json +43 -0
  39. repositories/k-diffusion/configs/config_32x32_small_butterflies.json +44 -0
  40. repositories/k-diffusion/configs/config_cifar10.json +43 -0
  41. repositories/k-diffusion/configs/config_mnist.json +43 -0
  42. repositories/k-diffusion/k_diffusion/__init__.py +2 -0
  43. repositories/k-diffusion/k_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  44. repositories/k-diffusion/k_diffusion/__pycache__/augmentation.cpython-310.pyc +0 -0
  45. repositories/k-diffusion/k_diffusion/__pycache__/config.cpython-310.pyc +0 -0
  46. repositories/k-diffusion/k_diffusion/__pycache__/evaluation.cpython-310.pyc +0 -0
  47. repositories/k-diffusion/k_diffusion/__pycache__/external.cpython-310.pyc +0 -0
  48. repositories/k-diffusion/k_diffusion/__pycache__/gns.cpython-310.pyc +0 -0
  49. repositories/k-diffusion/k_diffusion/__pycache__/layers.cpython-310.pyc +0 -0
  50. repositories/k-diffusion/k_diffusion/__pycache__/sampling.cpython-310.pyc +0 -0
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc ADDED
Binary file (2.66 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc ADDED
Binary file (1.51 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc ADDED
Binary file (3.02 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/loss.cpython-310.pyc ADDED
Binary file (2.34 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc ADDED
Binary file (16.4 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc ADDED
Binary file (27.5 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc ADDED
Binary file (1.88 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc ADDED
Binary file (1.7 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/denoiser.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from ...util import append_dims, instantiate_from_config
4
+
5
+
6
+ class Denoiser(nn.Module):
7
+ def __init__(self, weighting_config, scaling_config):
8
+ super().__init__()
9
+
10
+ self.weighting = instantiate_from_config(weighting_config)
11
+ self.scaling = instantiate_from_config(scaling_config)
12
+
13
+ def possibly_quantize_sigma(self, sigma):
14
+ return sigma
15
+
16
+ def possibly_quantize_c_noise(self, c_noise):
17
+ return c_noise
18
+
19
+ def w(self, sigma):
20
+ return self.weighting(sigma)
21
+
22
+ def __call__(self, network, input, sigma, cond):
23
+ sigma = self.possibly_quantize_sigma(sigma)
24
+ sigma_shape = sigma.shape
25
+ sigma = append_dims(sigma, input.ndim)
26
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
27
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
28
+ return network(input * c_in, c_noise, cond) * c_out + input * c_skip
29
+
30
+
31
+ class DiscreteDenoiser(Denoiser):
32
+ def __init__(
33
+ self,
34
+ weighting_config,
35
+ scaling_config,
36
+ num_idx,
37
+ discretization_config,
38
+ do_append_zero=False,
39
+ quantize_c_noise=True,
40
+ flip=True,
41
+ ):
42
+ super().__init__(weighting_config, scaling_config)
43
+ sigmas = instantiate_from_config(discretization_config)(
44
+ num_idx, do_append_zero=do_append_zero, flip=flip
45
+ )
46
+ self.register_buffer("sigmas", sigmas)
47
+ self.quantize_c_noise = quantize_c_noise
48
+
49
+ def sigma_to_idx(self, sigma):
50
+ dists = sigma - self.sigmas[:, None]
51
+ return dists.abs().argmin(dim=0).view(sigma.shape)
52
+
53
+ def idx_to_sigma(self, idx):
54
+ return self.sigmas[idx]
55
+
56
+ def possibly_quantize_sigma(self, sigma):
57
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
58
+
59
+ def possibly_quantize_c_noise(self, c_noise):
60
+ if self.quantize_c_noise:
61
+ return self.sigma_to_idx(c_noise)
62
+ else:
63
+ return c_noise
repositories/generative-models/sgm/modules/diffusionmodules/denoiser_scaling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class EDMScaling:
5
+ def __init__(self, sigma_data=0.5):
6
+ self.sigma_data = sigma_data
7
+
8
+ def __call__(self, sigma):
9
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
10
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
11
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
12
+ c_noise = 0.25 * sigma.log()
13
+ return c_skip, c_out, c_in, c_noise
14
+
15
+
16
+ class EpsScaling:
17
+ def __call__(self, sigma):
18
+ c_skip = torch.ones_like(sigma, device=sigma.device)
19
+ c_out = -sigma
20
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
21
+ c_noise = sigma.clone()
22
+ return c_skip, c_out, c_in, c_noise
23
+
24
+
25
+ class VScaling:
26
+ def __call__(self, sigma):
27
+ c_skip = 1.0 / (sigma**2 + 1.0)
28
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
29
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
30
+ c_noise = sigma.clone()
31
+ return c_skip, c_out, c_in, c_noise
repositories/generative-models/sgm/modules/diffusionmodules/denoiser_weighting.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class UnitWeighting:
5
+ def __call__(self, sigma):
6
+ return torch.ones_like(sigma, device=sigma.device)
7
+
8
+
9
+ class EDMWeighting:
10
+ def __init__(self, sigma_data=0.5):
11
+ self.sigma_data = sigma_data
12
+
13
+ def __call__(self, sigma):
14
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
15
+
16
+
17
+ class VWeighting(EDMWeighting):
18
+ def __init__(self):
19
+ super().__init__(sigma_data=1.0)
20
+
21
+
22
+ class EpsWeighting:
23
+ def __call__(self, sigma):
24
+ return sigma**-2.0
repositories/generative-models/sgm/modules/diffusionmodules/discretizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from functools import partial
4
+ from abc import abstractmethod
5
+
6
+ from ...util import append_zero
7
+ from ...modules.diffusionmodules.util import make_beta_schedule
8
+
9
+
10
+ def generate_roughly_equally_spaced_steps(
11
+ num_substeps: int, max_step: int
12
+ ) -> np.ndarray:
13
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
14
+
15
+
16
+ class Discretization:
17
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
18
+ sigmas = self.get_sigmas(n, device=device)
19
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
20
+ return sigmas if not flip else torch.flip(sigmas, (0,))
21
+
22
+ @abstractmethod
23
+ def get_sigmas(self, n, device):
24
+ pass
25
+
26
+
27
+ class EDMDiscretization(Discretization):
28
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
29
+ self.sigma_min = sigma_min
30
+ self.sigma_max = sigma_max
31
+ self.rho = rho
32
+
33
+ def get_sigmas(self, n, device="cpu"):
34
+ ramp = torch.linspace(0, 1, n, device=device)
35
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
36
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
37
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
38
+ return sigmas
39
+
40
+
41
+ class LegacyDDPMDiscretization(Discretization):
42
+ def __init__(
43
+ self,
44
+ linear_start=0.00085,
45
+ linear_end=0.0120,
46
+ num_timesteps=1000,
47
+ ):
48
+ super().__init__()
49
+ self.num_timesteps = num_timesteps
50
+ betas = make_beta_schedule(
51
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
52
+ )
53
+ alphas = 1.0 - betas
54
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
55
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
56
+
57
+ def get_sigmas(self, n, device="cpu"):
58
+ if n < self.num_timesteps:
59
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
60
+ alphas_cumprod = self.alphas_cumprod[timesteps]
61
+ elif n == self.num_timesteps:
62
+ alphas_cumprod = self.alphas_cumprod
63
+ else:
64
+ raise ValueError
65
+
66
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
67
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
68
+ return torch.flip(sigmas, (0,))
repositories/generative-models/sgm/modules/diffusionmodules/guiders.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from ...util import default, instantiate_from_config
6
+
7
+
8
+ class VanillaCFG:
9
+ """
10
+ implements parallelized CFG
11
+ """
12
+
13
+ def __init__(self, scale, dyn_thresh_config=None):
14
+ scale_schedule = lambda scale, sigma: scale # independent of step
15
+ self.scale_schedule = partial(scale_schedule, scale)
16
+ self.dyn_thresh = instantiate_from_config(
17
+ default(
18
+ dyn_thresh_config,
19
+ {
20
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
21
+ },
22
+ )
23
+ )
24
+
25
+ def __call__(self, x, sigma):
26
+ x_u, x_c = x.chunk(2)
27
+ scale_value = self.scale_schedule(sigma)
28
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29
+ return x_pred
30
+
31
+ def prepare_inputs(self, x, s, c, uc):
32
+ c_out = dict()
33
+
34
+ for k in c:
35
+ if k in ["vector", "crossattn", "concat"]:
36
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
37
+ else:
38
+ assert c[k] == uc[k]
39
+ c_out[k] = c[k]
40
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41
+
42
+
43
+ class IdentityGuider:
44
+ def __call__(self, x, sigma):
45
+ return x
46
+
47
+ def prepare_inputs(self, x, s, c, uc):
48
+ c_out = dict()
49
+
50
+ for k in c:
51
+ c_out[k] = c[k]
52
+
53
+ return x, s, c_out
repositories/generative-models/sgm/modules/diffusionmodules/loss.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import ListConfig
6
+ from taming.modules.losses.lpips import LPIPS
7
+
8
+ from ...util import append_dims, instantiate_from_config
9
+
10
+
11
+ class StandardDiffusionLoss(nn.Module):
12
+ def __init__(
13
+ self,
14
+ sigma_sampler_config,
15
+ type="l2",
16
+ offset_noise_level=0.0,
17
+ batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
18
+ ):
19
+ super().__init__()
20
+
21
+ assert type in ["l2", "l1", "lpips"]
22
+
23
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
24
+
25
+ self.type = type
26
+ self.offset_noise_level = offset_noise_level
27
+
28
+ if type == "lpips":
29
+ self.lpips = LPIPS().eval()
30
+
31
+ if not batch2model_keys:
32
+ batch2model_keys = []
33
+
34
+ if isinstance(batch2model_keys, str):
35
+ batch2model_keys = [batch2model_keys]
36
+
37
+ self.batch2model_keys = set(batch2model_keys)
38
+
39
+ def __call__(self, network, denoiser, conditioner, input, batch):
40
+ cond = conditioner(batch)
41
+ additional_model_inputs = {
42
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
43
+ }
44
+
45
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
46
+ noise = torch.randn_like(input)
47
+ if self.offset_noise_level > 0.0:
48
+ noise = noise + self.offset_noise_level * append_dims(
49
+ torch.randn(input.shape[0], device=input.device), input.ndim
50
+ )
51
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
52
+ model_output = denoiser(
53
+ network, noised_input, sigmas, cond, **additional_model_inputs
54
+ )
55
+ w = append_dims(denoiser.w(sigmas), input.ndim)
56
+ return self.get_loss(model_output, input, w)
57
+
58
+ def get_loss(self, model_output, target, w):
59
+ if self.type == "l2":
60
+ return torch.mean(
61
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
62
+ )
63
+ elif self.type == "l1":
64
+ return torch.mean(
65
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
66
+ )
67
+ elif self.type == "lpips":
68
+ loss = self.lpips(model_output, target).reshape(-1)
69
+ return loss
repositories/generative-models/sgm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ from typing import Any, Callable, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from packaging import version
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+
15
+ XFORMERS_IS_AVAILABLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILABLE = False
18
+ print("no module 'xformers'. Processing without...")
19
+
20
+ from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
21
+
22
+
23
+ def get_timestep_embedding(timesteps, embedding_dim):
24
+ """
25
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
26
+ From Fairseq.
27
+ Build sinusoidal embeddings.
28
+ This matches the implementation in tensor2tensor, but differs slightly
29
+ from the description in Section 3.5 of "Attention Is All You Need".
30
+ """
31
+ assert len(timesteps.shape) == 1
32
+
33
+ half_dim = embedding_dim // 2
34
+ emb = math.log(10000) / (half_dim - 1)
35
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
36
+ emb = emb.to(device=timesteps.device)
37
+ emb = timesteps.float()[:, None] * emb[None, :]
38
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
39
+ if embedding_dim % 2 == 1: # zero pad
40
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
41
+ return emb
42
+
43
+
44
+ def nonlinearity(x):
45
+ # swish
46
+ return x * torch.sigmoid(x)
47
+
48
+
49
+ def Normalize(in_channels, num_groups=32):
50
+ return torch.nn.GroupNorm(
51
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
52
+ )
53
+
54
+
55
+ class Upsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ self.conv = torch.nn.Conv2d(
61
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
62
+ )
63
+
64
+ def forward(self, x):
65
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
66
+ if self.with_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ def __init__(self, in_channels, with_conv):
73
+ super().__init__()
74
+ self.with_conv = with_conv
75
+ if self.with_conv:
76
+ # no asymmetric padding in torch conv, must do it ourselves
77
+ self.conv = torch.nn.Conv2d(
78
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
79
+ )
80
+
81
+ def forward(self, x):
82
+ if self.with_conv:
83
+ pad = (0, 1, 0, 1)
84
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
85
+ x = self.conv(x)
86
+ else:
87
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
88
+ return x
89
+
90
+
91
+ class ResnetBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ *,
95
+ in_channels,
96
+ out_channels=None,
97
+ conv_shortcut=False,
98
+ dropout,
99
+ temb_channels=512,
100
+ ):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+ out_channels = in_channels if out_channels is None else out_channels
104
+ self.out_channels = out_channels
105
+ self.use_conv_shortcut = conv_shortcut
106
+
107
+ self.norm1 = Normalize(in_channels)
108
+ self.conv1 = torch.nn.Conv2d(
109
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
110
+ )
111
+ if temb_channels > 0:
112
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
113
+ self.norm2 = Normalize(out_channels)
114
+ self.dropout = torch.nn.Dropout(dropout)
115
+ self.conv2 = torch.nn.Conv2d(
116
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
117
+ )
118
+ if self.in_channels != self.out_channels:
119
+ if self.use_conv_shortcut:
120
+ self.conv_shortcut = torch.nn.Conv2d(
121
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
122
+ )
123
+ else:
124
+ self.nin_shortcut = torch.nn.Conv2d(
125
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
126
+ )
127
+
128
+ def forward(self, x, temb):
129
+ h = x
130
+ h = self.norm1(h)
131
+ h = nonlinearity(h)
132
+ h = self.conv1(h)
133
+
134
+ if temb is not None:
135
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
136
+
137
+ h = self.norm2(h)
138
+ h = nonlinearity(h)
139
+ h = self.dropout(h)
140
+ h = self.conv2(h)
141
+
142
+ if self.in_channels != self.out_channels:
143
+ if self.use_conv_shortcut:
144
+ x = self.conv_shortcut(x)
145
+ else:
146
+ x = self.nin_shortcut(x)
147
+
148
+ return x + h
149
+
150
+
151
+ class LinAttnBlock(LinearAttention):
152
+ """to match AttnBlock usage"""
153
+
154
+ def __init__(self, in_channels):
155
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
156
+
157
+
158
+ class AttnBlock(nn.Module):
159
+ def __init__(self, in_channels):
160
+ super().__init__()
161
+ self.in_channels = in_channels
162
+
163
+ self.norm = Normalize(in_channels)
164
+ self.q = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+ self.k = torch.nn.Conv2d(
168
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
+ )
170
+ self.v = torch.nn.Conv2d(
171
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
172
+ )
173
+ self.proj_out = torch.nn.Conv2d(
174
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
175
+ )
176
+
177
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
178
+ h_ = self.norm(h_)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ b, c, h, w = q.shape
184
+ q, k, v = map(
185
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
186
+ )
187
+ h_ = torch.nn.functional.scaled_dot_product_attention(
188
+ q, k, v
189
+ ) # scale is dim ** -0.5 per default
190
+ # compute attention
191
+
192
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
193
+
194
+ def forward(self, x, **kwargs):
195
+ h_ = x
196
+ h_ = self.attention(h_)
197
+ h_ = self.proj_out(h_)
198
+ return x + h_
199
+
200
+
201
+ class MemoryEfficientAttnBlock(nn.Module):
202
+ """
203
+ Uses xformers efficient implementation,
204
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
205
+ Note: this is a single-head self-attention operation
206
+ """
207
+
208
+ #
209
+ def __init__(self, in_channels):
210
+ super().__init__()
211
+ self.in_channels = in_channels
212
+
213
+ self.norm = Normalize(in_channels)
214
+ self.q = torch.nn.Conv2d(
215
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
216
+ )
217
+ self.k = torch.nn.Conv2d(
218
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
+ )
220
+ self.v = torch.nn.Conv2d(
221
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
222
+ )
223
+ self.proj_out = torch.nn.Conv2d(
224
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
225
+ )
226
+ self.attention_op: Optional[Any] = None
227
+
228
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
229
+ h_ = self.norm(h_)
230
+ q = self.q(h_)
231
+ k = self.k(h_)
232
+ v = self.v(h_)
233
+
234
+ # compute attention
235
+ B, C, H, W = q.shape
236
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
237
+
238
+ q, k, v = map(
239
+ lambda t: t.unsqueeze(3)
240
+ .reshape(B, t.shape[1], 1, C)
241
+ .permute(0, 2, 1, 3)
242
+ .reshape(B * 1, t.shape[1], C)
243
+ .contiguous(),
244
+ (q, k, v),
245
+ )
246
+ out = xformers.ops.memory_efficient_attention(
247
+ q, k, v, attn_bias=None, op=self.attention_op
248
+ )
249
+
250
+ out = (
251
+ out.unsqueeze(0)
252
+ .reshape(B, 1, out.shape[1], C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B, out.shape[1], C)
255
+ )
256
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
257
+
258
+ def forward(self, x, **kwargs):
259
+ h_ = x
260
+ h_ = self.attention(h_)
261
+ h_ = self.proj_out(h_)
262
+ return x + h_
263
+
264
+
265
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
266
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
267
+ b, c, h, w = x.shape
268
+ x = rearrange(x, "b c h w -> b (h w) c")
269
+ out = super().forward(x, context=context, mask=mask)
270
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
271
+ return x + out
272
+
273
+
274
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
275
+ assert attn_type in [
276
+ "vanilla",
277
+ "vanilla-xformers",
278
+ "memory-efficient-cross-attn",
279
+ "linear",
280
+ "none",
281
+ ], f"attn_type {attn_type} unknown"
282
+ if (
283
+ version.parse(torch.__version__) < version.parse("2.0.0")
284
+ and attn_type != "none"
285
+ ):
286
+ assert XFORMERS_IS_AVAILABLE, (
287
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
288
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
289
+ )
290
+ attn_type = "vanilla-xformers"
291
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
292
+ if attn_type == "vanilla":
293
+ assert attn_kwargs is None
294
+ return AttnBlock(in_channels)
295
+ elif attn_type == "vanilla-xformers":
296
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
297
+ return MemoryEfficientAttnBlock(in_channels)
298
+ elif type == "memory-efficient-cross-attn":
299
+ attn_kwargs["query_dim"] = in_channels
300
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
301
+ elif attn_type == "none":
302
+ return nn.Identity(in_channels)
303
+ else:
304
+ return LinAttnBlock(in_channels)
305
+
306
+
307
+ class Model(nn.Module):
308
+ def __init__(
309
+ self,
310
+ *,
311
+ ch,
312
+ out_ch,
313
+ ch_mult=(1, 2, 4, 8),
314
+ num_res_blocks,
315
+ attn_resolutions,
316
+ dropout=0.0,
317
+ resamp_with_conv=True,
318
+ in_channels,
319
+ resolution,
320
+ use_timestep=True,
321
+ use_linear_attn=False,
322
+ attn_type="vanilla",
323
+ ):
324
+ super().__init__()
325
+ if use_linear_attn:
326
+ attn_type = "linear"
327
+ self.ch = ch
328
+ self.temb_ch = self.ch * 4
329
+ self.num_resolutions = len(ch_mult)
330
+ self.num_res_blocks = num_res_blocks
331
+ self.resolution = resolution
332
+ self.in_channels = in_channels
333
+
334
+ self.use_timestep = use_timestep
335
+ if self.use_timestep:
336
+ # timestep embedding
337
+ self.temb = nn.Module()
338
+ self.temb.dense = nn.ModuleList(
339
+ [
340
+ torch.nn.Linear(self.ch, self.temb_ch),
341
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
342
+ ]
343
+ )
344
+
345
+ # downsampling
346
+ self.conv_in = torch.nn.Conv2d(
347
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
348
+ )
349
+
350
+ curr_res = resolution
351
+ in_ch_mult = (1,) + tuple(ch_mult)
352
+ self.down = nn.ModuleList()
353
+ for i_level in range(self.num_resolutions):
354
+ block = nn.ModuleList()
355
+ attn = nn.ModuleList()
356
+ block_in = ch * in_ch_mult[i_level]
357
+ block_out = ch * ch_mult[i_level]
358
+ for i_block in range(self.num_res_blocks):
359
+ block.append(
360
+ ResnetBlock(
361
+ in_channels=block_in,
362
+ out_channels=block_out,
363
+ temb_channels=self.temb_ch,
364
+ dropout=dropout,
365
+ )
366
+ )
367
+ block_in = block_out
368
+ if curr_res in attn_resolutions:
369
+ attn.append(make_attn(block_in, attn_type=attn_type))
370
+ down = nn.Module()
371
+ down.block = block
372
+ down.attn = attn
373
+ if i_level != self.num_resolutions - 1:
374
+ down.downsample = Downsample(block_in, resamp_with_conv)
375
+ curr_res = curr_res // 2
376
+ self.down.append(down)
377
+
378
+ # middle
379
+ self.mid = nn.Module()
380
+ self.mid.block_1 = ResnetBlock(
381
+ in_channels=block_in,
382
+ out_channels=block_in,
383
+ temb_channels=self.temb_ch,
384
+ dropout=dropout,
385
+ )
386
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
387
+ self.mid.block_2 = ResnetBlock(
388
+ in_channels=block_in,
389
+ out_channels=block_in,
390
+ temb_channels=self.temb_ch,
391
+ dropout=dropout,
392
+ )
393
+
394
+ # upsampling
395
+ self.up = nn.ModuleList()
396
+ for i_level in reversed(range(self.num_resolutions)):
397
+ block = nn.ModuleList()
398
+ attn = nn.ModuleList()
399
+ block_out = ch * ch_mult[i_level]
400
+ skip_in = ch * ch_mult[i_level]
401
+ for i_block in range(self.num_res_blocks + 1):
402
+ if i_block == self.num_res_blocks:
403
+ skip_in = ch * in_ch_mult[i_level]
404
+ block.append(
405
+ ResnetBlock(
406
+ in_channels=block_in + skip_in,
407
+ out_channels=block_out,
408
+ temb_channels=self.temb_ch,
409
+ dropout=dropout,
410
+ )
411
+ )
412
+ block_in = block_out
413
+ if curr_res in attn_resolutions:
414
+ attn.append(make_attn(block_in, attn_type=attn_type))
415
+ up = nn.Module()
416
+ up.block = block
417
+ up.attn = attn
418
+ if i_level != 0:
419
+ up.upsample = Upsample(block_in, resamp_with_conv)
420
+ curr_res = curr_res * 2
421
+ self.up.insert(0, up) # prepend to get consistent order
422
+
423
+ # end
424
+ self.norm_out = Normalize(block_in)
425
+ self.conv_out = torch.nn.Conv2d(
426
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
427
+ )
428
+
429
+ def forward(self, x, t=None, context=None):
430
+ # assert x.shape[2] == x.shape[3] == self.resolution
431
+ if context is not None:
432
+ # assume aligned context, cat along channel axis
433
+ x = torch.cat((x, context), dim=1)
434
+ if self.use_timestep:
435
+ # timestep embedding
436
+ assert t is not None
437
+ temb = get_timestep_embedding(t, self.ch)
438
+ temb = self.temb.dense[0](temb)
439
+ temb = nonlinearity(temb)
440
+ temb = self.temb.dense[1](temb)
441
+ else:
442
+ temb = None
443
+
444
+ # downsampling
445
+ hs = [self.conv_in(x)]
446
+ for i_level in range(self.num_resolutions):
447
+ for i_block in range(self.num_res_blocks):
448
+ h = self.down[i_level].block[i_block](hs[-1], temb)
449
+ if len(self.down[i_level].attn) > 0:
450
+ h = self.down[i_level].attn[i_block](h)
451
+ hs.append(h)
452
+ if i_level != self.num_resolutions - 1:
453
+ hs.append(self.down[i_level].downsample(hs[-1]))
454
+
455
+ # middle
456
+ h = hs[-1]
457
+ h = self.mid.block_1(h, temb)
458
+ h = self.mid.attn_1(h)
459
+ h = self.mid.block_2(h, temb)
460
+
461
+ # upsampling
462
+ for i_level in reversed(range(self.num_resolutions)):
463
+ for i_block in range(self.num_res_blocks + 1):
464
+ h = self.up[i_level].block[i_block](
465
+ torch.cat([h, hs.pop()], dim=1), temb
466
+ )
467
+ if len(self.up[i_level].attn) > 0:
468
+ h = self.up[i_level].attn[i_block](h)
469
+ if i_level != 0:
470
+ h = self.up[i_level].upsample(h)
471
+
472
+ # end
473
+ h = self.norm_out(h)
474
+ h = nonlinearity(h)
475
+ h = self.conv_out(h)
476
+ return h
477
+
478
+ def get_last_layer(self):
479
+ return self.conv_out.weight
480
+
481
+
482
+ class Encoder(nn.Module):
483
+ def __init__(
484
+ self,
485
+ *,
486
+ ch,
487
+ out_ch,
488
+ ch_mult=(1, 2, 4, 8),
489
+ num_res_blocks,
490
+ attn_resolutions,
491
+ dropout=0.0,
492
+ resamp_with_conv=True,
493
+ in_channels,
494
+ resolution,
495
+ z_channels,
496
+ double_z=True,
497
+ use_linear_attn=False,
498
+ attn_type="vanilla",
499
+ **ignore_kwargs,
500
+ ):
501
+ super().__init__()
502
+ if use_linear_attn:
503
+ attn_type = "linear"
504
+ self.ch = ch
505
+ self.temb_ch = 0
506
+ self.num_resolutions = len(ch_mult)
507
+ self.num_res_blocks = num_res_blocks
508
+ self.resolution = resolution
509
+ self.in_channels = in_channels
510
+
511
+ # downsampling
512
+ self.conv_in = torch.nn.Conv2d(
513
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
514
+ )
515
+
516
+ curr_res = resolution
517
+ in_ch_mult = (1,) + tuple(ch_mult)
518
+ self.in_ch_mult = in_ch_mult
519
+ self.down = nn.ModuleList()
520
+ for i_level in range(self.num_resolutions):
521
+ block = nn.ModuleList()
522
+ attn = nn.ModuleList()
523
+ block_in = ch * in_ch_mult[i_level]
524
+ block_out = ch * ch_mult[i_level]
525
+ for i_block in range(self.num_res_blocks):
526
+ block.append(
527
+ ResnetBlock(
528
+ in_channels=block_in,
529
+ out_channels=block_out,
530
+ temb_channels=self.temb_ch,
531
+ dropout=dropout,
532
+ )
533
+ )
534
+ block_in = block_out
535
+ if curr_res in attn_resolutions:
536
+ attn.append(make_attn(block_in, attn_type=attn_type))
537
+ down = nn.Module()
538
+ down.block = block
539
+ down.attn = attn
540
+ if i_level != self.num_resolutions - 1:
541
+ down.downsample = Downsample(block_in, resamp_with_conv)
542
+ curr_res = curr_res // 2
543
+ self.down.append(down)
544
+
545
+ # middle
546
+ self.mid = nn.Module()
547
+ self.mid.block_1 = ResnetBlock(
548
+ in_channels=block_in,
549
+ out_channels=block_in,
550
+ temb_channels=self.temb_ch,
551
+ dropout=dropout,
552
+ )
553
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
554
+ self.mid.block_2 = ResnetBlock(
555
+ in_channels=block_in,
556
+ out_channels=block_in,
557
+ temb_channels=self.temb_ch,
558
+ dropout=dropout,
559
+ )
560
+
561
+ # end
562
+ self.norm_out = Normalize(block_in)
563
+ self.conv_out = torch.nn.Conv2d(
564
+ block_in,
565
+ 2 * z_channels if double_z else z_channels,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1,
569
+ )
570
+
571
+ def forward(self, x):
572
+ # timestep embedding
573
+ temb = None
574
+
575
+ # downsampling
576
+ hs = [self.conv_in(x)]
577
+ for i_level in range(self.num_resolutions):
578
+ for i_block in range(self.num_res_blocks):
579
+ h = self.down[i_level].block[i_block](hs[-1], temb)
580
+ if len(self.down[i_level].attn) > 0:
581
+ h = self.down[i_level].attn[i_block](h)
582
+ hs.append(h)
583
+ if i_level != self.num_resolutions - 1:
584
+ hs.append(self.down[i_level].downsample(hs[-1]))
585
+
586
+ # middle
587
+ h = hs[-1]
588
+ h = self.mid.block_1(h, temb)
589
+ h = self.mid.attn_1(h)
590
+ h = self.mid.block_2(h, temb)
591
+
592
+ # end
593
+ h = self.norm_out(h)
594
+ h = nonlinearity(h)
595
+ h = self.conv_out(h)
596
+ return h
597
+
598
+
599
+ class Decoder(nn.Module):
600
+ def __init__(
601
+ self,
602
+ *,
603
+ ch,
604
+ out_ch,
605
+ ch_mult=(1, 2, 4, 8),
606
+ num_res_blocks,
607
+ attn_resolutions,
608
+ dropout=0.0,
609
+ resamp_with_conv=True,
610
+ in_channels,
611
+ resolution,
612
+ z_channels,
613
+ give_pre_end=False,
614
+ tanh_out=False,
615
+ use_linear_attn=False,
616
+ attn_type="vanilla",
617
+ **ignorekwargs,
618
+ ):
619
+ super().__init__()
620
+ if use_linear_attn:
621
+ attn_type = "linear"
622
+ self.ch = ch
623
+ self.temb_ch = 0
624
+ self.num_resolutions = len(ch_mult)
625
+ self.num_res_blocks = num_res_blocks
626
+ self.resolution = resolution
627
+ self.in_channels = in_channels
628
+ self.give_pre_end = give_pre_end
629
+ self.tanh_out = tanh_out
630
+
631
+ # compute in_ch_mult, block_in and curr_res at lowest res
632
+ in_ch_mult = (1,) + tuple(ch_mult)
633
+ block_in = ch * ch_mult[self.num_resolutions - 1]
634
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
635
+ self.z_shape = (1, z_channels, curr_res, curr_res)
636
+ print(
637
+ "Working with z of shape {} = {} dimensions.".format(
638
+ self.z_shape, np.prod(self.z_shape)
639
+ )
640
+ )
641
+
642
+ make_attn_cls = self._make_attn()
643
+ make_resblock_cls = self._make_resblock()
644
+ make_conv_cls = self._make_conv()
645
+ # z to block_in
646
+ self.conv_in = torch.nn.Conv2d(
647
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
648
+ )
649
+
650
+ # middle
651
+ self.mid = nn.Module()
652
+ self.mid.block_1 = make_resblock_cls(
653
+ in_channels=block_in,
654
+ out_channels=block_in,
655
+ temb_channels=self.temb_ch,
656
+ dropout=dropout,
657
+ )
658
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
659
+ self.mid.block_2 = make_resblock_cls(
660
+ in_channels=block_in,
661
+ out_channels=block_in,
662
+ temb_channels=self.temb_ch,
663
+ dropout=dropout,
664
+ )
665
+
666
+ # upsampling
667
+ self.up = nn.ModuleList()
668
+ for i_level in reversed(range(self.num_resolutions)):
669
+ block = nn.ModuleList()
670
+ attn = nn.ModuleList()
671
+ block_out = ch * ch_mult[i_level]
672
+ for i_block in range(self.num_res_blocks + 1):
673
+ block.append(
674
+ make_resblock_cls(
675
+ in_channels=block_in,
676
+ out_channels=block_out,
677
+ temb_channels=self.temb_ch,
678
+ dropout=dropout,
679
+ )
680
+ )
681
+ block_in = block_out
682
+ if curr_res in attn_resolutions:
683
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
684
+ up = nn.Module()
685
+ up.block = block
686
+ up.attn = attn
687
+ if i_level != 0:
688
+ up.upsample = Upsample(block_in, resamp_with_conv)
689
+ curr_res = curr_res * 2
690
+ self.up.insert(0, up) # prepend to get consistent order
691
+
692
+ # end
693
+ self.norm_out = Normalize(block_in)
694
+ self.conv_out = make_conv_cls(
695
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
696
+ )
697
+
698
+ def _make_attn(self) -> Callable:
699
+ return make_attn
700
+
701
+ def _make_resblock(self) -> Callable:
702
+ return ResnetBlock
703
+
704
+ def _make_conv(self) -> Callable:
705
+ return torch.nn.Conv2d
706
+
707
+ def get_last_layer(self, **kwargs):
708
+ return self.conv_out.weight
709
+
710
+ def forward(self, z, **kwargs):
711
+ # assert z.shape[1:] == self.z_shape[1:]
712
+ self.last_z_shape = z.shape
713
+
714
+ # timestep embedding
715
+ temb = None
716
+
717
+ # z to block_in
718
+ h = self.conv_in(z)
719
+
720
+ # middle
721
+ h = self.mid.block_1(h, temb, **kwargs)
722
+ h = self.mid.attn_1(h, **kwargs)
723
+ h = self.mid.block_2(h, temb, **kwargs)
724
+
725
+ # upsampling
726
+ for i_level in reversed(range(self.num_resolutions)):
727
+ for i_block in range(self.num_res_blocks + 1):
728
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
729
+ if len(self.up[i_level].attn) > 0:
730
+ h = self.up[i_level].attn[i_block](h, **kwargs)
731
+ if i_level != 0:
732
+ h = self.up[i_level].upsample(h)
733
+
734
+ # end
735
+ if self.give_pre_end:
736
+ return h
737
+
738
+ h = self.norm_out(h)
739
+ h = nonlinearity(h)
740
+ h = self.conv_out(h, **kwargs)
741
+ if self.tanh_out:
742
+ h = torch.tanh(h)
743
+ return h
repositories/generative-models/sgm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+ from functools import partial
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+
12
+ from ...modules.attention import SpatialTransformer
13
+ from ...modules.diffusionmodules.util import (
14
+ avg_pool_nd,
15
+ checkpoint,
16
+ conv_nd,
17
+ linear,
18
+ normalization,
19
+ timestep_embedding,
20
+ zero_module,
21
+ )
22
+ from ...util import default, exists
23
+
24
+
25
+ # dummy replace
26
+ def convert_module_to_f16(x):
27
+ pass
28
+
29
+
30
+ def convert_module_to_f32(x):
31
+ pass
32
+
33
+
34
+ ## go
35
+ class AttentionPool2d(nn.Module):
36
+ """
37
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ spacial_dim: int,
43
+ embed_dim: int,
44
+ num_heads_channels: int,
45
+ output_dim: int = None,
46
+ ):
47
+ super().__init__()
48
+ self.positional_embedding = nn.Parameter(
49
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
50
+ )
51
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
52
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
53
+ self.num_heads = embed_dim // num_heads_channels
54
+ self.attention = QKVAttention(self.num_heads)
55
+
56
+ def forward(self, x):
57
+ b, c, *_spatial = x.shape
58
+ x = x.reshape(b, c, -1) # NC(HW)
59
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
60
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
61
+ x = self.qkv_proj(x)
62
+ x = self.attention(x)
63
+ x = self.c_proj(x)
64
+ return x[:, :, 0]
65
+
66
+
67
+ class TimestepBlock(nn.Module):
68
+ """
69
+ Any module where forward() takes timestep embeddings as a second argument.
70
+ """
71
+
72
+ @abstractmethod
73
+ def forward(self, x, emb):
74
+ """
75
+ Apply the module to `x` given `emb` timestep embeddings.
76
+ """
77
+
78
+
79
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
80
+ """
81
+ A sequential module that passes timestep embeddings to the children that
82
+ support it as an extra input.
83
+ """
84
+
85
+ def forward(
86
+ self,
87
+ x,
88
+ emb,
89
+ context=None,
90
+ skip_time_mix=False,
91
+ time_context=None,
92
+ num_video_frames=None,
93
+ time_context_cat=None,
94
+ use_crossframe_attention_in_spatial_layers=False,
95
+ ):
96
+ for layer in self:
97
+ if isinstance(layer, TimestepBlock):
98
+ x = layer(x, emb)
99
+ elif isinstance(layer, SpatialTransformer):
100
+ x = layer(x, context)
101
+ else:
102
+ x = layer(x)
103
+ return x
104
+
105
+
106
+ class Upsample(nn.Module):
107
+ """
108
+ An upsampling layer with an optional convolution.
109
+ :param channels: channels in the inputs and outputs.
110
+ :param use_conv: a bool determining if a convolution is applied.
111
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
112
+ upsampling occurs in the inner-two dimensions.
113
+ """
114
+
115
+ def __init__(
116
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
117
+ ):
118
+ super().__init__()
119
+ self.channels = channels
120
+ self.out_channels = out_channels or channels
121
+ self.use_conv = use_conv
122
+ self.dims = dims
123
+ self.third_up = third_up
124
+ if use_conv:
125
+ self.conv = conv_nd(
126
+ dims, self.channels, self.out_channels, 3, padding=padding
127
+ )
128
+
129
+ def forward(self, x):
130
+ assert x.shape[1] == self.channels
131
+ if self.dims == 3:
132
+ t_factor = 1 if not self.third_up else 2
133
+ x = F.interpolate(
134
+ x,
135
+ (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
136
+ mode="nearest",
137
+ )
138
+ else:
139
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
140
+ if self.use_conv:
141
+ x = self.conv(x)
142
+ return x
143
+
144
+
145
+ class TransposedUpsample(nn.Module):
146
+ "Learned 2x upsampling without padding"
147
+
148
+ def __init__(self, channels, out_channels=None, ks=5):
149
+ super().__init__()
150
+ self.channels = channels
151
+ self.out_channels = out_channels or channels
152
+
153
+ self.up = nn.ConvTranspose2d(
154
+ self.channels, self.out_channels, kernel_size=ks, stride=2
155
+ )
156
+
157
+ def forward(self, x):
158
+ return self.up(x)
159
+
160
+
161
+ class Downsample(nn.Module):
162
+ """
163
+ A downsampling layer with an optional convolution.
164
+ :param channels: channels in the inputs and outputs.
165
+ :param use_conv: a bool determining if a convolution is applied.
166
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
167
+ downsampling occurs in the inner-two dimensions.
168
+ """
169
+
170
+ def __init__(
171
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
172
+ ):
173
+ super().__init__()
174
+ self.channels = channels
175
+ self.out_channels = out_channels or channels
176
+ self.use_conv = use_conv
177
+ self.dims = dims
178
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
179
+ if use_conv:
180
+ print(f"Building a Downsample layer with {dims} dims.")
181
+ print(
182
+ f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
183
+ f"kernel-size: 3, stride: {stride}, padding: {padding}"
184
+ )
185
+ if dims == 3:
186
+ print(f" --> Downsampling third axis (time): {third_down}")
187
+ self.op = conv_nd(
188
+ dims,
189
+ self.channels,
190
+ self.out_channels,
191
+ 3,
192
+ stride=stride,
193
+ padding=padding,
194
+ )
195
+ else:
196
+ assert self.channels == self.out_channels
197
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
198
+
199
+ def forward(self, x):
200
+ assert x.shape[1] == self.channels
201
+ return self.op(x)
202
+
203
+
204
+ class ResBlock(TimestepBlock):
205
+ """
206
+ A residual block that can optionally change the number of channels.
207
+ :param channels: the number of input channels.
208
+ :param emb_channels: the number of timestep embedding channels.
209
+ :param dropout: the rate of dropout.
210
+ :param out_channels: if specified, the number of out channels.
211
+ :param use_conv: if True and out_channels is specified, use a spatial
212
+ convolution instead of a smaller 1x1 convolution to change the
213
+ channels in the skip connection.
214
+ :param dims: determines if the signal is 1D, 2D, or 3D.
215
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
216
+ :param up: if True, use this block for upsampling.
217
+ :param down: if True, use this block for downsampling.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ channels,
223
+ emb_channels,
224
+ dropout,
225
+ out_channels=None,
226
+ use_conv=False,
227
+ use_scale_shift_norm=False,
228
+ dims=2,
229
+ use_checkpoint=False,
230
+ up=False,
231
+ down=False,
232
+ kernel_size=3,
233
+ exchange_temb_dims=False,
234
+ skip_t_emb=False,
235
+ ):
236
+ super().__init__()
237
+ self.channels = channels
238
+ self.emb_channels = emb_channels
239
+ self.dropout = dropout
240
+ self.out_channels = out_channels or channels
241
+ self.use_conv = use_conv
242
+ self.use_checkpoint = use_checkpoint
243
+ self.use_scale_shift_norm = use_scale_shift_norm
244
+ self.exchange_temb_dims = exchange_temb_dims
245
+
246
+ if isinstance(kernel_size, Iterable):
247
+ padding = [k // 2 for k in kernel_size]
248
+ else:
249
+ padding = kernel_size // 2
250
+
251
+ self.in_layers = nn.Sequential(
252
+ normalization(channels),
253
+ nn.SiLU(),
254
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
255
+ )
256
+
257
+ self.updown = up or down
258
+
259
+ if up:
260
+ self.h_upd = Upsample(channels, False, dims)
261
+ self.x_upd = Upsample(channels, False, dims)
262
+ elif down:
263
+ self.h_upd = Downsample(channels, False, dims)
264
+ self.x_upd = Downsample(channels, False, dims)
265
+ else:
266
+ self.h_upd = self.x_upd = nn.Identity()
267
+
268
+ self.skip_t_emb = skip_t_emb
269
+ self.emb_out_channels = (
270
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
271
+ )
272
+ if self.skip_t_emb:
273
+ print(f"Skipping timestep embedding in {self.__class__.__name__}")
274
+ assert not self.use_scale_shift_norm
275
+ self.emb_layers = None
276
+ self.exchange_temb_dims = False
277
+ else:
278
+ self.emb_layers = nn.Sequential(
279
+ nn.SiLU(),
280
+ linear(
281
+ emb_channels,
282
+ self.emb_out_channels,
283
+ ),
284
+ )
285
+
286
+ self.out_layers = nn.Sequential(
287
+ normalization(self.out_channels),
288
+ nn.SiLU(),
289
+ nn.Dropout(p=dropout),
290
+ zero_module(
291
+ conv_nd(
292
+ dims,
293
+ self.out_channels,
294
+ self.out_channels,
295
+ kernel_size,
296
+ padding=padding,
297
+ )
298
+ ),
299
+ )
300
+
301
+ if self.out_channels == channels:
302
+ self.skip_connection = nn.Identity()
303
+ elif use_conv:
304
+ self.skip_connection = conv_nd(
305
+ dims, channels, self.out_channels, kernel_size, padding=padding
306
+ )
307
+ else:
308
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
309
+
310
+ def forward(self, x, emb):
311
+ """
312
+ Apply the block to a Tensor, conditioned on a timestep embedding.
313
+ :param x: an [N x C x ...] Tensor of features.
314
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
315
+ :return: an [N x C x ...] Tensor of outputs.
316
+ """
317
+ return checkpoint(
318
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
319
+ )
320
+
321
+ def _forward(self, x, emb):
322
+ if self.updown:
323
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
324
+ h = in_rest(x)
325
+ h = self.h_upd(h)
326
+ x = self.x_upd(x)
327
+ h = in_conv(h)
328
+ else:
329
+ h = self.in_layers(x)
330
+
331
+ if self.skip_t_emb:
332
+ emb_out = th.zeros_like(h)
333
+ else:
334
+ emb_out = self.emb_layers(emb).type(h.dtype)
335
+ while len(emb_out.shape) < len(h.shape):
336
+ emb_out = emb_out[..., None]
337
+ if self.use_scale_shift_norm:
338
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
339
+ scale, shift = th.chunk(emb_out, 2, dim=1)
340
+ h = out_norm(h) * (1 + scale) + shift
341
+ h = out_rest(h)
342
+ else:
343
+ if self.exchange_temb_dims:
344
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
345
+ h = h + emb_out
346
+ h = self.out_layers(h)
347
+ return self.skip_connection(x) + h
348
+
349
+
350
+ class AttentionBlock(nn.Module):
351
+ """
352
+ An attention block that allows spatial positions to attend to each other.
353
+ Originally ported from here, but adapted to the N-d case.
354
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ channels,
360
+ num_heads=1,
361
+ num_head_channels=-1,
362
+ use_checkpoint=False,
363
+ use_new_attention_order=False,
364
+ ):
365
+ super().__init__()
366
+ self.channels = channels
367
+ if num_head_channels == -1:
368
+ self.num_heads = num_heads
369
+ else:
370
+ assert (
371
+ channels % num_head_channels == 0
372
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
373
+ self.num_heads = channels // num_head_channels
374
+ self.use_checkpoint = use_checkpoint
375
+ self.norm = normalization(channels)
376
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
377
+ if use_new_attention_order:
378
+ # split qkv before split heads
379
+ self.attention = QKVAttention(self.num_heads)
380
+ else:
381
+ # split heads before split qkv
382
+ self.attention = QKVAttentionLegacy(self.num_heads)
383
+
384
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
385
+
386
+ def forward(self, x, **kwargs):
387
+ # TODO add crossframe attention and use mixed checkpoint
388
+ return checkpoint(
389
+ self._forward, (x,), self.parameters(), True
390
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
391
+ # return pt_checkpoint(self._forward, x) # pytorch
392
+
393
+ def _forward(self, x):
394
+ b, c, *spatial = x.shape
395
+ x = x.reshape(b, c, -1)
396
+ qkv = self.qkv(self.norm(x))
397
+ h = self.attention(qkv)
398
+ h = self.proj_out(h)
399
+ return (x + h).reshape(b, c, *spatial)
400
+
401
+
402
+ def count_flops_attn(model, _x, y):
403
+ """
404
+ A counter for the `thop` package to count the operations in an
405
+ attention operation.
406
+ Meant to be used like:
407
+ macs, params = thop.profile(
408
+ model,
409
+ inputs=(inputs, timestamps),
410
+ custom_ops={QKVAttention: QKVAttention.count_flops},
411
+ )
412
+ """
413
+ b, c, *spatial = y[0].shape
414
+ num_spatial = int(np.prod(spatial))
415
+ # We perform two matmuls with the same number of ops.
416
+ # The first computes the weight matrix, the second computes
417
+ # the combination of the value vectors.
418
+ matmul_ops = 2 * b * (num_spatial**2) * c
419
+ model.total_ops += th.DoubleTensor([matmul_ops])
420
+
421
+
422
+ class QKVAttentionLegacy(nn.Module):
423
+ """
424
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
425
+ """
426
+
427
+ def __init__(self, n_heads):
428
+ super().__init__()
429
+ self.n_heads = n_heads
430
+
431
+ def forward(self, qkv):
432
+ """
433
+ Apply QKV attention.
434
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
435
+ :return: an [N x (H * C) x T] tensor after attention.
436
+ """
437
+ bs, width, length = qkv.shape
438
+ assert width % (3 * self.n_heads) == 0
439
+ ch = width // (3 * self.n_heads)
440
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
441
+ scale = 1 / math.sqrt(math.sqrt(ch))
442
+ weight = th.einsum(
443
+ "bct,bcs->bts", q * scale, k * scale
444
+ ) # More stable with f16 than dividing afterwards
445
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
446
+ a = th.einsum("bts,bcs->bct", weight, v)
447
+ return a.reshape(bs, -1, length)
448
+
449
+ @staticmethod
450
+ def count_flops(model, _x, y):
451
+ return count_flops_attn(model, _x, y)
452
+
453
+
454
+ class QKVAttention(nn.Module):
455
+ """
456
+ A module which performs QKV attention and splits in a different order.
457
+ """
458
+
459
+ def __init__(self, n_heads):
460
+ super().__init__()
461
+ self.n_heads = n_heads
462
+
463
+ def forward(self, qkv):
464
+ """
465
+ Apply QKV attention.
466
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
467
+ :return: an [N x (H * C) x T] tensor after attention.
468
+ """
469
+ bs, width, length = qkv.shape
470
+ assert width % (3 * self.n_heads) == 0
471
+ ch = width // (3 * self.n_heads)
472
+ q, k, v = qkv.chunk(3, dim=1)
473
+ scale = 1 / math.sqrt(math.sqrt(ch))
474
+ weight = th.einsum(
475
+ "bct,bcs->bts",
476
+ (q * scale).view(bs * self.n_heads, ch, length),
477
+ (k * scale).view(bs * self.n_heads, ch, length),
478
+ ) # More stable with f16 than dividing afterwards
479
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
480
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
481
+ return a.reshape(bs, -1, length)
482
+
483
+ @staticmethod
484
+ def count_flops(model, _x, y):
485
+ return count_flops_attn(model, _x, y)
486
+
487
+
488
+ class Timestep(nn.Module):
489
+ def __init__(self, dim):
490
+ super().__init__()
491
+ self.dim = dim
492
+
493
+ def forward(self, t):
494
+ return timestep_embedding(t, self.dim)
495
+
496
+
497
+ class UNetModel(nn.Module):
498
+ """
499
+ The full UNet model with attention and timestep embedding.
500
+ :param in_channels: channels in the input Tensor.
501
+ :param model_channels: base channel count for the model.
502
+ :param out_channels: channels in the output Tensor.
503
+ :param num_res_blocks: number of residual blocks per downsample.
504
+ :param attention_resolutions: a collection of downsample rates at which
505
+ attention will take place. May be a set, list, or tuple.
506
+ For example, if this contains 4, then at 4x downsampling, attention
507
+ will be used.
508
+ :param dropout: the dropout probability.
509
+ :param channel_mult: channel multiplier for each level of the UNet.
510
+ :param conv_resample: if True, use learned convolutions for upsampling and
511
+ downsampling.
512
+ :param dims: determines if the signal is 1D, 2D, or 3D.
513
+ :param num_classes: if specified (as an int), then this model will be
514
+ class-conditional with `num_classes` classes.
515
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
516
+ :param num_heads: the number of attention heads in each attention layer.
517
+ :param num_heads_channels: if specified, ignore num_heads and instead use
518
+ a fixed channel width per attention head.
519
+ :param num_heads_upsample: works with num_heads to set a different number
520
+ of heads for upsampling. Deprecated.
521
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
522
+ :param resblock_updown: use residual blocks for up/downsampling.
523
+ :param use_new_attention_order: use a different attention pattern for potentially
524
+ increased efficiency.
525
+ """
526
+
527
+ def __init__(
528
+ self,
529
+ in_channels,
530
+ model_channels,
531
+ out_channels,
532
+ num_res_blocks,
533
+ attention_resolutions,
534
+ dropout=0,
535
+ channel_mult=(1, 2, 4, 8),
536
+ conv_resample=True,
537
+ dims=2,
538
+ num_classes=None,
539
+ use_checkpoint=False,
540
+ use_fp16=False,
541
+ num_heads=-1,
542
+ num_head_channels=-1,
543
+ num_heads_upsample=-1,
544
+ use_scale_shift_norm=False,
545
+ resblock_updown=False,
546
+ use_new_attention_order=False,
547
+ use_spatial_transformer=False, # custom transformer support
548
+ transformer_depth=1, # custom transformer support
549
+ context_dim=None, # custom transformer support
550
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
551
+ legacy=True,
552
+ disable_self_attentions=None,
553
+ num_attention_blocks=None,
554
+ disable_middle_self_attn=False,
555
+ use_linear_in_transformer=False,
556
+ spatial_transformer_attn_type="softmax",
557
+ adm_in_channels=None,
558
+ use_fairscale_checkpoint=False,
559
+ offload_to_cpu=False,
560
+ transformer_depth_middle=None,
561
+ ):
562
+ super().__init__()
563
+ from omegaconf.listconfig import ListConfig
564
+
565
+ if use_spatial_transformer:
566
+ assert (
567
+ context_dim is not None
568
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
569
+
570
+ if context_dim is not None:
571
+ assert (
572
+ use_spatial_transformer
573
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
574
+ if type(context_dim) == ListConfig:
575
+ context_dim = list(context_dim)
576
+
577
+ if num_heads_upsample == -1:
578
+ num_heads_upsample = num_heads
579
+
580
+ if num_heads == -1:
581
+ assert (
582
+ num_head_channels != -1
583
+ ), "Either num_heads or num_head_channels has to be set"
584
+
585
+ if num_head_channels == -1:
586
+ assert (
587
+ num_heads != -1
588
+ ), "Either num_heads or num_head_channels has to be set"
589
+
590
+ self.in_channels = in_channels
591
+ self.model_channels = model_channels
592
+ self.out_channels = out_channels
593
+ if isinstance(transformer_depth, int):
594
+ transformer_depth = len(channel_mult) * [transformer_depth]
595
+ elif isinstance(transformer_depth, ListConfig):
596
+ transformer_depth = list(transformer_depth)
597
+ transformer_depth_middle = default(
598
+ transformer_depth_middle, transformer_depth[-1]
599
+ )
600
+
601
+ if isinstance(num_res_blocks, int):
602
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
603
+ else:
604
+ if len(num_res_blocks) != len(channel_mult):
605
+ raise ValueError(
606
+ "provide num_res_blocks either as an int (globally constant) or "
607
+ "as a list/tuple (per-level) with the same length as channel_mult"
608
+ )
609
+ self.num_res_blocks = num_res_blocks
610
+ # self.num_res_blocks = num_res_blocks
611
+ if disable_self_attentions is not None:
612
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
613
+ assert len(disable_self_attentions) == len(channel_mult)
614
+ if num_attention_blocks is not None:
615
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
616
+ assert all(
617
+ map(
618
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
619
+ range(len(num_attention_blocks)),
620
+ )
621
+ )
622
+ print(
623
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
624
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
625
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
626
+ f"attention will still not be set."
627
+ ) # todo: convert to warning
628
+
629
+ self.attention_resolutions = attention_resolutions
630
+ self.dropout = dropout
631
+ self.channel_mult = channel_mult
632
+ self.conv_resample = conv_resample
633
+ self.num_classes = num_classes
634
+ self.use_checkpoint = use_checkpoint
635
+ if use_fp16:
636
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
637
+ # self.dtype = th.float16 if use_fp16 else th.float32
638
+ self.num_heads = num_heads
639
+ self.num_head_channels = num_head_channels
640
+ self.num_heads_upsample = num_heads_upsample
641
+ self.predict_codebook_ids = n_embed is not None
642
+
643
+ assert use_fairscale_checkpoint != use_checkpoint or not (
644
+ use_checkpoint or use_fairscale_checkpoint
645
+ )
646
+
647
+ self.use_fairscale_checkpoint = False
648
+ checkpoint_wrapper_fn = (
649
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
650
+ if self.use_fairscale_checkpoint
651
+ else lambda x: x
652
+ )
653
+
654
+ time_embed_dim = model_channels * 4
655
+ self.time_embed = checkpoint_wrapper_fn(
656
+ nn.Sequential(
657
+ linear(model_channels, time_embed_dim),
658
+ nn.SiLU(),
659
+ linear(time_embed_dim, time_embed_dim),
660
+ )
661
+ )
662
+
663
+ if self.num_classes is not None:
664
+ if isinstance(self.num_classes, int):
665
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
666
+ elif self.num_classes == "continuous":
667
+ print("setting up linear c_adm embedding layer")
668
+ self.label_emb = nn.Linear(1, time_embed_dim)
669
+ elif self.num_classes == "timestep":
670
+ self.label_emb = checkpoint_wrapper_fn(
671
+ nn.Sequential(
672
+ Timestep(model_channels),
673
+ nn.Sequential(
674
+ linear(model_channels, time_embed_dim),
675
+ nn.SiLU(),
676
+ linear(time_embed_dim, time_embed_dim),
677
+ ),
678
+ )
679
+ )
680
+ elif self.num_classes == "sequential":
681
+ assert adm_in_channels is not None
682
+ self.label_emb = nn.Sequential(
683
+ nn.Sequential(
684
+ linear(adm_in_channels, time_embed_dim),
685
+ nn.SiLU(),
686
+ linear(time_embed_dim, time_embed_dim),
687
+ )
688
+ )
689
+ else:
690
+ raise ValueError()
691
+
692
+ self.input_blocks = nn.ModuleList(
693
+ [
694
+ TimestepEmbedSequential(
695
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
696
+ )
697
+ ]
698
+ )
699
+ self._feature_size = model_channels
700
+ input_block_chans = [model_channels]
701
+ ch = model_channels
702
+ ds = 1
703
+ for level, mult in enumerate(channel_mult):
704
+ for nr in range(self.num_res_blocks[level]):
705
+ layers = [
706
+ checkpoint_wrapper_fn(
707
+ ResBlock(
708
+ ch,
709
+ time_embed_dim,
710
+ dropout,
711
+ out_channels=mult * model_channels,
712
+ dims=dims,
713
+ use_checkpoint=use_checkpoint,
714
+ use_scale_shift_norm=use_scale_shift_norm,
715
+ )
716
+ )
717
+ ]
718
+ ch = mult * model_channels
719
+ if ds in attention_resolutions:
720
+ if num_head_channels == -1:
721
+ dim_head = ch // num_heads
722
+ else:
723
+ num_heads = ch // num_head_channels
724
+ dim_head = num_head_channels
725
+ if legacy:
726
+ # num_heads = 1
727
+ dim_head = (
728
+ ch // num_heads
729
+ if use_spatial_transformer
730
+ else num_head_channels
731
+ )
732
+ if exists(disable_self_attentions):
733
+ disabled_sa = disable_self_attentions[level]
734
+ else:
735
+ disabled_sa = False
736
+
737
+ if (
738
+ not exists(num_attention_blocks)
739
+ or nr < num_attention_blocks[level]
740
+ ):
741
+ layers.append(
742
+ checkpoint_wrapper_fn(
743
+ AttentionBlock(
744
+ ch,
745
+ use_checkpoint=use_checkpoint,
746
+ num_heads=num_heads,
747
+ num_head_channels=dim_head,
748
+ use_new_attention_order=use_new_attention_order,
749
+ )
750
+ )
751
+ if not use_spatial_transformer
752
+ else checkpoint_wrapper_fn(
753
+ SpatialTransformer(
754
+ ch,
755
+ num_heads,
756
+ dim_head,
757
+ depth=transformer_depth[level],
758
+ context_dim=context_dim,
759
+ disable_self_attn=disabled_sa,
760
+ use_linear=use_linear_in_transformer,
761
+ attn_type=spatial_transformer_attn_type,
762
+ use_checkpoint=use_checkpoint,
763
+ )
764
+ )
765
+ )
766
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
767
+ self._feature_size += ch
768
+ input_block_chans.append(ch)
769
+ if level != len(channel_mult) - 1:
770
+ out_ch = ch
771
+ self.input_blocks.append(
772
+ TimestepEmbedSequential(
773
+ checkpoint_wrapper_fn(
774
+ ResBlock(
775
+ ch,
776
+ time_embed_dim,
777
+ dropout,
778
+ out_channels=out_ch,
779
+ dims=dims,
780
+ use_checkpoint=use_checkpoint,
781
+ use_scale_shift_norm=use_scale_shift_norm,
782
+ down=True,
783
+ )
784
+ )
785
+ if resblock_updown
786
+ else Downsample(
787
+ ch, conv_resample, dims=dims, out_channels=out_ch
788
+ )
789
+ )
790
+ )
791
+ ch = out_ch
792
+ input_block_chans.append(ch)
793
+ ds *= 2
794
+ self._feature_size += ch
795
+
796
+ if num_head_channels == -1:
797
+ dim_head = ch // num_heads
798
+ else:
799
+ num_heads = ch // num_head_channels
800
+ dim_head = num_head_channels
801
+ if legacy:
802
+ # num_heads = 1
803
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
804
+ self.middle_block = TimestepEmbedSequential(
805
+ checkpoint_wrapper_fn(
806
+ ResBlock(
807
+ ch,
808
+ time_embed_dim,
809
+ dropout,
810
+ dims=dims,
811
+ use_checkpoint=use_checkpoint,
812
+ use_scale_shift_norm=use_scale_shift_norm,
813
+ )
814
+ ),
815
+ checkpoint_wrapper_fn(
816
+ AttentionBlock(
817
+ ch,
818
+ use_checkpoint=use_checkpoint,
819
+ num_heads=num_heads,
820
+ num_head_channels=dim_head,
821
+ use_new_attention_order=use_new_attention_order,
822
+ )
823
+ )
824
+ if not use_spatial_transformer
825
+ else checkpoint_wrapper_fn(
826
+ SpatialTransformer( # always uses a self-attn
827
+ ch,
828
+ num_heads,
829
+ dim_head,
830
+ depth=transformer_depth_middle,
831
+ context_dim=context_dim,
832
+ disable_self_attn=disable_middle_self_attn,
833
+ use_linear=use_linear_in_transformer,
834
+ attn_type=spatial_transformer_attn_type,
835
+ use_checkpoint=use_checkpoint,
836
+ )
837
+ ),
838
+ checkpoint_wrapper_fn(
839
+ ResBlock(
840
+ ch,
841
+ time_embed_dim,
842
+ dropout,
843
+ dims=dims,
844
+ use_checkpoint=use_checkpoint,
845
+ use_scale_shift_norm=use_scale_shift_norm,
846
+ )
847
+ ),
848
+ )
849
+ self._feature_size += ch
850
+
851
+ self.output_blocks = nn.ModuleList([])
852
+ for level, mult in list(enumerate(channel_mult))[::-1]:
853
+ for i in range(self.num_res_blocks[level] + 1):
854
+ ich = input_block_chans.pop()
855
+ layers = [
856
+ checkpoint_wrapper_fn(
857
+ ResBlock(
858
+ ch + ich,
859
+ time_embed_dim,
860
+ dropout,
861
+ out_channels=model_channels * mult,
862
+ dims=dims,
863
+ use_checkpoint=use_checkpoint,
864
+ use_scale_shift_norm=use_scale_shift_norm,
865
+ )
866
+ )
867
+ ]
868
+ ch = model_channels * mult
869
+ if ds in attention_resolutions:
870
+ if num_head_channels == -1:
871
+ dim_head = ch // num_heads
872
+ else:
873
+ num_heads = ch // num_head_channels
874
+ dim_head = num_head_channels
875
+ if legacy:
876
+ # num_heads = 1
877
+ dim_head = (
878
+ ch // num_heads
879
+ if use_spatial_transformer
880
+ else num_head_channels
881
+ )
882
+ if exists(disable_self_attentions):
883
+ disabled_sa = disable_self_attentions[level]
884
+ else:
885
+ disabled_sa = False
886
+
887
+ if (
888
+ not exists(num_attention_blocks)
889
+ or i < num_attention_blocks[level]
890
+ ):
891
+ layers.append(
892
+ checkpoint_wrapper_fn(
893
+ AttentionBlock(
894
+ ch,
895
+ use_checkpoint=use_checkpoint,
896
+ num_heads=num_heads_upsample,
897
+ num_head_channels=dim_head,
898
+ use_new_attention_order=use_new_attention_order,
899
+ )
900
+ )
901
+ if not use_spatial_transformer
902
+ else checkpoint_wrapper_fn(
903
+ SpatialTransformer(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ depth=transformer_depth[level],
908
+ context_dim=context_dim,
909
+ disable_self_attn=disabled_sa,
910
+ use_linear=use_linear_in_transformer,
911
+ attn_type=spatial_transformer_attn_type,
912
+ use_checkpoint=use_checkpoint,
913
+ )
914
+ )
915
+ )
916
+ if level and i == self.num_res_blocks[level]:
917
+ out_ch = ch
918
+ layers.append(
919
+ checkpoint_wrapper_fn(
920
+ ResBlock(
921
+ ch,
922
+ time_embed_dim,
923
+ dropout,
924
+ out_channels=out_ch,
925
+ dims=dims,
926
+ use_checkpoint=use_checkpoint,
927
+ use_scale_shift_norm=use_scale_shift_norm,
928
+ up=True,
929
+ )
930
+ )
931
+ if resblock_updown
932
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
933
+ )
934
+ ds //= 2
935
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
936
+ self._feature_size += ch
937
+
938
+ self.out = checkpoint_wrapper_fn(
939
+ nn.Sequential(
940
+ normalization(ch),
941
+ nn.SiLU(),
942
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
943
+ )
944
+ )
945
+ if self.predict_codebook_ids:
946
+ self.id_predictor = checkpoint_wrapper_fn(
947
+ nn.Sequential(
948
+ normalization(ch),
949
+ conv_nd(dims, model_channels, n_embed, 1),
950
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
951
+ )
952
+ )
953
+
954
+ def convert_to_fp16(self):
955
+ """
956
+ Convert the torso of the model to float16.
957
+ """
958
+ self.input_blocks.apply(convert_module_to_f16)
959
+ self.middle_block.apply(convert_module_to_f16)
960
+ self.output_blocks.apply(convert_module_to_f16)
961
+
962
+ def convert_to_fp32(self):
963
+ """
964
+ Convert the torso of the model to float32.
965
+ """
966
+ self.input_blocks.apply(convert_module_to_f32)
967
+ self.middle_block.apply(convert_module_to_f32)
968
+ self.output_blocks.apply(convert_module_to_f32)
969
+
970
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
971
+ """
972
+ Apply the model to an input batch.
973
+ :param x: an [N x C x ...] Tensor of inputs.
974
+ :param timesteps: a 1-D batch of timesteps.
975
+ :param context: conditioning plugged in via crossattn
976
+ :param y: an [N] Tensor of labels, if class-conditional.
977
+ :return: an [N x C x ...] Tensor of outputs.
978
+ """
979
+ assert (y is not None) == (
980
+ self.num_classes is not None
981
+ ), "must specify y if and only if the model is class-conditional"
982
+ hs = []
983
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
984
+ emb = self.time_embed(t_emb)
985
+
986
+ if self.num_classes is not None:
987
+ assert y.shape[0] == x.shape[0]
988
+ emb = emb + self.label_emb(y)
989
+
990
+ # h = x.type(self.dtype)
991
+ h = x
992
+ for module in self.input_blocks:
993
+ h = module(h, emb, context)
994
+ hs.append(h)
995
+ h = self.middle_block(h, emb, context)
996
+ for module in self.output_blocks:
997
+ h = th.cat([h, hs.pop()], dim=1)
998
+ h = module(h, emb, context)
999
+ h = h.type(x.dtype)
1000
+ if self.predict_codebook_ids:
1001
+ assert False, "not supported anymore. what the f*** are you doing?"
1002
+ else:
1003
+ return self.out(h)
1004
+
1005
+
1006
+ class NoTimeUNetModel(UNetModel):
1007
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1008
+ timesteps = th.zeros_like(timesteps)
1009
+ return super().forward(x, timesteps, context, y, **kwargs)
1010
+
1011
+
1012
+ class EncoderUNetModel(nn.Module):
1013
+ """
1014
+ The half UNet model with attention and timestep embedding.
1015
+ For usage, see UNet.
1016
+ """
1017
+
1018
+ def __init__(
1019
+ self,
1020
+ image_size,
1021
+ in_channels,
1022
+ model_channels,
1023
+ out_channels,
1024
+ num_res_blocks,
1025
+ attention_resolutions,
1026
+ dropout=0,
1027
+ channel_mult=(1, 2, 4, 8),
1028
+ conv_resample=True,
1029
+ dims=2,
1030
+ use_checkpoint=False,
1031
+ use_fp16=False,
1032
+ num_heads=1,
1033
+ num_head_channels=-1,
1034
+ num_heads_upsample=-1,
1035
+ use_scale_shift_norm=False,
1036
+ resblock_updown=False,
1037
+ use_new_attention_order=False,
1038
+ pool="adaptive",
1039
+ *args,
1040
+ **kwargs,
1041
+ ):
1042
+ super().__init__()
1043
+
1044
+ if num_heads_upsample == -1:
1045
+ num_heads_upsample = num_heads
1046
+
1047
+ self.in_channels = in_channels
1048
+ self.model_channels = model_channels
1049
+ self.out_channels = out_channels
1050
+ self.num_res_blocks = num_res_blocks
1051
+ self.attention_resolutions = attention_resolutions
1052
+ self.dropout = dropout
1053
+ self.channel_mult = channel_mult
1054
+ self.conv_resample = conv_resample
1055
+ self.use_checkpoint = use_checkpoint
1056
+ self.dtype = th.float16 if use_fp16 else th.float32
1057
+ self.num_heads = num_heads
1058
+ self.num_head_channels = num_head_channels
1059
+ self.num_heads_upsample = num_heads_upsample
1060
+
1061
+ time_embed_dim = model_channels * 4
1062
+ self.time_embed = nn.Sequential(
1063
+ linear(model_channels, time_embed_dim),
1064
+ nn.SiLU(),
1065
+ linear(time_embed_dim, time_embed_dim),
1066
+ )
1067
+
1068
+ self.input_blocks = nn.ModuleList(
1069
+ [
1070
+ TimestepEmbedSequential(
1071
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1072
+ )
1073
+ ]
1074
+ )
1075
+ self._feature_size = model_channels
1076
+ input_block_chans = [model_channels]
1077
+ ch = model_channels
1078
+ ds = 1
1079
+ for level, mult in enumerate(channel_mult):
1080
+ for _ in range(num_res_blocks):
1081
+ layers = [
1082
+ ResBlock(
1083
+ ch,
1084
+ time_embed_dim,
1085
+ dropout,
1086
+ out_channels=mult * model_channels,
1087
+ dims=dims,
1088
+ use_checkpoint=use_checkpoint,
1089
+ use_scale_shift_norm=use_scale_shift_norm,
1090
+ )
1091
+ ]
1092
+ ch = mult * model_channels
1093
+ if ds in attention_resolutions:
1094
+ layers.append(
1095
+ AttentionBlock(
1096
+ ch,
1097
+ use_checkpoint=use_checkpoint,
1098
+ num_heads=num_heads,
1099
+ num_head_channels=num_head_channels,
1100
+ use_new_attention_order=use_new_attention_order,
1101
+ )
1102
+ )
1103
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1104
+ self._feature_size += ch
1105
+ input_block_chans.append(ch)
1106
+ if level != len(channel_mult) - 1:
1107
+ out_ch = ch
1108
+ self.input_blocks.append(
1109
+ TimestepEmbedSequential(
1110
+ ResBlock(
1111
+ ch,
1112
+ time_embed_dim,
1113
+ dropout,
1114
+ out_channels=out_ch,
1115
+ dims=dims,
1116
+ use_checkpoint=use_checkpoint,
1117
+ use_scale_shift_norm=use_scale_shift_norm,
1118
+ down=True,
1119
+ )
1120
+ if resblock_updown
1121
+ else Downsample(
1122
+ ch, conv_resample, dims=dims, out_channels=out_ch
1123
+ )
1124
+ )
1125
+ )
1126
+ ch = out_ch
1127
+ input_block_chans.append(ch)
1128
+ ds *= 2
1129
+ self._feature_size += ch
1130
+
1131
+ self.middle_block = TimestepEmbedSequential(
1132
+ ResBlock(
1133
+ ch,
1134
+ time_embed_dim,
1135
+ dropout,
1136
+ dims=dims,
1137
+ use_checkpoint=use_checkpoint,
1138
+ use_scale_shift_norm=use_scale_shift_norm,
1139
+ ),
1140
+ AttentionBlock(
1141
+ ch,
1142
+ use_checkpoint=use_checkpoint,
1143
+ num_heads=num_heads,
1144
+ num_head_channels=num_head_channels,
1145
+ use_new_attention_order=use_new_attention_order,
1146
+ ),
1147
+ ResBlock(
1148
+ ch,
1149
+ time_embed_dim,
1150
+ dropout,
1151
+ dims=dims,
1152
+ use_checkpoint=use_checkpoint,
1153
+ use_scale_shift_norm=use_scale_shift_norm,
1154
+ ),
1155
+ )
1156
+ self._feature_size += ch
1157
+ self.pool = pool
1158
+ if pool == "adaptive":
1159
+ self.out = nn.Sequential(
1160
+ normalization(ch),
1161
+ nn.SiLU(),
1162
+ nn.AdaptiveAvgPool2d((1, 1)),
1163
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1164
+ nn.Flatten(),
1165
+ )
1166
+ elif pool == "attention":
1167
+ assert num_head_channels != -1
1168
+ self.out = nn.Sequential(
1169
+ normalization(ch),
1170
+ nn.SiLU(),
1171
+ AttentionPool2d(
1172
+ (image_size // ds), ch, num_head_channels, out_channels
1173
+ ),
1174
+ )
1175
+ elif pool == "spatial":
1176
+ self.out = nn.Sequential(
1177
+ nn.Linear(self._feature_size, 2048),
1178
+ nn.ReLU(),
1179
+ nn.Linear(2048, self.out_channels),
1180
+ )
1181
+ elif pool == "spatial_v2":
1182
+ self.out = nn.Sequential(
1183
+ nn.Linear(self._feature_size, 2048),
1184
+ normalization(2048),
1185
+ nn.SiLU(),
1186
+ nn.Linear(2048, self.out_channels),
1187
+ )
1188
+ else:
1189
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1190
+
1191
+ def convert_to_fp16(self):
1192
+ """
1193
+ Convert the torso of the model to float16.
1194
+ """
1195
+ self.input_blocks.apply(convert_module_to_f16)
1196
+ self.middle_block.apply(convert_module_to_f16)
1197
+
1198
+ def convert_to_fp32(self):
1199
+ """
1200
+ Convert the torso of the model to float32.
1201
+ """
1202
+ self.input_blocks.apply(convert_module_to_f32)
1203
+ self.middle_block.apply(convert_module_to_f32)
1204
+
1205
+ def forward(self, x, timesteps):
1206
+ """
1207
+ Apply the model to an input batch.
1208
+ :param x: an [N x C x ...] Tensor of inputs.
1209
+ :param timesteps: a 1-D batch of timesteps.
1210
+ :return: an [N x K] Tensor of outputs.
1211
+ """
1212
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1213
+
1214
+ results = []
1215
+ # h = x.type(self.dtype)
1216
+ h = x
1217
+ for module in self.input_blocks:
1218
+ h = module(h, emb)
1219
+ if self.pool.startswith("spatial"):
1220
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1221
+ h = self.middle_block(h, emb)
1222
+ if self.pool.startswith("spatial"):
1223
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1224
+ h = th.cat(results, axis=-1)
1225
+ return self.out(h)
1226
+ else:
1227
+ h = h.type(x.dtype)
1228
+ return self.out(h)
1229
+
1230
+
1231
+ if __name__ == "__main__":
1232
+
1233
+ class Dummy(nn.Module):
1234
+ def __init__(self, in_channels=3, model_channels=64):
1235
+ super().__init__()
1236
+ self.input_blocks = nn.ModuleList(
1237
+ [
1238
+ TimestepEmbedSequential(
1239
+ conv_nd(2, in_channels, model_channels, 3, padding=1)
1240
+ )
1241
+ ]
1242
+ )
1243
+
1244
+ model = UNetModel(
1245
+ use_checkpoint=True,
1246
+ image_size=64,
1247
+ in_channels=4,
1248
+ out_channels=4,
1249
+ model_channels=128,
1250
+ attention_resolutions=[4, 2],
1251
+ num_res_blocks=2,
1252
+ channel_mult=[1, 2, 4],
1253
+ num_head_channels=64,
1254
+ use_spatial_transformer=False,
1255
+ use_linear_in_transformer=True,
1256
+ transformer_depth=1,
1257
+ legacy=False,
1258
+ ).cuda()
1259
+ x = th.randn(11, 4, 64, 64).cuda()
1260
+ t = th.randint(low=0, high=10, size=(11,), device="cuda")
1261
+ o = model(x, t)
1262
+ print("done.")
repositories/generative-models/sgm/modules/diffusionmodules/sampling.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
3
+ """
4
+
5
+
6
+ from typing import Dict, Union
7
+
8
+ import torch
9
+ from omegaconf import ListConfig, OmegaConf
10
+ from tqdm import tqdm
11
+
12
+ from ...modules.diffusionmodules.sampling_utils import (
13
+ get_ancestral_step,
14
+ linear_multistep_coeff,
15
+ to_d,
16
+ to_neg_log_sigma,
17
+ to_sigma,
18
+ )
19
+ from ...util import append_dims, default, instantiate_from_config
20
+
21
+ DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
22
+
23
+
24
+ class BaseDiffusionSampler:
25
+ def __init__(
26
+ self,
27
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
28
+ num_steps: Union[int, None] = None,
29
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
30
+ verbose: bool = False,
31
+ device: str = "cuda",
32
+ ):
33
+ self.num_steps = num_steps
34
+ self.discretization = instantiate_from_config(discretization_config)
35
+ self.guider = instantiate_from_config(
36
+ default(
37
+ guider_config,
38
+ DEFAULT_GUIDER,
39
+ )
40
+ )
41
+ self.verbose = verbose
42
+ self.device = device
43
+
44
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
45
+ sigmas = self.discretization(
46
+ self.num_steps if num_steps is None else num_steps, device=self.device
47
+ )
48
+ uc = default(uc, cond)
49
+
50
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
51
+ num_sigmas = len(sigmas)
52
+
53
+ s_in = x.new_ones([x.shape[0]])
54
+
55
+ return x, s_in, sigmas, num_sigmas, cond, uc
56
+
57
+ def denoise(self, x, denoiser, sigma, cond, uc):
58
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
59
+ denoised = self.guider(denoised, sigma)
60
+ return denoised
61
+
62
+ def get_sigma_gen(self, num_sigmas):
63
+ sigma_generator = range(num_sigmas - 1)
64
+ if self.verbose:
65
+ print("#" * 30, " Sampling setting ", "#" * 30)
66
+ print(f"Sampler: {self.__class__.__name__}")
67
+ print(f"Discretization: {self.discretization.__class__.__name__}")
68
+ print(f"Guider: {self.guider.__class__.__name__}")
69
+ sigma_generator = tqdm(
70
+ sigma_generator,
71
+ total=num_sigmas,
72
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
73
+ )
74
+ return sigma_generator
75
+
76
+
77
+ class SingleStepDiffusionSampler(BaseDiffusionSampler):
78
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
79
+ raise NotImplementedError
80
+
81
+ def euler_step(self, x, d, dt):
82
+ return x + dt * d
83
+
84
+
85
+ class EDMSampler(SingleStepDiffusionSampler):
86
+ def __init__(
87
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
88
+ ):
89
+ super().__init__(*args, **kwargs)
90
+
91
+ self.s_churn = s_churn
92
+ self.s_tmin = s_tmin
93
+ self.s_tmax = s_tmax
94
+ self.s_noise = s_noise
95
+
96
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
97
+ sigma_hat = sigma * (gamma + 1.0)
98
+ if gamma > 0:
99
+ eps = torch.randn_like(x) * self.s_noise
100
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
101
+
102
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
103
+ d = to_d(x, sigma_hat, denoised)
104
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
105
+
106
+ euler_step = self.euler_step(x, d, dt)
107
+ x = self.possible_correction_step(
108
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
109
+ )
110
+ return x
111
+
112
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
113
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
114
+ x, cond, uc, num_steps
115
+ )
116
+
117
+ for i in self.get_sigma_gen(num_sigmas):
118
+ gamma = (
119
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
120
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
121
+ else 0.0
122
+ )
123
+ x = self.sampler_step(
124
+ s_in * sigmas[i],
125
+ s_in * sigmas[i + 1],
126
+ denoiser,
127
+ x,
128
+ cond,
129
+ uc,
130
+ gamma,
131
+ )
132
+
133
+ return x
134
+
135
+
136
+ class AncestralSampler(SingleStepDiffusionSampler):
137
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
138
+ super().__init__(*args, **kwargs)
139
+
140
+ self.eta = eta
141
+ self.s_noise = s_noise
142
+ self.noise_sampler = lambda x: torch.randn_like(x)
143
+
144
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
145
+ d = to_d(x, sigma, denoised)
146
+ dt = append_dims(sigma_down - sigma, x.ndim)
147
+
148
+ return self.euler_step(x, d, dt)
149
+
150
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
151
+ x = torch.where(
152
+ append_dims(next_sigma, x.ndim) > 0.0,
153
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
154
+ x,
155
+ )
156
+ return x
157
+
158
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
159
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
160
+ x, cond, uc, num_steps
161
+ )
162
+
163
+ for i in self.get_sigma_gen(num_sigmas):
164
+ x = self.sampler_step(
165
+ s_in * sigmas[i],
166
+ s_in * sigmas[i + 1],
167
+ denoiser,
168
+ x,
169
+ cond,
170
+ uc,
171
+ )
172
+
173
+ return x
174
+
175
+
176
+ class LinearMultistepSampler(BaseDiffusionSampler):
177
+ def __init__(
178
+ self,
179
+ order=4,
180
+ *args,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(*args, **kwargs)
184
+
185
+ self.order = order
186
+
187
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
188
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
189
+ x, cond, uc, num_steps
190
+ )
191
+
192
+ ds = []
193
+ sigmas_cpu = sigmas.detach().cpu().numpy()
194
+ for i in self.get_sigma_gen(num_sigmas):
195
+ sigma = s_in * sigmas[i]
196
+ denoised = denoiser(
197
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
198
+ )
199
+ denoised = self.guider(denoised, sigma)
200
+ d = to_d(x, sigma, denoised)
201
+ ds.append(d)
202
+ if len(ds) > self.order:
203
+ ds.pop(0)
204
+ cur_order = min(i + 1, self.order)
205
+ coeffs = [
206
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
207
+ for j in range(cur_order)
208
+ ]
209
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
210
+
211
+ return x
212
+
213
+
214
+ class EulerEDMSampler(EDMSampler):
215
+ def possible_correction_step(
216
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
217
+ ):
218
+ return euler_step
219
+
220
+
221
+ class HeunEDMSampler(EDMSampler):
222
+ def possible_correction_step(
223
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
224
+ ):
225
+ if torch.sum(next_sigma) < 1e-14:
226
+ # Save a network evaluation if all noise levels are 0
227
+ return euler_step
228
+ else:
229
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
230
+ d_new = to_d(euler_step, next_sigma, denoised)
231
+ d_prime = (d + d_new) / 2.0
232
+
233
+ # apply correction if noise level is not 0
234
+ x = torch.where(
235
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
236
+ )
237
+ return x
238
+
239
+
240
+ class EulerAncestralSampler(AncestralSampler):
241
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
242
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
243
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
244
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
245
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
246
+
247
+ return x
248
+
249
+
250
+ class DPMPP2SAncestralSampler(AncestralSampler):
251
+ def get_variables(self, sigma, sigma_down):
252
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
253
+ h = t_next - t
254
+ s = t + 0.5 * h
255
+ return h, s, t, t_next
256
+
257
+ def get_mult(self, h, s, t, t_next):
258
+ mult1 = to_sigma(s) / to_sigma(t)
259
+ mult2 = (-0.5 * h).expm1()
260
+ mult3 = to_sigma(t_next) / to_sigma(t)
261
+ mult4 = (-h).expm1()
262
+
263
+ return mult1, mult2, mult3, mult4
264
+
265
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
266
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
267
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
268
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
269
+
270
+ if torch.sum(sigma_down) < 1e-14:
271
+ # Save a network evaluation if all noise levels are 0
272
+ x = x_euler
273
+ else:
274
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
275
+ mult = [
276
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
277
+ ]
278
+
279
+ x2 = mult[0] * x - mult[1] * denoised
280
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
281
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
282
+
283
+ # apply correction if noise level is not 0
284
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
285
+
286
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
287
+ return x
288
+
289
+
290
+ class DPMPP2MSampler(BaseDiffusionSampler):
291
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
292
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
293
+ h = t_next - t
294
+
295
+ if previous_sigma is not None:
296
+ h_last = t - to_neg_log_sigma(previous_sigma)
297
+ r = h_last / h
298
+ return h, r, t, t_next
299
+ else:
300
+ return h, None, t, t_next
301
+
302
+ def get_mult(self, h, r, t, t_next, previous_sigma):
303
+ mult1 = to_sigma(t_next) / to_sigma(t)
304
+ mult2 = (-h).expm1()
305
+
306
+ if previous_sigma is not None:
307
+ mult3 = 1 + 1 / (2 * r)
308
+ mult4 = 1 / (2 * r)
309
+ return mult1, mult2, mult3, mult4
310
+ else:
311
+ return mult1, mult2
312
+
313
+ def sampler_step(
314
+ self,
315
+ old_denoised,
316
+ previous_sigma,
317
+ sigma,
318
+ next_sigma,
319
+ denoiser,
320
+ x,
321
+ cond,
322
+ uc=None,
323
+ ):
324
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
325
+
326
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
327
+ mult = [
328
+ append_dims(mult, x.ndim)
329
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
330
+ ]
331
+
332
+ x_standard = mult[0] * x - mult[1] * denoised
333
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
334
+ # Save a network evaluation if all noise levels are 0 or on the first step
335
+ return x_standard, denoised
336
+ else:
337
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
338
+ x_advanced = mult[0] * x - mult[1] * denoised_d
339
+
340
+ # apply correction if noise level is not 0 and not first step
341
+ x = torch.where(
342
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
343
+ )
344
+
345
+ return x, denoised
346
+
347
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
348
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
349
+ x, cond, uc, num_steps
350
+ )
351
+
352
+ old_denoised = None
353
+ for i in self.get_sigma_gen(num_sigmas):
354
+ x, old_denoised = self.sampler_step(
355
+ old_denoised,
356
+ None if i == 0 else s_in * sigmas[i - 1],
357
+ s_in * sigmas[i],
358
+ s_in * sigmas[i + 1],
359
+ denoiser,
360
+ x,
361
+ cond,
362
+ uc=uc,
363
+ )
364
+
365
+ return x
repositories/generative-models/sgm/modules/diffusionmodules/sampling_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy import integrate
3
+
4
+ from ...util import append_dims
5
+
6
+
7
+ class NoDynamicThresholding:
8
+ def __call__(self, uncond, cond, scale):
9
+ return uncond + scale * (cond - uncond)
10
+
11
+
12
+ def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
13
+ if order - 1 > i:
14
+ raise ValueError(f"Order {order} too high for step {i}")
15
+
16
+ def fn(tau):
17
+ prod = 1.0
18
+ for k in range(order):
19
+ if j == k:
20
+ continue
21
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
22
+ return prod
23
+
24
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
25
+
26
+
27
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
28
+ if not eta:
29
+ return sigma_to, 0.0
30
+ sigma_up = torch.minimum(
31
+ sigma_to,
32
+ eta
33
+ * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
34
+ )
35
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
36
+ return sigma_down, sigma_up
37
+
38
+
39
+ def to_d(x, sigma, denoised):
40
+ return (x - denoised) / append_dims(sigma, x.ndim)
41
+
42
+
43
+ def to_neg_log_sigma(sigma):
44
+ return sigma.log().neg()
45
+
46
+
47
+ def to_sigma(neg_log_sigma):
48
+ return neg_log_sigma.neg().exp()
repositories/generative-models/sgm/modules/diffusionmodules/sigma_sampling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ...util import default, instantiate_from_config
4
+
5
+
6
+ class EDMSampling:
7
+ def __init__(self, p_mean=-1.2, p_std=1.2):
8
+ self.p_mean = p_mean
9
+ self.p_std = p_std
10
+
11
+ def __call__(self, n_samples, rand=None):
12
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13
+ return log_sigma.exp()
14
+
15
+
16
+ class DiscreteSampling:
17
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
18
+ self.num_idx = num_idx
19
+ self.sigmas = instantiate_from_config(discretization_config)(
20
+ num_idx, do_append_zero=do_append_zero, flip=flip
21
+ )
22
+
23
+ def idx_to_sigma(self, idx):
24
+ return self.sigmas[idx]
25
+
26
+ def __call__(self, n_samples, rand=None):
27
+ idx = default(
28
+ rand,
29
+ torch.randint(0, self.num_idx, (n_samples,)),
30
+ )
31
+ return self.idx_to_sigma(idx)
repositories/generative-models/sgm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adopted from
3
+ https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
4
+ and
5
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
6
+ and
7
+ https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
8
+
9
+ thanks!
10
+ """
11
+
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import repeat
17
+
18
+
19
+ def make_beta_schedule(
20
+ schedule,
21
+ n_timestep,
22
+ linear_start=1e-4,
23
+ linear_end=2e-2,
24
+ ):
25
+ if schedule == "linear":
26
+ betas = (
27
+ torch.linspace(
28
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
29
+ )
30
+ ** 2
31
+ )
32
+ return betas.numpy()
33
+
34
+
35
+ def extract_into_tensor(a, t, x_shape):
36
+ b, *_ = t.shape
37
+ out = a.gather(-1, t)
38
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
39
+
40
+
41
+ def mixed_checkpoint(func, inputs: dict, params, flag):
42
+ """
43
+ Evaluate a function without caching intermediate activations, allowing for
44
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
45
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
46
+ it also works with non-tensor inputs
47
+ :param func: the function to evaluate.
48
+ :param inputs: the argument dictionary to pass to `func`.
49
+ :param params: a sequence of parameters `func` depends on but does not
50
+ explicitly take as arguments.
51
+ :param flag: if False, disable gradient checkpointing.
52
+ """
53
+ if flag:
54
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
55
+ tensor_inputs = [
56
+ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
57
+ ]
58
+ non_tensor_keys = [
59
+ key for key in inputs if not isinstance(inputs[key], torch.Tensor)
60
+ ]
61
+ non_tensor_inputs = [
62
+ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
63
+ ]
64
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
65
+ return MixedCheckpointFunction.apply(
66
+ func,
67
+ len(tensor_inputs),
68
+ len(non_tensor_inputs),
69
+ tensor_keys,
70
+ non_tensor_keys,
71
+ *args,
72
+ )
73
+ else:
74
+ return func(**inputs)
75
+
76
+
77
+ class MixedCheckpointFunction(torch.autograd.Function):
78
+ @staticmethod
79
+ def forward(
80
+ ctx,
81
+ run_function,
82
+ length_tensors,
83
+ length_non_tensors,
84
+ tensor_keys,
85
+ non_tensor_keys,
86
+ *args,
87
+ ):
88
+ ctx.end_tensors = length_tensors
89
+ ctx.end_non_tensors = length_tensors + length_non_tensors
90
+ ctx.gpu_autocast_kwargs = {
91
+ "enabled": torch.is_autocast_enabled(),
92
+ "dtype": torch.get_autocast_gpu_dtype(),
93
+ "cache_enabled": torch.is_autocast_cache_enabled(),
94
+ }
95
+ assert (
96
+ len(tensor_keys) == length_tensors
97
+ and len(non_tensor_keys) == length_non_tensors
98
+ )
99
+
100
+ ctx.input_tensors = {
101
+ key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
102
+ }
103
+ ctx.input_non_tensors = {
104
+ key: val
105
+ for (key, val) in zip(
106
+ non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
107
+ )
108
+ }
109
+ ctx.run_function = run_function
110
+ ctx.input_params = list(args[ctx.end_non_tensors :])
111
+
112
+ with torch.no_grad():
113
+ output_tensors = ctx.run_function(
114
+ **ctx.input_tensors, **ctx.input_non_tensors
115
+ )
116
+ return output_tensors
117
+
118
+ @staticmethod
119
+ def backward(ctx, *output_grads):
120
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
121
+ ctx.input_tensors = {
122
+ key: ctx.input_tensors[key].detach().requires_grad_(True)
123
+ for key in ctx.input_tensors
124
+ }
125
+
126
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
127
+ # Fixes a bug where the first op in run_function modifies the
128
+ # Tensor storage in place, which is not allowed for detach()'d
129
+ # Tensors.
130
+ shallow_copies = {
131
+ key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
132
+ for key in ctx.input_tensors
133
+ }
134
+ # shallow_copies.update(additional_args)
135
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
136
+ input_grads = torch.autograd.grad(
137
+ output_tensors,
138
+ list(ctx.input_tensors.values()) + ctx.input_params,
139
+ output_grads,
140
+ allow_unused=True,
141
+ )
142
+ del ctx.input_tensors
143
+ del ctx.input_params
144
+ del output_tensors
145
+ return (
146
+ (None, None, None, None, None)
147
+ + input_grads[: ctx.end_tensors]
148
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
149
+ + input_grads[ctx.end_tensors :]
150
+ )
151
+
152
+
153
+ def checkpoint(func, inputs, params, flag):
154
+ """
155
+ Evaluate a function without caching intermediate activations, allowing for
156
+ reduced memory at the expense of extra compute in the backward pass.
157
+ :param func: the function to evaluate.
158
+ :param inputs: the argument sequence to pass to `func`.
159
+ :param params: a sequence of parameters `func` depends on but does not
160
+ explicitly take as arguments.
161
+ :param flag: if False, disable gradient checkpointing.
162
+ """
163
+ if flag:
164
+ args = tuple(inputs) + tuple(params)
165
+ return CheckpointFunction.apply(func, len(inputs), *args)
166
+ else:
167
+ return func(*inputs)
168
+
169
+
170
+ class CheckpointFunction(torch.autograd.Function):
171
+ @staticmethod
172
+ def forward(ctx, run_function, length, *args):
173
+ ctx.run_function = run_function
174
+ ctx.input_tensors = list(args[:length])
175
+ ctx.input_params = list(args[length:])
176
+ ctx.gpu_autocast_kwargs = {
177
+ "enabled": torch.is_autocast_enabled(),
178
+ "dtype": torch.get_autocast_gpu_dtype(),
179
+ "cache_enabled": torch.is_autocast_cache_enabled(),
180
+ }
181
+ with torch.no_grad():
182
+ output_tensors = ctx.run_function(*ctx.input_tensors)
183
+ return output_tensors
184
+
185
+ @staticmethod
186
+ def backward(ctx, *output_grads):
187
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
188
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
189
+ # Fixes a bug where the first op in run_function modifies the
190
+ # Tensor storage in place, which is not allowed for detach()'d
191
+ # Tensors.
192
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
193
+ output_tensors = ctx.run_function(*shallow_copies)
194
+ input_grads = torch.autograd.grad(
195
+ output_tensors,
196
+ ctx.input_tensors + ctx.input_params,
197
+ output_grads,
198
+ allow_unused=True,
199
+ )
200
+ del ctx.input_tensors
201
+ del ctx.input_params
202
+ del output_tensors
203
+ return (None, None) + input_grads
204
+
205
+
206
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
207
+ """
208
+ Create sinusoidal timestep embeddings.
209
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
210
+ These may be fractional.
211
+ :param dim: the dimension of the output.
212
+ :param max_period: controls the minimum frequency of the embeddings.
213
+ :return: an [N x dim] Tensor of positional embeddings.
214
+ """
215
+ if not repeat_only:
216
+ half = dim // 2
217
+ freqs = torch.exp(
218
+ -math.log(max_period)
219
+ * torch.arange(start=0, end=half, dtype=torch.float32)
220
+ / half
221
+ ).to(device=timesteps.device)
222
+ args = timesteps[:, None].float() * freqs[None]
223
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
224
+ if dim % 2:
225
+ embedding = torch.cat(
226
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
227
+ )
228
+ else:
229
+ embedding = repeat(timesteps, "b -> b d", d=dim)
230
+ return embedding
231
+
232
+
233
+ def zero_module(module):
234
+ """
235
+ Zero out the parameters of a module and return it.
236
+ """
237
+ for p in module.parameters():
238
+ p.detach().zero_()
239
+ return module
240
+
241
+
242
+ def scale_module(module, scale):
243
+ """
244
+ Scale the parameters of a module and return it.
245
+ """
246
+ for p in module.parameters():
247
+ p.detach().mul_(scale)
248
+ return module
249
+
250
+
251
+ def mean_flat(tensor):
252
+ """
253
+ Take the mean over all non-batch dimensions.
254
+ """
255
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
256
+
257
+
258
+ def normalization(channels):
259
+ """
260
+ Make a standard normalization layer.
261
+ :param channels: number of input channels.
262
+ :return: an nn.Module for normalization.
263
+ """
264
+ return GroupNorm32(32, channels)
265
+
266
+
267
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
268
+ class SiLU(nn.Module):
269
+ def forward(self, x):
270
+ return x * torch.sigmoid(x)
271
+
272
+
273
+ class GroupNorm32(nn.GroupNorm):
274
+ def forward(self, x):
275
+ return super().forward(x.float()).type(x.dtype)
276
+
277
+
278
+ def conv_nd(dims, *args, **kwargs):
279
+ """
280
+ Create a 1D, 2D, or 3D convolution module.
281
+ """
282
+ if dims == 1:
283
+ return nn.Conv1d(*args, **kwargs)
284
+ elif dims == 2:
285
+ return nn.Conv2d(*args, **kwargs)
286
+ elif dims == 3:
287
+ return nn.Conv3d(*args, **kwargs)
288
+ raise ValueError(f"unsupported dimensions: {dims}")
289
+
290
+
291
+ def linear(*args, **kwargs):
292
+ """
293
+ Create a linear module.
294
+ """
295
+ return nn.Linear(*args, **kwargs)
296
+
297
+
298
+ def avg_pool_nd(dims, *args, **kwargs):
299
+ """
300
+ Create a 1D, 2D, or 3D average pooling module.
301
+ """
302
+ if dims == 1:
303
+ return nn.AvgPool1d(*args, **kwargs)
304
+ elif dims == 2:
305
+ return nn.AvgPool2d(*args, **kwargs)
306
+ elif dims == 3:
307
+ return nn.AvgPool3d(*args, **kwargs)
308
+ raise ValueError(f"unsupported dimensions: {dims}")
repositories/generative-models/sgm/modules/diffusionmodules/wrappers.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from packaging import version
4
+
5
+ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
6
+
7
+
8
+ class IdentityWrapper(nn.Module):
9
+ def __init__(self, diffusion_model, compile_model: bool = False):
10
+ super().__init__()
11
+ compile = (
12
+ torch.compile
13
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
14
+ and compile_model
15
+ else lambda x: x
16
+ )
17
+ self.diffusion_model = compile(diffusion_model)
18
+
19
+ def forward(self, *args, **kwargs):
20
+ return self.diffusion_model(*args, **kwargs)
21
+
22
+
23
+ class OpenAIWrapper(IdentityWrapper):
24
+ def forward(
25
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
26
+ ) -> torch.Tensor:
27
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
28
+ return self.diffusion_model(
29
+ x,
30
+ timesteps=t,
31
+ context=c.get("crossattn", None),
32
+ y=c.get("vector", None),
33
+ **kwargs
34
+ )
repositories/generative-models/sgm/modules/distributions/__init__.py ADDED
File without changes
repositories/generative-models/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (185 Bytes). View file
 
repositories/generative-models/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
repositories/generative-models/sgm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self):
38
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
+ device=self.parameters.device
40
+ )
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.0])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3],
51
+ )
52
+ else:
53
+ return 0.5 * torch.sum(
54
+ torch.pow(self.mean - other.mean, 2) / other.var
55
+ + self.var / other.var
56
+ - 1.0
57
+ - self.logvar
58
+ + other.logvar,
59
+ dim=[1, 2, 3],
60
+ )
61
+
62
+ def nll(self, sample, dims=[1, 2, 3]):
63
+ if self.deterministic:
64
+ return torch.Tensor([0.0])
65
+ logtwopi = np.log(2.0 * np.pi)
66
+ return 0.5 * torch.sum(
67
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
+ dim=dims,
69
+ )
70
+
71
+ def mode(self):
72
+ return self.mean
73
+
74
+
75
+ def normal_kl(mean1, logvar1, mean2, logvar2):
76
+ """
77
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )
repositories/generative-models/sgm/modules/ema.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ torch.tensor(0, dtype=torch.int)
16
+ if use_num_upates
17
+ else torch.tensor(-1, dtype=torch.int),
18
+ )
19
+
20
+ for name, p in model.named_parameters():
21
+ if p.requires_grad:
22
+ # remove as '.'-character is not allowed in buffers
23
+ s_name = name.replace(".", "")
24
+ self.m_name2s_name.update({name: s_name})
25
+ self.register_buffer(s_name, p.clone().detach().data)
26
+
27
+ self.collected_params = []
28
+
29
+ def reset_num_updates(self):
30
+ del self.num_updates
31
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
+
33
+ def forward(self, model):
34
+ decay = self.decay
35
+
36
+ if self.num_updates >= 0:
37
+ self.num_updates += 1
38
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
+
40
+ one_minus_decay = 1.0 - decay
41
+
42
+ with torch.no_grad():
43
+ m_param = dict(model.named_parameters())
44
+ shadow_params = dict(self.named_buffers())
45
+
46
+ for key in m_param:
47
+ if m_param[key].requires_grad:
48
+ sname = self.m_name2s_name[key]
49
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
+ shadow_params[sname].sub_(
51
+ one_minus_decay * (shadow_params[sname] - m_param[key])
52
+ )
53
+ else:
54
+ assert not key in self.m_name2s_name
55
+
56
+ def copy_to(self, model):
57
+ m_param = dict(model.named_parameters())
58
+ shadow_params = dict(self.named_buffers())
59
+ for key in m_param:
60
+ if m_param[key].requires_grad:
61
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
+ else:
63
+ assert not key in self.m_name2s_name
64
+
65
+ def store(self, parameters):
66
+ """
67
+ Save the current parameters for restoring later.
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+ Args:
82
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
+ updated with the stored parameters.
84
+ """
85
+ for c_param, param in zip(self.collected_params, parameters):
86
+ param.data.copy_(c_param.data)
repositories/generative-models/sgm/modules/encoders/__init__.py ADDED
File without changes
repositories/generative-models/sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (180 Bytes). View file
 
repositories/generative-models/sgm/modules/encoders/__pycache__/modules.cpython-310.pyc ADDED
Binary file (26.7 kB). View file
 
repositories/generative-models/sgm/modules/encoders/modules.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ from functools import partial
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import kornia
6
+ import numpy as np
7
+ import open_clip
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from omegaconf import ListConfig
12
+ from torch.utils.checkpoint import checkpoint
13
+ from transformers import (
14
+ ByT5Tokenizer,
15
+ CLIPTextModel,
16
+ CLIPTokenizer,
17
+ T5EncoderModel,
18
+ T5Tokenizer,
19
+ )
20
+
21
+ from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
22
+ from ...modules.diffusionmodules.model import Encoder
23
+ from ...modules.diffusionmodules.openaimodel import Timestep
24
+ from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
25
+ from ...modules.distributions.distributions import DiagonalGaussianDistribution
26
+ from ...util import (
27
+ autocast,
28
+ count_params,
29
+ default,
30
+ disabled_train,
31
+ expand_dims_like,
32
+ instantiate_from_config,
33
+ )
34
+
35
+
36
+ class AbstractEmbModel(nn.Module):
37
+ def __init__(self):
38
+ super().__init__()
39
+ self._is_trainable = None
40
+ self._ucg_rate = None
41
+ self._input_key = None
42
+
43
+ @property
44
+ def is_trainable(self) -> bool:
45
+ return self._is_trainable
46
+
47
+ @property
48
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
49
+ return self._ucg_rate
50
+
51
+ @property
52
+ def input_key(self) -> str:
53
+ return self._input_key
54
+
55
+ @is_trainable.setter
56
+ def is_trainable(self, value: bool):
57
+ self._is_trainable = value
58
+
59
+ @ucg_rate.setter
60
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
61
+ self._ucg_rate = value
62
+
63
+ @input_key.setter
64
+ def input_key(self, value: str):
65
+ self._input_key = value
66
+
67
+ @is_trainable.deleter
68
+ def is_trainable(self):
69
+ del self._is_trainable
70
+
71
+ @ucg_rate.deleter
72
+ def ucg_rate(self):
73
+ del self._ucg_rate
74
+
75
+ @input_key.deleter
76
+ def input_key(self):
77
+ del self._input_key
78
+
79
+
80
+ class GeneralConditioner(nn.Module):
81
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
82
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
83
+
84
+ def __init__(self, emb_models: Union[List, ListConfig]):
85
+ super().__init__()
86
+ embedders = []
87
+ for n, embconfig in enumerate(emb_models):
88
+ embedder = instantiate_from_config(embconfig)
89
+ assert isinstance(
90
+ embedder, AbstractEmbModel
91
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
92
+ embedder.is_trainable = embconfig.get("is_trainable", False)
93
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
94
+ if not embedder.is_trainable:
95
+ embedder.train = disabled_train
96
+ for param in embedder.parameters():
97
+ param.requires_grad = False
98
+ embedder.eval()
99
+ print(
100
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
101
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
102
+ )
103
+
104
+ if "input_key" in embconfig:
105
+ embedder.input_key = embconfig["input_key"]
106
+ elif "input_keys" in embconfig:
107
+ embedder.input_keys = embconfig["input_keys"]
108
+ else:
109
+ raise KeyError(
110
+ f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
111
+ )
112
+
113
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
114
+ if embedder.legacy_ucg_val is not None:
115
+ embedder.ucg_prng = np.random.RandomState()
116
+
117
+ embedders.append(embedder)
118
+ self.embedders = nn.ModuleList(embedders)
119
+
120
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
121
+ assert embedder.legacy_ucg_val is not None
122
+ p = embedder.ucg_rate
123
+ val = embedder.legacy_ucg_val
124
+ for i in range(len(batch[embedder.input_key])):
125
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
126
+ batch[embedder.input_key][i] = val
127
+ return batch
128
+
129
+ def forward(
130
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
131
+ ) -> Dict:
132
+ output = dict()
133
+ if force_zero_embeddings is None:
134
+ force_zero_embeddings = []
135
+ for embedder in self.embedders:
136
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
137
+ with embedding_context():
138
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
139
+ if embedder.legacy_ucg_val is not None:
140
+ batch = self.possibly_get_ucg_val(embedder, batch)
141
+ emb_out = embedder(batch[embedder.input_key])
142
+ elif hasattr(embedder, "input_keys"):
143
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
144
+ assert isinstance(
145
+ emb_out, (torch.Tensor, list, tuple)
146
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
147
+ if not isinstance(emb_out, (list, tuple)):
148
+ emb_out = [emb_out]
149
+ for emb in emb_out:
150
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
151
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
152
+ emb = (
153
+ expand_dims_like(
154
+ torch.bernoulli(
155
+ (1.0 - embedder.ucg_rate)
156
+ * torch.ones(emb.shape[0], device=emb.device)
157
+ ),
158
+ emb,
159
+ )
160
+ * emb
161
+ )
162
+ if (
163
+ hasattr(embedder, "input_key")
164
+ and embedder.input_key in force_zero_embeddings
165
+ ):
166
+ emb = torch.zeros_like(emb)
167
+ if out_key in output:
168
+ output[out_key] = torch.cat(
169
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
170
+ )
171
+ else:
172
+ output[out_key] = emb
173
+ return output
174
+
175
+ def get_unconditional_conditioning(
176
+ self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
177
+ ):
178
+ if force_uc_zero_embeddings is None:
179
+ force_uc_zero_embeddings = []
180
+ ucg_rates = list()
181
+ for embedder in self.embedders:
182
+ ucg_rates.append(embedder.ucg_rate)
183
+ embedder.ucg_rate = 0.0
184
+ c = self(batch_c)
185
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
186
+
187
+ for embedder, rate in zip(self.embedders, ucg_rates):
188
+ embedder.ucg_rate = rate
189
+ return c, uc
190
+
191
+
192
+ class InceptionV3(nn.Module):
193
+ """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
194
+ port with an additional squeeze at the end"""
195
+
196
+ def __init__(self, normalize_input=False, **kwargs):
197
+ super().__init__()
198
+ from pytorch_fid import inception
199
+
200
+ kwargs["resize_input"] = True
201
+ self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
202
+
203
+ def forward(self, inp):
204
+ # inp = kornia.geometry.resize(inp, (299, 299),
205
+ # interpolation='bicubic',
206
+ # align_corners=False,
207
+ # antialias=True)
208
+ # inp = inp.clamp(min=-1, max=1)
209
+
210
+ outp = self.model(inp)
211
+
212
+ if len(outp) == 1:
213
+ return outp[0].squeeze()
214
+
215
+ return outp
216
+
217
+
218
+ class IdentityEncoder(AbstractEmbModel):
219
+ def encode(self, x):
220
+ return x
221
+
222
+ def forward(self, x):
223
+ return x
224
+
225
+
226
+ class ClassEmbedder(AbstractEmbModel):
227
+ def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
228
+ super().__init__()
229
+ self.embedding = nn.Embedding(n_classes, embed_dim)
230
+ self.n_classes = n_classes
231
+ self.add_sequence_dim = add_sequence_dim
232
+
233
+ def forward(self, c):
234
+ c = self.embedding(c)
235
+ if self.add_sequence_dim:
236
+ c = c[:, None, :]
237
+ return c
238
+
239
+ def get_unconditional_conditioning(self, bs, device="cuda"):
240
+ uc_class = (
241
+ self.n_classes - 1
242
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
243
+ uc = torch.ones((bs,), device=device) * uc_class
244
+ uc = {self.key: uc.long()}
245
+ return uc
246
+
247
+
248
+ class ClassEmbedderForMultiCond(ClassEmbedder):
249
+ def forward(self, batch, key=None, disable_dropout=False):
250
+ out = batch
251
+ key = default(key, self.key)
252
+ islist = isinstance(batch[key], list)
253
+ if islist:
254
+ batch[key] = batch[key][0]
255
+ c_out = super().forward(batch, key, disable_dropout)
256
+ out[key] = [c_out] if islist else c_out
257
+ return out
258
+
259
+
260
+ class FrozenT5Embedder(AbstractEmbModel):
261
+ """Uses the T5 transformer encoder for text"""
262
+
263
+ def __init__(
264
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
265
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
266
+ super().__init__()
267
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
268
+ self.transformer = T5EncoderModel.from_pretrained(version)
269
+ self.device = device
270
+ self.max_length = max_length
271
+ if freeze:
272
+ self.freeze()
273
+
274
+ def freeze(self):
275
+ self.transformer = self.transformer.eval()
276
+
277
+ for param in self.parameters():
278
+ param.requires_grad = False
279
+
280
+ # @autocast
281
+ def forward(self, text):
282
+ batch_encoding = self.tokenizer(
283
+ text,
284
+ truncation=True,
285
+ max_length=self.max_length,
286
+ return_length=True,
287
+ return_overflowing_tokens=False,
288
+ padding="max_length",
289
+ return_tensors="pt",
290
+ )
291
+ tokens = batch_encoding["input_ids"].to(self.device)
292
+ with torch.autocast("cuda", enabled=False):
293
+ outputs = self.transformer(input_ids=tokens)
294
+ z = outputs.last_hidden_state
295
+ return z
296
+
297
+ def encode(self, text):
298
+ return self(text)
299
+
300
+
301
+ class FrozenByT5Embedder(AbstractEmbModel):
302
+ """
303
+ Uses the ByT5 transformer encoder for text. Is character-aware.
304
+ """
305
+
306
+ def __init__(
307
+ self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
308
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
309
+ super().__init__()
310
+ self.tokenizer = ByT5Tokenizer.from_pretrained(version)
311
+ self.transformer = T5EncoderModel.from_pretrained(version)
312
+ self.device = device
313
+ self.max_length = max_length
314
+ if freeze:
315
+ self.freeze()
316
+
317
+ def freeze(self):
318
+ self.transformer = self.transformer.eval()
319
+
320
+ for param in self.parameters():
321
+ param.requires_grad = False
322
+
323
+ def forward(self, text):
324
+ batch_encoding = self.tokenizer(
325
+ text,
326
+ truncation=True,
327
+ max_length=self.max_length,
328
+ return_length=True,
329
+ return_overflowing_tokens=False,
330
+ padding="max_length",
331
+ return_tensors="pt",
332
+ )
333
+ tokens = batch_encoding["input_ids"].to(self.device)
334
+ with torch.autocast("cuda", enabled=False):
335
+ outputs = self.transformer(input_ids=tokens)
336
+ z = outputs.last_hidden_state
337
+ return z
338
+
339
+ def encode(self, text):
340
+ return self(text)
341
+
342
+
343
+ class FrozenCLIPEmbedder(AbstractEmbModel):
344
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
345
+
346
+ LAYERS = ["last", "pooled", "hidden"]
347
+
348
+ def __init__(
349
+ self,
350
+ version="openai/clip-vit-large-patch14",
351
+ device="cuda",
352
+ max_length=77,
353
+ freeze=True,
354
+ layer="last",
355
+ layer_idx=None,
356
+ always_return_pooled=False,
357
+ ): # clip-vit-base-patch32
358
+ super().__init__()
359
+ assert layer in self.LAYERS
360
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
361
+ self.transformer = CLIPTextModel.from_pretrained(version)
362
+ self.device = device
363
+ self.max_length = max_length
364
+ if freeze:
365
+ self.freeze()
366
+ self.layer = layer
367
+ self.layer_idx = layer_idx
368
+ self.return_pooled = always_return_pooled
369
+ if layer == "hidden":
370
+ assert layer_idx is not None
371
+ assert 0 <= abs(layer_idx) <= 12
372
+
373
+ def freeze(self):
374
+ self.transformer = self.transformer.eval()
375
+
376
+ for param in self.parameters():
377
+ param.requires_grad = False
378
+
379
+ @autocast
380
+ def forward(self, text):
381
+ batch_encoding = self.tokenizer(
382
+ text,
383
+ truncation=True,
384
+ max_length=self.max_length,
385
+ return_length=True,
386
+ return_overflowing_tokens=False,
387
+ padding="max_length",
388
+ return_tensors="pt",
389
+ )
390
+ tokens = batch_encoding["input_ids"].to(self.device)
391
+ outputs = self.transformer(
392
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
393
+ )
394
+ if self.layer == "last":
395
+ z = outputs.last_hidden_state
396
+ elif self.layer == "pooled":
397
+ z = outputs.pooler_output[:, None, :]
398
+ else:
399
+ z = outputs.hidden_states[self.layer_idx]
400
+ if self.return_pooled:
401
+ return z, outputs.pooler_output
402
+ return z
403
+
404
+ def encode(self, text):
405
+ return self(text)
406
+
407
+
408
+ class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
409
+ """
410
+ Uses the OpenCLIP transformer encoder for text
411
+ """
412
+
413
+ LAYERS = ["pooled", "last", "penultimate"]
414
+
415
+ def __init__(
416
+ self,
417
+ arch="ViT-H-14",
418
+ version="laion2b_s32b_b79k",
419
+ device="cuda",
420
+ max_length=77,
421
+ freeze=True,
422
+ layer="last",
423
+ always_return_pooled=False,
424
+ legacy=True,
425
+ ):
426
+ super().__init__()
427
+ assert layer in self.LAYERS
428
+ model, _, _ = open_clip.create_model_and_transforms(
429
+ arch,
430
+ device=torch.device("cpu"),
431
+ pretrained=version,
432
+ )
433
+ del model.visual
434
+ self.model = model
435
+
436
+ self.device = device
437
+ self.max_length = max_length
438
+ self.return_pooled = always_return_pooled
439
+ if freeze:
440
+ self.freeze()
441
+ self.layer = layer
442
+ if self.layer == "last":
443
+ self.layer_idx = 0
444
+ elif self.layer == "penultimate":
445
+ self.layer_idx = 1
446
+ else:
447
+ raise NotImplementedError()
448
+ self.legacy = legacy
449
+
450
+ def freeze(self):
451
+ self.model = self.model.eval()
452
+ for param in self.parameters():
453
+ param.requires_grad = False
454
+
455
+ @autocast
456
+ def forward(self, text):
457
+ tokens = open_clip.tokenize(text)
458
+ z = self.encode_with_transformer(tokens.to(self.device))
459
+ if not self.return_pooled and self.legacy:
460
+ return z
461
+ if self.return_pooled:
462
+ assert not self.legacy
463
+ return z[self.layer], z["pooled"]
464
+ return z[self.layer]
465
+
466
+ def encode_with_transformer(self, text):
467
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
468
+ x = x + self.model.positional_embedding
469
+ x = x.permute(1, 0, 2) # NLD -> LND
470
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
471
+ if self.legacy:
472
+ x = x[self.layer]
473
+ x = self.model.ln_final(x)
474
+ return x
475
+ else:
476
+ # x is a dict and will stay a dict
477
+ o = x["last"]
478
+ o = self.model.ln_final(o)
479
+ pooled = self.pool(o, text)
480
+ x["pooled"] = pooled
481
+ return x
482
+
483
+ def pool(self, x, text):
484
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
485
+ x = (
486
+ x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
487
+ @ self.model.text_projection
488
+ )
489
+ return x
490
+
491
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
492
+ outputs = {}
493
+ for i, r in enumerate(self.model.transformer.resblocks):
494
+ if i == len(self.model.transformer.resblocks) - 1:
495
+ outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
496
+ if (
497
+ self.model.transformer.grad_checkpointing
498
+ and not torch.jit.is_scripting()
499
+ ):
500
+ x = checkpoint(r, x, attn_mask)
501
+ else:
502
+ x = r(x, attn_mask=attn_mask)
503
+ outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
504
+ return outputs
505
+
506
+ def encode(self, text):
507
+ return self(text)
508
+
509
+
510
+ class FrozenOpenCLIPEmbedder(AbstractEmbModel):
511
+ LAYERS = [
512
+ # "pooled",
513
+ "last",
514
+ "penultimate",
515
+ ]
516
+
517
+ def __init__(
518
+ self,
519
+ arch="ViT-H-14",
520
+ version="laion2b_s32b_b79k",
521
+ device="cuda",
522
+ max_length=77,
523
+ freeze=True,
524
+ layer="last",
525
+ ):
526
+ super().__init__()
527
+ assert layer in self.LAYERS
528
+ model, _, _ = open_clip.create_model_and_transforms(
529
+ arch, device=torch.device("cpu"), pretrained=version
530
+ )
531
+ del model.visual
532
+ self.model = model
533
+
534
+ self.device = device
535
+ self.max_length = max_length
536
+ if freeze:
537
+ self.freeze()
538
+ self.layer = layer
539
+ if self.layer == "last":
540
+ self.layer_idx = 0
541
+ elif self.layer == "penultimate":
542
+ self.layer_idx = 1
543
+ else:
544
+ raise NotImplementedError()
545
+
546
+ def freeze(self):
547
+ self.model = self.model.eval()
548
+ for param in self.parameters():
549
+ param.requires_grad = False
550
+
551
+ def forward(self, text):
552
+ tokens = open_clip.tokenize(text)
553
+ z = self.encode_with_transformer(tokens.to(self.device))
554
+ return z
555
+
556
+ def encode_with_transformer(self, text):
557
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
558
+ x = x + self.model.positional_embedding
559
+ x = x.permute(1, 0, 2) # NLD -> LND
560
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
561
+ x = x.permute(1, 0, 2) # LND -> NLD
562
+ x = self.model.ln_final(x)
563
+ return x
564
+
565
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
566
+ for i, r in enumerate(self.model.transformer.resblocks):
567
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
568
+ break
569
+ if (
570
+ self.model.transformer.grad_checkpointing
571
+ and not torch.jit.is_scripting()
572
+ ):
573
+ x = checkpoint(r, x, attn_mask)
574
+ else:
575
+ x = r(x, attn_mask=attn_mask)
576
+ return x
577
+
578
+ def encode(self, text):
579
+ return self(text)
580
+
581
+
582
+ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
583
+ """
584
+ Uses the OpenCLIP vision transformer encoder for images
585
+ """
586
+
587
+ def __init__(
588
+ self,
589
+ arch="ViT-H-14",
590
+ version="laion2b_s32b_b79k",
591
+ device="cuda",
592
+ max_length=77,
593
+ freeze=True,
594
+ antialias=True,
595
+ ucg_rate=0.0,
596
+ unsqueeze_dim=False,
597
+ repeat_to_max_len=False,
598
+ num_image_crops=0,
599
+ output_tokens=False,
600
+ ):
601
+ super().__init__()
602
+ model, _, _ = open_clip.create_model_and_transforms(
603
+ arch,
604
+ device=torch.device("cpu"),
605
+ pretrained=version,
606
+ )
607
+ del model.transformer
608
+ self.model = model
609
+ self.max_crops = num_image_crops
610
+ self.pad_to_max_len = self.max_crops > 0
611
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
612
+ self.device = device
613
+ self.max_length = max_length
614
+ if freeze:
615
+ self.freeze()
616
+
617
+ self.antialias = antialias
618
+
619
+ self.register_buffer(
620
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
621
+ )
622
+ self.register_buffer(
623
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
624
+ )
625
+ self.ucg_rate = ucg_rate
626
+ self.unsqueeze_dim = unsqueeze_dim
627
+ self.stored_batch = None
628
+ self.model.visual.output_tokens = output_tokens
629
+ self.output_tokens = output_tokens
630
+
631
+ def preprocess(self, x):
632
+ # normalize to [0,1]
633
+ x = kornia.geometry.resize(
634
+ x,
635
+ (224, 224),
636
+ interpolation="bicubic",
637
+ align_corners=True,
638
+ antialias=self.antialias,
639
+ )
640
+ x = (x + 1.0) / 2.0
641
+ # renormalize according to clip
642
+ x = kornia.enhance.normalize(x, self.mean, self.std)
643
+ return x
644
+
645
+ def freeze(self):
646
+ self.model = self.model.eval()
647
+ for param in self.parameters():
648
+ param.requires_grad = False
649
+
650
+ @autocast
651
+ def forward(self, image, no_dropout=False):
652
+ z = self.encode_with_vision_transformer(image)
653
+ tokens = None
654
+ if self.output_tokens:
655
+ z, tokens = z[0], z[1]
656
+ z = z.to(image.dtype)
657
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
658
+ z = (
659
+ torch.bernoulli(
660
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
661
+ )[:, None]
662
+ * z
663
+ )
664
+ if tokens is not None:
665
+ tokens = (
666
+ expand_dims_like(
667
+ torch.bernoulli(
668
+ (1.0 - self.ucg_rate)
669
+ * torch.ones(tokens.shape[0], device=tokens.device)
670
+ ),
671
+ tokens,
672
+ )
673
+ * tokens
674
+ )
675
+ if self.unsqueeze_dim:
676
+ z = z[:, None, :]
677
+ if self.output_tokens:
678
+ assert not self.repeat_to_max_len
679
+ assert not self.pad_to_max_len
680
+ return tokens, z
681
+ if self.repeat_to_max_len:
682
+ if z.dim() == 2:
683
+ z_ = z[:, None, :]
684
+ else:
685
+ z_ = z
686
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
687
+ elif self.pad_to_max_len:
688
+ assert z.dim() == 3
689
+ z_pad = torch.cat(
690
+ (
691
+ z,
692
+ torch.zeros(
693
+ z.shape[0],
694
+ self.max_length - z.shape[1],
695
+ z.shape[2],
696
+ device=z.device,
697
+ ),
698
+ ),
699
+ 1,
700
+ )
701
+ return z_pad, z_pad[:, 0, ...]
702
+ return z
703
+
704
+ def encode_with_vision_transformer(self, img):
705
+ # if self.max_crops > 0:
706
+ # img = self.preprocess_by_cropping(img)
707
+ if img.dim() == 5:
708
+ assert self.max_crops == img.shape[1]
709
+ img = rearrange(img, "b n c h w -> (b n) c h w")
710
+ img = self.preprocess(img)
711
+ if not self.output_tokens:
712
+ assert not self.model.visual.output_tokens
713
+ x = self.model.visual(img)
714
+ tokens = None
715
+ else:
716
+ assert self.model.visual.output_tokens
717
+ x, tokens = self.model.visual(img)
718
+ if self.max_crops > 0:
719
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
720
+ # drop out between 0 and all along the sequence axis
721
+ x = (
722
+ torch.bernoulli(
723
+ (1.0 - self.ucg_rate)
724
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
725
+ )
726
+ * x
727
+ )
728
+ if tokens is not None:
729
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
730
+ print(
731
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
732
+ f"Check what you are doing, and then remove this message."
733
+ )
734
+ if self.output_tokens:
735
+ return x, tokens
736
+ return x
737
+
738
+ def encode(self, text):
739
+ return self(text)
740
+
741
+
742
+ class FrozenCLIPT5Encoder(AbstractEmbModel):
743
+ def __init__(
744
+ self,
745
+ clip_version="openai/clip-vit-large-patch14",
746
+ t5_version="google/t5-v1_1-xl",
747
+ device="cuda",
748
+ clip_max_length=77,
749
+ t5_max_length=77,
750
+ ):
751
+ super().__init__()
752
+ self.clip_encoder = FrozenCLIPEmbedder(
753
+ clip_version, device, max_length=clip_max_length
754
+ )
755
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
756
+ print(
757
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
758
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
759
+ )
760
+
761
+ def encode(self, text):
762
+ return self(text)
763
+
764
+ def forward(self, text):
765
+ clip_z = self.clip_encoder.encode(text)
766
+ t5_z = self.t5_encoder.encode(text)
767
+ return [clip_z, t5_z]
768
+
769
+
770
+ class SpatialRescaler(nn.Module):
771
+ def __init__(
772
+ self,
773
+ n_stages=1,
774
+ method="bilinear",
775
+ multiplier=0.5,
776
+ in_channels=3,
777
+ out_channels=None,
778
+ bias=False,
779
+ wrap_video=False,
780
+ kernel_size=1,
781
+ remap_output=False,
782
+ ):
783
+ super().__init__()
784
+ self.n_stages = n_stages
785
+ assert self.n_stages >= 0
786
+ assert method in [
787
+ "nearest",
788
+ "linear",
789
+ "bilinear",
790
+ "trilinear",
791
+ "bicubic",
792
+ "area",
793
+ ]
794
+ self.multiplier = multiplier
795
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
796
+ self.remap_output = out_channels is not None or remap_output
797
+ if self.remap_output:
798
+ print(
799
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
800
+ )
801
+ self.channel_mapper = nn.Conv2d(
802
+ in_channels,
803
+ out_channels,
804
+ kernel_size=kernel_size,
805
+ bias=bias,
806
+ padding=kernel_size // 2,
807
+ )
808
+ self.wrap_video = wrap_video
809
+
810
+ def forward(self, x):
811
+ if self.wrap_video and x.ndim == 5:
812
+ B, C, T, H, W = x.shape
813
+ x = rearrange(x, "b c t h w -> b t c h w")
814
+ x = rearrange(x, "b t c h w -> (b t) c h w")
815
+
816
+ for stage in range(self.n_stages):
817
+ x = self.interpolator(x, scale_factor=self.multiplier)
818
+
819
+ if self.wrap_video:
820
+ x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
821
+ x = rearrange(x, "b t c h w -> b c t h w")
822
+ if self.remap_output:
823
+ x = self.channel_mapper(x)
824
+ return x
825
+
826
+ def encode(self, x):
827
+ return self(x)
828
+
829
+
830
+ class LowScaleEncoder(nn.Module):
831
+ def __init__(
832
+ self,
833
+ model_config,
834
+ linear_start,
835
+ linear_end,
836
+ timesteps=1000,
837
+ max_noise_level=250,
838
+ output_size=64,
839
+ scale_factor=1.0,
840
+ ):
841
+ super().__init__()
842
+ self.max_noise_level = max_noise_level
843
+ self.model = instantiate_from_config(model_config)
844
+ self.augmentation_schedule = self.register_schedule(
845
+ timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
846
+ )
847
+ self.out_size = output_size
848
+ self.scale_factor = scale_factor
849
+
850
+ def register_schedule(
851
+ self,
852
+ beta_schedule="linear",
853
+ timesteps=1000,
854
+ linear_start=1e-4,
855
+ linear_end=2e-2,
856
+ cosine_s=8e-3,
857
+ ):
858
+ betas = make_beta_schedule(
859
+ beta_schedule,
860
+ timesteps,
861
+ linear_start=linear_start,
862
+ linear_end=linear_end,
863
+ cosine_s=cosine_s,
864
+ )
865
+ alphas = 1.0 - betas
866
+ alphas_cumprod = np.cumprod(alphas, axis=0)
867
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
868
+
869
+ (timesteps,) = betas.shape
870
+ self.num_timesteps = int(timesteps)
871
+ self.linear_start = linear_start
872
+ self.linear_end = linear_end
873
+ assert (
874
+ alphas_cumprod.shape[0] == self.num_timesteps
875
+ ), "alphas have to be defined for each timestep"
876
+
877
+ to_torch = partial(torch.tensor, dtype=torch.float32)
878
+
879
+ self.register_buffer("betas", to_torch(betas))
880
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
881
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
882
+
883
+ # calculations for diffusion q(x_t | x_{t-1}) and others
884
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
885
+ self.register_buffer(
886
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
887
+ )
888
+ self.register_buffer(
889
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
890
+ )
891
+ self.register_buffer(
892
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
893
+ )
894
+ self.register_buffer(
895
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
896
+ )
897
+
898
+ def q_sample(self, x_start, t, noise=None):
899
+ noise = default(noise, lambda: torch.randn_like(x_start))
900
+ return (
901
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
902
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
903
+ * noise
904
+ )
905
+
906
+ def forward(self, x):
907
+ z = self.model.encode(x)
908
+ if isinstance(z, DiagonalGaussianDistribution):
909
+ z = z.sample()
910
+ z = z * self.scale_factor
911
+ noise_level = torch.randint(
912
+ 0, self.max_noise_level, (x.shape[0],), device=x.device
913
+ ).long()
914
+ z = self.q_sample(z, noise_level)
915
+ if self.out_size is not None:
916
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
917
+ # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
918
+ return z, noise_level
919
+
920
+ def decode(self, z):
921
+ z = z / self.scale_factor
922
+ return self.model.decode(z)
923
+
924
+
925
+ class ConcatTimestepEmbedderND(AbstractEmbModel):
926
+ """embeds each dimension independently and concatenates them"""
927
+
928
+ def __init__(self, outdim):
929
+ super().__init__()
930
+ self.timestep = Timestep(outdim)
931
+ self.outdim = outdim
932
+
933
+ def forward(self, x):
934
+ if x.ndim == 1:
935
+ x = x[:, None]
936
+ assert len(x.shape) == 2
937
+ b, dims = x.shape[0], x.shape[1]
938
+ x = rearrange(x, "b d -> (b d)")
939
+ emb = self.timestep(x)
940
+ emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
941
+ return emb
942
+
943
+
944
+ class GaussianEncoder(Encoder, AbstractEmbModel):
945
+ def __init__(
946
+ self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
947
+ ):
948
+ super().__init__(*args, **kwargs)
949
+ self.posterior = DiagonalGaussianRegularizer()
950
+ self.weight = weight
951
+ self.flatten_output = flatten_output
952
+
953
+ def forward(self, x) -> Tuple[Dict, torch.Tensor]:
954
+ z = super().forward(x)
955
+ z, log = self.posterior(z)
956
+ log["loss"] = log["kl_loss"]
957
+ log["weight"] = self.weight
958
+ if self.flatten_output:
959
+ z = rearrange(z, "b c h w -> b (h w ) c")
960
+ return log, z
repositories/generative-models/sgm/util.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import importlib
3
+ import os
4
+ from functools import partial
5
+ from inspect import isfunction
6
+
7
+ import fsspec
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ from safetensors.torch import load_file as load_safetensors
12
+
13
+
14
+ def disabled_train(self, mode=True):
15
+ """Overwrite model.train with this function to make sure train/eval mode
16
+ does not change anymore."""
17
+ return self
18
+
19
+
20
+ def get_string_from_tuple(s):
21
+ try:
22
+ # Check if the string starts and ends with parentheses
23
+ if s[0] == "(" and s[-1] == ")":
24
+ # Convert the string to a tuple
25
+ t = eval(s)
26
+ # Check if the type of t is tuple
27
+ if type(t) == tuple:
28
+ return t[0]
29
+ else:
30
+ pass
31
+ except:
32
+ pass
33
+ return s
34
+
35
+
36
+ def is_power_of_two(n):
37
+ """
38
+ chat.openai.com/chat
39
+ Return True if n is a power of 2, otherwise return False.
40
+
41
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
42
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
43
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
44
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
45
+
46
+ """
47
+ if n <= 0:
48
+ return False
49
+ return (n & (n - 1)) == 0
50
+
51
+
52
+ def autocast(f, enabled=True):
53
+ def do_autocast(*args, **kwargs):
54
+ with torch.cuda.amp.autocast(
55
+ enabled=enabled,
56
+ dtype=torch.get_autocast_gpu_dtype(),
57
+ cache_enabled=torch.is_autocast_cache_enabled(),
58
+ ):
59
+ return f(*args, **kwargs)
60
+
61
+ return do_autocast
62
+
63
+
64
+ def load_partial_from_config(config):
65
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
66
+
67
+
68
+ def log_txt_as_img(wh, xc, size=10):
69
+ # wh a tuple of (width, height)
70
+ # xc a list of captions to plot
71
+ b = len(xc)
72
+ txts = list()
73
+ for bi in range(b):
74
+ txt = Image.new("RGB", wh, color="white")
75
+ draw = ImageDraw.Draw(txt)
76
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
77
+ nc = int(40 * (wh[0] / 256))
78
+ if isinstance(xc[bi], list):
79
+ text_seq = xc[bi][0]
80
+ else:
81
+ text_seq = xc[bi]
82
+ lines = "\n".join(
83
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
84
+ )
85
+
86
+ try:
87
+ draw.text((0, 0), lines, fill="black", font=font)
88
+ except UnicodeEncodeError:
89
+ print("Cant encode string for logging. Skipping.")
90
+
91
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
92
+ txts.append(txt)
93
+ txts = np.stack(txts)
94
+ txts = torch.tensor(txts)
95
+ return txts
96
+
97
+
98
+ def partialclass(cls, *args, **kwargs):
99
+ class NewCls(cls):
100
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
101
+
102
+ return NewCls
103
+
104
+
105
+ def make_path_absolute(path):
106
+ fs, p = fsspec.core.url_to_fs(path)
107
+ if fs.protocol == "file":
108
+ return os.path.abspath(p)
109
+ return path
110
+
111
+
112
+ def ismap(x):
113
+ if not isinstance(x, torch.Tensor):
114
+ return False
115
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
116
+
117
+
118
+ def isimage(x):
119
+ if not isinstance(x, torch.Tensor):
120
+ return False
121
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
122
+
123
+
124
+ def isheatmap(x):
125
+ if not isinstance(x, torch.Tensor):
126
+ return False
127
+
128
+ return x.ndim == 2
129
+
130
+
131
+ def isneighbors(x):
132
+ if not isinstance(x, torch.Tensor):
133
+ return False
134
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
135
+
136
+
137
+ def exists(x):
138
+ return x is not None
139
+
140
+
141
+ def expand_dims_like(x, y):
142
+ while x.dim() != y.dim():
143
+ x = x.unsqueeze(-1)
144
+ return x
145
+
146
+
147
+ def default(val, d):
148
+ if exists(val):
149
+ return val
150
+ return d() if isfunction(d) else d
151
+
152
+
153
+ def mean_flat(tensor):
154
+ """
155
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
156
+ Take the mean over all non-batch dimensions.
157
+ """
158
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
159
+
160
+
161
+ def count_params(model, verbose=False):
162
+ total_params = sum(p.numel() for p in model.parameters())
163
+ if verbose:
164
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
165
+ return total_params
166
+
167
+
168
+ def instantiate_from_config(config):
169
+ if not "target" in config:
170
+ if config == "__is_first_stage__":
171
+ return None
172
+ elif config == "__is_unconditional__":
173
+ return None
174
+ raise KeyError("Expected key `target` to instantiate.")
175
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
176
+
177
+
178
+ def get_obj_from_str(string, reload=False, invalidate_cache=True):
179
+ module, cls = string.rsplit(".", 1)
180
+ if invalidate_cache:
181
+ importlib.invalidate_caches()
182
+ if reload:
183
+ module_imp = importlib.import_module(module)
184
+ importlib.reload(module_imp)
185
+ return getattr(importlib.import_module(module, package=None), cls)
186
+
187
+
188
+ def append_zero(x):
189
+ return torch.cat([x, x.new_zeros([1])])
190
+
191
+
192
+ def append_dims(x, target_dims):
193
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
194
+ dims_to_append = target_dims - x.ndim
195
+ if dims_to_append < 0:
196
+ raise ValueError(
197
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
198
+ )
199
+ return x[(...,) + (None,) * dims_to_append]
200
+
201
+
202
+ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
203
+ print(f"Loading model from {ckpt}")
204
+ if ckpt.endswith("ckpt"):
205
+ pl_sd = torch.load(ckpt, map_location="cpu")
206
+ if "global_step" in pl_sd:
207
+ print(f"Global Step: {pl_sd['global_step']}")
208
+ sd = pl_sd["state_dict"]
209
+ elif ckpt.endswith("safetensors"):
210
+ sd = load_safetensors(ckpt)
211
+ else:
212
+ raise NotImplementedError
213
+
214
+ model = instantiate_from_config(config.model)
215
+ sd = pl_sd["state_dict"]
216
+
217
+ m, u = model.load_state_dict(sd, strict=False)
218
+
219
+ if len(m) > 0 and verbose:
220
+ print("missing keys:")
221
+ print(m)
222
+ if len(u) > 0 and verbose:
223
+ print("unexpected keys:")
224
+ print(u)
225
+
226
+ if freeze:
227
+ for param in model.parameters():
228
+ param.requires_grad = False
229
+
230
+ model.eval()
231
+ return model
repositories/k-diffusion/.github/workflows/python-publish.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - master
7
+ jobs:
8
+ deploy:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v2
12
+ - uses: actions-ecosystem/action-regex-match@v2
13
+ id: regex-match
14
+ with:
15
+ text: ${{ github.event.head_commit.message }}
16
+ regex: '^Release ([^ ]+)'
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v2
19
+ with:
20
+ python-version: '3.8'
21
+ - name: Install dependencies
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install setuptools wheel twine
25
+ - name: Release
26
+ if: ${{ steps.regex-match.outputs.match != '' }}
27
+ uses: softprops/action-gh-release@v1
28
+ with:
29
+ tag_name: v${{ steps.regex-match.outputs.group1 }}
30
+ - name: Build and publish
31
+ if: ${{ steps.regex-match.outputs.match != '' }}
32
+ env:
33
+ TWINE_USERNAME: __token__
34
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
35
+ run: |
36
+ python setup.py sdist bdist_wheel
37
+ twine upload dist/*
repositories/k-diffusion/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ venv*
2
+ __pycache__
3
+ .ipynb_checkpoints
4
+ *.pth
5
+ *.egg-info
6
+ data
7
+ *_demo_*.png
8
+ wandb/*
9
+ *.csv
10
+ .env
repositories/k-diffusion/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022 Katherine Crowson
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
repositories/k-diffusion/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # k-diffusion
2
+
3
+ An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
4
+
5
+ ## Installation
6
+
7
+ `k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e <path to repository>`.
8
+
9
+ ## Training:
10
+
11
+ To train models:
12
+
13
+ ```sh
14
+ $ ./train.py --config CONFIG_FILE --name RUN_NAME
15
+ ```
16
+
17
+ For instance, to train a model on MNIST:
18
+
19
+ ```sh
20
+ $ ./train.py --config configs/config_mnist.json --name RUN_NAME
21
+ ```
22
+
23
+ The configuration file allows you to specify the dataset type. Currently supported types are `"imagefolder"` (finds all images in that folder and its subfolders, recursively), `"cifar10"` (CIFAR-10), and `"mnist"` (MNIST). `"huggingface"` [Hugging Face Datasets](https://huggingface.co/docs/datasets/index) is also supported.
24
+
25
+ Multi-GPU and multi-node training is supported with [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index). You can configure Accelerate by running:
26
+
27
+ ```sh
28
+ $ accelerate config
29
+ ```
30
+
31
+ on all nodes, then running:
32
+
33
+ ```sh
34
+ $ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME
35
+ ```
36
+
37
+ on all nodes.
38
+
39
+ ## Enhancements/additional features:
40
+
41
+ - k-diffusion supports an experimental model output type, an isotropic Gaussian, which seems to have a lower gradient noise scale and to train faster than Karras et al. (2022) diffusion models.
42
+
43
+ - k-diffusion has wrappers for [v-diffusion-pytorch](https://github.com/crowsonkb/v-diffusion-pytorch), [OpenAI diffusion](https://github.com/openai/guided-diffusion), and [CompVis diffusion](https://github.com/CompVis/latent-diffusion) models allowing them to be used with its samplers and ODE/SDE.
44
+
45
+ - k-diffusion models support progressive growing.
46
+
47
+ - k-diffusion implements [DPM-Solver](https://arxiv.org/abs/2206.00927), which produces higher quality samples at the same number of function evalutions as Karras Algorithm 2, as well as supporting adaptive step size control. [DPM-Solver++(2S) and (2M)](https://arxiv.org/abs/2211.01095) are implemented now too for improved quality with low numbers of steps.
48
+
49
+ - k-diffusion supports [CLIP](https://openai.com/blog/clip/) guided sampling from unconditional diffusion models (see `sample_clip_guided.py`).
50
+
51
+ - k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models.
52
+
53
+ - k-diffusion can calculate, during training, the [FID](https://papers.nips.cc/paper/2017/file/8a1d694707eb0fefe65871369074926d-Paper.pdf) and [KID](https://arxiv.org/abs/1801.01401) vs the training set.
54
+
55
+ - k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from _An Empirical Model of Large-Batch Training_, https://arxiv.org/abs/1812.06162).
56
+
57
+ ## To do:
58
+
59
+ - Anything except unconditional image diffusion models
60
+
61
+ - Latent diffusion
repositories/k-diffusion/configs/config_32x32_small.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "type": "image_v1",
4
+ "input_channels": 3,
5
+ "input_size": [32, 32],
6
+ "patch_size": 1,
7
+ "mapping_out": 256,
8
+ "depths": [2, 4, 4],
9
+ "channels": [128, 256, 512],
10
+ "self_attn_depths": [false, true, true],
11
+ "has_variance": true,
12
+ "dropout_rate": 0.05,
13
+ "augment_wrapper": true,
14
+ "augment_prob": 0.12,
15
+ "sigma_data": 0.5,
16
+ "sigma_min": 1e-2,
17
+ "sigma_max": 80,
18
+ "sigma_sample_density": {
19
+ "type": "lognormal",
20
+ "mean": -1.2,
21
+ "std": 1.2
22
+ }
23
+ },
24
+ "dataset": {
25
+ "type": "imagefolder",
26
+ "location": "/path/to/dataset"
27
+ },
28
+ "optimizer": {
29
+ "type": "adamw",
30
+ "lr": 1e-4,
31
+ "betas": [0.95, 0.999],
32
+ "eps": 1e-6,
33
+ "weight_decay": 1e-3
34
+ },
35
+ "lr_sched": {
36
+ "type": "constant"
37
+ },
38
+ "ema_sched": {
39
+ "type": "inverse",
40
+ "power": 0.6667,
41
+ "max_value": 0.9999
42
+ }
43
+ }
repositories/k-diffusion/configs/config_32x32_small_butterflies.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "type": "image_v1",
4
+ "input_channels": 3,
5
+ "input_size": [32, 32],
6
+ "patch_size": 1,
7
+ "mapping_out": 256,
8
+ "depths": [2, 4, 4],
9
+ "channels": [128, 256, 512],
10
+ "self_attn_depths": [false, true, true],
11
+ "has_variance": true,
12
+ "dropout_rate": 0.05,
13
+ "augment_wrapper": true,
14
+ "augment_prob": 0.12,
15
+ "sigma_data": 0.5,
16
+ "sigma_min": 1e-2,
17
+ "sigma_max": 80,
18
+ "sigma_sample_density": {
19
+ "type": "lognormal",
20
+ "mean": -1.2,
21
+ "std": 1.2
22
+ }
23
+ },
24
+ "dataset": {
25
+ "type": "huggingface",
26
+ "location": "huggan/smithsonian_butterflies_subset",
27
+ "image_key": "image"
28
+ },
29
+ "optimizer": {
30
+ "type": "adamw",
31
+ "lr": 1e-4,
32
+ "betas": [0.95, 0.999],
33
+ "eps": 1e-6,
34
+ "weight_decay": 1e-3
35
+ },
36
+ "lr_sched": {
37
+ "type": "constant"
38
+ },
39
+ "ema_sched": {
40
+ "type": "inverse",
41
+ "power": 0.6667,
42
+ "max_value": 0.9999
43
+ }
44
+ }
repositories/k-diffusion/configs/config_cifar10.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "type": "image_v1",
4
+ "input_channels": 3,
5
+ "input_size": [32, 32],
6
+ "patch_size": 1,
7
+ "mapping_out": 256,
8
+ "depths": [2, 4, 4],
9
+ "channels": [128, 256, 512],
10
+ "self_attn_depths": [false, true, true],
11
+ "has_variance": true,
12
+ "dropout_rate": 0.05,
13
+ "augment_wrapper": true,
14
+ "augment_prob": 0.12,
15
+ "sigma_data": 0.5,
16
+ "sigma_min": 1e-2,
17
+ "sigma_max": 80,
18
+ "sigma_sample_density": {
19
+ "type": "lognormal",
20
+ "mean": -1.2,
21
+ "std": 1.2
22
+ }
23
+ },
24
+ "dataset": {
25
+ "type": "cifar10",
26
+ "location": "data"
27
+ },
28
+ "optimizer": {
29
+ "type": "adamw",
30
+ "lr": 1e-4,
31
+ "betas": [0.95, 0.999],
32
+ "eps": 1e-6,
33
+ "weight_decay": 1e-3
34
+ },
35
+ "lr_sched": {
36
+ "type": "constant"
37
+ },
38
+ "ema_sched": {
39
+ "type": "inverse",
40
+ "power": 0.6667,
41
+ "max_value": 0.9999
42
+ }
43
+ }
repositories/k-diffusion/configs/config_mnist.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "type": "image_v1",
4
+ "input_channels": 1,
5
+ "input_size": [28, 28],
6
+ "patch_size": 1,
7
+ "mapping_out": 256,
8
+ "depths": [2, 4, 4],
9
+ "channels": [128, 128, 256],
10
+ "self_attn_depths": [false, false, true],
11
+ "has_variance": true,
12
+ "dropout_rate": 0.05,
13
+ "augment_wrapper": true,
14
+ "augment_prob": 0.12,
15
+ "sigma_data": 0.6162,
16
+ "sigma_min": 1e-2,
17
+ "sigma_max": 80,
18
+ "sigma_sample_density": {
19
+ "type": "lognormal",
20
+ "mean": -1.2,
21
+ "std": 1.2
22
+ }
23
+ },
24
+ "dataset": {
25
+ "type": "mnist",
26
+ "location": "data"
27
+ },
28
+ "optimizer": {
29
+ "type": "adamw",
30
+ "lr": 2e-4,
31
+ "betas": [0.95, 0.999],
32
+ "eps": 1e-6,
33
+ "weight_decay": 1e-3
34
+ },
35
+ "lr_sched": {
36
+ "type": "constant"
37
+ },
38
+ "ema_sched": {
39
+ "type": "inverse",
40
+ "power": 0.6667,
41
+ "max_value": 0.9999
42
+ }
43
+ }
repositories/k-diffusion/k_diffusion/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import augmentation, config, evaluation, external, gns, layers, models, sampling, utils
2
+ from .layers import Denoiser
repositories/k-diffusion/k_diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (373 Bytes). View file
 
repositories/k-diffusion/k_diffusion/__pycache__/augmentation.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
repositories/k-diffusion/k_diffusion/__pycache__/config.cpython-310.pyc ADDED
Binary file (3.27 kB). View file
 
repositories/k-diffusion/k_diffusion/__pycache__/evaluation.cpython-310.pyc ADDED
Binary file (5.89 kB). View file
 
repositories/k-diffusion/k_diffusion/__pycache__/external.cpython-310.pyc ADDED
Binary file (8.53 kB). View file
 
repositories/k-diffusion/k_diffusion/__pycache__/gns.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
repositories/k-diffusion/k_diffusion/__pycache__/layers.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
repositories/k-diffusion/k_diffusion/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (23.7 kB). View file