|
import pdb |
|
from pytorch_lightning.strategies import DDPStrategy |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader, DistributedSampler, BatchSampler, Sampler |
|
from datasets import load_from_disk |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \ |
|
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler |
|
from pytorch_lightning.loggers import WandbLogger |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from transformers.optimization import get_cosine_schedule_with_warmup |
|
from argparse import ArgumentParser |
|
import os |
|
import uuid |
|
import esm |
|
import numpy as np |
|
import torch.distributed as dist |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup |
|
|
|
from torch.optim import Adam, AdamW |
|
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef |
|
import gc |
|
|
|
from models.graph import ProteinGraph |
|
from models.modules_vec import IntraGraphAttention, DiffEmbeddingLayer, MIM, CrossGraphAttention |
|
|
|
os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" |
|
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" |
|
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' |
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
|
binders = [] |
|
mutants = [] |
|
wildtypes = [] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
for b in batch: |
|
binders.append(torch.tensor(b['binder_tokens']).squeeze(0)) |
|
mutants.append(torch.tensor(b['mutant_tokens']).squeeze(0)) |
|
wildtypes.append(torch.tensor(b['wildtype_tokens']).squeeze(0)) |
|
|
|
|
|
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
|
|
return { |
|
'binder_input_ids': binder_input_ids.int(), |
|
'mutant_input_ids': mutant_input_ids.int(), |
|
'wildtype_input_ids': wildtype_input_ids.int(), |
|
} |
|
|
|
|
|
class LengthAwareDistributedSampler(DistributedSampler): |
|
def __init__(self, dataset, key, batch_size, num_replicas=None, rank=None, shuffle=True): |
|
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) |
|
self.dataset = dataset |
|
self.key = key |
|
self.batch_size = batch_size |
|
|
|
|
|
self.indices = sorted(range(len(self.dataset)), key=lambda i: len(self.dataset[i][key])) |
|
|
|
def __iter__(self): |
|
|
|
indices = self.indices[self.rank::self.num_replicas] |
|
|
|
if self.shuffle: |
|
torch.manual_seed(self.epoch) |
|
indices = torch.randperm(len(indices)).tolist() |
|
|
|
|
|
for i in range(0, len(indices), self.batch_size): |
|
yield indices[i:i+self.batch_size] |
|
|
|
def __len__(self): |
|
return len(self.indices) // self.num_replicas |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
|
|
class CustomDataModule(pl.LightningDataModule): |
|
def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128): |
|
super().__init__() |
|
self.train_dataset = train_dataset |
|
self.val_dataset = val_dataset |
|
self.batch_size = batch_size |
|
self.tokenizer = tokenizer |
|
|
|
def train_dataloader(self): |
|
|
|
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn, |
|
num_workers=8, pin_memory=True) |
|
|
|
def val_dataloader(self): |
|
|
|
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8, |
|
pin_memory=True) |
|
|
|
def setup(self, stage=None): |
|
if stage == 'test' or stage is None: |
|
pass |
|
|
|
|
|
class CosineAnnealingWithWarmup(_LRScheduler): |
|
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1): |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.base_lr = base_lr |
|
self.max_lr = max_lr |
|
self.min_lr = min_lr |
|
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) |
|
print(f"SELF BASE LRS = {self.base_lrs}") |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup_steps: |
|
|
|
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs] |
|
|
|
|
|
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
|
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) |
|
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay |
|
|
|
return [decayed_lr for base_lr in self.base_lrs] |
|
|
|
class muPPIt(pl.LightningModule): |
|
def __init__(self, d_node, d_edge, d_cross_edge, d_position, num_heads, |
|
num_intra_layers, num_mim_layers, num_cross_layers, lr, delta=1.0): |
|
super(muPPIt, self).__init__() |
|
|
|
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
for param in self.esm.parameters(): |
|
param.requires_grad = False |
|
|
|
self.graph = ProteinGraph(d_node, d_edge, d_position) |
|
|
|
self.intra_graph_att_layers = nn.ModuleList([ |
|
IntraGraphAttention(d_node, d_edge, num_heads) for _ in range(num_intra_layers) |
|
]) |
|
|
|
self.diff_layer = DiffEmbeddingLayer(d_node) |
|
|
|
self.mim_layers = nn.ModuleList([ |
|
MIM(d_node, d_edge, d_node, num_heads) for _ in range(num_mim_layers) |
|
]) |
|
|
|
self.cross_graph_att_layers = nn.ModuleList([ |
|
CrossGraphAttention(d_node, d_cross_edge, d_node, num_heads) for _ in range(num_cross_layers) |
|
]) |
|
|
|
self.cross_graph_edge_mapping = nn.Linear(1, d_cross_edge) |
|
self.mapping = nn.Linear(d_cross_edge, 1) |
|
|
|
self.d_cross_edge = d_cross_edge |
|
self.learning_rate = lr |
|
self.delta = delta |
|
|
|
def forward(self, binder_tokens, wt_tokens, mut_tokens): |
|
device = binder_tokens.device |
|
|
|
|
|
|
|
|
|
binder_node, binder_edge, binder_node_mask, binder_edge_mask = self.graph(binder_tokens, self.esm, self.alphabet) |
|
wt_node, wt_edge, wt_node_mask, wt_edge_mask = self.graph(wt_tokens, self.esm, self.alphabet) |
|
mut_node, mut_edge, mut_node_mask, mut_edge_mask = self.graph(mut_tokens, self.esm, self.alphabet) |
|
|
|
|
|
|
|
for layer in self.intra_graph_att_layers: |
|
binder_node, binder_edge = layer(binder_node, binder_edge) |
|
binder_node = binder_node * binder_node_mask.unsqueeze(-1) |
|
binder_edge = binder_edge * binder_edge_mask.unsqueeze(-1) |
|
|
|
wt_node, wt_edge = layer(wt_node, wt_edge) |
|
wt_node = wt_node * wt_node_mask.unsqueeze(-1) |
|
wt_edge = wt_edge * wt_edge_mask.unsqueeze(-1) |
|
|
|
mut_node, mut_edge = layer(mut_node, mut_edge) |
|
mut_node = mut_node * mut_node_mask.unsqueeze(-1) |
|
mut_edge = mut_edge * mut_edge_mask.unsqueeze(-1) |
|
|
|
|
|
|
|
diff_vec = self.diff_layer(wt_node, mut_node) |
|
|
|
|
|
|
|
for layer in self.mim_layers: |
|
wt_node, wt_edge = layer(wt_node, wt_edge, diff_vec) |
|
wt_node = wt_node * wt_node_mask.unsqueeze(-1) |
|
wt_edge = wt_edge * wt_edge_mask.unsqueeze(-1) |
|
|
|
mut_node, mut_edge = layer(mut_node, mut_edge, diff_vec) |
|
mut_node = mut_node * mut_node_mask.unsqueeze(-1) |
|
mut_edge = mut_edge * mut_edge_mask.unsqueeze(-1) |
|
|
|
|
|
B = mut_node.shape[0] |
|
L_mut = mut_node.shape[1] |
|
L_wt = wt_node.shape[1] |
|
L_binder = binder_node.shape[1] |
|
|
|
mut_binder_edges = torch.randn(B, L_mut, L_binder, self.d_cross_edge).to(device) |
|
wt_binder_edges = torch.randn(B, L_wt, L_binder, self.d_cross_edge).to(device) |
|
|
|
mut_binder_mask = mut_node_mask.unsqueeze(-1) * binder_node_mask.unsqueeze(1).to(device) |
|
wt_binder_mask = wt_node_mask.unsqueeze(-1) * binder_node_mask.unsqueeze(1).to(device) |
|
|
|
|
|
|
|
|
|
|
|
for layer in self.cross_graph_att_layers: |
|
wt_node, binder_node, wt_binder_edges = layer(wt_node, binder_node, wt_binder_edges, diff_vec) |
|
wt_node = wt_node * wt_node_mask.unsqueeze(-1) |
|
binder_node = binder_node * binder_node_mask.unsqueeze(-1) |
|
wt_binder_edges = wt_binder_edges * wt_binder_mask.unsqueeze(-1) |
|
|
|
mut_node, binder_node, mut_binder_edges = layer(mut_node, binder_node, mut_binder_edges, diff_vec) |
|
mut_node = mut_node * mut_node_mask.unsqueeze(-1) |
|
binder_node = binder_node * binder_node_mask.unsqueeze(-1) |
|
mut_binder_edges = mut_binder_edges * mut_binder_mask.unsqueeze(-1) |
|
|
|
wt_binder_edges = torch.mean(wt_binder_edges, dim=(1,2)) |
|
mut_binder_edges = torch.mean(mut_binder_edges, dim=(1,2)) |
|
|
|
wt_pred = torch.sigmoid(self.mapping(wt_binder_edges)) |
|
mut_pred = torch.sigmoid(self.mapping(mut_binder_edges)) |
|
|
|
return wt_pred, mut_pred |
|
|
|
def training_step(self, batch, batch_idx): |
|
opt = self.optimizers() |
|
lr = opt.param_groups[0]['lr'] |
|
self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |
|
|
|
binder_tokens = batch['binder_input_ids'].to(self.device) |
|
mut_tokens = batch['mutant_input_ids'].to(self.device) |
|
wt_tokens = batch['wildtype_input_ids'].to(self.device) |
|
|
|
wt_pred, mut_pred = self.forward(binder_tokens, wt_tokens, mut_tokens) |
|
|
|
wt_loss = (torch.relu(mut_pred) ** 2).mean() |
|
mut_loss = (torch.relu(1 - wt_pred) ** 2).mean() |
|
loss = wt_loss + mut_loss |
|
|
|
|
|
self.log('train_wt_loss', wt_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True) |
|
self.log('train_mut_loss', mut_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True) |
|
self.log('train_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
binder_tokens = batch['binder_input_ids'].to(self.device) |
|
mut_tokens = batch['mutant_input_ids'].to(self.device) |
|
wt_tokens = batch['wildtype_input_ids'].to(self.device) |
|
|
|
wt_pred, mut_pred = self.forward(binder_tokens, wt_tokens, mut_tokens) |
|
|
|
wt_loss = (torch.relu(mut_pred) ** 2).mean() |
|
mut_loss = (torch.relu(1 - wt_pred) ** 2).mean() |
|
loss = wt_loss + mut_loss |
|
|
|
self.log('val_wt_loss', wt_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True) |
|
self.log('val_mut_loss', mut_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True) |
|
self.log('val_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |
|
|
|
def configure_optimizers(self): |
|
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95)) |
|
|
|
base_lr = 1e-5 |
|
max_lr = self.learning_rate |
|
min_lr = 0.1 * self.learning_rate |
|
|
|
schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=600, total_steps=15390, |
|
base_lr=base_lr, max_lr=max_lr, min_lr=min_lr) |
|
|
|
lr_schedulers = { |
|
"scheduler": schedulers, |
|
"name": 'learning_rate_logs', |
|
"interval": 'step', |
|
'frequency': 1 |
|
} |
|
return [optimizer], [lr_schedulers] |
|
|
|
def on_training_epoch_end(self, outputs): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
super().training_epoch_end(outputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
|
|
parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str) |
|
parser.add_argument("-lr", type=float, default=1e-3) |
|
parser.add_argument("-batch_size", type=int, default=2, help="Batch size") |
|
parser.add_argument("-d_node", type=int, default=1024, help="Node Representation Dimension") |
|
parser.add_argument("-d_edge", type=int, default=512, help="Intra-Graph Edge Representation Dimension") |
|
parser.add_argument("-d_cross_edge", type=int, default=512, help="Cross-Graph Edge Representation Dimension") |
|
parser.add_argument("-d_position", type=int, default=8, help="Positional Embedding Dimension") |
|
parser.add_argument("-n_heads", type=int, default=8) |
|
parser.add_argument("-n_intra_layers", type=int, default=1) |
|
parser.add_argument("-n_mim_layers", type=int, default=1) |
|
parser.add_argument("-n_cross_layers", type=int, default=1) |
|
parser.add_argument("-sm", default=None, help="File containing initial params", type=str) |
|
parser.add_argument("-max_epochs", type=int, default=15, help="Max number of epochs to train") |
|
parser.add_argument("-dropout", type=float, default=0.2) |
|
parser.add_argument("-grad_clip", type=float, default=0.5) |
|
parser.add_argument("-delta", type=float, default=1) |
|
args = parser.parse_args() |
|
|
|
|
|
dist.init_process_group(backend='nccl') |
|
|
|
train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/train/ppiref') |
|
val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/val/ppiref') |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size) |
|
|
|
model = muPPIt(args.d_node, args.d_edge, args.d_cross_edge, args.d_position, args.n_heads, |
|
args.n_intra_layers, args.n_mim_layers, args.n_cross_layers, args.lr, args.delta) |
|
if args.sm: |
|
model = muPPIt.load_from_checkpoint(args.sm,args.d_node, args.d_edge, args.d_cross_edge, args.d_position, args.n_heads, |
|
args.n_intra_layers, args.n_mim_layers, args.n_cross_layers, args.lr, args.delta) |
|
else: |
|
print("Train from scratch!") |
|
|
|
run_id = str(uuid.uuid4()) |
|
|
|
logger = WandbLogger(project=f"muppit", |
|
name="debug", |
|
|
|
job_type='model-training', |
|
id=run_id) |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor='val_loss', |
|
dirpath=args.output_file, |
|
filename='model-{epoch:02d}-{val_mcc:.2f}', |
|
save_top_k=-1, |
|
mode='max', |
|
) |
|
|
|
early_stopping_callback = EarlyStopping( |
|
monitor='val_mcc', |
|
patience=5, |
|
verbose=True, |
|
mode='max' |
|
) |
|
|
|
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 3: 4, 20: 2}) |
|
|
|
trainer = pl.Trainer( |
|
max_epochs=args.max_epochs, |
|
accelerator='gpu', |
|
strategy='ddp_find_unused_parameters_true', |
|
precision='bf16', |
|
|
|
devices=[0,1,2], |
|
callbacks=[checkpoint_callback, accumulator, early_stopping_callback], |
|
gradient_clip_val=args.grad_clip, |
|
|
|
) |
|
|
|
trainer.fit(model, datamodule=data_module) |
|
|
|
best_model_path = checkpoint_callback.best_model_path |
|
print(best_model_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|