Soumic commited on
Commit
d06b274
β€’
1 Parent(s): 681d043

:lady_beetle: Repaired some major mistakes, but the model returns accuracy = 50%

Browse files
Files changed (3) hide show
  1. README.md +4 -0
  2. app.py β†’ failed_app_v3.py +18 -10
  3. failed_app_v4.py +435 -0
README.md CHANGED
@@ -9,3 +9,7 @@ license: creativeml-openrail-m
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+ ## TODOS:
14
+ https://github.com/jerryji1993/DNABERT/issues/11#issuecomment-802389446 Based on this comment, we need to split the
15
+ sequences into 512 length subsequences. Maybe that will give better results from with DNABert 6.
app.py β†’ failed_app_v3.py RENAMED
@@ -17,6 +17,11 @@ timber = logging.getLogger()
17
  # logging.basicConfig(level=logging.DEBUG)
18
  logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
19
 
 
 
 
 
 
20
  black = "\u001b[30m"
21
  red = "\u001b[31m"
22
  green = "\u001b[32m"
@@ -195,9 +200,9 @@ class MQtlBertClassifierLightningModule(LightningModule):
195
  def __init__(self,
196
  classifier: nn.Module,
197
  criterion=None, # nn.BCEWithLogitsLoss(),
198
- regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
199
- l1_lambda=0.001,
200
- l2_wright_decay=0.001,
201
  *args: Any,
202
  **kwargs: Any):
203
  super().__init__(*args, **kwargs)
@@ -227,7 +232,7 @@ class MQtlBertClassifierLightningModule(LightningModule):
227
  weight_decay = 0.0
228
  if self.regularization == 2 or self.regularization == 3:
229
  weight_decay = self.l2_weight_decay
230
- return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
231
 
232
  def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
233
  # Accuracy on training batch data
@@ -362,6 +367,9 @@ class DNABERTDataModule(LightningDataModule):
362
  # "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_validate_binned.csv",
363
  # "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_test_binned.csv",
364
 
 
 
 
365
  }
366
  if self.is_local:
367
  self.dataset = load_dataset("csv", data_files=data_files, streaming=True)
@@ -369,9 +377,9 @@ class DNABERTDataModule(LightningDataModule):
369
  self.dataset = load_dataset("fahimfarhan/mqtl-classification-datasets")
370
 
371
  def setup(self, stage=None):
372
- self.train_dataset = PagingMQTLDnaBertDataset(self.dataset['train_binned_4000'], self.tokenizer)
373
- self.validate_dataset = PagingMQTLDnaBertDataset(self.dataset['validate_binned_4000'], self.tokenizer)
374
- self.test_dataset = PagingMQTLDnaBertDataset(self.dataset['test_binned_4000'], self.tokenizer)
375
 
376
  def train_dataloader(self):
377
  return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1)
@@ -384,7 +392,7 @@ class DNABERTDataModule(LightningDataModule):
384
 
385
 
386
  def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=4,
387
- dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
388
  is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
389
  model_local_directory = f"my-awesome-model-{WINDOW}"
390
  model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
@@ -398,7 +406,7 @@ def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=
398
 
399
  classifier_module = MQtlBertClassifierLightningModule(
400
  classifier=classifier_model,
401
- regularization=2, criterion=criterion)
402
 
403
  # if os.path.exists(model_save_path):
404
  # classifier_module.load_state_dict(torch.load(model_save_path))
@@ -440,7 +448,7 @@ if __name__ == "__main__":
440
  pytorch_model = MQtlDnaBERT6Classifier()
441
  start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
442
  criterion=ReshapedBCEWithLogitsLoss(), WINDOW=4000, batch_size=12, # max 14 on my laptop...
443
- dataset_folder_prefix=dataset_folder_prefix, max_epochs=1)
444
 
445
  # Record the end time
446
  end_time = time.time()
 
17
  # logging.basicConfig(level=logging.DEBUG)
18
  logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
19
 
20
+ NO_REGULARIZATION = 0
21
+ L1_REGULARIZATION_CODE = 1
22
+ L2_REGULARIZATION_CODE = 2
23
+ L1_AND_L2_REGULARIZATION_CODE = 3
24
+
25
  black = "\u001b[30m"
26
  red = "\u001b[31m"
27
  green = "\u001b[32m"
 
200
  def __init__(self,
201
  classifier: nn.Module,
202
  criterion=None, # nn.BCEWithLogitsLoss(),
203
+ regularization: int = L2_REGULARIZATION_CODE, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
204
+ l1_lambda=0.0001,
205
+ l2_wright_decay=0.0001,
206
  *args: Any,
207
  **kwargs: Any):
208
  super().__init__(*args, **kwargs)
 
232
  weight_decay = 0.0
233
  if self.regularization == 2 or self.regularization == 3:
234
  weight_decay = self.l2_weight_decay
235
+ return torch.optim.Adam(self.parameters(), lr=1e-5, weight_decay=weight_decay) # , weight_decay=0.005)
236
 
237
  def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
238
  # Accuracy on training batch data
 
367
  # "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_validate_binned.csv",
368
  # "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_test_binned.csv",
369
 
370
+ "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_train_binned.csv",
371
+ "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_validate_binned.csv",
372
+ "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_test_binned.csv",
373
  }
374
  if self.is_local:
375
  self.dataset = load_dataset("csv", data_files=data_files, streaming=True)
 
377
  self.dataset = load_dataset("fahimfarhan/mqtl-classification-datasets")
378
 
379
  def setup(self, stage=None):
380
+ self.train_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_train'], self.tokenizer)
381
+ self.validate_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_validate'], self.tokenizer)
382
+ self.test_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer)
383
 
384
  def train_dataloader(self):
385
  return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1)
 
392
 
393
 
394
  def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=4,
395
+ is_binned=True, is_debug=False, max_epochs=10, regularization_code = L2_REGULARIZATION_CODE):
396
  is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
397
  model_local_directory = f"my-awesome-model-{WINDOW}"
398
  model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
 
406
 
407
  classifier_module = MQtlBertClassifierLightningModule(
408
  classifier=classifier_model,
409
+ regularization=regularization_code, criterion=criterion)
410
 
411
  # if os.path.exists(model_save_path):
412
  # classifier_module.load_state_dict(torch.load(model_save_path))
 
448
  pytorch_model = MQtlDnaBERT6Classifier()
449
  start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
450
  criterion=ReshapedBCEWithLogitsLoss(), WINDOW=4000, batch_size=12, # max 14 on my laptop...
451
+ max_epochs=1, regularization_code=L2_REGULARIZATION_CODE)
452
 
453
  # Record the end time
454
  end_time = time.time()
failed_app_v4.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import Any
5
+
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from pytorch_lightning import Trainer, LightningModule, LightningDataModule
8
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
9
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
10
+ from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
11
+ from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
12
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
13
+ import torch
14
+ from torch import nn
15
+ from datasets import load_dataset
16
+
17
+ timber = logging.getLogger()
18
+ # logging.basicConfig(level=logging.DEBUG)
19
+ logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
20
+
21
+ NO_REGULARIZATION = 0
22
+ L1_REGULARIZATION_CODE = 1
23
+ L2_REGULARIZATION_CODE = 2
24
+ L1_AND_L2_REGULARIZATION_CODE = 3
25
+
26
+ black = "\u001b[30m"
27
+ red = "\u001b[31m"
28
+ green = "\u001b[32m"
29
+ yellow = "\u001b[33m"
30
+ blue = "\u001b[34m"
31
+ magenta = "\u001b[35m"
32
+ cyan = "\u001b[36m"
33
+ white = "\u001b[37m"
34
+
35
+ FORWARD = "FORWARD_INPUT"
36
+ BACKWARD = "BACKWARD_INPUT"
37
+
38
+ DNA_BERT_6 = "zhihan1996/DNA_bert_6"
39
+
40
+
41
+ class CommonAttentionLayer(nn.Module):
42
+ def __init__(self, hidden_size, *args, **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.attention_linear = nn.Linear(hidden_size, 1)
45
+ pass
46
+
47
+ def forward(self, hidden_states):
48
+ # Apply linear layer
49
+ attn_weights = self.attention_linear(hidden_states)
50
+ # Apply softmax to get attention scores
51
+ attn_weights = torch.softmax(attn_weights, dim=1)
52
+ # Apply attention weights to hidden states
53
+ context_vector = torch.sum(attn_weights * hidden_states, dim=1)
54
+ return context_vector, attn_weights
55
+
56
+
57
+ class DNABert6MqtlClassifier(nn.Module, PyTorchModelHubMixin):
58
+ def __init__(self,
59
+ bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
60
+ hidden_size=768, # I got mat-mul error, looks like this will be 12 times :/
61
+ num_classes=1,
62
+ *args,
63
+ **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+ self.model_name = "DNABert6MqtlClassifier"
66
+ self.bert_model = bert_model
67
+ self.attention = CommonAttentionLayer(hidden_size) # Optional if you want to use attention
68
+
69
+ classifier_input_size = 8 # cz mat-mul error
70
+ self.classifier = nn.Linear(classifier_input_size, num_classes)
71
+
72
+ def forward(self, input_ids, attention_mask, token_type_ids):
73
+ # Run BERT on each sub-sequence and collect the embeddings
74
+ embeddings = []
75
+ for i in range(input_ids.size(0)): # Iterate over sub-sequences
76
+ outputs = self.bert_model(
77
+ input_ids=input_ids[i],
78
+ attention_mask=attention_mask[i],
79
+ token_type_ids=token_type_ids[i] if token_type_ids is not None else None
80
+ )
81
+ last_hidden_state = outputs.last_hidden_state
82
+ embedding = last_hidden_state.mean(dim=1) # Example: taking the mean of hidden states
83
+ embeddings.append(embedding)
84
+
85
+ # Concatenate embeddings from all sub-sequences
86
+ concatenated_embedding = torch.cat(embeddings, dim=1)
87
+
88
+ # apply attention here
89
+ context_vector, _ = self.attention(concatenated_embedding)
90
+
91
+ # Classify
92
+ y_probability = self.classifier(context_vector)
93
+ return y_probability # float / double
94
+
95
+
96
+ class TorchMetrics:
97
+ def __init__(self):
98
+ self.binary_accuracy = BinaryAccuracy() #.to(device)
99
+ self.binary_auc = BinaryAUROC() # .to(device)
100
+ self.binary_f1_score = BinaryF1Score() # .to(device)
101
+ self.binary_precision = BinaryPrecision() # .to(device)
102
+ self.binary_recall = BinaryRecall() # .to(device)
103
+ pass
104
+
105
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
106
+ # it looks like the library maintainers changed preds to input, ie, before: preds, now: input
107
+ self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels)
108
+ self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels)
109
+ self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels)
110
+ self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels)
111
+ self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
112
+ pass
113
+
114
+ def compute_metrics_and_log(self, log, log_prefix: str, log_color: str = green):
115
+ b_accuracy = self.binary_accuracy.compute()
116
+ b_auc = self.binary_auc.compute()
117
+ b_f1_score = self.binary_f1_score.compute()
118
+ b_precision = self.binary_precision.compute()
119
+ b_recall = self.binary_recall.compute()
120
+ timber.info(
121
+ log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
122
+ log(f"{log_prefix}_accuracy", b_accuracy)
123
+ log(f"{log_prefix}_auc", b_auc)
124
+ log(f"{log_prefix}_f1_score", b_f1_score)
125
+ log(f"{log_prefix}_precision", b_precision)
126
+ log(f"{log_prefix}_recall", b_recall)
127
+
128
+ pass
129
+
130
+ def reset_on_epoch_end(self):
131
+ self.binary_accuracy.reset()
132
+ self.binary_auc.reset()
133
+ self.binary_f1_score.reset()
134
+ self.binary_precision.reset()
135
+ self.binary_recall.reset()
136
+
137
+
138
+ class MQtlBertClassifierLightningModule(LightningModule):
139
+ def __init__(self,
140
+ classifier: nn.Module,
141
+ criterion=nn.BCEWithLogitsLoss(),
142
+ regularization: int = L2_REGULARIZATION_CODE,
143
+ # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
144
+ l1_lambda=0.0001,
145
+ l2_wright_decay=0.0001,
146
+ *args: Any,
147
+ **kwargs: Any):
148
+ super().__init__(*args, **kwargs)
149
+ self.classifier = classifier
150
+ self.criterion = criterion
151
+ self.train_metrics = TorchMetrics()
152
+ self.validate_metrics = TorchMetrics()
153
+ self.test_metrics = TorchMetrics()
154
+
155
+ self.regularization = regularization
156
+ self.l1_lambda = l1_lambda
157
+ self.l2_weight_decay = l2_wright_decay
158
+ pass
159
+
160
+ def forward(self, input_ids, attention_mask, token_type_ids, *args: Any, **kwargs: Any) -> Any:
161
+ # print(f"\n{ type(input_ids) = }, {input_ids = }")
162
+ # print(f"{ type(attention_mask) = }, { attention_mask = }")
163
+ # print(f"{ type(token_type_ids) = }, { token_type_ids = }")
164
+
165
+ return self.classifier.forward(input_ids, attention_mask, token_type_ids)
166
+
167
+ def configure_optimizers(self) -> OptimizerLRScheduler:
168
+ # Here we add weight decay (L2 regularization) to the optimizer
169
+ weight_decay = 0.0
170
+ if self.regularization == L2_REGULARIZATION_CODE or self.regularization == L1_AND_L2_REGULARIZATION_CODE:
171
+ weight_decay = self.l2_weight_decay
172
+ return torch.optim.Adam(self.parameters(), lr=1e-5, weight_decay=weight_decay) # , weight_decay=0.005)
173
+
174
+ def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
175
+ # Accuracy on training batch data
176
+ input_ids, attention_mask, token_type_ids, y = batch
177
+ probability = self.forward(input_ids, attention_mask, token_type_ids)
178
+ # prediction
179
+ predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int
180
+
181
+ loss = self.criterion(probability, y.float())
182
+
183
+ if self.regularization == L1_REGULARIZATION_CODE or self.regularization == L1_AND_L2_REGULARIZATION_CODE: # apply l1 regularization
184
+ l1_norm = sum(p.abs().sum() for p in self.parameters())
185
+ loss += self.l1_lambda * l1_norm
186
+
187
+ self.log("train_loss", loss)
188
+ # calculate the scores start
189
+ self.train_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y)
190
+ self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
191
+ # self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train")
192
+ # calculate the scores end
193
+ return loss
194
+
195
+ def on_train_epoch_end(self) -> None:
196
+ self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
197
+ self.train_metrics.reset_on_epoch_end()
198
+ pass
199
+
200
+ def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
201
+ # Accuracy on validation batch data
202
+ # print(f"debug { batch = }")
203
+ input_ids, attention_mask, token_type_ids, y = batch
204
+ probability = self.forward(input_ids, attention_mask, token_type_ids)
205
+ # prediction
206
+ predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int
207
+
208
+ # print(blue+f"{x.shape = }")
209
+ # x should have [32, sth...]
210
+ loss = self.criterion(probability, y.float())
211
+ """ loss = 0 # <------------------------- maybe the loss calculation is problematic """
212
+ self.log("valid_loss", loss)
213
+ # calculate the scores start
214
+ self.validate_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y)
215
+ self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
216
+ # self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue)
217
+
218
+ # calculate the scores end
219
+ return loss
220
+
221
+ def on_validation_epoch_end(self) -> None:
222
+ self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
223
+ self.validate_metrics.reset_on_epoch_end()
224
+ return None
225
+
226
+ def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
227
+ # Accuracy on validation batch data
228
+ input_ids, attention_mask, token_type_ids, y = batch
229
+ probability = self.forward(input_ids, attention_mask, token_type_ids)
230
+ # prediction
231
+ predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int
232
+
233
+ loss = self.criterion(probability, y.float())
234
+ self.log("test_loss", loss) # do we need this?
235
+ # calculate the scores start
236
+ self.test_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y)
237
+ self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
238
+ # self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta)
239
+
240
+ # calculate the scores end
241
+ return loss
242
+
243
+ def on_test_epoch_end(self) -> None:
244
+ self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
245
+ self.test_metrics.reset_on_epoch_end()
246
+ return None
247
+
248
+ pass
249
+
250
+
251
+ class PagingMQTLDnaBertDataset(IterableDataset):
252
+ def __init__(self, dataset, tokenizer, max_length=512):
253
+ self.dataset = dataset
254
+ self.bert_tokenizer = tokenizer
255
+ self.max_length = max_length
256
+
257
+ def __iter__(self):
258
+ for row in self.dataset:
259
+ processed = self.preprocess(row)
260
+ if processed is not None:
261
+ yield processed
262
+
263
+ def preprocess(self, row):
264
+ sequence = row['sequence']
265
+ label = row['label']
266
+
267
+ # Split the sequence into chunks of size max_length (512)
268
+ chunks = [sequence[i:i + self.max_length] for i in range(0, len(sequence), self.max_length)]
269
+
270
+ # Tokenize each chunk and return the tokenized inputs
271
+ tokenized_inputs = {
272
+ 'input_ids': [],
273
+ 'attention_mask': [],
274
+ 'token_type_ids': [] # If needed for DNABERT
275
+ }
276
+
277
+ for chunk in chunks:
278
+ encoded_chunk = self.bert_tokenizer(
279
+ chunk,
280
+ truncation=True,
281
+ padding='max_length',
282
+ max_length=self.max_length,
283
+ return_tensors='pt'
284
+ )
285
+
286
+ tokenized_inputs['input_ids'].append(encoded_chunk['input_ids'].squeeze(0))
287
+ tokenized_inputs['attention_mask'].append(encoded_chunk['attention_mask'].squeeze(0))
288
+ tokenized_inputs['token_type_ids'].append(
289
+ encoded_chunk['token_type_ids'].squeeze(0) if 'token_type_ids' in encoded_chunk else None)
290
+
291
+ # Convert list of tensors to tensors with an extra batch dimension
292
+ tokenized_inputs = {k: torch.stack(v) for k, v in tokenized_inputs.items() if v[0] is not None}
293
+
294
+ input_ids = tokenized_inputs['input_ids']
295
+ attention_mask = tokenized_inputs['attention_mask']
296
+ token_type_ids = tokenized_inputs['token_type_ids']
297
+
298
+ # print(f"{type(input_ids) }")
299
+ # print(f"{type(attention_mask) }")
300
+ # print(f"{type(token_type_ids) }")
301
+
302
+ # Concatenate these tensors along a new dimension
303
+ # Result will be shape [3, num_chunks, 512]
304
+ # stacked_inputs = torch.stack([input_ids, attention_mask, token_type_ids], dim=0)
305
+
306
+ # return stacked_inputs, torch.tensor(label)
307
+ return input_ids, attention_mask, token_type_ids, torch.tensor(label).int()
308
+
309
+
310
+ class DNABERTDataModule(LightningDataModule):
311
+ def __init__(self, model_name=DNA_BERT_6, batch_size=8, WINDOW=-1, is_local=False):
312
+ super().__init__()
313
+ self.tokenized_dataset = None
314
+ self.dataset = None
315
+ self.train_dataset: PagingMQTLDnaBertDataset = None
316
+ self.validate_dataset: PagingMQTLDnaBertDataset = None
317
+ self.test_dataset: PagingMQTLDnaBertDataset = None
318
+ self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
319
+ self.batch_size = batch_size
320
+ self.is_local = is_local
321
+ self.window = WINDOW
322
+
323
+ def prepare_data(self):
324
+ # Download and prepare dataset
325
+ data_files = {
326
+ # small samples
327
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
328
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
329
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
330
+ # medium samples
331
+ "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
332
+ "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
333
+ "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
334
+
335
+ # large samples
336
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
337
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
338
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
339
+
340
+ # really tiny
341
+ # "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_train_binned.csv",
342
+ # "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_validate_binned.csv",
343
+ # "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_test_binned.csv",
344
+
345
+ "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_train_binned.csv",
346
+ "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_validate_binned.csv",
347
+ "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_test_binned.csv",
348
+ }
349
+ if self.is_local:
350
+ self.dataset = load_dataset("csv", data_files=data_files, streaming=True)
351
+ else:
352
+ self.dataset = load_dataset("fahimfarhan/mqtl-classification-datasets")
353
+
354
+ def setup(self, stage=None):
355
+ self.train_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer)
356
+ self.validate_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_validate'], self.tokenizer)
357
+ self.test_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer)
358
+
359
+ def train_dataloader(self):
360
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1)
361
+
362
+ def val_dataloader(self):
363
+ return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=1)
364
+
365
+ def test_dataloader(self) -> EVAL_DATALOADERS:
366
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=1)
367
+
368
+
369
+ def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=4,
370
+ is_binned=True, is_debug=False, max_epochs=10, regularization_code=L2_REGULARIZATION_CODE):
371
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
372
+ model_local_directory = f"my-awesome-model-{WINDOW}"
373
+ model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
374
+ file_suffix = ""
375
+ if is_binned:
376
+ file_suffix = "_binned"
377
+
378
+ data_module = DNABERTDataModule(batch_size=batch_size, WINDOW=WINDOW, is_local=is_my_laptop)
379
+
380
+ # classifier_model = classifier_model.to(DEVICE)
381
+
382
+ classifier_module = MQtlBertClassifierLightningModule(
383
+ classifier=classifier_model,
384
+ regularization=regularization_code, criterion=criterion)
385
+
386
+ # if os.path.exists(model_save_path):
387
+ # classifier_module.load_state_dict(torch.load(model_save_path))
388
+
389
+ classifier_module = classifier_module # .double()
390
+
391
+ # Prepare data using the DataModule
392
+ data_module.prepare_data()
393
+ data_module.setup()
394
+
395
+ trainer = Trainer(max_epochs=max_epochs, precision="32")
396
+
397
+ # Train the model
398
+ trainer.fit(model=classifier_module, datamodule=data_module)
399
+ trainer.test(model=classifier_module, datamodule=data_module)
400
+ torch.save(classifier_module.state_dict(), model_save_path)
401
+
402
+ # classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
403
+
404
+ classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
405
+ # push to the hub
406
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
407
+ if is_my_laptop:
408
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
409
+
410
+ classifier_model.push_to_hub(
411
+ repo_id=model_remote_repository,
412
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
413
+ commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
414
+ # safe_serialization=False
415
+ )
416
+ pass
417
+
418
+
419
+ if __name__ == "__main__":
420
+ start_time = time.time()
421
+
422
+ dataset_folder_prefix = "inputdata/"
423
+ pytorch_model = DNABert6MqtlClassifier()
424
+ start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
425
+ criterion=nn.BCEWithLogitsLoss(), WINDOW=4000, batch_size=1, # 12, # max 14 on my laptop...
426
+ max_epochs=1, regularization_code=L2_REGULARIZATION_CODE)
427
+
428
+ # Record the end time
429
+ end_time = time.time()
430
+ # Calculate the duration
431
+ duration = end_time - start_time
432
+ # Print the runtime
433
+ print(f"Runtime: {duration:.2f} seconds")
434
+
435
+ pass