sgoel30 commited on
Commit
bce7ea6
·
verified ·
1 Parent(s): faaee28

Delete scripts

Browse files
Files changed (4) hide show
  1. scripts/diffusion.py +0 -295
  2. scripts/generate.py +0 -131
  3. scripts/test.py +0 -17
  4. scripts/train.py +0 -50
scripts/diffusion.py DELETED
@@ -1,295 +0,0 @@
1
- import itertools
2
- import math
3
- import torch
4
- import torch.nn as nn
5
- import numpy as np
6
- import pytorch_lightning as L
7
- import torchmetrics
8
- from dataclasses import dataclass
9
- from esm_utils import load_esm2_model
10
- from transformers import AutoModel, AutoTokenizer
11
- import dit, ema
12
- import sys
13
- import config
14
- import wandb
15
- import noise_schedule # Assuming this is part of the MDLM repository
16
-
17
- wandb_key = "2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f"
18
- wandb.login(key=wandb_key)
19
- wandb.init(project=config.Wandb.PROJECT, group=config.Wandb.GROUP)
20
-
21
- LOG2 = math.log(2)
22
-
23
- # Goal is to build an MDLM head on the BERT-style ESM model
24
- # Wrap the ESM model to obtain embeddings and ignore sigma to work with MDLM codebase
25
- class WrapESM(nn.Module):
26
- def __init__(self, esm_model_path):
27
- super(WrapESM, self).__init__()
28
- self.esm_tokenizer, self.esm_model, _ = load_esm2_model(esm_model_path)
29
-
30
- ### Only fine-tune the last 3 layers of ESM
31
- # Count number of encoder layers
32
- model_layers = len(self.esm_model.esm.encoder.layer)
33
-
34
- # Disable parameter updates for all layers
35
- for param in self.esm_model.parameters():
36
- param.requires_grad = False
37
-
38
- # Now that all parameters are disabled, only enable updates for the last 3 layers
39
- for i, layer in enumerate(self.esm_model.esm.encoder.layer):
40
- if i >= model_layers-config.ESM_LAYERS:
41
- for module in layer.attention.self.key.modules():
42
- for param in module.parameters():
43
- param.requires_grad = True
44
- for module in layer.attention.self.query.modules():
45
- for param in module.parameters():
46
- param.requires_grad = True
47
- for module in layer.attention.self.value.modules():
48
- for param in module.parameters():
49
- param.requires_grad = True
50
-
51
- def forward(self, latents, sigma):
52
- return latents
53
-
54
- @dataclass
55
- class Loss:
56
- loss: torch.FloatTensor
57
- nlls: torch.FloatTensor
58
- token_mask: torch.FloatTensor
59
-
60
- class NLL(torchmetrics.MeanMetric):
61
- pass
62
-
63
- class BPD(NLL):
64
- def compute(self) -> torch.Tensor:
65
- """Computes the bits per dimension.
66
- Returns:
67
- bpd
68
- """
69
- return self.mean_value / self.weight / LOG2
70
-
71
- class Perplexity(NLL):
72
- def compute(self) -> torch.Tensor:
73
- """Computes the Perplexity.
74
- Returns:
75
- Perplexity
76
- """
77
- return torch.exp(self.mean_value / self.weight)
78
-
79
-
80
- # Based on MDLM repo
81
- class Diffusion(L.LightningModule):
82
- def __init__(self, config, latent_dim, tokenizer):
83
- super().__init__()
84
- self.config = config
85
- self.latent_dim = latent_dim
86
- self.tokenizer = tokenizer
87
-
88
- self.softplus = torch.nn.Softplus()
89
- metrics = torchmetrics.MetricCollection({
90
- 'nll': NLL(),
91
- 'bpd': BPD(),
92
- 'ppl': Perplexity(),
93
- })
94
- metrics.set_dtype(torch.float64)
95
- self.train_metrics = metrics.clone(prefix='train/')
96
- self.valid_metrics = metrics.clone(prefix='val/')
97
- self.test_metrics = metrics.clone(prefix='test/')
98
-
99
- self.T = self.config.T
100
- self.lr = self.config.Optim.LR
101
- self.backbone = WrapESM(self.config.MODEL_NAME)
102
- self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
103
- self.time_conditioning = self.config.TIME_CONDITIONING
104
- self.subs_masking = self.config.SUBS_MASKING
105
- self.mask_index = self.tokenizer.mask_token_id
106
- self.antithetic_sampling = self.config.Training.ANTITHETIC_SAMPLING
107
- self.sampling_eps = self.config.Training.SAMPLING_EPS
108
- self.neg_infinity = -1000000.0
109
-
110
-
111
- ############ FORWARD DIFFUSION #########
112
- def subs_parameterization(self, logits, noised_latents):
113
- print(logits.size()) # [bsz x bsz x seq_len]
114
- logits = logits.float()
115
- logits[:, :, self.mask_index] += self.neg_infinity
116
-
117
- # Normalize the logits such that x.exp() is a probability distribution over vocab_size.
118
- logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
119
-
120
- unmasked_indices = (noised_latents != self.mask_index)
121
- logits[unmasked_indices] = self.neg_infinity
122
- logits[~unmasked_indices] = 0
123
-
124
- return logits
125
-
126
- # # -inf probability of selecting a masked token
127
- # unmasked_indices = (noised_latents != self.mask_index)
128
- # logits[unmasked_indices] = self.neg_infinity
129
-
130
- # # Carry over unmasked tokens
131
- # bsz, seq_len, input_dim = logits.shape
132
- # for batch_idx in range(bsz):
133
- # for residue in range(seq_len):
134
- # logits[batch_idx, residue, noised_latents[batch_idx, residue]] = 0
135
-
136
- # return logits
137
-
138
- def forward(self, latents, sigma):
139
- latents = latents.long()
140
- logits = self.backbone(latents, sigma)
141
- optimized_logits = self.subs_parameterization(logits, latents)
142
- return optimized_logits
143
-
144
- def q_xt(self, latents, move_chance):
145
- """
146
- Computes the noisy sample xt.
147
- Args:
148
- x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
149
- move_chance: float torch.Tensor with shape (batch_size, 1).
150
- """
151
- latents = torch.mean(latents, dim=2) # [bsz x seq_len x 1280] --> [bsz x seq_len] as per markdown
152
- move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
153
- noised_latents = torch.where(move_indices, self.mask_index, latents)
154
- return noised_latents
155
-
156
- def sample_timestep(self, n, device):
157
- _eps_t = torch.rand(n, device=device)
158
- if self.antithetic_sampling:
159
- offset = torch.arange(n, device=device) / n
160
- _eps_t = (_eps_t / n + offset) % 1
161
- t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
162
- # if self.importance_sampling:
163
- # return self.noise.importance_sampling_transformation(t)
164
- return t
165
-
166
- def forward_diffusion(self, x0):
167
- """Forward diffusion process, adds noise to the latents."""
168
-
169
- t = self.sample_timestep(x0.shape[0], x0.device)
170
- sigma, dsigma = self.noise(t)
171
- unet_conditioning = sigma[:, None]
172
- move_chance = 1 - torch.exp(-sigma[:, None, None])
173
-
174
- xt = self.q_xt(x0, move_chance)
175
- model_output = self.forward(xt, unet_conditioning)
176
- print(f'model out: {model_output}')
177
- print(f'model out dim: {model_output.size()}') # [bsz x bsz x seq_len]
178
-
179
- # SUBS parameterization, continuous time.
180
- idx = torch.mean(x0, dim=2).long()[:, :, None]
181
- print(f'idx: {idx}')
182
- print(f'idx dim: {idx.size()}') # [bsz x seq_len x 1]
183
-
184
- log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
185
- scale = (dsigma / torch.expm1(sigma))[:, None]
186
- return - log_p_theta * scale
187
-
188
-
189
- ######### LOSS CALCULATIONS #########
190
- def compute_loss(self, latents, attention_mask):
191
- """"Average of MLM losses to stabilize training"""
192
- loss = self.forward_diffusion(latents)
193
-
194
- nlls = loss * attention_mask
195
- count = attention_mask.sum()
196
- batch_nll = nlls.sum()
197
- token_nll = batch_nll / count
198
-
199
- return Loss(loss=token_nll, nlls=nlls, token_mask=attention_mask)
200
-
201
-
202
- ######### TRAINING #########
203
- def training_step(self, batch):
204
- latents, attention_mask = batch
205
- loss = self.compute_loss(latents, attention_mask)
206
- wandb.log({"train_loss": loss.loss.item()})
207
- return loss.loss
208
-
209
- def configure_optimizers(self):
210
- optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
211
- return optimizer
212
-
213
- def validation_step(self, batch):
214
- latents, attention_mask = batch
215
- loss = self.compute_loss(latents, attention_mask)
216
- wandb.log({"val_loss": loss.loss.item()})
217
- return loss.loss
218
-
219
-
220
- ######### GENERATION #########
221
- def sample_prior(self, *batch_dims):
222
- return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
223
-
224
- def sample_categorical(categorical_probs):
225
- gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
226
- return (categorical_probs / gumbel_norm).argmax(dim=-1)
227
-
228
- def ddpm_caching_update(self, x, t, dt, p_x0=None):
229
- assert self.config.noise.type == 'loglinear'
230
- sigma_t, _ = self.noise(t)
231
- if t.ndim > 1:
232
- t = t.squeeze(-1)
233
- assert t.ndim == 1
234
- move_chance_t = t[:, None, None]
235
- move_chance_s = (t - dt)[:, None, None]
236
- assert move_chance_t.ndim == 3, move_chance_t.shape
237
- if p_x0 is None:
238
- p_x0 = self.forward(x, sigma_t).exp()
239
-
240
- assert move_chance_t.ndim == p_x0.ndim
241
- q_xs = p_x0 * (move_chance_t - move_chance_s)
242
- q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
243
- _x = self.sample_categorical(q_xs)
244
-
245
- copy_flag = (x != self.mask_index).to(x.dtype)
246
- return p_x0, copy_flag * x + (1 - copy_flag) * _x
247
-
248
-
249
- @torch.no_grad()
250
- def sample_subs_guidance(self, n_samples, stride_length, num_strides, dt=0.001):
251
- ones = torch.ones(n_samples, dtype=self.dtype,device=self.device)
252
- num_steps = int(1 / dt)
253
- sampling_steps = 0
254
- intermediate_tokens = []
255
- target = None
256
-
257
- for _ in range(num_strides + 1):
258
- p_x0_cache = None
259
- x = self._sample_prior(n_samples,self.config.model.length).to(self.device)
260
-
261
- if target is not None:
262
- x[:, : -stride_length] = target
263
-
264
- for i in range(num_steps + 1):
265
- p_x0_cache, x_next = self.ddpm_caching_update(x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
266
- if (not torch.allclose(x_next, x) or self.time_conditioning):
267
- p_x0_cache = None
268
- sampling_steps += 1
269
- x = x_next
270
- x = self.forward(x, 0 * ones).argmax(dim=-1)
271
- intermediate_tokens.append(x[:, :stride_length].cpu().numpy())
272
- target = x[:, stride_length:]
273
-
274
- intermediate_tokens.append(target.cpu().numpy())
275
- intermediate_text_samples = []
276
- sequence_lengths = ((np.concatenate(intermediate_tokens, axis=1)[:, 1:]
277
- == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
278
-
279
- for i in range(2, len(intermediate_tokens) + 1):
280
- intermediate_text_samples.append(self.tokenizer.decode(np.concatenate(intermediate_tokens[:i], axis=1)))
281
-
282
- return (sampling_steps, intermediate_text_samples,
283
- sequence_lengths)
284
-
285
- def restore_model_and_semi_ar_sample(self, stride_length, num_strides, dt=0.001):
286
- """Generate samples from the model."""
287
- # Lightning auto-casting is not working in this method for some reason
288
- self.backbone.eval()
289
- self.noise.eval()
290
-
291
- (sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance(n_samples=self.config.Loader.BATCH_SIZE,stride_length=stride_length,num_strides=num_strides,dt=dt)
292
-
293
- self.backbone.train()
294
- self.noise.train()
295
- return sampling_steps, samples, sequence_lengths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/generate.py DELETED
@@ -1,131 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from transformers import AutoTokenizer, AutoModel
4
- from models.diffusion import Diffusion
5
- from configs.config import Config
6
- from utils.esm_utils import load_esm2_model, get_latents
7
-
8
- def mask_sequence(sequence, mask_char='X'):
9
- """Masks parts of the sequence based on the mask_char."""
10
- mask_indices = [i for i, char in enumerate(sequence) if char == mask_char]
11
- masked_sequence = sequence.replace(mask_char, '[MASK]')
12
- return masked_sequence, mask_indices
13
-
14
- def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices):
15
- """Generates the filled sequence for the masked regions."""
16
- inputs = tokenizer(masked_sequence, return_tensors="pt")
17
- with torch.no_grad():
18
- outputs = esm_model(**inputs)
19
- latents = outputs.last_hidden_state.squeeze(0)
20
-
21
- sigma = torch.rand(1, device=latents.device)
22
- noisy_latents = model.forward(latents, sigma)
23
- denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
24
-
25
- filled_sequence = list(masked_sequence)
26
- for idx in mask_indices:
27
- token_id = torch.argmax(denoised_latents[idx]).item()
28
- filled_sequence[idx] = tokenizer.decode([token_id])
29
-
30
- return ''.join(filled_sequence)
31
-
32
- def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length):
33
- """Generates a scaffold sequence to connect multiple peptides."""
34
- total_peptide_length = sum(len(peptide) for peptide in peptides)
35
- scaffold_length = final_length - total_peptide_length
36
- if scaffold_length <= 0:
37
- raise ValueError("Final length must be greater than the combined length of the peptides.")
38
-
39
- scaffold = "[MASK]" * scaffold_length
40
- masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:])
41
-
42
- inputs = tokenizer(masked_sequence, return_tensors="pt")
43
- with torch.no_grad():
44
- outputs = esm_model(**inputs)
45
- latents = outputs.last_hidden_state.squeeze(0)
46
-
47
- sigma = torch.rand(1, device=latents.device)
48
- noisy_latents = model.forward(latents, sigma)
49
- denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
50
-
51
- filled_sequence = list(masked_sequence)
52
- scaffold_start = len(peptides[0])
53
- scaffold_end = scaffold_start + scaffold_length
54
- for idx in range(scaffold_start, scaffold_end):
55
- token_id = torch.argmax(denoised_latents[idx]).item()
56
- filled_sequence[idx] = tokenizer.decode([token_id])
57
-
58
- return ''.join(filled_sequence)
59
-
60
- def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length):
61
- """Generates a de novo protein sequence of the specified length."""
62
- scaffold = "[MASK]" * sequence_length
63
- masked_sequence = scaffold
64
-
65
- inputs = tokenizer(masked_sequence, return_tensors="pt")
66
- with torch.no_grad():
67
- outputs = esm_model(**inputs)
68
- latents = outputs.last_hidden_state.squeeze(0)
69
-
70
- sigma = torch.rand(1, device=latents.device)
71
- noisy_latents = model.forward(latents, sigma)
72
- denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
73
-
74
- filled_sequence = list(masked_sequence)
75
- for idx in range(sequence_length):
76
- token_id = torch.argmax(denoised_latents[idx]).item()
77
- filled_sequence[idx] = tokenizer.decode([token_id])
78
-
79
- return ''.join(filled_sequence)
80
-
81
- if __name__ == "__main__":
82
- import argparse
83
-
84
- # Argument parsing
85
- parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.")
86
- subparsers = parser.add_subparsers(dest="mode")
87
-
88
- # Subparser for the first strategy (multiple peptides to scaffold)
89
- parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.")
90
- parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.")
91
- parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.")
92
-
93
- # Subparser for the second strategy (fill in regions)
94
- parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.")
95
- parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.")
96
-
97
- # Subparser for the third strategy (de novo generation)
98
- parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.")
99
- parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.")
100
-
101
- args = parser.parse_args()
102
-
103
- # Load configurations
104
- config = Config()
105
-
106
- # Load models
107
- tokenizer, esm_model = load_esm2_model(config.model_name)
108
- diffusion_model = Diffusion.load_from_checkpoint(config.training["save_dir"] + "example.ckpt", config=config, latent_dim=config.latent_dim)
109
- diffusion_model.eval()
110
-
111
- if args.mode == "scaffold":
112
- peptides = args.peptides
113
- final_length = args.final_length
114
- filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length)
115
- print(f"Peptides: {' '.join(peptides)}")
116
- print(f"Final Length: {final_length}")
117
- print(f"Generated Protein: {filled_sequence}")
118
-
119
- elif args.mode == "fill":
120
- sequence = args.sequence
121
- masked_sequence, mask_indices = mask_sequence(sequence)
122
- filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices)
123
- print(f"Original Sequence: {sequence}")
124
- print(f"Masked Sequence: {masked_sequence}")
125
- print(f"Filled Sequence: {filled_sequence}")
126
-
127
- elif args.mode == "de_novo":
128
- sequence_length = args.sequence_length
129
- filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length)
130
- print(f"De Novo Sequence Length: {sequence_length}")
131
- print(f"Generated Protein: {filled_sequence}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/test.py DELETED
@@ -1,17 +0,0 @@
1
- import pytorch_lightning as L
2
- from configs.config import Config
3
- from utils.data_loader import get_dataloaders
4
- from models.diffusion import Diffusion
5
-
6
- # Get dataloaders
7
- _, _, test_loader = get_dataloaders(Config)
8
-
9
- # Initialize model
10
- checkpoint_path = Config.training["save_dir"] + "example.ckpt"
11
- latent_diffusion_model = Diffusion.load_from_checkpoint(checkpoint_path, config=Config, latent_dim=Config.latent_dim)
12
-
13
- # Initialize trainer
14
- trainer = L.Trainer(gpus=Config.training["gpus"], precision=Config.training["precision"])
15
-
16
- # Test the model
17
- trainer.test(latent_diffusion_model, test_loader)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train.py DELETED
@@ -1,50 +0,0 @@
1
- import pytorch_lightning as L
2
- from pytorch_lightning.strategies import DDPStrategy
3
- from pytorch_lightning.callbacks import ModelCheckpoint
4
- import config
5
- from data_loader import get_dataloaders
6
- from esm_utils import load_esm2_model
7
- from diffusion import Diffusion
8
- import wandb
9
- import sys
10
-
11
- # Get dataloaders
12
- train_loader, val_loader, _ = get_dataloaders(config)
13
-
14
- # Initialize ESM tokenizer and model
15
- tokenizer, _, _ = load_esm2_model(config.MODEL_NAME)
16
-
17
- # Initialize diffusion model
18
- latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer)
19
- print(latent_diffusion_model)
20
- sys.stdout.flush()
21
-
22
- # Define checkpoints to save best model by minimum validation loss
23
- checkpoint_callback = ModelCheckpoint(
24
- monitor='val_loss',
25
- save_top_k=1,
26
- mode='min',
27
- dirpath="/workspace/a03-sgoel/MDpLM/",
28
- filename="best_model_epoch{epoch:02d}"
29
- )
30
-
31
- # Initialize trainer
32
- trainer = L.Trainer(
33
- max_epochs=config.Training.NUM_EPOCHS,
34
- precision=config.Training.PRECISION,
35
- devices=1,
36
- accelerator='gpu',
37
- strategy=DDPStrategy(find_unused_parameters=False),
38
- accumulate_grad_batches=config.Training.ACCUMULATE_GRAD_BATCHES,
39
- default_root_dir=config.Training.SAVE_DIR,
40
- callbacks=[checkpoint_callback]
41
- )
42
-
43
- print(trainer)
44
- print("Training model...")
45
- sys.stdout.flush()
46
-
47
- # Train the model
48
- trainer.fit(latent_diffusion_model, train_loader, val_loader)
49
-
50
- wandb.finish()