import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score from tqdm import tqdm from datetime import datetime import pandas as pd import numpy as np import pickle import os # Hyperparameters dictionary path = "/home/a03-sgoel/MDpLM" hyperparams = { "batch_size": 1, "learning_rate": 4e-5, "num_epochs": 5, "max_length": 2000, "train_data": path + "/benchmarks/membrane_type_train.csv", "test_data" : path + "/benchmarks/membrane_type_test.csv", "val_data": "", # none "embeddings_pkl": "" # Need to generate ESM embeddings } # Dataset class can load pickle file class LocalizationDataset(Dataset): def __init__(self, csv_file, embeddings_pkl, max_length=2000): self.data = pd.read_csv(csv_file) self.max_length = max_length # Map sequences to embeddings with open(embeddings_pkl, 'rb') as f: self.embeddings_dict = pickle.load(f) self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict) # Ensure sequences and embeddings are of the same length assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch" # Create multi-class label list self.data['label'] = self.data.iloc[:, 2:7].value.tolist() def __len__(self): return len(self.data) def __getitem__(self, idx): embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float) labels = torch.tensor(self.data['label'][idx], dtype=torch.long) return embeddings, labels # Multi-class localization predictor class LocalizationPredictor(nn.Module): def __init__(self, input_dim, num_classes): super(LocalizationPredictor, self).__init__() self.classifier = nn.Linear(input_dim, num_classes) # 1280 x 4 def forward(self, embeddings): avg_embedding = torch.mean(embeddings, dim=0) # Average embedding dimension: 1280 logits = self.classifier(avg_embedding) return logits # pass logits of dimension 1x4 (4-class distribution) to CE loss # Training function def train(model, dataloader, optimizer, criterion, device): model.train() total_loss = 0 for embeddings, labels in tqdm(dataloader): embeddings, labels = embeddings.to(device), labels.to(device) optimizer.zero_grad() outputs = model(embeddings) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(dataloader) # Evaluation function def evaluate(model, dataloader, device): model.eval() preds, true_labels = [], [] with torch.no_grad(): for embeddings, labels in tqdm(dataloader): embeddings, labels = embeddings.to(device), labels.to(device) outputs = model(embeddings) preds.append(outputs.cpu().numpy()) true_labels.append(labels.cpu().numpy()) return preds, true_labels # Metrics calculation def calculate_metrics(preds, labels, threshold=0.5): flat_binary_preds, flat_labels = [], [] for pred, label in zip(preds, labels): flat_binary_preds.extend((pred > threshold).astype(int).flatten()) flat_labels.extend(label.flatten()) flat_binary_preds = np.array(flat_binary_preds) flat_labels = np.array(flat_labels) accuracy = accuracy_score(flat_labels, flat_binary_preds) precision = precision_score(flat_labels, flat_binary_preds, average='macro') recall = recall_score(flat_labels, flat_binary_preds, average='macro') f1 = f1_score(flat_labels, flat_binary_preds, average='macro') return accuracy, precision, recall, f1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False) model = LocalizationPredictor(input_dim=1280, num_classes=4).to(device) optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"]) criterion = nn.CrossEntropyLoss() # Train the model for epoch in range(hyperparams["num_epochs"]): train_loss = train(model, train_dataloader, optimizer, criterion, device) print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}") print(f"TRAIN LOSS: {train_loss:.4f}") print("\n") # Evaluate model on test dataset print("Test set") test_preds, test_labels = evaluate(model, test_dataloader, device) test_metrics = calculate_metrics(test_preds, test_labels) print("TEST METRICS:") print(f"Accuracy: {test_metrics[0]:.4f}") print(f"Precision: {test_metrics[1]:.4f}") print(f"Recall: {test_metrics[2]:.4f}") print(f"F1 Score: {test_metrics[3]:.4f}")