Soumic commited on
Commit
a491ee5
1 Parent(s): 8cac562

:hammer_and_pick: Update dnabert6 classifier to run on huggingface

Browse files
Files changed (6) hide show
  1. .env_sample +1 -0
  2. .gitignore +171 -0
  3. README.md +2 -2
  4. app.py +310 -151
  5. app_v1_backup.py +337 -0
  6. requirements.txt +1 -2
.env_sample ADDED
@@ -0,0 +1 @@
 
 
1
+ HF_TOKEN=hf_YOUR_AWESOME_TOKEN
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # c++ generated files
163
+ *.out
164
+ *.exe
165
+
166
+ # my custom gitignores
167
+ lightning_logs/
168
+ *.pth
169
+ my-awesome-model/
170
+ my-awesome-model-200/
171
+ my-awesome-model-4000/
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Mqtl Classification Using Dnabert 6
3
  emoji: 👁
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
  license: creativeml-openrail-m
 
1
  ---
2
  title: Mqtl Classification Using Dnabert 6
3
  emoji: 👁
4
+ colorFrom: blue
5
+ colorTo: white
6
  sdk: docker
7
  pinned: false
8
  license: creativeml-openrail-m
app.py CHANGED
@@ -1,14 +1,28 @@
 
 
 
1
  from typing import Any
2
 
 
 
3
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
4
- from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
 
5
  from torch.utils.data import DataLoader, Dataset
6
- from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
7
  from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
8
  from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
9
  import torch
10
  from torch import nn
11
- from datasets import load_dataset
 
 
 
 
 
 
 
 
12
 
13
  black = "\u001b[30m"
14
  red = "\u001b[31m"
@@ -22,83 +36,101 @@ white = "\u001b[37m"
22
  FORWARD = "FORWARD_INPUT"
23
  BACKWARD = "BACKWARD_INPUT"
24
 
25
- DNA_BERT_6 = "zhihan1996/DNA_bert_6"
26
 
27
 
28
- class CommonAttentionLayer(nn.Module):
29
- def __init__(self, hidden_size, *args, **kwargs):
30
- super().__init__(*args, **kwargs)
31
- self.attention_linear = nn.Linear(hidden_size, 1)
32
- pass
 
 
33
 
34
- def forward(self, hidden_states):
35
- # Apply linear layer
36
- attn_weights = self.attention_linear(hidden_states)
37
- # Apply softmax to get attention scores
38
- attn_weights = torch.softmax(attn_weights, dim=1)
39
- # Apply attention weights to hidden states
40
- context_vector = torch.sum(attn_weights * hidden_states, dim=1)
41
- return context_vector, attn_weights
42
 
 
 
43
 
44
- class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
45
- def forward(self, input, target):
46
- return super().forward(input.squeeze(), target.float())
47
 
 
 
 
48
 
49
- class MQtlDnaBERT6Classifier(nn.Module):
50
- def __init__(self,
51
- bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
52
- hidden_size=768,
53
- num_classes=1,
54
- *args,
55
- **kwargs
56
- ):
57
- super().__init__(*args, **kwargs)
58
 
59
- self.model_name = "MQtlDnaBERT6Classifier"
 
 
 
 
 
 
60
 
61
- self.bert_model = bert_model
62
- self.attention = CommonAttentionLayer(hidden_size)
63
- self.classifier = nn.Linear(hidden_size, num_classes)
64
- pass
65
 
66
- def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
67
- """
68
- # torch.Size([128, 1, 512]) --> [128, 512]
69
- input_ids = input_ids.squeeze(dim=1).to(DEVICE)
70
- # torch.Size([16, 1, 512]) --> [16, 512]
71
- attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
72
- token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
73
- """
74
- bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
75
- input_ids=input_ids,
76
- attention_mask=attention_mask,
77
- token_type_ids=token_type_ids
78
- )
 
 
 
 
 
 
 
79
 
80
- last_hidden_state = bert_output.last_hidden_state
81
- context_vector, ignore_attention_weight = self.attention(last_hidden_state)
82
- y = self.classifier(context_vector)
83
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  class TorchMetrics:
87
- def __init__(self):
88
- self.binary_accuracy = BinaryAccuracy() #.to(device)
89
- self.binary_auc = BinaryAUROC() # .to(device)
90
- self.binary_f1_score = BinaryF1Score() # .to(device)
91
- self.binary_precision = BinaryPrecision() # .to(device)
92
- self.binary_recall = BinaryRecall() # .to(device)
93
  pass
94
 
95
  def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
96
- # it looks like the library maintainers changed preds to input, ie, before: preds, now: input
97
- self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels)
98
- self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels)
99
- self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels)
100
- self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels)
101
- self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
102
  pass
103
 
104
  def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
@@ -107,7 +139,8 @@ class TorchMetrics:
107
  b_f1_score = self.binary_f1_score.compute()
108
  b_precision = self.binary_precision.compute()
109
  b_recall = self.binary_recall.compute()
110
- # timber.info( log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
 
111
  log(f"{log_prefix}_accuracy", b_accuracy)
112
  log(f"{log_prefix}_auc", b_auc)
113
  log(f"{log_prefix}_f1_score", b_f1_score)
@@ -122,6 +155,95 @@ class TorchMetrics:
122
  pass
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  class MQtlBertClassifierLightningModule(LightningModule):
126
  def __init__(self,
127
  classifier: nn.Module,
@@ -185,7 +307,7 @@ class MQtlBertClassifierLightningModule(LightningModule):
185
  # print(f"debug { batch = }")
186
  x, y = batch
187
  preds = self.forward(x)
188
- loss = 0 # self.criterion(preds, y)
189
  self.log("valid_loss", loss)
190
  # calculate the scores start
191
  self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
@@ -214,76 +336,117 @@ class MQtlBertClassifierLightningModule(LightningModule):
214
  pass
215
 
216
 
217
- class DNABERTDataset(Dataset):
218
- def __init__(self, dataset, tokenizer, max_length=512):
219
- self.dataset = dataset
220
- self.bert_tokenizer = tokenizer
221
- self.max_length = max_length
222
-
223
- def __len__(self):
224
- return len(self.dataset)
225
 
226
- def __getitem__(self, idx):
227
- sequence = self.dataset[idx]['sequence'] # Fetch the 'sequence' column
228
- label = self.dataset[idx]['label'] # Fetch the 'label' column (or whatever target you use)
229
 
230
- # Tokenize the sequence
231
- encoded_sequence: BatchEncoding = self.bert_tokenizer(
232
- sequence,
233
- truncation=True,
234
- padding='max_length',
235
- max_length=self.max_length,
236
- return_tensors='pt'
237
- )
238
-
239
- encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
240
- return encoded_sequence_squeezed, label
241
 
 
 
 
 
 
 
 
 
242
 
243
- class DNABERTDataModule(LightningDataModule):
244
- def __init__(self, model_name=DNA_BERT_6, batch_size=8):
245
- super().__init__()
246
- self.tokenized_dataset = None
247
- self.dataset = None
248
- self.train_dataset: DNABERTDataset = None
249
- self.validate_dataset: DNABERTDataset = None
250
- self.test_dataset: DNABERTDataset = None
251
- self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6)
252
- self.batch_size = batch_size
253
 
254
- def prepare_data(self):
255
- # Download and prepare dataset
256
- self.dataset = load_dataset("fahimfarhan/mqtl-classification-dataset-binned-200")
257
 
258
- def setup(self, stage=None):
259
- self.train_dataset = DNABERTDataset(self.dataset['train'], self.tokenizer)
260
- self.validate_dataset = DNABERTDataset(self.dataset['validate'], self.tokenizer)
261
- self.test_dataset = DNABERTDataset(self.dataset['test'], self.tokenizer)
262
 
263
- def train_dataloader(self):
264
- return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=15)
 
 
 
 
 
 
 
 
 
 
265
 
266
- def val_dataloader(self):
267
- return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=15)
268
 
269
- def test_dataloader(self) -> EVAL_DATALOADERS:
270
- return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=15)
 
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- # Initialize DataModule
274
- model_name = "zhihan1996/DNABERT-6"
275
- data_module = DNABERTDataModule(model_name=model_name, batch_size=8)
 
276
 
277
 
278
- def start_bert(classifier_model, model_save_path, criterion, WINDOW=200, batch_size=4,
279
- dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
280
  file_suffix = ""
281
  if is_binned:
282
  file_suffix = "_binned"
283
 
284
- data_module = DNABERTDataModule(batch_size=batch_size)
285
-
286
- # classifier_model = classifier_model.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  classifier_module = MQtlBertClassifierLightningModule(
289
  classifier=classifier_model,
@@ -294,44 +457,40 @@ def start_bert(classifier_model, model_save_path, criterion, WINDOW=200, batch_s
294
 
295
  classifier_module = classifier_module # .double()
296
 
297
- # Set up training arguments
298
- training_args = TrainingArguments(
299
- output_dir='./results',
300
- evaluation_strategy="epoch",
301
- per_device_train_batch_size=batch_size,
302
- per_device_eval_batch_size=batch_size,
303
- num_train_epochs=max_epochs,
304
- logging_dir='./logs',
305
- report_to="none", # Disable reporting to WandB, etc.
306
- )
307
-
308
- # Prepare data using the DataModule
309
- data_module.prepare_data()
310
- data_module.setup()
311
-
312
- # Initialize Trainer
313
- # trainer = Trainer(
314
- # model=classifier_module,
315
- # args=training_args,
316
- # train_dataset=data_module.tokenized_dataset["train"],
317
- # eval_dataset=data_module.tokenized_dataset["test"],
318
- # )
319
-
320
  trainer = Trainer(max_epochs=max_epochs, precision="32")
321
-
322
- # Train the model
323
  trainer.fit(model=classifier_module, datamodule=data_module)
 
324
  trainer.test(model=classifier_module, datamodule=data_module)
325
- torch.save(classifier_module.state_dict(), model_save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
328
  pass
329
 
330
 
331
- if __name__ == "__main__":
332
- dataset_folder_prefix = "inputdata/"
333
- pytorch_model = MQtlDnaBERT6Classifier()
334
- start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
335
- criterion=ReshapedBCEWithLogitsLoss(), WINDOW=200, batch_size=4,
336
- dataset_folder_prefix=dataset_folder_prefix, max_epochs=2)
 
 
337
  pass
 
1
+ import logging
2
+ import os
3
+ import random
4
  from typing import Any
5
 
6
+ import numpy as np
7
+ import pandas as pd
8
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
9
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
10
+ from torch.nn.utils.rnn import pad_sequence
11
  from torch.utils.data import DataLoader, Dataset
12
+ from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
  from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
14
  from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
  import torch
16
  from torch import nn
17
+ from datasets import load_dataset, IterableDataset
18
+ from huggingface_hub import PyTorchModelHubMixin
19
+
20
+ from dotenv import load_dotenv
21
+ from huggingface_hub import login
22
+
23
+ timber = logging.getLogger()
24
+ # logging.basicConfig(level=logging.DEBUG)
25
+ logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
26
 
27
  black = "\u001b[30m"
28
  red = "\u001b[31m"
 
36
  FORWARD = "FORWARD_INPUT"
37
  BACKWARD = "BACKWARD_INPUT"
38
 
39
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
 
42
+ def login_inside_huggingface_virtualmachine():
43
+ # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
44
+ try:
45
+ load_dotenv() # Only useful on your laptop if .env exists
46
+ print(".env file loaded successfully.")
47
+ except Exception as e:
48
+ print(f"Warning: Could not load .env file. Exception: {e}")
49
 
50
+ # Try to get the token from environment variables
51
+ try:
52
+ token = os.getenv("HF_TOKEN")
 
 
 
 
 
53
 
54
+ if not token:
55
+ raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
56
 
57
+ # Log in to Hugging Face Hub
58
+ login(token)
59
+ print("Logged in to Hugging Face Hub successfully.")
60
 
61
+ except Exception as e:
62
+ print(f"Error during Hugging Face login: {e}")
63
+ # Handle the error appropriately (e.g., exit or retry)
64
 
 
 
 
 
 
 
 
 
 
65
 
66
+ def one_hot_e(dna_seq: str) -> np.ndarray:
67
+ mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
68
+ 'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
69
+ 'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
70
+ 'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
71
+ 'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
72
+ 'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
73
 
74
+ size_of_a_seq: int = len(dna_seq)
 
 
 
75
 
76
+ # forward = np.zeros(shape=(size_of_a_seq, 4))
77
+
78
+ forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
79
+ encoded = np.asarray(forward_list)
80
+ encoded_transposed = encoded.transpose() # todo: Needs review
81
+ return encoded_transposed
82
+
83
+
84
+ def one_hot_e_column(column: pd.Series) -> np.ndarray:
85
+ tmp_list: list = [one_hot_e(seq) for seq in column]
86
+ encoded_column = np.asarray(tmp_list).astype(np.float32)
87
+ return encoded_column
88
+
89
+
90
+ def reverse_dna_seq(dna_seq: str) -> str:
91
+ # m_reversed = ""
92
+ # for i in range(0, len(dna_seq)):
93
+ # m_reversed = dna_seq[i] + m_reversed
94
+ # return m_reversed
95
+ return dna_seq[::-1]
96
 
97
+
98
+ def complement_dna_seq(dna_seq: str) -> str:
99
+ comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
100
+ "a": "t", "c": "g", "t": "a", "g": "c",
101
+ "N": "N", "H": "H", "-": "-",
102
+ "n": "n", "h": "h"
103
+ }
104
+
105
+ comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
106
+ comp_dna_seq: str = "".join(comp_dna_seq_list)
107
+ return comp_dna_seq
108
+
109
+
110
+ def reverse_complement_dna_seq(dna_seq: str) -> str:
111
+ return reverse_dna_seq(complement_dna_seq(dna_seq))
112
+
113
+
114
+ def reverse_complement_column(column: pd.Series) -> np.ndarray:
115
+ rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
116
+ return rc_column
117
 
118
 
119
  class TorchMetrics:
120
+ def __init__(self, device=DEVICE):
121
+ self.binary_accuracy = BinaryAccuracy().to(device)
122
+ self.binary_auc = BinaryAUROC().to(device)
123
+ self.binary_f1_score = BinaryF1Score().to(device)
124
+ self.binary_precision = BinaryPrecision().to(device)
125
+ self.binary_recall = BinaryRecall().to(device)
126
  pass
127
 
128
  def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
129
+ self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
130
+ self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
131
+ self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
132
+ self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
133
+ self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
 
134
  pass
135
 
136
  def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
 
139
  b_f1_score = self.binary_f1_score.compute()
140
  b_precision = self.binary_precision.compute()
141
  b_recall = self.binary_recall.compute()
142
+ timber.info(
143
+ log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
144
  log(f"{log_prefix}_accuracy", b_accuracy)
145
  log(f"{log_prefix}_auc", b_auc)
146
  log(f"{log_prefix}_f1_score", b_f1_score)
 
155
  pass
156
 
157
 
158
+ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
159
+ start = 0
160
+ end = len(seq)
161
+ rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
162
+ random_end = rand_pos + len(DEBUG_MOTIF)
163
+ output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
164
+ assert len(seq) == len(output)
165
+ return output
166
+
167
+
168
+ class PagingMQTLDataset(IterableDataset):
169
+ def __init__(self,
170
+ m_dataset,
171
+ seq_len,
172
+ tokenizer,
173
+ max_length=512,
174
+ check_if_pipeline_is_ok_by_inserting_debug_motif=False):
175
+ self.dataset = m_dataset
176
+ self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
177
+ self.debug_motif = "ATCGCCTA"
178
+ self.seq_len = seq_len
179
+
180
+ self.bert_tokenizer = tokenizer
181
+ self.max_length = max_length
182
+ pass
183
+
184
+ def __iter__(self):
185
+ for row in self.dataset:
186
+ processed = self.preprocess(row)
187
+ if processed is not None:
188
+ yield processed
189
+
190
+ def preprocess(self, row):
191
+ sequence = row['sequence'] # Fetch the 'sequence' column
192
+ if len(sequence) != self.seq_len:
193
+ return None # skip problematic row!
194
+ label = row['label'] # Fetch the 'label' column (or whatever target you use)
195
+ if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
196
+ sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
197
+ # Tokenize the sequence
198
+ encoded_sequence: BatchEncoding = self.bert_tokenizer(
199
+ sequence,
200
+ truncation=True,
201
+ padding='max_length',
202
+ max_length=self.max_length,
203
+ return_tensors='pt'
204
+ )
205
+ encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
206
+ return encoded_sequence_squeezed, label
207
+
208
+
209
+ class MqtlDataModule(LightningDataModule):
210
+ def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
211
+ super().__init__()
212
+ self.batch_size = batch_size
213
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
214
+ # collate_fn=collate_fn,
215
+ num_workers=1,
216
+ # persistent_workers=True
217
+ )
218
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
219
+ # collate_fn=collate_fn,
220
+ num_workers=1,
221
+ # persistent_workers=True
222
+ )
223
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
224
+ # collate_fn=collate_fn,
225
+ num_workers=1,
226
+ # persistent_workers=True
227
+ )
228
+ pass
229
+
230
+ def prepare_data(self):
231
+ pass
232
+
233
+ def setup(self, stage: str) -> None:
234
+ timber.info(f"inside setup: {stage = }")
235
+ pass
236
+
237
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
238
+ return self.train_loader
239
+
240
+ def val_dataloader(self) -> EVAL_DATALOADERS:
241
+ return self.validate_loader
242
+
243
+ def test_dataloader(self) -> EVAL_DATALOADERS:
244
+ return self.test_loader
245
+
246
+
247
  class MQtlBertClassifierLightningModule(LightningModule):
248
  def __init__(self,
249
  classifier: nn.Module,
 
307
  # print(f"debug { batch = }")
308
  x, y = batch
309
  preds = self.forward(x)
310
+ loss = self.criterion(preds, y)
311
  self.log("valid_loss", loss)
312
  # calculate the scores start
313
  self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
 
336
  pass
337
 
338
 
339
+ DNA_BERT_6 = "zhihan1996/DNA_bert_6"
 
 
 
 
 
 
 
340
 
 
 
 
341
 
342
+ class CommonAttentionLayer(nn.Module):
343
+ def __init__(self, hidden_size, *args, **kwargs):
344
+ super().__init__(*args, **kwargs)
345
+ self.attention_linear = nn.Linear(hidden_size, 1)
346
+ pass
 
 
 
 
 
 
347
 
348
+ def forward(self, hidden_states):
349
+ # Apply linear layer
350
+ attn_weights = self.attention_linear(hidden_states)
351
+ # Apply softmax to get attention scores
352
+ attn_weights = torch.softmax(attn_weights, dim=1)
353
+ # Apply attention weights to hidden states
354
+ context_vector = torch.sum(attn_weights * hidden_states, dim=1)
355
+ return context_vector, attn_weights
356
 
 
 
 
 
 
 
 
 
 
 
357
 
358
+ class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
359
+ def forward(self, input, target):
360
+ return super().forward(input.squeeze(), target.float())
361
 
 
 
 
 
362
 
363
+ class DnaBert6MQTLClassifier(nn.Module):
364
+ def __init__(self,
365
+ seq_len: int, model_repository_name: str,
366
+ bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
367
+ hidden_size=768,
368
+ num_classes=1,
369
+ *args,
370
+ **kwargs
371
+ ):
372
+ super().__init__(*args, **kwargs)
373
+ self.seq_len = seq_len
374
+ self.model_repository_name = model_repository_name
375
 
376
+ self.model_name = "MQtlDnaBERT6Classifier"
 
377
 
378
+ self.bert_model = bert_model
379
+ self.attention = CommonAttentionLayer(hidden_size)
380
+ self.classifier = nn.Linear(hidden_size, num_classes)
381
+ pass
382
 
383
+ def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
384
+ """
385
+ # torch.Size([128, 1, 512]) --> [128, 512]
386
+ input_ids = input_ids.squeeze(dim=1).to(DEVICE)
387
+ # torch.Size([16, 1, 512]) --> [16, 512]
388
+ attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
389
+ token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
390
+ """
391
+ bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
392
+ input_ids=input_ids,
393
+ attention_mask=attention_mask,
394
+ token_type_ids=token_type_ids
395
+ )
396
 
397
+ last_hidden_state = bert_output.last_hidden_state
398
+ context_vector, ignore_attention_weight = self.attention(last_hidden_state)
399
+ y = self.classifier(context_vector)
400
+ return y
401
 
402
 
403
+ def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW=200,
404
+ is_binned=True, is_debug=False, max_epochs=10, batch_size=8):
405
  file_suffix = ""
406
  if is_binned:
407
  file_suffix = "_binned"
408
 
409
+ data_files = {
410
+ # small samples
411
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
412
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
413
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
414
+ # large samples
415
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
416
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
417
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
418
+ }
419
+
420
+ dataset_map = None
421
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
422
+ if is_my_laptop:
423
+ dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
424
+ else:
425
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
426
+
427
+ tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6)
428
+
429
+ train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
430
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
431
+ tokenizer=tokenizer,
432
+ seq_len=WINDOW
433
+ )
434
+ val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
435
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
436
+ tokenizer=tokenizer,
437
+ seq_len=WINDOW)
438
+ test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
439
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
440
+ tokenizer=tokenizer,
441
+ seq_len=WINDOW)
442
+
443
+ data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
444
+
445
+ classifier_model = classifier_model #.to(DEVICE)
446
+ try:
447
+ classifier_model = classifier_model.from_pretrained(classifier_model.model_repository_name)
448
+ except Exception as x:
449
+ print(x)
450
 
451
  classifier_module = MQtlBertClassifierLightningModule(
452
  classifier=classifier_model,
 
457
 
458
  classifier_module = classifier_module # .double()
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  trainer = Trainer(max_epochs=max_epochs, precision="32")
 
 
461
  trainer.fit(model=classifier_module, datamodule=data_module)
462
+ timber.info("\n\n")
463
  trainer.test(model=classifier_module, datamodule=data_module)
464
+ timber.info("\n\n")
465
+ # torch.save(classifier_module.state_dict(), model_save_path) # deprecated, use classifier_model.save_pretrained(model_subdirectory) instead
466
+
467
+ # save locally
468
+ model_subdirectory = classifier_model.model_repository_name
469
+ classifier_model.save_pretrained(model_subdirectory)
470
+
471
+ # push to the hub
472
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
473
+ if is_my_laptop:
474
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
475
+
476
+ classifier_model.push_to_hub(
477
+ repo_id=f"fahimfarhan/{classifier_model.model_repository_name}",
478
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
479
+ commit_message=commit_message # f":tada: Push model for window size {WINDOW}"
480
+ )
481
+
482
+ # reload
483
+ classifier_model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
484
 
 
485
  pass
486
 
487
 
488
+ if __name__ == '__main__':
489
+ login_inside_huggingface_virtualmachine()
490
+
491
+ WINDOW = 200
492
+ some_model = DnaBert6MQTLClassifier(seq_len=WINDOW, model_repository_name="dnabert-6-mqtl-classifier")
493
+ criterion = ReshapedBCEWithLogitsLoss()
494
+
495
+ start_bert(classifier_model=some_model, criterion=criterion, WINDOW=WINDOW, is_debug=True, max_epochs=2)
496
  pass
app_v1_backup.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from pytorch_lightning import Trainer, LightningModule, LightningDataModule
4
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
7
+ from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
8
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
9
+ import torch
10
+ from torch import nn
11
+ from datasets import load_dataset
12
+
13
+ black = "\u001b[30m"
14
+ red = "\u001b[31m"
15
+ green = "\u001b[32m"
16
+ yellow = "\u001b[33m"
17
+ blue = "\u001b[34m"
18
+ magenta = "\u001b[35m"
19
+ cyan = "\u001b[36m"
20
+ white = "\u001b[37m"
21
+
22
+ FORWARD = "FORWARD_INPUT"
23
+ BACKWARD = "BACKWARD_INPUT"
24
+
25
+ DNA_BERT_6 = "zhihan1996/DNA_bert_6"
26
+
27
+
28
+ class CommonAttentionLayer(nn.Module):
29
+ def __init__(self, hidden_size, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+ self.attention_linear = nn.Linear(hidden_size, 1)
32
+ pass
33
+
34
+ def forward(self, hidden_states):
35
+ # Apply linear layer
36
+ attn_weights = self.attention_linear(hidden_states)
37
+ # Apply softmax to get attention scores
38
+ attn_weights = torch.softmax(attn_weights, dim=1)
39
+ # Apply attention weights to hidden states
40
+ context_vector = torch.sum(attn_weights * hidden_states, dim=1)
41
+ return context_vector, attn_weights
42
+
43
+
44
+ class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
45
+ def forward(self, input, target):
46
+ return super().forward(input.squeeze(), target.float())
47
+
48
+
49
+ class MQtlDnaBERT6Classifier(nn.Module):
50
+ def __init__(self,
51
+ bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
52
+ hidden_size=768,
53
+ num_classes=1,
54
+ *args,
55
+ **kwargs
56
+ ):
57
+ super().__init__(*args, **kwargs)
58
+
59
+ self.model_name = "MQtlDnaBERT6Classifier"
60
+
61
+ self.bert_model = bert_model
62
+ self.attention = CommonAttentionLayer(hidden_size)
63
+ self.classifier = nn.Linear(hidden_size, num_classes)
64
+ pass
65
+
66
+ def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids):
67
+ """
68
+ # torch.Size([128, 1, 512]) --> [128, 512]
69
+ input_ids = input_ids.squeeze(dim=1).to(DEVICE)
70
+ # torch.Size([16, 1, 512]) --> [16, 512]
71
+ attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
72
+ token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
73
+ """
74
+ bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
+ token_type_ids=token_type_ids
78
+ )
79
+
80
+ last_hidden_state = bert_output.last_hidden_state
81
+ context_vector, ignore_attention_weight = self.attention(last_hidden_state)
82
+ y = self.classifier(context_vector)
83
+ return y
84
+
85
+
86
+ class TorchMetrics:
87
+ def __init__(self):
88
+ self.binary_accuracy = BinaryAccuracy() #.to(device)
89
+ self.binary_auc = BinaryAUROC() # .to(device)
90
+ self.binary_f1_score = BinaryF1Score() # .to(device)
91
+ self.binary_precision = BinaryPrecision() # .to(device)
92
+ self.binary_recall = BinaryRecall() # .to(device)
93
+ pass
94
+
95
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
96
+ # it looks like the library maintainers changed preds to input, ie, before: preds, now: input
97
+ self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels)
98
+ self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels)
99
+ self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels)
100
+ self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels)
101
+ self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
102
+ pass
103
+
104
+ def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
105
+ b_accuracy = self.binary_accuracy.compute()
106
+ b_auc = self.binary_auc.compute()
107
+ b_f1_score = self.binary_f1_score.compute()
108
+ b_precision = self.binary_precision.compute()
109
+ b_recall = self.binary_recall.compute()
110
+ # timber.info( log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
111
+ log(f"{log_prefix}_accuracy", b_accuracy)
112
+ log(f"{log_prefix}_auc", b_auc)
113
+ log(f"{log_prefix}_f1_score", b_f1_score)
114
+ log(f"{log_prefix}_precision", b_precision)
115
+ log(f"{log_prefix}_recall", b_recall)
116
+
117
+ self.binary_accuracy.reset()
118
+ self.binary_auc.reset()
119
+ self.binary_f1_score.reset()
120
+ self.binary_precision.reset()
121
+ self.binary_recall.reset()
122
+ pass
123
+
124
+
125
+ class MQtlBertClassifierLightningModule(LightningModule):
126
+ def __init__(self,
127
+ classifier: nn.Module,
128
+ criterion=None, # nn.BCEWithLogitsLoss(),
129
+ regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
130
+ l1_lambda=0.001,
131
+ l2_wright_decay=0.001,
132
+ *args: Any,
133
+ **kwargs: Any):
134
+ super().__init__(*args, **kwargs)
135
+ self.classifier = classifier
136
+ self.criterion = criterion
137
+ self.train_metrics = TorchMetrics()
138
+ self.validate_metrics = TorchMetrics()
139
+ self.test_metrics = TorchMetrics()
140
+
141
+ self.regularization = regularization
142
+ self.l1_lambda = l1_lambda
143
+ self.l2_weight_decay = l2_wright_decay
144
+ pass
145
+
146
+ def forward(self, x, *args: Any, **kwargs: Any) -> Any:
147
+ input_ids: torch.tensor = x["input_ids"]
148
+ attention_mask: torch.tensor = x["attention_mask"]
149
+ token_type_ids: torch.tensor = x["token_type_ids"]
150
+ # print(f"\n{ type(input_ids) = }, {input_ids = }")
151
+ # print(f"{ type(attention_mask) = }, { attention_mask = }")
152
+ # print(f"{ type(token_type_ids) = }, { token_type_ids = }")
153
+
154
+ return self.classifier.forward(input_ids, attention_mask, token_type_ids)
155
+
156
+ def configure_optimizers(self) -> OptimizerLRScheduler:
157
+ # Here we add weight decay (L2 regularization) to the optimizer
158
+ weight_decay = 0.0
159
+ if self.regularization == 2 or self.regularization == 3:
160
+ weight_decay = self.l2_weight_decay
161
+ return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
162
+
163
+ def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
164
+ # Accuracy on training batch data
165
+ x, y = batch
166
+ preds = self.forward(x)
167
+ loss = self.criterion(preds, y)
168
+
169
+ if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
170
+ l1_norm = sum(p.abs().sum() for p in self.parameters())
171
+ loss += self.l1_lambda * l1_norm
172
+
173
+ self.log("train_loss", loss)
174
+ # calculate the scores start
175
+ self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
176
+ # calculate the scores end
177
+ return loss
178
+
179
+ def on_train_epoch_end(self) -> None:
180
+ self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
181
+ pass
182
+
183
+ def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
184
+ # Accuracy on validation batch data
185
+ # print(f"debug { batch = }")
186
+ x, y = batch
187
+ preds = self.forward(x)
188
+ loss = 0 # self.criterion(preds, y)
189
+ self.log("valid_loss", loss)
190
+ # calculate the scores start
191
+ self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
192
+ # calculate the scores end
193
+ return loss
194
+
195
+ def on_validation_epoch_end(self) -> None:
196
+ self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
197
+ return None
198
+
199
+ def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
200
+ # Accuracy on validation batch data
201
+ x, y = batch
202
+ preds = self.forward(x)
203
+ loss = self.criterion(preds, y)
204
+ self.log("test_loss", loss) # do we need this?
205
+ # calculate the scores start
206
+ self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
207
+ # calculate the scores end
208
+ return loss
209
+
210
+ def on_test_epoch_end(self) -> None:
211
+ self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
212
+ return None
213
+
214
+ pass
215
+
216
+
217
+ class DNABERTDataset(Dataset):
218
+ def __init__(self, dataset, tokenizer, max_length=512):
219
+ self.dataset = dataset
220
+ self.bert_tokenizer = tokenizer
221
+ self.max_length = max_length
222
+
223
+ def __len__(self):
224
+ return len(self.dataset)
225
+
226
+ def __getitem__(self, idx):
227
+ sequence = self.dataset[idx]['sequence'] # Fetch the 'sequence' column
228
+ label = self.dataset[idx]['label'] # Fetch the 'label' column (or whatever target you use)
229
+
230
+ # Tokenize the sequence
231
+ encoded_sequence: BatchEncoding = self.bert_tokenizer(
232
+ sequence,
233
+ truncation=True,
234
+ padding='max_length',
235
+ max_length=self.max_length,
236
+ return_tensors='pt'
237
+ )
238
+
239
+ encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
240
+ return encoded_sequence_squeezed, label
241
+
242
+
243
+ class DNABERTDataModule(LightningDataModule):
244
+ def __init__(self, model_name=DNA_BERT_6, batch_size=8):
245
+ super().__init__()
246
+ self.tokenized_dataset = None
247
+ self.dataset = None
248
+ self.train_dataset: DNABERTDataset = None
249
+ self.validate_dataset: DNABERTDataset = None
250
+ self.test_dataset: DNABERTDataset = None
251
+ self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6)
252
+ self.batch_size = batch_size
253
+
254
+ def prepare_data(self):
255
+ # Download and prepare dataset
256
+ self.dataset = load_dataset("fahimfarhan/mqtl-classification-dataset-binned-200")
257
+
258
+ def setup(self, stage=None):
259
+ self.train_dataset = DNABERTDataset(self.dataset['train'], self.tokenizer)
260
+ self.validate_dataset = DNABERTDataset(self.dataset['validate'], self.tokenizer)
261
+ self.test_dataset = DNABERTDataset(self.dataset['test'], self.tokenizer)
262
+
263
+ def train_dataloader(self):
264
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=15)
265
+
266
+ def val_dataloader(self):
267
+ return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=15)
268
+
269
+ def test_dataloader(self) -> EVAL_DATALOADERS:
270
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=15)
271
+
272
+
273
+ # Initialize DataModule
274
+ model_name = "zhihan1996/DNABERT-6"
275
+ data_module = DNABERTDataModule(model_name=model_name, batch_size=8)
276
+
277
+
278
+ def start_bert(classifier_model, model_save_path, criterion, WINDOW=200, batch_size=4,
279
+ dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
280
+ file_suffix = ""
281
+ if is_binned:
282
+ file_suffix = "_binned"
283
+
284
+ data_module = DNABERTDataModule(batch_size=batch_size)
285
+
286
+ # classifier_model = classifier_model.to(DEVICE)
287
+
288
+ classifier_module = MQtlBertClassifierLightningModule(
289
+ classifier=classifier_model,
290
+ regularization=2, criterion=criterion)
291
+
292
+ # if os.path.exists(model_save_path):
293
+ # classifier_module.load_state_dict(torch.load(model_save_path))
294
+
295
+ classifier_module = classifier_module # .double()
296
+
297
+ # Set up training arguments
298
+ training_args = TrainingArguments(
299
+ output_dir='./results',
300
+ evaluation_strategy="epoch",
301
+ per_device_train_batch_size=batch_size,
302
+ per_device_eval_batch_size=batch_size,
303
+ num_train_epochs=max_epochs,
304
+ logging_dir='./logs',
305
+ report_to="none", # Disable reporting to WandB, etc.
306
+ )
307
+
308
+ # Prepare data using the DataModule
309
+ data_module.prepare_data()
310
+ data_module.setup()
311
+
312
+ # Initialize Trainer
313
+ # trainer = Trainer(
314
+ # model=classifier_module,
315
+ # args=training_args,
316
+ # train_dataset=data_module.tokenized_dataset["train"],
317
+ # eval_dataset=data_module.tokenized_dataset["test"],
318
+ # )
319
+
320
+ trainer = Trainer(max_epochs=max_epochs, precision="32")
321
+
322
+ # Train the model
323
+ trainer.fit(model=classifier_module, datamodule=data_module)
324
+ trainer.test(model=classifier_module, datamodule=data_module)
325
+ torch.save(classifier_module.state_dict(), model_save_path)
326
+
327
+ classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
328
+ pass
329
+
330
+
331
+ if __name__ == "__main__":
332
+ dataset_folder_prefix = "inputdata/"
333
+ pytorch_model = MQtlDnaBERT6Classifier()
334
+ start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
335
+ criterion=ReshapedBCEWithLogitsLoss(), WINDOW=200, batch_size=4,
336
+ dataset_folder_prefix=dataset_folder_prefix, max_epochs=2)
337
+ pass
requirements.txt CHANGED
@@ -30,5 +30,4 @@ gReLU # luckily now available in pip!
30
  # gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
31
  # lightning[extra] # cz I got a stupid warning in the console logs
32
  torchmetrics
33
- transformers
34
- huggingface-hub
 
30
  # gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
31
  # lightning[extra] # cz I got a stupid warning in the console logs
32
  torchmetrics
33
+ python-dotenv