ribesstefano commited on
Commit
d6ec1f3
·
1 Parent(s): d36ec1d

Added chirality to Morgan FP and improved logging

Browse files
notebooks/protac_degradation_predictor.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/protac_degradation_predictor.py CHANGED
@@ -2,6 +2,8 @@ import optuna
2
  import pandas as pd
3
  from rdkit import Chem
4
  from rdkit.Chem import AllChem
 
 
5
 
6
  import h5py
7
  import numpy as np
@@ -113,19 +115,16 @@ for cell_line in protac_df['Cell Line Identifier'].unique():
113
 
114
  # ## Precompute Molecular Fingerprints
115
 
116
- morgan_radius = 15
117
- n_bits = 1024
118
-
119
- # fpgen = AllChem.GetAtomPairGenerator()
120
- # rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512)
121
- morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits)
122
 
123
  smiles2fp = {}
124
  for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
125
  # Get the fingerprint as a bit vector
126
  morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
127
- # rdkit_fp = rdkit_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
128
- # fp = np.concatenate([morgan_fp, rdkit_fp])
129
  smiles2fp[smiles] = morgan_fp
130
 
131
  # Count the number of unique SMILES and the number of unique Morgan fingerprints
@@ -143,8 +142,6 @@ print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)
143
  print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
144
 
145
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
146
- from rdkit import DataStructs
147
- from collections import defaultdict
148
 
149
  tanimoto_matrix = defaultdict(list)
150
  for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
@@ -656,6 +653,8 @@ def train_model(
656
  smote_k_neighbors=5,
657
  use_ored_activity=False if active_col == 'Active' else True,
658
  fast_dev_run=False,
 
 
659
  disabled_embeddings=[],
660
  ) -> tuple:
661
  """ Train a PROTAC model using the given datasets and hyperparameters.
@@ -704,7 +703,7 @@ def train_model(
704
  )
705
  logger = pl.loggers.TensorBoardLogger(
706
  save_dir='../logs',
707
- name='protac',
708
  )
709
  callbacks = [
710
  pl.callbacks.EarlyStopping(
@@ -722,12 +721,13 @@ def train_model(
722
  ]
723
  # Define Trainer
724
  trainer = pl.Trainer(
725
- logger=logger,
726
  callbacks=callbacks,
727
  max_epochs=max_epochs,
728
  fast_dev_run=fast_dev_run,
729
  enable_model_summary=False,
730
  enable_checkpointing=False,
 
731
  )
732
  model = PROTAC_Model(
733
  hidden_dim=hidden_dim,
@@ -780,6 +780,7 @@ def objective(
780
  learning_rate=learning_rate,
781
  max_epochs=max_epochs,
782
  smote_k_neighbors=smote_k_neighbors,
 
783
  fast_dev_run=fast_dev_run,
784
  )
785
 
@@ -798,6 +799,7 @@ def hyperparameter_tuning_and_training(
798
  test_df,
799
  fast_dev_run=False,
800
  n_trials=20,
 
801
  ) -> tuple:
802
  """ Hyperparameter tuning and training of a PROTAC model.
803
 
@@ -849,6 +851,8 @@ def hyperparameter_tuning_and_training(
849
  batch_size=best_batch_size,
850
  learning_rate=best_learning_rate,
851
  max_epochs=best_max_epochs,
 
 
852
  fast_dev_run=fast_dev_run,
853
  )
854
 
@@ -927,6 +931,7 @@ for group_type in ['random', 'uniprot', 'tanimoto']:
927
  test_df,
928
  fast_dev_run=False,
929
  n_trials=50,
 
930
  )
931
  stats.update(metrics)
932
  del model
 
2
  import pandas as pd
3
  from rdkit import Chem
4
  from rdkit.Chem import AllChem
5
+ from rdkit import DataStructs
6
+ from collections import defaultdict
7
 
8
  import h5py
9
  import numpy as np
 
115
 
116
  # ## Precompute Molecular Fingerprints
117
 
118
+ morgan_fpgen = AllChem.GetMorganGenerator(
119
+ radius=15,
120
+ fpSize=1024,
121
+ includeChirality=True,
122
+ )
 
123
 
124
  smiles2fp = {}
125
  for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
126
  # Get the fingerprint as a bit vector
127
  morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
 
 
128
  smiles2fp[smiles] = morgan_fp
129
 
130
  # Count the number of unique SMILES and the number of unique Morgan fingerprints
 
142
  print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
143
 
144
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
 
 
145
 
146
  tanimoto_matrix = defaultdict(list)
147
  for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
 
653
  smote_k_neighbors=5,
654
  use_ored_activity=False if active_col == 'Active' else True,
655
  fast_dev_run=False,
656
+ use_logger=True,
657
+ logger_name='protac',
658
  disabled_embeddings=[],
659
  ) -> tuple:
660
  """ Train a PROTAC model using the given datasets and hyperparameters.
 
703
  )
704
  logger = pl.loggers.TensorBoardLogger(
705
  save_dir='../logs',
706
+ name=logger_name,
707
  )
708
  callbacks = [
709
  pl.callbacks.EarlyStopping(
 
721
  ]
722
  # Define Trainer
723
  trainer = pl.Trainer(
724
+ logger=logger if use_logger else False,
725
  callbacks=callbacks,
726
  max_epochs=max_epochs,
727
  fast_dev_run=fast_dev_run,
728
  enable_model_summary=False,
729
  enable_checkpointing=False,
730
+ enable_progress_bar=False,
731
  )
732
  model = PROTAC_Model(
733
  hidden_dim=hidden_dim,
 
780
  learning_rate=learning_rate,
781
  max_epochs=max_epochs,
782
  smote_k_neighbors=smote_k_neighbors,
783
+ use_logger=False,
784
  fast_dev_run=fast_dev_run,
785
  )
786
 
 
799
  test_df,
800
  fast_dev_run=False,
801
  n_trials=20,
802
+ logger_name='protac_hparam_search',
803
  ) -> tuple:
804
  """ Hyperparameter tuning and training of a PROTAC model.
805
 
 
851
  batch_size=best_batch_size,
852
  learning_rate=best_learning_rate,
853
  max_epochs=best_max_epochs,
854
+ use_logger=True,
855
+ logger_name=logger_name,
856
  fast_dev_run=fast_dev_run,
857
  )
858
 
 
931
  test_df,
932
  fast_dev_run=False,
933
  n_trials=50,
934
+ logger_name=f'protac_{group_type}_fold_{k}',
935
  )
936
  stats.update(metrics)
937
  del model