Upload 139 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- binder_generator_train.py +241 -0
- finetune.py +385 -0
- muppit/.gitignore +7 -0
- muppit/LICENSE +201 -0
- muppit/README.md +250 -0
- muppit/__pycache__/classifier.cpython-310.pyc +0 -0
- muppit/__pycache__/dataloader.cpython-310.pyc +0 -0
- muppit/__pycache__/diffusion.cpython-310.pyc +0 -0
- muppit/__pycache__/noise_schedule.cpython-310.pyc +0 -0
- muppit/__pycache__/utils.cpython-310.pyc +0 -0
- muppit/classifier.py +490 -0
- muppit/configs/callbacks/checkpoint_every_n_steps.yaml +8 -0
- muppit/configs/callbacks/checkpoint_monitor.yaml +10 -0
- muppit/configs/callbacks/learning_rate_monitor.yaml +3 -0
- muppit/configs/classifier_model/dimamba-classifier.yaml +14 -0
- muppit/configs/classifier_model/hyenadna-classifier.yaml +4 -0
- muppit/configs/classifier_model/small-classifier.yaml +11 -0
- muppit/configs/classifier_model/tiny-classifier.yaml +11 -0
- muppit/configs/classifier_model/tiny-dimamba-classifier.yaml +14 -0
- muppit/configs/config.yaml +104 -0
- muppit/configs/data/amazon_polarity.yaml +10 -0
- muppit/configs/data/cifar10.yaml +11 -0
- muppit/configs/data/lm1b.yaml +8 -0
- muppit/configs/data/peptide.yaml +8 -0
- muppit/configs/data/protein.yaml +8 -0
- muppit/configs/data/qm9.yaml +11 -0
- muppit/configs/data/ten_species.yaml +11 -0
- muppit/configs/data/text8.yaml +9 -0
- muppit/configs/guidance/cbg.yaml +5 -0
- muppit/configs/guidance/cfg.yaml +3 -0
- muppit/configs/guidance/fudge.yaml +5 -0
- muppit/configs/guidance/nos.yaml +6 -0
- muppit/configs/guidance/pplm.yaml +6 -0
- muppit/configs/lr_scheduler/constant_warmup.yaml +2 -0
- muppit/configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
- muppit/configs/model/dimamba.yaml +12 -0
- muppit/configs/model/fudge_predictor.yaml +4 -0
- muppit/configs/model/hf.yaml +2 -0
- muppit/configs/model/medium.yaml +10 -0
- muppit/configs/model/small.yaml +11 -0
- muppit/configs/model/tiny.yaml +10 -0
- muppit/configs/model/unet.yaml +19 -0
- muppit/configs/model/unet_campbell.yaml +19 -0
- muppit/configs/noise/ar.yaml +2 -0
- muppit/configs/noise/linear.yaml +3 -0
- muppit/configs/noise/loglinear.yaml +3 -0
- muppit/configs/noise/polynomial.yaml +5 -0
- muppit/configs/strategy/ddp.yaml +2 -0
- muppit/configs/strategy/fsdp.yaml +3 -0
- muppit/custom_datasets/__init__.py +2 -0
binder_generator_train.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, Subset
|
3 |
+
from torch.optim import AdamW
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn as nn
|
6 |
+
from datasets import load_from_disk
|
7 |
+
import esm
|
8 |
+
import numpy as np
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
13 |
+
from transformers import get_linear_schedule_with_warmup
|
14 |
+
from tqdm import tqdm
|
15 |
+
from torch.cuda.amp import autocast, GradScaler
|
16 |
+
import gc
|
17 |
+
import pdb
|
18 |
+
|
19 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
20 |
+
|
21 |
+
##################### Hyper-parameters #############################################
|
22 |
+
max_epochs = 30
|
23 |
+
batch_size = 4
|
24 |
+
lr = 1e-4
|
25 |
+
num_layers = 4
|
26 |
+
num_heads = 4
|
27 |
+
accumulation_steps = 4
|
28 |
+
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0'
|
29 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
30 |
+
|
31 |
+
print(f'''
|
32 |
+
max_epochs = 30
|
33 |
+
batch_size = 4
|
34 |
+
lr = 1e-4
|
35 |
+
num_layers = 4
|
36 |
+
num_heads = 4
|
37 |
+
accumulation_steps = 4
|
38 |
+
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0'
|
39 |
+
''')
|
40 |
+
####################################################################################
|
41 |
+
|
42 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
43 |
+
|
44 |
+
train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_generator')
|
45 |
+
val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_generator')
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
47 |
+
print(len(train_dataset), len(val_dataset))
|
48 |
+
|
49 |
+
def collate_fn(batch):
|
50 |
+
# Unpack the batch
|
51 |
+
binders = []
|
52 |
+
targets = []
|
53 |
+
|
54 |
+
global tokenizer
|
55 |
+
|
56 |
+
for b in batch:
|
57 |
+
binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1])
|
58 |
+
target = torch.tensor(b['target_input_ids']['input_ids'][1:-1])
|
59 |
+
|
60 |
+
if binder.dim() == 0 or binder.numel() == 0 or target.dim() == 0 or target.numel() == 0:
|
61 |
+
continue
|
62 |
+
binders.append(binder) # shape: 1*L1 -> L1
|
63 |
+
targets.append(target) # shape: 1*L2 -> L2
|
64 |
+
|
65 |
+
# Collate the tensors using torch's pad_sequence
|
66 |
+
try:
|
67 |
+
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
|
68 |
+
|
69 |
+
target_input_ids = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id)
|
70 |
+
except:
|
71 |
+
pdb.set_trace()
|
72 |
+
|
73 |
+
# Return the collated batch
|
74 |
+
return {
|
75 |
+
'binder_input_ids': binder_input_ids.long(),
|
76 |
+
'target_input_ids': target_input_ids.long(),
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def RoPE(x, seq_dim=0):
|
81 |
+
"""
|
82 |
+
Applies Rotary Positional Encoding to the input embeddings.
|
83 |
+
:param x: Input tensor (seq_len, batch_size, embed_dim)
|
84 |
+
:param seq_dim: The sequence dimension, usually 0 (first dimension in (seq_len, batch_size, embed_dim))
|
85 |
+
:return: Tensor with RoPE applied (seq_len, batch_size, embed_dim)
|
86 |
+
"""
|
87 |
+
seq_len = x.shape[seq_dim]
|
88 |
+
d_model = x.shape[-1]
|
89 |
+
|
90 |
+
# Create the positions and the sine-cosine rotational matrices
|
91 |
+
theta = torch.arange(0, d_model, 2, dtype=torch.float32) / d_model
|
92 |
+
theta = 10000 ** (-theta) # scaling factor for RoPE
|
93 |
+
seq_idx = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1)
|
94 |
+
|
95 |
+
# Compute sine and cosine embedding for each position
|
96 |
+
sin_emb = torch.sin(seq_idx * theta)
|
97 |
+
cos_emb = torch.cos(seq_idx * theta)
|
98 |
+
|
99 |
+
sin_emb = sin_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2]
|
100 |
+
cos_emb = cos_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2]
|
101 |
+
|
102 |
+
x1, x2 = x[..., ::2], x[..., 1::2] # Split embedding into even and odd indices
|
103 |
+
|
104 |
+
cos_emb = cos_emb.to(x1.device)
|
105 |
+
sin_emb = sin_emb.to(x1.device)
|
106 |
+
|
107 |
+
# Apply rotary transformation
|
108 |
+
x_rotated = torch.cat([x1 * cos_emb - x2 * sin_emb, x1 * sin_emb + x2 * cos_emb], dim=-1)
|
109 |
+
return x_rotated
|
110 |
+
|
111 |
+
class BinderGenerator(nn.Module):
|
112 |
+
def __init__(self, vocab_size=24, embed_dim=1280, num_heads=4, num_layers=4, lr=1e-4):
|
113 |
+
super(BinderGenerator, self).__init__()
|
114 |
+
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
115 |
+
for param in self.esm.parameters():
|
116 |
+
param.requires_grad = False
|
117 |
+
|
118 |
+
self.transformer = nn.Transformer(d_model=embed_dim, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers)
|
119 |
+
self.fc_out = nn.Linear(embed_dim, vocab_size)
|
120 |
+
|
121 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=self.alphabet.padding_idx)
|
122 |
+
self.vocab_size = vocab_size
|
123 |
+
self.learning_rate = lr
|
124 |
+
|
125 |
+
def forward(self, binder_tokens, target_tokens):
|
126 |
+
with torch.no_grad():
|
127 |
+
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int()
|
128 |
+
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1)
|
129 |
+
|
130 |
+
target_pad_mask = (target_tokens != self.alphabet.padding_idx).int()
|
131 |
+
target_embed = self.esm(target_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * target_pad_mask.unsqueeze(-1)
|
132 |
+
|
133 |
+
binder_embed = binder_embed.transpose(0,1)
|
134 |
+
target_embed = target_embed.transpose(0,1)
|
135 |
+
|
136 |
+
binder_embed = RoPE(binder_embed) # [src_len, batch_size, embed_dim]
|
137 |
+
target_embed = RoPE(target_embed) # [tgt_len, batch_size, embed_dim]
|
138 |
+
|
139 |
+
output = self.transformer(binder_embed, target_embed) # [tgt_len, batch_size, embed_dim]
|
140 |
+
return self.fc_out(output).transpose(0,1) # [batch_size, tgt_len, vocab_size]
|
141 |
+
|
142 |
+
def compute_loss(self, binder_tokens, target_tokens):
|
143 |
+
output = self.forward(binder_tokens, target_tokens)
|
144 |
+
|
145 |
+
loss = self.criterion(output[:, :-1, :].reshape(-1, self.vocab_size), target_tokens[:, 1:].reshape(-1))
|
146 |
+
|
147 |
+
return loss
|
148 |
+
|
149 |
+
def step(self, batch, compute_acc=False):
|
150 |
+
binder_tokens = batch['binder_input_ids']
|
151 |
+
target_tokens = batch['target_input_ids']
|
152 |
+
|
153 |
+
binder_tokens = binder_tokens.to(device)
|
154 |
+
target_tokens = target_tokens.to(device)
|
155 |
+
|
156 |
+
loss = self.compute_loss(binder_tokens, target_tokens)
|
157 |
+
|
158 |
+
if compute_acc:
|
159 |
+
preds = torch.argmax(output[:-1], dim=-1)
|
160 |
+
correct = (preds == target_tokens[1:]).sum().item()
|
161 |
+
accuracy = correct / (target_tokens[1:] != self.alphabet.padding_idx).sum().item()
|
162 |
+
return loss, accuracy
|
163 |
+
else:
|
164 |
+
return loss
|
165 |
+
|
166 |
+
|
167 |
+
def train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size, max_epochs=10, accumulation_steps=4):
|
168 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4)
|
169 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4)
|
170 |
+
|
171 |
+
max_val_acc = 0
|
172 |
+
for epoch in range(max_epochs):
|
173 |
+
print(f"Epoch {epoch + 1}/{max_epochs}")
|
174 |
+
|
175 |
+
scaler = GradScaler()
|
176 |
+
|
177 |
+
model.train()
|
178 |
+
running_loss = 0.0
|
179 |
+
optimizer.zero_grad()
|
180 |
+
|
181 |
+
for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
|
182 |
+
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} # Transfer batch to GPU
|
183 |
+
|
184 |
+
with autocast():
|
185 |
+
loss = model.step(batch)
|
186 |
+
|
187 |
+
scaler.scale(loss).backward()
|
188 |
+
|
189 |
+
if (batch_idx + 1) % accumulation_steps == 0:
|
190 |
+
scaler.step(optimizer)
|
191 |
+
scaler.update()
|
192 |
+
optimizer.zero_grad()
|
193 |
+
|
194 |
+
if scheduler.last_epoch < warmup_steps:
|
195 |
+
scheduler.step()
|
196 |
+
else:
|
197 |
+
cosine_scheduler.step()
|
198 |
+
|
199 |
+
running_loss += loss.item()
|
200 |
+
|
201 |
+
print(f"Epoch {epoch}: Training Loss = {running_loss / len(train_loader)}")
|
202 |
+
|
203 |
+
del train_loader, running_loss
|
204 |
+
gc.collect()
|
205 |
+
torch.cuda.empty_cache()
|
206 |
+
|
207 |
+
model.eval()
|
208 |
+
val_loss = 0.0
|
209 |
+
val_acc = 0.0
|
210 |
+
with torch.no_grad():
|
211 |
+
for batch in tqdm(val_loader, total=len(val_loader)):
|
212 |
+
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
|
213 |
+
val_loss_batch, val_acc_batch = model.step(batch, compute_acc=True)
|
214 |
+
val_loss += val_loss_batch.item()
|
215 |
+
val_acc += val_acc_batch.item()
|
216 |
+
|
217 |
+
print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader)}\tVal Acc = {val_acc / len(val_dataset)}")
|
218 |
+
|
219 |
+
if val_acc > max_val_acc:
|
220 |
+
max_val_acc = val_acc
|
221 |
+
|
222 |
+
global checkpoint_path
|
223 |
+
torch.save(model.state_dict(), os.path.join(checkpoint_path, f"epoch={epoch}_acc={round(val_acc / len(val_dataset), 2)}"))
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
model = BinderGenerator(vocab_size=24, embed_dim=1280, num_heads=num_heads, num_layers=num_layers, lr=lr).to(device)
|
228 |
+
optimizer = AdamW(model.parameters(), lr=model.learning_rate, betas=(0.9, 0.95), weight_decay=1e-5)
|
229 |
+
|
230 |
+
total_steps = len(train_dataset) // (batch_size*accumulation_steps) * max_epochs # Assuming batch_size=32, max_epochs=10
|
231 |
+
warmup_steps = int(0.1 * total_steps)
|
232 |
+
|
233 |
+
scheduler = get_linear_schedule_with_warmup(
|
234 |
+
optimizer,
|
235 |
+
num_warmup_steps=warmup_steps,
|
236 |
+
num_training_steps=total_steps
|
237 |
+
)
|
238 |
+
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0.1*lr)
|
239 |
+
|
240 |
+
|
241 |
+
train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size=batch_size, max_epochs=max_epochs, accumulation_steps=accumulation_steps)
|
finetune.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
from pytorch_lightning.strategies import DDPStrategy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.utils.data import DataLoader, DistributedSampler, BatchSampler, Sampler
|
7 |
+
from datasets import load_from_disk
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
|
10 |
+
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
|
11 |
+
from pytorch_lightning.loggers import WandbLogger
|
12 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
13 |
+
from transformers.optimization import get_cosine_schedule_with_warmup
|
14 |
+
from argparse import ArgumentParser
|
15 |
+
import os
|
16 |
+
import uuid
|
17 |
+
import esm
|
18 |
+
import numpy as np
|
19 |
+
import torch.distributed as dist
|
20 |
+
from torch.nn.utils.rnn import pad_sequence
|
21 |
+
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
22 |
+
# from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
23 |
+
from torch.optim import Adam, AdamW
|
24 |
+
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
|
25 |
+
import torch_geometric.nn as pyg_nn
|
26 |
+
import gc
|
27 |
+
import math
|
28 |
+
|
29 |
+
# os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
|
30 |
+
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
31 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
32 |
+
|
33 |
+
vhse8_values = {
|
34 |
+
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
|
35 |
+
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
|
36 |
+
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80],
|
37 |
+
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
|
38 |
+
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41],
|
39 |
+
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
|
40 |
+
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36],
|
41 |
+
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10],
|
42 |
+
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
|
43 |
+
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
|
44 |
+
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
|
45 |
+
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01],
|
46 |
+
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
|
47 |
+
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65],
|
48 |
+
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56],
|
49 |
+
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
|
50 |
+
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39],
|
51 |
+
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85],
|
52 |
+
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
|
53 |
+
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
|
54 |
+
}
|
55 |
+
|
56 |
+
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7}
|
57 |
+
|
58 |
+
vhse8_tensor = torch.zeros(24, 8)
|
59 |
+
for aa, values in vhse8_values.items():
|
60 |
+
aa_index = aa_to_idx[aa]
|
61 |
+
vhse8_tensor[aa_index] = torch.tensor(values)
|
62 |
+
vhse8_tensor.requires_grad = False
|
63 |
+
|
64 |
+
|
65 |
+
def collate_fn(batch):
|
66 |
+
# Unpack the batch
|
67 |
+
binders = []
|
68 |
+
mutants = []
|
69 |
+
wildtypes = []
|
70 |
+
affs = []
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
72 |
+
|
73 |
+
for b in batch:
|
74 |
+
binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1])
|
75 |
+
mutant = torch.tensor(b['mutant_input_ids']['input_ids'][1:-1])
|
76 |
+
wildtype = torch.tensor(b['wildtype_input_ids']['input_ids'][1:-1])
|
77 |
+
|
78 |
+
if binder.dim() == 0 or binder.numel() == 0 or mutant.dim() == 0 or mutant.numel() == 0 or wildtype.dim() == 0 or wildtype.numel() == 0:
|
79 |
+
continue
|
80 |
+
binders.append(binder) # shape: 1*L1 -> L1
|
81 |
+
mutants.append(mutant) # shape: 1*L2 -> L2
|
82 |
+
wildtypes.append(wildtype) # shape: 1*L3 -> L3
|
83 |
+
|
84 |
+
affs.append(b['aff'])
|
85 |
+
|
86 |
+
|
87 |
+
# Collate the tensors using torch's pad_sequence
|
88 |
+
try:
|
89 |
+
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
|
90 |
+
|
91 |
+
mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id)
|
92 |
+
|
93 |
+
wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id)
|
94 |
+
except:
|
95 |
+
pdb.set_trace()
|
96 |
+
|
97 |
+
affs = torch.tensor(affs)
|
98 |
+
# Return the collated batch
|
99 |
+
return {
|
100 |
+
'binder_input_ids': binder_input_ids.int(),
|
101 |
+
'mutant_input_ids': mutant_input_ids.int(),
|
102 |
+
'wildtype_input_ids': wildtype_input_ids.int(),
|
103 |
+
'aff': affs
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
class CustomDataModule(pl.LightningDataModule):
|
108 |
+
def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
|
109 |
+
super().__init__()
|
110 |
+
self.train_dataset = train_dataset
|
111 |
+
self.val_dataset = val_dataset
|
112 |
+
self.batch_size = batch_size
|
113 |
+
self.tokenizer = tokenizer
|
114 |
+
print(len(train_dataset))
|
115 |
+
print(len(val_dataset))
|
116 |
+
|
117 |
+
def train_dataloader(self):
|
118 |
+
# batch_sampler = LengthAwareDistributedSampler(self.train_dataset, 'mutant_tokens', self.batch_size)
|
119 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
|
120 |
+
num_workers=8, pin_memory=True)
|
121 |
+
|
122 |
+
def val_dataloader(self):
|
123 |
+
# batch_sampler = LengthAwareDistributedSampler(self.val_dataset, 'mutant_tokens', self.batch_size)
|
124 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
|
125 |
+
pin_memory=True)
|
126 |
+
|
127 |
+
def setup(self, stage=None):
|
128 |
+
if stage == 'test' or stage is None:
|
129 |
+
pass
|
130 |
+
|
131 |
+
|
132 |
+
class CosineAnnealingWithWarmup(_LRScheduler):
|
133 |
+
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
|
134 |
+
self.warmup_steps = warmup_steps
|
135 |
+
self.total_steps = total_steps
|
136 |
+
self.base_lr = base_lr
|
137 |
+
self.max_lr = max_lr
|
138 |
+
self.min_lr = min_lr
|
139 |
+
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
|
140 |
+
print(f"SELF BASE LRS = {self.base_lrs}")
|
141 |
+
|
142 |
+
def get_lr(self):
|
143 |
+
if self.last_epoch < self.warmup_steps:
|
144 |
+
# Linear warmup phase from base_lr to max_lr
|
145 |
+
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
|
146 |
+
|
147 |
+
# Cosine annealing phase from max_lr to min_lr
|
148 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
149 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
150 |
+
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
|
151 |
+
|
152 |
+
return [decayed_lr for base_lr in self.base_lrs]
|
153 |
+
|
154 |
+
|
155 |
+
class muPPIt(pl.LightningModule):
|
156 |
+
def __init__(self, d_node, num_heads, dropout, margin, lr):
|
157 |
+
super(muPPIt, self).__init__()
|
158 |
+
|
159 |
+
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
160 |
+
for param in self.esm.parameters():
|
161 |
+
param.requires_grad = False
|
162 |
+
|
163 |
+
self.attention = nn.MultiheadAttention(embed_dim=d_node, num_heads=num_heads)
|
164 |
+
self.layer_norm = nn.LayerNorm(d_node)
|
165 |
+
|
166 |
+
self.map = nn.Sequential(
|
167 |
+
nn.Linear(d_node, d_node // 2),
|
168 |
+
nn.SiLU(),
|
169 |
+
nn.Dropout(dropout),
|
170 |
+
nn.Linear(d_node // 2, 1)
|
171 |
+
)
|
172 |
+
|
173 |
+
self.margin = margin
|
174 |
+
self.learning_rate = lr
|
175 |
+
|
176 |
+
for layer in self.map:
|
177 |
+
if isinstance(layer, nn.Linear):
|
178 |
+
nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
179 |
+
if layer.bias is not None:
|
180 |
+
nn.init.zeros_(layer.bias)
|
181 |
+
|
182 |
+
def forward(self, binder_tokens, wt_tokens, mut_tokens):
|
183 |
+
device = binder_tokens.device
|
184 |
+
global vhse8_tensor
|
185 |
+
|
186 |
+
vhse8_tensor = vhse8_tensor.to(device)
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int()
|
190 |
+
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1)
|
191 |
+
binder_vhse8 = vhse8_tensor[binder_tokens]
|
192 |
+
binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1)
|
193 |
+
|
194 |
+
mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int()
|
195 |
+
mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * mut_pad_mask.unsqueeze(-1)
|
196 |
+
mut_vhse8 = vhse8_tensor[mut_tokens]
|
197 |
+
mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1)
|
198 |
+
|
199 |
+
wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int()
|
200 |
+
wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * wt_pad_mask.unsqueeze(-1)
|
201 |
+
wt_vhse8 = vhse8_tensor[wt_tokens]
|
202 |
+
wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1)
|
203 |
+
|
204 |
+
binder_wt = torch.concat([binder_embed, wt_embed], dim=1)
|
205 |
+
binder_mut = torch.concat([binder_embed, mut_embed], dim=1)
|
206 |
+
|
207 |
+
binder_wt_attn, _ = self.attention(binder_wt, binder_wt, binder_wt)
|
208 |
+
binder_mut_attn, _ = self.attention(binder_mut, binder_mut, binder_mut)
|
209 |
+
|
210 |
+
binder_wt_attn = binder_wt + binder_wt_attn
|
211 |
+
binder_mut_attn = binder_mut + binder_mut_attn
|
212 |
+
|
213 |
+
binder_wt_attn = self.layer_norm(binder_wt_attn)
|
214 |
+
binder_mut_attn = self.layer_norm(binder_mut_attn)
|
215 |
+
|
216 |
+
mapped_binder_wt = self.map(binder_wt_attn).squeeze(-1) # B*(L1+L2)
|
217 |
+
mapped_binder_mut = self.map(binder_mut_attn).squeeze(-1) # B*(L1+L2)
|
218 |
+
|
219 |
+
# mean_binder_wt = torch.mean(mapped_binder_wt, dim=1)
|
220 |
+
# mean_binder_mut = torch.mean(mapped_binder_mut, dim=1)
|
221 |
+
|
222 |
+
distance = torch.sqrt(torch.sum((mapped_binder_wt - mapped_binder_mut) ** 2, dim=-1))
|
223 |
+
return distance
|
224 |
+
|
225 |
+
|
226 |
+
def training_step(self, batch, batch_idx):
|
227 |
+
opt = self.optimizers()
|
228 |
+
lr = opt.param_groups[0]['lr']
|
229 |
+
self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
230 |
+
|
231 |
+
binder_tokens = batch['binder_input_ids'].to(self.device)
|
232 |
+
mut_tokens = batch['mutant_input_ids'].to(self.device)
|
233 |
+
wt_tokens = batch['wildtype_input_ids'].to(self.device)
|
234 |
+
aff = batch['aff'].to(self.device)
|
235 |
+
|
236 |
+
distance = self.forward(binder_tokens, wt_tokens, mut_tokens)
|
237 |
+
|
238 |
+
# pdb.set_trace()
|
239 |
+
# loss = torch.clamp(self.margin * aff - distance, min=0)
|
240 |
+
upper_loss = F.relu(distance - self.margin *(aff + 1)) # let distance < aff + 1
|
241 |
+
lower_loss = F.relu(self.margin * aff - distance) # let distance > aff
|
242 |
+
|
243 |
+
loss = 5 * upper_loss + lower_loss
|
244 |
+
|
245 |
+
self.log('train_loss', loss.mean().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
246 |
+
return loss.mean()
|
247 |
+
|
248 |
+
def validation_step(self, batch, batch_idx):
|
249 |
+
binder_tokens = batch['binder_input_ids'].to(self.device)
|
250 |
+
mut_tokens = batch['mutant_input_ids'].to(self.device)
|
251 |
+
wt_tokens = batch['wildtype_input_ids'].to(self.device)
|
252 |
+
aff = batch['aff'].to(self.device)
|
253 |
+
|
254 |
+
distance = self.forward(binder_tokens, wt_tokens, mut_tokens)
|
255 |
+
|
256 |
+
# pdb.set_trace()
|
257 |
+
|
258 |
+
# loss = torch.clamp(self.margin * aff - distance, min=0)
|
259 |
+
# accuracy = torch.sum(distance >= self.margin * aff) / aff.shape[0]
|
260 |
+
upper_loss = F.relu(distance - self.margin * (aff + 1))
|
261 |
+
lower_loss = F.relu(self.margin * aff - distance)
|
262 |
+
|
263 |
+
loss = 5 * upper_loss + lower_loss
|
264 |
+
|
265 |
+
accuracy = torch.sum(torch.logical_and(torch.ge(distance, self.margin * aff), torch.le(distance, self.margin *(aff + 1)))) / aff.shape[0]
|
266 |
+
|
267 |
+
self.log('val_loss', loss.mean().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
268 |
+
self.log('val_acc', accuracy.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
269 |
+
|
270 |
+
def configure_optimizers(self):
|
271 |
+
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
|
272 |
+
|
273 |
+
base_lr = 0.1 * self.learning_rate
|
274 |
+
max_lr = self.learning_rate
|
275 |
+
min_lr = 0.1 * self.learning_rate
|
276 |
+
|
277 |
+
schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=119, total_steps=1188,
|
278 |
+
base_lr=base_lr, max_lr=max_lr, min_lr=min_lr) # warmup_steps=3193, total_steps=31926
|
279 |
+
|
280 |
+
lr_schedulers = {
|
281 |
+
"scheduler": schedulers,
|
282 |
+
"name": 'learning_rate_logs',
|
283 |
+
"interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
|
284 |
+
'frequency': 1 # The scheduler updates the learning rate after every batch
|
285 |
+
}
|
286 |
+
return [optimizer], [lr_schedulers]
|
287 |
+
|
288 |
+
def on_training_epoch_end(self, outputs):
|
289 |
+
gc.collect()
|
290 |
+
torch.cuda.empty_cache()
|
291 |
+
super().training_epoch_end(outputs)
|
292 |
+
|
293 |
+
def load_weights(self, checkpoint_path):
|
294 |
+
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
295 |
+
|
296 |
+
state_dict = checkpoint['state_dict']
|
297 |
+
|
298 |
+
self.load_state_dict(state_dict, strict=True)
|
299 |
+
|
300 |
+
|
301 |
+
def main(args):
|
302 |
+
print(args)
|
303 |
+
dist.init_process_group(backend='nccl')
|
304 |
+
|
305 |
+
train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/affinity_embedding_skempi') #408643
|
306 |
+
val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/affinity_embedding_skempi')
|
307 |
+
# val_dataset = None
|
308 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
309 |
+
|
310 |
+
data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
|
311 |
+
|
312 |
+
model = muPPIt(args.d_node, args.num_heads, args.dropout, args.margin, args.lr)
|
313 |
+
print(f"Loading Pre-trained Model from {args.sm}")
|
314 |
+
model = muPPIt.load_from_checkpoint(args.sm, d_node=args.d_node, num_heads=args.num_heads, dropout=args.dropout, margin=args.margin, lr=args.lr)
|
315 |
+
|
316 |
+
run_id = str(uuid.uuid4())
|
317 |
+
|
318 |
+
logger = WandbLogger(project=f"muppit_embedding",
|
319 |
+
# name="debug",
|
320 |
+
name=f"affinity_lr={args.lr}_gradclip={args.grad_clip}_margin={args.margin}",
|
321 |
+
job_type='model-training',
|
322 |
+
id=run_id)
|
323 |
+
|
324 |
+
print(f"Saving to {args.output_file}")
|
325 |
+
|
326 |
+
checkpoint_callback = ModelCheckpoint(
|
327 |
+
monitor='val_acc',
|
328 |
+
# monitor='val_loss',
|
329 |
+
dirpath=args.output_file,
|
330 |
+
# filename='model-{epoch:02d}-{val_loss:.2f}',
|
331 |
+
filename='model-{epoch:02d}-{val_acc:.2f}',
|
332 |
+
# filename='muppit',
|
333 |
+
save_top_k=-1,
|
334 |
+
mode='max',
|
335 |
+
# mode='min',
|
336 |
+
# every_n_train_steps=1000,
|
337 |
+
# save_on_train_epoch_end=False
|
338 |
+
)
|
339 |
+
|
340 |
+
early_stopping_callback = EarlyStopping(
|
341 |
+
# monitor='val_acc',
|
342 |
+
monitor='val_loss',
|
343 |
+
patience=10,
|
344 |
+
verbose=True,
|
345 |
+
# mode='max',
|
346 |
+
mode='min',
|
347 |
+
)
|
348 |
+
|
349 |
+
accumulator = GradientAccumulationScheduler(scheduling={0: 4})
|
350 |
+
|
351 |
+
trainer = pl.Trainer(
|
352 |
+
max_epochs=args.max_epochs,
|
353 |
+
accelerator='gpu',
|
354 |
+
strategy='ddp_find_unused_parameters_true',
|
355 |
+
precision='bf16',
|
356 |
+
# logger=logger,
|
357 |
+
devices=[0,1],
|
358 |
+
callbacks=[checkpoint_callback, accumulator],
|
359 |
+
gradient_clip_val=args.grad_clip,
|
360 |
+
# val_check_interval=100,
|
361 |
+
)
|
362 |
+
|
363 |
+
trainer.fit(model, datamodule=data_module)
|
364 |
+
|
365 |
+
best_model_path = checkpoint_callback.best_model_path
|
366 |
+
print(best_model_path)
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
parser = ArgumentParser()
|
371 |
+
|
372 |
+
parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
|
373 |
+
parser.add_argument("-lr", type=float, default=1e-3)
|
374 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
|
375 |
+
parser.add_argument("-grad_clip", type=float, default=0.5)
|
376 |
+
parser.add_argument("-margin", type=float, default=0.5)
|
377 |
+
parser.add_argument("-max_epochs", type=int, default=30)
|
378 |
+
parser.add_argument("-d_node", type=int, default=1024, help="Node Representation Dimension")
|
379 |
+
parser.add_argument("-num_heads", type=int, default=4)
|
380 |
+
parser.add_argument("-dropout", type=float, default=0.1)
|
381 |
+
parser.add_argument("-sm", type=str, default='/home/tc415/muPPIt_embedding/checkpoints/train_10/model-epoch=15-val_acc=0.62.ckpt')
|
382 |
+
|
383 |
+
args = parser.parse_args()
|
384 |
+
|
385 |
+
main(args)
|
muppit/.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
.DS_Store
|
3 |
+
.ipynb_checkpoints/
|
4 |
+
__pycache__/
|
5 |
+
.hf_cache
|
6 |
+
outputs/
|
7 |
+
watch_folder/
|
muppit/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
muppit/README.md
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Simple Guidance Mechanisms for Discrete Diffusion Models
|
2 |
+
|
3 |
+
[](https://arxiv.org/abs/2412.10193)
|
4 |
+
[](https://discrete-diffusion-guidance.github.io/)
|
5 |
+
[](https://huggingface.co/collections/kuleshov-group/udlm-675e63ab42bc757093099e1b)
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<img src="https://discrete-diffusion-guidance.github.io/static/images/udlm.gif" alt="graphical abstract" width="450"/>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
This repository contains code for reproducing experiments in the paper [Simple Guidance Mechanisms for Discrete Diffusion Models](https://arxiv.org/abs/2412.10193)
|
12 |
+
|
13 |
+
We also share [trained models](https://huggingface.co/collections/kuleshov-group/udlm-675e63ab42bc757093099e1b) on HuggingFace 🤗 and support intergration with these models.
|
14 |
+
See the "[Using HuggingFace Models" section](#using-huggingface-models) below.
|
15 |
+
|
16 |
+
## Code Organization
|
17 |
+
<a name="code-organization"></a>
|
18 |
+
1. ```main.py```: Routines for training (language models and classifiers)
|
19 |
+
2. ```noise_schedule.py```: Noise schedules
|
20 |
+
3. ```diffusion.py```: Forward/reverse diffusion
|
21 |
+
- Absorbing state / uniform noise diffusion
|
22 |
+
- AR
|
23 |
+
4. ```dataloader.py```: Dataloaders
|
24 |
+
- For Discretized CIFAR10 and the Species10 datasets we use custom dataset classes defined in ```custom_datasets/```
|
25 |
+
5. ```utils.py```: LR scheduler, logging, `fsspec` handling
|
26 |
+
6. ```models/```: Denoising network architectures.
|
27 |
+
7. ```configs/```: Config files for datasets/denoising networks/noise schedules/LR schedules
|
28 |
+
8. ```scripts/```: Shell scripts for training/evaluation
|
29 |
+
9. ```guidance_eval/```: Guidance evaluation scripts
|
30 |
+
|
31 |
+
|
32 |
+
### Implemented Decoding Mechanisms
|
33 |
+
<a name="implemented-decoding"></a>
|
34 |
+
In [`diffusion.py`](./diffusion.py),
|
35 |
+
we define baseline and proposed decoding mechanisms for guidance.
|
36 |
+
These decoding schemes can be controlled via the hydra config with the `guidance` field.
|
37 |
+
For example, to use the proposed D-CFG guidance mechanism,
|
38 |
+
set `guidance=cfg` in the config file and optionally set the `guidance.gamma` parameter to control the strength of the guidance signal.
|
39 |
+
|
40 |
+
The implemented decoding methods are as follows:
|
41 |
+
- AR (Baseline):
|
42 |
+
- Standard decoding (i.e., no-guidance); set `guidance=null`
|
43 |
+
- Classifier-free guidance (D-CFG); set `guidance=cfg`
|
44 |
+
- Classifier-based guidance using [FUDGE](https://arxiv.org/abs/2104.05218) (set `guidance=fudge`) and using [PPLM](https://arxiv.org/abs/1912.02164) (set `guidance=pplm`)
|
45 |
+
- Diffusion:
|
46 |
+
- Standard decoding (i.e., no guidance); set `guidance=null`
|
47 |
+
- Classifier-free guidance (D-CFG); set `guidance=cfg`
|
48 |
+
- Classifier-based guidance (D-CBG); set `guidance=cbg`
|
49 |
+
- Classifier-based (baseline) method of [NOS](https://arxiv.org/abs/2305.20009); set `guidance=nos`
|
50 |
+
|
51 |
+
### Implemented Generative Models
|
52 |
+
<a name="implemented-models"></a>
|
53 |
+
The three modeling parameterizations
|
54 |
+
we explore in this work are:
|
55 |
+
1. Autoregressive (AR) Models
|
56 |
+
2. Masked Diffusion Language Models (MDLM)
|
57 |
+
3. Uniform Diffusion Language Models (UDLM)
|
58 |
+
|
59 |
+
The `config` files can be used
|
60 |
+
to specify which of these parameterizations to use.
|
61 |
+
Below we detail which config parameters correspond to which model.
|
62 |
+
|
63 |
+
**AR**
|
64 |
+
```bash
|
65 |
+
diffusion="absorbing_state" # AR models can be thought of as a special case of abosrbing state diffusion models
|
66 |
+
parameterization="ar"
|
67 |
+
T=0 # N/A for AR models, this is a placeholder
|
68 |
+
time_conditioning=False # AR models are not conditioned on time
|
69 |
+
zero_recon_loss=False # N/A for this model
|
70 |
+
```
|
71 |
+
|
72 |
+
**MDLM**
|
73 |
+
```bash
|
74 |
+
diffusion="absorbing_state"
|
75 |
+
parameterization="subs" # See MDLM paper for details: https://arxiv.org/abs/2406.07524
|
76 |
+
T=0 # Indicates continuous-time, e.g. T --> infinity
|
77 |
+
time_conditioning=False # MDLM not conditioned on time
|
78 |
+
zero_recon_loss=False # N/A for this model
|
79 |
+
```
|
80 |
+
|
81 |
+
**UDLM**
|
82 |
+
```bash
|
83 |
+
diffusion="uniform"
|
84 |
+
parameterization="d3pm" # Indicates that we explicitly compute KL on posteriors
|
85 |
+
T=0 # Indicates continuous-time, e.g. T --> infinity
|
86 |
+
time_conditioning=True # UDLM is conditioned on time
|
87 |
+
zero_recon_loss=True # In continuous time, recon loss evaluates to zero
|
88 |
+
```
|
89 |
+
|
90 |
+
## Getting started in this repository
|
91 |
+
<a name="getting-started"></a>
|
92 |
+
|
93 |
+
To get started, create a conda environment containing the required dependencies.
|
94 |
+
|
95 |
+
```bash
|
96 |
+
conda env create -f requirements.yaml
|
97 |
+
conda activate discdiff
|
98 |
+
```
|
99 |
+
|
100 |
+
Create the following directories to store saved models and slurm logs:
|
101 |
+
```bash
|
102 |
+
mkdir outputs
|
103 |
+
mkdir watch_folder
|
104 |
+
```
|
105 |
+
|
106 |
+
We rely on `wandb` integration
|
107 |
+
to log experiments and eval curves.
|
108 |
+
|
109 |
+
## Reproducing Experiments
|
110 |
+
<a name="reproducing-experiments"></a>
|
111 |
+
|
112 |
+
Below, we describe the steps required for reproducing the experiments in the paper.
|
113 |
+
Throughout, the main entry point for running experiments is the [`main.py`](./main.py) script.
|
114 |
+
We also provide sample `slurm` scripts for launching pre-training and evaluation experiments in the [`scrips/`](./scripts) directory.
|
115 |
+
|
116 |
+
|
117 |
+
### Language Modeling Experiments
|
118 |
+
<a name="lm_training"></a>
|
119 |
+
To reproduce the language modeling results, please refer to the following shell scripts in the [`scripts/`](./scripts) directory:
|
120 |
+
- Species10: [`train_ten_species_guidance.sh`](./scripts/train_ten_species_guidance.sh)
|
121 |
+
- QM9: [`train_qm9_no-guidance.sh`](./scripts/train_qm9_no-guidance.sh)
|
122 |
+
- CIFAR10: [`train_cifar10_unet_guidance.sh`](./scripts/train_cifar10_unet_guidance.sh)
|
123 |
+
- text8: [`train_text8.sh`](./scripts/train_text8.sh)
|
124 |
+
- Amazon Polarity: [`train_amazon_polarity.sh`](./scripts/train_amazon_polarity.sh)
|
125 |
+
- LM1B: [`train_lm1b.sh`](./scripts/train_lm1b.sh)
|
126 |
+
|
127 |
+
Each script contains a comment detailing the usage.
|
128 |
+
For example, to train either an AR,
|
129 |
+
MDLM, or UDLM model on the `text8` dataset, use the following command:
|
130 |
+
```bash
|
131 |
+
cd scripts/
|
132 |
+
MODEL=<ar|mdlm|udlm>
|
133 |
+
sbatch \
|
134 |
+
--export=ALL,MODEL=${MODEL} \
|
135 |
+
--job-name=train_text8_${MODEL} \
|
136 |
+
train_text8.sh
|
137 |
+
```
|
138 |
+
### Guidance Training
|
139 |
+
<a name="guidance-training"></a>
|
140 |
+
#### Classifier-Free
|
141 |
+
<a name="guidance-training-cfg"></a>
|
142 |
+
For classifier-free guidance we require training models
|
143 |
+
that can condition on the class label
|
144 |
+
to model conditional distributions,
|
145 |
+
and we randomly mask out the signal,
|
146 |
+
replacing it with a dummy value of `num_claseses + 1`, to simulate an unconditional model.
|
147 |
+
Refer to the shell scripts with the `_guidance` suffix
|
148 |
+
to train these models for CIFAR10,
|
149 |
+
QM9, and Species10 datasets.
|
150 |
+
For QM9, we have two experiments,
|
151 |
+
one where we condition on the drug-likeness
|
152 |
+
(`qed`)
|
153 |
+
of the molecules and another
|
154 |
+
where we condition on the ring counts (`ring_count`).
|
155 |
+
|
156 |
+
#### Classifier-Based
|
157 |
+
<a name="guidance-training-cbg"></a>
|
158 |
+
For classifier-based guidance,
|
159 |
+
we need to train a classifier on the noisy latent samples.
|
160 |
+
Refer to the following shell scripts
|
161 |
+
to train these classifiers:
|
162 |
+
- [FUDGE](https://arxiv.org/abs/2104.05218) (AR guidance): [`train_qm9_fudge_classifier.sh`](./scripts/train_qm9_fudge_classifier.sh)
|
163 |
+
- D-CBG (diffusion guidance): [`train_qm9_classifier.sh`](./scripts/train_qm9_classifier.sh)
|
164 |
+
|
165 |
+
##### PPLM / NOS baselines
|
166 |
+
An alternative classifier-based guidance mechanism to D-CBG is that of [PPLM](https://arxiv.org/abs/1912.02164)
|
167 |
+
(which was adapted for diffusion models in [NOS](https://arxiv.org/abs/2305.20009)).
|
168 |
+
To train these classifiers,
|
169 |
+
refer to the following shell script:
|
170 |
+
[`train_qm9_pplm_classifier.sh`](./scripts/train_qm9_pplm_classifier.sh)
|
171 |
+
(for both PPLM and NOS classifiers).
|
172 |
+
|
173 |
+
### Guidance Evaluation
|
174 |
+
<a name="guidance-eval"></a>
|
175 |
+
To evaluate guidance mechanisms, we load trained models
|
176 |
+
(and classifiers, if applicable)
|
177 |
+
and generate some number of samples
|
178 |
+
for which we compute "quality" metrics
|
179 |
+
(e.g., validity/novelty in the QM9 experiments)
|
180 |
+
and control label satisfaction (e.g., mean value of novel generated molecules for the property of interest in the QM9 experiments).
|
181 |
+
|
182 |
+
The scripts for these evaluations can be found in the [`guidance_eval/`](./guidance_eval) directory.
|
183 |
+
To run these evaluations, please refer to the following shell scripts:
|
184 |
+
- QM9: [`eval_qm9_guidance.sh`](./guidance_eval/eval_qm9_guidance.sh)
|
185 |
+
- Species10: [`eval_ten_species_guidance.sh`](./guidance_eval/eval_ten_species_guidance.sh)
|
186 |
+
- For this dataset, we also evaluate the accuracy of a HyenaDNA classifier on correctly classifying generated sequences.
|
187 |
+
This model can be trained using [`train_ten_species_eval_classifier.sh`](./scripts/train_ten_species_eval_classifier.sh).
|
188 |
+
- To see how this trained evaluation classifier performs on the validation set of the original data use this notebook [`eval_hyenadna_classifier.ipynb`](./notebooks/eval_hyenadna_classifier.ipynb).
|
189 |
+
|
190 |
+
In the paper,
|
191 |
+
we performed an extensive hyperparameter sweep for our proposed guidance mechanisms and for baselines.
|
192 |
+
The shell scripts can be used
|
193 |
+
to reproduce these experiments,
|
194 |
+
e.g., for the D-CFG experiments on QM9:
|
195 |
+
```bash
|
196 |
+
export MODEL=<ar|mdlm|udlm>
|
197 |
+
export PROP=<qed|ring_count>
|
198 |
+
export GUIDANCE=cfg
|
199 |
+
for GAMMA in $(seq 1 5); do
|
200 |
+
sbatch \
|
201 |
+
--export=ALL,MODEL=${MODEL},PROP=${PROP},GUIDANCE=${GUIDANCE},GAMMA=${GAMMA} \
|
202 |
+
--job-name=eval_qm9_${GUIDANCE}_${PROP}_${MODEL}_GAMMA-${GAMMA} \
|
203 |
+
eval_qm9_guidance.sh
|
204 |
+
done
|
205 |
+
```
|
206 |
+
|
207 |
+
Once each evaluation run is complete,
|
208 |
+
a `.csv` file
|
209 |
+
containing the results is saved in the run directory of the trained generative model.
|
210 |
+
|
211 |
+
## Using HuggingFace Models
|
212 |
+
<a name="hf_models"></a>
|
213 |
+
We provide pre-trained models on HuggingFace 🤗:
|
214 |
+
- UDLM trained on LM1B: [kuleshov-group/udlm-lm1b](https://huggingface.co/kuleshov-group/udlm-lm1b)
|
215 |
+
- UDLM trained on QM9: [kuleshov-group/udlm-qm9](https://huggingface.co/kuleshov-group/udlm-qm9)
|
216 |
+
- Note: this model was trained without guidance and can be used with classifier-free guidance.
|
217 |
+
|
218 |
+
Please see the README pages for these models on HuggingFace or our paper for more details about the training of these models.
|
219 |
+
|
220 |
+
To use these models, you can load them using the HuggingFace API, e.g.,
|
221 |
+
```python
|
222 |
+
from transformers import AutoModelForMaskedLM
|
223 |
+
|
224 |
+
model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/udlm-lm1b")
|
225 |
+
```
|
226 |
+
|
227 |
+
To use these models in our repository, set the following `config` parameters:
|
228 |
+
```bash
|
229 |
+
backbone="hf_dit"
|
230 |
+
model="hf"
|
231 |
+
model.pretrained_model_name_or_path="kuleshov-group/udlm-lm1b" # or "kuleshov-group/udlm-qm9"
|
232 |
+
```
|
233 |
+
|
234 |
+
## Acknowledgements
|
235 |
+
<a name="acknowledgements"></a>
|
236 |
+
This repository was built off of [MDLM](https://github.com/kuleshov-group/mdlm),
|
237 |
+
which in used [SEDD](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion).
|
238 |
+
Our code implementation of D-CBG is adapted from Nisonoff et al.'s [repo](https://github.com/hnisonoff/discrete_guidance).
|
239 |
+
|
240 |
+
## Citation
|
241 |
+
<a name="citation"></a>
|
242 |
+
```
|
243 |
+
@article{
|
244 |
+
schiff2024discreteguidance,
|
245 |
+
title={Simple Guidance Mechanisms for Discrete Diffusion Models},
|
246 |
+
author={Schiff, Yair and Sahoo, Subham Sekhar and Phung, Hao and Wang, Guanghan and Boshar, Sam and Dalla-torre, Hugo and de Almeida, Bernardo P and Rush, Alexander and Pierrot, Thomas and Kuleshov, Volodymyr},
|
247 |
+
journal={arXiv preprint arXiv:2412.10193},
|
248 |
+
year={2024}
|
249 |
+
}
|
250 |
+
```
|
muppit/__pycache__/classifier.cpython-310.pyc
ADDED
Binary file (13.8 kB). View file
|
|
muppit/__pycache__/dataloader.cpython-310.pyc
ADDED
Binary file (18 kB). View file
|
|
muppit/__pycache__/diffusion.cpython-310.pyc
ADDED
Binary file (33.4 kB). View file
|
|
muppit/__pycache__/noise_schedule.cpython-310.pyc
ADDED
Binary file (6.19 kB). View file
|
|
muppit/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.77 kB). View file
|
|
muppit/classifier.py
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import typing
|
3 |
+
|
4 |
+
import hydra.utils
|
5 |
+
import lightning as L
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchmetrics
|
9 |
+
import transformers
|
10 |
+
|
11 |
+
import dataloader
|
12 |
+
import models.dit
|
13 |
+
import noise_schedule
|
14 |
+
|
15 |
+
|
16 |
+
class MicroAveragingMetric(torchmetrics.Metric):
|
17 |
+
"""Micro-averaging metric.
|
18 |
+
|
19 |
+
Adapted from https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py#L12
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, class_idx: typing.Optional[int] = 1,
|
23 |
+
dist_sync_on_step=False):
|
24 |
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
25 |
+
self.class_idx = torch.tensor(class_idx) \
|
26 |
+
if class_idx is not None else None
|
27 |
+
self.add_state("numerator", default=torch.tensor(0.0),
|
28 |
+
dist_reduce_fx="sum")
|
29 |
+
self.add_state("denominator", default=torch.tensor(0.0),
|
30 |
+
dist_reduce_fx="sum")
|
31 |
+
|
32 |
+
def _update(
|
33 |
+
self, numerator, denominator, preds, y) -> tuple:
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
def update(self, logits: torch.Tensor, y: torch.Tensor):
|
37 |
+
# update metric states
|
38 |
+
preds = torch.argmax(logits, dim=-1)
|
39 |
+
y = y.view(-1)
|
40 |
+
assert preds.shape == y.shape, \
|
41 |
+
f"preds shape {preds.shape} != y shape {y.shape}"
|
42 |
+
self.numerator, self.denominator = self._update(
|
43 |
+
self.numerator, self.denominator, preds, y)
|
44 |
+
|
45 |
+
def compute(self):
|
46 |
+
# compute final result
|
47 |
+
value = self.numerator.float() / self.denominator \
|
48 |
+
if self.denominator.item() > 0. else torch.tensor(0.0)
|
49 |
+
return value
|
50 |
+
|
51 |
+
def reset(self):
|
52 |
+
self.numerator = torch.tensor(0.0).to(self.device)
|
53 |
+
self.denominator = torch.tensor(0.0).to(self.device)
|
54 |
+
|
55 |
+
|
56 |
+
class CrossEntropy(MicroAveragingMetric):
|
57 |
+
"""Calculates cross-entropy loss."""
|
58 |
+
def _update(
|
59 |
+
self, numerator, denominator, logits, y) -> tuple:
|
60 |
+
with torch.no_grad():
|
61 |
+
numerator += F.cross_entropy(
|
62 |
+
logits.view(-1, logits.size(-1)),
|
63 |
+
y.view(-1),
|
64 |
+
ignore_index=-100,
|
65 |
+
reduction='sum')
|
66 |
+
denominator += y.numel()
|
67 |
+
return numerator, denominator
|
68 |
+
|
69 |
+
# Overrides parent class to use logits and not (argmax) preds
|
70 |
+
def update(self, logits: torch.Tensor, y: torch.Tensor):
|
71 |
+
y = y.view(-1)
|
72 |
+
self.numerator, self.denominator = self._update(
|
73 |
+
self.numerator, self.denominator, logits, y)
|
74 |
+
|
75 |
+
|
76 |
+
class Accuracy(MicroAveragingMetric):
|
77 |
+
"""Calculates accuracy.
|
78 |
+
|
79 |
+
Can be used to calculate accuracy per class.
|
80 |
+
Copied from:
|
81 |
+
https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
|
82 |
+
"""
|
83 |
+
|
84 |
+
def _update(
|
85 |
+
self, numerator, denominator, preds, y) -> tuple:
|
86 |
+
if self.class_idx is None:
|
87 |
+
numerator += (preds == y).sum()
|
88 |
+
denominator += y.numel()
|
89 |
+
else:
|
90 |
+
class_idx = self.class_idx
|
91 |
+
relevant_idxs = (y == class_idx)
|
92 |
+
numerator += (preds[relevant_idxs] == class_idx).sum()
|
93 |
+
denominator += relevant_idxs.sum()
|
94 |
+
relevant_idxs = (y != class_idx)
|
95 |
+
numerator += (preds[relevant_idxs] != class_idx).sum()
|
96 |
+
denominator += relevant_idxs.sum()
|
97 |
+
return numerator, denominator
|
98 |
+
|
99 |
+
|
100 |
+
class Precision(MicroAveragingMetric):
|
101 |
+
"""Calculates precision.
|
102 |
+
|
103 |
+
Can be used to calculate precision per class.
|
104 |
+
Adapted from:
|
105 |
+
https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
|
106 |
+
"""
|
107 |
+
|
108 |
+
def _update(self, numerator, denominator, preds, y) -> tuple:
|
109 |
+
class_idx = self.class_idx
|
110 |
+
relevant_idxs = (preds == class_idx)
|
111 |
+
numerator += (y[relevant_idxs] == class_idx).sum()
|
112 |
+
denominator += relevant_idxs.sum()
|
113 |
+
return numerator, denominator
|
114 |
+
|
115 |
+
|
116 |
+
class Recall(MicroAveragingMetric):
|
117 |
+
"""Calculate recall.
|
118 |
+
|
119 |
+
Can be used to calculate recall per class.
|
120 |
+
Adapted from:
|
121 |
+
https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
|
122 |
+
"""
|
123 |
+
|
124 |
+
def _update(self, numerator, denominator, preds, y) -> tuple:
|
125 |
+
class_idx = self.class_idx
|
126 |
+
relevant_idxs = (y == class_idx)
|
127 |
+
numerator += (preds[relevant_idxs] == class_idx).sum()
|
128 |
+
denominator += relevant_idxs.sum()
|
129 |
+
return numerator, denominator
|
130 |
+
|
131 |
+
|
132 |
+
class Classifier(L.LightningModule):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
config,
|
136 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
137 |
+
pretrained_backbone: typing.Optional[torch.nn.Module] = None):
|
138 |
+
super().__init__()
|
139 |
+
self.save_hyperparameters(ignore=['pretrained_backbone'])
|
140 |
+
self.config = config
|
141 |
+
|
142 |
+
# This param indicates whether this model will be used
|
143 |
+
# for guidance (False) or only evaluation (True).
|
144 |
+
self.is_eval_classifier = getattr(
|
145 |
+
config, 'is_eval_classifier', False)
|
146 |
+
|
147 |
+
self.tokenizer = tokenizer
|
148 |
+
self.vocab_size = tokenizer.vocab_size
|
149 |
+
self.antithetic_sampling = config.training.antithetic_sampling
|
150 |
+
self.importance_sampling = config.training.importance_sampling
|
151 |
+
self.change_of_variables = config.training.change_of_variables
|
152 |
+
if (not hasattr(self.tokenizer, 'mask_token')
|
153 |
+
or self.tokenizer.mask_token is None):
|
154 |
+
self.mask_index = self.vocab_size
|
155 |
+
self.vocab_size += 1
|
156 |
+
else:
|
157 |
+
self.mask_index = self.tokenizer.mask_token_id
|
158 |
+
|
159 |
+
if config.classifier_backbone == 'dit':
|
160 |
+
self.classifier_model = models.dit.DITClassifier(
|
161 |
+
self.config, vocab_size=self.vocab_size)
|
162 |
+
elif self.config.classifier_backbone == 'dimamba':
|
163 |
+
self.classifier_model = models.dimamba.DiMambaClassifier(
|
164 |
+
self.config, vocab_size=self.vocab_size,
|
165 |
+
pad_token_id=self.tokenizer.pad_token_id)
|
166 |
+
elif config.classifier_backbone == 'hyenadna':
|
167 |
+
hyena_config = transformers.AutoConfig.from_pretrained(
|
168 |
+
config.classifier_model.hyena_model_name_or_path,
|
169 |
+
n_layer=config.classifier_model.n_layer,
|
170 |
+
trust_remote_code=True
|
171 |
+
)
|
172 |
+
self.classifier_model = transformers.AutoModelForSequenceClassification.from_config(
|
173 |
+
hyena_config,
|
174 |
+
pretrained=False,
|
175 |
+
num_labels=config.data.num_classes,
|
176 |
+
problem_type='single_label_classification',
|
177 |
+
trust_remote_code=True
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
raise NotImplementedError(
|
181 |
+
f"Classifier backbone "
|
182 |
+
f"{self.config.classifier_backbone} not "
|
183 |
+
f"implemented.")
|
184 |
+
if pretrained_backbone is not None: # For PPLM / NOS
|
185 |
+
self.classifier_model.load_pretrained_encoder(
|
186 |
+
pretrained_backbone)
|
187 |
+
# Metrics are automatically reset at end of epoch
|
188 |
+
metrics = torchmetrics.MetricCollection({
|
189 |
+
'cross_entropy': CrossEntropy(),
|
190 |
+
'accuracy': Accuracy(class_idx=None),
|
191 |
+
})
|
192 |
+
if config.data.num_classes > 2:
|
193 |
+
for c in range(config.data.num_classes):
|
194 |
+
metrics.add_metrics(
|
195 |
+
{f"accuracy_class{c}": Accuracy(class_idx=c),
|
196 |
+
f"precision_class{c}": Precision(class_idx=c),
|
197 |
+
f"recall_class{c}": Recall(class_idx=c)})
|
198 |
+
else:
|
199 |
+
metrics.add_metrics(
|
200 |
+
{'precision': Precision(class_idx=1),
|
201 |
+
'recall': Recall(class_idx=1)})
|
202 |
+
metrics.set_dtype(torch.float64)
|
203 |
+
self.train_metrics = metrics.clone(prefix='train/')
|
204 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
205 |
+
|
206 |
+
self.T = config.T
|
207 |
+
self.noise = noise_schedule.get_noise(config,
|
208 |
+
dtype=self.dtype)
|
209 |
+
self.sampling_eps = config.training.sampling_eps
|
210 |
+
self.lr = config.optim.lr
|
211 |
+
self.time_conditioning = config.time_conditioning
|
212 |
+
self.fast_forward_epochs = None
|
213 |
+
self.fast_forward_batches = None
|
214 |
+
|
215 |
+
def on_load_checkpoint(self, checkpoint):
|
216 |
+
# Copied from:
|
217 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
|
218 |
+
self.fast_forward_epochs = checkpoint['loops'][
|
219 |
+
'fit_loop']['epoch_progress']['current']['completed']
|
220 |
+
self.fast_forward_batches = checkpoint['loops'][
|
221 |
+
'fit_loop']['epoch_loop.batch_progress'][
|
222 |
+
'current']['completed']
|
223 |
+
|
224 |
+
def on_save_checkpoint(self, checkpoint):
|
225 |
+
# Copied from:
|
226 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
|
227 |
+
# ['epoch_loop.batch_progress']['total']['completed'] is
|
228 |
+
# 1 iteration behind, so we're using the optimizer's
|
229 |
+
# progress.
|
230 |
+
checkpoint['loops']['fit_loop'][
|
231 |
+
'epoch_loop.batch_progress']['total'][
|
232 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
233 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
234 |
+
'optimizer']['step']['total'][
|
235 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
236 |
+
checkpoint['loops']['fit_loop'][
|
237 |
+
'epoch_loop.batch_progress']['current'][
|
238 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
239 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
240 |
+
'optimizer']['step']['current'][
|
241 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
242 |
+
# _batches_that_stepped tracks the number of global
|
243 |
+
# steps, not the number of local steps, so we don't
|
244 |
+
# multiply with self.trainer.accumulate_grad_batches
|
245 |
+
# here.
|
246 |
+
checkpoint['loops']['fit_loop'][
|
247 |
+
'epoch_loop.state_dict'][
|
248 |
+
'_batches_that_stepped'] = \
|
249 |
+
checkpoint['loops']['fit_loop'][
|
250 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
251 |
+
'optimizer']['step']['total']['completed']
|
252 |
+
if 'sampler' not in checkpoint.keys():
|
253 |
+
checkpoint['sampler'] = {}
|
254 |
+
if hasattr(self.trainer.train_dataloader.sampler,
|
255 |
+
'state_dict'):
|
256 |
+
sampler_state_dict = self.trainer. \
|
257 |
+
train_dataloader.sampler.state_dict()
|
258 |
+
checkpoint['sampler'][
|
259 |
+
'random_state'] = sampler_state_dict.get(
|
260 |
+
'random_state', None)
|
261 |
+
else:
|
262 |
+
checkpoint['sampler']['random_state'] = None
|
263 |
+
|
264 |
+
def on_train_start(self):
|
265 |
+
# Adapted from:
|
266 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
|
267 |
+
distributed = (
|
268 |
+
self.trainer._accelerator_connector.use_distributed_sampler
|
269 |
+
and self.trainer._accelerator_connector.is_distributed)
|
270 |
+
if distributed:
|
271 |
+
sampler_cls = dataloader.FaultTolerantDistributedSampler
|
272 |
+
else:
|
273 |
+
sampler_cls = dataloader.RandomFaultTolerantSampler
|
274 |
+
updated_dls = []
|
275 |
+
for dl in self.trainer.fit_loop._combined_loader.flattened:
|
276 |
+
if hasattr(dl.sampler, 'shuffle'):
|
277 |
+
dl_sampler = sampler_cls(
|
278 |
+
dl.dataset, shuffle=dl.sampler.shuffle)
|
279 |
+
else:
|
280 |
+
dl_sampler = sampler_cls(dl.dataset)
|
281 |
+
if (distributed
|
282 |
+
and self.fast_forward_epochs is not None
|
283 |
+
and self.fast_forward_batches is not None):
|
284 |
+
dl_sampler.load_state_dict({
|
285 |
+
'epoch': self.fast_forward_epochs,
|
286 |
+
'counter': (self.fast_forward_batches
|
287 |
+
* self.config.loader.batch_size)})
|
288 |
+
updated_dls.append(
|
289 |
+
torch.utils.data.DataLoader(
|
290 |
+
dl.dataset,
|
291 |
+
batch_size=self.config.loader.batch_size,
|
292 |
+
num_workers=self.config.loader.num_workers,
|
293 |
+
pin_memory=self.config.loader.pin_memory,
|
294 |
+
sampler=dl_sampler,
|
295 |
+
shuffle=False,
|
296 |
+
persistent_workers=self.config.loader.persistent_workers
|
297 |
+
))
|
298 |
+
self.trainer.fit_loop._combined_loader.flattened = updated_dls
|
299 |
+
|
300 |
+
def forward(self, x, sigma=None, x_emb=None, attention_mask=None):
|
301 |
+
"""Returns logits.
|
302 |
+
|
303 |
+
x_emb can be provided during PPLM / NoS-style guidance
|
304 |
+
(see: https://arxiv.org/abs/2305.20009).
|
305 |
+
"""
|
306 |
+
if self.is_eval_classifier:
|
307 |
+
logits = self.classifier_model(x)
|
308 |
+
if hasattr(logits, 'logits'):
|
309 |
+
logits = logits.logits
|
310 |
+
else:
|
311 |
+
sigma = self._process_sigma(sigma) if sigma is not None else sigma
|
312 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
313 |
+
logits = self.classifier_model(x, sigma, x_emb=x_emb, attention_mask=attention_mask)
|
314 |
+
return logits
|
315 |
+
|
316 |
+
def get_log_probs(self, x, sigma, x_emb=None):
|
317 |
+
"""Returns log probabilities.
|
318 |
+
Use for CBG-style guidance.
|
319 |
+
"""
|
320 |
+
if self.is_eval_classifier:
|
321 |
+
raise NotImplementedError(
|
322 |
+
'`get_log_prob` not implemented for classifiers '
|
323 |
+
'that are meant to be used for evaluation purposes '
|
324 |
+
'only.')
|
325 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
326 |
+
return torch.nn.functional.log_softmax(
|
327 |
+
self.forward(x, sigma, x_emb=x_emb), dim=-1)
|
328 |
+
|
329 |
+
def training_step(self, batch, batch_idx):
|
330 |
+
loss = self._compute_loss(batch, prefix='train')
|
331 |
+
self.log(name='trainer/loss',
|
332 |
+
value=loss.item(),
|
333 |
+
on_step=True,
|
334 |
+
on_epoch=False,
|
335 |
+
sync_dist=True,
|
336 |
+
prog_bar=True)
|
337 |
+
self.log(name='lr',
|
338 |
+
value=
|
339 |
+
self.trainer.optimizers[0].param_groups[0][
|
340 |
+
'lr'],
|
341 |
+
on_step=True,
|
342 |
+
on_epoch=False,
|
343 |
+
sync_dist=True,
|
344 |
+
prog_bar=True, logger=False)
|
345 |
+
return loss
|
346 |
+
|
347 |
+
def validation_step(self, batch, batch_idx):
|
348 |
+
return self._compute_loss(batch, prefix='val')
|
349 |
+
|
350 |
+
def configure_optimizers(self):
|
351 |
+
# TODO(yair): Lightning currently giving this warning when using `fp16`:
|
352 |
+
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
|
353 |
+
# Not clear if this is a problem or not.
|
354 |
+
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
|
355 |
+
optimizer = torch.optim.AdamW(
|
356 |
+
itertools.chain(self.classifier_model.parameters(),
|
357 |
+
self.noise.parameters()),
|
358 |
+
lr=self.config.optim.lr,
|
359 |
+
betas=(self.config.optim.beta1,
|
360 |
+
self.config.optim.beta2),
|
361 |
+
eps=self.config.optim.eps,
|
362 |
+
weight_decay=self.config.optim.weight_decay)
|
363 |
+
|
364 |
+
scheduler = hydra.utils.instantiate(
|
365 |
+
self.config.lr_scheduler, optimizer=optimizer)
|
366 |
+
scheduler_dict = {
|
367 |
+
'scheduler': scheduler,
|
368 |
+
'interval': 'step',
|
369 |
+
'monitor': 'val/loss',
|
370 |
+
'name': 'trainer/lr',
|
371 |
+
}
|
372 |
+
return [optimizer], [scheduler_dict]
|
373 |
+
|
374 |
+
def _q_xt(self, x, move_chance):
|
375 |
+
"""Computes the noisy sample xt.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
x: int torch.Tensor with shape (batch_size,
|
379 |
+
diffusion_model_input_length), input.
|
380 |
+
move_chance: float torch.Tensor with shape
|
381 |
+
(batch_size, 1).
|
382 |
+
"""
|
383 |
+
move_indices = torch.rand(
|
384 |
+
*x.shape, device=x.device) < move_chance
|
385 |
+
if self.config.diffusion == 'absorbing_state':
|
386 |
+
return torch.where(move_indices, self.mask_index, x)
|
387 |
+
if self.config.diffusion == 'uniform':
|
388 |
+
uniform_tensor = torch.randint(
|
389 |
+
0, self.vocab_size, x.shape, device=x.device)
|
390 |
+
return torch.where(move_indices, uniform_tensor, x)
|
391 |
+
raise NotImplementedError(
|
392 |
+
f'Diffusion type {self.config.diffusion} not '
|
393 |
+
'implemented.')
|
394 |
+
|
395 |
+
def _compute_loss(self, batch, prefix):
|
396 |
+
x0 = batch['input_ids']
|
397 |
+
attention_mask = batch['attention_mask']
|
398 |
+
t = None
|
399 |
+
if self.is_eval_classifier:
|
400 |
+
logits = self.forward(x0)
|
401 |
+
elif self.config.parameterization == 'ar':
|
402 |
+
# do not add noise for AR FUDGE and AR PPLM
|
403 |
+
logits = self.forward(
|
404 |
+
x0, attention_mask=attention_mask)
|
405 |
+
else:
|
406 |
+
t = self._sample_t(x0.shape[0])
|
407 |
+
if self.T > 0:
|
408 |
+
t = (t * self.T).to(torch.int)
|
409 |
+
t = t / self.T
|
410 |
+
# t \in {1/T, 2/T, ..., 1}
|
411 |
+
t += (1 / self.T)
|
412 |
+
if self.change_of_variables:
|
413 |
+
time_conditioning = t[:, None]
|
414 |
+
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
|
415 |
+
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
|
416 |
+
move_chance = torch.exp(f_0 + t * (f_T - f_0))
|
417 |
+
move_chance = move_chance[:, None]
|
418 |
+
else:
|
419 |
+
sigma, _ = self.noise(t)
|
420 |
+
time_conditioning = sigma[:, None]
|
421 |
+
move_chance = 1 - torch.exp(-sigma[:, None])
|
422 |
+
|
423 |
+
xt = self._q_xt(x0, move_chance)
|
424 |
+
logits = self.forward(xt, time_conditioning, attention_mask=attention_mask)
|
425 |
+
if hasattr(self.config.data, 'label_col'):
|
426 |
+
if f"{self.config.data.label_col}_threshold" in batch:
|
427 |
+
y = batch[f"{self.config.data.label_col}_threshold"]
|
428 |
+
else:
|
429 |
+
y = batch[self.config.data.label_col]
|
430 |
+
else:
|
431 |
+
y = batch['label']
|
432 |
+
if (not self.is_eval_classifier
|
433 |
+
and getattr(self.config.training, 'use_label_smoothing', False)):
|
434 |
+
# Interpolate between one-hot and uniform distribution
|
435 |
+
labels = (torch.nn.functional.one_hot(y, self.config.data.num_classes) * (1 - t)[..., None] +
|
436 |
+
(1 / self.config.data.num_classes) * t[..., None])
|
437 |
+
else:
|
438 |
+
labels = y.view(-1)
|
439 |
+
if getattr(self.config, 'is_fudge_classifier', False):
|
440 |
+
expanded_y = y.unsqueeze(1).expand(-1, logits.shape[1]) # batch x seq
|
441 |
+
logits = logits.view(-1, self.config.data.num_classes)[attention_mask.flatten()==1, ...]
|
442 |
+
y = expanded_y.flatten().long()[attention_mask.flatten()==1]
|
443 |
+
loss = torch.nn.functional.cross_entropy(
|
444 |
+
logits,
|
445 |
+
y,
|
446 |
+
ignore_index=-100,
|
447 |
+
reduction='mean')
|
448 |
+
else:
|
449 |
+
loss = torch.nn.functional.cross_entropy(
|
450 |
+
logits.view(-1, logits.size(-1)),
|
451 |
+
labels,
|
452 |
+
ignore_index=-100,
|
453 |
+
reduction='mean')
|
454 |
+
|
455 |
+
if prefix == 'train':
|
456 |
+
self.train_metrics.update(logits, y)
|
457 |
+
metrics = self.train_metrics
|
458 |
+
elif prefix == 'val':
|
459 |
+
self.valid_metrics.update(logits, y)
|
460 |
+
metrics = self.valid_metrics
|
461 |
+
elif prefix == 'test':
|
462 |
+
self.test_metrics.update(logits, y)
|
463 |
+
metrics = self.test_metrics
|
464 |
+
else:
|
465 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
466 |
+
|
467 |
+
self.log_dict(metrics,
|
468 |
+
on_step=False,
|
469 |
+
on_epoch=True,
|
470 |
+
sync_dist=True)
|
471 |
+
return loss
|
472 |
+
|
473 |
+
def _sample_t(self, n):
|
474 |
+
_eps_t = torch.rand(n, device=self.device)
|
475 |
+
if self.antithetic_sampling:
|
476 |
+
offset = torch.arange(n, device=self.device) / n
|
477 |
+
_eps_t = (_eps_t / n + offset) % 1
|
478 |
+
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
|
479 |
+
if self.importance_sampling:
|
480 |
+
return self.noise.importance_sampling_transformation(
|
481 |
+
t)
|
482 |
+
return t
|
483 |
+
|
484 |
+
def _process_sigma(self, sigma):
|
485 |
+
if sigma.ndim > 1:
|
486 |
+
sigma = sigma.squeeze(-1)
|
487 |
+
if not self.time_conditioning:
|
488 |
+
sigma = torch.zeros_like(sigma)
|
489 |
+
assert sigma.ndim == 1, sigma.shape
|
490 |
+
return sigma
|
muppit/configs/callbacks/checkpoint_every_n_steps.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint_every_n_steps:
|
2 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
3 |
+
save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
|
4 |
+
save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
|
5 |
+
dirpath: ${checkpointing.save_dir}/checkpoints
|
6 |
+
verbose: True
|
7 |
+
auto_insert_metric_name: False
|
8 |
+
# every_n_train_steps: 500
|
muppit/configs/callbacks/checkpoint_monitor.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint_monitor:
|
2 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
3 |
+
monitor: val/nll # name of the logged metric which determines when model is improving
|
4 |
+
mode: min # can be "max" or "min"
|
5 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
6 |
+
save_last: False # True = additionally always save model from last epoch
|
7 |
+
dirpath: ${checkpointing.save_dir}/checkpoints
|
8 |
+
filename: best
|
9 |
+
auto_insert_metric_name: False
|
10 |
+
verbose: True
|
muppit/configs/callbacks/learning_rate_monitor.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
learning_rate_monitor:
|
2 |
+
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
3 |
+
logging_interval: step
|
muppit/configs/classifier_model/dimamba-classifier.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: dimamba
|
2 |
+
type: dimamba
|
3 |
+
hidden_size: 256
|
4 |
+
cond_dim: 128
|
5 |
+
length: ${model.length} # Same length as diffusion model
|
6 |
+
n_blocks: 8
|
7 |
+
scale_by_sigma: True
|
8 |
+
dropout: 0.1
|
9 |
+
tie_word_embeddings: False
|
10 |
+
bidirectional: True,
|
11 |
+
bidirectional_strategy: add
|
12 |
+
bidirectional_weight_tie: True
|
13 |
+
num_classes: ${data.num_classes}
|
14 |
+
pooling: mean
|
muppit/configs/classifier_model/hyenadna-classifier.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: hyena-32k
|
2 |
+
type: hyenadna
|
3 |
+
hyena_model_name_or_path: ???
|
4 |
+
n_layer: 4
|
muppit/configs/classifier_model/small-classifier.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: small
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 768
|
4 |
+
cond_dim: 128
|
5 |
+
length: ${model.length} # Same length as diffusion model
|
6 |
+
n_blocks: 12
|
7 |
+
n_heads: 12
|
8 |
+
scale_by_sigma: True
|
9 |
+
dropout: 0.1
|
10 |
+
num_classes: ${data.num_classes}
|
11 |
+
pooling: mean
|
muppit/configs/classifier_model/tiny-classifier.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: tiny
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 512
|
4 |
+
cond_dim: 128
|
5 |
+
length: ${model.length} # Same length as diffusion model
|
6 |
+
n_blocks: 8
|
7 |
+
n_heads: 8
|
8 |
+
scale_by_sigma: True
|
9 |
+
dropout: 0.1
|
10 |
+
num_classes: ${data.num_classes}
|
11 |
+
pooling: mean
|
muppit/configs/classifier_model/tiny-dimamba-classifier.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: tiny
|
2 |
+
type: dimamba
|
3 |
+
hidden_size: 128
|
4 |
+
cond_dim: 128
|
5 |
+
length: ${model.length} # Same length as diffusion model
|
6 |
+
n_blocks: 4
|
7 |
+
scale_by_sigma: True
|
8 |
+
dropout: 0.1
|
9 |
+
tie_word_embeddings: False
|
10 |
+
bidirectional: True,
|
11 |
+
bidirectional_strategy: add
|
12 |
+
bidirectional_weight_tie: True
|
13 |
+
num_classes: ${data.num_classes}
|
14 |
+
pooling: mean
|
muppit/configs/config.yaml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
|
4 |
+
- /data: peptide
|
5 |
+
- /model: small
|
6 |
+
- /strategy: ddp
|
7 |
+
- /noise: loglinear
|
8 |
+
- /lr_scheduler: cosine_decay_warmup # constant_warmup
|
9 |
+
- /classifier_model: null
|
10 |
+
- /guidance: cbg
|
11 |
+
|
12 |
+
mode: ppl_eval # train / train_classifier / ppl_eval
|
13 |
+
diffusion: uniform # absorbing_state / uniform
|
14 |
+
backbone: dit # dit / dimamba / ar
|
15 |
+
classifier_backbone: null
|
16 |
+
parameterization: d3pm # subs / d3pm / ar
|
17 |
+
time_conditioning: True # UDLM is conditioned on time
|
18 |
+
subs_masking: False
|
19 |
+
zero_recon_loss: True # Use for UDLM
|
20 |
+
T: 0 # 0 (continuous time) / 1000
|
21 |
+
|
22 |
+
is_vision: False
|
23 |
+
seed: 13
|
24 |
+
|
25 |
+
loader:
|
26 |
+
global_batch_size: 512
|
27 |
+
eval_global_batch_size: ${.global_batch_size}
|
28 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
29 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
30 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
31 |
+
num_workers: 0 # ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
32 |
+
pin_memory: True
|
33 |
+
persistent_workers: False # True
|
34 |
+
|
35 |
+
sampling:
|
36 |
+
use_cache: True
|
37 |
+
steps: 128
|
38 |
+
# Note: batch_size is **per machine**
|
39 |
+
batch_size: 1 # ${loader.eval_batch_size}
|
40 |
+
num_sample_batches: 50 # Total samples: `num_gpus` * `batch_size` * `num_sample_batches`
|
41 |
+
use_float64: False
|
42 |
+
|
43 |
+
eval:
|
44 |
+
checkpoint_path: '/home/tc415/muPPIt_embedding/muppit/model_path/PeptideUDLM.ckpt' # Used to evaluate a checkpoint after training.
|
45 |
+
wildtype: 'MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCSRLHNKPTFSQTIALLNIYRNPQNSSQSADGLRCAVSDVEMQEHYDEFFEEVFTEMEEKYGEVEEMNVCDNLGDHLVGNVYVKFRREEDAEKAVIDLNNRWFNGQPIHAELSPVTDFREACCRQYEMGECTRGGFCNFMHLKPISRELRRELYGRRRKKHRSRSRSRERRSRSRDRGRGGGGGGGGGGGGRERDRRRSRDRERSGRF'
|
46 |
+
mutant: 'MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCFRLHNKPTFSQTIALLNIYRNPQNSSQSADGLRCAVSDVEMQEHYDEFFEEVFTEMEEKYGEVEEMNVCDNLGDHLVGNVYVKFRREEDAEKAVIDLNNRWFNGQPIHAELSPVTDFREACCRQYEMGECTRGGFCNFMHLKPISRELRRELYGRRRKKHRSRSRSRERRSRSRDRGRGGGGGGGGGGGGRERDRRRSRDRERSGRF'
|
47 |
+
|
48 |
+
disable_ema: False
|
49 |
+
generate_samples: True
|
50 |
+
generated_samples_path: ''
|
51 |
+
max_samples: 50_000
|
52 |
+
|
53 |
+
training:
|
54 |
+
ema: 0.9999
|
55 |
+
antithetic_sampling: True
|
56 |
+
importance_sampling: False
|
57 |
+
sampling_eps: 1e-3
|
58 |
+
change_of_variables: False
|
59 |
+
compute_loss_on_pad_tokens: True
|
60 |
+
use_simple_ce_loss: False # Ignore ELBO; just use CE
|
61 |
+
guidance: null # Can turn off with `training.guidance: null`
|
62 |
+
# cond_dropout: 0.0
|
63 |
+
|
64 |
+
optim:
|
65 |
+
weight_decay: 1e-4
|
66 |
+
lr: 1e-5
|
67 |
+
beta1: 0.9
|
68 |
+
beta2: 0.999
|
69 |
+
eps: 1e-8
|
70 |
+
|
71 |
+
trainer:
|
72 |
+
_target_: lightning.Trainer
|
73 |
+
accelerator: cuda
|
74 |
+
num_nodes: 1
|
75 |
+
devices: 2 # ${device_count:}
|
76 |
+
accumulate_grad_batches: 1 # ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
77 |
+
gradient_clip_val: 1.0
|
78 |
+
precision: 'bf16-mixed'
|
79 |
+
num_sanity_val_steps: 2
|
80 |
+
# max_epochs: 10
|
81 |
+
max_steps: 1652000
|
82 |
+
log_every_n_steps: 100
|
83 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
84 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
85 |
+
val_check_interval: 16520 # 2545
|
86 |
+
|
87 |
+
wandb:
|
88 |
+
project: moPPIt-v2
|
89 |
+
job_type: model-training
|
90 |
+
name: protein_medium_100epochs_lr1e-5_gradclip1_wd1e-4_dropout0.1 #epochs10_lr3e-4_bsz8_64-true_all-params_gradclip1_beta-one0.9_beta-two0.999
|
91 |
+
id: ${.name}
|
92 |
+
|
93 |
+
hydra:
|
94 |
+
run:
|
95 |
+
dir: ./outputs/${wandb.name} # ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
|
96 |
+
job:
|
97 |
+
chdir: true
|
98 |
+
|
99 |
+
checkpointing:
|
100 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
101 |
+
save_dir: ${cwd:}
|
102 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
103 |
+
resume_from_ckpt: False
|
104 |
+
resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
|
muppit/configs/data/amazon_polarity.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: amazon_polarity
|
2 |
+
valid: amazon_polarity
|
3 |
+
tokenizer_name_or_path: bert-base-uncased
|
4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
5 |
+
wrap: False
|
6 |
+
streaming: False
|
7 |
+
override_cache: False
|
8 |
+
add_special_tokens: True
|
9 |
+
label_col: label
|
10 |
+
num_classes: 2
|
muppit/configs/data/cifar10.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: ??? # (Local) Path to CIFAR-10 training data
|
2 |
+
valid: ??? # (Local) Path to CIFAR-10 validation data
|
3 |
+
label_col: labels
|
4 |
+
num_classes: 10
|
5 |
+
streaming: False
|
6 |
+
size: 1024
|
7 |
+
length: 3072
|
8 |
+
add_special_tokens: True
|
9 |
+
add_mask_token: True
|
10 |
+
tokenizer_name_or_path: raw_pixels
|
11 |
+
|
muppit/configs/data/lm1b.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: lm1b
|
2 |
+
valid: lm1b
|
3 |
+
tokenizer_name_or_path: bert-base-uncased
|
4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
5 |
+
wrap: False
|
6 |
+
streaming: False
|
7 |
+
override_cache: False
|
8 |
+
add_special_tokens: True
|
muppit/configs/data/peptide.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: peptide
|
2 |
+
valid: peptide
|
3 |
+
tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D
|
4 |
+
cache_dir: /home/tc415/discrete-diffusion-guidance/dataset
|
5 |
+
wrap: False
|
6 |
+
streaming: False
|
7 |
+
override_cache: False
|
8 |
+
add_special_tokens: True
|
muppit/configs/data/protein.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: protein_400k
|
2 |
+
valid: protein_400k
|
3 |
+
tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D
|
4 |
+
cache_dir: /home/tc415/discrete-diffusion-guidance/dataset
|
5 |
+
wrap: False
|
6 |
+
streaming: False
|
7 |
+
override_cache: False
|
8 |
+
add_special_tokens: True
|
muppit/configs/data/qm9.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: qm9
|
2 |
+
valid: qm9
|
3 |
+
tokenizer_name_or_path: yairschiff/qm9-tokenizer
|
4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
5 |
+
wrap: False
|
6 |
+
streaming: False
|
7 |
+
override_cache: False
|
8 |
+
add_special_tokens: True
|
9 |
+
label_col: qed
|
10 |
+
label_col_pctile: 90
|
11 |
+
num_classes: 2
|
muppit/configs/data/ten_species.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: ten_species
|
2 |
+
valid: ten_species
|
3 |
+
tokenizer_name_or_path: kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16
|
4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
5 |
+
wrap: False
|
6 |
+
streaming: False
|
7 |
+
override_cache: False
|
8 |
+
add_special_tokens: False
|
9 |
+
label_col: species_label
|
10 |
+
num_classes: 10
|
11 |
+
rc_aug: False
|
muppit/configs/data/text8.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: When using this dataset, set model.length = 256 to match D3PM setup
|
2 |
+
train: text8
|
3 |
+
valid: text8
|
4 |
+
tokenizer_name_or_path: text8
|
5 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
6 |
+
wrap: True
|
7 |
+
streaming: False
|
8 |
+
override_cache: False
|
9 |
+
add_special_tokens: False
|
muppit/configs/guidance/cbg.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: cbg
|
2 |
+
condition: 0
|
3 |
+
classifier_checkpoint_path: '/home/tc415/muPPIt_embedding/checkpoints/mutBind_small'
|
4 |
+
gamma: 2.0
|
5 |
+
use_approx: False # use first-order approximation
|
muppit/configs/guidance/cfg.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
method: cfg
|
2 |
+
condition: 0
|
3 |
+
gamma: 1.0
|
muppit/configs/guidance/fudge.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: fudge
|
2 |
+
condition: 0
|
3 |
+
classifier_checkpoint_path: ''
|
4 |
+
topk: 20
|
5 |
+
gamma: 1.0
|
muppit/configs/guidance/nos.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: nos
|
2 |
+
condition: 0
|
3 |
+
classifier_checkpoint_path: ''
|
4 |
+
num_nos_steps: 1
|
5 |
+
nos_step_size: 0.1
|
6 |
+
nos_stability_coef: 0.01
|
muppit/configs/guidance/pplm.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
method: pplm
|
2 |
+
condition: 0
|
3 |
+
classifier_checkpoint_path: ''
|
4 |
+
num_pplm_steps: 1
|
5 |
+
pplm_step_size: 0.1
|
6 |
+
pplm_stability_coef: 0.01
|
muppit/configs/lr_scheduler/constant_warmup.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
2 |
+
num_warmup_steps: 2500
|
muppit/configs/lr_scheduler/cosine_decay_warmup.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: utils.CosineDecayWarmupLRScheduler
|
2 |
+
t_in_epochs: False
|
3 |
+
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
|
4 |
+
warmup_prefix: True
|
5 |
+
warmup_lr_init: 1e-7
|
6 |
+
warmup_t: ${eval:0.1*${trainer.max_steps}}
|
7 |
+
lr_min: 1e-7
|
muppit/configs/model/dimamba.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: dimamba
|
2 |
+
type: dimamba
|
3 |
+
hidden_size: 256
|
4 |
+
cond_dim: 128
|
5 |
+
length: 32768
|
6 |
+
n_blocks: 8
|
7 |
+
scale_by_sigma: True
|
8 |
+
dropout: 0.1
|
9 |
+
tie_word_embeddings: False
|
10 |
+
bidirectional: True,
|
11 |
+
bidirectional_strategy: add
|
12 |
+
bidirectional_weight_tie: True
|
muppit/configs/model/fudge_predictor.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: fudge_predictor
|
2 |
+
type: lstm
|
3 |
+
hidden_dim: 300
|
4 |
+
length: 1024
|
muppit/configs/model/hf.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
pretrained_model_name_or_path: null
|
2 |
+
length: 128
|
muppit/configs/model/medium.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: medium
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 1024
|
4 |
+
cond_dim: 128
|
5 |
+
length: 82
|
6 |
+
n_blocks: 24
|
7 |
+
n_heads: 16
|
8 |
+
scale_by_sigma: True
|
9 |
+
dropout: 0.1
|
10 |
+
tie_word_embeddings: False
|
muppit/configs/model/small.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: small
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 768
|
4 |
+
cond_dim: 128
|
5 |
+
length: 15
|
6 |
+
# length_range: '6-49'
|
7 |
+
n_blocks: 12
|
8 |
+
n_heads: 12
|
9 |
+
scale_by_sigma: True
|
10 |
+
dropout: 0.1
|
11 |
+
tie_word_embeddings: False
|
muppit/configs/model/tiny.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: tiny
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 512
|
4 |
+
cond_dim: 128
|
5 |
+
length: 1024
|
6 |
+
n_blocks: 8
|
7 |
+
n_heads: 8
|
8 |
+
scale_by_sigma: True
|
9 |
+
dropout: 0.1
|
10 |
+
tie_word_embeddings: False
|
muppit/configs/model/unet.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: unet
|
2 |
+
type: unet
|
3 |
+
ch: 128
|
4 |
+
num_res_blocks: 2
|
5 |
+
num_scales: 4
|
6 |
+
ch_mult: [1, 2, 2, 2]
|
7 |
+
input_channels: 3
|
8 |
+
output_channels: -1 # determined by vocab_size
|
9 |
+
scale_count_to_put_attn: 1 # at 16 res
|
10 |
+
data_min_max: [0, 255] # No need currently
|
11 |
+
dropout: 0.1
|
12 |
+
skip_rescale: True
|
13 |
+
time_conditioning: True # Whether to add in time embeddings
|
14 |
+
time_scale_factor: 1000
|
15 |
+
time_embed_dim: ${.ch}
|
16 |
+
fix_logistic: False
|
17 |
+
size: ${data.size}
|
18 |
+
cond_dim: ${.ch}
|
19 |
+
length: ${data.length}
|
muppit/configs/model/unet_campbell.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: unet
|
2 |
+
type: unet
|
3 |
+
ch: 128
|
4 |
+
num_res_blocks: 2
|
5 |
+
num_scales: 4
|
6 |
+
ch_mult: [1, 2, 2, 2]
|
7 |
+
input_channels: 3
|
8 |
+
output_channels: -1 # determined by input_channels * 2
|
9 |
+
scale_count_to_put_attn: 1 # at 16 res
|
10 |
+
data_min_max: [0, 255] # No need currently, determined by [0, vocab_size]
|
11 |
+
dropout: 0.1
|
12 |
+
skip_rescale: True
|
13 |
+
time_conditioning: True # Whether to add in time embeddings
|
14 |
+
time_scale_factor: 1000
|
15 |
+
time_embed_dim: ${.ch}
|
16 |
+
fix_logistic: False
|
17 |
+
size: ${data.size}
|
18 |
+
cond_dim: ${.ch}
|
19 |
+
length: ${data.length}
|
muppit/configs/noise/ar.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
type: ar
|
2 |
+
scale: 6.0
|
muppit/configs/noise/linear.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
type: linear
|
2 |
+
sigma_min: 1e-3
|
3 |
+
sigma_max: 7.0
|
muppit/configs/noise/loglinear.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
type: loglinear
|
2 |
+
sigma_min: 1e-4
|
3 |
+
sigma_max: 20
|
muppit/configs/noise/polynomial.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
type: polynomial
|
2 |
+
a: -3
|
3 |
+
b: 5
|
4 |
+
c: -4
|
5 |
+
eps: 1e-3
|
muppit/configs/strategy/ddp.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: lightning.pytorch.strategies.DDPStrategy
|
2 |
+
find_unused_parameters: false
|
muppit/configs/strategy/fsdp.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# TODO(yair): Currently not compatible with grad clipping
|
2 |
+
_target_: lightning.pytorch.strategies.FSDPStrategy
|
3 |
+
sharding_strategy: SHARD_GRAD_OP
|
muppit/custom_datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import discretized_cifar10
|
2 |
+
from . import ten_species_dataset
|