Upload diffusion.py
Browse files- scripts/diffusion.py +8 -6
scripts/diffusion.py
CHANGED
@@ -110,6 +110,7 @@ class Diffusion(L.LightningModule):
|
|
110 |
|
111 |
############ FORWARD DIFFUSION #########
|
112 |
def subs_parameterization(self, logits, noised_latents):
|
|
|
113 |
logits = logits.float()
|
114 |
logits[:, :, self.mask_index] += self.neg_infinity
|
115 |
|
@@ -147,7 +148,7 @@ class Diffusion(L.LightningModule):
|
|
147 |
x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
|
148 |
move_chance: float torch.Tensor with shape (batch_size, 1).
|
149 |
"""
|
150 |
-
|
151 |
move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
|
152 |
noised_latents = torch.where(move_indices, self.mask_index, latents)
|
153 |
return noised_latents
|
@@ -172,13 +173,14 @@ class Diffusion(L.LightningModule):
|
|
172 |
|
173 |
xt = self.q_xt(x0, move_chance)
|
174 |
model_output = self.forward(xt, unet_conditioning)
|
|
|
|
|
175 |
|
176 |
# SUBS parameterization, continuous time.
|
177 |
-
idx = x0.long()
|
178 |
-
print(f'idx: {idx
|
179 |
-
print(f'idx
|
180 |
-
|
181 |
-
print(f'model out: {model_output.size()}')
|
182 |
log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
|
183 |
scale = (dsigma / torch.expm1(sigma))[:, None]
|
184 |
return - log_p_theta * scale
|
|
|
110 |
|
111 |
############ FORWARD DIFFUSION #########
|
112 |
def subs_parameterization(self, logits, noised_latents):
|
113 |
+
print(logits.size()) # [bsz x bsz x seq_len]
|
114 |
logits = logits.float()
|
115 |
logits[:, :, self.mask_index] += self.neg_infinity
|
116 |
|
|
|
148 |
x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
|
149 |
move_chance: float torch.Tensor with shape (batch_size, 1).
|
150 |
"""
|
151 |
+
latents = torch.mean(latents, dim=2) # [bsz x seq_len x 1280] --> [bsz x seq_len] as per markdown
|
152 |
move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
|
153 |
noised_latents = torch.where(move_indices, self.mask_index, latents)
|
154 |
return noised_latents
|
|
|
173 |
|
174 |
xt = self.q_xt(x0, move_chance)
|
175 |
model_output = self.forward(xt, unet_conditioning)
|
176 |
+
print(f'model out: {model_output}')
|
177 |
+
print(f'model out dim: {model_output.size()}') # [bsz x bsz x seq_len]
|
178 |
|
179 |
# SUBS parameterization, continuous time.
|
180 |
+
idx = torch.mean(x0, dim=2).long()[:, :, None]
|
181 |
+
print(f'idx: {idx}')
|
182 |
+
print(f'idx dim: {idx.size()}') # [bsz x seq_len x 1]
|
183 |
+
|
|
|
184 |
log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
|
185 |
scale = (dsigma / torch.expm1(sigma))[:, None]
|
186 |
return - log_p_theta * scale
|