import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from transformers import AutoModel, AutoTokenizer from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm import pandas as pd import numpy as np import sys import os from datetime import datetime import logging logging.getLogger("transformers").setLevel(logging.ERROR) # Hyperparameters dictionary path = "/workspace/sg666/MDpLM" hyperparams = { "train_data": path + "/data/membrane/train.csv", "val_data": path + "/data/membrane/val.csv", "test_data": path + "/data/membrane/test.csv", 'esm_model_path': "facebook/esm2_t33_650M_UR50D", 'mlm_model_path': path + "/benchmarks/MLM/model_ckpts/best_model_epoch", "mdlm_model_path": path + "/checkpoints/membrane_automodel/epochs30_lr3e-4_bsz16_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params", "batch_size": 1, "learning_rate": 5e-5, "num_epochs": 2, "num_layers": 4, "num_heads": 16, "dropout": 0.5 } # Helper functions to obtain all embeddings for a sequence def load_models(esm_model_path, mlm_model_path, mdlm_model_path): esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_path) esm_model = AutoModel.from_pretrained(esm_model_path).to(device) mlm_model = AutoModel.from_pretrained(mlm_model_path).to(device) mdlm_model = AutoModel.from_pretrained(mdlm_model_path).to(device) return esm_tokenizer, esm_model, mlm_model, mdlm_model def get_latents(embedding_type, esm_model_path, mlm_model_path, mdlm_model_path, sequence, device): tokenizer, esm_model, mlm_model, mdlm_model = load_models(esm_model_path, mlm_model_path, mdlm_model_path) if embedding_type == "esm": model = esm_model elif embedding_type == "mlm": model = mlm_model elif embedding_type == "mdlm": model = mdlm_model inputs = tokenizer(sequence.upper(), return_tensors="pt").to(device)['input_ids'] with torch.no_grad(): embeddings = model(inputs).last_hidden_state.squeeze(0)[1:-1] return embeddings # Dataset class that loads embeddings and labels class SolubilityDataset(Dataset): def __init__(self, embedding_type, csv_file, esm_model_path, mlm_model_path, mdlm_model_path, device): self.data = pd.read_csv(csv_file).head(5) #self.data = self.data[self.data['Sequence'].apply(len) < 1024].reset_index(drop=True) self.embedding_type = embedding_type self.esm_model_path = esm_model_path self.mlm_model_path = mlm_model_path self.mdlm_model_path = mdlm_model_path self.device = device def __len__(self): return len(self.data) def __getitem__(self, idx): sequence = self.data.iloc[idx]['Sequence'] seq_len = len(sequence) embeddings = get_latents(self.embedding_type, self.esm_model_path, self.mlm_model_path, self.mdlm_model_path, sequence, self.device) # Lowercase residues = soluble, uppercase = insoluble label = [0 if residue.islower() else 1 for residue in sequence] labels = torch.tensor(label, dtype=torch.float32) return embeddings, labels, seq_len # Transformer model class class SolubilityPredictor(nn.Module): def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout): super(SolubilityPredictor, self).__init__() #self.embedding_dim = input_dim # self.self_attention = nn.MultiheadAttention(input_dim, num_heads, dropout) # encoder_layer = nn.TransformerEncoderLayer( # d_model=hidden_dim, # nhead=num_heads, # dropout=dropout, # batch_first=True # ) # self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.classifier = nn.Sequential( nn.Linear(input_dim, 320), nn.ReLU(), nn.Linear(320, 1) ) self.sigmoid = nn.Sigmoid() def forward(self, embeddings): #attn_out, _ = self.self_attention(embeddings, embeddings, embeddings) #transformer_out = self.transformer_encoder(attn_out)#.squeeze(1).mean(dim=1) #logits = self.classifier(transformer_out) logits = self.classifier(embeddings) probs = self.sigmoid(logits.squeeze(-1)) return probs # Get probabilities of dimension seq_len # Training function def train(model, train_loader, val_loader, optimizer, criterion, device): """ Trains the model for a single epoch. Args: model (nn.Module): model that will be trained dataloader (DataLoader): PyTorch DataLoader with training data optimizer (torch.optim): optimizer criterion (nn.Module): loss function device (torch.device): device (GPU or CPU to train the model Returns: total_loss (float): model loss """ # Training loop model.train() train_loss = 0 prog_bar = tqdm(total=len(train_loader), leave=True, file=sys.stdout) for step, batch in enumerate(train_loader, start=1): embeddings, labels, seq_len = batch embeddings, labels = embeddings.to(device), labels.to(device) embeddings = embeddings.squeeze(1) optimizer.zero_grad() outputs = model(embeddings) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() prog_bar.update() sys.stdout.flush() prog_bar.close() # Validation loop model.eval() val_loss = 0.0 prog_bar = tqdm(total=len(val_loader), leave=True, file=sys.stdout) for step, batch in enumerate(val_loader): embeddings, labels, seq_len = batch embeddings, labels = embeddings.to(device), labels.to(device) with torch.no_grad(): outputs = model(embeddings) loss = criterion(outputs, labels) val_loss += loss.item() prog_bar.update() sys.stdout.flush() prog_bar.close() return train_loss/len(train_loader), val_loss/len(val_loader) # Evaluation function def evaluate(model, dataloader, device): """ Performs inference on a trained model Args: model (nn.Module): the trained model dataloader (DataLoader): PyTorch DataLoader with testing data device (torch.device): device (GPU or CPU) to be used for inference Returns: preds (list): predicted per-residue disorder labels true_labels (list): ground truth per-residue disorder labels """ model.eval() preds, true_labels = [], [] with torch.no_grad(): for embeddings, labels, seq_len in tqdm(dataloader): embeddings, labels = embeddings.to(device), labels.to(device) outputs = model(embeddings) preds.append(outputs.cpu().numpy()) true_labels.append(labels.cpu().numpy()) return preds, true_labels # Metrics calculation def calculate_metrics(preds, labels, threshold=0.5): """ Calculates metrics to assess model performance Args: preds (list): model's predictions labels (list): ground truth labels threshold (float): minimum threshold a prediction must be met to be considered disordered Returns: accuracy (float): accuracy precision (float): precision recall (float): recall f1 (float): F1 score roc_auc (float): AUROC score """ flat_binary_preds, flat_prob_preds, flat_labels = [], [], [] for pred, label in zip(preds, labels): flat_binary_preds.extend((pred > threshold).astype(int).flatten()) flat_prob_preds.extend(pred.flatten()) flat_labels.extend(label.flatten()) flat_binary_preds = np.array(flat_binary_preds) flat_prob_preds = np.array(flat_prob_preds) flat_labels = np.array(flat_labels) accuracy = accuracy_score(flat_labels, flat_binary_preds) precision = precision_score(flat_labels, flat_binary_preds) recall = recall_score(flat_labels, flat_binary_preds) f1 = f1_score(flat_labels, flat_binary_preds) roc_auc = roc_auc_score(flat_labels, flat_prob_preds) return accuracy, precision, recall, f1, roc_auc if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) for embedding_type in ['mlm', 'esm', 'mdlm']: best_val_loss = float('inf') best_model = None # Load train and test dataset train_dataset = SolubilityDataset(embedding_type, hyperparams['train_data'], hyperparams['esm_model_path'], hyperparams['mlm_model_path'], hyperparams['mdlm_model_path'], device) test_dataset = SolubilityDataset(embedding_type, hyperparams['test_data'], hyperparams['esm_model_path'], hyperparams['mlm_model_path'], hyperparams['mdlm_model_path'], device) val_dataset = SolubilityDataset(embedding_type, hyperparams['val_data'], hyperparams['esm_model_path'], hyperparams['mlm_model_path'], hyperparams['mdlm_model_path'], device) # Load PyTorch datasets into DataLoaders train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True) val_dataloader = DataLoader(val_dataset, batch_size=hyperparams["batch_size"], shuffle=False) test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False) # Set device to GPU ### Grid search to explore hyperparameter space # Define hyperparameters param_grid = { 'learning_rate': [5e-4], 'batch_size': [1], 'num_heads': [4], 'num_layers': [2], 'dropout': [0.5], 'num_epochs': [5] } # Loop over the parameter grid grid = ParameterGrid(param_grid) for params in grid: # Update hyperparameters hyperparams.update(params) # Update model with the new set of hyperparms input_dim=640 if embedding_type=="mdlm" else 1280 hidden_dim = input_dim model = SolubilityPredictor( input_dim=input_dim, hidden_dim=hidden_dim, num_layers=hyperparams["num_layers"], num_heads=hyperparams["num_heads"], dropout=hyperparams['dropout'] ) model = model.to(device) # Push model to GPU # Update optimizer optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"]) criterion = nn.BCELoss() num_epochs = hyperparams['num_epochs'] # Train for epoch in range(hyperparams["num_epochs"]): print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}") train_loss, val_loss = train(model, train_dataloader, val_dataloader, optimizer, criterion, device) print(f"TRAIN LOSS: {train_loss:.4f}") print(f"VALIDATION LOSS: {val_loss:.4f}\n") sys.stdout.flush() if val_loss < best_val_loss: best_val_loss = val_loss best_model = model.state_dict() # Evaluate model on test sequences print("TEST METRICS:") test_preds, test_labels = evaluate(model, test_dataloader, device) test_metrics = calculate_metrics(test_preds, test_labels) print(f"Accuracy: {test_metrics[0]:.4f}") print(f"Precision: {test_metrics[1]:.4f}") print(f"Recall: {test_metrics[2]:.4f}") print(f"F1 Score: {test_metrics[3]:.4f}") print(f"ROC AUC: {test_metrics[4]:.4f}") print(f"\n") sys.stdout.flush() ### Save model and metrics for this hyperparameter combination folder_name = f"{path}/benchmarks/Supervised/Solubility/transformer_models/{embedding_type}/lr{hyperparams['learning_rate']}_bs{hyperparams['batch_size']}_epochs{hyperparams['num_epochs']}_layers{hyperparams['num_layers']}_heads{hyperparams['num_heads']}_drpt{hyperparams['dropout']}" os.makedirs(folder_name, exist_ok=True) # Save current model for this hyperparameter combination model_file_path = os.path.join(folder_name, "model.pth") torch.save(model.state_dict(), model_file_path) # Save hyperparameters and test metrics to txt file output_file_path = os.path.join(folder_name, "hyperparams_and_test_results.txt") with open(output_file_path, 'w') as out_file: for key, value in hyperparams.items(): out_file.write(f"{key}: {value}\n") out_file.write("\nTEST METRICS:\n") out_file.write(f"Accuracy: {test_metrics[0]:.4f}\n") out_file.write(f"Precision: {test_metrics[1]:.4f}\n") out_file.write(f"Recall: {test_metrics[2]:.4f}\n") out_file.write(f"F1 Score: {test_metrics[3]:.4f}\n") out_file.write(f"ROC AUC: {test_metrics[4]:.4f}\n") # Save the best model and its hyperparameters if best_model is not None: best_model_dir = f"{path}/benchmarks/Supervised/Solubility/transformer_models/{embedding_type}" os.makedirs(best_model_dir, exist_ok=True) best_model_path = os.path.join(best_model_dir, "best_model.pth") torch.save(best_model, best_model_path) # Save the hyperparameters for the best model best_hyperparams_path = f"{path}/benchmarks/Supervised/Solubility/transformer_models/{embedding_type}/best_model_hyperparams.txt" with open(best_hyperparams_path, 'w') as out_file: out_file.write("Best Validation Loss: {:.4f}\n".format(best_val_loss)) for key, value in hyperparams.items(): out_file.write(f"{key}: {value}\n")