ribesstefano commited on
Commit
ab45a22
·
1 Parent(s): acd572a

Finalized predictor script

Browse files
notebooks/protac_degradation_predictor.py CHANGED
@@ -1,19 +1,20 @@
1
- import optuna
2
- from optuna.samplers import TPESampler
3
- import h5py
4
  import os
5
  import pickle
6
  import warnings
7
  import logging
 
 
 
 
 
 
 
8
  import pandas as pd
9
  import numpy as np
10
- import urllib.request
11
 
12
  from rdkit import Chem
13
  from rdkit.Chem import AllChem
14
  from rdkit import DataStructs
15
- from collections import defaultdict
16
- from typing import Literal
17
  from jsonargparse import CLI
18
  from tqdm.auto import tqdm
19
  from imblearn.over_sampling import SMOTE, ADASYN
@@ -44,25 +45,39 @@ warnings.filterwarnings("ignore", ".*FixedLocator*")
44
  # Ignore UserWarning from PyTorch Lightning
45
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
46
 
47
-
48
  protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
49
- protac_df.head()
50
-
51
- # Get the unique Article IDs of the entries with NaN values in the Active column
52
- nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique()
53
- nan_active
54
 
55
  # Map E3 Ligase Iap to IAP
56
  protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
57
 
58
- cells = sorted(protac_df['Cell Type'].dropna().unique().tolist())
59
- print(f'Number of non-cleaned cell lines: {len(cells)}')
60
-
61
- cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist())
62
- print(f'Number of cleaned cell lines: {len(cells)}')
63
-
64
- unlabeled_df = protac_df[protac_df['Active'].isna()]
65
- print(f'Number of compounds in test set: {len(unlabeled_df)}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # ## Load Protein Embeddings
68
 
@@ -101,9 +116,10 @@ for cell_line in protac_df['Cell Line Identifier'].unique():
101
  cell2embedding[cell_line] = np.zeros(emb_shape)
102
 
103
  ## Precompute Molecular Fingerprints
 
104
  morgan_fpgen = AllChem.GetMorganGenerator(
105
  radius=15,
106
- fpSize=1024,
107
  includeChirality=True,
108
  )
109
 
@@ -131,7 +147,7 @@ print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smi
131
  tanimoto_matrix = defaultdict(list)
132
  for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
133
  fp1 = smiles2fp[smiles1]
134
- # TODO: Use BulkTanimotoSimilarity
135
  for j, smiles2 in enumerate(protac_df['Smiles'].unique()):
136
  if j < i:
137
  continue
@@ -153,7 +169,8 @@ class PROTAC_Dataset(Dataset):
153
  smiles2fp=smiles2fp,
154
  use_smote=False,
155
  oversampler=None,
156
- use_ored_activity=False,
 
157
  ):
158
  """ Initialize the PROTAC dataset
159
 
@@ -165,11 +182,13 @@ class PROTAC_Dataset(Dataset):
165
  use_smote (bool): Whether to use SMOTE for oversampling
166
  use_ored_activity (bool): Whether to use the 'Active - OR' column
167
  """
168
- # Filter out examples with NaN in 'Active' column
169
- self.data = protac_df # [~protac_df['Active'].isna()]
170
  self.protein_embeddings = protein_embeddings
171
  self.cell2embedding = cell2embedding
172
  self.smiles2fp = smiles2fp
 
 
173
 
174
  self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
175
  self.protein_emb_dim = protein_embeddings[list(
@@ -177,11 +196,18 @@ class PROTAC_Dataset(Dataset):
177
  self.cell_emb_dim = cell2embedding[list(
178
  cell2embedding.keys())[0]].shape[0]
179
 
180
- self.active_label = 'Active - OR' if use_ored_activity else 'Active'
 
 
 
 
 
 
 
181
 
 
182
  self.use_smote = use_smote
183
  self.oversampler = oversampler
184
- # Apply SMOTE
185
  if self.use_smote:
186
  self.apply_smote()
187
 
@@ -190,15 +216,11 @@ class PROTAC_Dataset(Dataset):
190
  features = []
191
  labels = []
192
  for _, row in self.data.iterrows():
193
- smiles_emb = smiles2fp[row['Smiles']]
194
- poi_emb = protein_embeddings[row['Uniprot']]
195
- e3_emb = protein_embeddings[row['E3 Ligase Uniprot']]
196
- cell_emb = cell2embedding[row['Cell Line Identifier']]
197
  features.append(np.hstack([
198
- smiles_emb.astype(np.float32),
199
- poi_emb.astype(np.float32),
200
- e3_emb.astype(np.float32),
201
- cell_emb.astype(np.float32),
202
  ]))
203
  labels.append(row[self.active_label])
204
 
@@ -231,27 +253,74 @@ class PROTAC_Dataset(Dataset):
231
  })
232
  self.data = df_smote
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  def __len__(self):
235
  return len(self.data)
236
 
237
  def __getitem__(self, idx):
238
- if self.use_smote:
239
- # NOTE: We do not need to look up the embeddings anymore
240
- elem = {
241
- 'smiles_emb': self.data['Smiles'].iloc[idx],
242
- 'poi_emb': self.data['Uniprot'].iloc[idx],
243
- 'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
244
- 'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
245
- 'active': self.data[self.active_label].iloc[idx],
246
- }
247
- else:
248
- elem = {
249
- 'smiles_emb': self.smiles2fp[self.data['Smiles'].iloc[idx]].astype(np.float32),
250
- 'poi_emb': self.protein_embeddings[self.data['Uniprot'].iloc[idx]].astype(np.float32),
251
- 'e3_emb': self.protein_embeddings[self.data['E3 Ligase Uniprot'].iloc[idx]].astype(np.float32),
252
- 'cell_emb': self.cell2embedding[self.data['Cell Line Identifier'].iloc[idx]].astype(np.float32),
253
- 'active': 1. if self.data[self.active_label].iloc[idx] else 0.,
254
- }
255
  return elem
256
 
257
 
@@ -260,18 +329,19 @@ class PROTAC_Model(pl.LightningModule):
260
  def __init__(
261
  self,
262
  hidden_dim: int,
263
- smiles_emb_dim: int = 1024,
264
  poi_emb_dim: int = 1024,
265
  e3_emb_dim: int = 1024,
266
  cell_emb_dim: int = 768,
267
  batch_size: int = 32,
268
  learning_rate: float = 1e-3,
269
  dropout: float = 0.2,
270
- join_embeddings: Literal['concat', 'sum'] = 'concat',
271
  train_dataset: PROTAC_Dataset = None,
272
  val_dataset: PROTAC_Dataset = None,
273
  test_dataset: PROTAC_Dataset = None,
274
  disabled_embeddings: list = [],
 
275
  ):
276
  super().__init__()
277
  self.poi_emb_dim = poi_emb_dim
@@ -286,6 +356,7 @@ class PROTAC_Model(pl.LightningModule):
286
  self.val_dataset = val_dataset
287
  self.test_dataset = test_dataset
288
  self.disabled_embeddings = disabled_embeddings
 
289
  # Set our init args as class attributes
290
  self.__dict__.update(locals()) # Add arguments as attributes
291
  # Save the arguments passed to init
@@ -296,19 +367,29 @@ class PROTAC_Model(pl.LightningModule):
296
  ]
297
  self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
298
 
299
- if 'poi' not in self.disabled_embeddings:
300
- self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
301
- if 'e3' not in self.disabled_embeddings:
302
- self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
303
- if 'cell' not in self.disabled_embeddings:
304
- self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
305
- if 'smiles' not in self.disabled_embeddings:
306
- self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
307
-
308
- if self.join_embeddings == 'concat':
 
 
 
 
 
 
 
 
309
  joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
310
  elif self.join_embeddings == 'sum':
311
  joint_dim = hidden_dim
 
 
312
  self.fc1 = nn.Linear(joint_dim, hidden_dim)
313
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
314
  self.fc3 = nn.Linear(hidden_dim, 1)
@@ -333,25 +414,46 @@ class PROTAC_Model(pl.LightningModule):
333
  model = {1}.load_from_checkpoint('checkpoint.ckpt')
334
  model.{0} = my_{0}
335
  '''
 
 
 
 
 
 
 
 
 
336
 
337
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
338
  embeddings = []
339
- if 'poi' not in self.disabled_embeddings:
340
- embeddings.append(self.poi_emb(poi_emb))
341
- if 'e3' not in self.disabled_embeddings:
342
- embeddings.append(self.e3_emb(e3_emb))
343
- if 'cell' not in self.disabled_embeddings:
344
- embeddings.append(self.cell_emb(cell_emb))
345
- if 'smiles' not in self.disabled_embeddings:
346
- embeddings.append(self.smiles_emb(smiles_emb))
347
- if self.join_embeddings == 'concat':
348
  x = torch.cat(embeddings, dim=1)
349
- elif self.join_embeddings == 'sum':
350
- if len(embeddings) > 1:
351
- embeddings = torch.stack(embeddings, dim=1)
352
- x = torch.sum(embeddings, dim=1)
353
- else:
354
- x = embeddings[0]
 
 
 
 
 
 
 
 
 
 
 
 
355
  x = self.dropout(F.relu(self.fc1(x)))
356
  x = self.dropout(F.relu(self.fc2(x)))
357
  x = self.fc3(x)
@@ -391,6 +493,25 @@ class PROTAC_Model(pl.LightningModule):
391
  cell_emb = batch['cell_emb']
392
  smiles_emb = batch['smiles_emb']
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
395
  return torch.sigmoid(y_hat)
396
 
@@ -398,6 +519,7 @@ class PROTAC_Model(pl.LightningModule):
398
  if self.train_dataset is None:
399
  format = 'train_dataset', self.__class__.__name__
400
  raise ValueError(self.missing_dataset_error.format(*format))
 
401
  return DataLoader(
402
  self.train_dataset,
403
  batch_size=self.batch_size,
@@ -425,23 +547,25 @@ class PROTAC_Model(pl.LightningModule):
425
  shuffle=False,
426
  )
427
 
428
-
429
  def train_model(
430
- train_df,
431
- val_df,
432
- test_df=None,
433
- hidden_dim=768,
434
- batch_size=8,
435
- learning_rate=2e-5,
436
- max_epochs=50,
437
- smiles_emb_dim=1024,
438
- join_embeddings='concat',
439
- smote_k_neighbors=5,
440
- use_ored_activity=True,
441
- fast_dev_run=False,
442
- use_logger=True,
443
- logger_name='protac',
444
- disabled_embeddings=[],
 
 
 
445
  ) -> tuple:
446
  """ Train a PROTAC model using the given datasets and hyperparameters.
447
 
@@ -455,7 +579,6 @@ def train_model(
455
  max_epochs (int): The maximum number of epochs.
456
  smiles_emb_dim (int): The dimension of the SMILES embeddings.
457
  smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
458
- use_ored_activity (bool): Whether to use the ORED activity column, i.e., "Active - OR" column.
459
  fast_dev_run (bool): Whether to run a fast development run.
460
  disabled_embeddings (list): The list of disabled embeddings.
461
 
@@ -468,16 +591,16 @@ def train_model(
468
  protein_embeddings,
469
  cell2embedding,
470
  smiles2fp,
471
- use_smote=True,
472
- oversampler=oversampler,
473
- use_ored_activity=use_ored_activity,
474
  )
475
  val_ds = PROTAC_Dataset(
476
  val_df,
477
  protein_embeddings,
478
  cell2embedding,
479
  smiles2fp,
480
- use_ored_activity=use_ored_activity,
481
  )
482
  if test_df is not None:
483
  test_ds = PROTAC_Dataset(
@@ -485,7 +608,7 @@ def train_model(
485
  protein_embeddings,
486
  cell2embedding,
487
  smiles2fp,
488
- use_ored_activity=use_ored_activity,
489
  )
490
  logger = pl.loggers.TensorBoardLogger(
491
  save_dir='../logs',
@@ -495,6 +618,18 @@ def train_model(
495
  pl.callbacks.EarlyStopping(
496
  monitor='train_loss',
497
  patience=10,
 
 
 
 
 
 
 
 
 
 
 
 
498
  mode='max',
499
  verbose=True,
500
  ),
@@ -514,6 +649,8 @@ def train_model(
514
  enable_model_summary=False,
515
  enable_checkpointing=False,
516
  enable_progress_bar=False,
 
 
517
  )
518
  model = PROTAC_Model(
519
  hidden_dim=hidden_dim,
@@ -522,8 +659,10 @@ def train_model(
522
  e3_emb_dim=1024,
523
  cell_emb_dim=768,
524
  batch_size=batch_size,
525
- learning_rate=learning_rate,
526
  join_embeddings=join_embeddings,
 
 
 
527
  train_dataset=train_ds,
528
  val_dataset=val_ds,
529
  test_dataset=test_ds if test_df is not None else None,
@@ -541,25 +680,42 @@ def train_model(
541
  # Setup hyperparameter optimization:
542
 
543
  def objective(
544
- trial,
545
- train_df,
546
- val_df,
547
- hidden_dim_options,
548
- batch_size_options,
549
- learning_rate_options,
550
- max_epochs_options,
551
- smote_k_neighbors_options,
552
- fast_dev_run=False,
553
- use_ored_activity=True,
554
- disabled_embeddings=[],
555
  ) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  # Generate the hyperparameters
557
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
558
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
559
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
560
- max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options)
561
- join_embeddings = trial.suggest_categorical('join_embeddings', ['concat', 'sum'])
562
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
 
 
 
563
 
564
  # Train the model with the current set of hyperparameters
565
  _, _, metrics = train_model(
@@ -569,11 +725,14 @@ def objective(
569
  batch_size=batch_size,
570
  join_embeddings=join_embeddings,
571
  learning_rate=learning_rate,
572
- max_epochs=max_epochs,
 
573
  smote_k_neighbors=smote_k_neighbors,
 
 
574
  use_logger=False,
575
  fast_dev_run=fast_dev_run,
576
- use_ored_activity=use_ored_activity,
577
  disabled_embeddings=disabled_embeddings,
578
  )
579
 
@@ -587,14 +746,14 @@ def objective(
587
 
588
 
589
  def hyperparameter_tuning_and_training(
590
- train_df,
591
- val_df,
592
- test_df,
593
- fast_dev_run=False,
594
- n_trials=20,
595
- logger_name='protac_hparam_search',
596
- use_ored_activity=True,
597
- disabled_embeddings=[],
598
  ) -> tuple:
599
  """ Hyperparameter tuning and training of a PROTAC model.
600
 
@@ -603,6 +762,10 @@ def hyperparameter_tuning_and_training(
603
  val_df (pd.DataFrame): The validation set.
604
  test_df (pd.DataFrame): The test set.
605
  fast_dev_run (bool): Whether to run a fast development run.
 
 
 
 
606
 
607
  Returns:
608
  tuple: The trained model, the trainer, and the best metrics.
@@ -611,7 +774,6 @@ def hyperparameter_tuning_and_training(
611
  hidden_dim_options = [256, 512, 768]
612
  batch_size_options = [8, 16, 32]
613
  learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
614
- max_epochs_options = [10, 20, 50]
615
  smote_k_neighbors_options = list(range(3, 16))
616
 
617
  # Set the verbosity of Optuna
@@ -624,13 +786,12 @@ def hyperparameter_tuning_and_training(
624
  trial,
625
  train_df,
626
  val_df,
627
- hidden_dim_options,
628
- batch_size_options,
629
- learning_rate_options,
630
- max_epochs_options,
631
  smote_k_neighbors_options=smote_k_neighbors_options,
632
  fast_dev_run=fast_dev_run,
633
- use_ored_activity=use_ored_activity,
634
  disabled_embeddings=disabled_embeddings,
635
  ),
636
  n_trials=n_trials,
@@ -644,7 +805,7 @@ def hyperparameter_tuning_and_training(
644
  use_logger=True,
645
  logger_name=logger_name,
646
  fast_dev_run=fast_dev_run,
647
- use_ored_activity=use_ored_activity,
648
  disabled_embeddings=disabled_embeddings,
649
  **study.best_params,
650
  )
@@ -657,10 +818,11 @@ def hyperparameter_tuning_and_training(
657
 
658
 
659
  def main(
660
- use_ored_activity: bool = True,
661
  n_trials: int = 50,
662
- n_splits: int = 5,
663
  fast_dev_run: bool = False,
 
 
664
  ):
665
  """ Train a PROTAC model using the given datasets and hyperparameters.
666
 
@@ -671,101 +833,178 @@ def main(
671
  fast_dev_run (bool): Whether to run a fast development run.
672
  """
673
  ## Set the Column to Predict
674
- active_col = 'Active - OR' if use_ored_activity else 'Active'
675
- active_name = active_col.replace(' ', '').lower()
676
- active_name = 'active-and' if active_name == 'active' else active_name
677
-
678
- ## Test Sets
679
 
680
- active_df = protac_df[protac_df[active_col].notna()]
681
- # Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
682
- # * its SMILES appears only once in the dataframe
683
- # * its Uniprot appears only once in the dataframe
684
- # * its (Smiles, Uniprot) pair appears only once in the dataframe
685
- unique_smiles = active_df['Smiles'].value_counts() == 1
686
- unique_uniprot = active_df['Uniprot'].value_counts() == 1
687
- unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1
688
-
689
- # Get the indices of the unique samples
690
- unique_smiles_idx = active_df['Smiles'].map(unique_smiles)
691
- unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot)
692
- unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot)
693
-
694
- # Cross the indices to get the unique samples
695
- unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx & unique_smiles_uniprot_idx].index
696
- test_df = active_df.loc[unique_samples]
697
- train_val_df = active_df[~active_df.index.isin(unique_samples)]
698
 
699
- ## Cross-Validation Training
 
 
700
 
701
- # Cross validation training with 5 splits. The split operation is done in three different ways:
702
- #
703
- # * Random split
704
- # * POI-wise: some POIs never in both splits
705
- # * Least Tanimoto similarity PROTAC-wise
706
 
707
- # NOTE: When set to 60, it will result in 29 groups, with nice distributions of
708
- # the number of unique groups in the train and validation sets, together with
709
- # the number of active and inactive PROTACs.
710
- n_bins_tanimoto = 60 if active_col == 'Active' else 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
 
 
712
  # Make directory ../reports if it does not exist
713
  if not os.path.exists('../reports'):
714
  os.makedirs('../reports')
715
 
716
- # Seed everything in pytorch lightning
717
- pl.seed_everything(42)
718
-
719
- # Loop over the different splits and train the model:
720
- report = []
721
- for group_type in ['random', 'uniprot', 'tanimoto']:
722
- print('-' * 100)
723
- print(f'Starting CV for group type: {group_type}')
724
- print('-' * 100)
725
- # Setup CV iterator and groups
726
- if group_type == 'random':
727
- kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
728
- groups = None
729
- elif group_type == 'uniprot':
730
- # Split by Uniprot
731
- kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
732
- encoder = OrdinalEncoder()
733
- groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
734
- elif group_type == 'tanimoto':
735
- # Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
736
- kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
737
- tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
738
- encoder = OrdinalEncoder()
739
- groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
740
  # Start the CV over the folds
741
  X = train_val_df.drop(columns=active_col)
742
  y = train_val_df[active_col].tolist()
743
- for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
 
744
  print('-' * 100)
745
- print(f'Starting CV for group type: {group_type}, fold: {k}')
746
  print('-' * 100)
747
  train_df = train_val_df.iloc[train_index]
748
  val_df = train_val_df.iloc[val_index]
 
 
 
 
749
  stats = {
750
  'fold': k,
751
- 'group_type': group_type,
752
  'train_len': len(train_df),
753
  'val_len': len(val_df),
754
  'train_perc': len(train_df) / len(train_val_df),
755
  'val_perc': len(val_df) / len(train_val_df),
756
- 'train_active_perc': train_df[active_col].sum() / len(train_df),
757
- 'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
758
- 'val_active_perc': val_df[active_col].sum() / len(val_df),
759
- 'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
760
- 'test_active_perc': test_df[active_col].sum() / len(test_df),
761
- 'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
762
- 'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
763
- 'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
764
- 'disabled_embeddings': np.nan,
765
  }
766
- if group_type != 'random':
767
- stats['train_unique_groups'] = len(np.unique(groups[train_index]))
768
- stats['val_unique_groups'] = len(np.unique(groups[val_index]))
 
 
769
  # Train and evaluate the model
770
  model, trainer, metrics = hyperparameter_tuning_and_training(
771
  train_df,
@@ -773,8 +1012,8 @@ def main(
773
  test_df,
774
  fast_dev_run=fast_dev_run,
775
  n_trials=n_trials,
776
- logger_name=f'protac_{active_name}_{group_type}_fold_{k}',
777
- use_ored_activity=use_ored_activity,
778
  )
779
  hparams = {p.strip('hparam_'): v for p, v in stats.items() if p.startswith('hparam_')}
780
  stats.update(metrics)
@@ -793,8 +1032,8 @@ def main(
793
  val_df,
794
  test_df,
795
  fast_dev_run=fast_dev_run,
796
- logger_name=f'protac_{active_name}_{group_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
797
- use_ored_activity=use_ored_activity,
798
  disabled_embeddings=disabled_embeddings,
799
  **hparams,
800
  )
@@ -803,11 +1042,11 @@ def main(
803
  del model
804
  del trainer
805
 
806
- report = pd.DataFrame(report)
807
- report.to_csv(
808
- f'../reports/cv_report_hparam_search_{n_splits}-splits_{active_name}.csv',
809
- index=False,
810
- )
811
 
812
 
813
  if __name__ == '__main__':
 
 
 
 
1
  import os
2
  import pickle
3
  import warnings
4
  import logging
5
+ from collections import defaultdict
6
+ from typing import Literal, List, Tuple, Optional
7
+ import urllib.request
8
+
9
+ import optuna
10
+ from optuna.samplers import TPESampler
11
+ import h5py
12
  import pandas as pd
13
  import numpy as np
 
14
 
15
  from rdkit import Chem
16
  from rdkit.Chem import AllChem
17
  from rdkit import DataStructs
 
 
18
  from jsonargparse import CLI
19
  from tqdm.auto import tqdm
20
  from imblearn.over_sampling import SMOTE, ADASYN
 
45
  # Ignore UserWarning from PyTorch Lightning
46
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
47
 
 
48
  protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
 
 
 
 
 
49
 
50
  # Map E3 Ligase Iap to IAP
51
  protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
52
 
53
+ def is_active(DC50: float, Dmax: float, oring=False, pDC50_threshold=7.0, Dmax_threshold=0.8) -> bool:
54
+ """ Check if a PROTAC is active based on DC50 and Dmax.
55
+ Args:
56
+ DC50(float): DC50 in nM
57
+ Dmax(float): Dmax in %
58
+ Returns:
59
+ bool: True if active, False if inactive, np.nan if either DC50 or Dmax is NaN
60
+ """
61
+ pDC50 = -np.log10(DC50 * 1e-9) if pd.notnull(DC50) else np.nan
62
+ Dmax = Dmax / 100
63
+ if pd.notnull(pDC50):
64
+ if pDC50 < pDC50_threshold:
65
+ return False
66
+ if pd.notnull(Dmax):
67
+ if Dmax < Dmax_threshold:
68
+ return False
69
+ if oring:
70
+ if pd.notnull(pDC50):
71
+ return True if pDC50 >= pDC50_threshold else False
72
+ elif pd.notnull(Dmax):
73
+ return True if Dmax >= Dmax_threshold else False
74
+ else:
75
+ return np.nan
76
+ else:
77
+ if pd.notnull(pDC50) and pd.notnull(Dmax):
78
+ return True if pDC50 >= pDC50_threshold and Dmax >= Dmax_threshold else False
79
+ else:
80
+ return np.nan
81
 
82
  # ## Load Protein Embeddings
83
 
 
116
  cell2embedding[cell_line] = np.zeros(emb_shape)
117
 
118
  ## Precompute Molecular Fingerprints
119
+ fingerprint_size = 224
120
  morgan_fpgen = AllChem.GetMorganGenerator(
121
  radius=15,
122
+ fpSize=fingerprint_size,
123
  includeChirality=True,
124
  )
125
 
 
147
  tanimoto_matrix = defaultdict(list)
148
  for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
149
  fp1 = smiles2fp[smiles1]
150
+ # TODO: Use BulkTanimotoSimilarity for better performance
151
  for j, smiles2 in enumerate(protac_df['Smiles'].unique()):
152
  if j < i:
153
  continue
 
169
  smiles2fp=smiles2fp,
170
  use_smote=False,
171
  oversampler=None,
172
+ active_label='Active',
173
+ include_mol_graphs=False,
174
  ):
175
  """ Initialize the PROTAC dataset
176
 
 
182
  use_smote (bool): Whether to use SMOTE for oversampling
183
  use_ored_activity (bool): Whether to use the 'Active - OR' column
184
  """
185
+ # Filter out examples with NaN in active_col column
186
+ self.data = protac_df # [~protac_df[active_col].isna()]
187
  self.protein_embeddings = protein_embeddings
188
  self.cell2embedding = cell2embedding
189
  self.smiles2fp = smiles2fp
190
+ self.active_label = active_label
191
+ self.include_mol_graphs = include_mol_graphs
192
 
193
  self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
194
  self.protein_emb_dim = protein_embeddings[list(
 
196
  self.cell_emb_dim = cell2embedding[list(
197
  cell2embedding.keys())[0]].shape[0]
198
 
199
+ # Look up the embeddings
200
+ self.data = pd.DataFrame({
201
+ 'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(),
202
+ 'Uniprot': self.data['Uniprot'].apply(lambda x: protein_embeddings[x].astype(np.float32)).tolist(),
203
+ 'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein_embeddings[x].astype(np.float32)).tolist(),
204
+ 'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(),
205
+ self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
206
+ })
207
 
208
+ # Apply SMOTE
209
  self.use_smote = use_smote
210
  self.oversampler = oversampler
 
211
  if self.use_smote:
212
  self.apply_smote()
213
 
 
216
  features = []
217
  labels = []
218
  for _, row in self.data.iterrows():
 
 
 
 
219
  features.append(np.hstack([
220
+ row['Smiles'],
221
+ row['Uniprot'],
222
+ row['E3 Ligase Uniprot'],
223
+ row['Cell Line Identifier'],
224
  ]))
225
  labels.append(row[self.active_label])
226
 
 
253
  })
254
  self.data = df_smote
255
 
256
+ def fit_scaling(self, use_single_scaler=False, **scaler_kwargs) -> dict:
257
+ """ Fit the scalers for the data.
258
+
259
+ Returns:
260
+ dict: The fitted scalers.
261
+ """
262
+ if use_single_scaler:
263
+ scaler = StandardScaler(**scaler_kwargs)
264
+ embeddings = np.hstack([
265
+ np.array(self.data['Smiles'].tolist()),
266
+ np.array(self.data['Uniprot'].tolist()),
267
+ np.array(self.data['E3 Ligase Uniprot'].tolist()),
268
+ np.array(self.data['Cell Line Identifier'].tolist()),
269
+ ])
270
+ scaler.fit(embeddings)
271
+ return scaler
272
+ else:
273
+ scalers = {}
274
+ scalers['Smiles'] = StandardScaler(**scaler_kwargs)
275
+ scalers['Uniprot'] = StandardScaler(**scaler_kwargs)
276
+ scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs)
277
+ scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs)
278
+
279
+ scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy()))
280
+ scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy()))
281
+ scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
282
+ scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
283
+
284
+ return scalers
285
+
286
+ def apply_scaling(self, scalers: dict, use_single_scaler=False):
287
+ """ Apply scaling to the data.
288
+
289
+ Args:
290
+ scalers (dict): The scalers for each feature.
291
+ """
292
+ if use_single_scaler:
293
+ embeddings = np.hstack([
294
+ np.array(self.data['Smiles'].tolist()),
295
+ np.array(self.data['Uniprot'].tolist()),
296
+ np.array(self.data['E3 Ligase Uniprot'].tolist()),
297
+ np.array(self.data['Cell Line Identifier'].tolist()),
298
+ ])
299
+ scaled_embeddings = scalers.transform(embeddings)
300
+ self.data = pd.DataFrame({
301
+ 'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]),
302
+ 'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]),
303
+ 'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]),
304
+ 'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]),
305
+ self.active_label: self.data[self.active_label]
306
+ })
307
+ else:
308
+ self.data['Smiles'] = self.data['Smiles'].apply(lambda x: scalers['Smiles'].transform(x[np.newaxis, :])[0])
309
+ self.data['Uniprot'] = self.data['Uniprot'].apply(lambda x: scalers['Uniprot'].transform(x[np.newaxis, :])[0])
310
+ self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
311
+ self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
312
+
313
  def __len__(self):
314
  return len(self.data)
315
 
316
  def __getitem__(self, idx):
317
+ elem = {
318
+ 'smiles_emb': self.data['Smiles'].iloc[idx],
319
+ 'poi_emb': self.data['Uniprot'].iloc[idx],
320
+ 'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
321
+ 'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
322
+ 'active': self.data[self.active_label].iloc[idx],
323
+ }
 
 
 
 
 
 
 
 
 
 
324
  return elem
325
 
326
 
 
329
  def __init__(
330
  self,
331
  hidden_dim: int,
332
+ smiles_emb_dim: int = fingerprint_size,
333
  poi_emb_dim: int = 1024,
334
  e3_emb_dim: int = 1024,
335
  cell_emb_dim: int = 768,
336
  batch_size: int = 32,
337
  learning_rate: float = 1e-3,
338
  dropout: float = 0.2,
339
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
340
  train_dataset: PROTAC_Dataset = None,
341
  val_dataset: PROTAC_Dataset = None,
342
  test_dataset: PROTAC_Dataset = None,
343
  disabled_embeddings: list = [],
344
+ apply_scaling: bool = False,
345
  ):
346
  super().__init__()
347
  self.poi_emb_dim = poi_emb_dim
 
356
  self.val_dataset = val_dataset
357
  self.test_dataset = test_dataset
358
  self.disabled_embeddings = disabled_embeddings
359
+ self.apply_scaling = apply_scaling
360
  # Set our init args as class attributes
361
  self.__dict__.update(locals()) # Add arguments as attributes
362
  # Save the arguments passed to init
 
367
  ]
368
  self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
369
 
370
+ # Define "surrogate models" branches
371
+ if self.join_embeddings != 'beginning':
372
+ if 'poi' not in self.disabled_embeddings:
373
+ self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
374
+ if 'e3' not in self.disabled_embeddings:
375
+ self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
376
+ if 'cell' not in self.disabled_embeddings:
377
+ self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
378
+ if 'smiles' not in self.disabled_embeddings:
379
+ self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
380
+
381
+ # Define hidden dimension for joining layer
382
+ if self.join_embeddings == 'beginning':
383
+ joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0
384
+ joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
385
+ joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
386
+ joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
387
+ elif self.join_embeddings == 'concat':
388
  joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
389
  elif self.join_embeddings == 'sum':
390
  joint_dim = hidden_dim
391
+
392
+ self.fc0 = nn.Linear(joint_dim, joint_dim)
393
  self.fc1 = nn.Linear(joint_dim, hidden_dim)
394
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
395
  self.fc3 = nn.Linear(hidden_dim, 1)
 
414
  model = {1}.load_from_checkpoint('checkpoint.ckpt')
415
  model.{0} = my_{0}
416
  '''
417
+
418
+ # Apply scaling in datasets
419
+ if self.apply_scaling:
420
+ use_single_scaler = True if self.join_embeddings == 'beginning' else False
421
+ self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
422
+ self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
423
+ self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
424
+ if self.test_dataset:
425
+ self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
426
 
427
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
428
  embeddings = []
429
+ if self.join_embeddings == 'beginning':
430
+ if 'poi' not in self.disabled_embeddings:
431
+ embeddings.append(poi_emb)
432
+ if 'e3' not in self.disabled_embeddings:
433
+ embeddings.append(e3_emb)
434
+ if 'cell' not in self.disabled_embeddings:
435
+ embeddings.append(cell_emb)
436
+ if 'smiles' not in self.disabled_embeddings:
437
+ embeddings.append(smiles_emb)
438
  x = torch.cat(embeddings, dim=1)
439
+ x = self.dropout(F.relu(self.fc0(x)))
440
+ else:
441
+ if 'poi' not in self.disabled_embeddings:
442
+ embeddings.append(self.poi_emb(poi_emb))
443
+ if 'e3' not in self.disabled_embeddings:
444
+ embeddings.append(self.e3_emb(e3_emb))
445
+ if 'cell' not in self.disabled_embeddings:
446
+ embeddings.append(self.cell_emb(cell_emb))
447
+ if 'smiles' not in self.disabled_embeddings:
448
+ embeddings.append(self.smiles_emb(smiles_emb))
449
+ if self.join_embeddings == 'concat':
450
+ x = torch.cat(embeddings, dim=1)
451
+ elif self.join_embeddings == 'sum':
452
+ if len(embeddings) > 1:
453
+ embeddings = torch.stack(embeddings, dim=1)
454
+ x = torch.sum(embeddings, dim=1)
455
+ else:
456
+ x = embeddings[0]
457
  x = self.dropout(F.relu(self.fc1(x)))
458
  x = self.dropout(F.relu(self.fc2(x)))
459
  x = self.fc3(x)
 
493
  cell_emb = batch['cell_emb']
494
  smiles_emb = batch['smiles_emb']
495
 
496
+ if self.apply_scaling:
497
+ if self.join_embeddings == 'beginning':
498
+ embeddings = np.hstack([
499
+ np.array(smiles_emb.tolist()),
500
+ np.array(poi_emb.tolist()),
501
+ np.array(e3_emb.tolist()),
502
+ np.array(cell_emb.tolist()),
503
+ ])
504
+ embeddings = self.scalers.transform(embeddings)
505
+ smiles_emb = embeddings[:, :self.smiles_emb_dim]
506
+ poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
507
+ e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
508
+ cell_emb = embeddings[:, -self.cell_emb_dim:]
509
+ else:
510
+ poi_emb = self.scalers['Uniprot'].transform(poi_emb)
511
+ e3_emb = self.scalers['E3 Ligase Uniprot'].transform(e3_emb)
512
+ cell_emb = self.scalers['Cell Line Identifier'].transform(cell_emb)
513
+ smiles_emb = self.scalers['Smiles'].transform(smiles_emb)
514
+
515
  y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
516
  return torch.sigmoid(y_hat)
517
 
 
519
  if self.train_dataset is None:
520
  format = 'train_dataset', self.__class__.__name__
521
  raise ValueError(self.missing_dataset_error.format(*format))
522
+
523
  return DataLoader(
524
  self.train_dataset,
525
  batch_size=self.batch_size,
 
547
  shuffle=False,
548
  )
549
 
 
550
  def train_model(
551
+ train_df: pd.DataFrame,
552
+ val_df: pd.DataFrame,
553
+ test_df: Optional[pd.DataFrame] = None,
554
+ hidden_dim: int = 768,
555
+ batch_size: int = 8,
556
+ learning_rate: float = 2e-5,
557
+ dropout: float = 0.2,
558
+ max_epochs: int = 50,
559
+ smiles_emb_dim: int = fingerprint_size,
560
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
561
+ smote_k_neighbors:int = 5,
562
+ use_smote: bool = True,
563
+ apply_scaling: bool = False,
564
+ active_label:str = 'Active',
565
+ fast_dev_run: bool = False,
566
+ use_logger: bool = True,
567
+ logger_name: str = 'protac',
568
+ disabled_embeddings: List[str] = [],
569
  ) -> tuple:
570
  """ Train a PROTAC model using the given datasets and hyperparameters.
571
 
 
579
  max_epochs (int): The maximum number of epochs.
580
  smiles_emb_dim (int): The dimension of the SMILES embeddings.
581
  smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
 
582
  fast_dev_run (bool): Whether to run a fast development run.
583
  disabled_embeddings (list): The list of disabled embeddings.
584
 
 
591
  protein_embeddings,
592
  cell2embedding,
593
  smiles2fp,
594
+ use_smote=use_smote,
595
+ oversampler=oversampler if use_smote else None,
596
+ active_label=active_label,
597
  )
598
  val_ds = PROTAC_Dataset(
599
  val_df,
600
  protein_embeddings,
601
  cell2embedding,
602
  smiles2fp,
603
+ active_label=active_label,
604
  )
605
  if test_df is not None:
606
  test_ds = PROTAC_Dataset(
 
608
  protein_embeddings,
609
  cell2embedding,
610
  smiles2fp,
611
+ active_label=active_label,
612
  )
613
  logger = pl.loggers.TensorBoardLogger(
614
  save_dir='../logs',
 
618
  pl.callbacks.EarlyStopping(
619
  monitor='train_loss',
620
  patience=10,
621
+ mode='min',
622
+ verbose=True,
623
+ ),
624
+ pl.callbacks.EarlyStopping(
625
+ monitor='val_loss',
626
+ patience=5,
627
+ mode='min',
628
+ verbose=True,
629
+ ),
630
+ pl.callbacks.EarlyStopping(
631
+ monitor='val_acc',
632
+ patience=10,
633
  mode='max',
634
  verbose=True,
635
  ),
 
649
  enable_model_summary=False,
650
  enable_checkpointing=False,
651
  enable_progress_bar=False,
652
+ devices=1,
653
+ num_nodes=1,
654
  )
655
  model = PROTAC_Model(
656
  hidden_dim=hidden_dim,
 
659
  e3_emb_dim=1024,
660
  cell_emb_dim=768,
661
  batch_size=batch_size,
 
662
  join_embeddings=join_embeddings,
663
+ dropout=dropout,
664
+ learning_rate=learning_rate,
665
+ apply_scaling=apply_scaling,
666
  train_dataset=train_ds,
667
  val_dataset=val_ds,
668
  test_dataset=test_ds if test_df is not None else None,
 
680
  # Setup hyperparameter optimization:
681
 
682
  def objective(
683
+ trial: optuna.Trial,
684
+ train_df: pd.DataFrame,
685
+ val_df: pd.DataFrame,
686
+ hidden_dim_options: List[int] = [256, 512, 768],
687
+ batch_size_options: List[int] = [8, 16, 32],
688
+ learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
689
+ smote_k_neighbors_options: List[int] = list(range(3, 16)),
690
+ dropout_options: Tuple[float, float] = (0.1, 0.5),
691
+ fast_dev_run: bool = False,
692
+ active_label: str = 'Active',
693
+ disabled_embeddings: List[str] = [],
694
  ) -> float:
695
+ """ Objective function for hyperparameter optimization.
696
+
697
+ Args:
698
+ trial (optuna.Trial): The Optuna trial object.
699
+ train_df (pd.DataFrame): The training set.
700
+ val_df (pd.DataFrame): The validation set.
701
+ hidden_dim_options (List[int]): The hidden dimension options.
702
+ batch_size_options (List[int]): The batch size options.
703
+ learning_rate_options (Tuple[float, float]): The learning rate options.
704
+ smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
705
+ dropout_options (Tuple[float, float]): The dropout options.
706
+ fast_dev_run (bool): Whether to run a fast development run.
707
+ active_label (str): The active label column.
708
+ disabled_embeddings (List[str]): The list of disabled embeddings.
709
+ """
710
  # Generate the hyperparameters
711
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
712
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
713
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
714
+ join_embeddings = trial.suggest_categorical('join_embeddings', ['beginning', 'concat', 'sum'])
 
715
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
716
+ use_smote = trial.suggest_categorical('use_smote', [True, False])
717
+ apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
718
+ dropout = trial.suggest_float('dropout', *dropout_options)
719
 
720
  # Train the model with the current set of hyperparameters
721
  _, _, metrics = train_model(
 
725
  batch_size=batch_size,
726
  join_embeddings=join_embeddings,
727
  learning_rate=learning_rate,
728
+ dropout=dropout,
729
+ max_epochs=100,
730
  smote_k_neighbors=smote_k_neighbors,
731
+ apply_scaling=apply_scaling,
732
+ use_smote=use_smote,
733
  use_logger=False,
734
  fast_dev_run=fast_dev_run,
735
+ active_label=active_label,
736
  disabled_embeddings=disabled_embeddings,
737
  )
738
 
 
746
 
747
 
748
  def hyperparameter_tuning_and_training(
749
+ train_df: pd.DataFrame,
750
+ val_df: pd.DataFrame,
751
+ test_df: pd.DataFrame,
752
+ fast_dev_run: bool = False,
753
+ n_trials: int = 50,
754
+ logger_name: str = 'protac_hparam_search',
755
+ active_label: str = 'Active',
756
+ disabled_embeddings: List[str] = [],
757
  ) -> tuple:
758
  """ Hyperparameter tuning and training of a PROTAC model.
759
 
 
762
  val_df (pd.DataFrame): The validation set.
763
  test_df (pd.DataFrame): The test set.
764
  fast_dev_run (bool): Whether to run a fast development run.
765
+ n_trials (int): The number of hyperparameter optimization trials.
766
+ logger_name (str): The name of the logger.
767
+ active_label (str): The active label column.
768
+ disabled_embeddings (List[str]): The list of disabled embeddings.
769
 
770
  Returns:
771
  tuple: The trained model, the trainer, and the best metrics.
 
774
  hidden_dim_options = [256, 512, 768]
775
  batch_size_options = [8, 16, 32]
776
  learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
 
777
  smote_k_neighbors_options = list(range(3, 16))
778
 
779
  # Set the verbosity of Optuna
 
786
  trial,
787
  train_df,
788
  val_df,
789
+ hidden_dim_options=hidden_dim_options,
790
+ batch_size_options=batch_size_options,
791
+ learning_rate_options=learning_rate_options,
 
792
  smote_k_neighbors_options=smote_k_neighbors_options,
793
  fast_dev_run=fast_dev_run,
794
+ active_label=active_label,
795
  disabled_embeddings=disabled_embeddings,
796
  ),
797
  n_trials=n_trials,
 
805
  use_logger=True,
806
  logger_name=logger_name,
807
  fast_dev_run=fast_dev_run,
808
+ active_label=active_label,
809
  disabled_embeddings=disabled_embeddings,
810
  **study.best_params,
811
  )
 
818
 
819
 
820
  def main(
821
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
822
  n_trials: int = 50,
 
823
  fast_dev_run: bool = False,
824
+ test_split: float = 0.2,
825
+ cv_n_splits: int = 5,
826
  ):
827
  """ Train a PROTAC model using the given datasets and hyperparameters.
828
 
 
833
  fast_dev_run (bool): Whether to run a fast development run.
834
  """
835
  ## Set the Column to Predict
836
+ active_name = active_col.replace(' ', '_').strip('(').strip(')').strip(',')
 
 
 
 
837
 
838
+ # Get Dmax_threshold from the active_col
839
+ Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
840
+ pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
 
842
+ protac_df[active_col] = protac_df.apply(
843
+ lambda x: is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
844
+ )
845
 
846
+ ## Test Sets
 
 
 
 
847
 
848
+ test_indeces = {}
849
+
850
+ ### Random Split
851
+
852
+ # Randomly select 20% of the active PROTACs as the test set
853
+ active_df = protac_df[protac_df[active_col].notna()].copy()
854
+ test_df = active_df.sample(frac=test_split, random_state=42)
855
+ test_indeces['random'] = test_df.index
856
+
857
+ ### E3-based Split
858
+
859
+ encoder = OrdinalEncoder()
860
+ protac_df['E3 Group'] = encoder.fit_transform(protac_df[['E3 Ligase']]).astype(int)
861
+ active_df = protac_df[protac_df[active_col].notna()].copy()
862
+ test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
863
+ test_indeces['e3_ligase'] = test_df.index
864
+
865
+ ### Tanimoto-based Split
866
+
867
+ n_bins_tanimoto = 200
868
+ tanimoto_groups = pd.cut(protac_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
869
+ encoder = OrdinalEncoder()
870
+ protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
871
+ active_df = protac_df[protac_df[active_col].notna()].copy()
872
+
873
+ test_df = []
874
+ # For each group, get the number of active and inactive entries. Then, add those
875
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
876
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
877
+ # in the active_col in test_df is roughly 50%.
878
+ # Start the loop from the groups containing the smallest number of entries.
879
+ for group in reversed(active_df['Tanimoto Group'].value_counts().index):
880
+ group_df = active_df[active_df['Tanimoto Group'] == group]
881
+ if test_df == []:
882
+ test_df.append(group_df)
883
+ continue
884
+
885
+ num_entries = len(group_df)
886
+ num_active_group = group_df[active_col].sum()
887
+ num_inactive_group = num_entries - num_active_group
888
+
889
+ tmp_test_df = pd.concat(test_df)
890
+ num_entries_test = len(tmp_test_df)
891
+ num_active_test = tmp_test_df[active_col].sum()
892
+ num_inactive_test = num_entries_test - num_active_test
893
+
894
+ # Check if the group entries can be added to the test_df
895
+ if num_entries_test + num_entries < test_split * len(active_df):
896
+ # Add anything at the beggining
897
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
898
+ test_df.append(group_df)
899
+ continue
900
+ # Be more selective and make sure that the percentage of active and
901
+ # inactive is balanced
902
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
903
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
904
+ test_df.append(group_df)
905
+ test_df = pd.concat(test_df)
906
+ # Save to global dictionary of test indeces
907
+ test_indeces['tanimoto'] = test_df.index
908
+
909
+ ### Target-based Split
910
+
911
+ encoder = OrdinalEncoder()
912
+ protac_df['Uniprot Group'] = encoder.fit_transform(protac_df[['Uniprot']]).astype(int)
913
+ active_df = protac_df[protac_df[active_col].notna()].copy()
914
+
915
+ test_df = []
916
+ # For each group, get the number of active and inactive entries. Then, add those
917
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
918
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
919
+ # in the active_col in test_df is roughly 50%.
920
+ # Start the loop from the groups containing the smallest number of entries.
921
+ for group in reversed(active_df['Uniprot'].value_counts().index):
922
+ group_df = active_df[active_df['Uniprot'] == group]
923
+ if test_df == []:
924
+ test_df.append(group_df)
925
+ continue
926
+
927
+ num_entries = len(group_df)
928
+ num_active_group = group_df[active_col].sum()
929
+ num_inactive_group = num_entries - num_active_group
930
+
931
+ tmp_test_df = pd.concat(test_df)
932
+ num_entries_test = len(tmp_test_df)
933
+ num_active_test = tmp_test_df[active_col].sum()
934
+ num_inactive_test = num_entries_test - num_active_test
935
+
936
+ # Check if the group entries can be added to the test_df
937
+ if num_entries_test + num_entries < test_split * len(active_df):
938
+ # Add anything at the beggining
939
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
940
+ test_df.append(group_df)
941
+ continue
942
+ # Be more selective and make sure that the percentage of active and
943
+ # inactive is balanced
944
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
945
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
946
+ test_df.append(group_df)
947
+ test_df = pd.concat(test_df)
948
+ # Save to global dictionary of test indeces
949
+ test_indeces['uniprot'] = test_df.index
950
 
951
+ ## Cross-Validation Training
952
+
953
  # Make directory ../reports if it does not exist
954
  if not os.path.exists('../reports'):
955
  os.makedirs('../reports')
956
 
957
+ for split_type, indeces in test_indeces.items():
958
+ active_df = protac_df[protac_df[active_col].notna()].copy()
959
+ test_df = active_df.loc[indeces]
960
+ train_val_df = active_df[~active_df.index.isin(test_df.index)]
961
+
962
+ if split_type == 'random':
963
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
964
+ group = None
965
+ elif split_type == 'e3_ligase':
966
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
967
+ group = train_val_df['E3 Group'].to_numpy()
968
+ elif split_type == 'tanimoto':
969
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
970
+ group = train_val_df['Tanimoto Group'].to_numpy()
971
+ elif split_type == 'uniprot':
972
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
973
+ group = train_val_df['Uniprot Group'].to_numpy()
 
 
 
 
 
 
 
974
  # Start the CV over the folds
975
  X = train_val_df.drop(columns=active_col)
976
  y = train_val_df[active_col].tolist()
977
+ report = []
978
+ for k, (train_index, val_index) in enumerate(kf.split(X, y, group)):
979
  print('-' * 100)
980
+ print(f'Starting CV for group type: {split_type}, fold: {k}')
981
  print('-' * 100)
982
  train_df = train_val_df.iloc[train_index]
983
  val_df = train_val_df.iloc[val_index]
984
+
985
+ leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
986
+ leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
987
+
988
  stats = {
989
  'fold': k,
 
990
  'train_len': len(train_df),
991
  'val_len': len(val_df),
992
  'train_perc': len(train_df) / len(train_val_df),
993
  'val_perc': len(val_df) / len(train_val_df),
994
+ 'train_active (%)': train_df[active_col].sum() / len(train_df) * 100,
995
+ 'train_inactive (%)': (len(train_df) - train_df[active_col].sum()) / len(train_df) * 100,
996
+ 'val_active (%)': val_df[active_col].sum() / len(val_df) * 100,
997
+ 'val_inactive (%)': (len(val_df) - val_df[active_col].sum()) / len(val_df) * 100,
998
+ 'num_leaking_uniprot': len(leaking_uniprot),
999
+ 'num_leaking_smiles': len(leaking_smiles),
1000
+ 'train_leaking_uniprot (%)': len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df) * 100,
1001
+ 'train_leaking_smiles (%)': len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df) * 100,
 
1002
  }
1003
+ if split_type != 'random':
1004
+ stats['train_unique_groups'] = len(np.unique(group[train_index]))
1005
+ stats['val_unique_groups'] = len(np.unique(group[val_index]))
1006
+ report.append(stats)
1007
+
1008
  # Train and evaluate the model
1009
  model, trainer, metrics = hyperparameter_tuning_and_training(
1010
  train_df,
 
1012
  test_df,
1013
  fast_dev_run=fast_dev_run,
1014
  n_trials=n_trials,
1015
+ logger_name=f'protac_{active_name}_{split_type}_fold_{k}',
1016
+ active_label=active_col,
1017
  )
1018
  hparams = {p.strip('hparam_'): v for p, v in stats.items() if p.startswith('hparam_')}
1019
  stats.update(metrics)
 
1032
  val_df,
1033
  test_df,
1034
  fast_dev_run=fast_dev_run,
1035
+ logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
1036
+ active_label=active_col,
1037
  disabled_embeddings=disabled_embeddings,
1038
  **hparams,
1039
  )
 
1042
  del model
1043
  del trainer
1044
 
1045
+ report = pd.DataFrame(report)
1046
+ report.to_csv(
1047
+ f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}.csv',
1048
+ index=False,
1049
+ )
1050
 
1051
 
1052
  if __name__ == '__main__':