|
import pdb |
|
from pytorch_lightning.strategies import DDPStrategy |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader, DistributedSampler, BatchSampler, Sampler |
|
from datasets import load_from_disk |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \ |
|
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler |
|
from pytorch_lightning.loggers import WandbLogger |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from transformers.optimization import get_cosine_schedule_with_warmup |
|
from argparse import ArgumentParser |
|
import os |
|
import uuid |
|
import esm |
|
import numpy as np |
|
import torch.distributed as dist |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup |
|
|
|
from torch.optim import Adam, AdamW |
|
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef |
|
import torch_geometric.nn as pyg_nn |
|
import gc |
|
import math |
|
|
|
|
|
|
|
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' |
|
|
|
vhse8_values = { |
|
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48], |
|
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83], |
|
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80], |
|
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56], |
|
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41], |
|
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41], |
|
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36], |
|
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10], |
|
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65], |
|
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13], |
|
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62], |
|
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01], |
|
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68], |
|
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65], |
|
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56], |
|
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11], |
|
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39], |
|
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85], |
|
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52], |
|
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03], |
|
} |
|
|
|
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7} |
|
|
|
vhse8_tensor = torch.zeros(24, 8) |
|
for aa, values in vhse8_values.items(): |
|
aa_index = aa_to_idx[aa] |
|
vhse8_tensor[aa_index] = torch.tensor(values) |
|
vhse8_tensor.requires_grad = False |
|
|
|
|
|
def collate_fn(batch): |
|
|
|
binders = [] |
|
mutants = [] |
|
wildtypes = [] |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
for b in batch: |
|
binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1]) |
|
mutant = torch.tensor(b['mutant_input_ids']['input_ids'][1:-1]) |
|
wildtype = torch.tensor(b['wildtype_input_ids']['input_ids'][1:-1]) |
|
|
|
if binder.dim() == 0 or binder.numel() == 0 or mutant.dim() == 0 or mutant.numel() == 0 or wildtype.dim() == 0 or wildtype.numel() == 0: |
|
continue |
|
binders.append(binder) |
|
mutants.append(mutant) |
|
wildtypes.append(wildtype) |
|
|
|
|
|
|
|
try: |
|
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
except: |
|
pdb.set_trace() |
|
|
|
return { |
|
'binder_input_ids': binder_input_ids.int(), |
|
'mutant_input_ids': mutant_input_ids.int(), |
|
'wildtype_input_ids': wildtype_input_ids.int(), |
|
} |
|
|
|
|
|
class CustomDataModule(pl.LightningDataModule): |
|
def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128): |
|
super().__init__() |
|
self.train_dataset = train_dataset |
|
self.val_dataset = val_dataset |
|
self.batch_size = batch_size |
|
self.tokenizer = tokenizer |
|
print(len(train_dataset)) |
|
print(len(val_dataset)) |
|
|
|
def train_dataloader(self): |
|
|
|
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn, |
|
num_workers=8, pin_memory=True) |
|
|
|
def val_dataloader(self): |
|
|
|
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8, |
|
pin_memory=True) |
|
|
|
def setup(self, stage=None): |
|
if stage == 'test' or stage is None: |
|
pass |
|
|
|
|
|
class CosineAnnealingWithWarmup(_LRScheduler): |
|
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1): |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.base_lr = base_lr |
|
self.max_lr = max_lr |
|
self.min_lr = min_lr |
|
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) |
|
print(f"SELF BASE LRS = {self.base_lrs}") |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup_steps: |
|
|
|
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs] |
|
|
|
|
|
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
|
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) |
|
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay |
|
|
|
return [decayed_lr for base_lr in self.base_lrs] |
|
|
|
|
|
class muPPIt(pl.LightningModule): |
|
def __init__(self, d_node, num_heads, dropout, margin, lr): |
|
super(muPPIt, self).__init__() |
|
|
|
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
for param in self.esm.parameters(): |
|
param.requires_grad = False |
|
|
|
self.attention = nn.MultiheadAttention(embed_dim=d_node, num_heads=num_heads) |
|
self.layer_norm = nn.LayerNorm(d_node) |
|
|
|
self.map = nn.Sequential( |
|
nn.Linear(d_node, d_node // 2), |
|
nn.SiLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(d_node // 2, 1) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.margin = margin |
|
self.learning_rate = lr |
|
|
|
for layer in self.map: |
|
if isinstance(layer, nn.Linear): |
|
nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
|
if layer.bias is not None: |
|
nn.init.zeros_(layer.bias) |
|
|
|
def forward(self, binder_tokens, wt_tokens, mut_tokens): |
|
device = binder_tokens.device |
|
global vhse8_tensor |
|
|
|
vhse8_tensor = vhse8_tensor.to(device) |
|
|
|
with torch.no_grad(): |
|
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int() |
|
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1) |
|
binder_vhse8 = vhse8_tensor[binder_tokens] |
|
binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1) |
|
|
|
mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int() |
|
mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * mut_pad_mask.unsqueeze(-1) |
|
mut_vhse8 = vhse8_tensor[mut_tokens] |
|
mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1) |
|
|
|
wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int() |
|
wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * wt_pad_mask.unsqueeze(-1) |
|
wt_vhse8 = vhse8_tensor[wt_tokens] |
|
wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1) |
|
|
|
binder_wt = torch.concat([binder_embed, wt_embed], dim=1) |
|
binder_mut = torch.concat([binder_embed, mut_embed], dim=1) |
|
|
|
binder_wt = binder_wt.transpose(0,1) |
|
binder_mut = binder_mut.transpose(0,1) |
|
|
|
binder_wt_attn, _ = self.attention(binder_wt, binder_wt, binder_wt) |
|
binder_mut_attn, _ = self.attention(binder_mut, binder_mut, binder_mut) |
|
|
|
binder_wt_attn = binder_wt + binder_wt_attn |
|
binder_mut_attn = binder_mut + binder_mut_attn |
|
|
|
binder_wt_attn = binder_wt_attn.transpose(0, 1) |
|
binder_mut_attn = binder_mut_attn.transpose(0, 1) |
|
|
|
binder_wt_attn = self.layer_norm(binder_wt_attn) |
|
binder_mut_attn = self.layer_norm(binder_mut_attn) |
|
|
|
mapped_binder_wt = self.map(binder_wt_attn).squeeze(-1) |
|
mapped_binder_mut = self.map(binder_mut_attn).squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
distance = torch.sqrt(torch.sum((mapped_binder_wt - mapped_binder_mut) ** 2, dim=-1)) |
|
return distance |
|
|
|
def load_weights(self, checkpoint_path): |
|
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) |
|
|
|
state_dict = checkpoint['state_dict'] |
|
|
|
self.load_state_dict(state_dict, strict=True) |
|
|
|
for name, param in self.named_parameters(): |
|
param.requires_grad = False |
|
|
|
def main(args): |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
model = muPPIt(args.d_node, args.num_heads, args.dropout, args.margin, args.lr) |
|
model.load_weights(args.sm) |
|
|
|
device = model.device |
|
model.eval() |
|
|
|
binder_tokens = torch.tensor(tokenizer(args.binder)['input_ids'][1:-1]).unsqueeze(0).to(device) |
|
mut_tokens = torch.tensor(tokenizer(args.mutant)['input_ids'][1:-1]).unsqueeze(0).to(device) |
|
wt_tokens = torch.tensor(tokenizer(args.wildtype)['input_ids'][1:-1]).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
distance = model(binder_tokens, wt_tokens, mut_tokens) |
|
|
|
print(distance.item()) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
|
|
parser.add_argument("-sm", required=True, type=str) |
|
parser.add_argument("-binder", required=True, type=str) |
|
parser.add_argument("-mutant", required=True, type=str) |
|
parser.add_argument("-wildtype", required=True, type=str) |
|
parser.add_argument("-lr", type=float, default=1e-3) |
|
parser.add_argument("-batch_size", type=int, default=2, help="Batch size") |
|
parser.add_argument("-grad_clip", type=float, default=0.5) |
|
parser.add_argument("-margin", type=float, default=0.5) |
|
parser.add_argument("-max_epochs", type=int, default=30) |
|
parser.add_argument("-d_node", type=int, default=1024, help="Node Representation Dimension") |
|
parser.add_argument("-num_heads", type=int, default=4) |
|
parser.add_argument("-dropout", type=float, default=0.1) |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|