Soumic commited on
Commit
9e90264
·
1 Parent(s): 1b17226

:zap: Ready the app.py for next experiment. Looks ok

Browse files
Files changed (3) hide show
  1. app.py +370 -273
  2. app_v1_backup.py +22 -3
  3. failed_app.py +337 -0
app.py CHANGED
@@ -1,21 +1,17 @@
 
1
  import os
2
- import random
3
-
4
- import huggingface_hub
5
- import numpy as np
6
- from datasets import load_dataset, Dataset
7
- from dotenv import load_dotenv
8
- from pytorch_lightning import LightningDataModule
9
- from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
10
- from torch.utils.data import DataLoader, IterableDataset
11
- from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
12
- # from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, BertModel
14
- from transformers import TrainingArguments, Trainer
15
  import torch
16
- import logging
17
- import wandb
18
-
19
  timber = logging.getLogger()
20
  # logging.basicConfig(level=logging.DEBUG)
21
  logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
@@ -32,36 +28,273 @@ white = "\u001b[37m"
32
  FORWARD = "FORWARD_INPUT"
33
  BACKWARD = "BACKWARD_INPUT"
34
 
35
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
36
 
37
- PRETRAINED_MODEL_NAME: str = "zhihan1996/DNA_bert_6"
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
- def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
41
- start = 0
42
- end = len(seq)
43
- rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
44
- random_end = rand_pos + len(DEBUG_MOTIF)
45
- output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
46
- assert len(seq) == len(output)
47
- return output
48
 
49
 
50
- class PagingMQTLDataset(IterableDataset):
51
  def __init__(self,
52
- m_dataset,
53
- seq_len,
54
- tokenizer,
55
- max_length=512,
56
- check_if_pipeline_is_ok_by_inserting_debug_motif=False):
57
- self.dataset = m_dataset
58
- self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
59
- self.debug_motif = "ATCGCCTA"
60
- self.seq_len = seq_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  self.bert_tokenizer = tokenizer
63
  self.max_length = max_length
64
- pass
 
 
65
 
66
  def __iter__(self):
67
  for row in self.dataset:
@@ -71,267 +304,131 @@ class PagingMQTLDataset(IterableDataset):
71
 
72
  def preprocess(self, row):
73
  sequence = row['sequence'] # Fetch the 'sequence' column
74
- if len(sequence) != self.seq_len:
75
- return None # skip problematic row!
76
  label = row['label'] # Fetch the 'label' column (or whatever target you use)
77
- if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
78
- sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
79
 
80
- input_ids = self.bert_tokenizer(sequence)["input_ids"]
81
- tokenized_tensor = torch.tensor(input_ids)
82
- label_tensor = torch.tensor(label)
83
- output_dict = {"input_ids": tokenized_tensor, "labels": label_tensor} # so this is now you do it?
84
- return output_dict # tokenized_tensor, label_tensor
 
 
 
 
 
 
85
 
86
 
87
- class MqtlDataModule(LightningDataModule):
88
- def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
89
  super().__init__()
 
 
 
 
 
 
90
  self.batch_size = batch_size
91
- self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
92
- # collate_fn=collate_fn,
93
- num_workers=1,
94
- # persistent_workers=True
95
- )
96
- self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
97
- # collate_fn=collate_fn,
98
- num_workers=1,
99
- # persistent_workers=True
100
- )
101
- self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
102
- # collate_fn=collate_fn,
103
- num_workers=1,
104
- # persistent_workers=True
105
- )
106
- pass
107
 
108
  def prepare_data(self):
109
- pass
110
-
111
- def setup(self, stage: str) -> None:
112
- timber.info(f"inside setup: {stage = }")
113
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- def train_dataloader(self) -> TRAIN_DATALOADERS:
116
- return self.train_loader
117
 
118
- def val_dataloader(self) -> EVAL_DATALOADERS:
119
- return self.validate_loader
120
 
121
- def test_dataloader(self) -> EVAL_DATALOADERS:
122
- return self.test_loader
123
-
124
-
125
- def create_paging_train_val_test_datasets(tokenizer, WINDOW, is_debug, batch_size=1000):
126
- data_files = {
127
- # small samples
128
- "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
129
- "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
130
- "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
131
- # medium samples
132
- "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
133
- "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
134
- "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
135
-
136
- # large samples
137
- "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
138
- "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
139
- "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
140
- }
141
-
142
- dataset_map = None
143
  is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
144
- if is_my_laptop:
145
- dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
146
- else:
147
- dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
148
-
149
- train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
150
- check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
151
- tokenizer=tokenizer,
152
- seq_len=WINDOW
153
- )
154
- val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
155
- check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
156
- tokenizer=tokenizer,
157
- seq_len=WINDOW)
158
- test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
159
- check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
160
- tokenizer=tokenizer,
161
- seq_len=WINDOW)
162
- # data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
163
- return train_dataset, val_dataset, test_dataset
164
-
165
-
166
- def login_inside_huggingface_virtualmachine():
167
- # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
168
- try:
169
- load_dotenv() # Only useful on your laptop if .env exists
170
- print(".env file loaded successfully.")
171
- except Exception as e:
172
- print(f"Warning: Could not load .env file. Exception: {e}")
173
-
174
- # Try to get the token from environment variables
175
- try:
176
- token = os.getenv("HF_TOKEN")
177
-
178
- if not token:
179
- raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
180
-
181
- # Log in to Hugging Face Hub
182
- huggingface_hub.login(token)
183
- print("Logged in to Hugging Face Hub successfully.")
184
-
185
- except Exception as e:
186
- print(f"Error during Hugging Face login: {e}")
187
- # Handle the error appropriately (e.g., exit or retry)
188
-
189
- # wand db login
190
- try:
191
- api_key = os.getenv("WAND_DB_API_KEY")
192
- timber.info(f"{api_key = }")
193
-
194
- if not api_key:
195
- raise ValueError("WAND_DB_API_KEY not found. Make sure to set it in the environment variables or .env file.")
196
-
197
- # Log in to Hugging Face Hub
198
- wandb.login(key=api_key)
199
- print("Logged in to wand db successfully.")
200
-
201
- except Exception as e:
202
- print(f"Error during wand db Face login: {e}")
203
- pass
204
 
 
205
 
206
- # use sklearn cz torchmetrics.classification gave array index out of bound exception :/ (whatever it is called in python)
207
- def compute_metrics_using_sklearn(p):
208
- try:
209
- pred, labels = p
210
 
211
- # Get predicted class labels
212
- pred_labels = np.argmax(pred, axis=1)
 
213
 
214
- # Get predicted probabilities for the positive class
215
- pred_probs = pred[:, 1] # Assuming binary classification and 2 output classes
216
 
217
- accuracy = accuracy_score(y_true=labels, y_pred=pred_labels)
218
- recall = recall_score(y_true=labels, y_pred=pred_labels)
219
- precision = precision_score(y_true=labels, y_pred=pred_labels)
220
- f1 = f1_score(y_true=labels, y_pred=pred_labels)
221
- roc_auc = roc_auc_score(y_true=labels, y_score=pred_probs)
222
 
223
- return {"accuracy": accuracy, "roc_auc": roc_auc, "precision": precision, "recall": recall, "f1": f1}
 
 
224
 
225
- except Exception as x:
226
- print(f"compute_metrics_using_sklearn failed with exception: {x}")
227
- return {"accuracy": 0, "roc_auc": 0, "precision": 0, "recall": 0, "f1": 0}
228
 
 
 
 
 
229
 
230
- def start():
231
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
232
 
233
- login_inside_huggingface_virtualmachine()
234
- WINDOW = 4000
235
- batch_size = 100
236
- model_local_directory = f"my-awesome-model-{WINDOW}"
237
- model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
238
-
239
- is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
240
 
241
- tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, trust_remote_code=True)
242
- classifier_model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, num_labels=2)
243
- args = {
244
- "output_dir": "output_dnabert-6-mqtl_classification",
245
- "num_train_epochs": 1,
246
- "max_steps": 20_000, # train 36k + val 4k = 40k
247
- # Set the number of steps you expect to train, originally 1000, takes too much time. So I set it to 10 to run faster and check my code/pipeline
248
- "run_name": "laptop_run_dna-bert-6-mqtl_classification", # Override run_name here
249
- "per_device_train_batch_size": 1,
250
- "gradient_accumulation_steps": 32,
251
- "gradient_checkpointing": True,
252
- "learning_rate": 1e-3,
253
- "save_safetensors": False, # I added it. this solves the runtime error!
254
- # not sure if it is a good idea. sklearn may slow down training, causing time loss... if so, disable these 2 lines below
255
- "evaluation_strategy": "epoch", # To calculate metrics per epoch
256
- "logging_strategy": "epoch" # Extra: to log training data stats for loss
257
- }
258
-
259
- training_args = TrainingArguments(**args)
260
- # train_dataset, eval_dataset, test_dataset = create_data_module(tokenizer=tokenizer, WINDOW=WINDOW,
261
- # batch_size=batch_size,
262
- # is_debug=False)
263
- """ # example code
264
- max_length = 32_000
265
- sequence = 'ACTG' * int(max_length / 4)
266
- # sequence = 'ACTG' * int(1000) # seq_len = 4000 it works!
267
- sequence = [sequence] * 8 # Create 8 identical samples
268
- tokenized = tokenizer(sequence)["input_ids"]
269
- labels = [0, 1] * 4
270
-
271
- # Create a dataset for training
272
- run_the_code_ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels})
273
- run_the_code_ds.set_format("pt")
274
- """
275
-
276
- train_ds, val_ds, test_ds = create_paging_train_val_test_datasets(tokenizer, WINDOW=WINDOW, is_debug=False)
277
- # train_ds, val_ds, test_ds = run_the_code_ds, run_the_code_ds, run_the_code_ds
278
- # train_ds.set_format("pt") # doesn't work!
279
-
280
- trainer = Trainer(
281
- model=classifier_model,
282
- args=training_args,
283
- train_dataset=train_ds,
284
- eval_dataset=val_ds,
285
- compute_metrics=compute_metrics_using_sklearn # torch_metrics.compute_metrics
286
  )
287
- # train, and validate
288
- result = trainer.train()
289
- try:
290
- print(f"{result = }")
291
- except Exception as x:
292
- print(f"{x = }")
293
-
294
- # testing
295
- try:
296
- # with torch.no_grad(): # didn't work :/
297
- test_results = trainer.evaluate(eval_dataset=test_ds)
298
- print(f"{test_results = }")
299
- except Exception as oome:
300
- print(f"{oome = }")
301
- finally:
302
- # save the model
303
- model_name = "DnaBert6MQtlClassifier"
304
-
305
- classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
306
-
307
- # push to the hub
308
- commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
309
- if is_my_laptop:
310
- commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
311
-
312
- classifier_model.push_to_hub(
313
- repo_id=model_remote_repository,
314
- # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
315
- commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
316
- safe_serialization=False
317
- )
318
  pass
319
 
320
 
321
- def interprete_demo():
322
- is_my_laptop = True
323
- WINDOW = 4000
324
- batch_size = 100
325
- model_local_directory = f"my-awesome-model-{WINDOW}"
326
- model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
327
-
328
- try:
329
- classifier_model = AutoModel.from_pretrained(model_remote_repository)
330
- # todo: use captum / gentech-grelu to interpret the model
331
- except Exception as x:
332
- print(x)
333
-
334
-
335
- if __name__ == '__main__':
336
- start()
337
  pass
 
1
+ import logging
2
  import os
3
+ from typing import Any
4
+
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+ from pytorch_lightning import Trainer, LightningModule, LightningDataModule
7
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
8
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
9
+ from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
10
+ from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
11
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
 
 
 
 
12
  import torch
13
+ from torch import nn
14
+ from datasets import load_dataset
 
15
  timber = logging.getLogger()
16
  # logging.basicConfig(level=logging.DEBUG)
17
  logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
 
28
  FORWARD = "FORWARD_INPUT"
29
  BACKWARD = "BACKWARD_INPUT"
30
 
31
+ DNA_BERT_6 = "zhihan1996/DNA_bert_6"
32
+
33
 
34
+ class CommonAttentionLayer(nn.Module):
35
+ def __init__(self, hidden_size, *args, **kwargs):
36
+ super().__init__(*args, **kwargs)
37
+ self.attention_linear = nn.Linear(hidden_size, 1)
38
+ pass
39
+
40
+ def forward(self, hidden_states):
41
+ # Apply linear layer
42
+ attn_weights = self.attention_linear(hidden_states)
43
+ # Apply softmax to get attention scores
44
+ attn_weights = torch.softmax(attn_weights, dim=1)
45
+ # Apply attention weights to hidden states
46
+ context_vector = torch.sum(attn_weights * hidden_states, dim=1)
47
+ return context_vector, attn_weights
48
 
49
 
50
+ class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
51
+ def forward(self, input, target):
52
+ return super().forward(input.squeeze(), target.float())
 
 
 
 
 
53
 
54
 
55
+ class MQtlDnaBERT6Classifier(nn.Module, PyTorchModelHubMixin):
56
  def __init__(self,
57
+ bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
58
+ hidden_size=768,
59
+ num_classes=1,
60
+ *args,
61
+ **kwargs
62
+ ):
63
+ super().__init__(*args, **kwargs)
64
+
65
+ self.model_name = "MQtlDnaBERT6Classifier"
66
+
67
+ self.bert_model = bert_model
68
+ self.attention = CommonAttentionLayer(hidden_size)
69
+ self.classifier = nn.Linear(hidden_size, num_classes)
70
+ pass
71
+
72
+ def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
73
+ """
74
+ # torch.Size([128, 1, 512]) --> [128, 512]
75
+ input_ids = input_ids.squeeze(dim=1).to(DEVICE)
76
+ # torch.Size([16, 1, 512]) --> [16, 512]
77
+ attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
78
+ token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
79
+ """
80
+ bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask,
83
+ token_type_ids=token_type_ids
84
+ )
85
+
86
+ last_hidden_state = bert_output.last_hidden_state
87
+ context_vector, ignore_attention_weight = self.attention(last_hidden_state)
88
+ y = self.classifier(context_vector)
89
+ return y
90
+
91
+ """
92
+ class TorchMetrics:
93
+ def __init__(self):
94
+ self.binary_accuracy = BinaryAccuracy() #.to(device)
95
+ self.binary_auc = BinaryAUROC() # .to(device)
96
+ self.binary_f1_score = BinaryF1Score() # .to(device)
97
+ self.binary_precision = BinaryPrecision() # .to(device)
98
+ self.binary_recall = BinaryRecall() # .to(device)
99
+ pass
100
+
101
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
102
+ # it looks like the library maintainers changed preds to input, ie, before: preds, now: input
103
+ self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels)
104
+ self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels)
105
+ self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels)
106
+ self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels)
107
+ self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
108
+ pass
109
+
110
+ def compute_and_log_on_each_step(self, log, log_prefix: str, log_color: str = green):
111
+ b_accuracy = self.binary_accuracy.compute()
112
+ b_auc = self.binary_auc.compute()
113
+ b_f1_score = self.binary_f1_score.compute()
114
+ b_precision = self.binary_precision.compute()
115
+ b_recall = self.binary_recall.compute()
116
+ timber.info(log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
117
+ log(f"{log_prefix}_accuracy", b_accuracy)
118
+ log(f"{log_prefix}_auc", b_auc)
119
+ log(f"{log_prefix}_f1_score", b_f1_score)
120
+ log(f"{log_prefix}_precision", b_precision)
121
+ log(f"{log_prefix}_recall", b_recall)
122
+
123
+ # def reset_on_epoch_end(self):
124
+ # self.binary_accuracy.reset()
125
+ # self.binary_auc.reset()
126
+ # self.binary_f1_score.reset()
127
+ # self.binary_precision.reset()
128
+ # self.binary_recall.reset()
129
+
130
+ def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
131
+ b_accuracy = self.binary_accuracy.compute()
132
+ b_auc = self.binary_auc.compute()
133
+ b_f1_score = self.binary_f1_score.compute()
134
+ b_precision = self.binary_precision.compute()
135
+ b_recall = self.binary_recall.compute()
136
+ timber.info( log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
137
+ log(f"{log_prefix}_accuracy", b_accuracy)
138
+ log(f"{log_prefix}_auc", b_auc)
139
+ log(f"{log_prefix}_f1_score", b_f1_score)
140
+ log(f"{log_prefix}_precision", b_precision)
141
+ log(f"{log_prefix}_recall", b_recall)
142
+
143
+ self.binary_accuracy.reset()
144
+ self.binary_auc.reset()
145
+ self.binary_f1_score.reset()
146
+ self.binary_precision.reset()
147
+ self.binary_recall.reset()
148
+ pass
149
+ """
150
+
151
+
152
+ class TorchMetrics:
153
+ def __init__(self):
154
+ self.binary_accuracy = BinaryAccuracy() #.to(device)
155
+ self.binary_auc = BinaryAUROC() # .to(device)
156
+ self.binary_f1_score = BinaryF1Score() # .to(device)
157
+ self.binary_precision = BinaryPrecision() # .to(device)
158
+ self.binary_recall = BinaryRecall() # .to(device)
159
+ pass
160
+
161
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
162
+ # it looks like the library maintainers changed preds to input, ie, before: preds, now: input
163
+ self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels)
164
+ self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels)
165
+ self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels)
166
+ self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels)
167
+ self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
168
+ pass
169
+
170
+ def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
171
+ b_accuracy = self.binary_accuracy.compute()
172
+ b_auc = self.binary_auc.compute()
173
+ b_f1_score = self.binary_f1_score.compute()
174
+ b_precision = self.binary_precision.compute()
175
+ b_recall = self.binary_recall.compute()
176
+ timber.info( log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
177
+ log(f"{log_prefix}_accuracy", b_accuracy)
178
+ log(f"{log_prefix}_auc", b_auc)
179
+ log(f"{log_prefix}_f1_score", b_f1_score)
180
+ log(f"{log_prefix}_precision", b_precision)
181
+ log(f"{log_prefix}_recall", b_recall)
182
+
183
+ self.binary_accuracy.reset()
184
+ self.binary_auc.reset()
185
+ self.binary_f1_score.reset()
186
+ self.binary_precision.reset()
187
+ self.binary_recall.reset()
188
+ pass
189
+
190
+
191
+
192
+ class MQtlBertClassifierLightningModule(LightningModule):
193
+ def __init__(self,
194
+ classifier: nn.Module,
195
+ criterion=None, # nn.BCEWithLogitsLoss(),
196
+ regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
197
+ l1_lambda=0.001,
198
+ l2_wright_decay=0.001,
199
+ *args: Any,
200
+ **kwargs: Any):
201
+ super().__init__(*args, **kwargs)
202
+ self.classifier = classifier
203
+ self.criterion = criterion
204
+ self.train_metrics = TorchMetrics()
205
+ self.validate_metrics = TorchMetrics()
206
+ self.test_metrics = TorchMetrics()
207
+
208
+ self.regularization = regularization
209
+ self.l1_lambda = l1_lambda
210
+ self.l2_weight_decay = l2_wright_decay
211
+ pass
212
+
213
+ def forward(self, x, *args: Any, **kwargs: Any) -> Any:
214
+ input_ids: torch.tensor = x["input_ids"]
215
+ attention_mask: torch.tensor = x["attention_mask"]
216
+ token_type_ids: torch.tensor = x["token_type_ids"]
217
+ # print(f"\n{ type(input_ids) = }, {input_ids = }")
218
+ # print(f"{ type(attention_mask) = }, { attention_mask = }")
219
+ # print(f"{ type(token_type_ids) = }, { token_type_ids = }")
220
+
221
+ return self.classifier.forward(input_ids, attention_mask, token_type_ids)
222
+
223
+ def configure_optimizers(self) -> OptimizerLRScheduler:
224
+ # Here we add weight decay (L2 regularization) to the optimizer
225
+ weight_decay = 0.0
226
+ if self.regularization == 2 or self.regularization == 3:
227
+ weight_decay = self.l2_weight_decay
228
+ return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
229
+
230
+ def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
231
+ # Accuracy on training batch data
232
+ x, y = batch
233
+ preds = self.forward(x)
234
+ loss = self.criterion(preds, y)
235
+
236
+ if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
237
+ l1_norm = sum(p.abs().sum() for p in self.parameters())
238
+ loss += self.l1_lambda * l1_norm
239
+
240
+ self.log("train_loss", loss)
241
+ # calculate the scores start
242
+ self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
243
+ # self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train")
244
+ # calculate the scores end
245
+ return loss
246
+
247
+ def on_train_epoch_end(self) -> None:
248
+ self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
249
+ pass
250
 
251
+ def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
252
+ # Accuracy on validation batch data
253
+ # print(f"debug { batch = }")
254
+ x, y = batch
255
+ preds = self.forward(x)
256
+ loss = self.criterion(preds, y)
257
+ """ loss = 0 # <------------------------- maybe the loss calculation is problematic """
258
+ self.log("valid_loss", loss)
259
+ # calculate the scores start
260
+ self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
261
+ # self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue)
262
+
263
+ # calculate the scores end
264
+ return loss
265
+
266
+ def on_validation_epoch_end(self) -> None:
267
+ self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
268
+ return None
269
+
270
+ def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
271
+ # Accuracy on validation batch data
272
+ x, y = batch
273
+ preds = self.forward(x)
274
+ loss = self.criterion(preds, y)
275
+ self.log("test_loss", loss) # do we need this?
276
+ # calculate the scores start
277
+ self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
278
+ # self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta)
279
+
280
+ # calculate the scores end
281
+ return loss
282
+
283
+ def on_test_epoch_end(self) -> None:
284
+ self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
285
+ return None
286
+
287
+ pass
288
+
289
+
290
+ class PagingMQTLDnaBertDataset(IterableDataset):
291
+ def __init__(self, dataset, tokenizer, max_length=512): # hold on! why is it 512? I added 4000, and it crashed, the error suggested 512, that's why 512
292
+ self.dataset = dataset
293
  self.bert_tokenizer = tokenizer
294
  self.max_length = max_length
295
+
296
+ # def __len__(self):
297
+ # return len(self.dataset)
298
 
299
  def __iter__(self):
300
  for row in self.dataset:
 
304
 
305
  def preprocess(self, row):
306
  sequence = row['sequence'] # Fetch the 'sequence' column
 
 
307
  label = row['label'] # Fetch the 'label' column (or whatever target you use)
 
 
308
 
309
+ # Tokenize the sequence
310
+ encoded_sequence: BatchEncoding = self.bert_tokenizer(
311
+ sequence,
312
+ truncation=True,
313
+ padding='max_length',
314
+ max_length=self.max_length,
315
+ return_tensors='pt'
316
+ )
317
+
318
+ encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
319
+ return encoded_sequence_squeezed, label
320
 
321
 
322
+ class DNABERTDataModule(LightningDataModule):
323
+ def __init__(self, model_name=DNA_BERT_6, batch_size=8, WINDOW=-1, is_local=False):
324
  super().__init__()
325
+ self.tokenized_dataset = None
326
+ self.dataset = None
327
+ self.train_dataset: PagingMQTLDnaBertDataset = None
328
+ self.validate_dataset: PagingMQTLDnaBertDataset = None
329
+ self.test_dataset: PagingMQTLDnaBertDataset = None
330
+ self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
331
  self.batch_size = batch_size
332
+ self.is_local = is_local
333
+ self.window = WINDOW
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  def prepare_data(self):
336
+ # Download and prepare dataset
337
+ data_files = {
338
+ # small samples
339
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
340
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
341
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
342
+ # medium samples
343
+ "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
344
+ "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
345
+ "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
346
+
347
+ # large samples
348
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
349
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
350
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
351
+
352
+ # really tiny
353
+ # "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_train_binned.csv",
354
+ # "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_validate_binned.csv",
355
+ # "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_test_binned.csv",
356
+
357
+ }
358
+ if self.is_local:
359
+ self.dataset = load_dataset("csv", data_files=data_files, streaming=True)
360
+ else:
361
+ self.dataset = load_dataset("fahimfarhan/mqtl-classification-datasets")
362
+
363
+ def setup(self, stage=None):
364
+ self.train_dataset = PagingMQTLDnaBertDataset(self.dataset['train_binned_4000'], self.tokenizer)
365
+ self.validate_dataset = PagingMQTLDnaBertDataset(self.dataset['validate_binned_4000'], self.tokenizer)
366
+ self.test_dataset = PagingMQTLDnaBertDataset(self.dataset['test_binned_4000'], self.tokenizer)
367
+
368
+ def train_dataloader(self):
369
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1)
370
+
371
+ def val_dataloader(self):
372
+ return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=1)
373
 
374
+ def test_dataloader(self) -> EVAL_DATALOADERS:
375
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=1)
376
 
 
 
377
 
378
+ def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=4,
379
+ dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
381
+ model_local_directory = f"my-awesome-model-{WINDOW}"
382
+ model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
383
+ file_suffix = ""
384
+ if is_binned:
385
+ file_suffix = "_binned"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ data_module = DNABERTDataModule(batch_size=batch_size, WINDOW=WINDOW, is_local=is_my_laptop)
388
 
389
+ # classifier_model = classifier_model.to(DEVICE)
 
 
 
390
 
391
+ classifier_module = MQtlBertClassifierLightningModule(
392
+ classifier=classifier_model,
393
+ regularization=2, criterion=criterion)
394
 
395
+ # if os.path.exists(model_save_path):
396
+ # classifier_module.load_state_dict(torch.load(model_save_path))
397
 
398
+ classifier_module = classifier_module # .double()
 
 
 
 
399
 
400
+ # Prepare data using the DataModule
401
+ data_module.prepare_data()
402
+ data_module.setup()
403
 
404
+ trainer = Trainer(max_epochs=max_epochs, precision="32")
 
 
405
 
406
+ # Train the model
407
+ trainer.fit(model=classifier_module, datamodule=data_module)
408
+ trainer.test(model=classifier_module, datamodule=data_module)
409
+ torch.save(classifier_module.state_dict(), model_save_path)
410
 
411
+ # classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
 
412
 
413
+ classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
414
+ # push to the hub
415
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
416
+ if is_my_laptop:
417
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
 
 
418
 
419
+ classifier_model.push_to_hub(
420
+ repo_id=model_remote_repository,
421
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
422
+ commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
423
+ # safe_serialization=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  pass
426
 
427
 
428
+ if __name__ == "__main__":
429
+ dataset_folder_prefix = "inputdata/"
430
+ pytorch_model = MQtlDnaBERT6Classifier()
431
+ start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
432
+ criterion=ReshapedBCEWithLogitsLoss(), WINDOW=4000, batch_size=12, # max 14 on my laptop...
433
+ dataset_folder_prefix=dataset_folder_prefix, max_epochs=1)
 
 
 
 
 
 
 
 
 
 
434
  pass
app_v1_backup.py CHANGED
@@ -1,5 +1,7 @@
 
1
  from typing import Any
2
 
 
3
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
4
  from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
5
  from torch.utils.data import DataLoader, Dataset
@@ -46,7 +48,7 @@ class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
46
  return super().forward(input.squeeze(), target.float())
47
 
48
 
49
- class MQtlDnaBERT6Classifier(nn.Module):
50
  def __init__(self,
51
  bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
52
  hidden_size=768,
@@ -277,6 +279,10 @@ data_module = DNABERTDataModule(model_name=model_name, batch_size=8)
277
 
278
  def start_bert(classifier_model, model_save_path, criterion, WINDOW=200, batch_size=4,
279
  dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
 
 
 
 
280
  file_suffix = ""
281
  if is_binned:
282
  file_suffix = "_binned"
@@ -324,7 +330,20 @@ def start_bert(classifier_model, model_save_path, criterion, WINDOW=200, batch_s
324
  trainer.test(model=classifier_module, datamodule=data_module)
325
  torch.save(classifier_module.state_dict(), model_save_path)
326
 
327
- classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  pass
329
 
330
 
@@ -332,6 +351,6 @@ if __name__ == "__main__":
332
  dataset_folder_prefix = "inputdata/"
333
  pytorch_model = MQtlDnaBERT6Classifier()
334
  start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
335
- criterion=ReshapedBCEWithLogitsLoss(), WINDOW=200, batch_size=4,
336
  dataset_folder_prefix=dataset_folder_prefix, max_epochs=2)
337
  pass
 
1
+ import os
2
  from typing import Any
3
 
4
+ from huggingface_hub import PyTorchModelHubMixin
5
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
6
  from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
7
  from torch.utils.data import DataLoader, Dataset
 
48
  return super().forward(input.squeeze(), target.float())
49
 
50
 
51
+ class MQtlDnaBERT6Classifier(nn.Module, PyTorchModelHubMixin):
52
  def __init__(self,
53
  bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
54
  hidden_size=768,
 
279
 
280
  def start_bert(classifier_model, model_save_path, criterion, WINDOW=200, batch_size=4,
281
  dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
282
+
283
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
284
+ model_local_directory = f"my-awesome-model-{WINDOW}"
285
+ model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
286
  file_suffix = ""
287
  if is_binned:
288
  file_suffix = "_binned"
 
330
  trainer.test(model=classifier_module, datamodule=data_module)
331
  torch.save(classifier_module.state_dict(), model_save_path)
332
 
333
+ # classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
334
+
335
+ classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
336
+ # push to the hub
337
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
338
+ if is_my_laptop:
339
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
340
+
341
+ classifier_model.push_to_hub(
342
+ repo_id=model_remote_repository,
343
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
344
+ commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
345
+ safe_serialization=False
346
+ )
347
  pass
348
 
349
 
 
351
  dataset_folder_prefix = "inputdata/"
352
  pytorch_model = MQtlDnaBERT6Classifier()
353
  start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
354
+ criterion=ReshapedBCEWithLogitsLoss(), WINDOW=4000, batch_size=4,
355
  dataset_folder_prefix=dataset_folder_prefix, max_epochs=2)
356
  pass
failed_app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import huggingface_hub
5
+ import numpy as np
6
+ from datasets import load_dataset, Dataset
7
+ from dotenv import load_dotenv
8
+ from pytorch_lightning import LightningDataModule
9
+ from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
10
+ from torch.utils.data import DataLoader, IterableDataset
11
+ from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
12
+ # from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, BertModel
14
+ from transformers import TrainingArguments, Trainer
15
+ import torch
16
+ import logging
17
+ import wandb
18
+
19
+ timber = logging.getLogger()
20
+ # logging.basicConfig(level=logging.DEBUG)
21
+ logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
22
+
23
+ black = "\u001b[30m"
24
+ red = "\u001b[31m"
25
+ green = "\u001b[32m"
26
+ yellow = "\u001b[33m"
27
+ blue = "\u001b[34m"
28
+ magenta = "\u001b[35m"
29
+ cyan = "\u001b[36m"
30
+ white = "\u001b[37m"
31
+
32
+ FORWARD = "FORWARD_INPUT"
33
+ BACKWARD = "BACKWARD_INPUT"
34
+
35
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ PRETRAINED_MODEL_NAME: str = "zhihan1996/DNA_bert_6"
38
+
39
+
40
+ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
41
+ start = 0
42
+ end = len(seq)
43
+ rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
44
+ random_end = rand_pos + len(DEBUG_MOTIF)
45
+ output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
46
+ assert len(seq) == len(output)
47
+ return output
48
+
49
+
50
+ class PagingMQTLDataset(IterableDataset):
51
+ def __init__(self,
52
+ m_dataset,
53
+ seq_len,
54
+ tokenizer,
55
+ max_length=512,
56
+ check_if_pipeline_is_ok_by_inserting_debug_motif=False):
57
+ self.dataset = m_dataset
58
+ self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
59
+ self.debug_motif = "ATCGCCTA"
60
+ self.seq_len = seq_len
61
+
62
+ self.bert_tokenizer = tokenizer
63
+ self.max_length = max_length
64
+ pass
65
+
66
+ def __iter__(self):
67
+ for row in self.dataset:
68
+ processed = self.preprocess(row)
69
+ if processed is not None:
70
+ yield processed
71
+
72
+ def preprocess(self, row):
73
+ sequence = row['sequence'] # Fetch the 'sequence' column
74
+ if len(sequence) != self.seq_len:
75
+ return None # skip problematic row!
76
+ label = row['label'] # Fetch the 'label' column (or whatever target you use)
77
+ if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
78
+ sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
79
+
80
+ input_ids = self.bert_tokenizer(sequence)["input_ids"]
81
+ tokenized_tensor = torch.tensor(input_ids)
82
+ label_tensor = torch.tensor(label)
83
+ output_dict = {"input_ids": tokenized_tensor, "labels": label_tensor} # so this is now you do it?
84
+ return output_dict # tokenized_tensor, label_tensor
85
+
86
+
87
+ class MqtlDataModule(LightningDataModule):
88
+ def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
89
+ super().__init__()
90
+ self.batch_size = batch_size
91
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
92
+ # collate_fn=collate_fn,
93
+ num_workers=1,
94
+ # persistent_workers=True
95
+ )
96
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
97
+ # collate_fn=collate_fn,
98
+ num_workers=1,
99
+ # persistent_workers=True
100
+ )
101
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
102
+ # collate_fn=collate_fn,
103
+ num_workers=1,
104
+ # persistent_workers=True
105
+ )
106
+ pass
107
+
108
+ def prepare_data(self):
109
+ pass
110
+
111
+ def setup(self, stage: str) -> None:
112
+ timber.info(f"inside setup: {stage = }")
113
+ pass
114
+
115
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
116
+ return self.train_loader
117
+
118
+ def val_dataloader(self) -> EVAL_DATALOADERS:
119
+ return self.validate_loader
120
+
121
+ def test_dataloader(self) -> EVAL_DATALOADERS:
122
+ return self.test_loader
123
+
124
+
125
+ def create_paging_train_val_test_datasets(tokenizer, WINDOW, is_debug, batch_size=1000):
126
+ data_files = {
127
+ # small samples
128
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
129
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
130
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
131
+ # medium samples
132
+ "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
133
+ "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
134
+ "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
135
+
136
+ # large samples
137
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
138
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
139
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
140
+ }
141
+
142
+ dataset_map = None
143
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
144
+ if is_my_laptop:
145
+ dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
146
+ else:
147
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
148
+
149
+ train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
150
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
151
+ tokenizer=tokenizer,
152
+ seq_len=WINDOW
153
+ )
154
+ val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
155
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
156
+ tokenizer=tokenizer,
157
+ seq_len=WINDOW)
158
+ test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
159
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
160
+ tokenizer=tokenizer,
161
+ seq_len=WINDOW)
162
+ # data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
163
+ return train_dataset, val_dataset, test_dataset
164
+
165
+
166
+ def login_inside_huggingface_virtualmachine():
167
+ # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
168
+ try:
169
+ load_dotenv() # Only useful on your laptop if .env exists
170
+ print(".env file loaded successfully.")
171
+ except Exception as e:
172
+ print(f"Warning: Could not load .env file. Exception: {e}")
173
+
174
+ # Try to get the token from environment variables
175
+ try:
176
+ token = os.getenv("HF_TOKEN")
177
+
178
+ if not token:
179
+ raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
180
+
181
+ # Log in to Hugging Face Hub
182
+ huggingface_hub.login(token)
183
+ print("Logged in to Hugging Face Hub successfully.")
184
+
185
+ except Exception as e:
186
+ print(f"Error during Hugging Face login: {e}")
187
+ # Handle the error appropriately (e.g., exit or retry)
188
+
189
+ # wand db login
190
+ try:
191
+ api_key = os.getenv("WAND_DB_API_KEY")
192
+ timber.info(f"{api_key = }")
193
+
194
+ if not api_key:
195
+ raise ValueError("WAND_DB_API_KEY not found. Make sure to set it in the environment variables or .env file.")
196
+
197
+ # Log in to Hugging Face Hub
198
+ wandb.login(key=api_key)
199
+ print("Logged in to wand db successfully.")
200
+
201
+ except Exception as e:
202
+ print(f"Error during wand db Face login: {e}")
203
+ pass
204
+
205
+
206
+ # use sklearn cz torchmetrics.classification gave array index out of bound exception :/ (whatever it is called in python)
207
+ def compute_metrics_using_sklearn(p):
208
+ try:
209
+ pred, labels = p
210
+
211
+ # Get predicted class labels
212
+ pred_labels = np.argmax(pred, axis=1)
213
+
214
+ # Get predicted probabilities for the positive class
215
+ pred_probs = pred[:, 1] # Assuming binary classification and 2 output classes
216
+
217
+ accuracy = accuracy_score(y_true=labels, y_pred=pred_labels)
218
+ recall = recall_score(y_true=labels, y_pred=pred_labels)
219
+ precision = precision_score(y_true=labels, y_pred=pred_labels)
220
+ f1 = f1_score(y_true=labels, y_pred=pred_labels)
221
+ roc_auc = roc_auc_score(y_true=labels, y_score=pred_probs)
222
+
223
+ return {"accuracy": accuracy, "roc_auc": roc_auc, "precision": precision, "recall": recall, "f1": f1}
224
+
225
+ except Exception as x:
226
+ print(f"compute_metrics_using_sklearn failed with exception: {x}")
227
+ return {"accuracy": 0, "roc_auc": 0, "precision": 0, "recall": 0, "f1": 0}
228
+
229
+
230
+ def start():
231
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
232
+
233
+ login_inside_huggingface_virtualmachine()
234
+ WINDOW = 4000
235
+ batch_size = 100
236
+ model_local_directory = f"my-awesome-model-{WINDOW}"
237
+ model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
238
+
239
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
240
+
241
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, trust_remote_code=True)
242
+ classifier_model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, num_labels=2)
243
+ args = {
244
+ "output_dir": "output_dnabert-6-mqtl_classification",
245
+ "num_train_epochs": 1,
246
+ "max_steps": 20_000, # train 36k + val 4k = 40k
247
+ # Set the number of steps you expect to train, originally 1000, takes too much time. So I set it to 10 to run faster and check my code/pipeline
248
+ "run_name": "laptop_run_dna-bert-6-mqtl_classification", # Override run_name here
249
+ "per_device_train_batch_size": 1,
250
+ "gradient_accumulation_steps": 32,
251
+ "gradient_checkpointing": True,
252
+ "learning_rate": 1e-3,
253
+ "save_safetensors": False, # I added it. this solves the runtime error!
254
+ # not sure if it is a good idea. sklearn may slow down training, causing time loss... if so, disable these 2 lines below
255
+ "evaluation_strategy": "epoch", # To calculate metrics per epoch
256
+ "logging_strategy": "epoch" # Extra: to log training data stats for loss
257
+ }
258
+
259
+ training_args = TrainingArguments(**args)
260
+ # train_dataset, eval_dataset, test_dataset = create_data_module(tokenizer=tokenizer, WINDOW=WINDOW,
261
+ # batch_size=batch_size,
262
+ # is_debug=False)
263
+ """ # example code
264
+ max_length = 32_000
265
+ sequence = 'ACTG' * int(max_length / 4)
266
+ # sequence = 'ACTG' * int(1000) # seq_len = 4000 it works!
267
+ sequence = [sequence] * 8 # Create 8 identical samples
268
+ tokenized = tokenizer(sequence)["input_ids"]
269
+ labels = [0, 1] * 4
270
+
271
+ # Create a dataset for training
272
+ run_the_code_ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels})
273
+ run_the_code_ds.set_format("pt")
274
+ """
275
+
276
+ train_ds, val_ds, test_ds = create_paging_train_val_test_datasets(tokenizer, WINDOW=WINDOW, is_debug=False)
277
+ # train_ds, val_ds, test_ds = run_the_code_ds, run_the_code_ds, run_the_code_ds
278
+ # train_ds.set_format("pt") # doesn't work!
279
+
280
+ trainer = Trainer(
281
+ model=classifier_model,
282
+ args=training_args,
283
+ train_dataset=train_ds,
284
+ eval_dataset=val_ds,
285
+ compute_metrics=compute_metrics_using_sklearn # torch_metrics.compute_metrics
286
+ )
287
+ # train, and validate
288
+ result = trainer.train()
289
+ try:
290
+ print(f"{result = }")
291
+ except Exception as x:
292
+ print(f"{x = }")
293
+
294
+ # testing
295
+ try:
296
+ # with torch.no_grad(): # didn't work :/
297
+ test_results = trainer.evaluate(eval_dataset=test_ds)
298
+ print(f"{test_results = }")
299
+ except Exception as oome:
300
+ print(f"{oome = }")
301
+ finally:
302
+ # save the model
303
+ model_name = "DnaBert6MQtlClassifier"
304
+
305
+ classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
306
+
307
+ # push to the hub
308
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
309
+ if is_my_laptop:
310
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
311
+
312
+ classifier_model.push_to_hub(
313
+ repo_id=model_remote_repository,
314
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
315
+ commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
316
+ safe_serialization=False
317
+ )
318
+ pass
319
+
320
+
321
+ def interprete_demo():
322
+ is_my_laptop = True
323
+ WINDOW = 4000
324
+ batch_size = 100
325
+ model_local_directory = f"my-awesome-model-{WINDOW}"
326
+ model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
327
+
328
+ try:
329
+ classifier_model = AutoModel.from_pretrained(model_remote_repository)
330
+ # todo: use captum / gentech-grelu to interpret the model
331
+ except Exception as x:
332
+ print(x)
333
+
334
+
335
+ if __name__ == '__main__':
336
+ start()
337
+ pass