ifwi commited on
Commit
3c454af
1 Parent(s): 63a2e47
ldm/models/diffusion/ddpm.py CHANGED
@@ -47,6 +47,7 @@ def disabled_train(self, mode=True):
47
  def uniform_on_device(r1, r2, shape, device):
48
  return (r1 - r2) * torch.rand(*shape, device=device) + r2
49
 
 
50
  class DDPM(pl.LightningModule):
51
  # classic DDPM with Gaussian diffusion, in image space
52
  def __init__(self,
@@ -124,7 +125,8 @@ class DDPM(pl.LightningModule):
124
  self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
125
  if reset_ema:
126
  assert self.use_ema
127
- print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
 
128
  self.model_ema = LitEma(self.model)
129
  if reset_num_ema_updates:
130
  print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
@@ -573,7 +575,7 @@ class LatentDiffusion(DDPM):
573
  self.scale_factor = scale_factor
574
  else:
575
  self.register_buffer('scale_factor', torch.tensor(scale_factor))
576
-
577
  self.instantiate_first_stage(first_stage_config)
578
  self.instantiate_cond_stage(cond_stage_config)
579
  self.cond_stage_forward = cond_stage_forward
@@ -586,7 +588,7 @@ class LatentDiffusion(DDPM):
586
  self.proj_out = None
587
  if self.use_pbe_weight:
588
  print("learnable vector gene")
589
- self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
590
  else:
591
  self.learnable_vector = None
592
 
@@ -608,7 +610,7 @@ class LatentDiffusion(DDPM):
608
  print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
609
  assert self.use_ema
610
  self.model_ema.reset_num_updates()
611
-
612
  def make_cond_schedule(self, ):
613
  self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
614
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
@@ -646,7 +648,7 @@ class LatentDiffusion(DDPM):
646
  self.first_stage_model.train = disabled_train
647
  for param in self.first_stage_model.parameters():
648
  param.requires_grad = False
649
-
650
  def instantiate_cond_stage(self, config):
651
  if not self.cond_stage_trainable:
652
  if config == "__is_first_stage__":
@@ -791,14 +793,15 @@ class LatentDiffusion(DDPM):
791
 
792
  @torch.no_grad()
793
  def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
794
- cond_key=None, return_original_cond=False, bs=None, return_x=False, no_latent=False, is_controlnet=False):
 
795
  x = super().get_input(batch, k)
796
  if bs is not None:
797
  x = x[:bs]
798
  x = x.to(self.device)
799
  if no_latent:
800
- _,_,h,w = x.shape
801
- x = resize(x, (h//8, w//8))
802
  return [x, None]
803
  encoder_posterior = self.encode_first_stage(x)
804
  z = self.get_first_stage_encoding(encoder_posterior).detach()
@@ -815,12 +818,12 @@ class LatentDiffusion(DDPM):
815
  xc = batch
816
  else:
817
  xc = super().get_input(batch, cond_key).to(self.device)
818
- else:
819
  xc = x
820
  if not self.cond_stage_trainable or force_c_encode:
821
  if self.kwargs["use_imageCLIP"]:
822
- xc = resize(xc, (224,224))
823
- xc = self.imagenet_norm((xc+1)/2)
824
  c = xc
825
  else:
826
  if isinstance(xc, dict) or isinstance(xc, list):
@@ -830,8 +833,8 @@ class LatentDiffusion(DDPM):
830
  c = c.float()
831
  else:
832
  if self.kwargs["use_imageCLIP"]:
833
- xc = resize(xc, (224,224))
834
- xc = self.imagenet_norm((xc+1)/2)
835
  c = xc
836
  if bs is not None:
837
  c = c[:bs]
@@ -847,7 +850,7 @@ class LatentDiffusion(DDPM):
847
  if self.use_positional_encodings:
848
  pos_x, pos_y = self.compute_latent_shifts(batch)
849
  c = {'pos_x': pos_x, 'pos_y': pos_y}
850
-
851
  out = [z, c]
852
  if return_first_stage_outputs:
853
  xrec = self.decode_first_stage(z)
@@ -872,6 +875,7 @@ class LatentDiffusion(DDPM):
872
  return output
873
  else:
874
  return output.sample
 
875
  def decode_first_stage_train(self, z, predict_cids=False, force_not_quantize=False):
876
  if predict_cids:
877
  if z.dim() == 4:
@@ -905,12 +909,11 @@ class LatentDiffusion(DDPM):
905
  # pbe negative condition
906
  else:
907
  t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
908
- self.u_cond_prop=random.uniform(0, 1)
909
  c["c_crossattn"] = [self.get_learned_conditioning(c["c_crossattn"])]
910
  if self.u_cond_prop < self.u_cond_percent:
911
- c["c_crossattn"] = [self.learnable_vector.repeat(x.shape[0],1,1)]
912
  return self.p_losses(x, c, t, *args, **kwargs)
913
-
914
 
915
  def apply_model(self, x_noisy, t, cond, return_ids=False):
916
  if isinstance(cond, dict):
@@ -931,7 +934,7 @@ class LatentDiffusion(DDPM):
931
 
932
  def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
933
  return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
934
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
935
 
936
  def _prior_bpd(self, x_start):
937
  """
@@ -946,6 +949,7 @@ class LatentDiffusion(DDPM):
946
  qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
947
  kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
948
  return mean_flat(kl_prior) / np.log(2.0)
 
949
  def p_losses(self, x_start, cond, t, noise=None):
950
  loss_dict = {}
951
  noise = default(noise, lambda: torch.randn_like(x_start))
@@ -969,11 +973,11 @@ class LatentDiffusion(DDPM):
969
  if self.only_agn_simple_loss:
970
  _, _, l_h, l_w = model_output.shape
971
  m_agn = F.interpolate(super().get_input(self.batch, "agn_mask"), (l_h, l_w))
972
- loss_simple = self.get_loss(model_output * (1-m_agn), target * (1-m_agn), mean=False).mean([1, 2, 3])
973
  else:
974
  loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
975
  loss_dict.update({f'simple': loss_simple.mean()})
976
-
977
  logvar_t = self.logvar[t].to(self.device)
978
  loss = loss_simple / torch.exp(logvar_t) + logvar_t
979
  # loss = loss_simple / torch.exp(self.logvar) + self.logvar
@@ -981,7 +985,7 @@ class LatentDiffusion(DDPM):
981
  loss_dict.update({f'gamma': loss.mean()})
982
  loss_dict.update({'logvar': self.logvar.data.mean()})
983
  loss = self.l_simple_weight * loss.mean()
984
-
985
  loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
986
  loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
987
  if self.original_elbo_weight != 0:
@@ -990,7 +994,7 @@ class LatentDiffusion(DDPM):
990
 
991
  if model_loss is not None:
992
  loss += model_loss
993
- loss_dict.update({f"model loss" : model_loss})
994
  loss_dict.update({f'{prefix}_loss': loss})
995
 
996
  return loss, loss_dict
@@ -1540,7 +1544,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
1540
  uc[k] = [uc_tmp]
1541
  elif k == "c_adm": # todo: only run with text-based guidance?
1542
  assert isinstance(c[k], torch.Tensor)
1543
- #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1544
  uc[k] = c[k]
1545
  elif isinstance(c[k], list):
1546
  uc[k] = [c[k][i] for i in range(len(c[k]))]
@@ -1807,7 +1811,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
1807
  log = super().log_images(*args, **kwargs)
1808
  depth = self.depth_model(args[0][self.depth_stage_key])
1809
  depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1810
- torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1811
  log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1812
  return log
1813
 
@@ -1816,6 +1820,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1816
  """
1817
  condition on low-res image (and optionally on some spatial noise augmentation)
1818
  """
 
1819
  def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1820
  low_scale_config=None, low_scale_key=None, *args, **kwargs):
1821
  super().__init__(concat_keys=concat_keys, *args, **kwargs)
@@ -1872,4 +1877,4 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1872
  def log_images(self, *args, **kwargs):
1873
  log = super().log_images(*args, **kwargs)
1874
  log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1875
- return log
 
47
  def uniform_on_device(r1, r2, shape, device):
48
  return (r1 - r2) * torch.rand(*shape, device=device) + r2
49
 
50
+
51
  class DDPM(pl.LightningModule):
52
  # classic DDPM with Gaussian diffusion, in image space
53
  def __init__(self,
 
125
  self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
126
  if reset_ema:
127
  assert self.use_ema
128
+ print(
129
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
130
  self.model_ema = LitEma(self.model)
131
  if reset_num_ema_updates:
132
  print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
 
575
  self.scale_factor = scale_factor
576
  else:
577
  self.register_buffer('scale_factor', torch.tensor(scale_factor))
578
+
579
  self.instantiate_first_stage(first_stage_config)
580
  self.instantiate_cond_stage(cond_stage_config)
581
  self.cond_stage_forward = cond_stage_forward
 
588
  self.proj_out = None
589
  if self.use_pbe_weight:
590
  print("learnable vector gene")
591
+ self.learnable_vector = nn.Parameter(torch.randn((1, 1, 768)), requires_grad=True)
592
  else:
593
  self.learnable_vector = None
594
 
 
610
  print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
611
  assert self.use_ema
612
  self.model_ema.reset_num_updates()
613
+
614
  def make_cond_schedule(self, ):
615
  self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
616
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
 
648
  self.first_stage_model.train = disabled_train
649
  for param in self.first_stage_model.parameters():
650
  param.requires_grad = False
651
+
652
  def instantiate_cond_stage(self, config):
653
  if not self.cond_stage_trainable:
654
  if config == "__is_first_stage__":
 
793
 
794
  @torch.no_grad()
795
  def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
796
+ cond_key=None, return_original_cond=False, bs=None, return_x=False, no_latent=False,
797
+ is_controlnet=False):
798
  x = super().get_input(batch, k)
799
  if bs is not None:
800
  x = x[:bs]
801
  x = x.to(self.device)
802
  if no_latent:
803
+ _, _, h, w = x.shape
804
+ x = resize(x, (h // 8, w // 8))
805
  return [x, None]
806
  encoder_posterior = self.encode_first_stage(x)
807
  z = self.get_first_stage_encoding(encoder_posterior).detach()
 
818
  xc = batch
819
  else:
820
  xc = super().get_input(batch, cond_key).to(self.device)
821
+ else:
822
  xc = x
823
  if not self.cond_stage_trainable or force_c_encode:
824
  if self.kwargs["use_imageCLIP"]:
825
+ xc = resize(xc, (224, 224))
826
+ xc = self.imagenet_norm((xc + 1) / 2)
827
  c = xc
828
  else:
829
  if isinstance(xc, dict) or isinstance(xc, list):
 
833
  c = c.float()
834
  else:
835
  if self.kwargs["use_imageCLIP"]:
836
+ xc = resize(xc, (224, 224))
837
+ xc = self.imagenet_norm((xc + 1) / 2)
838
  c = xc
839
  if bs is not None:
840
  c = c[:bs]
 
850
  if self.use_positional_encodings:
851
  pos_x, pos_y = self.compute_latent_shifts(batch)
852
  c = {'pos_x': pos_x, 'pos_y': pos_y}
853
+
854
  out = [z, c]
855
  if return_first_stage_outputs:
856
  xrec = self.decode_first_stage(z)
 
875
  return output
876
  else:
877
  return output.sample
878
+
879
  def decode_first_stage_train(self, z, predict_cids=False, force_not_quantize=False):
880
  if predict_cids:
881
  if z.dim() == 4:
 
909
  # pbe negative condition
910
  else:
911
  t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
912
+ self.u_cond_prop = random.uniform(0, 1)
913
  c["c_crossattn"] = [self.get_learned_conditioning(c["c_crossattn"])]
914
  if self.u_cond_prop < self.u_cond_percent:
915
+ c["c_crossattn"] = [self.learnable_vector.repeat(x.shape[0], 1, 1)]
916
  return self.p_losses(x, c, t, *args, **kwargs)
 
917
 
918
  def apply_model(self, x_noisy, t, cond, return_ids=False):
919
  if isinstance(cond, dict):
 
934
 
935
  def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
936
  return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
937
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
938
 
939
  def _prior_bpd(self, x_start):
940
  """
 
949
  qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
950
  kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
951
  return mean_flat(kl_prior) / np.log(2.0)
952
+
953
  def p_losses(self, x_start, cond, t, noise=None):
954
  loss_dict = {}
955
  noise = default(noise, lambda: torch.randn_like(x_start))
 
973
  if self.only_agn_simple_loss:
974
  _, _, l_h, l_w = model_output.shape
975
  m_agn = F.interpolate(super().get_input(self.batch, "agn_mask"), (l_h, l_w))
976
+ loss_simple = self.get_loss(model_output * (1 - m_agn), target * (1 - m_agn), mean=False).mean([1, 2, 3])
977
  else:
978
  loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
979
  loss_dict.update({f'simple': loss_simple.mean()})
980
+
981
  logvar_t = self.logvar[t].to(self.device)
982
  loss = loss_simple / torch.exp(logvar_t) + logvar_t
983
  # loss = loss_simple / torch.exp(self.logvar) + self.logvar
 
985
  loss_dict.update({f'gamma': loss.mean()})
986
  loss_dict.update({'logvar': self.logvar.data.mean()})
987
  loss = self.l_simple_weight * loss.mean()
988
+
989
  loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
990
  loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
991
  if self.original_elbo_weight != 0:
 
994
 
995
  if model_loss is not None:
996
  loss += model_loss
997
+ loss_dict.update({f"model loss": model_loss})
998
  loss_dict.update({f'{prefix}_loss': loss})
999
 
1000
  return loss, loss_dict
 
1544
  uc[k] = [uc_tmp]
1545
  elif k == "c_adm": # todo: only run with text-based guidance?
1546
  assert isinstance(c[k], torch.Tensor)
1547
+ # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1548
  uc[k] = c[k]
1549
  elif isinstance(c[k], list):
1550
  uc[k] = [c[k][i] for i in range(len(c[k]))]
 
1811
  log = super().log_images(*args, **kwargs)
1812
  depth = self.depth_model(args[0][self.depth_stage_key])
1813
  depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1814
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1815
  log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1816
  return log
1817
 
 
1820
  """
1821
  condition on low-res image (and optionally on some spatial noise augmentation)
1822
  """
1823
+
1824
  def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1825
  low_scale_config=None, low_scale_key=None, *args, **kwargs):
1826
  super().__init__(concat_keys=concat_keys, *args, **kwargs)
 
1877
  def log_images(self, *args, **kwargs):
1878
  log = super().log_images(*args, **kwargs)
1879
  log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1880
+ return log
ldm/modules/attention.py CHANGED
@@ -12,20 +12,23 @@ from ldm.modules.diffusionmodules.util import checkpoint
12
  try:
13
  import xformers
14
  import xformers.ops
 
15
  XFORMERS_IS_AVAILBLE = True
16
  except:
17
  XFORMERS_IS_AVAILBLE = False
18
 
19
  # CrossAttn precision handling
20
  import os
 
21
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
 
 
23
  def exists(val):
24
  return val is not None
25
 
26
 
27
  def uniq(arr):
28
- return{el: True for el in arr}.keys()
29
 
30
 
31
  def default(val, d):
@@ -33,6 +36,7 @@ def default(val, d):
33
  return val
34
  return d() if isfunction(d) else d
35
 
 
36
  class GEGLU(nn.Module):
37
  def __init__(self, dim_in, dim_out):
38
  super().__init__()
@@ -110,12 +114,12 @@ class SpatialSelfAttention(nn.Module):
110
  k = self.k(h_)
111
  v = self.v(h_)
112
 
113
- b,c,h,w = q.shape
114
  q = rearrange(q, 'b c h w -> b (h w) c')
115
  k = rearrange(k, 'b c h w -> b c (h w)')
116
  w_ = torch.einsum('bij,bjk->bik', q, k)
117
 
118
- w_ = w_ * (int(c)**(-0.5))
119
  w_ = torch.nn.functional.softmax(w_, dim=2)
120
 
121
  v = rearrange(v, 'b c h w -> b c (h w)')
@@ -124,7 +128,8 @@ class SpatialSelfAttention(nn.Module):
124
  h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
125
  h_ = self.proj_out(h_)
126
 
127
- return x+h_
 
128
 
129
  class CrossAttention(nn.Module):
130
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kwargs):
@@ -143,7 +148,6 @@ class CrossAttention(nn.Module):
143
  nn.Linear(inner_dim, query_dim),
144
  nn.Dropout(dropout)
145
  )
146
-
147
 
148
  def forward(self, x, context=None, mask=None):
149
  h = self.heads
@@ -153,26 +157,27 @@ class CrossAttention(nn.Module):
153
  v = self.to_v(context)
154
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
155
 
156
- if _ATTN_PRECISION =="fp32":
157
- with torch.autocast(enabled=False, device_type = 'cuda'):
158
  q, k = q.float(), k.float()
159
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
160
  else:
161
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
162
-
163
  del q, k
164
  if exists(mask):
165
  mask = rearrange(mask, 'b ... -> b (...)')
166
  max_neg_value = -torch.finfo(sim.dtype).max
167
  mask = repeat(mask, 'b j -> (b h) () j', h=h)
168
  sim.masked_fill_(~mask, max_neg_value)
169
-
170
- sim = sim.softmax(dim=-1)
171
-
172
  out = einsum('b i j, b j d -> b i d', sim, v)
173
  out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
174
  return self.to_out(out)
175
 
 
176
  class MemoryEfficientCrossAttention(nn.Module):
177
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
178
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, zero_init=False, **kwargs):
@@ -195,7 +200,6 @@ class MemoryEfficientCrossAttention(nn.Module):
195
 
196
  self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
197
  self.attention_op: Optional[Any] = None
198
-
199
 
200
  def forward(self, x, context=None, mask=None, **kwargs):
201
  q = self.to_q(x)
@@ -221,23 +225,25 @@ class MemoryEfficientCrossAttention(nn.Module):
221
  .reshape(b, out.shape[1], self.heads * self.dim_head)
222
  )
223
  return self.to_out(out)
224
-
 
225
  class BasicTransformerBlock(nn.Module):
226
  ATTENTION_MODES = {
227
  "softmax": CrossAttention, # vanilla attention
228
  "softmax-xformers": MemoryEfficientCrossAttention
229
  }
 
230
  def __init__(
231
- self,
232
- dim,
233
- n_heads,
234
- d_head,
235
- dropout=0.,
236
- context_dim=None,
237
- gated_ff=True,
238
  checkpoint=True,
239
  disable_self_attn=False
240
- ):
241
  super().__init__()
242
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
243
  assert attn_mode in self.ATTENTION_MODES
@@ -247,24 +253,25 @@ class BasicTransformerBlock(nn.Module):
247
  context_dim=context_dim if self.disable_self_attn else None)
248
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
249
  self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
250
- heads=n_heads, dim_head=d_head, dropout=dropout)
251
  self.norm1 = nn.LayerNorm(dim)
252
  self.norm2 = nn.LayerNorm(dim)
253
  self.norm3 = nn.LayerNorm(dim)
254
  self.checkpoint = checkpoint
255
 
256
- def forward(self, x, context=None,hint=None):
257
  if hint is None:
258
  return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
259
  else:
260
  return checkpoint(self._forward, (x, context, hint), self.parameters(), self.checkpoint)
261
 
262
- def _forward(self, x, context=None,hint=None):
263
- x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None,hint=hint) + x
264
  x = self.attn2(self.norm2(x), context=context) + x
265
  x = self.ff(self.norm3(x)) + x
266
  return x
267
 
 
268
  class SpatialTransformer(nn.Module):
269
  """
270
  Transformer block for image-like data.
@@ -274,6 +281,7 @@ class SpatialTransformer(nn.Module):
274
  Finally, reshape to image
275
  NEW: use_linear for more efficiency instead of the 1x1 convs
276
  """
 
277
  def __init__(self, in_channels, n_heads, d_head,
278
  depth=1, dropout=0., context_dim=None,
279
  disable_self_attn=False, use_linear=False,
@@ -296,7 +304,7 @@ class SpatialTransformer(nn.Module):
296
  self.transformer_blocks = nn.ModuleList(
297
  [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
298
  disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
299
- for d in range(depth)]
300
  )
301
  if not use_linear:
302
  self.proj_out = zero_module(nn.Conv2d(inner_dim,
@@ -308,7 +316,7 @@ class SpatialTransformer(nn.Module):
308
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
309
  self.use_linear = use_linear
310
 
311
- def forward(self, x, context=None,hint=None):
312
  # note: if no context is given, cross-attention defaults to self-attention
313
  if not isinstance(context, list):
314
  context = [context]
@@ -321,10 +329,10 @@ class SpatialTransformer(nn.Module):
321
  if self.use_linear:
322
  x = self.proj_in(x)
323
  for i, block in enumerate(self.transformer_blocks):
324
- x = block(x, context=context[i],hint=hint)
325
  if self.use_linear:
326
  x = self.proj_out(x)
327
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
328
  if not self.use_linear:
329
  x = self.proj_out(x)
330
- return x + x_in
 
12
  try:
13
  import xformers
14
  import xformers.ops
15
+
16
  XFORMERS_IS_AVAILBLE = True
17
  except:
18
  XFORMERS_IS_AVAILBLE = False
19
 
20
  # CrossAttn precision handling
21
  import os
22
+
23
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
 
25
+
26
  def exists(val):
27
  return val is not None
28
 
29
 
30
  def uniq(arr):
31
+ return {el: True for el in arr}.keys()
32
 
33
 
34
  def default(val, d):
 
36
  return val
37
  return d() if isfunction(d) else d
38
 
39
+
40
  class GEGLU(nn.Module):
41
  def __init__(self, dim_in, dim_out):
42
  super().__init__()
 
114
  k = self.k(h_)
115
  v = self.v(h_)
116
 
117
+ b, c, h, w = q.shape
118
  q = rearrange(q, 'b c h w -> b (h w) c')
119
  k = rearrange(k, 'b c h w -> b c (h w)')
120
  w_ = torch.einsum('bij,bjk->bik', q, k)
121
 
122
+ w_ = w_ * (int(c) ** (-0.5))
123
  w_ = torch.nn.functional.softmax(w_, dim=2)
124
 
125
  v = rearrange(v, 'b c h w -> b c (h w)')
 
128
  h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
129
  h_ = self.proj_out(h_)
130
 
131
+ return x + h_
132
+
133
 
134
  class CrossAttention(nn.Module):
135
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kwargs):
 
148
  nn.Linear(inner_dim, query_dim),
149
  nn.Dropout(dropout)
150
  )
 
151
 
152
  def forward(self, x, context=None, mask=None):
153
  h = self.heads
 
157
  v = self.to_v(context)
158
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
159
 
160
+ if _ATTN_PRECISION == "fp32":
161
+ with torch.autocast(enabled=False, device_type='cuda'):
162
  q, k = q.float(), k.float()
163
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
164
  else:
165
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
166
+
167
  del q, k
168
  if exists(mask):
169
  mask = rearrange(mask, 'b ... -> b (...)')
170
  max_neg_value = -torch.finfo(sim.dtype).max
171
  mask = repeat(mask, 'b j -> (b h) () j', h=h)
172
  sim.masked_fill_(~mask, max_neg_value)
173
+
174
+ sim = sim.softmax(dim=-1)
175
+
176
  out = einsum('b i j, b j d -> b i d', sim, v)
177
  out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
178
  return self.to_out(out)
179
 
180
+
181
  class MemoryEfficientCrossAttention(nn.Module):
182
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
183
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, zero_init=False, **kwargs):
 
200
 
201
  self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
202
  self.attention_op: Optional[Any] = None
 
203
 
204
  def forward(self, x, context=None, mask=None, **kwargs):
205
  q = self.to_q(x)
 
225
  .reshape(b, out.shape[1], self.heads * self.dim_head)
226
  )
227
  return self.to_out(out)
228
+
229
+
230
  class BasicTransformerBlock(nn.Module):
231
  ATTENTION_MODES = {
232
  "softmax": CrossAttention, # vanilla attention
233
  "softmax-xformers": MemoryEfficientCrossAttention
234
  }
235
+
236
  def __init__(
237
+ self,
238
+ dim,
239
+ n_heads,
240
+ d_head,
241
+ dropout=0.,
242
+ context_dim=None,
243
+ gated_ff=True,
244
  checkpoint=True,
245
  disable_self_attn=False
246
+ ):
247
  super().__init__()
248
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
249
  assert attn_mode in self.ATTENTION_MODES
 
253
  context_dim=context_dim if self.disable_self_attn else None)
254
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
255
  self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
256
+ heads=n_heads, dim_head=d_head, dropout=dropout)
257
  self.norm1 = nn.LayerNorm(dim)
258
  self.norm2 = nn.LayerNorm(dim)
259
  self.norm3 = nn.LayerNorm(dim)
260
  self.checkpoint = checkpoint
261
 
262
+ def forward(self, x, context=None, hint=None):
263
  if hint is None:
264
  return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
265
  else:
266
  return checkpoint(self._forward, (x, context, hint), self.parameters(), self.checkpoint)
267
 
268
+ def _forward(self, x, context=None, hint=None):
269
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, hint=hint) + x
270
  x = self.attn2(self.norm2(x), context=context) + x
271
  x = self.ff(self.norm3(x)) + x
272
  return x
273
 
274
+
275
  class SpatialTransformer(nn.Module):
276
  """
277
  Transformer block for image-like data.
 
281
  Finally, reshape to image
282
  NEW: use_linear for more efficiency instead of the 1x1 convs
283
  """
284
+
285
  def __init__(self, in_channels, n_heads, d_head,
286
  depth=1, dropout=0., context_dim=None,
287
  disable_self_attn=False, use_linear=False,
 
304
  self.transformer_blocks = nn.ModuleList(
305
  [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
306
  disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
307
+ for d in range(depth)]
308
  )
309
  if not use_linear:
310
  self.proj_out = zero_module(nn.Conv2d(inner_dim,
 
316
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
317
  self.use_linear = use_linear
318
 
319
+ def forward(self, x, context=None, hint=None):
320
  # note: if no context is given, cross-attention defaults to self-attention
321
  if not isinstance(context, list):
322
  context = [context]
 
329
  if self.use_linear:
330
  x = self.proj_in(x)
331
  for i, block in enumerate(self.transformer_blocks):
332
+ x = block(x, context=context[i])
333
  if self.use_linear:
334
  x = self.proj_out(x)
335
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
336
  if not self.use_linear:
337
  x = self.proj_out(x)
338
+ return x + x_in