AlienChen commited on
Commit
65bd8af
·
verified ·
1 Parent(s): ad15e9b

Upload 139 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. binder_generator_train.py +241 -0
  2. finetune.py +385 -0
  3. muppit/.gitignore +7 -0
  4. muppit/LICENSE +201 -0
  5. muppit/README.md +250 -0
  6. muppit/__pycache__/classifier.cpython-310.pyc +0 -0
  7. muppit/__pycache__/dataloader.cpython-310.pyc +0 -0
  8. muppit/__pycache__/diffusion.cpython-310.pyc +0 -0
  9. muppit/__pycache__/noise_schedule.cpython-310.pyc +0 -0
  10. muppit/__pycache__/utils.cpython-310.pyc +0 -0
  11. muppit/classifier.py +490 -0
  12. muppit/configs/callbacks/checkpoint_every_n_steps.yaml +8 -0
  13. muppit/configs/callbacks/checkpoint_monitor.yaml +10 -0
  14. muppit/configs/callbacks/learning_rate_monitor.yaml +3 -0
  15. muppit/configs/classifier_model/dimamba-classifier.yaml +14 -0
  16. muppit/configs/classifier_model/hyenadna-classifier.yaml +4 -0
  17. muppit/configs/classifier_model/small-classifier.yaml +11 -0
  18. muppit/configs/classifier_model/tiny-classifier.yaml +11 -0
  19. muppit/configs/classifier_model/tiny-dimamba-classifier.yaml +14 -0
  20. muppit/configs/config.yaml +104 -0
  21. muppit/configs/data/amazon_polarity.yaml +10 -0
  22. muppit/configs/data/cifar10.yaml +11 -0
  23. muppit/configs/data/lm1b.yaml +8 -0
  24. muppit/configs/data/peptide.yaml +8 -0
  25. muppit/configs/data/protein.yaml +8 -0
  26. muppit/configs/data/qm9.yaml +11 -0
  27. muppit/configs/data/ten_species.yaml +11 -0
  28. muppit/configs/data/text8.yaml +9 -0
  29. muppit/configs/guidance/cbg.yaml +5 -0
  30. muppit/configs/guidance/cfg.yaml +3 -0
  31. muppit/configs/guidance/fudge.yaml +5 -0
  32. muppit/configs/guidance/nos.yaml +6 -0
  33. muppit/configs/guidance/pplm.yaml +6 -0
  34. muppit/configs/lr_scheduler/constant_warmup.yaml +2 -0
  35. muppit/configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
  36. muppit/configs/model/dimamba.yaml +12 -0
  37. muppit/configs/model/fudge_predictor.yaml +4 -0
  38. muppit/configs/model/hf.yaml +2 -0
  39. muppit/configs/model/medium.yaml +10 -0
  40. muppit/configs/model/small.yaml +11 -0
  41. muppit/configs/model/tiny.yaml +10 -0
  42. muppit/configs/model/unet.yaml +19 -0
  43. muppit/configs/model/unet_campbell.yaml +19 -0
  44. muppit/configs/noise/ar.yaml +2 -0
  45. muppit/configs/noise/linear.yaml +3 -0
  46. muppit/configs/noise/loglinear.yaml +3 -0
  47. muppit/configs/noise/polynomial.yaml +5 -0
  48. muppit/configs/strategy/ddp.yaml +2 -0
  49. muppit/configs/strategy/fsdp.yaml +3 -0
  50. muppit/custom_datasets/__init__.py +2 -0
binder_generator_train.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, Subset
3
+ from torch.optim import AdamW
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+ from datasets import load_from_disk
7
+ import esm
8
+ import numpy as np
9
+ import math
10
+ import os
11
+ from transformers import AutoTokenizer
12
+ from torch.optim.lr_scheduler import CosineAnnealingLR
13
+ from transformers import get_linear_schedule_with_warmup
14
+ from tqdm import tqdm
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ import gc
17
+ import pdb
18
+
19
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1'
20
+
21
+ ##################### Hyper-parameters #############################################
22
+ max_epochs = 30
23
+ batch_size = 4
24
+ lr = 1e-4
25
+ num_layers = 4
26
+ num_heads = 4
27
+ accumulation_steps = 4
28
+ checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0'
29
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
+
31
+ print(f'''
32
+ max_epochs = 30
33
+ batch_size = 4
34
+ lr = 1e-4
35
+ num_layers = 4
36
+ num_heads = 4
37
+ accumulation_steps = 4
38
+ checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0'
39
+ ''')
40
+ ####################################################################################
41
+
42
+ os.makedirs(checkpoint_path, exist_ok=True)
43
+
44
+ train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_generator')
45
+ val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_generator')
46
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
47
+ print(len(train_dataset), len(val_dataset))
48
+
49
+ def collate_fn(batch):
50
+ # Unpack the batch
51
+ binders = []
52
+ targets = []
53
+
54
+ global tokenizer
55
+
56
+ for b in batch:
57
+ binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1])
58
+ target = torch.tensor(b['target_input_ids']['input_ids'][1:-1])
59
+
60
+ if binder.dim() == 0 or binder.numel() == 0 or target.dim() == 0 or target.numel() == 0:
61
+ continue
62
+ binders.append(binder) # shape: 1*L1 -> L1
63
+ targets.append(target) # shape: 1*L2 -> L2
64
+
65
+ # Collate the tensors using torch's pad_sequence
66
+ try:
67
+ binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
68
+
69
+ target_input_ids = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id)
70
+ except:
71
+ pdb.set_trace()
72
+
73
+ # Return the collated batch
74
+ return {
75
+ 'binder_input_ids': binder_input_ids.long(),
76
+ 'target_input_ids': target_input_ids.long(),
77
+ }
78
+
79
+
80
+ def RoPE(x, seq_dim=0):
81
+ """
82
+ Applies Rotary Positional Encoding to the input embeddings.
83
+ :param x: Input tensor (seq_len, batch_size, embed_dim)
84
+ :param seq_dim: The sequence dimension, usually 0 (first dimension in (seq_len, batch_size, embed_dim))
85
+ :return: Tensor with RoPE applied (seq_len, batch_size, embed_dim)
86
+ """
87
+ seq_len = x.shape[seq_dim]
88
+ d_model = x.shape[-1]
89
+
90
+ # Create the positions and the sine-cosine rotational matrices
91
+ theta = torch.arange(0, d_model, 2, dtype=torch.float32) / d_model
92
+ theta = 10000 ** (-theta) # scaling factor for RoPE
93
+ seq_idx = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1)
94
+
95
+ # Compute sine and cosine embedding for each position
96
+ sin_emb = torch.sin(seq_idx * theta)
97
+ cos_emb = torch.cos(seq_idx * theta)
98
+
99
+ sin_emb = sin_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2]
100
+ cos_emb = cos_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2]
101
+
102
+ x1, x2 = x[..., ::2], x[..., 1::2] # Split embedding into even and odd indices
103
+
104
+ cos_emb = cos_emb.to(x1.device)
105
+ sin_emb = sin_emb.to(x1.device)
106
+
107
+ # Apply rotary transformation
108
+ x_rotated = torch.cat([x1 * cos_emb - x2 * sin_emb, x1 * sin_emb + x2 * cos_emb], dim=-1)
109
+ return x_rotated
110
+
111
+ class BinderGenerator(nn.Module):
112
+ def __init__(self, vocab_size=24, embed_dim=1280, num_heads=4, num_layers=4, lr=1e-4):
113
+ super(BinderGenerator, self).__init__()
114
+ self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
115
+ for param in self.esm.parameters():
116
+ param.requires_grad = False
117
+
118
+ self.transformer = nn.Transformer(d_model=embed_dim, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers)
119
+ self.fc_out = nn.Linear(embed_dim, vocab_size)
120
+
121
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.alphabet.padding_idx)
122
+ self.vocab_size = vocab_size
123
+ self.learning_rate = lr
124
+
125
+ def forward(self, binder_tokens, target_tokens):
126
+ with torch.no_grad():
127
+ binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int()
128
+ binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1)
129
+
130
+ target_pad_mask = (target_tokens != self.alphabet.padding_idx).int()
131
+ target_embed = self.esm(target_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * target_pad_mask.unsqueeze(-1)
132
+
133
+ binder_embed = binder_embed.transpose(0,1)
134
+ target_embed = target_embed.transpose(0,1)
135
+
136
+ binder_embed = RoPE(binder_embed) # [src_len, batch_size, embed_dim]
137
+ target_embed = RoPE(target_embed) # [tgt_len, batch_size, embed_dim]
138
+
139
+ output = self.transformer(binder_embed, target_embed) # [tgt_len, batch_size, embed_dim]
140
+ return self.fc_out(output).transpose(0,1) # [batch_size, tgt_len, vocab_size]
141
+
142
+ def compute_loss(self, binder_tokens, target_tokens):
143
+ output = self.forward(binder_tokens, target_tokens)
144
+
145
+ loss = self.criterion(output[:, :-1, :].reshape(-1, self.vocab_size), target_tokens[:, 1:].reshape(-1))
146
+
147
+ return loss
148
+
149
+ def step(self, batch, compute_acc=False):
150
+ binder_tokens = batch['binder_input_ids']
151
+ target_tokens = batch['target_input_ids']
152
+
153
+ binder_tokens = binder_tokens.to(device)
154
+ target_tokens = target_tokens.to(device)
155
+
156
+ loss = self.compute_loss(binder_tokens, target_tokens)
157
+
158
+ if compute_acc:
159
+ preds = torch.argmax(output[:-1], dim=-1)
160
+ correct = (preds == target_tokens[1:]).sum().item()
161
+ accuracy = correct / (target_tokens[1:] != self.alphabet.padding_idx).sum().item()
162
+ return loss, accuracy
163
+ else:
164
+ return loss
165
+
166
+
167
+ def train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size, max_epochs=10, accumulation_steps=4):
168
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4)
169
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4)
170
+
171
+ max_val_acc = 0
172
+ for epoch in range(max_epochs):
173
+ print(f"Epoch {epoch + 1}/{max_epochs}")
174
+
175
+ scaler = GradScaler()
176
+
177
+ model.train()
178
+ running_loss = 0.0
179
+ optimizer.zero_grad()
180
+
181
+ for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
182
+ batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} # Transfer batch to GPU
183
+
184
+ with autocast():
185
+ loss = model.step(batch)
186
+
187
+ scaler.scale(loss).backward()
188
+
189
+ if (batch_idx + 1) % accumulation_steps == 0:
190
+ scaler.step(optimizer)
191
+ scaler.update()
192
+ optimizer.zero_grad()
193
+
194
+ if scheduler.last_epoch < warmup_steps:
195
+ scheduler.step()
196
+ else:
197
+ cosine_scheduler.step()
198
+
199
+ running_loss += loss.item()
200
+
201
+ print(f"Epoch {epoch}: Training Loss = {running_loss / len(train_loader)}")
202
+
203
+ del train_loader, running_loss
204
+ gc.collect()
205
+ torch.cuda.empty_cache()
206
+
207
+ model.eval()
208
+ val_loss = 0.0
209
+ val_acc = 0.0
210
+ with torch.no_grad():
211
+ for batch in tqdm(val_loader, total=len(val_loader)):
212
+ batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
213
+ val_loss_batch, val_acc_batch = model.step(batch, compute_acc=True)
214
+ val_loss += val_loss_batch.item()
215
+ val_acc += val_acc_batch.item()
216
+
217
+ print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader)}\tVal Acc = {val_acc / len(val_dataset)}")
218
+
219
+ if val_acc > max_val_acc:
220
+ max_val_acc = val_acc
221
+
222
+ global checkpoint_path
223
+ torch.save(model.state_dict(), os.path.join(checkpoint_path, f"epoch={epoch}_acc={round(val_acc / len(val_dataset), 2)}"))
224
+
225
+
226
+
227
+ model = BinderGenerator(vocab_size=24, embed_dim=1280, num_heads=num_heads, num_layers=num_layers, lr=lr).to(device)
228
+ optimizer = AdamW(model.parameters(), lr=model.learning_rate, betas=(0.9, 0.95), weight_decay=1e-5)
229
+
230
+ total_steps = len(train_dataset) // (batch_size*accumulation_steps) * max_epochs # Assuming batch_size=32, max_epochs=10
231
+ warmup_steps = int(0.1 * total_steps)
232
+
233
+ scheduler = get_linear_schedule_with_warmup(
234
+ optimizer,
235
+ num_warmup_steps=warmup_steps,
236
+ num_training_steps=total_steps
237
+ )
238
+ cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0.1*lr)
239
+
240
+
241
+ train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size=batch_size, max_epochs=max_epochs, accumulation_steps=accumulation_steps)
finetune.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pytorch_lightning.strategies import DDPStrategy
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import DataLoader, DistributedSampler, BatchSampler, Sampler
7
+ from datasets import load_from_disk
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
10
+ Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ from torch.optim.lr_scheduler import _LRScheduler
13
+ from transformers.optimization import get_cosine_schedule_with_warmup
14
+ from argparse import ArgumentParser
15
+ import os
16
+ import uuid
17
+ import esm
18
+ import numpy as np
19
+ import torch.distributed as dist
20
+ from torch.nn.utils.rnn import pad_sequence
21
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
22
+ # from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
23
+ from torch.optim import Adam, AdamW
24
+ from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
25
+ import torch_geometric.nn as pyg_nn
26
+ import gc
27
+ import math
28
+
29
+ # os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
30
+ # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
31
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
32
+
33
+ vhse8_values = {
34
+ 'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
35
+ 'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
36
+ 'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80],
37
+ 'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
38
+ 'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41],
39
+ 'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
40
+ 'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36],
41
+ 'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10],
42
+ 'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
43
+ 'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
44
+ 'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
45
+ 'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01],
46
+ 'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
47
+ 'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65],
48
+ 'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56],
49
+ 'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
50
+ 'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39],
51
+ 'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85],
52
+ 'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
53
+ 'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
54
+ }
55
+
56
+ aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7}
57
+
58
+ vhse8_tensor = torch.zeros(24, 8)
59
+ for aa, values in vhse8_values.items():
60
+ aa_index = aa_to_idx[aa]
61
+ vhse8_tensor[aa_index] = torch.tensor(values)
62
+ vhse8_tensor.requires_grad = False
63
+
64
+
65
+ def collate_fn(batch):
66
+ # Unpack the batch
67
+ binders = []
68
+ mutants = []
69
+ wildtypes = []
70
+ affs = []
71
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
72
+
73
+ for b in batch:
74
+ binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1])
75
+ mutant = torch.tensor(b['mutant_input_ids']['input_ids'][1:-1])
76
+ wildtype = torch.tensor(b['wildtype_input_ids']['input_ids'][1:-1])
77
+
78
+ if binder.dim() == 0 or binder.numel() == 0 or mutant.dim() == 0 or mutant.numel() == 0 or wildtype.dim() == 0 or wildtype.numel() == 0:
79
+ continue
80
+ binders.append(binder) # shape: 1*L1 -> L1
81
+ mutants.append(mutant) # shape: 1*L2 -> L2
82
+ wildtypes.append(wildtype) # shape: 1*L3 -> L3
83
+
84
+ affs.append(b['aff'])
85
+
86
+
87
+ # Collate the tensors using torch's pad_sequence
88
+ try:
89
+ binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
90
+
91
+ mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id)
92
+
93
+ wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id)
94
+ except:
95
+ pdb.set_trace()
96
+
97
+ affs = torch.tensor(affs)
98
+ # Return the collated batch
99
+ return {
100
+ 'binder_input_ids': binder_input_ids.int(),
101
+ 'mutant_input_ids': mutant_input_ids.int(),
102
+ 'wildtype_input_ids': wildtype_input_ids.int(),
103
+ 'aff': affs
104
+ }
105
+
106
+
107
+ class CustomDataModule(pl.LightningDataModule):
108
+ def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
109
+ super().__init__()
110
+ self.train_dataset = train_dataset
111
+ self.val_dataset = val_dataset
112
+ self.batch_size = batch_size
113
+ self.tokenizer = tokenizer
114
+ print(len(train_dataset))
115
+ print(len(val_dataset))
116
+
117
+ def train_dataloader(self):
118
+ # batch_sampler = LengthAwareDistributedSampler(self.train_dataset, 'mutant_tokens', self.batch_size)
119
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
120
+ num_workers=8, pin_memory=True)
121
+
122
+ def val_dataloader(self):
123
+ # batch_sampler = LengthAwareDistributedSampler(self.val_dataset, 'mutant_tokens', self.batch_size)
124
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
125
+ pin_memory=True)
126
+
127
+ def setup(self, stage=None):
128
+ if stage == 'test' or stage is None:
129
+ pass
130
+
131
+
132
+ class CosineAnnealingWithWarmup(_LRScheduler):
133
+ def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
134
+ self.warmup_steps = warmup_steps
135
+ self.total_steps = total_steps
136
+ self.base_lr = base_lr
137
+ self.max_lr = max_lr
138
+ self.min_lr = min_lr
139
+ super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
140
+ print(f"SELF BASE LRS = {self.base_lrs}")
141
+
142
+ def get_lr(self):
143
+ if self.last_epoch < self.warmup_steps:
144
+ # Linear warmup phase from base_lr to max_lr
145
+ return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
146
+
147
+ # Cosine annealing phase from max_lr to min_lr
148
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
149
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
150
+ decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
151
+
152
+ return [decayed_lr for base_lr in self.base_lrs]
153
+
154
+
155
+ class muPPIt(pl.LightningModule):
156
+ def __init__(self, d_node, num_heads, dropout, margin, lr):
157
+ super(muPPIt, self).__init__()
158
+
159
+ self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
160
+ for param in self.esm.parameters():
161
+ param.requires_grad = False
162
+
163
+ self.attention = nn.MultiheadAttention(embed_dim=d_node, num_heads=num_heads)
164
+ self.layer_norm = nn.LayerNorm(d_node)
165
+
166
+ self.map = nn.Sequential(
167
+ nn.Linear(d_node, d_node // 2),
168
+ nn.SiLU(),
169
+ nn.Dropout(dropout),
170
+ nn.Linear(d_node // 2, 1)
171
+ )
172
+
173
+ self.margin = margin
174
+ self.learning_rate = lr
175
+
176
+ for layer in self.map:
177
+ if isinstance(layer, nn.Linear):
178
+ nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
179
+ if layer.bias is not None:
180
+ nn.init.zeros_(layer.bias)
181
+
182
+ def forward(self, binder_tokens, wt_tokens, mut_tokens):
183
+ device = binder_tokens.device
184
+ global vhse8_tensor
185
+
186
+ vhse8_tensor = vhse8_tensor.to(device)
187
+
188
+ with torch.no_grad():
189
+ binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int()
190
+ binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1)
191
+ binder_vhse8 = vhse8_tensor[binder_tokens]
192
+ binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1)
193
+
194
+ mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int()
195
+ mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * mut_pad_mask.unsqueeze(-1)
196
+ mut_vhse8 = vhse8_tensor[mut_tokens]
197
+ mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1)
198
+
199
+ wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int()
200
+ wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * wt_pad_mask.unsqueeze(-1)
201
+ wt_vhse8 = vhse8_tensor[wt_tokens]
202
+ wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1)
203
+
204
+ binder_wt = torch.concat([binder_embed, wt_embed], dim=1)
205
+ binder_mut = torch.concat([binder_embed, mut_embed], dim=1)
206
+
207
+ binder_wt_attn, _ = self.attention(binder_wt, binder_wt, binder_wt)
208
+ binder_mut_attn, _ = self.attention(binder_mut, binder_mut, binder_mut)
209
+
210
+ binder_wt_attn = binder_wt + binder_wt_attn
211
+ binder_mut_attn = binder_mut + binder_mut_attn
212
+
213
+ binder_wt_attn = self.layer_norm(binder_wt_attn)
214
+ binder_mut_attn = self.layer_norm(binder_mut_attn)
215
+
216
+ mapped_binder_wt = self.map(binder_wt_attn).squeeze(-1) # B*(L1+L2)
217
+ mapped_binder_mut = self.map(binder_mut_attn).squeeze(-1) # B*(L1+L2)
218
+
219
+ # mean_binder_wt = torch.mean(mapped_binder_wt, dim=1)
220
+ # mean_binder_mut = torch.mean(mapped_binder_mut, dim=1)
221
+
222
+ distance = torch.sqrt(torch.sum((mapped_binder_wt - mapped_binder_mut) ** 2, dim=-1))
223
+ return distance
224
+
225
+
226
+ def training_step(self, batch, batch_idx):
227
+ opt = self.optimizers()
228
+ lr = opt.param_groups[0]['lr']
229
+ self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
230
+
231
+ binder_tokens = batch['binder_input_ids'].to(self.device)
232
+ mut_tokens = batch['mutant_input_ids'].to(self.device)
233
+ wt_tokens = batch['wildtype_input_ids'].to(self.device)
234
+ aff = batch['aff'].to(self.device)
235
+
236
+ distance = self.forward(binder_tokens, wt_tokens, mut_tokens)
237
+
238
+ # pdb.set_trace()
239
+ # loss = torch.clamp(self.margin * aff - distance, min=0)
240
+ upper_loss = F.relu(distance - self.margin *(aff + 1)) # let distance < aff + 1
241
+ lower_loss = F.relu(self.margin * aff - distance) # let distance > aff
242
+
243
+ loss = 5 * upper_loss + lower_loss
244
+
245
+ self.log('train_loss', loss.mean().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
246
+ return loss.mean()
247
+
248
+ def validation_step(self, batch, batch_idx):
249
+ binder_tokens = batch['binder_input_ids'].to(self.device)
250
+ mut_tokens = batch['mutant_input_ids'].to(self.device)
251
+ wt_tokens = batch['wildtype_input_ids'].to(self.device)
252
+ aff = batch['aff'].to(self.device)
253
+
254
+ distance = self.forward(binder_tokens, wt_tokens, mut_tokens)
255
+
256
+ # pdb.set_trace()
257
+
258
+ # loss = torch.clamp(self.margin * aff - distance, min=0)
259
+ # accuracy = torch.sum(distance >= self.margin * aff) / aff.shape[0]
260
+ upper_loss = F.relu(distance - self.margin * (aff + 1))
261
+ lower_loss = F.relu(self.margin * aff - distance)
262
+
263
+ loss = 5 * upper_loss + lower_loss
264
+
265
+ accuracy = torch.sum(torch.logical_and(torch.ge(distance, self.margin * aff), torch.le(distance, self.margin *(aff + 1)))) / aff.shape[0]
266
+
267
+ self.log('val_loss', loss.mean().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
268
+ self.log('val_acc', accuracy.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
269
+
270
+ def configure_optimizers(self):
271
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
272
+
273
+ base_lr = 0.1 * self.learning_rate
274
+ max_lr = self.learning_rate
275
+ min_lr = 0.1 * self.learning_rate
276
+
277
+ schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=119, total_steps=1188,
278
+ base_lr=base_lr, max_lr=max_lr, min_lr=min_lr) # warmup_steps=3193, total_steps=31926
279
+
280
+ lr_schedulers = {
281
+ "scheduler": schedulers,
282
+ "name": 'learning_rate_logs',
283
+ "interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
284
+ 'frequency': 1 # The scheduler updates the learning rate after every batch
285
+ }
286
+ return [optimizer], [lr_schedulers]
287
+
288
+ def on_training_epoch_end(self, outputs):
289
+ gc.collect()
290
+ torch.cuda.empty_cache()
291
+ super().training_epoch_end(outputs)
292
+
293
+ def load_weights(self, checkpoint_path):
294
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
295
+
296
+ state_dict = checkpoint['state_dict']
297
+
298
+ self.load_state_dict(state_dict, strict=True)
299
+
300
+
301
+ def main(args):
302
+ print(args)
303
+ dist.init_process_group(backend='nccl')
304
+
305
+ train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/affinity_embedding_skempi') #408643
306
+ val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/affinity_embedding_skempi')
307
+ # val_dataset = None
308
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
309
+
310
+ data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
311
+
312
+ model = muPPIt(args.d_node, args.num_heads, args.dropout, args.margin, args.lr)
313
+ print(f"Loading Pre-trained Model from {args.sm}")
314
+ model = muPPIt.load_from_checkpoint(args.sm, d_node=args.d_node, num_heads=args.num_heads, dropout=args.dropout, margin=args.margin, lr=args.lr)
315
+
316
+ run_id = str(uuid.uuid4())
317
+
318
+ logger = WandbLogger(project=f"muppit_embedding",
319
+ # name="debug",
320
+ name=f"affinity_lr={args.lr}_gradclip={args.grad_clip}_margin={args.margin}",
321
+ job_type='model-training',
322
+ id=run_id)
323
+
324
+ print(f"Saving to {args.output_file}")
325
+
326
+ checkpoint_callback = ModelCheckpoint(
327
+ monitor='val_acc',
328
+ # monitor='val_loss',
329
+ dirpath=args.output_file,
330
+ # filename='model-{epoch:02d}-{val_loss:.2f}',
331
+ filename='model-{epoch:02d}-{val_acc:.2f}',
332
+ # filename='muppit',
333
+ save_top_k=-1,
334
+ mode='max',
335
+ # mode='min',
336
+ # every_n_train_steps=1000,
337
+ # save_on_train_epoch_end=False
338
+ )
339
+
340
+ early_stopping_callback = EarlyStopping(
341
+ # monitor='val_acc',
342
+ monitor='val_loss',
343
+ patience=10,
344
+ verbose=True,
345
+ # mode='max',
346
+ mode='min',
347
+ )
348
+
349
+ accumulator = GradientAccumulationScheduler(scheduling={0: 4})
350
+
351
+ trainer = pl.Trainer(
352
+ max_epochs=args.max_epochs,
353
+ accelerator='gpu',
354
+ strategy='ddp_find_unused_parameters_true',
355
+ precision='bf16',
356
+ # logger=logger,
357
+ devices=[0,1],
358
+ callbacks=[checkpoint_callback, accumulator],
359
+ gradient_clip_val=args.grad_clip,
360
+ # val_check_interval=100,
361
+ )
362
+
363
+ trainer.fit(model, datamodule=data_module)
364
+
365
+ best_model_path = checkpoint_callback.best_model_path
366
+ print(best_model_path)
367
+
368
+
369
+ if __name__ == "__main__":
370
+ parser = ArgumentParser()
371
+
372
+ parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
373
+ parser.add_argument("-lr", type=float, default=1e-3)
374
+ parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
375
+ parser.add_argument("-grad_clip", type=float, default=0.5)
376
+ parser.add_argument("-margin", type=float, default=0.5)
377
+ parser.add_argument("-max_epochs", type=int, default=30)
378
+ parser.add_argument("-d_node", type=int, default=1024, help="Node Representation Dimension")
379
+ parser.add_argument("-num_heads", type=int, default=4)
380
+ parser.add_argument("-dropout", type=float, default=0.1)
381
+ parser.add_argument("-sm", type=str, default='/home/tc415/muPPIt_embedding/checkpoints/train_10/model-epoch=15-val_acc=0.62.ckpt')
382
+
383
+ args = parser.parse_args()
384
+
385
+ main(args)
muppit/.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .idea/
2
+ .DS_Store
3
+ .ipynb_checkpoints/
4
+ __pycache__/
5
+ .hf_cache
6
+ outputs/
7
+ watch_folder/
muppit/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
muppit/README.md ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simple Guidance Mechanisms for Discrete Diffusion Models
2
+
3
+ [![arXiv](https://img.shields.io/badge/arXiv-2412.10193-red.svg)](https://arxiv.org/abs/2412.10193)
4
+ [![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](https://discrete-diffusion-guidance.github.io/)
5
+ [![deploy](https://img.shields.io/badge/Huggingface%20-UDLM%20-blue)](https://huggingface.co/collections/kuleshov-group/udlm-675e63ab42bc757093099e1b)
6
+
7
+ <p align="center">
8
+ <img src="https://discrete-diffusion-guidance.github.io/static/images/udlm.gif" alt="graphical abstract" width="450"/>
9
+ </p>
10
+
11
+ This repository contains code for reproducing experiments in the paper [Simple Guidance Mechanisms for Discrete Diffusion Models](https://arxiv.org/abs/2412.10193)
12
+
13
+ We also share [trained models](https://huggingface.co/collections/kuleshov-group/udlm-675e63ab42bc757093099e1b) on HuggingFace 🤗 and support intergration with these models.
14
+ See the "[Using HuggingFace Models" section](#using-huggingface-models) below.
15
+
16
+ ## Code Organization
17
+ <a name="code-organization"></a>
18
+ 1. ```main.py```: Routines for training (language models and classifiers)
19
+ 2. ```noise_schedule.py```: Noise schedules
20
+ 3. ```diffusion.py```: Forward/reverse diffusion
21
+ - Absorbing state / uniform noise diffusion
22
+ - AR
23
+ 4. ```dataloader.py```: Dataloaders
24
+ - For Discretized CIFAR10 and the Species10 datasets we use custom dataset classes defined in ```custom_datasets/```
25
+ 5. ```utils.py```: LR scheduler, logging, `fsspec` handling
26
+ 6. ```models/```: Denoising network architectures.
27
+ 7. ```configs/```: Config files for datasets/denoising networks/noise schedules/LR schedules
28
+ 8. ```scripts/```: Shell scripts for training/evaluation
29
+ 9. ```guidance_eval/```: Guidance evaluation scripts
30
+
31
+
32
+ ### Implemented Decoding Mechanisms
33
+ <a name="implemented-decoding"></a>
34
+ In [`diffusion.py`](./diffusion.py),
35
+ we define baseline and proposed decoding mechanisms for guidance.
36
+ These decoding schemes can be controlled via the hydra config with the `guidance` field.
37
+ For example, to use the proposed D-CFG guidance mechanism,
38
+ set `guidance=cfg` in the config file and optionally set the `guidance.gamma` parameter to control the strength of the guidance signal.
39
+
40
+ The implemented decoding methods are as follows:
41
+ - AR (Baseline):
42
+ - Standard decoding (i.e., no-guidance); set `guidance=null`
43
+ - Classifier-free guidance (D-CFG); set `guidance=cfg`
44
+ - Classifier-based guidance using [FUDGE](https://arxiv.org/abs/2104.05218) (set `guidance=fudge`) and using [PPLM](https://arxiv.org/abs/1912.02164) (set `guidance=pplm`)
45
+ - Diffusion:
46
+ - Standard decoding (i.e., no guidance); set `guidance=null`
47
+ - Classifier-free guidance (D-CFG); set `guidance=cfg`
48
+ - Classifier-based guidance (D-CBG); set `guidance=cbg`
49
+ - Classifier-based (baseline) method of [NOS](https://arxiv.org/abs/2305.20009); set `guidance=nos`
50
+
51
+ ### Implemented Generative Models
52
+ <a name="implemented-models"></a>
53
+ The three modeling parameterizations
54
+ we explore in this work are:
55
+ 1. Autoregressive (AR) Models
56
+ 2. Masked Diffusion Language Models (MDLM)
57
+ 3. Uniform Diffusion Language Models (UDLM)
58
+
59
+ The `config` files can be used
60
+ to specify which of these parameterizations to use.
61
+ Below we detail which config parameters correspond to which model.
62
+
63
+ **AR**
64
+ ```bash
65
+ diffusion="absorbing_state" # AR models can be thought of as a special case of abosrbing state diffusion models
66
+ parameterization="ar"
67
+ T=0 # N/A for AR models, this is a placeholder
68
+ time_conditioning=False # AR models are not conditioned on time
69
+ zero_recon_loss=False # N/A for this model
70
+ ```
71
+
72
+ **MDLM**
73
+ ```bash
74
+ diffusion="absorbing_state"
75
+ parameterization="subs" # See MDLM paper for details: https://arxiv.org/abs/2406.07524
76
+ T=0 # Indicates continuous-time, e.g. T --> infinity
77
+ time_conditioning=False # MDLM not conditioned on time
78
+ zero_recon_loss=False # N/A for this model
79
+ ```
80
+
81
+ **UDLM**
82
+ ```bash
83
+ diffusion="uniform"
84
+ parameterization="d3pm" # Indicates that we explicitly compute KL on posteriors
85
+ T=0 # Indicates continuous-time, e.g. T --> infinity
86
+ time_conditioning=True # UDLM is conditioned on time
87
+ zero_recon_loss=True # In continuous time, recon loss evaluates to zero
88
+ ```
89
+
90
+ ## Getting started in this repository
91
+ <a name="getting-started"></a>
92
+
93
+ To get started, create a conda environment containing the required dependencies.
94
+
95
+ ```bash
96
+ conda env create -f requirements.yaml
97
+ conda activate discdiff
98
+ ```
99
+
100
+ Create the following directories to store saved models and slurm logs:
101
+ ```bash
102
+ mkdir outputs
103
+ mkdir watch_folder
104
+ ```
105
+
106
+ We rely on `wandb` integration
107
+ to log experiments and eval curves.
108
+
109
+ ## Reproducing Experiments
110
+ <a name="reproducing-experiments"></a>
111
+
112
+ Below, we describe the steps required for reproducing the experiments in the paper.
113
+ Throughout, the main entry point for running experiments is the [`main.py`](./main.py) script.
114
+ We also provide sample `slurm` scripts for launching pre-training and evaluation experiments in the [`scrips/`](./scripts) directory.
115
+
116
+
117
+ ### Language Modeling Experiments
118
+ <a name="lm_training"></a>
119
+ To reproduce the language modeling results, please refer to the following shell scripts in the [`scripts/`](./scripts) directory:
120
+ - Species10: [`train_ten_species_guidance.sh`](./scripts/train_ten_species_guidance.sh)
121
+ - QM9: [`train_qm9_no-guidance.sh`](./scripts/train_qm9_no-guidance.sh)
122
+ - CIFAR10: [`train_cifar10_unet_guidance.sh`](./scripts/train_cifar10_unet_guidance.sh)
123
+ - text8: [`train_text8.sh`](./scripts/train_text8.sh)
124
+ - Amazon Polarity: [`train_amazon_polarity.sh`](./scripts/train_amazon_polarity.sh)
125
+ - LM1B: [`train_lm1b.sh`](./scripts/train_lm1b.sh)
126
+
127
+ Each script contains a comment detailing the usage.
128
+ For example, to train either an AR,
129
+ MDLM, or UDLM model on the `text8` dataset, use the following command:
130
+ ```bash
131
+ cd scripts/
132
+ MODEL=<ar|mdlm|udlm>
133
+ sbatch \
134
+ --export=ALL,MODEL=${MODEL} \
135
+ --job-name=train_text8_${MODEL} \
136
+ train_text8.sh
137
+ ```
138
+ ### Guidance Training
139
+ <a name="guidance-training"></a>
140
+ #### Classifier-Free
141
+ <a name="guidance-training-cfg"></a>
142
+ For classifier-free guidance we require training models
143
+ that can condition on the class label
144
+ to model conditional distributions,
145
+ and we randomly mask out the signal,
146
+ replacing it with a dummy value of `num_claseses + 1`, to simulate an unconditional model.
147
+ Refer to the shell scripts with the `_guidance` suffix
148
+ to train these models for CIFAR10,
149
+ QM9, and Species10 datasets.
150
+ For QM9, we have two experiments,
151
+ one where we condition on the drug-likeness
152
+ (`qed`)
153
+ of the molecules and another
154
+ where we condition on the ring counts (`ring_count`).
155
+
156
+ #### Classifier-Based
157
+ <a name="guidance-training-cbg"></a>
158
+ For classifier-based guidance,
159
+ we need to train a classifier on the noisy latent samples.
160
+ Refer to the following shell scripts
161
+ to train these classifiers:
162
+ - [FUDGE](https://arxiv.org/abs/2104.05218) (AR guidance): [`train_qm9_fudge_classifier.sh`](./scripts/train_qm9_fudge_classifier.sh)
163
+ - D-CBG (diffusion guidance): [`train_qm9_classifier.sh`](./scripts/train_qm9_classifier.sh)
164
+
165
+ ##### PPLM / NOS baselines
166
+ An alternative classifier-based guidance mechanism to D-CBG is that of [PPLM](https://arxiv.org/abs/1912.02164)
167
+ (which was adapted for diffusion models in [NOS](https://arxiv.org/abs/2305.20009)).
168
+ To train these classifiers,
169
+ refer to the following shell script:
170
+ [`train_qm9_pplm_classifier.sh`](./scripts/train_qm9_pplm_classifier.sh)
171
+ (for both PPLM and NOS classifiers).
172
+
173
+ ### Guidance Evaluation
174
+ <a name="guidance-eval"></a>
175
+ To evaluate guidance mechanisms, we load trained models
176
+ (and classifiers, if applicable)
177
+ and generate some number of samples
178
+ for which we compute "quality" metrics
179
+ (e.g., validity/novelty in the QM9 experiments)
180
+ and control label satisfaction (e.g., mean value of novel generated molecules for the property of interest in the QM9 experiments).
181
+
182
+ The scripts for these evaluations can be found in the [`guidance_eval/`](./guidance_eval) directory.
183
+ To run these evaluations, please refer to the following shell scripts:
184
+ - QM9: [`eval_qm9_guidance.sh`](./guidance_eval/eval_qm9_guidance.sh)
185
+ - Species10: [`eval_ten_species_guidance.sh`](./guidance_eval/eval_ten_species_guidance.sh)
186
+ - For this dataset, we also evaluate the accuracy of a HyenaDNA classifier on correctly classifying generated sequences.
187
+ This model can be trained using [`train_ten_species_eval_classifier.sh`](./scripts/train_ten_species_eval_classifier.sh).
188
+ - To see how this trained evaluation classifier performs on the validation set of the original data use this notebook [`eval_hyenadna_classifier.ipynb`](./notebooks/eval_hyenadna_classifier.ipynb).
189
+
190
+ In the paper,
191
+ we performed an extensive hyperparameter sweep for our proposed guidance mechanisms and for baselines.
192
+ The shell scripts can be used
193
+ to reproduce these experiments,
194
+ e.g., for the D-CFG experiments on QM9:
195
+ ```bash
196
+ export MODEL=<ar|mdlm|udlm>
197
+ export PROP=<qed|ring_count>
198
+ export GUIDANCE=cfg
199
+ for GAMMA in $(seq 1 5); do
200
+ sbatch \
201
+ --export=ALL,MODEL=${MODEL},PROP=${PROP},GUIDANCE=${GUIDANCE},GAMMA=${GAMMA} \
202
+ --job-name=eval_qm9_${GUIDANCE}_${PROP}_${MODEL}_GAMMA-${GAMMA} \
203
+ eval_qm9_guidance.sh
204
+ done
205
+ ```
206
+
207
+ Once each evaluation run is complete,
208
+ a `.csv` file
209
+ containing the results is saved in the run directory of the trained generative model.
210
+
211
+ ## Using HuggingFace Models
212
+ <a name="hf_models"></a>
213
+ We provide pre-trained models on HuggingFace 🤗:
214
+ - UDLM trained on LM1B: [kuleshov-group/udlm-lm1b](https://huggingface.co/kuleshov-group/udlm-lm1b)
215
+ - UDLM trained on QM9: [kuleshov-group/udlm-qm9](https://huggingface.co/kuleshov-group/udlm-qm9)
216
+ - Note: this model was trained without guidance and can be used with classifier-free guidance.
217
+
218
+ Please see the README pages for these models on HuggingFace or our paper for more details about the training of these models.
219
+
220
+ To use these models, you can load them using the HuggingFace API, e.g.,
221
+ ```python
222
+ from transformers import AutoModelForMaskedLM
223
+
224
+ model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/udlm-lm1b")
225
+ ```
226
+
227
+ To use these models in our repository, set the following `config` parameters:
228
+ ```bash
229
+ backbone="hf_dit"
230
+ model="hf"
231
+ model.pretrained_model_name_or_path="kuleshov-group/udlm-lm1b" # or "kuleshov-group/udlm-qm9"
232
+ ```
233
+
234
+ ## Acknowledgements
235
+ <a name="acknowledgements"></a>
236
+ This repository was built off of [MDLM](https://github.com/kuleshov-group/mdlm),
237
+ which in used [SEDD](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion).
238
+ Our code implementation of D-CBG is adapted from Nisonoff et al.'s [repo](https://github.com/hnisonoff/discrete_guidance).
239
+
240
+ ## Citation
241
+ <a name="citation"></a>
242
+ ```
243
+ @article{
244
+ schiff2024discreteguidance,
245
+ title={Simple Guidance Mechanisms for Discrete Diffusion Models},
246
+ author={Schiff, Yair and Sahoo, Subham Sekhar and Phung, Hao and Wang, Guanghan and Boshar, Sam and Dalla-torre, Hugo and de Almeida, Bernardo P and Rush, Alexander and Pierrot, Thomas and Kuleshov, Volodymyr},
247
+ journal={arXiv preprint arXiv:2412.10193},
248
+ year={2024}
249
+ }
250
+ ```
muppit/__pycache__/classifier.cpython-310.pyc ADDED
Binary file (13.8 kB). View file
 
muppit/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
muppit/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (33.4 kB). View file
 
muppit/__pycache__/noise_schedule.cpython-310.pyc ADDED
Binary file (6.19 kB). View file
 
muppit/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.77 kB). View file
 
muppit/classifier.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import typing
3
+
4
+ import hydra.utils
5
+ import lightning as L
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchmetrics
9
+ import transformers
10
+
11
+ import dataloader
12
+ import models.dit
13
+ import noise_schedule
14
+
15
+
16
+ class MicroAveragingMetric(torchmetrics.Metric):
17
+ """Micro-averaging metric.
18
+
19
+ Adapted from https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py#L12
20
+ """
21
+
22
+ def __init__(self, class_idx: typing.Optional[int] = 1,
23
+ dist_sync_on_step=False):
24
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
25
+ self.class_idx = torch.tensor(class_idx) \
26
+ if class_idx is not None else None
27
+ self.add_state("numerator", default=torch.tensor(0.0),
28
+ dist_reduce_fx="sum")
29
+ self.add_state("denominator", default=torch.tensor(0.0),
30
+ dist_reduce_fx="sum")
31
+
32
+ def _update(
33
+ self, numerator, denominator, preds, y) -> tuple:
34
+ raise NotImplementedError
35
+
36
+ def update(self, logits: torch.Tensor, y: torch.Tensor):
37
+ # update metric states
38
+ preds = torch.argmax(logits, dim=-1)
39
+ y = y.view(-1)
40
+ assert preds.shape == y.shape, \
41
+ f"preds shape {preds.shape} != y shape {y.shape}"
42
+ self.numerator, self.denominator = self._update(
43
+ self.numerator, self.denominator, preds, y)
44
+
45
+ def compute(self):
46
+ # compute final result
47
+ value = self.numerator.float() / self.denominator \
48
+ if self.denominator.item() > 0. else torch.tensor(0.0)
49
+ return value
50
+
51
+ def reset(self):
52
+ self.numerator = torch.tensor(0.0).to(self.device)
53
+ self.denominator = torch.tensor(0.0).to(self.device)
54
+
55
+
56
+ class CrossEntropy(MicroAveragingMetric):
57
+ """Calculates cross-entropy loss."""
58
+ def _update(
59
+ self, numerator, denominator, logits, y) -> tuple:
60
+ with torch.no_grad():
61
+ numerator += F.cross_entropy(
62
+ logits.view(-1, logits.size(-1)),
63
+ y.view(-1),
64
+ ignore_index=-100,
65
+ reduction='sum')
66
+ denominator += y.numel()
67
+ return numerator, denominator
68
+
69
+ # Overrides parent class to use logits and not (argmax) preds
70
+ def update(self, logits: torch.Tensor, y: torch.Tensor):
71
+ y = y.view(-1)
72
+ self.numerator, self.denominator = self._update(
73
+ self.numerator, self.denominator, logits, y)
74
+
75
+
76
+ class Accuracy(MicroAveragingMetric):
77
+ """Calculates accuracy.
78
+
79
+ Can be used to calculate accuracy per class.
80
+ Copied from:
81
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
82
+ """
83
+
84
+ def _update(
85
+ self, numerator, denominator, preds, y) -> tuple:
86
+ if self.class_idx is None:
87
+ numerator += (preds == y).sum()
88
+ denominator += y.numel()
89
+ else:
90
+ class_idx = self.class_idx
91
+ relevant_idxs = (y == class_idx)
92
+ numerator += (preds[relevant_idxs] == class_idx).sum()
93
+ denominator += relevant_idxs.sum()
94
+ relevant_idxs = (y != class_idx)
95
+ numerator += (preds[relevant_idxs] != class_idx).sum()
96
+ denominator += relevant_idxs.sum()
97
+ return numerator, denominator
98
+
99
+
100
+ class Precision(MicroAveragingMetric):
101
+ """Calculates precision.
102
+
103
+ Can be used to calculate precision per class.
104
+ Adapted from:
105
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
106
+ """
107
+
108
+ def _update(self, numerator, denominator, preds, y) -> tuple:
109
+ class_idx = self.class_idx
110
+ relevant_idxs = (preds == class_idx)
111
+ numerator += (y[relevant_idxs] == class_idx).sum()
112
+ denominator += relevant_idxs.sum()
113
+ return numerator, denominator
114
+
115
+
116
+ class Recall(MicroAveragingMetric):
117
+ """Calculate recall.
118
+
119
+ Can be used to calculate recall per class.
120
+ Adapted from:
121
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
122
+ """
123
+
124
+ def _update(self, numerator, denominator, preds, y) -> tuple:
125
+ class_idx = self.class_idx
126
+ relevant_idxs = (y == class_idx)
127
+ numerator += (preds[relevant_idxs] == class_idx).sum()
128
+ denominator += relevant_idxs.sum()
129
+ return numerator, denominator
130
+
131
+
132
+ class Classifier(L.LightningModule):
133
+ def __init__(
134
+ self,
135
+ config,
136
+ tokenizer: transformers.PreTrainedTokenizer,
137
+ pretrained_backbone: typing.Optional[torch.nn.Module] = None):
138
+ super().__init__()
139
+ self.save_hyperparameters(ignore=['pretrained_backbone'])
140
+ self.config = config
141
+
142
+ # This param indicates whether this model will be used
143
+ # for guidance (False) or only evaluation (True).
144
+ self.is_eval_classifier = getattr(
145
+ config, 'is_eval_classifier', False)
146
+
147
+ self.tokenizer = tokenizer
148
+ self.vocab_size = tokenizer.vocab_size
149
+ self.antithetic_sampling = config.training.antithetic_sampling
150
+ self.importance_sampling = config.training.importance_sampling
151
+ self.change_of_variables = config.training.change_of_variables
152
+ if (not hasattr(self.tokenizer, 'mask_token')
153
+ or self.tokenizer.mask_token is None):
154
+ self.mask_index = self.vocab_size
155
+ self.vocab_size += 1
156
+ else:
157
+ self.mask_index = self.tokenizer.mask_token_id
158
+
159
+ if config.classifier_backbone == 'dit':
160
+ self.classifier_model = models.dit.DITClassifier(
161
+ self.config, vocab_size=self.vocab_size)
162
+ elif self.config.classifier_backbone == 'dimamba':
163
+ self.classifier_model = models.dimamba.DiMambaClassifier(
164
+ self.config, vocab_size=self.vocab_size,
165
+ pad_token_id=self.tokenizer.pad_token_id)
166
+ elif config.classifier_backbone == 'hyenadna':
167
+ hyena_config = transformers.AutoConfig.from_pretrained(
168
+ config.classifier_model.hyena_model_name_or_path,
169
+ n_layer=config.classifier_model.n_layer,
170
+ trust_remote_code=True
171
+ )
172
+ self.classifier_model = transformers.AutoModelForSequenceClassification.from_config(
173
+ hyena_config,
174
+ pretrained=False,
175
+ num_labels=config.data.num_classes,
176
+ problem_type='single_label_classification',
177
+ trust_remote_code=True
178
+ )
179
+ else:
180
+ raise NotImplementedError(
181
+ f"Classifier backbone "
182
+ f"{self.config.classifier_backbone} not "
183
+ f"implemented.")
184
+ if pretrained_backbone is not None: # For PPLM / NOS
185
+ self.classifier_model.load_pretrained_encoder(
186
+ pretrained_backbone)
187
+ # Metrics are automatically reset at end of epoch
188
+ metrics = torchmetrics.MetricCollection({
189
+ 'cross_entropy': CrossEntropy(),
190
+ 'accuracy': Accuracy(class_idx=None),
191
+ })
192
+ if config.data.num_classes > 2:
193
+ for c in range(config.data.num_classes):
194
+ metrics.add_metrics(
195
+ {f"accuracy_class{c}": Accuracy(class_idx=c),
196
+ f"precision_class{c}": Precision(class_idx=c),
197
+ f"recall_class{c}": Recall(class_idx=c)})
198
+ else:
199
+ metrics.add_metrics(
200
+ {'precision': Precision(class_idx=1),
201
+ 'recall': Recall(class_idx=1)})
202
+ metrics.set_dtype(torch.float64)
203
+ self.train_metrics = metrics.clone(prefix='train/')
204
+ self.valid_metrics = metrics.clone(prefix='val/')
205
+
206
+ self.T = config.T
207
+ self.noise = noise_schedule.get_noise(config,
208
+ dtype=self.dtype)
209
+ self.sampling_eps = config.training.sampling_eps
210
+ self.lr = config.optim.lr
211
+ self.time_conditioning = config.time_conditioning
212
+ self.fast_forward_epochs = None
213
+ self.fast_forward_batches = None
214
+
215
+ def on_load_checkpoint(self, checkpoint):
216
+ # Copied from:
217
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
218
+ self.fast_forward_epochs = checkpoint['loops'][
219
+ 'fit_loop']['epoch_progress']['current']['completed']
220
+ self.fast_forward_batches = checkpoint['loops'][
221
+ 'fit_loop']['epoch_loop.batch_progress'][
222
+ 'current']['completed']
223
+
224
+ def on_save_checkpoint(self, checkpoint):
225
+ # Copied from:
226
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
227
+ # ['epoch_loop.batch_progress']['total']['completed'] is
228
+ # 1 iteration behind, so we're using the optimizer's
229
+ # progress.
230
+ checkpoint['loops']['fit_loop'][
231
+ 'epoch_loop.batch_progress']['total'][
232
+ 'completed'] = checkpoint['loops']['fit_loop'][
233
+ 'epoch_loop.automatic_optimization.optim_progress'][
234
+ 'optimizer']['step']['total'][
235
+ 'completed'] * self.trainer.accumulate_grad_batches
236
+ checkpoint['loops']['fit_loop'][
237
+ 'epoch_loop.batch_progress']['current'][
238
+ 'completed'] = checkpoint['loops']['fit_loop'][
239
+ 'epoch_loop.automatic_optimization.optim_progress'][
240
+ 'optimizer']['step']['current'][
241
+ 'completed'] * self.trainer.accumulate_grad_batches
242
+ # _batches_that_stepped tracks the number of global
243
+ # steps, not the number of local steps, so we don't
244
+ # multiply with self.trainer.accumulate_grad_batches
245
+ # here.
246
+ checkpoint['loops']['fit_loop'][
247
+ 'epoch_loop.state_dict'][
248
+ '_batches_that_stepped'] = \
249
+ checkpoint['loops']['fit_loop'][
250
+ 'epoch_loop.automatic_optimization.optim_progress'][
251
+ 'optimizer']['step']['total']['completed']
252
+ if 'sampler' not in checkpoint.keys():
253
+ checkpoint['sampler'] = {}
254
+ if hasattr(self.trainer.train_dataloader.sampler,
255
+ 'state_dict'):
256
+ sampler_state_dict = self.trainer. \
257
+ train_dataloader.sampler.state_dict()
258
+ checkpoint['sampler'][
259
+ 'random_state'] = sampler_state_dict.get(
260
+ 'random_state', None)
261
+ else:
262
+ checkpoint['sampler']['random_state'] = None
263
+
264
+ def on_train_start(self):
265
+ # Adapted from:
266
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
267
+ distributed = (
268
+ self.trainer._accelerator_connector.use_distributed_sampler
269
+ and self.trainer._accelerator_connector.is_distributed)
270
+ if distributed:
271
+ sampler_cls = dataloader.FaultTolerantDistributedSampler
272
+ else:
273
+ sampler_cls = dataloader.RandomFaultTolerantSampler
274
+ updated_dls = []
275
+ for dl in self.trainer.fit_loop._combined_loader.flattened:
276
+ if hasattr(dl.sampler, 'shuffle'):
277
+ dl_sampler = sampler_cls(
278
+ dl.dataset, shuffle=dl.sampler.shuffle)
279
+ else:
280
+ dl_sampler = sampler_cls(dl.dataset)
281
+ if (distributed
282
+ and self.fast_forward_epochs is not None
283
+ and self.fast_forward_batches is not None):
284
+ dl_sampler.load_state_dict({
285
+ 'epoch': self.fast_forward_epochs,
286
+ 'counter': (self.fast_forward_batches
287
+ * self.config.loader.batch_size)})
288
+ updated_dls.append(
289
+ torch.utils.data.DataLoader(
290
+ dl.dataset,
291
+ batch_size=self.config.loader.batch_size,
292
+ num_workers=self.config.loader.num_workers,
293
+ pin_memory=self.config.loader.pin_memory,
294
+ sampler=dl_sampler,
295
+ shuffle=False,
296
+ persistent_workers=self.config.loader.persistent_workers
297
+ ))
298
+ self.trainer.fit_loop._combined_loader.flattened = updated_dls
299
+
300
+ def forward(self, x, sigma=None, x_emb=None, attention_mask=None):
301
+ """Returns logits.
302
+
303
+ x_emb can be provided during PPLM / NoS-style guidance
304
+ (see: https://arxiv.org/abs/2305.20009).
305
+ """
306
+ if self.is_eval_classifier:
307
+ logits = self.classifier_model(x)
308
+ if hasattr(logits, 'logits'):
309
+ logits = logits.logits
310
+ else:
311
+ sigma = self._process_sigma(sigma) if sigma is not None else sigma
312
+ with torch.cuda.amp.autocast(dtype=torch.float32):
313
+ logits = self.classifier_model(x, sigma, x_emb=x_emb, attention_mask=attention_mask)
314
+ return logits
315
+
316
+ def get_log_probs(self, x, sigma, x_emb=None):
317
+ """Returns log probabilities.
318
+ Use for CBG-style guidance.
319
+ """
320
+ if self.is_eval_classifier:
321
+ raise NotImplementedError(
322
+ '`get_log_prob` not implemented for classifiers '
323
+ 'that are meant to be used for evaluation purposes '
324
+ 'only.')
325
+ with torch.cuda.amp.autocast(dtype=torch.float32):
326
+ return torch.nn.functional.log_softmax(
327
+ self.forward(x, sigma, x_emb=x_emb), dim=-1)
328
+
329
+ def training_step(self, batch, batch_idx):
330
+ loss = self._compute_loss(batch, prefix='train')
331
+ self.log(name='trainer/loss',
332
+ value=loss.item(),
333
+ on_step=True,
334
+ on_epoch=False,
335
+ sync_dist=True,
336
+ prog_bar=True)
337
+ self.log(name='lr',
338
+ value=
339
+ self.trainer.optimizers[0].param_groups[0][
340
+ 'lr'],
341
+ on_step=True,
342
+ on_epoch=False,
343
+ sync_dist=True,
344
+ prog_bar=True, logger=False)
345
+ return loss
346
+
347
+ def validation_step(self, batch, batch_idx):
348
+ return self._compute_loss(batch, prefix='val')
349
+
350
+ def configure_optimizers(self):
351
+ # TODO(yair): Lightning currently giving this warning when using `fp16`:
352
+ # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
353
+ # Not clear if this is a problem or not.
354
+ # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
355
+ optimizer = torch.optim.AdamW(
356
+ itertools.chain(self.classifier_model.parameters(),
357
+ self.noise.parameters()),
358
+ lr=self.config.optim.lr,
359
+ betas=(self.config.optim.beta1,
360
+ self.config.optim.beta2),
361
+ eps=self.config.optim.eps,
362
+ weight_decay=self.config.optim.weight_decay)
363
+
364
+ scheduler = hydra.utils.instantiate(
365
+ self.config.lr_scheduler, optimizer=optimizer)
366
+ scheduler_dict = {
367
+ 'scheduler': scheduler,
368
+ 'interval': 'step',
369
+ 'monitor': 'val/loss',
370
+ 'name': 'trainer/lr',
371
+ }
372
+ return [optimizer], [scheduler_dict]
373
+
374
+ def _q_xt(self, x, move_chance):
375
+ """Computes the noisy sample xt.
376
+
377
+ Args:
378
+ x: int torch.Tensor with shape (batch_size,
379
+ diffusion_model_input_length), input.
380
+ move_chance: float torch.Tensor with shape
381
+ (batch_size, 1).
382
+ """
383
+ move_indices = torch.rand(
384
+ *x.shape, device=x.device) < move_chance
385
+ if self.config.diffusion == 'absorbing_state':
386
+ return torch.where(move_indices, self.mask_index, x)
387
+ if self.config.diffusion == 'uniform':
388
+ uniform_tensor = torch.randint(
389
+ 0, self.vocab_size, x.shape, device=x.device)
390
+ return torch.where(move_indices, uniform_tensor, x)
391
+ raise NotImplementedError(
392
+ f'Diffusion type {self.config.diffusion} not '
393
+ 'implemented.')
394
+
395
+ def _compute_loss(self, batch, prefix):
396
+ x0 = batch['input_ids']
397
+ attention_mask = batch['attention_mask']
398
+ t = None
399
+ if self.is_eval_classifier:
400
+ logits = self.forward(x0)
401
+ elif self.config.parameterization == 'ar':
402
+ # do not add noise for AR FUDGE and AR PPLM
403
+ logits = self.forward(
404
+ x0, attention_mask=attention_mask)
405
+ else:
406
+ t = self._sample_t(x0.shape[0])
407
+ if self.T > 0:
408
+ t = (t * self.T).to(torch.int)
409
+ t = t / self.T
410
+ # t \in {1/T, 2/T, ..., 1}
411
+ t += (1 / self.T)
412
+ if self.change_of_variables:
413
+ time_conditioning = t[:, None]
414
+ f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
415
+ f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
416
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
417
+ move_chance = move_chance[:, None]
418
+ else:
419
+ sigma, _ = self.noise(t)
420
+ time_conditioning = sigma[:, None]
421
+ move_chance = 1 - torch.exp(-sigma[:, None])
422
+
423
+ xt = self._q_xt(x0, move_chance)
424
+ logits = self.forward(xt, time_conditioning, attention_mask=attention_mask)
425
+ if hasattr(self.config.data, 'label_col'):
426
+ if f"{self.config.data.label_col}_threshold" in batch:
427
+ y = batch[f"{self.config.data.label_col}_threshold"]
428
+ else:
429
+ y = batch[self.config.data.label_col]
430
+ else:
431
+ y = batch['label']
432
+ if (not self.is_eval_classifier
433
+ and getattr(self.config.training, 'use_label_smoothing', False)):
434
+ # Interpolate between one-hot and uniform distribution
435
+ labels = (torch.nn.functional.one_hot(y, self.config.data.num_classes) * (1 - t)[..., None] +
436
+ (1 / self.config.data.num_classes) * t[..., None])
437
+ else:
438
+ labels = y.view(-1)
439
+ if getattr(self.config, 'is_fudge_classifier', False):
440
+ expanded_y = y.unsqueeze(1).expand(-1, logits.shape[1]) # batch x seq
441
+ logits = logits.view(-1, self.config.data.num_classes)[attention_mask.flatten()==1, ...]
442
+ y = expanded_y.flatten().long()[attention_mask.flatten()==1]
443
+ loss = torch.nn.functional.cross_entropy(
444
+ logits,
445
+ y,
446
+ ignore_index=-100,
447
+ reduction='mean')
448
+ else:
449
+ loss = torch.nn.functional.cross_entropy(
450
+ logits.view(-1, logits.size(-1)),
451
+ labels,
452
+ ignore_index=-100,
453
+ reduction='mean')
454
+
455
+ if prefix == 'train':
456
+ self.train_metrics.update(logits, y)
457
+ metrics = self.train_metrics
458
+ elif prefix == 'val':
459
+ self.valid_metrics.update(logits, y)
460
+ metrics = self.valid_metrics
461
+ elif prefix == 'test':
462
+ self.test_metrics.update(logits, y)
463
+ metrics = self.test_metrics
464
+ else:
465
+ raise ValueError(f'Invalid prefix: {prefix}')
466
+
467
+ self.log_dict(metrics,
468
+ on_step=False,
469
+ on_epoch=True,
470
+ sync_dist=True)
471
+ return loss
472
+
473
+ def _sample_t(self, n):
474
+ _eps_t = torch.rand(n, device=self.device)
475
+ if self.antithetic_sampling:
476
+ offset = torch.arange(n, device=self.device) / n
477
+ _eps_t = (_eps_t / n + offset) % 1
478
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
479
+ if self.importance_sampling:
480
+ return self.noise.importance_sampling_transformation(
481
+ t)
482
+ return t
483
+
484
+ def _process_sigma(self, sigma):
485
+ if sigma.ndim > 1:
486
+ sigma = sigma.squeeze(-1)
487
+ if not self.time_conditioning:
488
+ sigma = torch.zeros_like(sigma)
489
+ assert sigma.ndim == 1, sigma.shape
490
+ return sigma
muppit/configs/callbacks/checkpoint_every_n_steps.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ checkpoint_every_n_steps:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
4
+ save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
5
+ dirpath: ${checkpointing.save_dir}/checkpoints
6
+ verbose: True
7
+ auto_insert_metric_name: False
8
+ # every_n_train_steps: 500
muppit/configs/callbacks/checkpoint_monitor.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_monitor:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ monitor: val/nll # name of the logged metric which determines when model is improving
4
+ mode: min # can be "max" or "min"
5
+ save_top_k: 1 # save k best models (determined by above metric)
6
+ save_last: False # True = additionally always save model from last epoch
7
+ dirpath: ${checkpointing.save_dir}/checkpoints
8
+ filename: best
9
+ auto_insert_metric_name: False
10
+ verbose: True
muppit/configs/callbacks/learning_rate_monitor.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ learning_rate_monitor:
2
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
3
+ logging_interval: step
muppit/configs/classifier_model/dimamba-classifier.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dimamba
2
+ type: dimamba
3
+ hidden_size: 256
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 8
7
+ scale_by_sigma: True
8
+ dropout: 0.1
9
+ tie_word_embeddings: False
10
+ bidirectional: True,
11
+ bidirectional_strategy: add
12
+ bidirectional_weight_tie: True
13
+ num_classes: ${data.num_classes}
14
+ pooling: mean
muppit/configs/classifier_model/hyenadna-classifier.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: hyena-32k
2
+ type: hyenadna
3
+ hyena_model_name_or_path: ???
4
+ n_layer: 4
muppit/configs/classifier_model/small-classifier.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 12
7
+ n_heads: 12
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ num_classes: ${data.num_classes}
11
+ pooling: mean
muppit/configs/classifier_model/tiny-classifier.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ num_classes: ${data.num_classes}
11
+ pooling: mean
muppit/configs/classifier_model/tiny-dimamba-classifier.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: dimamba
3
+ hidden_size: 128
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 4
7
+ scale_by_sigma: True
8
+ dropout: 0.1
9
+ tie_word_embeddings: False
10
+ bidirectional: True,
11
+ bidirectional_strategy: add
12
+ bidirectional_weight_tie: True
13
+ num_classes: ${data.num_classes}
14
+ pooling: mean
muppit/configs/config.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
4
+ - /data: peptide
5
+ - /model: small
6
+ - /strategy: ddp
7
+ - /noise: loglinear
8
+ - /lr_scheduler: cosine_decay_warmup # constant_warmup
9
+ - /classifier_model: null
10
+ - /guidance: cbg
11
+
12
+ mode: ppl_eval # train / train_classifier / ppl_eval
13
+ diffusion: uniform # absorbing_state / uniform
14
+ backbone: dit # dit / dimamba / ar
15
+ classifier_backbone: null
16
+ parameterization: d3pm # subs / d3pm / ar
17
+ time_conditioning: True # UDLM is conditioned on time
18
+ subs_masking: False
19
+ zero_recon_loss: True # Use for UDLM
20
+ T: 0 # 0 (continuous time) / 1000
21
+
22
+ is_vision: False
23
+ seed: 13
24
+
25
+ loader:
26
+ global_batch_size: 512
27
+ eval_global_batch_size: ${.global_batch_size}
28
+ # Note: batch_size and eval_batch_size are **per machine**
29
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
30
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
31
+ num_workers: 0 # ${eval:"len(__import__('os').sched_getaffinity(0))"}
32
+ pin_memory: True
33
+ persistent_workers: False # True
34
+
35
+ sampling:
36
+ use_cache: True
37
+ steps: 128
38
+ # Note: batch_size is **per machine**
39
+ batch_size: 1 # ${loader.eval_batch_size}
40
+ num_sample_batches: 50 # Total samples: `num_gpus` * `batch_size` * `num_sample_batches`
41
+ use_float64: False
42
+
43
+ eval:
44
+ checkpoint_path: '/home/tc415/muPPIt_embedding/muppit/model_path/PeptideUDLM.ckpt' # Used to evaluate a checkpoint after training.
45
+ wildtype: 'MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCSRLHNKPTFSQTIALLNIYRNPQNSSQSADGLRCAVSDVEMQEHYDEFFEEVFTEMEEKYGEVEEMNVCDNLGDHLVGNVYVKFRREEDAEKAVIDLNNRWFNGQPIHAELSPVTDFREACCRQYEMGECTRGGFCNFMHLKPISRELRRELYGRRRKKHRSRSRSRERRSRSRDRGRGGGGGGGGGGGGRERDRRRSRDRERSGRF'
46
+ mutant: 'MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCFRLHNKPTFSQTIALLNIYRNPQNSSQSADGLRCAVSDVEMQEHYDEFFEEVFTEMEEKYGEVEEMNVCDNLGDHLVGNVYVKFRREEDAEKAVIDLNNRWFNGQPIHAELSPVTDFREACCRQYEMGECTRGGFCNFMHLKPISRELRRELYGRRRKKHRSRSRSRERRSRSRDRGRGGGGGGGGGGGGRERDRRRSRDRERSGRF'
47
+
48
+ disable_ema: False
49
+ generate_samples: True
50
+ generated_samples_path: ''
51
+ max_samples: 50_000
52
+
53
+ training:
54
+ ema: 0.9999
55
+ antithetic_sampling: True
56
+ importance_sampling: False
57
+ sampling_eps: 1e-3
58
+ change_of_variables: False
59
+ compute_loss_on_pad_tokens: True
60
+ use_simple_ce_loss: False # Ignore ELBO; just use CE
61
+ guidance: null # Can turn off with `training.guidance: null`
62
+ # cond_dropout: 0.0
63
+
64
+ optim:
65
+ weight_decay: 1e-4
66
+ lr: 1e-5
67
+ beta1: 0.9
68
+ beta2: 0.999
69
+ eps: 1e-8
70
+
71
+ trainer:
72
+ _target_: lightning.Trainer
73
+ accelerator: cuda
74
+ num_nodes: 1
75
+ devices: 2 # ${device_count:}
76
+ accumulate_grad_batches: 1 # ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
77
+ gradient_clip_val: 1.0
78
+ precision: 'bf16-mixed'
79
+ num_sanity_val_steps: 2
80
+ # max_epochs: 10
81
+ max_steps: 1652000
82
+ log_every_n_steps: 100
83
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
84
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
85
+ val_check_interval: 16520 # 2545
86
+
87
+ wandb:
88
+ project: moPPIt-v2
89
+ job_type: model-training
90
+ name: protein_medium_100epochs_lr1e-5_gradclip1_wd1e-4_dropout0.1 #epochs10_lr3e-4_bsz8_64-true_all-params_gradclip1_beta-one0.9_beta-two0.999
91
+ id: ${.name}
92
+
93
+ hydra:
94
+ run:
95
+ dir: ./outputs/${wandb.name} # ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
96
+ job:
97
+ chdir: true
98
+
99
+ checkpointing:
100
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
101
+ save_dir: ${cwd:}
102
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
103
+ resume_from_ckpt: False
104
+ resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
muppit/configs/data/amazon_polarity.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ train: amazon_polarity
2
+ valid: amazon_polarity
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
9
+ label_col: label
10
+ num_classes: 2
muppit/configs/data/cifar10.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train: ??? # (Local) Path to CIFAR-10 training data
2
+ valid: ??? # (Local) Path to CIFAR-10 validation data
3
+ label_col: labels
4
+ num_classes: 10
5
+ streaming: False
6
+ size: 1024
7
+ length: 3072
8
+ add_special_tokens: True
9
+ add_mask_token: True
10
+ tokenizer_name_or_path: raw_pixels
11
+
muppit/configs/data/lm1b.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
muppit/configs/data/peptide.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ train: peptide
2
+ valid: peptide
3
+ tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D
4
+ cache_dir: /home/tc415/discrete-diffusion-guidance/dataset
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
muppit/configs/data/protein.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ train: protein_400k
2
+ valid: protein_400k
3
+ tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D
4
+ cache_dir: /home/tc415/discrete-diffusion-guidance/dataset
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
muppit/configs/data/qm9.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train: qm9
2
+ valid: qm9
3
+ tokenizer_name_or_path: yairschiff/qm9-tokenizer
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
9
+ label_col: qed
10
+ label_col_pctile: 90
11
+ num_classes: 2
muppit/configs/data/ten_species.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train: ten_species
2
+ valid: ten_species
3
+ tokenizer_name_or_path: kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: False
9
+ label_col: species_label
10
+ num_classes: 10
11
+ rc_aug: False
muppit/configs/data/text8.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: When using this dataset, set model.length = 256 to match D3PM setup
2
+ train: text8
3
+ valid: text8
4
+ tokenizer_name_or_path: text8
5
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
6
+ wrap: True
7
+ streaming: False
8
+ override_cache: False
9
+ add_special_tokens: False
muppit/configs/guidance/cbg.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ method: cbg
2
+ condition: 0
3
+ classifier_checkpoint_path: '/home/tc415/muPPIt_embedding/checkpoints/mutBind_small'
4
+ gamma: 2.0
5
+ use_approx: False # use first-order approximation
muppit/configs/guidance/cfg.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ method: cfg
2
+ condition: 0
3
+ gamma: 1.0
muppit/configs/guidance/fudge.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ method: fudge
2
+ condition: 0
3
+ classifier_checkpoint_path: ''
4
+ topk: 20
5
+ gamma: 1.0
muppit/configs/guidance/nos.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ method: nos
2
+ condition: 0
3
+ classifier_checkpoint_path: ''
4
+ num_nos_steps: 1
5
+ nos_step_size: 0.1
6
+ nos_stability_coef: 0.01
muppit/configs/guidance/pplm.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ method: pplm
2
+ condition: 0
3
+ classifier_checkpoint_path: ''
4
+ num_pplm_steps: 1
5
+ pplm_step_size: 0.1
6
+ pplm_stability_coef: 0.01
muppit/configs/lr_scheduler/constant_warmup.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: transformers.get_constant_schedule_with_warmup
2
+ num_warmup_steps: 2500
muppit/configs/lr_scheduler/cosine_decay_warmup.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: utils.CosineDecayWarmupLRScheduler
2
+ t_in_epochs: False
3
+ t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
4
+ warmup_prefix: True
5
+ warmup_lr_init: 1e-7
6
+ warmup_t: ${eval:0.1*${trainer.max_steps}}
7
+ lr_min: 1e-7
muppit/configs/model/dimamba.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dimamba
2
+ type: dimamba
3
+ hidden_size: 256
4
+ cond_dim: 128
5
+ length: 32768
6
+ n_blocks: 8
7
+ scale_by_sigma: True
8
+ dropout: 0.1
9
+ tie_word_embeddings: False
10
+ bidirectional: True,
11
+ bidirectional_strategy: add
12
+ bidirectional_weight_tie: True
muppit/configs/model/fudge_predictor.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: fudge_predictor
2
+ type: lstm
3
+ hidden_dim: 300
4
+ length: 1024
muppit/configs/model/hf.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pretrained_model_name_or_path: null
2
+ length: 128
muppit/configs/model/medium.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: medium
2
+ type: ddit
3
+ hidden_size: 1024
4
+ cond_dim: 128
5
+ length: 82
6
+ n_blocks: 24
7
+ n_heads: 16
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
muppit/configs/model/small.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: 15
6
+ # length_range: '6-49'
7
+ n_blocks: 12
8
+ n_heads: 12
9
+ scale_by_sigma: True
10
+ dropout: 0.1
11
+ tie_word_embeddings: False
muppit/configs/model/tiny.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
muppit/configs/model/unet.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: unet
2
+ type: unet
3
+ ch: 128
4
+ num_res_blocks: 2
5
+ num_scales: 4
6
+ ch_mult: [1, 2, 2, 2]
7
+ input_channels: 3
8
+ output_channels: -1 # determined by vocab_size
9
+ scale_count_to_put_attn: 1 # at 16 res
10
+ data_min_max: [0, 255] # No need currently
11
+ dropout: 0.1
12
+ skip_rescale: True
13
+ time_conditioning: True # Whether to add in time embeddings
14
+ time_scale_factor: 1000
15
+ time_embed_dim: ${.ch}
16
+ fix_logistic: False
17
+ size: ${data.size}
18
+ cond_dim: ${.ch}
19
+ length: ${data.length}
muppit/configs/model/unet_campbell.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: unet
2
+ type: unet
3
+ ch: 128
4
+ num_res_blocks: 2
5
+ num_scales: 4
6
+ ch_mult: [1, 2, 2, 2]
7
+ input_channels: 3
8
+ output_channels: -1 # determined by input_channels * 2
9
+ scale_count_to_put_attn: 1 # at 16 res
10
+ data_min_max: [0, 255] # No need currently, determined by [0, vocab_size]
11
+ dropout: 0.1
12
+ skip_rescale: True
13
+ time_conditioning: True # Whether to add in time embeddings
14
+ time_scale_factor: 1000
15
+ time_embed_dim: ${.ch}
16
+ fix_logistic: False
17
+ size: ${data.size}
18
+ cond_dim: ${.ch}
19
+ length: ${data.length}
muppit/configs/noise/ar.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ type: ar
2
+ scale: 6.0
muppit/configs/noise/linear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: linear
2
+ sigma_min: 1e-3
3
+ sigma_max: 7.0
muppit/configs/noise/loglinear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: loglinear
2
+ sigma_min: 1e-4
3
+ sigma_max: 20
muppit/configs/noise/polynomial.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ type: polynomial
2
+ a: -3
3
+ b: 5
4
+ c: -4
5
+ eps: 1e-3
muppit/configs/strategy/ddp.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: lightning.pytorch.strategies.DDPStrategy
2
+ find_unused_parameters: false
muppit/configs/strategy/fsdp.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # TODO(yair): Currently not compatible with grad clipping
2
+ _target_: lightning.pytorch.strategies.FSDPStrategy
3
+ sharding_strategy: SHARD_GRAD_OP
muppit/custom_datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import discretized_cifar10
2
+ from . import ten_species_dataset