sgoel30 commited on
Commit
9875b25
·
verified ·
1 Parent(s): 3340437

Upload 3 files

Browse files
Files changed (2) hide show
  1. models/diffusion.py +246 -22
  2. models/dit.py +3 -4
models/diffusion.py CHANGED
@@ -1,11 +1,11 @@
1
  import itertools
2
  import math
3
  import torch
4
- import torch.nn.functional as F
5
  import pytorch_lightning as L
6
  import torchmetrics
7
  from dataclasses import dataclass
8
- from models import dit, ema
9
  import noise_schedule # Assuming this is part of the MDLM repository
10
 
11
  LOG2 = math.log(2)
@@ -22,7 +22,6 @@ class NLL(torchmetrics.MeanMetric):
22
  class BPD(NLL):
23
  def compute(self) -> torch.Tensor:
24
  """Computes the bits per dimension.
25
-
26
  Returns:
27
  bpd
28
  """
@@ -31,21 +30,24 @@ class BPD(NLL):
31
  class Perplexity(NLL):
32
  def compute(self) -> torch.Tensor:
33
  """Computes the Perplexity.
34
-
35
  Returns:
36
  Perplexity
37
  """
38
  return torch.exp(self.mean_value / self.weight)
39
 
 
40
  class Diffusion(L.LightningModule):
41
- def __init__(self, config, latent_dim):
42
  super().__init__()
43
  self.config = config
44
  self.latent_dim = latent_dim
 
45
 
46
- self.backbone = dit.DIT(config, vocab_size=self.latent_dim)
47
  self.T = self.config.T
48
- self.subs_masking = self.config.subs_masking
 
 
49
 
50
  self.softplus = torch.nn.Softplus()
51
  metrics = torchmetrics.MetricCollection({
@@ -59,30 +61,252 @@ class Diffusion(L.LightningModule):
59
  self.test_metrics = metrics.clone(prefix='test/')
60
 
61
  self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
62
- self.lr = self.config.optim["lr"]
63
- self.sampling_eps = self.config.training.get("sampling_eps", 1e-5)
64
- self.time_conditioning = self.config.get("time_conditioning", True)
65
  self.neg_infinity = -1000000.0
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def forward(self, latents, sigma):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  """Forward diffusion process, adds noise to the latents."""
69
- noise = sigma * torch.randn_like(latents)
70
- noisy_latents = latents + noise
71
- return noisy_latents
72
 
73
- def reverse_diffusion(self, noisy_latents, sigma):
74
- """Reverse diffusion process, denoises the latents."""
75
- denoised_latents = self.backbone(noisy_latents, sigma)
76
- return denoised_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def training_step(self, batch, batch_idx):
79
- sigma = torch.rand(batch.size(0), device=self.device)
80
- noisy_latents = self.forward(batch, sigma)
81
- denoised_latents = self.reverse_diffusion(noisy_latents, sigma)
82
- loss = F.mse_loss(denoised_latents, batch)
83
- self.log("train_loss", loss)
84
  return loss
85
 
86
  def configure_optimizers(self):
87
  optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
88
  return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import itertools
2
  import math
3
  import torch
4
+ import numpy as np
5
  import pytorch_lightning as L
6
  import torchmetrics
7
  from dataclasses import dataclass
8
+ import dit, ema
9
  import noise_schedule # Assuming this is part of the MDLM repository
10
 
11
  LOG2 = math.log(2)
 
22
  class BPD(NLL):
23
  def compute(self) -> torch.Tensor:
24
  """Computes the bits per dimension.
 
25
  Returns:
26
  bpd
27
  """
 
30
  class Perplexity(NLL):
31
  def compute(self) -> torch.Tensor:
32
  """Computes the Perplexity.
 
33
  Returns:
34
  Perplexity
35
  """
36
  return torch.exp(self.mean_value / self.weight)
37
 
38
+ # Based on MDLM repo
39
  class Diffusion(L.LightningModule):
40
+ def __init__(self, config, latent_dim, tokenizer):
41
  super().__init__()
42
  self.config = config
43
  self.latent_dim = latent_dim
44
+ self.tokenizer = tokenizer
45
 
46
+ self.backbone = dit.DIT(self.config, vocab_size=self.latent_dim)
47
  self.T = self.config.T
48
+ self.subs_masking = self.config.SUBS_MASKING
49
+ self.antithetic_sampling = self.config.Training.ANTITHETIC_SAMPLING
50
+ self.mask_index = self.tokenizer.mask_token_id
51
 
52
  self.softplus = torch.nn.Softplus()
53
  metrics = torchmetrics.MetricCollection({
 
61
  self.test_metrics = metrics.clone(prefix='test/')
62
 
63
  self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
64
+ self.lr = self.config.Optim.LR
65
+ self.sampling_eps = self.config.Training.SAMPLING_EPS
66
+ self.time_conditioning = self.config.TIME_CONDITIONING
67
  self.neg_infinity = -1000000.0
68
 
69
+
70
+ ############ FORWARD DIFFUSION #########
71
+ def subs_parameterization(self, logits, noised_latents):
72
+ # log prob at the mask index = - infinity
73
+ logits[:, :, self.mask_index] += self.neg_infinity
74
+
75
+ # Normalize the logits such that x.exp() is
76
+ # a probability distribution over vocab_size.
77
+ logits = logits - torch.logsumexp(logits, dim=-1,
78
+ keepdim=True)
79
+
80
+ # Apply updates directly in the logits matrix.
81
+ # For the logits of the unmasked tokens, set all values
82
+ # to -infinity except for the indices corresponding to
83
+ # the unmasked tokens.
84
+ unmasked_indices = (noised_latents != self.mask_index)
85
+ logits[unmasked_indices] = self.neg_infinity
86
+ logits[unmasked_indices, noised_latents[unmasked_indices]] = 0
87
+ return logits
88
+
89
  def forward(self, latents, sigma):
90
+ latents = latents.long()
91
+ with torch.cuda.amp.autocast(dtype=torch.float32):
92
+ logits = self.backbone(latents, sigma)
93
+ print(logits)
94
+ optimized_logits = self.subs_parameterization(logits, latents)
95
+ return optimized_logits
96
+
97
+ def q_xt(self, latents, move_chance):
98
+ """
99
+ Computes the noisy sample xt.
100
+ Args:
101
+ x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
102
+ move_chance: float torch.Tensor with shape (batch_size, 1).
103
+ """
104
+ latents = latents.mean(dim=1) # [bsz x seq_len x 1280] --> [bsz x 1280] as per args
105
+ move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
106
+ noised_latents = torch.where(move_indices, self.mask_index, latents)
107
+ return noised_latents
108
+
109
+ def sample_timestep(self, n, device):
110
+ _eps_t = torch.rand(n, device=device)
111
+ if self.antithetic_sampling:
112
+ offset = torch.arange(n, device=device) / n
113
+ _eps_t = (_eps_t / n + offset) % 1
114
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
115
+ # if self.importance_sampling:
116
+ # return self.noise.importance_sampling_transformation(t)
117
+ return t
118
+
119
+
120
+ def d3pm_loss(self, model_output, xt, x0, t):
121
+ """Computes the D3PM loss between noisy latents and the original input at a given time step."""
122
+ dt = 1 / self.T
123
+
124
+ if torch.is_tensor(t):
125
+ t = t[:, None]
126
+ assert t.ndim == 2
127
+ t = t.clamp(0., 1. - 1e-4)
128
+ alpha_t = 1 - t + torch.zeros_like(xt)
129
+ alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
130
+
131
+ x0 = x0.to(torch.int64)
132
+ log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]).squeeze(-1)
133
+ log_x_theta_at_m = model_output[:, :, self.mask_index]
134
+ x_theta_at_m = log_x_theta_at_m.exp()
135
+
136
+ term_1_coef = dt / t
137
+ term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
138
+ term_1_log_dr = log_x_theta_at_x0
139
+
140
+ term_2_coef = 1 - dt / t
141
+ term_2_log_nr = term_1_log_nr
142
+ term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
143
+
144
+ L_vb_masked = (
145
+ term_1_coef * (term_1_log_nr - term_1_log_dr)
146
+ + term_2_coef * (term_2_log_nr - term_2_log_dr))
147
+
148
+ L_vb = L_vb_masked * (xt == self.mask_index)
149
+
150
+ return self.T * L_vb
151
+
152
+ def forward_diffusion(self, latents):
153
  """Forward diffusion process, adds noise to the latents."""
 
 
 
154
 
155
+ t = self.sample_timestep(latents.shape[0], latents.device)
156
+ if self.T > 0:
157
+ t = (t * self.T).to(torch.int)
158
+ t = t / self.T
159
+ # t \in {1/T, 2/T, ..., 1}
160
+ t += (1 / self.T)
161
+
162
+ sigma, dsigma = self.noise(t)
163
+ unet_conditioning = sigma[:, None]
164
+ move_chance = 1 - torch.exp(-sigma[:, None])
165
+
166
+ noised_latents = self.q_xt(latents, move_chance)
167
+ model_output = self.forward(noised_latents, unet_conditioning)
168
+
169
+ if self.T > 0:
170
+ diffusion_loss = self.d3pm_loss(model_output=model_output, xt=noised_latents, x0=latents, t=t)
171
+ return diffusion_loss
172
+ # SUBS parameterization, continuous time.
173
+ else:
174
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=latents[:, :, None]).squeeze(-1)
175
+ return - log_p_theta * (dsigma / torch.expm1(sigma))[:, None]
176
+
177
 
178
+ ######### LOSS CALCULATIONS #########
179
+ def maybe_sub_sample(self, x0, attention_mask):
180
+ # seqlen = x0.shape[1]
181
+ # print(seqlen)
182
+ # if seqlen > self.config.model.length:
183
+ # assert seqlen == 2 * self.config.model.length
184
+ # # cropping is needed for text8-crop dataset
185
+ # # try the same starting point for now
186
+ # start = np.random.choice(self.config.model.length)
187
+ # end = start + self.config.model.length
188
+ # input_tokens = x0[:, start: end]
189
+ # output_tokens = x0[:, start + 1: end + 1]
190
+ # new_attention_mask = attention_mask[:, start: end]
191
+
192
+ # # Helps with validation PPL, since the val
193
+ # # examples will all start and end with BOS/EOS
194
+ # input_tokens[:, 0] = self.tokenizer.bos_token_id
195
+ # output_tokens[:, -1] = self.tokenizer.eos_token_id
196
+
197
+ # elif self.parameterization == 'ar':
198
+ # input_tokens = x0[:, :-1]
199
+ # output_tokens = x0[:, 1:]
200
+ # new_attention_mask = attention_mask[:, 1:]
201
+ # else:
202
+ input_tokens = x0
203
+ output_tokens = None
204
+ new_attention_mask = attention_mask
205
+
206
+ return input_tokens, output_tokens, new_attention_mask
207
+
208
+ def compute_loss(self, latents, attention_mask):
209
+ """"Average of MLM losses to stabilize training"""
210
+ (input_tokens, output_tokens, attention_mask) = self.maybe_sub_sample(latents, attention_mask)
211
+ loss = self.forward_diffusion(input_tokens)
212
+
213
+ nlls = loss * attention_mask
214
+ count = attention_mask.sum()
215
+ batch_nll = nlls.sum()
216
+ token_nll = batch_nll / count
217
+
218
+ return Loss(loss=token_nll, nlls=nlls, token_mask=attention_mask)
219
+
220
+
221
+ ######### TRAINING #########
222
  def training_step(self, batch, batch_idx):
223
+ latents, attention_mask = batch
224
+ loss = self.compute_loss(latents, attention_mask)
 
 
 
225
  return loss
226
 
227
  def configure_optimizers(self):
228
  optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
229
  return optimizer
230
+
231
+ def validation_step(self, batch):
232
+ latents, attention_mask = batch
233
+ loss = self.compute_loss(latents, attention_mask)
234
+ return loss
235
+
236
+
237
+ ######### GENERATION #########
238
+ def sample_prior(self, *batch_dims):
239
+ return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
240
+
241
+ def sample_categorical(categorical_probs):
242
+ gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
243
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
244
+
245
+ def ddpm_caching_update(self, x, t, dt, p_x0=None):
246
+ assert self.config.noise.type == 'loglinear'
247
+ sigma_t, _ = self.noise(t)
248
+ if t.ndim > 1:
249
+ t = t.squeeze(-1)
250
+ assert t.ndim == 1
251
+ move_chance_t = t[:, None, None]
252
+ move_chance_s = (t - dt)[:, None, None]
253
+ assert move_chance_t.ndim == 3, move_chance_t.shape
254
+ if p_x0 is None:
255
+ p_x0 = self.forward(x, sigma_t).exp()
256
+
257
+ assert move_chance_t.ndim == p_x0.ndim
258
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
259
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
260
+ _x = self.sample_categorical(q_xs)
261
+
262
+ copy_flag = (x != self.mask_index).to(x.dtype)
263
+ return p_x0, copy_flag * x + (1 - copy_flag) * _x
264
+
265
+
266
+ @torch.no_grad()
267
+ def sample_subs_guidance(self, n_samples, stride_length, num_strides, dt=0.001):
268
+ ones = torch.ones(n_samples, dtype=self.dtype,device=self.device)
269
+ num_steps = int(1 / dt)
270
+ sampling_steps = 0
271
+ intermediate_tokens = []
272
+ target = None
273
+
274
+ for _ in range(num_strides + 1):
275
+ p_x0_cache = None
276
+ x = self._sample_prior(n_samples,self.config.model.length).to(self.device)
277
+
278
+ if target is not None:
279
+ x[:, : -stride_length] = target
280
+
281
+ for i in range(num_steps + 1):
282
+ p_x0_cache, x_next = self.ddpm_caching_update(x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
283
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
284
+ p_x0_cache = None
285
+ sampling_steps += 1
286
+ x = x_next
287
+ x = self.forward(x, 0 * ones).argmax(dim=-1)
288
+ intermediate_tokens.append(x[:, :stride_length].cpu().numpy())
289
+ target = x[:, stride_length:]
290
+
291
+ intermediate_tokens.append(target.cpu().numpy())
292
+ intermediate_text_samples = []
293
+ sequence_lengths = ((np.concatenate(intermediate_tokens, axis=1)[:, 1:]
294
+ == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
295
+
296
+ for i in range(2, len(intermediate_tokens) + 1):
297
+ intermediate_text_samples.append(self.tokenizer.decode(np.concatenate(intermediate_tokens[:i], axis=1)))
298
+
299
+ return (sampling_steps, intermediate_text_samples,
300
+ sequence_lengths)
301
+
302
+ def restore_model_and_semi_ar_sample(self, stride_length, num_strides, dt=0.001):
303
+ """Generate samples from the model."""
304
+ # Lightning auto-casting is not working in this method for some reason
305
+ self.backbone.eval()
306
+ self.noise.eval()
307
+
308
+ (sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance(n_samples=self.config.Loader.BATCH_SIZE,stride_length=stride_length,num_strides=num_strides,dt=dt)
309
+
310
+ self.backbone.train()
311
+ self.noise.train()
312
+ return sampling_steps, samples, sequence_lengths
models/dit.py CHANGED
@@ -246,8 +246,7 @@ class DDiTBlock(nn.Module):
246
 
247
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
248
 
249
- (shift_msa, scale_msa, gate_msa, shift_mlp,
250
- scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
251
 
252
  # attention operation
253
  x_skip = x
@@ -315,7 +314,7 @@ class DDitFinalLayer(nn.Module):
315
 
316
 
317
  def forward(self, x, c):
318
- shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
319
  x = modulate_fused(self.norm_final(x), shift, scale)
320
  x = self.linear(x)
321
  return x
@@ -348,7 +347,7 @@ class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
348
  config.model.hidden_size,
349
  vocab_size,
350
  config.model.cond_dim)
351
- self.scale_by_sigma = config.model.scale_by_sigma
352
 
353
  def _get_bias_dropout_scale(self):
354
  if self.training:
 
246
 
247
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
248
 
249
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None][0].chunk(6, dim=2)
 
250
 
251
  # attention operation
252
  x_skip = x
 
314
 
315
 
316
  def forward(self, x, c):
317
+ shift, scale = self.adaLN_modulation(c)[:, None][0].chunk(2, dim=2)
318
  x = modulate_fused(self.norm_final(x), shift, scale)
319
  x = self.linear(x)
320
  return x
 
347
  config.model.hidden_size,
348
  vocab_size,
349
  config.model.cond_dim)
350
+ #self.scale_by_sigma = config.model.scale_by_sigma
351
 
352
  def _get_bias_dropout_scale(self):
353
  if self.training: