ribesstefano commited on
Commit
b09510c
·
1 Parent(s): 5c27a23

Added script file for hparam CV training

Browse files
notebooks/protac_degradation_predictor.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/protac_degradation_predictor.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # PROTAC-Degradation-Predictor
3
+
4
+ # %%
5
+ import pandas as pd
6
+
7
+ protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
8
+ protac_df.head()
9
+
10
+ # %%
11
+ # Get the unique Article IDs of the entries with NaN values in the Active column
12
+ nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique()
13
+ nan_active
14
+
15
+ # %%
16
+ # Map E3 Ligase Iap to IAP
17
+ protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
18
+
19
+ # %%
20
+ protac_df.columns
21
+
22
+ # %%
23
+ cells = sorted(protac_df['Cell Type'].dropna().unique().tolist())
24
+ print(f'Number of non-cleaned cell lines: {len(cells)}')
25
+
26
+ # %%
27
+ cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist())
28
+ print(f'Number of cleaned cell lines: {len(cells)}')
29
+
30
+ # %%
31
+ unlabeled_df = protac_df[protac_df['Active'].isna()]
32
+ print(f'Number of compounds in test set: {len(unlabeled_df)}')
33
+
34
+ # %% [markdown]
35
+ # ## Load Protein Embeddings
36
+
37
+ # %% [markdown]
38
+ # Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
39
+ #
40
+ # Please note that running the following cell the first time might take a while.
41
+
42
+ # %%
43
+ import os
44
+ import urllib.request
45
+
46
+ download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
47
+ embeddings_path = "../data/uniprot2embedding.h5"
48
+ if not os.path.exists(embeddings_path):
49
+ # Download the file
50
+ print(f'Downloading embeddings from {download_link}')
51
+ urllib.request.urlretrieve(download_link, embeddings_path)
52
+
53
+ # %%
54
+ import h5py
55
+ import numpy as np
56
+ from tqdm.auto import tqdm
57
+
58
+ protein_embeddings = {}
59
+ with h5py.File("../data/uniprot2embedding.h5", "r") as file:
60
+ print(f"number of entries: {len(file.items()):,}")
61
+ uniprots = protac_df['Uniprot'].unique().tolist()
62
+ uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist()
63
+ for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'):
64
+ try:
65
+ embedding = file[sequence_id][:]
66
+ protein_embeddings[sequence_id] = np.array(embedding)
67
+ if i < 10:
68
+ print(
69
+ f"\tid: {sequence_id}, "
70
+ f"\tembeddings shape: {embedding.shape}, "
71
+ f"\tembeddings mean: {np.array(embedding).mean()}"
72
+ )
73
+ except KeyError:
74
+ print(f'KeyError for {sequence_id}')
75
+ protein_embeddings[sequence_id] = np.zeros((1024,))
76
+
77
+ # %% [markdown]
78
+ # ## Load Cell Embeddings
79
+
80
+ # %%
81
+ import pickle
82
+
83
+ cell2embedding_filepath = '../data/cell2embedding.pkl'
84
+ with open(cell2embedding_filepath, 'rb') as f:
85
+ cell2embedding = pickle.load(f)
86
+ print(f'Loaded {len(cell2embedding)} cell lines')
87
+
88
+ # %%
89
+ emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape
90
+ # Assign all-zero vectors to cell lines that are not in the embedding file
91
+ for cell_line in protac_df['Cell Line Identifier'].unique():
92
+ if cell_line not in cell2embedding:
93
+ cell2embedding[cell_line] = np.zeros(emb_shape)
94
+
95
+ # %% [markdown]
96
+ # ## Precompute Molecular Fingerprints
97
+
98
+ # %%
99
+ from rdkit import Chem
100
+ from rdkit.Chem import AllChem
101
+ from rdkit.Chem import Draw
102
+
103
+ morgan_radius = 15
104
+ n_bits = 1024
105
+
106
+ # fpgen = AllChem.GetAtomPairGenerator()
107
+ rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512)
108
+ morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits)
109
+
110
+ smiles2fp = {}
111
+ for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
112
+ # Get the fingerprint as a bit vector
113
+ morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
114
+ # rdkit_fp = rdkit_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
115
+ # fp = np.concatenate([morgan_fp, rdkit_fp])
116
+ smiles2fp[smiles] = morgan_fp
117
+
118
+ # Count the number of unique SMILES and the number of unique Morgan fingerprints
119
+ print(f'Number of unique SMILES: {len(smiles2fp)}')
120
+ print(f'Number of unique fingerprints: {len(set([tuple(fp) for fp in smiles2fp.values()]))}')
121
+ # Get the list of SMILES with overlapping fingerprints
122
+ overlapping_smiles = []
123
+ unique_fps = set()
124
+ for smiles, fp in smiles2fp.items():
125
+ if tuple(fp) in unique_fps:
126
+ overlapping_smiles.append(smiles)
127
+ else:
128
+ unique_fps.add(tuple(fp))
129
+ print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')
130
+ print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
131
+
132
+ # %%
133
+ # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
134
+ from rdkit import DataStructs
135
+ from collections import defaultdict
136
+
137
+ tanimoto_matrix = defaultdict(list)
138
+ for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
139
+ fp1 = smiles2fp[smiles1]
140
+ # TODO: Use BulkTanimotoSimilarity
141
+ for j, smiles2 in enumerate(protac_df['Smiles'].unique()):
142
+ if j < i:
143
+ continue
144
+ fp2 = smiles2fp[smiles2]
145
+ tanimoto_dist = DataStructs.TanimotoSimilarity(fp1, fp2)
146
+ tanimoto_matrix[smiles1].append(tanimoto_dist)
147
+ avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
148
+ protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
149
+
150
+ # %%
151
+ # # Plot the distribution of the average Tanimoto similarity
152
+ # import seaborn as sns
153
+ # import matplotlib.pyplot as plt
154
+
155
+ # sns.histplot(protac_df['Avg Tanimoto'], bins=50)
156
+ # plt.xlabel('Average Tanimoto similarity')
157
+ # plt.ylabel('Count')
158
+ # plt.title('Distribution of average Tanimoto similarity')
159
+ # plt.grid(axis='y', alpha=0.5)
160
+ # plt.show()
161
+
162
+ # %%
163
+ smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
164
+
165
+ # %% [markdown]
166
+ # ## Set the Column to Predict
167
+
168
+ # %%
169
+ # active_col = 'Active'
170
+ active_col = 'Active - OR'
171
+
172
+
173
+ from sklearn.preprocessing import StandardScaler
174
+
175
+ # %% [markdown]
176
+ # ## Define Torch Dataset
177
+
178
+ # %%
179
+ from imblearn.over_sampling import SMOTE, ADASYN
180
+ from sklearn.preprocessing import LabelEncoder
181
+ import pandas as pd
182
+ import numpy as np
183
+
184
+ # %%
185
+ from torch.utils.data import Dataset, DataLoader
186
+
187
+
188
+ class PROTAC_Dataset(Dataset):
189
+ def __init__(
190
+ self,
191
+ protac_df,
192
+ protein_embeddings=protein_embeddings,
193
+ cell2embedding=cell2embedding,
194
+ smiles2fp=smiles2fp,
195
+ use_smote=False,
196
+ oversampler=None,
197
+ use_ored_activity=False,
198
+ ):
199
+ """ Initialize the PROTAC dataset
200
+
201
+ Args:
202
+ protac_df (pd.DataFrame): The PROTAC dataframe
203
+ protein_embeddings (dict): Dictionary of protein embeddings
204
+ cell2embedding (dict): Dictionary of cell line embeddings
205
+ smiles2fp (dict): Dictionary of SMILES to fingerprint
206
+ use_smote (bool): Whether to use SMOTE for oversampling
207
+ use_ored_activity (bool): Whether to use the 'Active - OR' column
208
+ """
209
+ # Filter out examples with NaN in 'Active' column
210
+ self.data = protac_df # [~protac_df['Active'].isna()]
211
+ self.protein_embeddings = protein_embeddings
212
+ self.cell2embedding = cell2embedding
213
+ self.smiles2fp = smiles2fp
214
+
215
+ self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
216
+ self.protein_emb_dim = protein_embeddings[list(
217
+ protein_embeddings.keys())[0]].shape[0]
218
+ self.cell_emb_dim = cell2embedding[list(
219
+ cell2embedding.keys())[0]].shape[0]
220
+
221
+ self.active_label = 'Active - OR' if use_ored_activity else 'Active'
222
+
223
+ self.use_smote = use_smote
224
+ self.oversampler = oversampler
225
+ # Apply SMOTE
226
+ if self.use_smote:
227
+ self.apply_smote()
228
+
229
+ def apply_smote(self):
230
+ # Prepare the dataset for SMOTE
231
+ features = []
232
+ labels = []
233
+ for _, row in self.data.iterrows():
234
+ smiles_emb = smiles2fp[row['Smiles']]
235
+ poi_emb = protein_embeddings[row['Uniprot']]
236
+ e3_emb = protein_embeddings[row['E3 Ligase Uniprot']]
237
+ cell_emb = cell2embedding[row['Cell Line Identifier']]
238
+ features.append(np.hstack([
239
+ smiles_emb.astype(np.float32),
240
+ poi_emb.astype(np.float32),
241
+ e3_emb.astype(np.float32),
242
+ cell_emb.astype(np.float32),
243
+ ]))
244
+ labels.append(row[self.active_label])
245
+
246
+ # Convert to numpy array
247
+ features = np.array(features).astype(np.float32)
248
+ labels = np.array(labels).astype(np.float32)
249
+
250
+ # Initialize SMOTE and fit
251
+ if self.oversampler is None:
252
+ oversampler = SMOTE(random_state=42)
253
+ else:
254
+ oversampler = self.oversampler
255
+ features_smote, labels_smote = oversampler.fit_resample(features, labels)
256
+
257
+ # Separate the features back into their respective embeddings
258
+ smiles_embs = features_smote[:, :self.smiles_emb_dim]
259
+ poi_embs = features_smote[:,
260
+ self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
261
+ e3_embs = features_smote[:, self.smiles_emb_dim +
262
+ self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
263
+ cell_embs = features_smote[:, -self.cell_emb_dim:]
264
+
265
+ # Reconstruct the dataframe with oversampled data
266
+ df_smote = pd.DataFrame({
267
+ 'Smiles': list(smiles_embs),
268
+ 'Uniprot': list(poi_embs),
269
+ 'E3 Ligase Uniprot': list(e3_embs),
270
+ 'Cell Line Identifier': list(cell_embs),
271
+ self.active_label: labels_smote
272
+ })
273
+ self.data = df_smote
274
+
275
+ def __len__(self):
276
+ return len(self.data)
277
+
278
+ def __getitem__(self, idx):
279
+ if self.use_smote:
280
+ # NOTE: We do not need to look up the embeddings anymore
281
+ elem = {
282
+ 'smiles_emb': self.data['Smiles'].iloc[idx],
283
+ 'poi_emb': self.data['Uniprot'].iloc[idx],
284
+ 'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
285
+ 'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
286
+ 'active': self.data[self.active_label].iloc[idx],
287
+ }
288
+ else:
289
+ elem = {
290
+ 'smiles_emb': self.smiles2fp[self.data['Smiles'].iloc[idx]].astype(np.float32),
291
+ 'poi_emb': self.protein_embeddings[self.data['Uniprot'].iloc[idx]].astype(np.float32),
292
+ 'e3_emb': self.protein_embeddings[self.data['E3 Ligase Uniprot'].iloc[idx]].astype(np.float32),
293
+ 'cell_emb': self.cell2embedding[self.data['Cell Line Identifier'].iloc[idx]].astype(np.float32),
294
+ 'active': 1. if self.data[self.active_label].iloc[idx] else 0.,
295
+ }
296
+ return elem
297
+
298
+ # %%
299
+ import warnings
300
+ import torch
301
+ import torch.nn as nn
302
+ import torch.nn.functional as F
303
+ import torch.optim as optim
304
+ import pytorch_lightning as pl
305
+ from torchmetrics import (
306
+ Accuracy,
307
+ AUROC,
308
+ Precision,
309
+ Recall,
310
+ F1Score,
311
+ )
312
+ from torchmetrics import MetricCollection
313
+
314
+ # Ignore UserWarning from PyTorch Lightning
315
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
316
+
317
+ class PROTAC_Model(pl.LightningModule):
318
+
319
+ def __init__(
320
+ self,
321
+ hidden_dim,
322
+ smiles_emb_dim=1024,
323
+ poi_emb_dim=1024,
324
+ e3_emb_dim=1024,
325
+ cell_emb_dim=768,
326
+ batch_size=32,
327
+ learning_rate=1e-3,
328
+ dropout=0.2,
329
+ train_dataset=None,
330
+ val_dataset=None,
331
+ test_dataset=None,
332
+ disabled_embeddings=[],
333
+ ):
334
+ super().__init__()
335
+ self.poi_emb_dim = poi_emb_dim
336
+ self.e3_emb_dim = e3_emb_dim
337
+ self.cell_emb_dim = cell_emb_dim
338
+ self.smiles_emb_dim = smiles_emb_dim
339
+ self.hidden_dim = hidden_dim
340
+ self.batch_size = batch_size
341
+ self.learning_rate = learning_rate
342
+ self.train_dataset = train_dataset
343
+ self.val_dataset = val_dataset
344
+ self.test_dataset = test_dataset
345
+ self.disabled_embeddings = disabled_embeddings
346
+ # Set our init args as class attributes
347
+ self.__dict__.update(locals()) # Add arguments as attributes
348
+ # Save the arguments passed to init
349
+ ignore_args_as_hyperparams = [
350
+ 'train_dataset',
351
+ 'test_dataset',
352
+ 'val_dataset',
353
+ ]
354
+ self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
355
+
356
+ if 'poi' not in self.disabled_embeddings:
357
+ self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
358
+ # # Set the POI surrogate model as a Sequential model
359
+ # self.poi_emb = nn.Sequential(
360
+ # nn.Linear(poi_emb_dim, hidden_dim),
361
+ # nn.GELU(),
362
+ # nn.Dropout(p=dropout),
363
+ # nn.Linear(hidden_dim, hidden_dim),
364
+ # # nn.ReLU(),
365
+ # # nn.Dropout(p=dropout),
366
+ # )
367
+ if 'e3' not in self.disabled_embeddings:
368
+ self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
369
+ # self.e3_emb = nn.Sequential(
370
+ # nn.Linear(e3_emb_dim, hidden_dim),
371
+ # # nn.ReLU(),
372
+ # nn.Dropout(p=dropout),
373
+ # # nn.Linear(hidden_dim, hidden_dim),
374
+ # # nn.ReLU(),
375
+ # # nn.Dropout(p=dropout),
376
+ # )
377
+ if 'cell' not in self.disabled_embeddings:
378
+ self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
379
+ # self.cell_emb = nn.Sequential(
380
+ # nn.Linear(cell_emb_dim, hidden_dim),
381
+ # # nn.ReLU(),
382
+ # nn.Dropout(p=dropout),
383
+ # # nn.Linear(hidden_dim, hidden_dim),
384
+ # # nn.ReLU(),
385
+ # # nn.Dropout(p=dropout),
386
+ # )
387
+ if 'smiles' not in self.disabled_embeddings:
388
+ self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
389
+ # self.smiles_emb = nn.Sequential(
390
+ # nn.Linear(smiles_emb_dim, hidden_dim),
391
+ # # nn.ReLU(),
392
+ # nn.Dropout(p=dropout),
393
+ # # nn.Linear(hidden_dim, hidden_dim),
394
+ # # nn.ReLU(),
395
+ # # nn.Dropout(p=dropout),
396
+ # )
397
+
398
+ self.fc1 = nn.Linear(
399
+ hidden_dim * (4 - len(self.disabled_embeddings)), hidden_dim)
400
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
401
+ self.fc3 = nn.Linear(hidden_dim, 1)
402
+
403
+ self.dropout = nn.Dropout(p=dropout)
404
+
405
+ stages = ['train_metrics', 'val_metrics', 'test_metrics']
406
+ self.metrics = nn.ModuleDict({s: MetricCollection({
407
+ 'acc': Accuracy(task='binary'),
408
+ 'roc_auc': AUROC(task='binary'),
409
+ 'precision': Precision(task='binary'),
410
+ 'recall': Recall(task='binary'),
411
+ 'f1_score': F1Score(task='binary'),
412
+ 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
413
+ 'hp_metric': Accuracy(task='binary'),
414
+ }, prefix=s.replace('metrics', '')) for s in stages})
415
+
416
+ # Misc settings
417
+ self.missing_dataset_error = \
418
+ '''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually:
419
+
420
+ model = {1}.load_from_checkpoint('checkpoint.ckpt')
421
+ model.{0} = my_{0}
422
+ '''
423
+
424
+ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
425
+ embeddings = []
426
+ if 'poi' not in self.disabled_embeddings:
427
+ embeddings.append(self.poi_emb(poi_emb))
428
+ if 'e3' not in self.disabled_embeddings:
429
+ embeddings.append(self.e3_emb(e3_emb))
430
+ if 'cell' not in self.disabled_embeddings:
431
+ embeddings.append(self.cell_emb(cell_emb))
432
+ if 'smiles' not in self.disabled_embeddings:
433
+ embeddings.append(self.smiles_emb(smiles_emb))
434
+ x = torch.cat(embeddings, dim=1)
435
+ x = self.dropout(F.gelu(self.fc1(x)))
436
+ x = self.dropout(F.gelu(self.fc2(x)))
437
+ x = self.fc3(x)
438
+ return x
439
+
440
+ def step(self, batch, batch_idx, stage):
441
+ poi_emb = batch['poi_emb']
442
+ e3_emb = batch['e3_emb']
443
+ cell_emb = batch['cell_emb']
444
+ smiles_emb = batch['smiles_emb']
445
+ y = batch['active'].float().unsqueeze(1)
446
+
447
+ y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
448
+ loss = F.binary_cross_entropy_with_logits(y_hat, y)
449
+
450
+ self.metrics[f'{stage}_metrics'].update(y_hat, y)
451
+ self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True)
452
+ self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True)
453
+
454
+ return loss
455
+
456
+ def training_step(self, batch, batch_idx):
457
+ return self.step(batch, batch_idx, 'train')
458
+
459
+ def validation_step(self, batch, batch_idx):
460
+ return self.step(batch, batch_idx, 'val')
461
+
462
+ def test_step(self, batch, batch_idx):
463
+ return self.step(batch, batch_idx, 'test')
464
+
465
+ def configure_optimizers(self):
466
+ return optim.Adam(self.parameters(), lr=self.learning_rate)
467
+
468
+ def predict_step(self, batch, batch_idx):
469
+ poi_emb = batch['poi_emb']
470
+ e3_emb = batch['e3_emb']
471
+ cell_emb = batch['cell_emb']
472
+ smiles_emb = batch['smiles_emb']
473
+
474
+ y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
475
+ return torch.sigmoid(y_hat)
476
+
477
+ def train_dataloader(self):
478
+ if self.train_dataset is None:
479
+ format = 'train_dataset', self.__class__.__name__
480
+ raise ValueError(self.missing_dataset_error.format(*format))
481
+ return DataLoader(
482
+ self.train_dataset,
483
+ batch_size=self.batch_size,
484
+ shuffle=True,
485
+ # drop_last=True,
486
+ )
487
+
488
+ def val_dataloader(self):
489
+ if self.val_dataset is None:
490
+ format = 'val_dataset', self.__class__.__name__
491
+ raise ValueError(self.missing_dataset_error.format(*format))
492
+ return DataLoader(
493
+ self.val_dataset,
494
+ batch_size=self.batch_size,
495
+ shuffle=False,
496
+ )
497
+
498
+ def test_dataloader(self):
499
+ if self.test_dataset is None:
500
+ format = 'test_dataset', self.__class__.__name__
501
+ raise ValueError(self.missing_dataset_error.format(*format))
502
+ return DataLoader(
503
+ self.test_dataset,
504
+ batch_size=self.batch_size,
505
+ shuffle=False,
506
+ )
507
+
508
+ # %% [markdown]
509
+ # ## Test Sets
510
+
511
+ # %% [markdown]
512
+ # We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
513
+ # * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
514
+ # * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
515
+ # * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
516
+
517
+ # %%
518
+ test_indeces = {}
519
+
520
+ # %% [markdown]
521
+ # Isolating the unique SMILES and Uniprots:
522
+
523
+ # %%
524
+ active_df = protac_df[protac_df[active_col].notna()].copy()
525
+
526
+ # Get the unique SMILES and Uniprot
527
+ unique_smiles = active_df['Smiles'].value_counts() == 1
528
+ unique_uniprot = active_df['Uniprot'].value_counts() == 1
529
+ print(f'Number of unique SMILES: {unique_smiles.sum()}')
530
+ print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
531
+ # Sample 1% of the len(active_df) from unique SMILES and Uniprot and get the
532
+ # indices for a test set
533
+ n = int(0.05 * len(active_df)) // 2
534
+ unique_smiles = unique_smiles[unique_smiles].sample(n=n, random_state=42)
535
+ # unique_uniprot = unique_uniprot[unique_uniprot].sample(n=, random_state=42)
536
+ unique_indices = active_df[
537
+ active_df['Smiles'].isin(unique_smiles.index) &
538
+ active_df['Uniprot'].isin(unique_uniprot.index)
539
+ ].index
540
+ print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
541
+
542
+ test_indeces['random'] = unique_indices
543
+
544
+ # # Get the test set
545
+ # test_df = active_df.loc[unique_indices]
546
+ # # Bar plot of the test Active distribution as percentage
547
+ # test_df['Active'].value_counts(normalize=True).plot(kind='bar')
548
+ # plt.title('Test set Active distribution')
549
+ # plt.show()
550
+ # # Bar plot of the test Active - OR distribution as percentage
551
+ # test_df['Active - OR'].value_counts(normalize=True).plot(kind='bar')
552
+ # plt.title('Test set Active - OR distribution')
553
+ # plt.show()
554
+
555
+ # %% [markdown]
556
+ # Isolating the unique Uniprots:
557
+
558
+ # %%
559
+ active_df = protac_df[protac_df[active_col].notna()].copy()
560
+
561
+ unique_uniprot = active_df['Uniprot'].value_counts() == 1
562
+ print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
563
+
564
+ # NOTE: Since they are very few, all unique Uniprot will be used as test set.
565
+ # Get the indices for a test set
566
+ unique_indices = active_df[active_df['Uniprot'].isin(unique_uniprot.index)].index
567
+
568
+
569
+ test_indeces['uniprot'] = unique_indices
570
+ print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
571
+
572
+ # %% [markdown]
573
+ # DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
574
+ # * its SMILES is unique
575
+ # * its POI is unique
576
+ # * its (SMILES, POI) pair is unique
577
+
578
+ # %%
579
+ active_df = protac_df[protac_df[active_col].notna()]
580
+
581
+ # Find the samples that:
582
+ # * have their SMILES appearing only once in the dataframe
583
+ # * have their Uniprot appearing only once in the dataframe
584
+ # * have their (Smiles, Uniprot) pair appearing only once in the dataframe
585
+ unique_smiles = active_df['Smiles'].value_counts() == 1
586
+ unique_uniprot = active_df['Uniprot'].value_counts() == 1
587
+ unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1
588
+
589
+ # Get the indices of the unique samples
590
+ unique_smiles_idx = active_df['Smiles'].map(unique_smiles)
591
+ unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot)
592
+ unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot)
593
+
594
+ # Cross the indices to get the unique samples
595
+ # unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx & unique_smiles_uniprot_idx].index
596
+ unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx].index
597
+ test_df = active_df.loc[unique_samples]
598
+
599
+ warnings.filterwarnings("ignore", ".*FixedLocator*")
600
+
601
+ # %% [markdown]
602
+ # ## Cross-Validation Training
603
+
604
+ # %% [markdown]
605
+ # Cross validation training with 5 splits. The split operation is done in three different ways:
606
+ #
607
+ # * Random split
608
+ # * POI-wise: some POIs never in both splits
609
+ # * Least Tanimoto similarity PROTAC-wise
610
+
611
+ # %% [markdown]
612
+ # ### Plotting CV Folds
613
+
614
+ # %%
615
+ from sklearn.model_selection import (
616
+ StratifiedKFold,
617
+ StratifiedGroupKFold,
618
+ )
619
+ from sklearn.preprocessing import OrdinalEncoder
620
+
621
+ # NOTE: When set to 60, it will result in 29 groups, with nice distributions of
622
+ # the number of unique groups in the train and validation sets, together with
623
+ # the number of active and inactive PROTACs.
624
+ n_bins_tanimoto = 60 if active_col == 'Active' else 400
625
+ n_splits = 5
626
+ # The train and validation sets will be created from the active PROTACs only,
627
+ # i.e., the ones with 'Active' column not NaN, and that are NOT in the test set
628
+ active_df = protac_df[protac_df[active_col].notna()]
629
+ train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
630
+
631
+ # Make three groups for CV:
632
+ # * Random split
633
+ # * Split by Uniprot (POI)
634
+ # * Split by least tanimoto similarity PROTAC-wise
635
+ groups = [
636
+ 'random',
637
+ 'uniprot',
638
+ 'tanimoto',
639
+ ]
640
+ for group_type in groups:
641
+ if group_type == 'random':
642
+ kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
643
+ groups = None
644
+ elif group_type == 'uniprot':
645
+ # Split by Uniprot
646
+ kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
647
+ encoder = OrdinalEncoder()
648
+ groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
649
+ print(f'Number of unique groups: {len(encoder.categories_[0])}')
650
+ elif group_type == 'tanimoto':
651
+ # Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
652
+ kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
653
+ tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
654
+ encoder = OrdinalEncoder()
655
+ groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
656
+ print(f'Number of unique groups: {len(encoder.categories_[0])}')
657
+
658
+
659
+ X = train_val_df.drop(columns=active_col)
660
+ y = train_val_df[active_col].tolist()
661
+
662
+ # print(f'Group: {group_type}')
663
+ # fig, ax = plt.subplots(figsize=(6, 3))
664
+ # plot_cv_indices(kf, X=X, y=y, group=groups, ax=ax, n_splits=n_splits)
665
+ # plt.tight_layout()
666
+ # plt.show()
667
+
668
+ stats = []
669
+ for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
670
+ train_df = train_val_df.iloc[train_index]
671
+ val_df = train_val_df.iloc[val_index]
672
+ stat = {
673
+ 'fold': k,
674
+ 'train_len': len(train_df),
675
+ 'val_len': len(val_df),
676
+ 'train_perc': len(train_df) / len(train_val_df),
677
+ 'val_perc': len(val_df) / len(train_val_df),
678
+ 'train_active (%)': train_df[active_col].sum() / len(train_df) * 100,
679
+ 'train_inactive (%)': (len(train_df) - train_df[active_col].sum()) / len(train_df) * 100,
680
+ 'val_active (%)': val_df[active_col].sum() / len(val_df) * 100,
681
+ 'val_inactive (%)': (len(val_df) - val_df[active_col].sum()) / len(val_df) * 100,
682
+ 'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
683
+ 'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
684
+ }
685
+ if group_type != 'random':
686
+ stat['train_unique_groups'] = len(np.unique(groups[train_index]))
687
+ stat['val_unique_groups'] = len(np.unique(groups[val_index]))
688
+ stats.append(stat)
689
+ print('-' * 120)
690
+
691
+ # %% [markdown]
692
+ # ### Run CV
693
+
694
+ # %%
695
+ import warnings
696
+
697
+ # Seed everything in pytorch lightning
698
+ pl.seed_everything(42)
699
+
700
+
701
+ def train_model(
702
+ train_df,
703
+ val_df,
704
+ test_df=None,
705
+ hidden_dim=768,
706
+ batch_size=8,
707
+ learning_rate=2e-5,
708
+ max_epochs=50,
709
+ smiles_emb_dim=1024,
710
+ smote_n_neighbors=5,
711
+ use_ored_activity=False if active_col == 'Active' else True,
712
+ fast_dev_run=False,
713
+ disabled_embeddings=[],
714
+ ) -> tuple:
715
+ """ Train a PROTAC model using the given datasets and hyperparameters.
716
+
717
+ Args:
718
+ train_df (pd.DataFrame): The training set.
719
+ val_df (pd.DataFrame): The validation set.
720
+ test_df (pd.DataFrame): The test set.
721
+ hidden_dim (int): The hidden dimension of the model.
722
+ batch_size (int): The batch size.
723
+ learning_rate (float): The learning rate.
724
+ max_epochs (int): The maximum number of epochs.
725
+ smiles_emb_dim (int): The dimension of the SMILES embeddings.
726
+ smote_n_neighbors (int): The number of neighbors for the SMOTE oversampler.
727
+ use_ored_activity (bool): Whether to use the ORED activity column.
728
+ fast_dev_run (bool): Whether to run a fast development run.
729
+ disabled_embeddings (list): The list of disabled embeddings.
730
+
731
+ Returns:
732
+ tuple: The trained model, the trainer, and the metrics.
733
+ """
734
+ oversampler = SMOTE(k_neighbors=smote_n_neighbors, random_state=42)
735
+ train_ds = PROTAC_Dataset(
736
+ train_df,
737
+ protein_embeddings,
738
+ cell2embedding,
739
+ smiles2fp,
740
+ use_smote=True,
741
+ oversampler=oversampler,
742
+ use_ored_activity=use_ored_activity,
743
+ )
744
+ val_ds = PROTAC_Dataset(
745
+ val_df,
746
+ protein_embeddings,
747
+ cell2embedding,
748
+ smiles2fp,
749
+ use_ored_activity=use_ored_activity,
750
+ )
751
+ if test_df is not None:
752
+ test_ds = PROTAC_Dataset(
753
+ test_df,
754
+ protein_embeddings,
755
+ cell2embedding,
756
+ smiles2fp,
757
+ use_ored_activity=use_ored_activity,
758
+ )
759
+ logger = pl.loggers.TensorBoardLogger(
760
+ save_dir='../logs',
761
+ name='protac',
762
+ )
763
+ callbacks = [
764
+ pl.callbacks.EarlyStopping(
765
+ monitor='train_loss',
766
+ patience=10,
767
+ mode='max',
768
+ verbose=True,
769
+ ),
770
+ # pl.callbacks.ModelCheckpoint(
771
+ # monitor='val_acc',
772
+ # mode='max',
773
+ # verbose=True,
774
+ # filename='{epoch}-{val_metrics_opt_score:.4f}',
775
+ # ),
776
+ ]
777
+ # Define Trainer
778
+ trainer = pl.Trainer(
779
+ logger=logger,
780
+ callbacks=callbacks,
781
+ max_epochs=max_epochs,
782
+ fast_dev_run=fast_dev_run,
783
+ enable_model_summary=False,
784
+ enable_checkpointing=False,
785
+ )
786
+ model = PROTAC_Model(
787
+ hidden_dim=hidden_dim,
788
+ smiles_emb_dim=smiles_emb_dim,
789
+ poi_emb_dim=1024,
790
+ e3_emb_dim=1024,
791
+ cell_emb_dim=768,
792
+ batch_size=batch_size,
793
+ learning_rate=learning_rate,
794
+ train_dataset=train_ds,
795
+ val_dataset=val_ds,
796
+ test_dataset=test_ds if test_df is not None else None,
797
+ disabled_embeddings=disabled_embeddings,
798
+ )
799
+ with warnings.catch_warnings():
800
+ warnings.simplefilter("ignore")
801
+ trainer.fit(model)
802
+ metrics = trainer.validate(model, verbose=False)[0]
803
+ if test_df is not None:
804
+ test_metrics = trainer.test(model, verbose=False)[0]
805
+ metrics.update(test_metrics)
806
+ return model, trainer, metrics
807
+
808
+ # %% [markdown]
809
+ # Setup hyperparameter optimization:
810
+
811
+ # %%
812
+ import optuna
813
+ import pandas as pd
814
+
815
+
816
+ def objective(
817
+ trial,
818
+ train_df,
819
+ val_df,
820
+ hidden_dim_options,
821
+ batch_size_options,
822
+ learning_rate_options,
823
+ max_epochs_options,
824
+ fast_dev_run=False,
825
+ ) -> float:
826
+ # Generate the hyperparameters
827
+ hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
828
+ batch_size = trial.suggest_categorical('batch_size', batch_size_options)
829
+ learning_rate = trial.suggest_loguniform('learning_rate', *learning_rate_options)
830
+ max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options)
831
+
832
+ # Train the model with the current set of hyperparameters
833
+ _, _, metrics = train_model(
834
+ train_df,
835
+ val_df,
836
+ hidden_dim=hidden_dim,
837
+ batch_size=batch_size,
838
+ learning_rate=learning_rate,
839
+ max_epochs=max_epochs,
840
+ fast_dev_run=fast_dev_run,
841
+ )
842
+
843
+ # Metrics is a dictionary containing at least the validation loss
844
+ val_loss = metrics['val_loss']
845
+ val_acc = metrics['val_acc']
846
+ val_roc_auc = metrics['val_roc_auc']
847
+
848
+ # Optuna aims to minimize the objective
849
+ return val_loss - val_acc - val_roc_auc
850
+
851
+
852
+ def hyperparameter_tuning_and_training(
853
+ train_df,
854
+ val_df,
855
+ test_df,
856
+ fast_dev_run=False,
857
+ n_trials=20,
858
+ ) -> tuple:
859
+ """ Hyperparameter tuning and training of a PROTAC model.
860
+
861
+ Args:
862
+ train_df (pd.DataFrame): The training set.
863
+ val_df (pd.DataFrame): The validation set.
864
+ test_df (pd.DataFrame): The test set.
865
+ fast_dev_run (bool): Whether to run a fast development run.
866
+
867
+ Returns:
868
+ tuple: The trained model, the trainer, and the best metrics.
869
+ """
870
+ # Define the search space
871
+ hidden_dim_options = [256, 512, 768]
872
+ batch_size_options = [8, 16, 32]
873
+ learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
874
+ max_epochs_options = [10, 20, 50]
875
+
876
+ # Create an Optuna study object
877
+ study = optuna.create_study(direction='minimize')
878
+ study.optimize(lambda trial: objective(
879
+ trial,
880
+ train_df,
881
+ val_df,
882
+ hidden_dim_options,
883
+ batch_size_options,
884
+ learning_rate_options,
885
+ max_epochs_options,
886
+ fast_dev_run=fast_dev_run,),
887
+ n_trials=n_trials,
888
+ )
889
+
890
+ # Retrieve the best hyperparameters
891
+ best_params = study.best_params
892
+ best_hidden_dim = best_params['hidden_dim']
893
+ best_batch_size = best_params['batch_size']
894
+ best_learning_rate = best_params['learning_rate']
895
+ best_max_epochs = best_params['max_epochs']
896
+
897
+ # Retrain the model with the best hyperparameters
898
+ model, trainer, metrics = train_model(
899
+ train_df,
900
+ val_df,
901
+ test_df,
902
+ hidden_dim=best_hidden_dim,
903
+ batch_size=best_batch_size,
904
+ learning_rate=best_learning_rate,
905
+ max_epochs=best_max_epochs,
906
+ fast_dev_run=fast_dev_run,
907
+ )
908
+
909
+ # Return the best metrics
910
+ return model, trainer, metrics
911
+
912
+ # Example usage
913
+ # train_df, val_df, test_df = load_your_data() # You need to load your datasets here
914
+ # model, trainer, best_metrics = hyperparameter_tuning_and_training(train_df, val_df, test_df)
915
+
916
+ # %% [markdown]
917
+ # Loop over the different splits and train the model:
918
+
919
+ # %%
920
+ n_splits = 5
921
+ report = []
922
+ active_df = protac_df[protac_df[active_col].notna()]
923
+ train_val_df = active_df[~active_df.index.isin(unique_samples)]
924
+
925
+ # Make directory ../reports if it does not exist
926
+ if not os.path.exists('../reports'):
927
+ os.makedirs('../reports')
928
+
929
+ for group_type in ['random', 'uniprot', 'tanimoto']:
930
+ print(f'Starting CV for group type: {group_type}')
931
+ # Setup CV iterator and groups
932
+ if group_type == 'random':
933
+ kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
934
+ groups = None
935
+ elif group_type == 'uniprot':
936
+ # Split by Uniprot
937
+ kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
938
+ encoder = OrdinalEncoder()
939
+ groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
940
+ elif group_type == 'tanimoto':
941
+ # Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
942
+ kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
943
+ tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
944
+ encoder = OrdinalEncoder()
945
+ groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
946
+ # Start the CV over the folds
947
+ X = train_val_df.drop(columns=active_col)
948
+ y = train_val_df[active_col].tolist()
949
+ for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
950
+ train_df = train_val_df.iloc[train_index]
951
+ val_df = train_val_df.iloc[val_index]
952
+ stats = {
953
+ 'fold': k,
954
+ 'group_type': group_type,
955
+ 'train_len': len(train_df),
956
+ 'val_len': len(val_df),
957
+ 'train_perc': len(train_df) / len(train_val_df),
958
+ 'val_perc': len(val_df) / len(train_val_df),
959
+ 'train_active_perc': train_df[active_col].sum() / len(train_df),
960
+ 'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
961
+ 'val_active_perc': val_df[active_col].sum() / len(val_df),
962
+ 'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
963
+ 'test_active_perc': test_df[active_col].sum() / len(test_df),
964
+ 'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
965
+ 'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
966
+ 'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
967
+ }
968
+ if group_type != 'random':
969
+ stats['train_unique_groups'] = len(np.unique(groups[train_index]))
970
+ stats['val_unique_groups'] = len(np.unique(groups[val_index]))
971
+ # Train and evaluate the model
972
+ # model, trainer, metrics = train_model(train_df, val_df, test_df)
973
+ model, trainer, metrics = hyperparameter_tuning_and_training(
974
+ train_df,
975
+ val_df,
976
+ test_df,
977
+ fast_dev_run=False,
978
+ n_trials=50,
979
+ )
980
+ stats.update(metrics)
981
+ del model
982
+ del trainer
983
+ report.append(stats)
984
+ report = pd.DataFrame(report)
985
+ report.to_csv(
986
+ f'../reports/cv_report_hparam_search_{n_splits}-splits.csv', index=False,
987
+ )
988
+