Soumic commited on
Commit
c334cb2
·
1 Parent(s): 5b23ff9

:hammer: Create another submodule to test the huggingface pipeline

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. Dockerfile +36 -0
  3. app.py +418 -0
  4. requirements.txt +32 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ lightning_logs/
2
+ *.pth
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official PyTorch Docker image as a base (includes CUDA and PyTorch)
2
+ FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
3
+
4
+ # Install required dependencies (add any additional system dependencies you need)
5
+ RUN apt update && apt install -y ffmpeg
6
+
7
+ # Create a non-root user with a home directory
8
+ RUN useradd -m -u 1000 user
9
+
10
+ # Switch to the new non-root user
11
+ USER user
12
+
13
+ # Set environment variables for the new user
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ # Set a working directory
18
+ WORKDIR $HOME/app
19
+
20
+ # Set the TRANSFORMERS_CACHE directory to be within the user's home directory
21
+ ENV TRANSFORMERS_CACHE=$HOME/cache
22
+
23
+ # Copy the app code and set ownership to the non-root user
24
+ COPY --chown=user . $HOME/app
25
+
26
+ # Install Python dependencies in the virtual environment
27
+ RUN python -m venv /home/user/venv
28
+ ENV PATH="/home/user/venv/bin:$PATH"
29
+
30
+ # Install pip dependencies within the virtual environment
31
+ COPY requirements.txt .
32
+ RUN pip install --upgrade pip
33
+ RUN pip install -r requirements.txt
34
+
35
+ # Run the training script
36
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pytorch_lightning import Trainer, LightningModule, LightningDataModule
8
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
11
+ from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
12
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
13
+ import torch
14
+ from torch import nn
15
+ from datasets import load_dataset
16
+
17
+ timber = logging.getLogger()
18
+ # logging.basicConfig(level=logging.DEBUG)
19
+ logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
20
+
21
+ black = "\u001b[30m"
22
+ red = "\u001b[31m"
23
+ green = "\u001b[32m"
24
+ yellow = "\u001b[33m"
25
+ blue = "\u001b[34m"
26
+ magenta = "\u001b[35m"
27
+ cyan = "\u001b[36m"
28
+ white = "\u001b[37m"
29
+
30
+ FORWARD = "FORWARD_INPUT"
31
+ BACKWARD = "BACKWARD_INPUT"
32
+
33
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+
36
+ def one_hot_e(dna_seq: str) -> np.ndarray:
37
+ mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
38
+ 'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
39
+ 'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
40
+ 'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
41
+ 'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
42
+ 'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
43
+
44
+ size_of_a_seq: int = len(dna_seq)
45
+
46
+ # forward = np.zeros(shape=(size_of_a_seq, 4))
47
+
48
+ forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
49
+ encoded = np.asarray(forward_list)
50
+ encoded_transposed = encoded.transpose() # todo: Needs review
51
+ return encoded_transposed
52
+
53
+
54
+ def one_hot_e_column(column: pd.Series) -> np.ndarray:
55
+ tmp_list: list = [one_hot_e(seq) for seq in column]
56
+ encoded_column = np.asarray(tmp_list).astype(np.float32)
57
+ return encoded_column
58
+
59
+
60
+ def reverse_dna_seq(dna_seq: str) -> str:
61
+ # m_reversed = ""
62
+ # for i in range(0, len(dna_seq)):
63
+ # m_reversed = dna_seq[i] + m_reversed
64
+ # return m_reversed
65
+ return dna_seq[::-1]
66
+
67
+
68
+ def complement_dna_seq(dna_seq: str) -> str:
69
+ comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
70
+ "a": "t", "c": "g", "t": "a", "g": "c",
71
+ "N": "N", "H": "H", "-": "-",
72
+ "n": "n", "h": "h"
73
+ }
74
+
75
+ comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
76
+ comp_dna_seq: str = "".join(comp_dna_seq_list)
77
+ return comp_dna_seq
78
+
79
+
80
+ def reverse_complement_dna_seq(dna_seq: str) -> str:
81
+ return reverse_dna_seq(complement_dna_seq(dna_seq))
82
+
83
+
84
+ def reverse_complement_column(column: pd.Series) -> np.ndarray:
85
+ rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
86
+ return rc_column
87
+
88
+
89
+ class TorchMetrics:
90
+ def __init__(self, device=DEVICE):
91
+ self.binary_accuracy = BinaryAccuracy().to(device)
92
+ self.binary_auc = BinaryAUROC().to(device)
93
+ self.binary_f1_score = BinaryF1Score().to(device)
94
+ self.binary_precision = BinaryPrecision().to(device)
95
+ self.binary_recall = BinaryRecall().to(device)
96
+ pass
97
+
98
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
99
+ self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
100
+ self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
101
+ self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
102
+ self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
103
+ self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
104
+ pass
105
+
106
+ def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
107
+ b_accuracy = self.binary_accuracy.compute()
108
+ b_auc = self.binary_auc.compute()
109
+ b_f1_score = self.binary_f1_score.compute()
110
+ b_precision = self.binary_precision.compute()
111
+ b_recall = self.binary_recall.compute()
112
+ timber.info(
113
+ log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
114
+ log(f"{log_prefix}_accuracy", b_accuracy)
115
+ log(f"{log_prefix}_auc", b_auc)
116
+ log(f"{log_prefix}_f1_score", b_f1_score)
117
+ log(f"{log_prefix}_precision", b_precision)
118
+ log(f"{log_prefix}_recall", b_recall)
119
+
120
+ self.binary_accuracy.reset()
121
+ self.binary_auc.reset()
122
+ self.binary_f1_score.reset()
123
+ self.binary_precision.reset()
124
+ self.binary_recall.reset()
125
+ pass
126
+
127
+
128
+ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
129
+ start = 0
130
+ end = len(seq)
131
+ rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
132
+ random_end = rand_pos + len(DEBUG_MOTIF)
133
+ output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
134
+ assert len(seq) == len(output)
135
+ return output
136
+
137
+
138
+ class MQTLDataset(Dataset):
139
+ def __init__(self, dataset, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
140
+ self.dataset = dataset
141
+ self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
142
+ self.debug_motif = "ATCGCCTA"
143
+ pass
144
+
145
+ def __len__(self):
146
+ return len(self.dataset)
147
+
148
+ def __getitem__(self, idx):
149
+ seq = self.dataset[idx]['sequence'] # Fetch the 'sequence' column
150
+ label = self.dataset[idx]['label'] # Fetch the 'label' column (or whatever target you use)
151
+ if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
152
+ seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
153
+ seq_rc = reverse_complement_dna_seq(seq)
154
+ ohe_seq = one_hot_e(dna_seq=seq)
155
+ # print(f"shape fafafa = { ohe_seq.shape = }")
156
+ ohe_seq_rc = one_hot_e(dna_seq=seq_rc)
157
+
158
+ label_number = label * 1.0
159
+ label_np_array = np.asarray([label_number]).astype(np.float32)
160
+ # return ohe_seq, ohe_seq_rc, label
161
+ return [ohe_seq, ohe_seq_rc], label_np_array
162
+
163
+
164
+ class MqtlDataModule(LightningDataModule):
165
+ def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
166
+ super().__init__()
167
+ self.batch_size = batch_size
168
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True, num_workers=15,
169
+ persistent_workers=True)
170
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
171
+ persistent_workers=True)
172
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
173
+ persistent_workers=True)
174
+ pass
175
+
176
+ def prepare_data(self):
177
+ pass
178
+
179
+ def setup(self, stage: str) -> None:
180
+ timber.info(f"inside setup: {stage = }")
181
+ pass
182
+
183
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
184
+ return self.train_loader
185
+
186
+ def val_dataloader(self) -> EVAL_DATALOADERS:
187
+ return self.validate_loader
188
+
189
+ def test_dataloader(self) -> EVAL_DATALOADERS:
190
+ return self.test_loader
191
+
192
+
193
+ class MQtlClassifierLightningModule(LightningModule):
194
+ def __init__(self,
195
+ classifier: nn.Module,
196
+ criterion=nn.BCELoss(), # nn.BCEWithLogitsLoss(),
197
+ regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
198
+ l1_lambda=0.001,
199
+ l2_wright_decay=0.001,
200
+ m_optimizer=torch.optim.Adam,
201
+ *args: Any,
202
+ **kwargs: Any):
203
+ super().__init__(*args, **kwargs)
204
+ self.classifier = classifier
205
+ self.criterion = criterion
206
+ self.train_metrics = TorchMetrics()
207
+ self.validate_metrics = TorchMetrics()
208
+ self.test_metrics = TorchMetrics()
209
+
210
+ self.regularization = regularization
211
+ self.l1_lambda = l1_lambda
212
+ self.l2_weight_decay = l2_wright_decay
213
+ self.m_optimizer = m_optimizer
214
+ pass
215
+
216
+ def forward(self, x, *args: Any, **kwargs: Any) -> Any:
217
+ return self.classifier.forward(x)
218
+
219
+ def configure_optimizers(self) -> OptimizerLRScheduler:
220
+ # Here we add weight decay (L2 regularization) to the optimizer
221
+ weight_decay = 0.0
222
+ if self.regularization == 2 or self.regularization == 3:
223
+ weight_decay = self.l2_weight_decay
224
+ return self.m_optimizer(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
225
+
226
+ def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
227
+ # Accuracy on training batch data
228
+ x, y = batch
229
+ x = [i.float() for i in x]
230
+ preds = self.forward(x)
231
+ loss = self.criterion(preds, y)
232
+
233
+ if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
234
+ l1_norm = sum(p.abs().sum() for p in self.parameters())
235
+ loss += self.l1_lambda * l1_norm
236
+
237
+ self.log("train_loss", loss)
238
+ # calculate the scores start
239
+ self.train_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
240
+ # calculate the scores end
241
+ return loss
242
+
243
+ def on_train_epoch_end(self) -> None:
244
+ timber.info(green + "on_train_epoch_end")
245
+ self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
246
+ pass
247
+
248
+ def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
249
+ # Accuracy on validation batch data
250
+ x, y = batch
251
+ x = [i.float() for i in x]
252
+
253
+ preds = self.forward(x)
254
+ loss = self.criterion(preds, y)
255
+ self.log("valid_loss", loss)
256
+ # calculate the scores start
257
+ self.validate_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
258
+ # calculate the scores end
259
+ return loss
260
+
261
+ def on_validation_epoch_end(self) -> None:
262
+ timber.info(blue + "on_validation_epoch_end")
263
+ self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
264
+ return None
265
+
266
+ def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
267
+ # Accuracy on validation batch data
268
+ x, y = batch
269
+ x = [i.float() for i in x]
270
+
271
+ preds = self.forward(x)
272
+ loss = self.criterion(preds, y)
273
+ self.log("test_loss", loss) # do we need this?
274
+ # calculate the scores start
275
+ self.test_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
276
+ # calculate the scores end
277
+ return loss
278
+
279
+ def on_test_epoch_end(self) -> None:
280
+ timber.info(magenta + "on_test_epoch_end")
281
+ self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
282
+ return None
283
+
284
+ pass
285
+
286
+
287
+ # Some more util functions!
288
+ def create_conv_sequence(in_channel_num_of_nucleotides, num_filters, kernel_size_k_mer_motif) -> nn.Sequential:
289
+ conv1d = nn.Conv1d(in_channels=in_channel_num_of_nucleotides, out_channels=num_filters,
290
+ kernel_size=kernel_size_k_mer_motif,
291
+ padding="same") # stride = 2, just dont use stride, keep it simple for now
292
+ activation = nn.ReLU(inplace=False) # (inplace=True) will fess with interpretability
293
+ pooling = nn.MaxPool1d(
294
+ kernel_size=kernel_size_k_mer_motif) # stride = 2, just dont use stride, keep it simple for now
295
+
296
+ return nn.Sequential(conv1d, activation, pooling)
297
+
298
+
299
+ class Cnn1dClassifier(nn.Module):
300
+ def __init__(self,
301
+ seq_len,
302
+ in_channel_num_of_nucleotides=4,
303
+ kernel_size_k_mer_motif=4,
304
+ num_filters=32,
305
+ lstm_hidden_size=128,
306
+ dnn_size=128,
307
+ conv_seq_list_size=3,
308
+ *args, **kwargs):
309
+ super().__init__(*args, **kwargs)
310
+ self.file_name = f"weights_Cnn1dClassifier_seqlen_{seq_len}.pth"
311
+
312
+ self.seq_layer_forward = create_conv_sequence(in_channel_num_of_nucleotides, num_filters,
313
+ kernel_size_k_mer_motif)
314
+ self.seq_layer_backward = create_conv_sequence(in_channel_num_of_nucleotides, num_filters,
315
+ kernel_size_k_mer_motif)
316
+
317
+ self.flatten = nn.Flatten()
318
+
319
+ dnn_in_features = int(num_filters * (seq_len * 2) / kernel_size_k_mer_motif) # no idea why
320
+ # two because forward_sequence,and backward_sequence
321
+ self.dnn = nn.Linear(in_features=dnn_in_features, out_features=dnn_size)
322
+ self.dnn_activation = nn.ReLU(inplace=False) # inplace = true messes with interpretability!
323
+ self.dropout = nn.Dropout(p=0.33)
324
+
325
+ self.output_layer = nn.Linear(in_features=dnn_size, out_features=1)
326
+ self.output_activation = torch.sigmoid # not needed if using nn.BCEWithLogitsLoss()
327
+
328
+ self.layer_output_logger: dict = {}
329
+ pass
330
+
331
+ def forward(self, x):
332
+ xf, xb = x[0], x[1]
333
+
334
+ hf = self.seq_layer_forward(xf)
335
+ timber.debug(red + f"1{ hf.shape = }")
336
+ hb = self.seq_layer_backward(xb)
337
+ timber.debug(green + f"2{ hb.shape = }")
338
+
339
+ h = torch.concatenate(tensors=(hf, hb), dim=2)
340
+ timber.debug(yellow + f"4{ h.shape = } concat")
341
+
342
+ h = self.flatten(h)
343
+ timber.debug(yellow + f"5{ h.shape = } flatten")
344
+
345
+ h = self.dnn(h)
346
+ timber.debug(yellow + f"8{ h.shape = } dnn")
347
+
348
+ h = self.dnn_activation(h)
349
+ timber.debug(blue + f"9{ h.shape = } dnn_activation")
350
+ h = self.dropout(h)
351
+ timber.debug(blue + f"10{ h.shape = } dropout")
352
+ h = self.output_layer(h)
353
+ timber.debug(blue + f"11{ h.shape = } output_layer")
354
+ h = self.output_activation(h)
355
+ timber.debug(blue + f"12{ h.shape = } output_activation")
356
+ return h
357
+
358
+
359
+ def start(classifier_model, model_save_path, is_attention_model=False, m_optimizer=torch.optim.Adam, WINDOW=200,
360
+ dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
361
+ # experiment = 'tutorial_3'
362
+ # if not os.path.exists(experiment):
363
+ # os.makedirs(experiment)
364
+ """
365
+ x_train, x_tmp, y_train, y_tmp = train_test_split(df["sequence"], df["label"], test_size=0.2)
366
+ x_test, x_val, y_test, y_val = train_test_split(x_tmp, y_tmp, test_size=0.5)
367
+
368
+ train_dataset = MyDataSet(x_train, y_train)
369
+ val_dataset = MyDataSet(x_val, y_val)
370
+ test_dataset = MyDataSet(x_test, y_test)
371
+ """
372
+ file_suffix = ""
373
+ if is_binned:
374
+ file_suffix = "_binned"
375
+
376
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-dataset-binned-200")
377
+
378
+ train_dataset = MQTLDataset(dataset_map["train"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
379
+ val_dataset = MQTLDataset(dataset_map["validate"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
380
+ test_dataset = MQTLDataset(dataset_map["test"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
381
+
382
+ data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
383
+
384
+ classifier_model = classifier_model #.to(DEVICE)
385
+
386
+ classifier_module = MQtlClassifierLightningModule(classifier=classifier_model, regularization=2,
387
+ m_optimizer=m_optimizer)
388
+
389
+ # if os.path.exists(model_save_path):
390
+ # classifier_module.load_state_dict(torch.load(model_save_path))
391
+
392
+ classifier_module = classifier_module # .double()
393
+
394
+ trainer = Trainer(max_epochs=max_epochs, precision="32")
395
+ trainer.fit(model=classifier_module, datamodule=data_module)
396
+ timber.info("\n\n")
397
+ trainer.test(model=classifier_module, datamodule=data_module)
398
+ timber.info("\n\n")
399
+ torch.save(classifier_module.state_dict(), model_save_path)
400
+
401
+ trainer.push_to_hub("fahimfarhan/mqtl-classifier-model")
402
+
403
+ # start_interpreting_ig_and_dl(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
404
+ # start_interpreting_with_dlshap(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
405
+ # if is_attention_model: # todo: repair it later
406
+ # start_interpreting_attention_failed(classifier_model)
407
+ pass
408
+
409
+
410
+ if __name__ == '__main__':
411
+ WINDOW = 200
412
+ simple_cnn = Cnn1dClassifier(seq_len=WINDOW)
413
+ simple_cnn.enable_logging = True
414
+
415
+ start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
416
+ dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
417
+
418
+ pass
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate # required by HayenaDNA
2
+ datasets
3
+ pandas
4
+ polars
5
+ numpy
6
+ matplotlib
7
+ scipy
8
+ shap
9
+ scikit-learn
10
+ skorch==1.0.0
11
+ six
12
+ hyperopt
13
+ requests
14
+ pyyaml
15
+ Bio
16
+ plotly
17
+ Levenshtein
18
+ # pytorch
19
+ captum
20
+ torch==2.4.0
21
+ torchvision
22
+ torchaudio
23
+ torchsummary
24
+ torcheval
25
+ pydot
26
+ pydotplus
27
+ PySide2 # matplotlib dependency on ubuntu. you may need sth else for other os/env setup
28
+ torchviz
29
+ gReLU # luckily now available in pip!
30
+ # gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
31
+ # lightning[extra] # cz I got a stupid warning in the console logs
32
+ torchmetrics