Spaces:
Runtime error
Runtime error
Soumic
commited on
Commit
·
c334cb2
1
Parent(s):
5b23ff9
:hammer: Create another submodule to test the huggingface pipeline
Browse files- .gitignore +2 -0
- Dockerfile +36 -0
- app.py +418 -0
- requirements.txt +32 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
lightning_logs/
|
2 |
+
*.pth
|
Dockerfile
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use the official PyTorch Docker image as a base (includes CUDA and PyTorch)
|
2 |
+
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
|
3 |
+
|
4 |
+
# Install required dependencies (add any additional system dependencies you need)
|
5 |
+
RUN apt update && apt install -y ffmpeg
|
6 |
+
|
7 |
+
# Create a non-root user with a home directory
|
8 |
+
RUN useradd -m -u 1000 user
|
9 |
+
|
10 |
+
# Switch to the new non-root user
|
11 |
+
USER user
|
12 |
+
|
13 |
+
# Set environment variables for the new user
|
14 |
+
ENV HOME=/home/user \
|
15 |
+
PATH=/home/user/.local/bin:$PATH
|
16 |
+
|
17 |
+
# Set a working directory
|
18 |
+
WORKDIR $HOME/app
|
19 |
+
|
20 |
+
# Set the TRANSFORMERS_CACHE directory to be within the user's home directory
|
21 |
+
ENV TRANSFORMERS_CACHE=$HOME/cache
|
22 |
+
|
23 |
+
# Copy the app code and set ownership to the non-root user
|
24 |
+
COPY --chown=user . $HOME/app
|
25 |
+
|
26 |
+
# Install Python dependencies in the virtual environment
|
27 |
+
RUN python -m venv /home/user/venv
|
28 |
+
ENV PATH="/home/user/venv/bin:$PATH"
|
29 |
+
|
30 |
+
# Install pip dependencies within the virtual environment
|
31 |
+
COPY requirements.txt .
|
32 |
+
RUN pip install --upgrade pip
|
33 |
+
RUN pip install -r requirements.txt
|
34 |
+
|
35 |
+
# Run the training script
|
36 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
|
8 |
+
from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
9 |
+
from torch.utils.data import DataLoader, Dataset
|
10 |
+
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
|
11 |
+
from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
|
12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from datasets import load_dataset
|
16 |
+
|
17 |
+
timber = logging.getLogger()
|
18 |
+
# logging.basicConfig(level=logging.DEBUG)
|
19 |
+
logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
|
20 |
+
|
21 |
+
black = "\u001b[30m"
|
22 |
+
red = "\u001b[31m"
|
23 |
+
green = "\u001b[32m"
|
24 |
+
yellow = "\u001b[33m"
|
25 |
+
blue = "\u001b[34m"
|
26 |
+
magenta = "\u001b[35m"
|
27 |
+
cyan = "\u001b[36m"
|
28 |
+
white = "\u001b[37m"
|
29 |
+
|
30 |
+
FORWARD = "FORWARD_INPUT"
|
31 |
+
BACKWARD = "BACKWARD_INPUT"
|
32 |
+
|
33 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
|
35 |
+
|
36 |
+
def one_hot_e(dna_seq: str) -> np.ndarray:
|
37 |
+
mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
|
38 |
+
'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
|
39 |
+
'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
|
40 |
+
'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
|
41 |
+
'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
|
42 |
+
'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
|
43 |
+
|
44 |
+
size_of_a_seq: int = len(dna_seq)
|
45 |
+
|
46 |
+
# forward = np.zeros(shape=(size_of_a_seq, 4))
|
47 |
+
|
48 |
+
forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
|
49 |
+
encoded = np.asarray(forward_list)
|
50 |
+
encoded_transposed = encoded.transpose() # todo: Needs review
|
51 |
+
return encoded_transposed
|
52 |
+
|
53 |
+
|
54 |
+
def one_hot_e_column(column: pd.Series) -> np.ndarray:
|
55 |
+
tmp_list: list = [one_hot_e(seq) for seq in column]
|
56 |
+
encoded_column = np.asarray(tmp_list).astype(np.float32)
|
57 |
+
return encoded_column
|
58 |
+
|
59 |
+
|
60 |
+
def reverse_dna_seq(dna_seq: str) -> str:
|
61 |
+
# m_reversed = ""
|
62 |
+
# for i in range(0, len(dna_seq)):
|
63 |
+
# m_reversed = dna_seq[i] + m_reversed
|
64 |
+
# return m_reversed
|
65 |
+
return dna_seq[::-1]
|
66 |
+
|
67 |
+
|
68 |
+
def complement_dna_seq(dna_seq: str) -> str:
|
69 |
+
comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
|
70 |
+
"a": "t", "c": "g", "t": "a", "g": "c",
|
71 |
+
"N": "N", "H": "H", "-": "-",
|
72 |
+
"n": "n", "h": "h"
|
73 |
+
}
|
74 |
+
|
75 |
+
comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
|
76 |
+
comp_dna_seq: str = "".join(comp_dna_seq_list)
|
77 |
+
return comp_dna_seq
|
78 |
+
|
79 |
+
|
80 |
+
def reverse_complement_dna_seq(dna_seq: str) -> str:
|
81 |
+
return reverse_dna_seq(complement_dna_seq(dna_seq))
|
82 |
+
|
83 |
+
|
84 |
+
def reverse_complement_column(column: pd.Series) -> np.ndarray:
|
85 |
+
rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
|
86 |
+
return rc_column
|
87 |
+
|
88 |
+
|
89 |
+
class TorchMetrics:
|
90 |
+
def __init__(self, device=DEVICE):
|
91 |
+
self.binary_accuracy = BinaryAccuracy().to(device)
|
92 |
+
self.binary_auc = BinaryAUROC().to(device)
|
93 |
+
self.binary_f1_score = BinaryF1Score().to(device)
|
94 |
+
self.binary_precision = BinaryPrecision().to(device)
|
95 |
+
self.binary_recall = BinaryRecall().to(device)
|
96 |
+
pass
|
97 |
+
|
98 |
+
def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
|
99 |
+
self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
|
100 |
+
self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
|
101 |
+
self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
|
102 |
+
self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
|
103 |
+
self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
|
104 |
+
pass
|
105 |
+
|
106 |
+
def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
|
107 |
+
b_accuracy = self.binary_accuracy.compute()
|
108 |
+
b_auc = self.binary_auc.compute()
|
109 |
+
b_f1_score = self.binary_f1_score.compute()
|
110 |
+
b_precision = self.binary_precision.compute()
|
111 |
+
b_recall = self.binary_recall.compute()
|
112 |
+
timber.info(
|
113 |
+
log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
|
114 |
+
log(f"{log_prefix}_accuracy", b_accuracy)
|
115 |
+
log(f"{log_prefix}_auc", b_auc)
|
116 |
+
log(f"{log_prefix}_f1_score", b_f1_score)
|
117 |
+
log(f"{log_prefix}_precision", b_precision)
|
118 |
+
log(f"{log_prefix}_recall", b_recall)
|
119 |
+
|
120 |
+
self.binary_accuracy.reset()
|
121 |
+
self.binary_auc.reset()
|
122 |
+
self.binary_f1_score.reset()
|
123 |
+
self.binary_precision.reset()
|
124 |
+
self.binary_recall.reset()
|
125 |
+
pass
|
126 |
+
|
127 |
+
|
128 |
+
def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
|
129 |
+
start = 0
|
130 |
+
end = len(seq)
|
131 |
+
rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
|
132 |
+
random_end = rand_pos + len(DEBUG_MOTIF)
|
133 |
+
output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
|
134 |
+
assert len(seq) == len(output)
|
135 |
+
return output
|
136 |
+
|
137 |
+
|
138 |
+
class MQTLDataset(Dataset):
|
139 |
+
def __init__(self, dataset, check_if_pipeline_is_ok_by_inserting_debug_motif=False):
|
140 |
+
self.dataset = dataset
|
141 |
+
self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
|
142 |
+
self.debug_motif = "ATCGCCTA"
|
143 |
+
pass
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.dataset)
|
147 |
+
|
148 |
+
def __getitem__(self, idx):
|
149 |
+
seq = self.dataset[idx]['sequence'] # Fetch the 'sequence' column
|
150 |
+
label = self.dataset[idx]['label'] # Fetch the 'label' column (or whatever target you use)
|
151 |
+
if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
|
152 |
+
seq = insert_debug_motif_at_random_position(seq=seq, DEBUG_MOTIF=self.debug_motif)
|
153 |
+
seq_rc = reverse_complement_dna_seq(seq)
|
154 |
+
ohe_seq = one_hot_e(dna_seq=seq)
|
155 |
+
# print(f"shape fafafa = { ohe_seq.shape = }")
|
156 |
+
ohe_seq_rc = one_hot_e(dna_seq=seq_rc)
|
157 |
+
|
158 |
+
label_number = label * 1.0
|
159 |
+
label_np_array = np.asarray([label_number]).astype(np.float32)
|
160 |
+
# return ohe_seq, ohe_seq_rc, label
|
161 |
+
return [ohe_seq, ohe_seq_rc], label_np_array
|
162 |
+
|
163 |
+
|
164 |
+
class MqtlDataModule(LightningDataModule):
|
165 |
+
def __init__(self, train_ds: Dataset, val_ds: Dataset, test_ds: Dataset, batch_size=16):
|
166 |
+
super().__init__()
|
167 |
+
self.batch_size = batch_size
|
168 |
+
self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True, num_workers=15,
|
169 |
+
persistent_workers=True)
|
170 |
+
self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
|
171 |
+
persistent_workers=True)
|
172 |
+
self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False, num_workers=15,
|
173 |
+
persistent_workers=True)
|
174 |
+
pass
|
175 |
+
|
176 |
+
def prepare_data(self):
|
177 |
+
pass
|
178 |
+
|
179 |
+
def setup(self, stage: str) -> None:
|
180 |
+
timber.info(f"inside setup: {stage = }")
|
181 |
+
pass
|
182 |
+
|
183 |
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
184 |
+
return self.train_loader
|
185 |
+
|
186 |
+
def val_dataloader(self) -> EVAL_DATALOADERS:
|
187 |
+
return self.validate_loader
|
188 |
+
|
189 |
+
def test_dataloader(self) -> EVAL_DATALOADERS:
|
190 |
+
return self.test_loader
|
191 |
+
|
192 |
+
|
193 |
+
class MQtlClassifierLightningModule(LightningModule):
|
194 |
+
def __init__(self,
|
195 |
+
classifier: nn.Module,
|
196 |
+
criterion=nn.BCELoss(), # nn.BCEWithLogitsLoss(),
|
197 |
+
regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
|
198 |
+
l1_lambda=0.001,
|
199 |
+
l2_wright_decay=0.001,
|
200 |
+
m_optimizer=torch.optim.Adam,
|
201 |
+
*args: Any,
|
202 |
+
**kwargs: Any):
|
203 |
+
super().__init__(*args, **kwargs)
|
204 |
+
self.classifier = classifier
|
205 |
+
self.criterion = criterion
|
206 |
+
self.train_metrics = TorchMetrics()
|
207 |
+
self.validate_metrics = TorchMetrics()
|
208 |
+
self.test_metrics = TorchMetrics()
|
209 |
+
|
210 |
+
self.regularization = regularization
|
211 |
+
self.l1_lambda = l1_lambda
|
212 |
+
self.l2_weight_decay = l2_wright_decay
|
213 |
+
self.m_optimizer = m_optimizer
|
214 |
+
pass
|
215 |
+
|
216 |
+
def forward(self, x, *args: Any, **kwargs: Any) -> Any:
|
217 |
+
return self.classifier.forward(x)
|
218 |
+
|
219 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
220 |
+
# Here we add weight decay (L2 regularization) to the optimizer
|
221 |
+
weight_decay = 0.0
|
222 |
+
if self.regularization == 2 or self.regularization == 3:
|
223 |
+
weight_decay = self.l2_weight_decay
|
224 |
+
return self.m_optimizer(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
|
225 |
+
|
226 |
+
def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
227 |
+
# Accuracy on training batch data
|
228 |
+
x, y = batch
|
229 |
+
x = [i.float() for i in x]
|
230 |
+
preds = self.forward(x)
|
231 |
+
loss = self.criterion(preds, y)
|
232 |
+
|
233 |
+
if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
|
234 |
+
l1_norm = sum(p.abs().sum() for p in self.parameters())
|
235 |
+
loss += self.l1_lambda * l1_norm
|
236 |
+
|
237 |
+
self.log("train_loss", loss)
|
238 |
+
# calculate the scores start
|
239 |
+
self.train_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
|
240 |
+
# calculate the scores end
|
241 |
+
return loss
|
242 |
+
|
243 |
+
def on_train_epoch_end(self) -> None:
|
244 |
+
timber.info(green + "on_train_epoch_end")
|
245 |
+
self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
|
246 |
+
pass
|
247 |
+
|
248 |
+
def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
249 |
+
# Accuracy on validation batch data
|
250 |
+
x, y = batch
|
251 |
+
x = [i.float() for i in x]
|
252 |
+
|
253 |
+
preds = self.forward(x)
|
254 |
+
loss = self.criterion(preds, y)
|
255 |
+
self.log("valid_loss", loss)
|
256 |
+
# calculate the scores start
|
257 |
+
self.validate_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
|
258 |
+
# calculate the scores end
|
259 |
+
return loss
|
260 |
+
|
261 |
+
def on_validation_epoch_end(self) -> None:
|
262 |
+
timber.info(blue + "on_validation_epoch_end")
|
263 |
+
self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
|
264 |
+
return None
|
265 |
+
|
266 |
+
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
267 |
+
# Accuracy on validation batch data
|
268 |
+
x, y = batch
|
269 |
+
x = [i.float() for i in x]
|
270 |
+
|
271 |
+
preds = self.forward(x)
|
272 |
+
loss = self.criterion(preds, y)
|
273 |
+
self.log("test_loss", loss) # do we need this?
|
274 |
+
# calculate the scores start
|
275 |
+
self.test_metrics.update_on_each_step(batch_predicted_labels=preds, batch_actual_labels=y)
|
276 |
+
# calculate the scores end
|
277 |
+
return loss
|
278 |
+
|
279 |
+
def on_test_epoch_end(self) -> None:
|
280 |
+
timber.info(magenta + "on_test_epoch_end")
|
281 |
+
self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
|
282 |
+
return None
|
283 |
+
|
284 |
+
pass
|
285 |
+
|
286 |
+
|
287 |
+
# Some more util functions!
|
288 |
+
def create_conv_sequence(in_channel_num_of_nucleotides, num_filters, kernel_size_k_mer_motif) -> nn.Sequential:
|
289 |
+
conv1d = nn.Conv1d(in_channels=in_channel_num_of_nucleotides, out_channels=num_filters,
|
290 |
+
kernel_size=kernel_size_k_mer_motif,
|
291 |
+
padding="same") # stride = 2, just dont use stride, keep it simple for now
|
292 |
+
activation = nn.ReLU(inplace=False) # (inplace=True) will fess with interpretability
|
293 |
+
pooling = nn.MaxPool1d(
|
294 |
+
kernel_size=kernel_size_k_mer_motif) # stride = 2, just dont use stride, keep it simple for now
|
295 |
+
|
296 |
+
return nn.Sequential(conv1d, activation, pooling)
|
297 |
+
|
298 |
+
|
299 |
+
class Cnn1dClassifier(nn.Module):
|
300 |
+
def __init__(self,
|
301 |
+
seq_len,
|
302 |
+
in_channel_num_of_nucleotides=4,
|
303 |
+
kernel_size_k_mer_motif=4,
|
304 |
+
num_filters=32,
|
305 |
+
lstm_hidden_size=128,
|
306 |
+
dnn_size=128,
|
307 |
+
conv_seq_list_size=3,
|
308 |
+
*args, **kwargs):
|
309 |
+
super().__init__(*args, **kwargs)
|
310 |
+
self.file_name = f"weights_Cnn1dClassifier_seqlen_{seq_len}.pth"
|
311 |
+
|
312 |
+
self.seq_layer_forward = create_conv_sequence(in_channel_num_of_nucleotides, num_filters,
|
313 |
+
kernel_size_k_mer_motif)
|
314 |
+
self.seq_layer_backward = create_conv_sequence(in_channel_num_of_nucleotides, num_filters,
|
315 |
+
kernel_size_k_mer_motif)
|
316 |
+
|
317 |
+
self.flatten = nn.Flatten()
|
318 |
+
|
319 |
+
dnn_in_features = int(num_filters * (seq_len * 2) / kernel_size_k_mer_motif) # no idea why
|
320 |
+
# two because forward_sequence,and backward_sequence
|
321 |
+
self.dnn = nn.Linear(in_features=dnn_in_features, out_features=dnn_size)
|
322 |
+
self.dnn_activation = nn.ReLU(inplace=False) # inplace = true messes with interpretability!
|
323 |
+
self.dropout = nn.Dropout(p=0.33)
|
324 |
+
|
325 |
+
self.output_layer = nn.Linear(in_features=dnn_size, out_features=1)
|
326 |
+
self.output_activation = torch.sigmoid # not needed if using nn.BCEWithLogitsLoss()
|
327 |
+
|
328 |
+
self.layer_output_logger: dict = {}
|
329 |
+
pass
|
330 |
+
|
331 |
+
def forward(self, x):
|
332 |
+
xf, xb = x[0], x[1]
|
333 |
+
|
334 |
+
hf = self.seq_layer_forward(xf)
|
335 |
+
timber.debug(red + f"1{ hf.shape = }")
|
336 |
+
hb = self.seq_layer_backward(xb)
|
337 |
+
timber.debug(green + f"2{ hb.shape = }")
|
338 |
+
|
339 |
+
h = torch.concatenate(tensors=(hf, hb), dim=2)
|
340 |
+
timber.debug(yellow + f"4{ h.shape = } concat")
|
341 |
+
|
342 |
+
h = self.flatten(h)
|
343 |
+
timber.debug(yellow + f"5{ h.shape = } flatten")
|
344 |
+
|
345 |
+
h = self.dnn(h)
|
346 |
+
timber.debug(yellow + f"8{ h.shape = } dnn")
|
347 |
+
|
348 |
+
h = self.dnn_activation(h)
|
349 |
+
timber.debug(blue + f"9{ h.shape = } dnn_activation")
|
350 |
+
h = self.dropout(h)
|
351 |
+
timber.debug(blue + f"10{ h.shape = } dropout")
|
352 |
+
h = self.output_layer(h)
|
353 |
+
timber.debug(blue + f"11{ h.shape = } output_layer")
|
354 |
+
h = self.output_activation(h)
|
355 |
+
timber.debug(blue + f"12{ h.shape = } output_activation")
|
356 |
+
return h
|
357 |
+
|
358 |
+
|
359 |
+
def start(classifier_model, model_save_path, is_attention_model=False, m_optimizer=torch.optim.Adam, WINDOW=200,
|
360 |
+
dataset_folder_prefix="inputdata/", is_binned=True, is_debug=False, max_epochs=10):
|
361 |
+
# experiment = 'tutorial_3'
|
362 |
+
# if not os.path.exists(experiment):
|
363 |
+
# os.makedirs(experiment)
|
364 |
+
"""
|
365 |
+
x_train, x_tmp, y_train, y_tmp = train_test_split(df["sequence"], df["label"], test_size=0.2)
|
366 |
+
x_test, x_val, y_test, y_val = train_test_split(x_tmp, y_tmp, test_size=0.5)
|
367 |
+
|
368 |
+
train_dataset = MyDataSet(x_train, y_train)
|
369 |
+
val_dataset = MyDataSet(x_val, y_val)
|
370 |
+
test_dataset = MyDataSet(x_test, y_test)
|
371 |
+
"""
|
372 |
+
file_suffix = ""
|
373 |
+
if is_binned:
|
374 |
+
file_suffix = "_binned"
|
375 |
+
|
376 |
+
dataset_map = load_dataset("fahimfarhan/mqtl-classification-dataset-binned-200")
|
377 |
+
|
378 |
+
train_dataset = MQTLDataset(dataset_map["train"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
|
379 |
+
val_dataset = MQTLDataset(dataset_map["validate"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
|
380 |
+
test_dataset = MQTLDataset(dataset_map["test"], check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug)
|
381 |
+
|
382 |
+
data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
|
383 |
+
|
384 |
+
classifier_model = classifier_model #.to(DEVICE)
|
385 |
+
|
386 |
+
classifier_module = MQtlClassifierLightningModule(classifier=classifier_model, regularization=2,
|
387 |
+
m_optimizer=m_optimizer)
|
388 |
+
|
389 |
+
# if os.path.exists(model_save_path):
|
390 |
+
# classifier_module.load_state_dict(torch.load(model_save_path))
|
391 |
+
|
392 |
+
classifier_module = classifier_module # .double()
|
393 |
+
|
394 |
+
trainer = Trainer(max_epochs=max_epochs, precision="32")
|
395 |
+
trainer.fit(model=classifier_module, datamodule=data_module)
|
396 |
+
timber.info("\n\n")
|
397 |
+
trainer.test(model=classifier_module, datamodule=data_module)
|
398 |
+
timber.info("\n\n")
|
399 |
+
torch.save(classifier_module.state_dict(), model_save_path)
|
400 |
+
|
401 |
+
trainer.push_to_hub("fahimfarhan/mqtl-classifier-model")
|
402 |
+
|
403 |
+
# start_interpreting_ig_and_dl(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
|
404 |
+
# start_interpreting_with_dlshap(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
|
405 |
+
# if is_attention_model: # todo: repair it later
|
406 |
+
# start_interpreting_attention_failed(classifier_model)
|
407 |
+
pass
|
408 |
+
|
409 |
+
|
410 |
+
if __name__ == '__main__':
|
411 |
+
WINDOW = 200
|
412 |
+
simple_cnn = Cnn1dClassifier(seq_len=WINDOW)
|
413 |
+
simple_cnn.enable_logging = True
|
414 |
+
|
415 |
+
start(classifier_model=simple_cnn, model_save_path=simple_cnn.file_name, WINDOW=WINDOW,
|
416 |
+
dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
|
417 |
+
|
418 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate # required by HayenaDNA
|
2 |
+
datasets
|
3 |
+
pandas
|
4 |
+
polars
|
5 |
+
numpy
|
6 |
+
matplotlib
|
7 |
+
scipy
|
8 |
+
shap
|
9 |
+
scikit-learn
|
10 |
+
skorch==1.0.0
|
11 |
+
six
|
12 |
+
hyperopt
|
13 |
+
requests
|
14 |
+
pyyaml
|
15 |
+
Bio
|
16 |
+
plotly
|
17 |
+
Levenshtein
|
18 |
+
# pytorch
|
19 |
+
captum
|
20 |
+
torch==2.4.0
|
21 |
+
torchvision
|
22 |
+
torchaudio
|
23 |
+
torchsummary
|
24 |
+
torcheval
|
25 |
+
pydot
|
26 |
+
pydotplus
|
27 |
+
PySide2 # matplotlib dependency on ubuntu. you may need sth else for other os/env setup
|
28 |
+
torchviz
|
29 |
+
gReLU # luckily now available in pip!
|
30 |
+
# gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
|
31 |
+
# lightning[extra] # cz I got a stupid warning in the console logs
|
32 |
+
torchmetrics
|