sgoel30 commited on
Commit
b6a71c9
·
verified ·
1 Parent(s): d6c63a1

Upload diffusion.py

Browse files
Files changed (1) hide show
  1. scripts/diffusion.py +8 -6
scripts/diffusion.py CHANGED
@@ -110,6 +110,7 @@ class Diffusion(L.LightningModule):
110
 
111
  ############ FORWARD DIFFUSION #########
112
  def subs_parameterization(self, logits, noised_latents):
 
113
  logits = logits.float()
114
  logits[:, :, self.mask_index] += self.neg_infinity
115
 
@@ -147,7 +148,7 @@ class Diffusion(L.LightningModule):
147
  x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
148
  move_chance: float torch.Tensor with shape (batch_size, 1).
149
  """
150
- #latents = latents.mean(dim=1) # [bsz x seq_len x 1280] --> [bsz x 1280] as per markdown
151
  move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
152
  noised_latents = torch.where(move_indices, self.mask_index, latents)
153
  return noised_latents
@@ -172,13 +173,14 @@ class Diffusion(L.LightningModule):
172
 
173
  xt = self.q_xt(x0, move_chance)
174
  model_output = self.forward(xt, unet_conditioning)
 
 
175
 
176
  # SUBS parameterization, continuous time.
177
- idx = x0.long()
178
- print(f'idx: {idx.size()}')
179
- print(f'idx min: {idx.min()}')
180
- print(f'idx max: {idx.max()}')
181
- print(f'model out: {model_output.size()}')
182
  log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
183
  scale = (dsigma / torch.expm1(sigma))[:, None]
184
  return - log_p_theta * scale
 
110
 
111
  ############ FORWARD DIFFUSION #########
112
  def subs_parameterization(self, logits, noised_latents):
113
+ print(logits.size()) # [bsz x bsz x seq_len]
114
  logits = logits.float()
115
  logits[:, :, self.mask_index] += self.neg_infinity
116
 
 
148
  x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
149
  move_chance: float torch.Tensor with shape (batch_size, 1).
150
  """
151
+ latents = torch.mean(latents, dim=2) # [bsz x seq_len x 1280] --> [bsz x seq_len] as per markdown
152
  move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
153
  noised_latents = torch.where(move_indices, self.mask_index, latents)
154
  return noised_latents
 
173
 
174
  xt = self.q_xt(x0, move_chance)
175
  model_output = self.forward(xt, unet_conditioning)
176
+ print(f'model out: {model_output}')
177
+ print(f'model out dim: {model_output.size()}') # [bsz x bsz x seq_len]
178
 
179
  # SUBS parameterization, continuous time.
180
+ idx = torch.mean(x0, dim=2).long()[:, :, None]
181
+ print(f'idx: {idx}')
182
+ print(f'idx dim: {idx.size()}') # [bsz x seq_len x 1]
183
+
 
184
  log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
185
  scale = (dsigma / torch.expm1(sigma))[:, None]
186
  return - log_p_theta * scale