Commit
·
4d17fea
1
Parent(s):
165d38a
Updated protac dataset to handle missing values in embedding dictionaries
Browse files
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
from typing import Literal, List, Tuple, Optional, Dict
|
|
|
3 |
|
4 |
from .pytorch_models import train_model
|
5 |
from .sklearn_models import (
|
@@ -141,7 +142,7 @@ def hyperparameter_tuning_and_training(
|
|
141 |
if os.path.exists(study_filename):
|
142 |
study = joblib.load(study_filename)
|
143 |
study_loaded = True
|
144 |
-
|
145 |
|
146 |
if not study_loaded:
|
147 |
study.optimize(
|
@@ -253,9 +254,23 @@ def hyperparameter_tuning_and_training_sklearn(
|
|
253 |
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
254 |
active_label: str = 'Active',
|
255 |
n_trials: int = 50,
|
256 |
-
logger_name: str = '
|
257 |
study_filename: Optional[str] = None,
|
258 |
) -> Tuple:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
# Set the verbosity of Optuna
|
260 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
261 |
# Create an Optuna study object
|
@@ -267,7 +282,7 @@ def hyperparameter_tuning_and_training_sklearn(
|
|
267 |
if os.path.exists(study_filename):
|
268 |
study = joblib.load(study_filename)
|
269 |
study_loaded = True
|
270 |
-
|
271 |
|
272 |
if not study_loaded:
|
273 |
study.optimize(
|
|
|
1 |
import os
|
2 |
from typing import Literal, List, Tuple, Optional, Dict
|
3 |
+
import logging
|
4 |
|
5 |
from .pytorch_models import train_model
|
6 |
from .sklearn_models import (
|
|
|
142 |
if os.path.exists(study_filename):
|
143 |
study = joblib.load(study_filename)
|
144 |
study_loaded = True
|
145 |
+
logging.info(f'Loaded study from {study_filename}')
|
146 |
|
147 |
if not study_loaded:
|
148 |
study.optimize(
|
|
|
254 |
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
255 |
active_label: str = 'Active',
|
256 |
n_trials: int = 50,
|
257 |
+
logger_name: str = 'protac_hparam_search_sklearn',
|
258 |
study_filename: Optional[str] = None,
|
259 |
) -> Tuple:
|
260 |
+
""" Hyperparameter tuning and training of a PROTAC model.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
train_df (pd.DataFrame): The training set.
|
264 |
+
val_df (pd.DataFrame): The validation set.
|
265 |
+
test_df (pd.DataFrame): The test set.
|
266 |
+
model_type (str): The model type.
|
267 |
+
n_trials (int): The number of hyperparameter optimization trials.
|
268 |
+
logger_name (str): The name of the logger. Unused, for compatibility with hyperparameter_tuning_and_training.
|
269 |
+
active_label (str): The active label column.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
tuple: The trained model and the best metrics.
|
273 |
+
"""
|
274 |
# Set the verbosity of Optuna
|
275 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
276 |
# Create an Optuna study object
|
|
|
282 |
if os.path.exists(study_filename):
|
283 |
study = joblib.load(study_filename)
|
284 |
study_loaded = True
|
285 |
+
logging.info(f'Loaded study from {study_filename}')
|
286 |
|
287 |
if not study_loaded:
|
288 |
study.optimize(
|
protac_degradation_predictor/protac_dataset.py
CHANGED
@@ -41,13 +41,17 @@ class PROTAC_Dataset(Dataset):
|
|
41 |
protein2embedding.keys())[0]].shape[0]
|
42 |
self.cell_emb_dim = cell2embedding[list(
|
43 |
cell2embedding.keys())[0]].shape[0]
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Look up the embeddings
|
46 |
self.data = pd.DataFrame({
|
47 |
-
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp
|
48 |
-
'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding
|
49 |
-
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding
|
50 |
-
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding
|
51 |
self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
|
52 |
})
|
53 |
|
|
|
41 |
protein2embedding.keys())[0]].shape[0]
|
42 |
self.cell_emb_dim = cell2embedding[list(
|
43 |
cell2embedding.keys())[0]].shape[0]
|
44 |
+
|
45 |
+
self.default_smiles_emb = np.zeros(self.smiles_emb_dim)
|
46 |
+
self.default_protein_emb = np.zeros(self.protein_emb_dim)
|
47 |
+
self.default_cell_emb = np.zeros(self.cell_emb_dim)
|
48 |
|
49 |
# Look up the embeddings
|
50 |
self.data = pd.DataFrame({
|
51 |
+
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp.get(x, self.default_smiles_emb).astype(np.float32)).tolist(),
|
52 |
+
'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(),
|
53 |
+
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(),
|
54 |
+
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding.get(x, self.default_cell_emb).astype(np.float32)).tolist(),
|
55 |
self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
|
56 |
})
|
57 |
|
src/run_experiments.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import sys
|
3 |
from collections import defaultdict
|
4 |
import warnings
|
5 |
-
|
6 |
|
7 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
8 |
|
@@ -201,6 +201,7 @@ def main(
|
|
201 |
fast_dev_run: bool = False,
|
202 |
test_split: float = 0.2,
|
203 |
cv_n_splits: int = 5,
|
|
|
204 |
):
|
205 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
206 |
|
@@ -245,9 +246,8 @@ def main(
|
|
245 |
# Cross-Validation Training
|
246 |
report = []
|
247 |
for split_type, indeces in test_indeces.items():
|
248 |
-
|
249 |
-
|
250 |
-
train_val_df = active_df[~active_df.index.isin(test_df.index)]
|
251 |
|
252 |
# Get the CV object
|
253 |
if split_type == 'random':
|
@@ -297,57 +297,86 @@ def main(
|
|
297 |
if split_type != 'random':
|
298 |
stats['train_unique_groups'] = len(np.unique(group[train_index]))
|
299 |
stats['val_unique_groups'] = len(np.unique(group[val_index]))
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
|
353 |
if __name__ == '__main__':
|
|
|
2 |
import sys
|
3 |
from collections import defaultdict
|
4 |
import warnings
|
5 |
+
import logging
|
6 |
|
7 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
8 |
|
|
|
201 |
fast_dev_run: bool = False,
|
202 |
test_split: float = 0.2,
|
203 |
cv_n_splits: int = 5,
|
204 |
+
run_sklearn: bool = False,
|
205 |
):
|
206 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
207 |
|
|
|
246 |
# Cross-Validation Training
|
247 |
report = []
|
248 |
for split_type, indeces in test_indeces.items():
|
249 |
+
test_df = active_df.loc[indeces].copy()
|
250 |
+
train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
|
|
|
251 |
|
252 |
# Get the CV object
|
253 |
if split_type == 'random':
|
|
|
297 |
if split_type != 'random':
|
298 |
stats['train_unique_groups'] = len(np.unique(group[train_index]))
|
299 |
stats['val_unique_groups'] = len(np.unique(group[val_index]))
|
300 |
+
|
301 |
+
# At each fold, train and evaluate the Pytorch model
|
302 |
+
if split_type != 'tanimoto' or run_sklearn:
|
303 |
+
logging.info(f'Skipping Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
|
304 |
+
continue
|
305 |
+
else:
|
306 |
+
logging.info(f'Starting Pytorch model training on fold {k} with split type {split_type} and test split {test_split}.')
|
307 |
+
# Train and evaluate the model
|
308 |
+
model, trainer, metrics = pdp.hyperparameter_tuning_and_training(
|
309 |
+
protein2embedding,
|
310 |
+
cell2embedding,
|
311 |
+
smiles2fp,
|
312 |
+
train_df,
|
313 |
+
val_df,
|
314 |
+
test_df,
|
315 |
+
fast_dev_run=fast_dev_run,
|
316 |
+
n_trials=n_trials,
|
317 |
+
logger_name=f'protac_{active_name}_{split_type}_fold_{k}_test_split_{test_split}',
|
318 |
+
active_label=active_col,
|
319 |
+
study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}.pkl',
|
320 |
+
)
|
321 |
+
hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
|
322 |
+
stats.update(metrics)
|
323 |
+
stats['model_type'] = 'Pytorch'
|
324 |
+
report.append(stats.copy())
|
325 |
+
del model
|
326 |
+
del trainer
|
327 |
+
|
328 |
+
# Ablation study: disable embeddings at a time
|
329 |
+
for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
|
330 |
+
print('-' * 100)
|
331 |
+
print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
|
332 |
+
print('-' * 100)
|
333 |
+
stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
|
334 |
+
model, trainer, metrics = pdp.train_model(
|
335 |
+
protein2embedding,
|
336 |
+
cell2embedding,
|
337 |
+
smiles2fp,
|
338 |
+
train_df,
|
339 |
+
val_df,
|
340 |
+
test_df,
|
341 |
+
fast_dev_run=fast_dev_run,
|
342 |
+
logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
|
343 |
+
active_label=active_col,
|
344 |
+
disabled_embeddings=disabled_embeddings,
|
345 |
+
**hparams,
|
346 |
+
)
|
347 |
+
stats.update(metrics)
|
348 |
+
report.append(stats.copy())
|
349 |
+
del model
|
350 |
+
del trainer
|
351 |
+
|
352 |
+
# At each fold, train and evaluate sklearn models
|
353 |
+
if run_sklearn:
|
354 |
+
for model_type in ['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting']:
|
355 |
+
logging.info(f'Starting sklearn model {model_type} training on fold {k} with split type {split_type} and test split {test_split}.')
|
356 |
+
# Train and evaluate sklearn models
|
357 |
+
model, metrics = pdp.hyperparameter_tuning_and_training_sklearn(
|
358 |
+
protein2embedding=protein2embedding,
|
359 |
+
cell2embedding=cell2embedding,
|
360 |
+
smiles2fp=smiles2fp,
|
361 |
+
train_df=train_df,
|
362 |
+
val_df=val_df,
|
363 |
+
test_df=test_df,
|
364 |
+
model_type=model_type,
|
365 |
+
active_label=active_col,
|
366 |
+
n_trials=n_trials,
|
367 |
+
study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}_{model_type.lower()}.pkl',
|
368 |
+
)
|
369 |
+
hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
|
370 |
+
stats['model_type'] = model_type
|
371 |
+
stats.update(metrics)
|
372 |
+
report.append(stats.copy())
|
373 |
+
|
374 |
+
# Save the report at the end of each split type
|
375 |
+
report_df = pd.DataFrame(report)
|
376 |
+
report_df.to_csv(
|
377 |
+
f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}{"_sklearn" if run_sklearn else ""}.csv',
|
378 |
+
index=False,
|
379 |
+
)
|
380 |
|
381 |
|
382 |
if __name__ == '__main__':
|