muPPIt / train.py
AlienChen's picture
Upload 139 files
65bd8af verified
import pdb
from pytorch_lightning.strategies import DDPStrategy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler, BatchSampler, Sampler
from datasets import load_from_disk
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
from pytorch_lightning.loggers import WandbLogger
from torch.optim.lr_scheduler import _LRScheduler
from transformers.optimization import get_cosine_schedule_with_warmup
from argparse import ArgumentParser
import os
import uuid
import esm
import numpy as np
import torch.distributed as dist
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
# from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torch.optim import Adam, AdamW
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
import torch_geometric.nn as pyg_nn
import gc
import math
# os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
vhse8_values = {
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80],
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41],
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36],
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10],
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01],
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65],
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56],
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39],
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85],
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
}
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7}
vhse8_tensor = torch.zeros(24, 8)
for aa, values in vhse8_values.items():
aa_index = aa_to_idx[aa]
vhse8_tensor[aa_index] = torch.tensor(values)
vhse8_tensor.requires_grad = False
def collate_fn(batch):
# Unpack the batch
binders = []
mutants = []
wildtypes = []
affs = []
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
for b in batch:
binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1])
mutant = torch.tensor(b['mutant_input_ids']['input_ids'][1:-1])
wildtype = torch.tensor(b['wildtype_input_ids']['input_ids'][1:-1])
if binder.dim() == 0 or binder.numel() == 0 or mutant.dim() == 0 or mutant.numel() == 0 or wildtype.dim() == 0 or wildtype.numel() == 0:
continue
binders.append(binder) # shape: 1*L1 -> L1
mutants.append(mutant) # shape: 1*L2 -> L2
wildtypes.append(wildtype) # shape: 1*L3 -> L3
affs.append(b['aff'])
# Collate the tensors using torch's pad_sequence
try:
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id)
wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id)
except:
pdb.set_trace()
affs = torch.tensor(affs)
# Return the collated batch
return {
'binder_input_ids': binder_input_ids.int(),
'mutant_input_ids': mutant_input_ids.int(),
'wildtype_input_ids': wildtype_input_ids.int(),
'aff': affs
}
class CustomDataModule(pl.LightningDataModule):
def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.batch_size = batch_size
self.tokenizer = tokenizer
print(len(train_dataset))
print(len(val_dataset))
def train_dataloader(self):
# batch_sampler = LengthAwareDistributedSampler(self.train_dataset, 'mutant_tokens', self.batch_size)
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
num_workers=8, pin_memory=True)
def val_dataloader(self):
# batch_sampler = LengthAwareDistributedSampler(self.val_dataset, 'mutant_tokens', self.batch_size)
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
pin_memory=True)
def setup(self, stage=None):
if stage == 'test' or stage is None:
pass
class CosineAnnealingWithWarmup(_LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.base_lr = base_lr
self.max_lr = max_lr
self.min_lr = min_lr
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
print(f"SELF BASE LRS = {self.base_lrs}")
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# Linear warmup phase from base_lr to max_lr
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
# Cosine annealing phase from max_lr to min_lr
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
return [decayed_lr for base_lr in self.base_lrs]
class muPPIt(pl.LightningModule):
def __init__(self, d_node, num_heads, dropout, margin, lr):
super(muPPIt, self).__init__()
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
for param in self.esm.parameters():
param.requires_grad = False
self.attention = nn.MultiheadAttention(embed_dim=d_node, num_heads=num_heads)
self.layer_norm = nn.LayerNorm(d_node)
self.map = nn.Sequential(
nn.Linear(d_node, d_node // 2),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(d_node // 2, 1)
)
self.margin = margin
self.learning_rate = lr
# self.best_val_acc = 0.0 # To track best validation accuracy
# self.epochs_without_improvement = 0 # To track epochs without improvement
for layer in self.map:
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
if layer.bias is not None:
nn.init.zeros_(layer.bias)
def forward(self, binder_tokens, wt_tokens, mut_tokens):
device = binder_tokens.device
global vhse8_tensor
vhse8_tensor = vhse8_tensor.to(device)
with torch.no_grad():
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int()
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1)
binder_vhse8 = vhse8_tensor[binder_tokens]
binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1)
mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int()
mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * mut_pad_mask.unsqueeze(-1)
mut_vhse8 = vhse8_tensor[mut_tokens]
mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1)
wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int()
wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * wt_pad_mask.unsqueeze(-1)
wt_vhse8 = vhse8_tensor[wt_tokens]
wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1)
binder_wt = torch.concat([binder_embed, wt_embed], dim=1)
binder_mut = torch.concat([binder_embed, mut_embed], dim=1)
binder_wt = binder_wt.transpose(0,1)
binder_mut = binder_mut.transpose(0,1)
binder_wt_attn, _ = self.attention(binder_wt, binder_wt, binder_wt)
binder_mut_attn, _ = self.attention(binder_mut, binder_mut, binder_mut)
binder_wt_attn = binder_wt + binder_wt_attn
binder_mut_attn = binder_mut + binder_mut_attn
binder_wt_attn = binder_wt_attn.transpose(0, 1)
binder_mut_attn = binder_mut_attn.transpose(0, 1)
binder_wt_attn = self.layer_norm(binder_wt_attn)
binder_mut_attn = self.layer_norm(binder_mut_attn)
mapped_binder_wt = self.map(binder_wt_attn).squeeze(-1) # B*(L1+L2)
mapped_binder_mut = self.map(binder_mut_attn).squeeze(-1) # B*(L1+L2)
# mean_binder_wt = torch.mean(mapped_binder_wt, dim=1)
# mean_binder_mut = torch.mean(mapped_binder_mut, dim=1)
# pdb.set_trace()
distance = torch.sqrt(torch.sum((mapped_binder_wt - mapped_binder_mut) ** 2, dim=-1))
return distance
def training_step(self, batch, batch_idx):
opt = self.optimizers()
lr = opt.param_groups[0]['lr']
self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
binder_tokens = batch['binder_input_ids'].to(self.device)
mut_tokens = batch['mutant_input_ids'].to(self.device)
wt_tokens = batch['wildtype_input_ids'].to(self.device)
aff = batch['aff'].to(self.device)
distance = self.forward(binder_tokens, wt_tokens, mut_tokens)
# pdb.set_trace()
# loss = torch.clamp(self.margin * aff - distance, min=0)
upper_loss = F.relu(distance - self.margin *(aff + 1)) # let distance < aff + 1
lower_loss = F.relu(self.margin * aff - distance) # let distance > aff
loss = upper_loss + lower_loss
self.log('train_loss', loss.mean().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
return loss.mean()
def validation_step(self, batch, batch_idx):
binder_tokens = batch['binder_input_ids'].to(self.device)
mut_tokens = batch['mutant_input_ids'].to(self.device)
wt_tokens = batch['wildtype_input_ids'].to(self.device)
aff = batch['aff'].to(self.device)
distance = self.forward(binder_tokens, wt_tokens, mut_tokens)
# pdb.set_trace()
# loss = torch.clamp(self.margin * aff - distance, min=0)
upper_loss = F.relu(distance - self.margin * (aff + 1))
lower_loss = F.relu(self.margin * aff - distance)
loss = upper_loss + lower_loss
accuracy = torch.sum(torch.logical_and(torch.ge(distance, self.margin * aff), torch.le(distance, self.margin *(aff + 1)))) / aff.shape[0]
self.log('val_loss', loss.mean().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.log('val_acc', accuracy.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
return accuracy.item()
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
base_lr = 0.1 * self.learning_rate
max_lr = self.learning_rate
min_lr = 0.1 * self.learning_rate
schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=779, total_steps=7786,
base_lr=base_lr, max_lr=max_lr, min_lr=min_lr)
lr_schedulers = {
"scheduler": schedulers,
"name": 'learning_rate_logs',
"interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
'frequency': 1 # The scheduler updates the learning rate after every batch
}
return [optimizer], [lr_schedulers]
def on_training_epoch_end(self, outputs):
gc.collect()
torch.cuda.empty_cache()
super().training_epoch_end(outputs)
# def on_validation_epoch_end(self):
# avg_val_acc = self.trainer.callback_metrics.get('val_acc')
# if avg_val_acc > self.best_val_acc:
# self.best_val_acc = avg_val_acc
# self.epochs_without_improvement = 0
# else:
# self.epochs_without_improvement += 1
# if self.epochs_without_improvement >= 3:
# self.margin *= 2
# self.epochs_without_improvement = 0
# print(f"Margin increased to {self.margin} due to no improvement in validation accuracy for 3 epochs.")
# # Log margin to track its changes
# # self.log('current_margin', self.margin, prog_bar=True, logger=True)
def main(args):
print(args)
dist.init_process_group(backend='nccl')
train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_skempi_2') #16689, 16609, 17465
val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_skempi_2')
# val_dataset = None
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
model = muPPIt(args.d_node, args.num_heads, args.dropout, args.margin, args.lr)
run_id = str(uuid.uuid4())
logger = WandbLogger(project=f"muppit_embedding",
# name="debug",
name=f"affinity_lr={args.lr}_gradclip={args.grad_clip}_margin={args.margin}",
job_type='model-training',
id=run_id)
print(f"Saving to {args.output_file}")
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
# monitor='val_loss',
dirpath=args.output_file,
# filename='model-{epoch:02d}-{val_loss:.2f}',
filename='model-{epoch:02d}-{val_acc:.2f}',
# filename='muppit',
save_top_k=-1,
mode='max',
# mode='min',
# every_n_train_steps=1000,
# save_on_train_epoch_end=False
)
early_stopping_callback = EarlyStopping(
# monitor='val_acc',
monitor='val_loss',
patience=10,
verbose=True,
# mode='max',
mode='min',
)
accumulator = GradientAccumulationScheduler(scheduling={0: 8})
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator='gpu',
strategy='ddp_find_unused_parameters_true',
precision='bf16',
logger=logger,
devices=[0,1],
callbacks=[checkpoint_callback, accumulator],
gradient_clip_val=args.grad_clip,
# val_check_interval=100,
)
trainer.fit(model, datamodule=data_module)
best_model_path = checkpoint_callback.best_model_path
print(best_model_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
parser.add_argument("-lr", type=float, default=1e-3)
parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-grad_clip", type=float, default=0.5)
parser.add_argument("-margin", type=float, default=0.5)
parser.add_argument("-max_epochs", type=int, default=30)
parser.add_argument("-d_node", type=int, default=1024, help="Node Representation Dimension")
parser.add_argument("-num_heads", type=int, default=4)
parser.add_argument("-dropout", type=float, default=0.1)
args = parser.parse_args()
main(args)