sgoel30 commited on
Commit
60ee22e
·
verified ·
1 Parent(s): 3c395e0

Upload 2 files

Browse files
Files changed (2) hide show
  1. scripts/diffusion.py +293 -0
  2. scripts/train.py +3 -1
scripts/diffusion.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import pytorch_lightning as L
7
+ import torchmetrics
8
+ from dataclasses import dataclass
9
+ from esm_utils import load_esm2_model
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import dit, ema
12
+ import sys
13
+ import config
14
+ import wandb
15
+ import noise_schedule # Assuming this is part of the MDLM repository
16
+
17
+ wandb_key = "2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f"
18
+ wandb.login(key=wandb_key)
19
+ wandb.init(project=config.Wandb.PROJECT, group=config.Wandb.GROUP)
20
+
21
+ LOG2 = math.log(2)
22
+
23
+ # Goal is to build an MDLM head on the BERT-style ESM model
24
+ # Wrap the ESM model to obtain embeddings and ignore sigma to work with MDLM codebase
25
+ class WrapESM(nn.Module):
26
+ def __init__(self, esm_model_path):
27
+ super(WrapESM, self).__init__()
28
+ self.esm_tokenizer, self.esm_model, _ = load_esm2_model(esm_model_path)
29
+
30
+ ### Only fine-tune the last 3 layers of ESM
31
+ # Count number of encoder layers
32
+ model_layers = len(self.esm_model.esm.encoder.layer)
33
+
34
+ # Disable parameter updates for all layers
35
+ for param in self.esm_model.parameters():
36
+ param.requires_grad = False
37
+
38
+ # Now that all parameters are disabled, only enable updates for the last 3 layers
39
+ for i, layer in enumerate(self.esm_model.esm.encoder.layer):
40
+ if i >= model_layers-config.ESM_LAYERS:
41
+ for module in layer.attention.self.key.modules():
42
+ for param in module.parameters():
43
+ param.requires_grad = True
44
+ for module in layer.attention.self.query.modules():
45
+ for param in module.parameters():
46
+ param.requires_grad = True
47
+ for module in layer.attention.self.value.modules():
48
+ for param in module.parameters():
49
+ param.requires_grad = True
50
+
51
+ def forward(self, latents, sigma):
52
+ return latents
53
+
54
+ @dataclass
55
+ class Loss:
56
+ loss: torch.FloatTensor
57
+ nlls: torch.FloatTensor
58
+ token_mask: torch.FloatTensor
59
+
60
+ class NLL(torchmetrics.MeanMetric):
61
+ pass
62
+
63
+ class BPD(NLL):
64
+ def compute(self) -> torch.Tensor:
65
+ """Computes the bits per dimension.
66
+ Returns:
67
+ bpd
68
+ """
69
+ return self.mean_value / self.weight / LOG2
70
+
71
+ class Perplexity(NLL):
72
+ def compute(self) -> torch.Tensor:
73
+ """Computes the Perplexity.
74
+ Returns:
75
+ Perplexity
76
+ """
77
+ return torch.exp(self.mean_value / self.weight)
78
+
79
+
80
+ # Based on MDLM repo
81
+ class Diffusion(L.LightningModule):
82
+ def __init__(self, config, latent_dim, tokenizer):
83
+ super().__init__()
84
+ self.config = config
85
+ self.latent_dim = latent_dim
86
+ self.tokenizer = tokenizer
87
+
88
+ self.softplus = torch.nn.Softplus()
89
+ metrics = torchmetrics.MetricCollection({
90
+ 'nll': NLL(),
91
+ 'bpd': BPD(),
92
+ 'ppl': Perplexity(),
93
+ })
94
+ metrics.set_dtype(torch.float64)
95
+ self.train_metrics = metrics.clone(prefix='train/')
96
+ self.valid_metrics = metrics.clone(prefix='val/')
97
+ self.test_metrics = metrics.clone(prefix='test/')
98
+
99
+ self.T = self.config.T
100
+ self.lr = self.config.Optim.LR
101
+ self.backbone = WrapESM(self.config.MODEL_NAME)
102
+ self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
103
+ self.time_conditioning = self.config.TIME_CONDITIONING
104
+ self.subs_masking = self.config.SUBS_MASKING
105
+ self.mask_index = self.tokenizer.mask_token_id
106
+ self.antithetic_sampling = self.config.Training.ANTITHETIC_SAMPLING
107
+ self.sampling_eps = self.config.Training.SAMPLING_EPS
108
+ self.neg_infinity = -1000000.0
109
+
110
+
111
+ ############ FORWARD DIFFUSION #########
112
+ def subs_parameterization(self, logits, noised_latents):
113
+ logits = logits.float()
114
+ logits[:, :, self.mask_index] += self.neg_infinity
115
+
116
+ # Normalize the logits such that x.exp() is a probability distribution over vocab_size.
117
+ logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
118
+
119
+ unmasked_indices = (noised_latents != self.mask_index)
120
+ logits[unmasked_indices] = self.neg_infinity
121
+ logits[~unmasked_indices] = 0
122
+
123
+ return logits
124
+
125
+ # # -inf probability of selecting a masked token
126
+ # unmasked_indices = (noised_latents != self.mask_index)
127
+ # logits[unmasked_indices] = self.neg_infinity
128
+
129
+ # # Carry over unmasked tokens
130
+ # bsz, seq_len, input_dim = logits.shape
131
+ # for batch_idx in range(bsz):
132
+ # for residue in range(seq_len):
133
+ # logits[batch_idx, residue, noised_latents[batch_idx, residue]] = 0
134
+
135
+ # return logits
136
+
137
+ def forward(self, latents, sigma):
138
+ latents = latents.long()
139
+ logits = self.backbone(latents, sigma)
140
+ optimized_logits = self.subs_parameterization(logits, latents)
141
+ return optimized_logits
142
+
143
+ def q_xt(self, latents, move_chance):
144
+ """
145
+ Computes the noisy sample xt.
146
+ Args:
147
+ x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
148
+ move_chance: float torch.Tensor with shape (batch_size, 1).
149
+ """
150
+ #latents = latents.mean(dim=1) # [bsz x seq_len x 1280] --> [bsz x 1280] as per markdown
151
+ move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
152
+ noised_latents = torch.where(move_indices, self.mask_index, latents)
153
+ return noised_latents
154
+
155
+ def sample_timestep(self, n, device):
156
+ _eps_t = torch.rand(n, device=device)
157
+ if self.antithetic_sampling:
158
+ offset = torch.arange(n, device=device) / n
159
+ _eps_t = (_eps_t / n + offset) % 1
160
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
161
+ # if self.importance_sampling:
162
+ # return self.noise.importance_sampling_transformation(t)
163
+ return t
164
+
165
+ def forward_diffusion(self, x0):
166
+ """Forward diffusion process, adds noise to the latents."""
167
+
168
+ t = self.sample_timestep(x0.shape[0], x0.device)
169
+ sigma, dsigma = self.noise(t)
170
+ unet_conditioning = sigma[:, None]
171
+ move_chance = 1 - torch.exp(-sigma[:, None, None])
172
+
173
+ xt = self.q_xt(x0, move_chance)
174
+ model_output = self.forward(xt, unet_conditioning)
175
+
176
+ # SUBS parameterization, continuous time.
177
+ idx = x0.long()
178
+ print(f'idx: {idx.size()}')
179
+ print(f'idx min: {idx.min()}')
180
+ print(f'idx max: {idx.max()}')
181
+ print(f'model out: {model_output.size()}')
182
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
183
+ scale = (dsigma / torch.expm1(sigma))[:, None]
184
+ return - log_p_theta * scale
185
+
186
+
187
+ ######### LOSS CALCULATIONS #########
188
+ def compute_loss(self, latents, attention_mask):
189
+ """"Average of MLM losses to stabilize training"""
190
+ loss = self.forward_diffusion(latents)
191
+
192
+ nlls = loss * attention_mask
193
+ count = attention_mask.sum()
194
+ batch_nll = nlls.sum()
195
+ token_nll = batch_nll / count
196
+
197
+ return Loss(loss=token_nll, nlls=nlls, token_mask=attention_mask)
198
+
199
+
200
+ ######### TRAINING #########
201
+ def training_step(self, batch):
202
+ latents, attention_mask = batch
203
+ loss = self.compute_loss(latents, attention_mask)
204
+ wandb.log({"train_loss": loss.loss.item()})
205
+ return loss.loss
206
+
207
+ def configure_optimizers(self):
208
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
209
+ return optimizer
210
+
211
+ def validation_step(self, batch):
212
+ latents, attention_mask = batch
213
+ loss = self.compute_loss(latents, attention_mask)
214
+ wandb.log({"val_loss": loss.loss.item()})
215
+ return loss.loss
216
+
217
+
218
+ ######### GENERATION #########
219
+ def sample_prior(self, *batch_dims):
220
+ return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
221
+
222
+ def sample_categorical(categorical_probs):
223
+ gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
224
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
225
+
226
+ def ddpm_caching_update(self, x, t, dt, p_x0=None):
227
+ assert self.config.noise.type == 'loglinear'
228
+ sigma_t, _ = self.noise(t)
229
+ if t.ndim > 1:
230
+ t = t.squeeze(-1)
231
+ assert t.ndim == 1
232
+ move_chance_t = t[:, None, None]
233
+ move_chance_s = (t - dt)[:, None, None]
234
+ assert move_chance_t.ndim == 3, move_chance_t.shape
235
+ if p_x0 is None:
236
+ p_x0 = self.forward(x, sigma_t).exp()
237
+
238
+ assert move_chance_t.ndim == p_x0.ndim
239
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
240
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
241
+ _x = self.sample_categorical(q_xs)
242
+
243
+ copy_flag = (x != self.mask_index).to(x.dtype)
244
+ return p_x0, copy_flag * x + (1 - copy_flag) * _x
245
+
246
+
247
+ @torch.no_grad()
248
+ def sample_subs_guidance(self, n_samples, stride_length, num_strides, dt=0.001):
249
+ ones = torch.ones(n_samples, dtype=self.dtype,device=self.device)
250
+ num_steps = int(1 / dt)
251
+ sampling_steps = 0
252
+ intermediate_tokens = []
253
+ target = None
254
+
255
+ for _ in range(num_strides + 1):
256
+ p_x0_cache = None
257
+ x = self._sample_prior(n_samples,self.config.model.length).to(self.device)
258
+
259
+ if target is not None:
260
+ x[:, : -stride_length] = target
261
+
262
+ for i in range(num_steps + 1):
263
+ p_x0_cache, x_next = self.ddpm_caching_update(x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
264
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
265
+ p_x0_cache = None
266
+ sampling_steps += 1
267
+ x = x_next
268
+ x = self.forward(x, 0 * ones).argmax(dim=-1)
269
+ intermediate_tokens.append(x[:, :stride_length].cpu().numpy())
270
+ target = x[:, stride_length:]
271
+
272
+ intermediate_tokens.append(target.cpu().numpy())
273
+ intermediate_text_samples = []
274
+ sequence_lengths = ((np.concatenate(intermediate_tokens, axis=1)[:, 1:]
275
+ == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
276
+
277
+ for i in range(2, len(intermediate_tokens) + 1):
278
+ intermediate_text_samples.append(self.tokenizer.decode(np.concatenate(intermediate_tokens[:i], axis=1)))
279
+
280
+ return (sampling_steps, intermediate_text_samples,
281
+ sequence_lengths)
282
+
283
+ def restore_model_and_semi_ar_sample(self, stride_length, num_strides, dt=0.001):
284
+ """Generate samples from the model."""
285
+ # Lightning auto-casting is not working in this method for some reason
286
+ self.backbone.eval()
287
+ self.noise.eval()
288
+
289
+ (sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance(n_samples=self.config.Loader.BATCH_SIZE,stride_length=stride_length,num_strides=num_strides,dt=dt)
290
+
291
+ self.backbone.train()
292
+ self.noise.train()
293
+ return sampling_steps, samples, sequence_lengths
scripts/train.py CHANGED
@@ -5,13 +5,14 @@ import config
5
  from data_loader import get_dataloaders
6
  from esm_utils import load_esm2_model
7
  from diffusion import Diffusion
 
8
  import sys
9
 
10
  # Get dataloaders
11
  train_loader, val_loader, _ = get_dataloaders(config)
12
 
13
  # Initialize ESM tokenizer and model
14
- tokenizer, model = load_esm2_model(config.MODEL_NAME)
15
 
16
  # Initialize diffusion model
17
  latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer)
@@ -46,3 +47,4 @@ sys.stdout.flush()
46
  # Train the model
47
  trainer.fit(latent_diffusion_model, train_loader, val_loader)
48
 
 
 
5
  from data_loader import get_dataloaders
6
  from esm_utils import load_esm2_model
7
  from diffusion import Diffusion
8
+ import wandb
9
  import sys
10
 
11
  # Get dataloaders
12
  train_loader, val_loader, _ = get_dataloaders(config)
13
 
14
  # Initialize ESM tokenizer and model
15
+ tokenizer, _, _ = load_esm2_model(config.MODEL_NAME)
16
 
17
  # Initialize diffusion model
18
  latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer)
 
47
  # Train the model
48
  trainer.fit(latent_diffusion_model, train_loader, val_loader)
49
 
50
+ wandb.finish()