ifwi
commited on
Commit
•
3c454af
1
Parent(s):
63a2e47
add more
Browse files- ldm/models/diffusion/ddpm.py +30 -25
- ldm/modules/attention.py +37 -29
ldm/models/diffusion/ddpm.py
CHANGED
@@ -47,6 +47,7 @@ def disabled_train(self, mode=True):
|
|
47 |
def uniform_on_device(r1, r2, shape, device):
|
48 |
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
49 |
|
|
|
50 |
class DDPM(pl.LightningModule):
|
51 |
# classic DDPM with Gaussian diffusion, in image space
|
52 |
def __init__(self,
|
@@ -124,7 +125,8 @@ class DDPM(pl.LightningModule):
|
|
124 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
125 |
if reset_ema:
|
126 |
assert self.use_ema
|
127 |
-
print(
|
|
|
128 |
self.model_ema = LitEma(self.model)
|
129 |
if reset_num_ema_updates:
|
130 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
@@ -573,7 +575,7 @@ class LatentDiffusion(DDPM):
|
|
573 |
self.scale_factor = scale_factor
|
574 |
else:
|
575 |
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
576 |
-
|
577 |
self.instantiate_first_stage(first_stage_config)
|
578 |
self.instantiate_cond_stage(cond_stage_config)
|
579 |
self.cond_stage_forward = cond_stage_forward
|
@@ -586,7 +588,7 @@ class LatentDiffusion(DDPM):
|
|
586 |
self.proj_out = None
|
587 |
if self.use_pbe_weight:
|
588 |
print("learnable vector gene")
|
589 |
-
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
|
590 |
else:
|
591 |
self.learnable_vector = None
|
592 |
|
@@ -608,7 +610,7 @@ class LatentDiffusion(DDPM):
|
|
608 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
609 |
assert self.use_ema
|
610 |
self.model_ema.reset_num_updates()
|
611 |
-
|
612 |
def make_cond_schedule(self, ):
|
613 |
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
614 |
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
@@ -646,7 +648,7 @@ class LatentDiffusion(DDPM):
|
|
646 |
self.first_stage_model.train = disabled_train
|
647 |
for param in self.first_stage_model.parameters():
|
648 |
param.requires_grad = False
|
649 |
-
|
650 |
def instantiate_cond_stage(self, config):
|
651 |
if not self.cond_stage_trainable:
|
652 |
if config == "__is_first_stage__":
|
@@ -791,14 +793,15 @@ class LatentDiffusion(DDPM):
|
|
791 |
|
792 |
@torch.no_grad()
|
793 |
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
794 |
-
cond_key=None, return_original_cond=False, bs=None, return_x=False, no_latent=False,
|
|
|
795 |
x = super().get_input(batch, k)
|
796 |
if bs is not None:
|
797 |
x = x[:bs]
|
798 |
x = x.to(self.device)
|
799 |
if no_latent:
|
800 |
-
_,_,h,w = x.shape
|
801 |
-
x = resize(x, (h//8, w//8))
|
802 |
return [x, None]
|
803 |
encoder_posterior = self.encode_first_stage(x)
|
804 |
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
@@ -815,12 +818,12 @@ class LatentDiffusion(DDPM):
|
|
815 |
xc = batch
|
816 |
else:
|
817 |
xc = super().get_input(batch, cond_key).to(self.device)
|
818 |
-
else:
|
819 |
xc = x
|
820 |
if not self.cond_stage_trainable or force_c_encode:
|
821 |
if self.kwargs["use_imageCLIP"]:
|
822 |
-
xc = resize(xc, (224,224))
|
823 |
-
xc = self.imagenet_norm((xc+1)/2)
|
824 |
c = xc
|
825 |
else:
|
826 |
if isinstance(xc, dict) or isinstance(xc, list):
|
@@ -830,8 +833,8 @@ class LatentDiffusion(DDPM):
|
|
830 |
c = c.float()
|
831 |
else:
|
832 |
if self.kwargs["use_imageCLIP"]:
|
833 |
-
xc = resize(xc, (224,224))
|
834 |
-
xc = self.imagenet_norm((xc+1)/2)
|
835 |
c = xc
|
836 |
if bs is not None:
|
837 |
c = c[:bs]
|
@@ -847,7 +850,7 @@ class LatentDiffusion(DDPM):
|
|
847 |
if self.use_positional_encodings:
|
848 |
pos_x, pos_y = self.compute_latent_shifts(batch)
|
849 |
c = {'pos_x': pos_x, 'pos_y': pos_y}
|
850 |
-
|
851 |
out = [z, c]
|
852 |
if return_first_stage_outputs:
|
853 |
xrec = self.decode_first_stage(z)
|
@@ -872,6 +875,7 @@ class LatentDiffusion(DDPM):
|
|
872 |
return output
|
873 |
else:
|
874 |
return output.sample
|
|
|
875 |
def decode_first_stage_train(self, z, predict_cids=False, force_not_quantize=False):
|
876 |
if predict_cids:
|
877 |
if z.dim() == 4:
|
@@ -905,12 +909,11 @@ class LatentDiffusion(DDPM):
|
|
905 |
# pbe negative condition
|
906 |
else:
|
907 |
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
908 |
-
self.u_cond_prop=random.uniform(0, 1)
|
909 |
c["c_crossattn"] = [self.get_learned_conditioning(c["c_crossattn"])]
|
910 |
if self.u_cond_prop < self.u_cond_percent:
|
911 |
-
c["c_crossattn"] = [self.learnable_vector.repeat(x.shape[0],1,1)]
|
912 |
return self.p_losses(x, c, t, *args, **kwargs)
|
913 |
-
|
914 |
|
915 |
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
916 |
if isinstance(cond, dict):
|
@@ -931,7 +934,7 @@ class LatentDiffusion(DDPM):
|
|
931 |
|
932 |
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
933 |
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
934 |
-
|
935 |
|
936 |
def _prior_bpd(self, x_start):
|
937 |
"""
|
@@ -946,6 +949,7 @@ class LatentDiffusion(DDPM):
|
|
946 |
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
947 |
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
948 |
return mean_flat(kl_prior) / np.log(2.0)
|
|
|
949 |
def p_losses(self, x_start, cond, t, noise=None):
|
950 |
loss_dict = {}
|
951 |
noise = default(noise, lambda: torch.randn_like(x_start))
|
@@ -969,11 +973,11 @@ class LatentDiffusion(DDPM):
|
|
969 |
if self.only_agn_simple_loss:
|
970 |
_, _, l_h, l_w = model_output.shape
|
971 |
m_agn = F.interpolate(super().get_input(self.batch, "agn_mask"), (l_h, l_w))
|
972 |
-
loss_simple = self.get_loss(model_output * (1-m_agn), target * (1-m_agn), mean=False).mean([1, 2, 3])
|
973 |
else:
|
974 |
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
975 |
loss_dict.update({f'simple': loss_simple.mean()})
|
976 |
-
|
977 |
logvar_t = self.logvar[t].to(self.device)
|
978 |
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
979 |
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
@@ -981,7 +985,7 @@ class LatentDiffusion(DDPM):
|
|
981 |
loss_dict.update({f'gamma': loss.mean()})
|
982 |
loss_dict.update({'logvar': self.logvar.data.mean()})
|
983 |
loss = self.l_simple_weight * loss.mean()
|
984 |
-
|
985 |
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
986 |
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
987 |
if self.original_elbo_weight != 0:
|
@@ -990,7 +994,7 @@ class LatentDiffusion(DDPM):
|
|
990 |
|
991 |
if model_loss is not None:
|
992 |
loss += model_loss
|
993 |
-
loss_dict.update({f"model loss"
|
994 |
loss_dict.update({f'{prefix}_loss': loss})
|
995 |
|
996 |
return loss, loss_dict
|
@@ -1540,7 +1544,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
|
1540 |
uc[k] = [uc_tmp]
|
1541 |
elif k == "c_adm": # todo: only run with text-based guidance?
|
1542 |
assert isinstance(c[k], torch.Tensor)
|
1543 |
-
#uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
|
1544 |
uc[k] = c[k]
|
1545 |
elif isinstance(c[k], list):
|
1546 |
uc[k] = [c[k][i] for i in range(len(c[k]))]
|
@@ -1807,7 +1811,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
|
1807 |
log = super().log_images(*args, **kwargs)
|
1808 |
depth = self.depth_model(args[0][self.depth_stage_key])
|
1809 |
depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
|
1810 |
-
|
1811 |
log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
|
1812 |
return log
|
1813 |
|
@@ -1816,6 +1820,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|
1816 |
"""
|
1817 |
condition on low-res image (and optionally on some spatial noise augmentation)
|
1818 |
"""
|
|
|
1819 |
def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
|
1820 |
low_scale_config=None, low_scale_key=None, *args, **kwargs):
|
1821 |
super().__init__(concat_keys=concat_keys, *args, **kwargs)
|
@@ -1872,4 +1877,4 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|
1872 |
def log_images(self, *args, **kwargs):
|
1873 |
log = super().log_images(*args, **kwargs)
|
1874 |
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
1875 |
-
return log
|
|
|
47 |
def uniform_on_device(r1, r2, shape, device):
|
48 |
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
49 |
|
50 |
+
|
51 |
class DDPM(pl.LightningModule):
|
52 |
# classic DDPM with Gaussian diffusion, in image space
|
53 |
def __init__(self,
|
|
|
125 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
126 |
if reset_ema:
|
127 |
assert self.use_ema
|
128 |
+
print(
|
129 |
+
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
|
130 |
self.model_ema = LitEma(self.model)
|
131 |
if reset_num_ema_updates:
|
132 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
|
|
575 |
self.scale_factor = scale_factor
|
576 |
else:
|
577 |
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
578 |
+
|
579 |
self.instantiate_first_stage(first_stage_config)
|
580 |
self.instantiate_cond_stage(cond_stage_config)
|
581 |
self.cond_stage_forward = cond_stage_forward
|
|
|
588 |
self.proj_out = None
|
589 |
if self.use_pbe_weight:
|
590 |
print("learnable vector gene")
|
591 |
+
self.learnable_vector = nn.Parameter(torch.randn((1, 1, 768)), requires_grad=True)
|
592 |
else:
|
593 |
self.learnable_vector = None
|
594 |
|
|
|
610 |
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
611 |
assert self.use_ema
|
612 |
self.model_ema.reset_num_updates()
|
613 |
+
|
614 |
def make_cond_schedule(self, ):
|
615 |
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
616 |
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
|
|
648 |
self.first_stage_model.train = disabled_train
|
649 |
for param in self.first_stage_model.parameters():
|
650 |
param.requires_grad = False
|
651 |
+
|
652 |
def instantiate_cond_stage(self, config):
|
653 |
if not self.cond_stage_trainable:
|
654 |
if config == "__is_first_stage__":
|
|
|
793 |
|
794 |
@torch.no_grad()
|
795 |
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
796 |
+
cond_key=None, return_original_cond=False, bs=None, return_x=False, no_latent=False,
|
797 |
+
is_controlnet=False):
|
798 |
x = super().get_input(batch, k)
|
799 |
if bs is not None:
|
800 |
x = x[:bs]
|
801 |
x = x.to(self.device)
|
802 |
if no_latent:
|
803 |
+
_, _, h, w = x.shape
|
804 |
+
x = resize(x, (h // 8, w // 8))
|
805 |
return [x, None]
|
806 |
encoder_posterior = self.encode_first_stage(x)
|
807 |
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
|
|
818 |
xc = batch
|
819 |
else:
|
820 |
xc = super().get_input(batch, cond_key).to(self.device)
|
821 |
+
else:
|
822 |
xc = x
|
823 |
if not self.cond_stage_trainable or force_c_encode:
|
824 |
if self.kwargs["use_imageCLIP"]:
|
825 |
+
xc = resize(xc, (224, 224))
|
826 |
+
xc = self.imagenet_norm((xc + 1) / 2)
|
827 |
c = xc
|
828 |
else:
|
829 |
if isinstance(xc, dict) or isinstance(xc, list):
|
|
|
833 |
c = c.float()
|
834 |
else:
|
835 |
if self.kwargs["use_imageCLIP"]:
|
836 |
+
xc = resize(xc, (224, 224))
|
837 |
+
xc = self.imagenet_norm((xc + 1) / 2)
|
838 |
c = xc
|
839 |
if bs is not None:
|
840 |
c = c[:bs]
|
|
|
850 |
if self.use_positional_encodings:
|
851 |
pos_x, pos_y = self.compute_latent_shifts(batch)
|
852 |
c = {'pos_x': pos_x, 'pos_y': pos_y}
|
853 |
+
|
854 |
out = [z, c]
|
855 |
if return_first_stage_outputs:
|
856 |
xrec = self.decode_first_stage(z)
|
|
|
875 |
return output
|
876 |
else:
|
877 |
return output.sample
|
878 |
+
|
879 |
def decode_first_stage_train(self, z, predict_cids=False, force_not_quantize=False):
|
880 |
if predict_cids:
|
881 |
if z.dim() == 4:
|
|
|
909 |
# pbe negative condition
|
910 |
else:
|
911 |
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
912 |
+
self.u_cond_prop = random.uniform(0, 1)
|
913 |
c["c_crossattn"] = [self.get_learned_conditioning(c["c_crossattn"])]
|
914 |
if self.u_cond_prop < self.u_cond_percent:
|
915 |
+
c["c_crossattn"] = [self.learnable_vector.repeat(x.shape[0], 1, 1)]
|
916 |
return self.p_losses(x, c, t, *args, **kwargs)
|
|
|
917 |
|
918 |
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
919 |
if isinstance(cond, dict):
|
|
|
934 |
|
935 |
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
936 |
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
937 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
938 |
|
939 |
def _prior_bpd(self, x_start):
|
940 |
"""
|
|
|
949 |
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
950 |
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
951 |
return mean_flat(kl_prior) / np.log(2.0)
|
952 |
+
|
953 |
def p_losses(self, x_start, cond, t, noise=None):
|
954 |
loss_dict = {}
|
955 |
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
973 |
if self.only_agn_simple_loss:
|
974 |
_, _, l_h, l_w = model_output.shape
|
975 |
m_agn = F.interpolate(super().get_input(self.batch, "agn_mask"), (l_h, l_w))
|
976 |
+
loss_simple = self.get_loss(model_output * (1 - m_agn), target * (1 - m_agn), mean=False).mean([1, 2, 3])
|
977 |
else:
|
978 |
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
979 |
loss_dict.update({f'simple': loss_simple.mean()})
|
980 |
+
|
981 |
logvar_t = self.logvar[t].to(self.device)
|
982 |
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
983 |
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
|
|
985 |
loss_dict.update({f'gamma': loss.mean()})
|
986 |
loss_dict.update({'logvar': self.logvar.data.mean()})
|
987 |
loss = self.l_simple_weight * loss.mean()
|
988 |
+
|
989 |
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
990 |
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
991 |
if self.original_elbo_weight != 0:
|
|
|
994 |
|
995 |
if model_loss is not None:
|
996 |
loss += model_loss
|
997 |
+
loss_dict.update({f"model loss": model_loss})
|
998 |
loss_dict.update({f'{prefix}_loss': loss})
|
999 |
|
1000 |
return loss, loss_dict
|
|
|
1544 |
uc[k] = [uc_tmp]
|
1545 |
elif k == "c_adm": # todo: only run with text-based guidance?
|
1546 |
assert isinstance(c[k], torch.Tensor)
|
1547 |
+
# uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
|
1548 |
uc[k] = c[k]
|
1549 |
elif isinstance(c[k], list):
|
1550 |
uc[k] = [c[k][i] for i in range(len(c[k]))]
|
|
|
1811 |
log = super().log_images(*args, **kwargs)
|
1812 |
depth = self.depth_model(args[0][self.depth_stage_key])
|
1813 |
depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
|
1814 |
+
torch.amax(depth, dim=[1, 2, 3], keepdim=True)
|
1815 |
log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
|
1816 |
return log
|
1817 |
|
|
|
1820 |
"""
|
1821 |
condition on low-res image (and optionally on some spatial noise augmentation)
|
1822 |
"""
|
1823 |
+
|
1824 |
def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
|
1825 |
low_scale_config=None, low_scale_key=None, *args, **kwargs):
|
1826 |
super().__init__(concat_keys=concat_keys, *args, **kwargs)
|
|
|
1877 |
def log_images(self, *args, **kwargs):
|
1878 |
log = super().log_images(*args, **kwargs)
|
1879 |
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
1880 |
+
return log
|
ldm/modules/attention.py
CHANGED
@@ -12,20 +12,23 @@ from ldm.modules.diffusionmodules.util import checkpoint
|
|
12 |
try:
|
13 |
import xformers
|
14 |
import xformers.ops
|
|
|
15 |
XFORMERS_IS_AVAILBLE = True
|
16 |
except:
|
17 |
XFORMERS_IS_AVAILBLE = False
|
18 |
|
19 |
# CrossAttn precision handling
|
20 |
import os
|
|
|
21 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
22 |
|
|
|
23 |
def exists(val):
|
24 |
return val is not None
|
25 |
|
26 |
|
27 |
def uniq(arr):
|
28 |
-
return{el: True for el in arr}.keys()
|
29 |
|
30 |
|
31 |
def default(val, d):
|
@@ -33,6 +36,7 @@ def default(val, d):
|
|
33 |
return val
|
34 |
return d() if isfunction(d) else d
|
35 |
|
|
|
36 |
class GEGLU(nn.Module):
|
37 |
def __init__(self, dim_in, dim_out):
|
38 |
super().__init__()
|
@@ -110,12 +114,12 @@ class SpatialSelfAttention(nn.Module):
|
|
110 |
k = self.k(h_)
|
111 |
v = self.v(h_)
|
112 |
|
113 |
-
b,c,h,w = q.shape
|
114 |
q = rearrange(q, 'b c h w -> b (h w) c')
|
115 |
k = rearrange(k, 'b c h w -> b c (h w)')
|
116 |
w_ = torch.einsum('bij,bjk->bik', q, k)
|
117 |
|
118 |
-
w_ = w_ * (int(c)**(-0.5))
|
119 |
w_ = torch.nn.functional.softmax(w_, dim=2)
|
120 |
|
121 |
v = rearrange(v, 'b c h w -> b c (h w)')
|
@@ -124,7 +128,8 @@ class SpatialSelfAttention(nn.Module):
|
|
124 |
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
125 |
h_ = self.proj_out(h_)
|
126 |
|
127 |
-
return x+h_
|
|
|
128 |
|
129 |
class CrossAttention(nn.Module):
|
130 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kwargs):
|
@@ -143,7 +148,6 @@ class CrossAttention(nn.Module):
|
|
143 |
nn.Linear(inner_dim, query_dim),
|
144 |
nn.Dropout(dropout)
|
145 |
)
|
146 |
-
|
147 |
|
148 |
def forward(self, x, context=None, mask=None):
|
149 |
h = self.heads
|
@@ -153,26 +157,27 @@ class CrossAttention(nn.Module):
|
|
153 |
v = self.to_v(context)
|
154 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
155 |
|
156 |
-
if _ATTN_PRECISION =="fp32":
|
157 |
-
with torch.autocast(enabled=False, device_type
|
158 |
q, k = q.float(), k.float()
|
159 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
160 |
else:
|
161 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
162 |
-
|
163 |
del q, k
|
164 |
if exists(mask):
|
165 |
mask = rearrange(mask, 'b ... -> b (...)')
|
166 |
max_neg_value = -torch.finfo(sim.dtype).max
|
167 |
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
168 |
sim.masked_fill_(~mask, max_neg_value)
|
169 |
-
|
170 |
-
sim = sim.softmax(dim=-1)
|
171 |
-
|
172 |
out = einsum('b i j, b j d -> b i d', sim, v)
|
173 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
174 |
return self.to_out(out)
|
175 |
|
|
|
176 |
class MemoryEfficientCrossAttention(nn.Module):
|
177 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
178 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, zero_init=False, **kwargs):
|
@@ -195,7 +200,6 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
195 |
|
196 |
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
197 |
self.attention_op: Optional[Any] = None
|
198 |
-
|
199 |
|
200 |
def forward(self, x, context=None, mask=None, **kwargs):
|
201 |
q = self.to_q(x)
|
@@ -221,23 +225,25 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
221 |
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
222 |
)
|
223 |
return self.to_out(out)
|
224 |
-
|
|
|
225 |
class BasicTransformerBlock(nn.Module):
|
226 |
ATTENTION_MODES = {
|
227 |
"softmax": CrossAttention, # vanilla attention
|
228 |
"softmax-xformers": MemoryEfficientCrossAttention
|
229 |
}
|
|
|
230 |
def __init__(
|
231 |
-
self,
|
232 |
-
dim,
|
233 |
-
n_heads,
|
234 |
-
d_head,
|
235 |
-
dropout=0.,
|
236 |
-
context_dim=None,
|
237 |
-
gated_ff=True,
|
238 |
checkpoint=True,
|
239 |
disable_self_attn=False
|
240 |
-
|
241 |
super().__init__()
|
242 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
243 |
assert attn_mode in self.ATTENTION_MODES
|
@@ -247,24 +253,25 @@ class BasicTransformerBlock(nn.Module):
|
|
247 |
context_dim=context_dim if self.disable_self_attn else None)
|
248 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
249 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
250 |
-
heads=n_heads, dim_head=d_head, dropout=dropout)
|
251 |
self.norm1 = nn.LayerNorm(dim)
|
252 |
self.norm2 = nn.LayerNorm(dim)
|
253 |
self.norm3 = nn.LayerNorm(dim)
|
254 |
self.checkpoint = checkpoint
|
255 |
|
256 |
-
def forward(self, x, context=None,hint=None):
|
257 |
if hint is None:
|
258 |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
259 |
else:
|
260 |
return checkpoint(self._forward, (x, context, hint), self.parameters(), self.checkpoint)
|
261 |
|
262 |
-
def _forward(self, x, context=None,hint=None):
|
263 |
-
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None,hint=hint) + x
|
264 |
x = self.attn2(self.norm2(x), context=context) + x
|
265 |
x = self.ff(self.norm3(x)) + x
|
266 |
return x
|
267 |
|
|
|
268 |
class SpatialTransformer(nn.Module):
|
269 |
"""
|
270 |
Transformer block for image-like data.
|
@@ -274,6 +281,7 @@ class SpatialTransformer(nn.Module):
|
|
274 |
Finally, reshape to image
|
275 |
NEW: use_linear for more efficiency instead of the 1x1 convs
|
276 |
"""
|
|
|
277 |
def __init__(self, in_channels, n_heads, d_head,
|
278 |
depth=1, dropout=0., context_dim=None,
|
279 |
disable_self_attn=False, use_linear=False,
|
@@ -296,7 +304,7 @@ class SpatialTransformer(nn.Module):
|
|
296 |
self.transformer_blocks = nn.ModuleList(
|
297 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
298 |
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
299 |
-
|
300 |
)
|
301 |
if not use_linear:
|
302 |
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
@@ -308,7 +316,7 @@ class SpatialTransformer(nn.Module):
|
|
308 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
309 |
self.use_linear = use_linear
|
310 |
|
311 |
-
def forward(self, x, context=None,hint=None):
|
312 |
# note: if no context is given, cross-attention defaults to self-attention
|
313 |
if not isinstance(context, list):
|
314 |
context = [context]
|
@@ -321,10 +329,10 @@ class SpatialTransformer(nn.Module):
|
|
321 |
if self.use_linear:
|
322 |
x = self.proj_in(x)
|
323 |
for i, block in enumerate(self.transformer_blocks):
|
324 |
-
x = block(x, context=context[i]
|
325 |
if self.use_linear:
|
326 |
x = self.proj_out(x)
|
327 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
328 |
if not self.use_linear:
|
329 |
x = self.proj_out(x)
|
330 |
-
return x + x_in
|
|
|
12 |
try:
|
13 |
import xformers
|
14 |
import xformers.ops
|
15 |
+
|
16 |
XFORMERS_IS_AVAILBLE = True
|
17 |
except:
|
18 |
XFORMERS_IS_AVAILBLE = False
|
19 |
|
20 |
# CrossAttn precision handling
|
21 |
import os
|
22 |
+
|
23 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
24 |
|
25 |
+
|
26 |
def exists(val):
|
27 |
return val is not None
|
28 |
|
29 |
|
30 |
def uniq(arr):
|
31 |
+
return {el: True for el in arr}.keys()
|
32 |
|
33 |
|
34 |
def default(val, d):
|
|
|
36 |
return val
|
37 |
return d() if isfunction(d) else d
|
38 |
|
39 |
+
|
40 |
class GEGLU(nn.Module):
|
41 |
def __init__(self, dim_in, dim_out):
|
42 |
super().__init__()
|
|
|
114 |
k = self.k(h_)
|
115 |
v = self.v(h_)
|
116 |
|
117 |
+
b, c, h, w = q.shape
|
118 |
q = rearrange(q, 'b c h w -> b (h w) c')
|
119 |
k = rearrange(k, 'b c h w -> b c (h w)')
|
120 |
w_ = torch.einsum('bij,bjk->bik', q, k)
|
121 |
|
122 |
+
w_ = w_ * (int(c) ** (-0.5))
|
123 |
w_ = torch.nn.functional.softmax(w_, dim=2)
|
124 |
|
125 |
v = rearrange(v, 'b c h w -> b c (h w)')
|
|
|
128 |
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
129 |
h_ = self.proj_out(h_)
|
130 |
|
131 |
+
return x + h_
|
132 |
+
|
133 |
|
134 |
class CrossAttention(nn.Module):
|
135 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kwargs):
|
|
|
148 |
nn.Linear(inner_dim, query_dim),
|
149 |
nn.Dropout(dropout)
|
150 |
)
|
|
|
151 |
|
152 |
def forward(self, x, context=None, mask=None):
|
153 |
h = self.heads
|
|
|
157 |
v = self.to_v(context)
|
158 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
159 |
|
160 |
+
if _ATTN_PRECISION == "fp32":
|
161 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
162 |
q, k = q.float(), k.float()
|
163 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
164 |
else:
|
165 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
166 |
+
|
167 |
del q, k
|
168 |
if exists(mask):
|
169 |
mask = rearrange(mask, 'b ... -> b (...)')
|
170 |
max_neg_value = -torch.finfo(sim.dtype).max
|
171 |
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
172 |
sim.masked_fill_(~mask, max_neg_value)
|
173 |
+
|
174 |
+
sim = sim.softmax(dim=-1)
|
175 |
+
|
176 |
out = einsum('b i j, b j d -> b i d', sim, v)
|
177 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
178 |
return self.to_out(out)
|
179 |
|
180 |
+
|
181 |
class MemoryEfficientCrossAttention(nn.Module):
|
182 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
183 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, zero_init=False, **kwargs):
|
|
|
200 |
|
201 |
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
202 |
self.attention_op: Optional[Any] = None
|
|
|
203 |
|
204 |
def forward(self, x, context=None, mask=None, **kwargs):
|
205 |
q = self.to_q(x)
|
|
|
225 |
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
226 |
)
|
227 |
return self.to_out(out)
|
228 |
+
|
229 |
+
|
230 |
class BasicTransformerBlock(nn.Module):
|
231 |
ATTENTION_MODES = {
|
232 |
"softmax": CrossAttention, # vanilla attention
|
233 |
"softmax-xformers": MemoryEfficientCrossAttention
|
234 |
}
|
235 |
+
|
236 |
def __init__(
|
237 |
+
self,
|
238 |
+
dim,
|
239 |
+
n_heads,
|
240 |
+
d_head,
|
241 |
+
dropout=0.,
|
242 |
+
context_dim=None,
|
243 |
+
gated_ff=True,
|
244 |
checkpoint=True,
|
245 |
disable_self_attn=False
|
246 |
+
):
|
247 |
super().__init__()
|
248 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
249 |
assert attn_mode in self.ATTENTION_MODES
|
|
|
253 |
context_dim=context_dim if self.disable_self_attn else None)
|
254 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
255 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
256 |
+
heads=n_heads, dim_head=d_head, dropout=dropout)
|
257 |
self.norm1 = nn.LayerNorm(dim)
|
258 |
self.norm2 = nn.LayerNorm(dim)
|
259 |
self.norm3 = nn.LayerNorm(dim)
|
260 |
self.checkpoint = checkpoint
|
261 |
|
262 |
+
def forward(self, x, context=None, hint=None):
|
263 |
if hint is None:
|
264 |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
265 |
else:
|
266 |
return checkpoint(self._forward, (x, context, hint), self.parameters(), self.checkpoint)
|
267 |
|
268 |
+
def _forward(self, x, context=None, hint=None):
|
269 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, hint=hint) + x
|
270 |
x = self.attn2(self.norm2(x), context=context) + x
|
271 |
x = self.ff(self.norm3(x)) + x
|
272 |
return x
|
273 |
|
274 |
+
|
275 |
class SpatialTransformer(nn.Module):
|
276 |
"""
|
277 |
Transformer block for image-like data.
|
|
|
281 |
Finally, reshape to image
|
282 |
NEW: use_linear for more efficiency instead of the 1x1 convs
|
283 |
"""
|
284 |
+
|
285 |
def __init__(self, in_channels, n_heads, d_head,
|
286 |
depth=1, dropout=0., context_dim=None,
|
287 |
disable_self_attn=False, use_linear=False,
|
|
|
304 |
self.transformer_blocks = nn.ModuleList(
|
305 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
306 |
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
307 |
+
for d in range(depth)]
|
308 |
)
|
309 |
if not use_linear:
|
310 |
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
|
|
316 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
317 |
self.use_linear = use_linear
|
318 |
|
319 |
+
def forward(self, x, context=None, hint=None):
|
320 |
# note: if no context is given, cross-attention defaults to self-attention
|
321 |
if not isinstance(context, list):
|
322 |
context = [context]
|
|
|
329 |
if self.use_linear:
|
330 |
x = self.proj_in(x)
|
331 |
for i, block in enumerate(self.transformer_blocks):
|
332 |
+
x = block(x, context=context[i])
|
333 |
if self.use_linear:
|
334 |
x = self.proj_out(x)
|
335 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
336 |
if not self.use_linear:
|
337 |
x = self.proj_out(x)
|
338 |
+
return x + x_in
|