muPPIt / train_2.py
AlienChen's picture
Upload 139 files
65bd8af verified
import torch
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW
import torch.nn.functional as F
import torch.nn as nn
from datasets import load_from_disk
import esm
import numpy as np
import math
import os
from transformers import AutoTokenizer
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import gc
import pdb
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
##################### Hyper-parameters #############################################
max_epochs = 30
batch_size = 4
lr = 1e-4
# dropout = 0.1
margin = 10
accumulation_steps = 4 # 16
num_heads = 4
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/improved_train_7'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'''
max_epochs = 30
batch_size = 4
lr = 1e-4
# dropout = 0.1
margin = 10
accumulation_steps = 4
num_heads = 4
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/improved_train_7'
''')
####################################################################################
os.makedirs(checkpoint_path, exist_ok=True)
vhse8_values = {
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80],
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41],
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36],
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10],
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01],
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65],
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56],
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39],
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85],
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
}
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7}
vhse8_tensor = torch.zeros(24, 8)
for aa, values in vhse8_values.items():
aa_index = aa_to_idx[aa]
vhse8_tensor[aa_index] = torch.tensor(values)
vhse8_tensor = vhse8_tensor.to(device)
vhse8_tensor.requires_grad = False
train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_skempi_2') #16689, 16609, 17465
val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_skempi_2')
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
def collate_fn(batch):
# Unpack the batch
binders = []
mutants = []
wildtypes = []
affs = []
global tokenizer
for b in batch:
binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1])
mutant = torch.tensor(b['mutant_input_ids']['input_ids'][1:-1])
wildtype = torch.tensor(b['wildtype_input_ids']['input_ids'][1:-1])
if binder.dim() == 0 or binder.numel() == 0 or mutant.dim() == 0 or mutant.numel() == 0 or wildtype.dim() == 0 or wildtype.numel() == 0:
continue
binders.append(binder) # shape: 1*L1 -> L1
mutants.append(mutant) # shape: 1*L2 -> L2
wildtypes.append(wildtype) # shape: 1*L3 -> L3
affs.append(b['aff'])
# Collate the tensors using torch's pad_sequence
try:
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id)
wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id)
except:
pdb.set_trace()
affs = torch.tensor(affs)
# Return the collated batch
return {
'binder_input_ids': binder_input_ids.int(),
'mutant_input_ids': mutant_input_ids.int(),
'wildtype_input_ids': wildtype_input_ids.int(),
'aff': affs
}
class muPPIt(torch.nn.Module):
def __init__(self, d_node, num_heads, margin, lr, device):
super(muPPIt, self).__init__()
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
for param in self.esm.parameters():
param.requires_grad = False
self.attention = torch.nn.MultiheadAttention(embed_dim=d_node, num_heads=num_heads)
self.layer_norm = torch.nn.LayerNorm(d_node)
self.map = torch.nn.Sequential(
torch.nn.Linear(d_node, d_node // 2),
torch.nn.SiLU(),
torch.nn.Linear(d_node // 2, d_node // 4),
torch.nn.SiLU(),
torch.nn.Linear(d_node // 4, 1)
)
for layer in self.map:
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
if layer.bias is not None:
nn.init.zeros_(layer.bias)
self.margin = margin
self.learning_rate = lr
self.loss_threshold = 10 # Set a threshold for identifying hard examples
self.device = device
# Easy and hard example tracking
self.easy_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/ppiref_index.npy').tolist()
self.hard_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/skempi_index.npy').tolist()
def forward(self, binder_tokens, wt_tokens, mut_tokens):
device = self.device
global vhse8_tensor
with torch.no_grad():
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int()
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1)
binder_vhse8 = vhse8_tensor[binder_tokens]
binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1)
mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int()
mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * mut_pad_mask.unsqueeze(-1)
mut_vhse8 = vhse8_tensor[mut_tokens]
mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1)
wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int()
wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * wt_pad_mask.unsqueeze(-1)
wt_vhse8 = vhse8_tensor[wt_tokens]
wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1)
binder_wt = torch.concat([binder_embed, wt_embed], dim=1)
binder_mut = torch.concat([binder_embed, mut_embed], dim=1)
binder_wt = binder_wt.transpose(0,1)
binder_mut = binder_mut.transpose(0,1)
binder_wt_attn, _ = self.attention(binder_wt, binder_wt, binder_wt)
binder_mut_attn, _ = self.attention(binder_mut, binder_mut, binder_mut)
binder_wt_attn = binder_wt + binder_wt_attn
binder_mut_attn = binder_mut + binder_mut_attn
binder_wt_attn = binder_wt_attn.transpose(0, 1)
binder_mut_attn = binder_mut_attn.transpose(0, 1)
binder_wt_attn = self.layer_norm(binder_wt_attn)
binder_mut_attn = self.layer_norm(binder_mut_attn)
mapped_binder_wt = self.map(binder_wt_attn).squeeze(-1) # B*(L1+L2)
mapped_binder_mut = self.map(binder_mut_attn).squeeze(-1) # B*(L1+L2)
distance = torch.sqrt(torch.sum((mapped_binder_wt - mapped_binder_mut) ** 2, dim=-1))
return distance
def compute_loss(self, binder_tokens, wt_tokens, mut_tokens, aff):
distance = self.forward(binder_tokens, wt_tokens, mut_tokens)
# Loss computation
upper_loss = F.relu(distance - self.margin * (aff + 1)) # let distance < aff + 1
lower_loss = F.relu(self.margin * aff - distance) # let distance > aff
loss = upper_loss + lower_loss
loss_weights = torch.ones_like(loss)
hard_example_mask = loss > self.loss_threshold
loss_weights[hard_example_mask] = 2.0 # Double the weight for hard examples
weighted_loss = loss * loss_weights
return weighted_loss.mean(), distance
def step(self, batch, compute_acc=False):
binder_tokens = batch['binder_input_ids']
mut_tokens = batch['mutant_input_ids']
wt_tokens = batch['wildtype_input_ids']
aff = batch['aff']
binder_tokens = binder_tokens.to(device)
wt_tokens = wt_tokens.to(device)
mut_tokens = mut_tokens.to(device)
aff = aff.to(self.device)
loss, distance = self.compute_loss(binder_tokens, wt_tokens, mut_tokens, aff)
if compute_acc:
global margin
accuracy = torch.sum(torch.logical_and(torch.ge(distance, margin * aff), torch.le(distance, self.margin *(aff + 1))))
return loss, accuracy
else:
return loss
def train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size, max_epochs=10, accumulation_steps=4):
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4)
max_val_acc = 0
for epoch in range(max_epochs):
print(f"Epoch {epoch + 1}/{max_epochs}")
if epoch < 3:
train_subset = Subset(train_dataset, model.easy_example_indices)
else:
num_hard_examples = int((epoch / max_epochs) * len(model.hard_example_indices))
selected_hard_indices = model.hard_example_indices[:num_hard_examples]
combined_indices = model.easy_example_indices + selected_hard_indices
train_subset = Subset(train_dataset, combined_indices)
train_loader = DataLoader(train_subset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4)
scaler = GradScaler()
model.train()
running_loss = 0.0
optimizer.zero_grad()
for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} # Transfer batch to GPU
with autocast():
loss = model.step(batch)
scaler.scale(loss).backward()
if (batch_idx + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if scheduler.last_epoch < warmup_steps:
scheduler.step()
else:
cosine_scheduler.step()
running_loss += loss.item()
print(f"Epoch {epoch}: Training Loss = {running_loss / len(train_loader)}")
del train_loader, running_loss
gc.collect()
torch.cuda.empty_cache()
model.eval()
val_loss = 0.0
val_acc = 0.0
with torch.no_grad():
for batch in tqdm(val_loader, total=len(val_loader)):
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
val_loss_batch, val_acc_batch = model.step(batch, compute_acc=True)
val_loss += val_loss_batch.item()
val_acc += val_acc_batch.item()
print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader)}\tVal Acc = {val_acc / len(val_dataset)}")
if val_acc > max_val_acc:
max_val_acc = val_acc
global checkpoint_path
torch.save(model.state_dict(), os.path.join(checkpoint_path, f"epoch={epoch}_acc={round(val_acc / len(val_dataset), 2)}"))
model = muPPIt(d_node=1288, num_heads=num_heads, margin=margin, lr=lr, device=device).to(device)
optimizer = AdamW(model.parameters(), lr=model.learning_rate, betas=(0.9, 0.95), weight_decay=1e-5)
total_steps = len(train_dataset) // (batch_size*accumulation_steps) * max_epochs # Assuming batch_size=32, max_epochs=10
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0.1*lr)
train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size=batch_size, max_epochs=max_epochs, accumulation_steps=accumulation_steps)