|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
|
|
|
from tqdm import tqdm |
|
from datetime import datetime |
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
import os |
|
|
|
|
|
path = "/home/a03-sgoel/MDpLM" |
|
|
|
hyperparams = { |
|
"batch_size": 1, |
|
"learning_rate": 4e-5, |
|
"num_epochs": 5, |
|
"max_length": 2000, |
|
"train_data": path + "/benchmarks/membrane_type_train.csv", |
|
"test_data" : path + "/benchmarks/membrane_type_test.csv", |
|
"val_data": "", |
|
"embeddings_pkl": "" |
|
} |
|
|
|
|
|
class LocalizationDataset(Dataset): |
|
def __init__(self, csv_file, embeddings_pkl, max_length=2000): |
|
self.data = pd.read_csv(csv_file) |
|
self.max_length = max_length |
|
|
|
|
|
with open(embeddings_pkl, 'rb') as f: |
|
self.embeddings_dict = pickle.load(f) |
|
self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict) |
|
|
|
|
|
assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch" |
|
|
|
|
|
self.data['label'] = self.data.iloc[:, 2:7].value.tolist() |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float) |
|
labels = torch.tensor(self.data['label'][idx], dtype=torch.long) |
|
|
|
return embeddings, labels |
|
|
|
|
|
class LocalizationPredictor(nn.Module): |
|
def __init__(self, input_dim, num_classes): |
|
super(LocalizationPredictor, self).__init__() |
|
self.classifier = nn.Linear(input_dim, num_classes) |
|
|
|
def forward(self, embeddings): |
|
avg_embedding = torch.mean(embeddings, dim=0) |
|
logits = self.classifier(avg_embedding) |
|
return logits |
|
|
|
|
|
def train(model, dataloader, optimizer, criterion, device): |
|
model.train() |
|
total_loss = 0 |
|
for embeddings, labels in tqdm(dataloader): |
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
optimizer.zero_grad() |
|
outputs = model(embeddings) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
return total_loss / len(dataloader) |
|
|
|
|
|
def evaluate(model, dataloader, device): |
|
model.eval() |
|
preds, true_labels = [], [] |
|
with torch.no_grad(): |
|
for embeddings, labels in tqdm(dataloader): |
|
embeddings, labels = embeddings.to(device), labels.to(device) |
|
outputs = model(embeddings) |
|
preds.append(outputs.cpu().numpy()) |
|
true_labels.append(labels.cpu().numpy()) |
|
return preds, true_labels |
|
|
|
|
|
def calculate_metrics(preds, labels, threshold=0.5): |
|
flat_binary_preds, flat_labels = [], [] |
|
|
|
for pred, label in zip(preds, labels): |
|
flat_binary_preds.extend((pred > threshold).astype(int).flatten()) |
|
flat_labels.extend(label.flatten()) |
|
|
|
flat_binary_preds = np.array(flat_binary_preds) |
|
flat_labels = np.array(flat_labels) |
|
|
|
accuracy = accuracy_score(flat_labels, flat_binary_preds) |
|
precision = precision_score(flat_labels, flat_binary_preds, average='macro') |
|
recall = recall_score(flat_labels, flat_binary_preds, average='macro') |
|
f1 = f1_score(flat_labels, flat_binary_preds, average='macro') |
|
|
|
return accuracy, precision, recall, f1 |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) |
|
test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) |
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True) |
|
test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False) |
|
|
|
model = LocalizationPredictor(input_dim=1280, num_classes=4).to(device) |
|
optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"]) |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
for epoch in range(hyperparams["num_epochs"]): |
|
train_loss = train(model, train_dataloader, optimizer, criterion, device) |
|
print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}") |
|
print(f"TRAIN LOSS: {train_loss:.4f}") |
|
print("\n") |
|
|
|
|
|
print("Test set") |
|
test_preds, test_labels = evaluate(model, test_dataloader, device) |
|
test_metrics = calculate_metrics(test_preds, test_labels) |
|
print("TEST METRICS:") |
|
print(f"Accuracy: {test_metrics[0]:.4f}") |
|
print(f"Precision: {test_metrics[1]:.4f}") |
|
print(f"Recall: {test_metrics[2]:.4f}") |
|
print(f"F1 Score: {test_metrics[3]:.4f}") |