Soumic
:hammer: Read token from os environment, or .env file
b3624c4
raw
history blame
21.1 kB
import logging
import os
import random
from typing import Any
import numpy as np
import pandas as pd
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
import torch
from torch import nn
from datasets import load_dataset, IterableDataset
from huggingface_hub import PyTorchModelHubMixin
from dotenv import load_dotenv
from huggingface_hub import login
timber = logging.getLogger()
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
black = "\u001b[30m"
red = "\u001b[31m"
green = "\u001b[32m"
yellow = "\u001b[33m"
blue = "\u001b[34m"
magenta = "\u001b[35m"
cyan = "\u001b[36m"
white = "\u001b[37m"
FORWARD = "FORWARD_INPUT"
BACKWARD = "BACKWARD_INPUT"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def login_inside_huggingface_virtualmachine():
# Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
try:
load_dotenv() # Only useful on your laptop if .env exists
print(".env file loaded successfully.")
except Exception as e:
print(f"Warning: Could not load .env file. Exception: {e}")
# Try to get the token from environment variables
try:
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
# Log in to Hugging Face Hub
login(token)
print("Logged in to Hugging Face Hub successfully.")
except Exception as e:
print(f"Error during Hugging Face login: {e}")
# Handle the error appropriately (e.g., exit or retry)
def one_hot_e(dna_seq: str) -> np.ndarray:
mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
size_of_a_seq: int = len(dna_seq)
# forward = np.zeros(shape=(size_of_a_seq, 4))
forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
encoded = np.asarray(forward_list)
encoded_transposed = encoded.transpose() # todo: Needs review
return encoded_transposed
def one_hot_e_column(column: pd.Series) -> np.ndarray:
tmp_list: list = [one_hot_e(seq) for seq in column]
encoded_column = np.asarray(tmp_list).astype(np.float32)
return encoded_column
def reverse_dna_seq(dna_seq: str) -> str:
# m_reversed = ""
# for i in range(0, len(dna_seq)):
# m_reversed = dna_seq[i] + m_reversed
# return m_reversed
return dna_seq[::-1]
def complement_dna_seq(dna_seq: str) -> str:
comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
"a": "t", "c": "g", "t": "a", "g": "c",
"N": "N", "H": "H", "-": "-",
"n": "n", "h": "h"
}
comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
comp_dna_seq: str = "".join(comp_dna_seq_list)
return comp_dna_seq
def reverse_complement_dna_seq(dna_seq: str) -> str:
return reverse_dna_seq(complement_dna_seq(dna_seq))
def reverse_complement_column(column: pd.Series) -> np.ndarray:
rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
return rc_column
class TorchMetrics:
def __init__(self, device=DEVICE):
self.binary_accuracy = BinaryAccuracy().to(device)
self.binary_auc = BinaryAUROC().to(device)
self.binary_f1_score = BinaryF1Score().to(device)
self.binary_precision = BinaryPrecision().to(device)
self.binary_recall = BinaryRecall().to(device)
pass
def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
pass
def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
b_accuracy = self.binary_accuracy.compute()
b_auc = self.binary_auc.compute()
b_f1_score = self.binary_f1_score.compute()
b_precision = self.binary_precision.compute()
b_recall = self.binary_recall.compute()
timber.info(
log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
log(f"{log_prefix}_accuracy", b_accuracy)
log(f"{log_prefix}_auc", b_auc)
log(f"{log_prefix}_f1_score", b_f1_score)
log(f"{log_prefix}_precision", b_precision)
log(f"{log_prefix}_recall", b_recall)
self.binary_accuracy.reset()
self.binary_auc.reset()
self.binary_f1_score.reset()
self.binary_precision.reset()
self.binary_recall.reset()
pass
def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
start = 0
end = len(seq)
rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
random_end = rand_pos + len(DEBUG_MOTIF)
output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
assert len(seq) == len(output)
return output
class MQTLDataset(IterableDataset):
def __init__(self, m_dataset, seq_len, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
self.dataset = m_dataset
self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
self.debug_motif = "ATCGCCTA"
self.seq_len = seq_len
pass
def __iter__(self):
for row in self.dataset:
processed = self.preprocess(row)
if processed is not None:
yield processed
def preprocess(self, row):
seq = row['sequence'] # Fetch the 'sequence' column
if len(seq) != self.seq_len:
return None # skip problematic row!
label = row['label'] # Fetch the 'label' column (or whatever target you use)
if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
seq_rc = reverse_complement_dna_seq(seq)
ohe_seq = one_hot_e(dna_seq=seq)
# print(f"shape fafafa = { ohe_seq.shape = }")
ohe_seq_rc = one_hot_e(dna_seq=seq_rc)
label_number = label * 1.0
label_np_array = np.asarray([label_number]).astype(np.float32)
# return ohe_seq, ohe_seq_rc, label
return [ohe_seq, ohe_seq_rc], label_np_array
# def collate_fn(batch):
# sequences, labels = zip(*batch)
# ohe_seq, ohe_seq_rc = sequences[0], sequences[1]
# # Pad sequences to the maximum length in this batch
# padded_sequences = pad_sequence(ohe_seq, batch_first=True, padding_value=0)
# padded_sequences_rc = pad_sequence(ohe_seq_rc, batch_first=True, padding_value=0)
# # Convert labels to a tensor
# labels = torch.stack(labels)
# return [padded_sequences, padded_sequences_rc], labels
class MqtlDataModule(LightningDataModule):
def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
super().__init__()
self.batch_size = batch_size
self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=15,
# persistent_workers=True
)
self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=15,
# persistent_workers=True
)
self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
# collate_fn=collate_fn,
num_workers=15,
# persistent_workers=True
)
pass
def prepare_data(self):
pass
def setup(self, stage: str) -> None:
timber.info(f"inside setup: {stage = }")
pass
def train_dataloader(self) -> TRAIN_DATALOADERS:
return self.train_loader
def val_dataloader(self) -> EVAL_DATALOADERS:
return self.validate_loader
def test_dataloader(self) -> EVAL_DATALOADERS:
return self.test_loader
class MQtlClassifierLightningModule(LightningModule):
def __init__(self,
classifier: nn.Module,
criterion=nn.BCELoss(), # nn.BCEWithLogitsLoss(),
regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
l1_lambda=0.001,
l2_wright_decay=0.001,
m_optimizer=torch.optim.Adam,
*args: Any,
**kwargs: Any):
super().__init__(*args, **kwargs)
self.classifier = classifier
self.criterion = criterion
self.train_metrics = TorchMetrics()
self.validate_metrics = TorchMetrics()
self.test_metrics = TorchMetrics()
self.regularization = regularization
self.l1_lambda = l1_lambda
self.l2_weight_decay = l2_wright_decay
self.m_optimizer = m_optimizer
pass
def forward(self, x, *args: Any, **kwargs: Any) -> Any:
return self.classifier.forward(x)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Here we add weight decay (L2 regularization) to the optimizer
weight_decay = 0.0
if self.regularization == 2 or self.regularization == 3:
weight_decay = self.l2_weight_decay
return self.m_optimizer(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on training batch data
x, y = batch
x = [i.float() for i in x]
preds = self.forward(x)
loss = self.criterion(preds, y)
if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
l1_norm = sum(p.abs().sum() for p in self.parameters())
loss += self.l1_lambda * l1_norm
self.log("train_loss", loss)
# calculate the scores start
self.train_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
# calculate the scores end
return loss
def on_train_epoch_end(self) -> None:
timber.info(green + "on_train_epoch_end")
self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
pass
def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on validation batch data
x, y = batch
x = [i.float() for i in x]
preds = self.forward(x)
loss = self.criterion(preds, y)
self.log("valid_loss", loss)
# calculate the scores start
self.validate_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
# calculate the scores end
return loss
def on_validation_epoch_end(self) -> None:
timber.info(blue + "on_validation_epoch_end")
self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
return None
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on validation batch data
x, y = batch
x = [i.float() for i in x]
preds = self.forward(x)
loss = self.criterion(preds, y)
self.log("test_loss", loss) # do we need this?
# calculate the scores start
self.test_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
# calculate the scores end
return loss
def on_test_epoch_end(self) -> None:
timber.info(magenta + "on_test_epoch_end")
self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
return None
pass
# Some more util functions!
def create_conv_sequence(in_channel_num_of_nucleotides, num_filters, kernel_size_k_mer_motif) -> nn.Sequential:
conv1d = nn.Conv1d(in_channels=in_channel_num_of_nucleotides, out_channels=num_filters,
kernel_size=kernel_size_k_mer_motif,
padding="same") # stride = 2, just dont use stride, keep it simple for now
activation = nn.ReLU(inplace=False) # (inplace=True) will fess with interpretability
pooling = nn.MaxPool1d(
kernel_size=kernel_size_k_mer_motif) # stride = 2, just dont use stride, keep it simple for now
return nn.Sequential(conv1d, activation, pooling)
class Cnn1dClassifier(nn.Module,
PyTorchModelHubMixin
):
def __init__(self,
seq_len,
in_channel_num_of_nucleotides=4,
kernel_size_k_mer_motif=4,
num_filters=32,
lstm_hidden_size=128,
dnn_size=128,
conv_seq_list_size=3,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.file_name = f"weights_Cnn1dClassifier_seqlen_{seq_len}.pth"
self.seq_layer_forward = create_conv_sequence(in_channel_num_of_nucleotides, num_filters,
kernel_size_k_mer_motif)
self.seq_layer_backward = create_conv_sequence(in_channel_num_of_nucleotides, num_filters,
kernel_size_k_mer_motif)
self.flatten = nn.Flatten()
dnn_in_features = int(num_filters * (seq_len * 2) / kernel_size_k_mer_motif) # no idea why
# two because forward_sequence,and backward_sequence
self.dnn = nn.Linear(in_features=dnn_in_features, out_features=dnn_size)
self.dnn_activation = nn.ReLU(inplace=False) # inplace = true messes with interpretability!
self.dropout = nn.Dropout(p=0.33)
self.output_layer = nn.Linear(in_features=dnn_size, out_features=1)
self.output_activation = torch.sigmoid # not needed if using nn.BCEWithLogitsLoss()
self.layer_output_logger: dict = {}
pass
def forward(self, x):
xf, xb = x[0], x[1]
hf = self.seq_layer_forward(xf)
timber.debug(red + f"1{ hf.shape = }")
hb = self.seq_layer_backward(xb)
timber.debug(green + f"2{ hb.shape = }")
h = torch.concatenate(tensors=(hf, hb), dim=2)
timber.debug(yellow + f"4{ h.shape = } concat")
h = self.flatten(h)
timber.debug(yellow + f"5{ h.shape = } flatten")
h = self.dnn(h)
timber.debug(yellow + f"8{ h.shape = } dnn")
h = self.dnn_activation(h)
timber.debug(blue + f"9{ h.shape = } dnn_activation")
h = self.dropout(h)
timber.debug(blue + f"10{ h.shape = } dropout")
h = self.output_layer(h)
timber.debug(blue + f"11{ h.shape = } output_layer")
h = self.output_activation(h)
timber.debug(blue + f"12{ h.shape = } output_activation")
return h
def start(classifier_model, model_save_path, is_attention_model=False, m_optimizer=torch.optim.Adam, WINDOW=200,
dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
# experiment = 'tutorial_3'
# if not os.path.exists(experiment):
# os.makedirs(experiment)
"""
x_train, x_tmp, y_train, y_tmp = train_test_split(df["sequence"], df["label"], test_size=0.2)
x_test, x_val, y_test, y_val = train_test_split(x_tmp, y_tmp, test_size=0.5)
train_dataset = MyDataSet(x_train, y_train)
val_dataset = MyDataSet(x_val, y_val)
test_dataset = MyDataSet(x_test, y_test)
"""
file_suffix = ""
if is_binned:
file_suffix = "_binned"
data_files = {
# small samples
"train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
"validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
"test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
# large samples
"train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
"validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
"test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
}
dataset_map = None
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
if is_my_laptop:
dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
else:
dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
train_dataset = MQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
seq_len=WINDOW
)
val_dataset = MQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
seq_len=WINDOW)
test_dataset = MQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
seq_len=WINDOW)
data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
classifier_model = classifier_model #.to(DEVICE)
try:
classifier_model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
except Exception as x:
print(x)
classifier_module = MQtlClassifierLightningModule(classifier=classifier_model, regularization=2,
m_optimizer=m_optimizer)
# if os.path.exists(model_save_path):
# classifier_module.load_state_dict(torch.load(model_save_path))
classifier_module = classifier_module # .double()
trainer = Trainer(max_epochs=max_epochs, precision="32")
trainer.fit(model=classifier_module, datamodule=data_module)
timber.info("\n\n")
trainer.test(model=classifier_module, datamodule=data_module)
timber.info("\n\n")
torch.save(classifier_module.state_dict(), model_save_path)
# save locally
model_subdirectory = f"my-awesome-model-{WINDOW}"
classifier_model.save_pretrained(model_subdirectory)
# push to the hub
classifier_model.push_to_hub(
repo_id="fahimfarhan/mqtl-classifier-model",
# subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
commit_message=f":tada: Push model for window size {WINDOW}"
)
# reload
model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
# repo_url = "https://huggingface.co/fahimfarhan/mqtl-classifier-model"
#
# push_to_hub(
# model_file=classifier_model.file_name, # Replace with your model file path
# repo_url=repo_url,
# # config_file="config.json" # Optional, if you have a config file
# )
# start_interpreting_ig_and_dl(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
# start_interpreting_with_dlshap(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
# if is_attention_model: # todo: repair it later
# start_interpreting_attention_failed(classifier_model)
pass
if __name__ == '__main__':
login_inside_huggingface_virtualmachine()
WINDOW = 200
simple_cnn = Cnn1dClassifier(seq_len=WINDOW)
simple_cnn.enable_logging = True
start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=10)
pass
"""
lightning_logs/
*.pth
my-awesome-model
INFO:root:validate_acc = 0.5625, validate_auc = 0.5490195751190186, validate_f1_score = 0.30000001192092896, validate_precision = 0.6000000238418579, validate_recall = 0.20000000298023224
/home/soumic/Codes/mqtl-classification/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
"""