|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from transformers import AutoModel, AutoTokenizer |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
|
from sklearn.model_selection import ParameterGrid |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import numpy as np |
|
import sys |
|
import os |
|
from datetime import datetime |
|
import logging |
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
|
|
path = "/workspace/sg666/MDpLM" |
|
hyperparams = { |
|
"train_data": path + "/data/membrane/train.csv", |
|
"val_data": path + "/data/membrane/val.csv", |
|
"test_data": path + "/data/membrane/test.csv", |
|
'esm_model_path': "facebook/esm2_t33_650M_UR50D", |
|
'mlm_model_path': path + "/benchmarks/MLM/model_ckpts/best_model_epoch", |
|
"mdlm_model_path": path + "/checkpoints/membrane_automodel/epochs30_lr3e-4_bsz16_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params", |
|
"batch_size": 1, |
|
"learning_rate": 5e-5, |
|
"num_epochs": 2, |
|
"num_layers": 4, |
|
"num_heads": 16, |
|
"dropout": 0.5 |
|
} |
|
|
|
|
|
|
|
def load_models(esm_model_path, mlm_model_path, mdlm_model_path): |
|
esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
|
esm_model = AutoModel.from_pretrained(esm_model_path).to(device) |
|
mlm_model = AutoModel.from_pretrained(mlm_model_path).to(device) |
|
mdlm_model = AutoModel.from_pretrained(mdlm_model_path).to(device) |
|
return esm_tokenizer, esm_model, mlm_model, mdlm_model |
|
|
|
|
|
def get_latents(embedding_type, esm_model_path, mlm_model_path, mdlm_model_path, sequence, device): |
|
tokenizer, esm_model, mlm_model, mdlm_model = load_models(esm_model_path, mlm_model_path, mdlm_model_path) |
|
|
|
if embedding_type == "esm": |
|
model = esm_model |
|
elif embedding_type == "mlm": |
|
model = mlm_model |
|
elif embedding_type == "mdlm": |
|
model = mdlm_model |
|
|
|
inputs = tokenizer(sequence.upper(), return_tensors="pt").to(device)['input_ids'] |
|
with torch.no_grad(): |
|
embeddings = model(inputs).last_hidden_state.squeeze(0)[1:-1] |
|
|
|
return embeddings |
|
|
|
|
|
|
|
class SolubilityDataset(Dataset): |
|
def __init__(self, embedding_type, csv_file, esm_model_path, mlm_model_path, mdlm_model_path, device): |
|
self.data = pd.read_csv(csv_file).head(5) |
|
|
|
self.embedding_type = embedding_type |
|
self.esm_model_path = esm_model_path |
|
self.mlm_model_path = mlm_model_path |
|
self.mdlm_model_path = mdlm_model_path |
|
self.device = device |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
sequence = self.data.iloc[idx]['Sequence'] |
|
seq_len = len(sequence) |
|
embeddings = get_latents(self.embedding_type, self.esm_model_path, self.mlm_model_path, self.mdlm_model_path, |
|
sequence, self.device) |
|
|
|
label = [0 if residue.islower() else 1 for residue in sequence] |
|
labels = torch.tensor(label, dtype=torch.float32) |
|
|
|
return embeddings, labels, seq_len |
|
|
|
|
|
class SolubilityPredictor(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout): |
|
super(SolubilityPredictor, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(input_dim, 320), |
|
nn.ReLU(), |
|
nn.Linear(320, 1) |
|
) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, embeddings): |
|
|
|
|
|
|
|
|
|
logits = self.classifier(embeddings) |
|
probs = self.sigmoid(logits.squeeze(-1)) |
|
|
|
return probs |
|
|
|
|
|
|
|
def train(model, train_loader, val_loader, optimizer, criterion, device): |
|
""" |
|
Trains the model for a single epoch. |
|
Args: |
|
model (nn.Module): model that will be trained |
|
dataloader (DataLoader): PyTorch DataLoader with training data |
|
optimizer (torch.optim): optimizer |
|
criterion (nn.Module): loss function |
|
device (torch.device): device (GPU or CPU to train the model |
|
Returns: |
|
total_loss (float): model loss |
|
""" |
|
|
|
model.train() |
|
train_loss = 0 |
|
|
|
prog_bar = tqdm(total=len(train_loader), leave=True, file=sys.stdout) |
|
for step, batch in enumerate(train_loader, start=1): |
|
embeddings, labels, seq_len = batch |
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
embeddings = embeddings.squeeze(1) |
|
optimizer.zero_grad() |
|
outputs = model(embeddings) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
train_loss += loss.item() |
|
prog_bar.update() |
|
sys.stdout.flush() |
|
prog_bar.close() |
|
|
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
|
|
prog_bar = tqdm(total=len(val_loader), leave=True, file=sys.stdout) |
|
for step, batch in enumerate(val_loader): |
|
embeddings, labels, seq_len = batch |
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
with torch.no_grad(): |
|
outputs = model(embeddings) |
|
loss = criterion(outputs, labels) |
|
val_loss += loss.item() |
|
prog_bar.update() |
|
sys.stdout.flush() |
|
prog_bar.close() |
|
|
|
return train_loss/len(train_loader), val_loss/len(val_loader) |
|
|
|
|
|
|
|
|
|
def evaluate(model, dataloader, device): |
|
""" |
|
Performs inference on a trained model |
|
Args: |
|
model (nn.Module): the trained model |
|
dataloader (DataLoader): PyTorch DataLoader with testing data |
|
device (torch.device): device (GPU or CPU) to be used for inference |
|
Returns: |
|
preds (list): predicted per-residue disorder labels |
|
true_labels (list): ground truth per-residue disorder labels |
|
""" |
|
model.eval() |
|
preds, true_labels = [], [] |
|
with torch.no_grad(): |
|
for embeddings, labels, seq_len in tqdm(dataloader): |
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
outputs = model(embeddings) |
|
preds.append(outputs.cpu().numpy()) |
|
true_labels.append(labels.cpu().numpy()) |
|
return preds, true_labels |
|
|
|
|
|
def calculate_metrics(preds, labels, threshold=0.5): |
|
""" |
|
Calculates metrics to assess model performance |
|
Args: |
|
preds (list): model's predictions |
|
labels (list): ground truth labels |
|
threshold (float): minimum threshold a prediction must be met to be considered disordered |
|
Returns: |
|
accuracy (float): accuracy |
|
precision (float): precision |
|
recall (float): recall |
|
f1 (float): F1 score |
|
roc_auc (float): AUROC score |
|
""" |
|
flat_binary_preds, flat_prob_preds, flat_labels = [], [], [] |
|
|
|
for pred, label in zip(preds, labels): |
|
flat_binary_preds.extend((pred > threshold).astype(int).flatten()) |
|
flat_prob_preds.extend(pred.flatten()) |
|
flat_labels.extend(label.flatten()) |
|
|
|
flat_binary_preds = np.array(flat_binary_preds) |
|
flat_prob_preds = np.array(flat_prob_preds) |
|
flat_labels = np.array(flat_labels) |
|
|
|
accuracy = accuracy_score(flat_labels, flat_binary_preds) |
|
precision = precision_score(flat_labels, flat_binary_preds) |
|
recall = recall_score(flat_labels, flat_binary_preds) |
|
f1 = f1_score(flat_labels, flat_binary_preds) |
|
roc_auc = roc_auc_score(flat_labels, flat_prob_preds) |
|
|
|
return accuracy, precision, recall, f1, roc_auc |
|
|
|
|
|
if __name__ == "__main__": |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(device) |
|
|
|
for embedding_type in ['mlm', 'esm', 'mdlm']: |
|
best_val_loss = float('inf') |
|
best_model = None |
|
|
|
|
|
train_dataset = SolubilityDataset(embedding_type, |
|
hyperparams['train_data'], |
|
hyperparams['esm_model_path'], |
|
hyperparams['mlm_model_path'], |
|
hyperparams['mdlm_model_path'], |
|
device) |
|
test_dataset = SolubilityDataset(embedding_type, |
|
hyperparams['test_data'], |
|
hyperparams['esm_model_path'], |
|
hyperparams['mlm_model_path'], |
|
hyperparams['mdlm_model_path'], |
|
device) |
|
val_dataset = SolubilityDataset(embedding_type, |
|
hyperparams['val_data'], |
|
hyperparams['esm_model_path'], |
|
hyperparams['mlm_model_path'], |
|
hyperparams['mdlm_model_path'], |
|
device) |
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True) |
|
val_dataloader = DataLoader(val_dataset, batch_size=hyperparams["batch_size"], shuffle=False) |
|
test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False) |
|
|
|
|
|
|
|
|
|
|
|
param_grid = { |
|
'learning_rate': [5e-4], |
|
'batch_size': [1], |
|
'num_heads': [4], |
|
'num_layers': [2], |
|
'dropout': [0.5], |
|
'num_epochs': [5] |
|
} |
|
|
|
|
|
grid = ParameterGrid(param_grid) |
|
for params in grid: |
|
|
|
hyperparams.update(params) |
|
|
|
|
|
input_dim=640 if embedding_type=="mdlm" else 1280 |
|
hidden_dim = input_dim |
|
model = SolubilityPredictor( |
|
input_dim=input_dim, |
|
hidden_dim=hidden_dim, |
|
num_layers=hyperparams["num_layers"], |
|
num_heads=hyperparams["num_heads"], |
|
dropout=hyperparams['dropout'] |
|
) |
|
model = model.to(device) |
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"]) |
|
criterion = nn.BCELoss() |
|
num_epochs = hyperparams['num_epochs'] |
|
|
|
|
|
for epoch in range(hyperparams["num_epochs"]): |
|
print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}") |
|
train_loss, val_loss = train(model, train_dataloader, val_dataloader, optimizer, criterion, device) |
|
print(f"TRAIN LOSS: {train_loss:.4f}") |
|
print(f"VALIDATION LOSS: {val_loss:.4f}\n") |
|
sys.stdout.flush() |
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
best_model = model.state_dict() |
|
|
|
|
|
print("TEST METRICS:") |
|
test_preds, test_labels = evaluate(model, test_dataloader, device) |
|
test_metrics = calculate_metrics(test_preds, test_labels) |
|
print(f"Accuracy: {test_metrics[0]:.4f}") |
|
print(f"Precision: {test_metrics[1]:.4f}") |
|
print(f"Recall: {test_metrics[2]:.4f}") |
|
print(f"F1 Score: {test_metrics[3]:.4f}") |
|
print(f"ROC AUC: {test_metrics[4]:.4f}") |
|
print(f"\n") |
|
sys.stdout.flush() |
|
|
|
|
|
folder_name = f"{path}/benchmarks/Supervised/Solubility/transformer_models/{embedding_type}/lr{hyperparams['learning_rate']}_bs{hyperparams['batch_size']}_epochs{hyperparams['num_epochs']}_layers{hyperparams['num_layers']}_heads{hyperparams['num_heads']}_drpt{hyperparams['dropout']}" |
|
os.makedirs(folder_name, exist_ok=True) |
|
|
|
|
|
model_file_path = os.path.join(folder_name, "model.pth") |
|
torch.save(model.state_dict(), model_file_path) |
|
|
|
|
|
output_file_path = os.path.join(folder_name, "hyperparams_and_test_results.txt") |
|
with open(output_file_path, 'w') as out_file: |
|
for key, value in hyperparams.items(): |
|
out_file.write(f"{key}: {value}\n") |
|
|
|
out_file.write("\nTEST METRICS:\n") |
|
out_file.write(f"Accuracy: {test_metrics[0]:.4f}\n") |
|
out_file.write(f"Precision: {test_metrics[1]:.4f}\n") |
|
out_file.write(f"Recall: {test_metrics[2]:.4f}\n") |
|
out_file.write(f"F1 Score: {test_metrics[3]:.4f}\n") |
|
out_file.write(f"ROC AUC: {test_metrics[4]:.4f}\n") |
|
|
|
|
|
if best_model is not None: |
|
best_model_dir = f"{path}/benchmarks/Supervised/Solubility/transformer_models/{embedding_type}" |
|
os.makedirs(best_model_dir, exist_ok=True) |
|
best_model_path = os.path.join(best_model_dir, "best_model.pth") |
|
torch.save(best_model, best_model_path) |
|
|
|
|
|
best_hyperparams_path = f"{path}/benchmarks/Supervised/Solubility/transformer_models/{embedding_type}/best_model_hyperparams.txt" |
|
with open(best_hyperparams_path, 'w') as out_file: |
|
out_file.write("Best Validation Loss: {:.4f}\n".format(best_val_loss)) |
|
for key, value in hyperparams.items(): |
|
out_file.write(f"{key}: {value}\n") |