Soumic commited on
Commit
d6f88f6
·
2 Parent(s): ab311be b3624c4

:twisted_rightwards_arrows: Merge branch 'develop'

Browse files
Files changed (4) hide show
  1. .env_sample +1 -0
  2. .gitignore +169 -1
  3. app.py +116 -28
  4. requirements.txt +2 -1
.env_sample ADDED
@@ -0,0 +1 @@
 
 
1
+ HF_TOKEN=hf_YOUR_AWESOME_TOKEN
.gitignore CHANGED
@@ -1,3 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  lightning_logs/
2
  *.pth
3
- my-awesome-model/
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # c++ generated files
163
+ *.out
164
+ *.exe
165
+
166
+ # my custom gitignores
167
  lightning_logs/
168
  *.pth
169
+ my-awesome-model/
170
+ my-awesome-model-200/
171
+ my-awesome-model-4000/
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  import random
3
  from typing import Any
4
 
@@ -6,15 +7,19 @@ import numpy as np
6
  import pandas as pd
7
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
8
  from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
 
9
  from torch.utils.data import DataLoader, Dataset
10
  from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
11
  from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
12
  from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
13
  import torch
14
  from torch import nn
15
- from datasets import load_dataset
16
  from huggingface_hub import PyTorchModelHubMixin
17
 
 
 
 
18
  timber = logging.getLogger()
19
  # logging.basicConfig(level=logging.DEBUG)
20
  logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
@@ -34,6 +39,30 @@ BACKWARD = "BACKWARD_INPUT"
34
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def one_hot_e(dna_seq: str) -> np.ndarray:
38
  mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
39
  'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
@@ -136,19 +165,25 @@ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
136
  return output
137
 
138
 
139
- class MQTLDataset(Dataset):
140
- def __init__(self, dataset, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
141
- self.dataset = dataset
142
  self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
143
  self.debug_motif = "ATCGCCTA"
 
144
  pass
145
 
146
- def __len__(self):
147
- return len(self.dataset)
148
-
149
- def __getitem__(self, idx):
150
- seq = self.dataset[idx]['sequence'] # Fetch the 'sequence' column
151
- label = self.dataset[idx]['label'] # Fetch the 'label' column (or whatever target you use)
 
 
 
 
 
152
  if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
153
  seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
154
  seq_rc = reverse_complement_dna_seq(seq)
@@ -162,16 +197,36 @@ class MQTLDataset(Dataset):
162
  return [ohe_seq, ohe_seq_rc], label_np_array
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
165
  class MqtlDataModule(LightningDataModule):
166
  def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
167
  super().__init__()
168
  self.batch_size = batch_size
169
- self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True, num_workers=15,
170
- persistent_workers=True)
171
- self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
172
- persistent_workers=True)
173
- self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
174
- persistent_workers=True)
 
 
 
 
 
 
 
 
 
175
  pass
176
 
177
  def prepare_data(self):
@@ -376,17 +431,40 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
376
  if is_binned:
377
  file_suffix = "_binned"
378
 
379
- dataset_map = load_dataset("fahimfarhan/mqtl-classification-dataset-binned-200")
380
-
381
- train_dataset = MQTLDataset(dataset_map["train"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
382
- val_dataset = MQTLDataset(dataset_map["validate"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
383
- test_dataset = MQTLDataset(dataset_map["test"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
  data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
386
 
387
  classifier_model = classifier_model #.to(DEVICE)
388
  try:
389
- classifier_model = classifier_model.from_pretrained("my-awesome-model")
390
  except Exception as x:
391
  print(x)
392
 
@@ -406,13 +484,18 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
406
  torch.save(classifier_module.state_dict(), model_save_path)
407
 
408
  # save locally
409
- classifier_model.save_pretrained("my-awesome-model")
 
410
 
411
  # push to the hub
412
- classifier_model.push_to_hub(repo_id="fahimfarhan/mqtl-classifier-model", commit_message=":tada: Push model using huggingface_hub")
 
 
 
 
413
 
414
  # reload
415
- model = classifier_model.from_pretrained("my-awesome-model")
416
  # repo_url = "https://huggingface.co/fahimfarhan/mqtl-classifier-model"
417
  #
418
  # push_to_hub(
@@ -429,18 +512,23 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
429
 
430
 
431
  if __name__ == '__main__':
 
 
432
  WINDOW = 200
433
  simple_cnn = Cnn1dClassifier(seq_len=WINDOW)
434
  simple_cnn.enable_logging = True
435
 
436
  start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
437
- dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
438
 
439
  pass
440
 
441
-
442
  """
443
  lightning_logs/
444
  *.pth
445
  my-awesome-model
446
- """
 
 
 
 
 
1
  import logging
2
+ import os
3
  import random
4
  from typing import Any
5
 
 
7
  import pandas as pd
8
  from pytorch_lightning import Trainer, LightningModule, LightningDataModule
9
  from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
10
+ from torch.nn.utils.rnn import pad_sequence
11
  from torch.utils.data import DataLoader, Dataset
12
  from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
  from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
14
  from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
  import torch
16
  from torch import nn
17
+ from datasets import load_dataset, IterableDataset
18
  from huggingface_hub import PyTorchModelHubMixin
19
 
20
+ from dotenv import load_dotenv
21
+ from huggingface_hub import login
22
+
23
  timber = logging.getLogger()
24
  # logging.basicConfig(level=logging.DEBUG)
25
  logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
 
39
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
 
42
+ def login_inside_huggingface_virtualmachine():
43
+ # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
44
+ try:
45
+ load_dotenv() # Only useful on your laptop if .env exists
46
+ print(".env file loaded successfully.")
47
+ except Exception as e:
48
+ print(f"Warning: Could not load .env file. Exception: {e}")
49
+
50
+ # Try to get the token from environment variables
51
+ try:
52
+ token = os.getenv("HF_TOKEN")
53
+
54
+ if not token:
55
+ raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
56
+
57
+ # Log in to Hugging Face Hub
58
+ login(token)
59
+ print("Logged in to Hugging Face Hub successfully.")
60
+
61
+ except Exception as e:
62
+ print(f"Error during Hugging Face login: {e}")
63
+ # Handle the error appropriately (e.g., exit or retry)
64
+
65
+
66
  def one_hot_e(dna_seq: str) -> np.ndarray:
67
  mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
68
  'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
 
165
  return output
166
 
167
 
168
+ class MQTLDataset(IterableDataset):
169
+ def __init__(self, m_dataset, seq_len, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
170
+ self.dataset = m_dataset
171
  self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
172
  self.debug_motif = "ATCGCCTA"
173
+ self.seq_len = seq_len
174
  pass
175
 
176
+ def __iter__(self):
177
+ for row in self.dataset:
178
+ processed = self.preprocess(row)
179
+ if processed is not None:
180
+ yield processed
181
+
182
+ def preprocess(self, row):
183
+ seq = row['sequence'] # Fetch the 'sequence' column
184
+ if len(seq) != self.seq_len:
185
+ return None # skip problematic row!
186
+ label = row['label'] # Fetch the 'label' column (or whatever target you use)
187
  if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
188
  seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
189
  seq_rc = reverse_complement_dna_seq(seq)
 
197
  return [ohe_seq, ohe_seq_rc], label_np_array
198
 
199
 
200
+ # def collate_fn(batch):
201
+ # sequences, labels = zip(*batch)
202
+ # ohe_seq, ohe_seq_rc = sequences[0], sequences[1]
203
+ # # Pad sequences to the maximum length in this batch
204
+ # padded_sequences = pad_sequence(ohe_seq, batch_first=True, padding_value=0)
205
+ # padded_sequences_rc = pad_sequence(ohe_seq_rc, batch_first=True, padding_value=0)
206
+ # # Convert labels to a tensor
207
+ # labels = torch.stack(labels)
208
+ # return [padded_sequences, padded_sequences_rc], labels
209
+
210
+
211
  class MqtlDataModule(LightningDataModule):
212
  def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
213
  super().__init__()
214
  self.batch_size = batch_size
215
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
216
+ # collate_fn=collate_fn,
217
+ num_workers=15,
218
+ # persistent_workers=True
219
+ )
220
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
221
+ # collate_fn=collate_fn,
222
+ num_workers=15,
223
+ # persistent_workers=True
224
+ )
225
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
226
+ # collate_fn=collate_fn,
227
+ num_workers=15,
228
+ # persistent_workers=True
229
+ )
230
  pass
231
 
232
  def prepare_data(self):
 
431
  if is_binned:
432
  file_suffix = "_binned"
433
 
434
+ data_files = {
435
+ # small samples
436
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
437
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
438
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
439
+ # large samples
440
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
441
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
442
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
443
+ }
444
+
445
+ dataset_map = None
446
+ is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv")
447
+ if is_my_laptop:
448
+ dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
449
+ else:
450
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
451
+
452
+ train_dataset = MQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
453
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
454
+ seq_len=WINDOW
455
+ )
456
+ val_dataset = MQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
457
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
458
+ seq_len=WINDOW)
459
+ test_dataset = MQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
460
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
461
+ seq_len=WINDOW)
462
 
463
  data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
464
 
465
  classifier_model = classifier_model #.to(DEVICE)
466
  try:
467
+ classifier_model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
468
  except Exception as x:
469
  print(x)
470
 
 
484
  torch.save(classifier_module.state_dict(), model_save_path)
485
 
486
  # save locally
487
+ model_subdirectory = f"my-awesome-model-{WINDOW}"
488
+ classifier_model.save_pretrained(model_subdirectory)
489
 
490
  # push to the hub
491
+ classifier_model.push_to_hub(
492
+ repo_id="fahimfarhan/mqtl-classifier-model",
493
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
494
+ commit_message=f":tada: Push model for window size {WINDOW}"
495
+ )
496
 
497
  # reload
498
+ model = classifier_model.from_pretrained(f"my-awesome-model-{WINDOW}")
499
  # repo_url = "https://huggingface.co/fahimfarhan/mqtl-classifier-model"
500
  #
501
  # push_to_hub(
 
512
 
513
 
514
  if __name__ == '__main__':
515
+ login_inside_huggingface_virtualmachine()
516
+
517
  WINDOW = 200
518
  simple_cnn = Cnn1dClassifier(seq_len=WINDOW)
519
  simple_cnn.enable_logging = True
520
 
521
  start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
522
+ dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=10)
523
 
524
  pass
525
 
 
526
  """
527
  lightning_logs/
528
  *.pth
529
  my-awesome-model
530
+
531
+ INFO:root:validate_acc = 0.5625, validate_auc = 0.5490195751190186, validate_f1_score = 0.30000001192092896, validate_precision = 0.6000000238418579, validate_recall = 0.20000000298023224
532
+ /home/soumic/Codes/mqtl-classification/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
533
+
534
+ """
requirements.txt CHANGED
@@ -29,4 +29,5 @@ torchviz
29
  gReLU # luckily now available in pip!
30
  # gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
31
  # lightning[extra] # cz I got a stupid warning in the console logs
32
- torchmetrics
 
 
29
  gReLU # luckily now available in pip!
30
  # gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
31
  # lightning[extra] # cz I got a stupid warning in the console logs
32
+ torchmetrics
33
+ python-dotenv