ribesstefano commited on
Commit
4d17fea
·
1 Parent(s): 165d38a

Updated protac dataset to handle missing values in embedding dictionaries

Browse files
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  from typing import Literal, List, Tuple, Optional, Dict
 
3
 
4
  from .pytorch_models import train_model
5
  from .sklearn_models import (
@@ -141,7 +142,7 @@ def hyperparameter_tuning_and_training(
141
  if os.path.exists(study_filename):
142
  study = joblib.load(study_filename)
143
  study_loaded = True
144
- print(f'Loaded study from {study_filename}')
145
 
146
  if not study_loaded:
147
  study.optimize(
@@ -253,9 +254,23 @@ def hyperparameter_tuning_and_training_sklearn(
253
  model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
254
  active_label: str = 'Active',
255
  n_trials: int = 50,
256
- logger_name: str = 'protac_hparam_search',
257
  study_filename: Optional[str] = None,
258
  ) -> Tuple:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  # Set the verbosity of Optuna
260
  optuna.logging.set_verbosity(optuna.logging.WARNING)
261
  # Create an Optuna study object
@@ -267,7 +282,7 @@ def hyperparameter_tuning_and_training_sklearn(
267
  if os.path.exists(study_filename):
268
  study = joblib.load(study_filename)
269
  study_loaded = True
270
- print(f'Loaded study from {study_filename}')
271
 
272
  if not study_loaded:
273
  study.optimize(
 
1
  import os
2
  from typing import Literal, List, Tuple, Optional, Dict
3
+ import logging
4
 
5
  from .pytorch_models import train_model
6
  from .sklearn_models import (
 
142
  if os.path.exists(study_filename):
143
  study = joblib.load(study_filename)
144
  study_loaded = True
145
+ logging.info(f'Loaded study from {study_filename}')
146
 
147
  if not study_loaded:
148
  study.optimize(
 
254
  model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
255
  active_label: str = 'Active',
256
  n_trials: int = 50,
257
+ logger_name: str = 'protac_hparam_search_sklearn',
258
  study_filename: Optional[str] = None,
259
  ) -> Tuple:
260
+ """ Hyperparameter tuning and training of a PROTAC model.
261
+
262
+ Args:
263
+ train_df (pd.DataFrame): The training set.
264
+ val_df (pd.DataFrame): The validation set.
265
+ test_df (pd.DataFrame): The test set.
266
+ model_type (str): The model type.
267
+ n_trials (int): The number of hyperparameter optimization trials.
268
+ logger_name (str): The name of the logger. Unused, for compatibility with hyperparameter_tuning_and_training.
269
+ active_label (str): The active label column.
270
+
271
+ Returns:
272
+ tuple: The trained model and the best metrics.
273
+ """
274
  # Set the verbosity of Optuna
275
  optuna.logging.set_verbosity(optuna.logging.WARNING)
276
  # Create an Optuna study object
 
282
  if os.path.exists(study_filename):
283
  study = joblib.load(study_filename)
284
  study_loaded = True
285
+ logging.info(f'Loaded study from {study_filename}')
286
 
287
  if not study_loaded:
288
  study.optimize(
protac_degradation_predictor/protac_dataset.py CHANGED
@@ -41,13 +41,17 @@ class PROTAC_Dataset(Dataset):
41
  protein2embedding.keys())[0]].shape[0]
42
  self.cell_emb_dim = cell2embedding[list(
43
  cell2embedding.keys())[0]].shape[0]
 
 
 
 
44
 
45
  # Look up the embeddings
46
  self.data = pd.DataFrame({
47
- 'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(),
48
- 'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
49
- 'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
50
- 'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(),
51
  self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
52
  })
53
 
 
41
  protein2embedding.keys())[0]].shape[0]
42
  self.cell_emb_dim = cell2embedding[list(
43
  cell2embedding.keys())[0]].shape[0]
44
+
45
+ self.default_smiles_emb = np.zeros(self.smiles_emb_dim)
46
+ self.default_protein_emb = np.zeros(self.protein_emb_dim)
47
+ self.default_cell_emb = np.zeros(self.cell_emb_dim)
48
 
49
  # Look up the embeddings
50
  self.data = pd.DataFrame({
51
+ 'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp.get(x, self.default_smiles_emb).astype(np.float32)).tolist(),
52
+ 'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(),
53
+ 'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(),
54
+ 'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding.get(x, self.default_cell_emb).astype(np.float32)).tolist(),
55
  self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
56
  })
57
 
src/run_experiments.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import sys
3
  from collections import defaultdict
4
  import warnings
5
-
6
 
7
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
 
@@ -201,6 +201,7 @@ def main(
201
  fast_dev_run: bool = False,
202
  test_split: float = 0.2,
203
  cv_n_splits: int = 5,
 
204
  ):
205
  """ Train a PROTAC model using the given datasets and hyperparameters.
206
 
@@ -245,9 +246,8 @@ def main(
245
  # Cross-Validation Training
246
  report = []
247
  for split_type, indeces in test_indeces.items():
248
- # active_df = protac_df[protac_df[active_col].notna()].copy()
249
- test_df = active_df.loc[indeces]
250
- train_val_df = active_df[~active_df.index.isin(test_df.index)]
251
 
252
  # Get the CV object
253
  if split_type == 'random':
@@ -297,57 +297,86 @@ def main(
297
  if split_type != 'random':
298
  stats['train_unique_groups'] = len(np.unique(group[train_index]))
299
  stats['val_unique_groups'] = len(np.unique(group[val_index]))
300
-
301
- print(stats)
302
- # # Train and evaluate the model
303
- # model, trainer, metrics = pdp.hyperparameter_tuning_and_training(
304
- # protein2embedding,
305
- # cell2embedding,
306
- # smiles2fp,
307
- # train_df,
308
- # val_df,
309
- # test_df,
310
- # fast_dev_run=fast_dev_run,
311
- # n_trials=n_trials,
312
- # logger_name=f'protac_{active_name}_{split_type}_fold_{k}_test_split_{test_split}',
313
- # active_label=active_col,
314
- # study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}.pkl',
315
- # )
316
- # hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
317
- # stats.update(metrics)
318
- # report.append(stats.copy())
319
- # del model
320
- # del trainer
321
-
322
- # # Ablation study: disable embeddings at a time
323
- # for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
324
- # print('-' * 100)
325
- # print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
326
- # print('-' * 100)
327
- # stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
328
- # model, trainer, metrics = pdp.train_model(
329
- # protein2embedding,
330
- # cell2embedding,
331
- # smiles2fp,
332
- # train_df,
333
- # val_df,
334
- # test_df,
335
- # fast_dev_run=fast_dev_run,
336
- # logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
337
- # active_label=active_col,
338
- # disabled_embeddings=disabled_embeddings,
339
- # **hparams,
340
- # )
341
- # stats.update(metrics)
342
- # report.append(stats.copy())
343
- # del model
344
- # del trainer
345
-
346
- # report_df = pd.DataFrame(report)
347
- # report_df.to_csv(
348
- # f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}_sklearn.csv',
349
- # index=False,
350
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
 
353
  if __name__ == '__main__':
 
2
  import sys
3
  from collections import defaultdict
4
  import warnings
5
+ import logging
6
 
7
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
 
 
201
  fast_dev_run: bool = False,
202
  test_split: float = 0.2,
203
  cv_n_splits: int = 5,
204
+ run_sklearn: bool = False,
205
  ):
206
  """ Train a PROTAC model using the given datasets and hyperparameters.
207
 
 
246
  # Cross-Validation Training
247
  report = []
248
  for split_type, indeces in test_indeces.items():
249
+ test_df = active_df.loc[indeces].copy()
250
+ train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
 
251
 
252
  # Get the CV object
253
  if split_type == 'random':
 
297
  if split_type != 'random':
298
  stats['train_unique_groups'] = len(np.unique(group[train_index]))
299
  stats['val_unique_groups'] = len(np.unique(group[val_index]))
300
+
301
+ # At each fold, train and evaluate the Pytorch model
302
+ if split_type != 'tanimoto' or run_sklearn:
303
+ logging.info(f'Skipping Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
304
+ continue
305
+ else:
306
+ logging.info(f'Starting Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
307
+ # Train and evaluate the model
308
+ model, trainer, metrics = pdp.hyperparameter_tuning_and_training(
309
+ protein2embedding,
310
+ cell2embedding,
311
+ smiles2fp,
312
+ train_df,
313
+ val_df,
314
+ test_df,
315
+ fast_dev_run=fast_dev_run,
316
+ n_trials=n_trials,
317
+ logger_name=f'protac_{active_name}_{split_type}_fold_{k}_test_split_{test_split}',
318
+ active_label=active_col,
319
+ study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}.pkl',
320
+ )
321
+ hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
322
+ stats.update(metrics)
323
+ stats['model_type'] = 'Pytorch'
324
+ report.append(stats.copy())
325
+ del model
326
+ del trainer
327
+
328
+ # Ablation study: disable embeddings at a time
329
+ for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
330
+ print('-' * 100)
331
+ print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
332
+ print('-' * 100)
333
+ stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
334
+ model, trainer, metrics = pdp.train_model(
335
+ protein2embedding,
336
+ cell2embedding,
337
+ smiles2fp,
338
+ train_df,
339
+ val_df,
340
+ test_df,
341
+ fast_dev_run=fast_dev_run,
342
+ logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
343
+ active_label=active_col,
344
+ disabled_embeddings=disabled_embeddings,
345
+ **hparams,
346
+ )
347
+ stats.update(metrics)
348
+ report.append(stats.copy())
349
+ del model
350
+ del trainer
351
+
352
+ # At each fold, train and evaluate sklearn models
353
+ if run_sklearn:
354
+ for model_type in ['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting']:
355
+ logging.info(f'Starting sklearn model {model_type} training on fold {k} with split type {split_type} and test split {test_split}.')
356
+ # Train and evaluate sklearn models
357
+ model, metrics = pdp.hyperparameter_tuning_and_training_sklearn(
358
+ protein2embedding=protein2embedding,
359
+ cell2embedding=cell2embedding,
360
+ smiles2fp=smiles2fp,
361
+ train_df=train_df,
362
+ val_df=val_df,
363
+ test_df=test_df,
364
+ model_type=model_type,
365
+ active_label=active_col,
366
+ n_trials=n_trials,
367
+ study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}_{model_type.lower()}.pkl',
368
+ )
369
+ hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
370
+ stats['model_type'] = model_type
371
+ stats.update(metrics)
372
+ report.append(stats.copy())
373
+
374
+ # Save the report at the end of each split type
375
+ report_df = pd.DataFrame(report)
376
+ report_df.to_csv(
377
+ f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}{"_sklearn" if run_sklearn else ""}.csv',
378
+ index=False,
379
+ )
380
 
381
 
382
  if __name__ == '__main__':