ribesstefano's picture
Updated README
f3d4b52
raw
history blame
26.9 kB
import os
from typing import Literal, List, Tuple, Optional, Dict
import logging
from .pytorch_models import (
train_model,
PROTAC_Model,
evaluate_model,
)
from .protac_dataset import get_datasets
from .sklearn_models import (
train_sklearn_model,
suggest_random_forest,
suggest_logistic_regression,
suggest_svc,
suggest_gradient_boosting,
)
import torch
import optuna
from optuna.samplers import TPESampler
import joblib
import pandas as pd
from sklearn.ensemble import (
RandomForestClassifier,
GradientBoostingClassifier,
)
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import (
StratifiedKFold,
StratifiedGroupKFold,
)
import numpy as np
import pytorch_lightning as pl
from torchmetrics import (
Accuracy,
AUROC,
Precision,
Recall,
F1Score,
)
def get_dataframe_stats(
train_df = None,
val_df = None,
test_df = None,
active_label = 'Active',
) -> Dict:
""" Get some statistics from the dataframes.
Args:
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
test_df (pd.DataFrame): The test set.
"""
stats = {}
if train_df is not None:
stats['train_len'] = len(train_df)
stats['train_active_perc'] = train_df[active_label].sum() / len(train_df)
stats['train_inactive_perc'] = (len(train_df) - train_df[active_label].sum()) / len(train_df)
stats['train_avg_tanimoto_dist'] = train_df['Avg Tanimoto'].mean()
if val_df is not None:
stats['val_len'] = len(val_df)
stats['val_active_perc'] = val_df[active_label].sum() / len(val_df)
stats['val_inactive_perc'] = (len(val_df) - val_df[active_label].sum()) / len(val_df)
stats['val_avg_tanimoto_dist'] = val_df['Avg Tanimoto'].mean()
if test_df is not None:
stats['test_len'] = len(test_df)
stats['test_active_perc'] = test_df[active_label].sum() / len(test_df)
stats['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
stats['test_avg_tanimoto_dist'] = test_df['Avg Tanimoto'].mean()
if train_df is not None and val_df is not None:
leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
stats['num_leaking_uniprot_train_val'] = len(leaking_uniprot)
stats['num_leaking_smiles_train_val'] = len(leaking_smiles)
stats['perc_leaking_uniprot_train_val'] = len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df)
stats['perc_leaking_smiles_train_val'] = len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df)
if train_df is not None and test_df is not None:
leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(test_df['Uniprot'])))
leaking_smiles = list(set(train_df['Smiles']).intersection(set(test_df['Smiles'])))
stats['num_leaking_uniprot_train_test'] = len(leaking_uniprot)
stats['num_leaking_smiles_train_test'] = len(leaking_smiles)
stats['perc_leaking_uniprot_train_test'] = len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df)
stats['perc_leaking_smiles_train_test'] = len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df)
return stats
def get_majority_vote_metrics(
test_preds: List,
test_df: pd.DataFrame,
active_label: str = 'Active',
) -> Dict:
""" Get the majority vote metrics. """
test_preds = torch.stack(test_preds)
test_preds, _ = torch.mode(test_preds, dim=0)
y = torch.tensor(test_df[active_label].tolist())
# Measure the test accuracy and ROC AUC
majority_vote_metrics = {
'test_acc': Accuracy(task='binary')(test_preds, y).item(),
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
'test_precision': Precision(task='binary')(test_preds, y).item(),
'test_recall': Recall(task='binary')(test_preds, y).item(),
'test_f1_score': F1Score(task='binary')(test_preds, y).item(),
}
return majority_vote_metrics
def pytorch_model_objective(
trial: optuna.Trial,
protein2embedding: Dict,
cell2embedding: Dict,
smiles2fp: Dict,
train_val_df: pd.DataFrame,
kf: StratifiedKFold | StratifiedGroupKFold,
groups: Optional[np.array] = None,
test_df: Optional[pd.DataFrame] = None,
hidden_dim_options: List[int] = [256, 512, 768],
batch_size_options: List[int] = [8, 16, 32],
learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
smote_k_neighbors_options: List[int] = list(range(3, 16)),
dropout_options: Tuple[float, float] = (0.1, 0.5),
fast_dev_run: bool = False,
active_label: str = 'Active',
disabled_embeddings: List[str] = [],
max_epochs: int = 100,
use_logger: bool = False,
logger_save_dir: str = 'logs',
logger_name: str = 'cv_model',
enable_checkpointing: bool = False,
) -> float:
""" Objective function for hyperparameter optimization.
Args:
trial (optuna.Trial): The Optuna trial object.
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
hidden_dim_options (List[int]): The hidden dimension options.
batch_size_options (List[int]): The batch size options.
learning_rate_options (Tuple[float, float]): The learning rate options.
smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
dropout_options (Tuple[float, float]): The dropout options.
fast_dev_run (bool): Whether to run a fast development run.
active_label (str): The active label column.
disabled_embeddings (List[str]): The list of disabled embeddings.
"""
# Suggest hyperparameters to be used accross the CV folds
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
batch_size = 128 # trial.suggest_categorical('batch_size', batch_size_options)
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
use_smote = trial.suggest_categorical('use_smote', [True, False])
apply_scaling = True # trial.suggest_categorical('apply_scaling', [True, False])
dropout = trial.suggest_float('dropout', *dropout_options)
use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
# Start the CV over the folds
X = train_val_df.copy().drop(columns=active_label)
y = train_val_df[active_label].tolist()
report = []
val_preds = []
test_preds = []
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
# Get the train and val sets
train_df = train_val_df.iloc[train_index]
val_df = train_val_df.iloc[val_index]
# Get some statistics from the dataframes
stats = {
'model_type': 'Pytorch',
'fold': k,
'train_len': len(train_df),
'val_len': len(val_df),
'train_perc': len(train_df) / len(train_val_df),
'val_perc': len(val_df) / len(train_val_df),
}
stats.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
if groups is not None:
stats['train_unique_groups'] = len(np.unique(groups[train_index]))
stats['val_unique_groups'] = len(np.unique(groups[val_index]))
# At each fold, train and evaluate the Pytorch model
# Train the model with the current set of hyperparameters
ret = train_model(
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_df=train_df,
val_df=val_df,
test_df=test_df,
hidden_dim=hidden_dim,
batch_size=batch_size,
learning_rate=learning_rate,
dropout=dropout,
use_batch_norm=use_batch_norm,
max_epochs=max_epochs,
smote_k_neighbors=smote_k_neighbors,
apply_scaling=apply_scaling,
use_smote=use_smote,
fast_dev_run=fast_dev_run,
active_label=active_label,
return_predictions=True,
disabled_embeddings=disabled_embeddings,
use_logger=use_logger,
logger_save_dir=logger_save_dir,
logger_name=f'{logger_name}_fold{k}',
enable_checkpointing=enable_checkpointing,
)
if test_df is not None:
_, _, metrics, val_pred, test_pred = ret
test_preds.append(test_pred)
else:
_, _, metrics, val_pred = ret
stats.update(metrics)
report.append(stats.copy())
val_preds.append(val_pred)
# Save the report in the trial
trial.set_user_attr('report', report)
# Get the majority vote for the test predictions
if test_df is not None and not fast_dev_run:
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
logging.info(f'Majority vote metrics: {majority_vote_metrics}')
# Get the average validation accuracy and ROC AUC accross the folds
val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
# Optuna aims to minimize the pytorch_model_objective
return - val_roc_auc
# # Get the majority vote for the test predictions
# if test_df is not None and not fast_dev_run:
# majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
# majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
# trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
# logging.info(f'Majority vote metrics: {majority_vote_metrics}')
# # Get the average validation accuracy and ROC AUC accross the folds
# val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
# # Optuna aims to minimize the pytorch_model_objective
# return - val_roc_auc
def hyperparameter_tuning_and_training(
protein2embedding: Dict,
cell2embedding: Dict,
smiles2fp: Dict,
train_val_df: pd.DataFrame,
test_df: pd.DataFrame,
kf: StratifiedKFold | StratifiedGroupKFold,
groups: Optional[np.array] = None,
split_type: str = 'random',
n_models_for_test: int = 3,
fast_dev_run: bool = False,
n_trials: int = 50,
logger_save_dir: str = 'logs',
logger_name: str = 'protac_hparam_search',
active_label: str = 'Active',
max_epochs: int = 100,
study_filename: Optional[str] = None,
force_study: bool = False,
) -> tuple:
""" Hyperparameter tuning and training of a PROTAC model.
Args:
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
test_df (pd.DataFrame): The test set.
fast_dev_run (bool): Whether to run a fast development run.
n_trials (int): The number of hyperparameter optimization trials.
logger_name (str): The name of the logger.
active_label (str): The active label column.
disabled_embeddings (List[str]): The list of disabled embeddings.
Returns:
tuple: The trained model, the trainer, and the best metrics.
"""
pl.seed_everything(42)
# Define the search space
hidden_dim_options = [16, 32, 64, 128, 256] #, 512]
batch_size_options = [128, 128] # [4, 8, 16, 32, 64, 128]
learning_rate_options = (1e-6, 1e-3) # min and max values for loguniform distribution
smote_k_neighbors_options = list(range(3, 16))
# NOTE: We want Optuna to explore the combination (very low dropout, very
# small hidden_dim)
dropout_options = (0, 0.5)
# Set the verbosity of Optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
# Create an Optuna study object
sampler = TPESampler(seed=42, multivariate=True)
study = optuna.create_study(direction='minimize', sampler=sampler)
study_loaded = False
if study_filename and not force_study:
if os.path.exists(study_filename):
study = joblib.load(study_filename)
study_loaded = True
logging.info(f'Loaded study from {study_filename}')
logging.info(f'Study best params: {study.best_params}')
if not study_loaded or force_study:
study.optimize(
lambda trial: pytorch_model_objective(
trial=trial,
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_val_df=train_val_df,
kf=kf,
groups=groups,
test_df=test_df,
hidden_dim_options=hidden_dim_options,
batch_size_options=batch_size_options,
learning_rate_options=learning_rate_options,
smote_k_neighbors_options=smote_k_neighbors_options,
dropout_options=dropout_options,
fast_dev_run=fast_dev_run,
active_label=active_label,
max_epochs=max_epochs,
disabled_embeddings=[],
),
n_trials=n_trials,
)
if study_filename:
joblib.dump(study, study_filename)
cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
hparam_report = pd.DataFrame([study.best_params])
# Train the best CV models and store their checkpoints by running the objective
pytorch_model_objective(
trial=study.best_trial,
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_val_df=train_val_df,
kf=kf,
groups=groups,
test_df=test_df,
hidden_dim_options=hidden_dim_options,
batch_size_options=batch_size_options,
learning_rate_options=learning_rate_options,
smote_k_neighbors_options=smote_k_neighbors_options,
dropout_options=dropout_options,
fast_dev_run=fast_dev_run,
active_label=active_label,
max_epochs=max_epochs,
disabled_embeddings=[],
use_logger=True,
logger_save_dir=logger_save_dir,
logger_name=f'{logger_name}_{split_type}_cv_model',
enable_checkpointing=True,
)
# Retrain N models with the best hyperparameters (measure model uncertainty)
best_models = []
test_report = []
test_preds = []
dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
for i in range(n_models_for_test):
pl.seed_everything(42 + i + 1)
model, trainer, metrics, test_pred = train_model(
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_df=train_val_df,
val_df=test_df,
fast_dev_run=fast_dev_run,
active_label=active_label,
max_epochs=max_epochs,
disabled_embeddings=[],
use_logger=True,
logger_save_dir=logger_save_dir,
logger_name=f'{logger_name}_best_model_n{i}',
enable_checkpointing=True,
checkpoint_model_name=f'best_model_n{i}_{split_type}',
return_predictions=True,
batch_size=128,
apply_scaling=True,
**study.best_params,
)
# Rename the keys in the metrics dictionary
metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
metrics['model_type'] = 'Pytorch'
metrics['test_model_id'] = i
metrics.update(dfs_stats)
test_report.append(metrics.copy())
test_preds.append(test_pred)
best_models.append({'model': model, 'trainer': trainer})
test_report = pd.DataFrame(test_report)
# Get the majority vote for the test predictions
if not fast_dev_run:
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
majority_vote_metrics_cv['cv_models'] = True
majority_vote_report = pd.DataFrame([
majority_vote_metrics,
majority_vote_metrics_cv,
])
majority_vote_report['model_type'] = 'Pytorch'
majority_vote_report['split_type'] = split_type
# Ablation study: disable embeddings at a time
ablation_report = []
dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
disabled_embeddings_combinations = [
['e3'],
['poi'],
['cell'],
['smiles'],
['e3', 'cell'],
['poi', 'e3'],
['poi', 'e3', 'cell'],
]
for disabled_embeddings in disabled_embeddings_combinations:
logging.info('-' * 100)
logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
logging.info('-' * 100)
disabled_embeddings_str = 'disabled ' + ' '.join(disabled_embeddings)
test_preds = []
for i, model_trainer in enumerate(best_models):
logging.info(f'Evaluating model n.{i} on {disabled_embeddings_str}.')
model = model_trainer['model']
trainer = model_trainer['trainer']
_, test_ds, _ = get_datasets(
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_df=train_val_df,
val_df=test_df,
disabled_embeddings=disabled_embeddings,
active_label=active_label,
scaler=model.scalers,
use_single_scaler=model.join_embeddings == 'beginning',
)
ret = evaluate_model(model, trainer, test_ds, batch_size=128)
# NOTE: We are passing the test set as the validation set argument
# Rename the keys in the metrics dictionary
test_preds.append(ret['val_pred'])
ret['val_metrics'] = {k.replace('val_', 'test_'): v for k, v in ret['val_metrics'].items()}
ret['val_metrics'].update(dfs_stats)
ret['val_metrics']['majority_vote'] = False
ret['val_metrics']['model_type'] = 'Pytorch'
ret['val_metrics']['disabled_embeddings'] = disabled_embeddings_str
ablation_report.append(ret['val_metrics'].copy())
# Get the majority vote for the test predictions
if not fast_dev_run:
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
majority_vote_metrics.update(dfs_stats)
majority_vote_metrics['majority_vote'] = True
majority_vote_metrics['model_type'] = 'Pytorch'
majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str
ablation_report.append(majority_vote_metrics.copy())
# _, _, metrics = train_model(
# protein2embedding=protein2embedding,
# cell2embedding=cell2embedding,
# smiles2fp=smiles2fp,
# train_df=train_val_df,
# val_df=test_df,
# fast_dev_run=fast_dev_run,
# active_label=active_label,
# max_epochs=max_epochs,
# use_logger=False,
# logger_save_dir=logger_save_dir,
# logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
# disabled_embeddings=disabled_embeddings,
# batch_size=128,
# apply_scaling=True,
# **study.best_params,
# )
# # Rename the keys in the metrics dictionary
# metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
# metrics['disabled_embeddings'] = disabled_embeddings_str
# metrics['model_type'] = 'Pytorch'
# metrics.update(dfs_stats)
# # Add the training metrics
# train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
# metrics.update(train_metrics)
# ablation_report.append(metrics.copy())
ablation_report = pd.DataFrame(ablation_report)
# Add a column with the split_type to all reports
for report in [cv_report, hparam_report, test_report, ablation_report]:
report['split_type'] = split_type
# Return the reports
ret = {
'cv_report': cv_report,
'hparam_report': hparam_report,
'test_report': test_report,
'ablation_report': ablation_report,
}
if not fast_dev_run:
ret['majority_vote_report'] = majority_vote_report
return ret
def sklearn_model_objective(
trial: optuna.Trial,
protein2embedding: Dict,
cell2embedding: Dict,
smiles2fp: Dict,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
active_label: str = 'Active',
) -> float:
""" Objective function for hyperparameter optimization.
Args:
trial (optuna.Trial): The Optuna trial object.
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
model_type (str): The model type.
hyperparameters (Dict): The hyperparameters for the model.
fast_dev_run (bool): Whether to run a fast development run.
active_label (str): The active label column.
"""
# Generate the hyperparameters
use_single_scaler = trial.suggest_categorical('use_single_scaler', [True, False])
if model_type == 'RandomForest':
clf = suggest_random_forest(trial)
elif model_type == 'SVC':
clf = suggest_svc(trial)
elif model_type == 'LogisticRegression':
clf = suggest_logistic_regression(trial)
elif model_type == 'GradientBoosting':
clf = suggest_gradient_boosting(trial)
else:
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
# Train the model with the current set of hyperparameters
_, metrics = train_sklearn_model(
clf=clf,
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_df=train_df,
val_df=val_df,
active_label=active_label,
use_single_scaler=use_single_scaler,
)
# Metrics is a dictionary containing at least the validation loss
val_acc = metrics['val_acc']
val_roc_auc = metrics['val_roc_auc']
# Optuna aims to minimize the sklearn_model_objective
return - val_acc - val_roc_auc
def hyperparameter_tuning_and_training_sklearn(
protein2embedding: Dict,
cell2embedding: Dict,
smiles2fp: Dict,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: Optional[pd.DataFrame] = None,
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
active_label: str = 'Active',
n_trials: int = 50,
logger_name: str = 'protac_hparam_search_sklearn',
study_filename: Optional[str] = None,
) -> Tuple:
""" Hyperparameter tuning and training of a PROTAC model.
Args:
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
test_df (pd.DataFrame): The test set.
model_type (str): The model type.
n_trials (int): The number of hyperparameter optimization trials.
logger_name (str): The name of the logger. Unused, for compatibility with hyperparameter_tuning_and_training.
active_label (str): The active label column.
Returns:
tuple: The trained model and the best metrics.
"""
# Set the verbosity of Optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
# Create an Optuna study object
sampler = TPESampler(seed=42, multivariate=True)
study = optuna.create_study(direction='minimize', sampler=sampler)
study_loaded = False
if study_filename:
if os.path.exists(study_filename):
study = joblib.load(study_filename)
study_loaded = True
logging.info(f'Loaded study from {study_filename}')
if not study_loaded:
study.optimize(
lambda trial: sklearn_model_objective(
trial=trial,
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_df=train_df,
val_df=val_df,
model_type=model_type,
active_label=active_label,
),
n_trials=n_trials,
)
if study_filename:
joblib.dump(study, study_filename)
# Retrain the model with the best hyperparameters
best_hyperparameters = {k.replace('model_', ''): v for k, v in study.best_params.items() if k.startswith('model_')}
if model_type == 'RandomForest':
clf = RandomForestClassifier(random_state=42, **best_hyperparameters)
elif model_type == 'SVC':
clf = SVC(random_state=42, probability=True, **best_hyperparameters)
elif model_type == 'LogisticRegression':
clf = LogisticRegression(random_state=42, max_iter=1000, **best_hyperparameters)
elif model_type == 'GradientBoosting':
clf = GradientBoostingClassifier(random_state=42, **best_hyperparameters)
else:
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
model, metrics = train_sklearn_model(
clf=clf,
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_df=train_df,
val_df=val_df,
test_df=test_df,
active_label=active_label,
use_single_scaler=study.best_params['use_single_scaler'],
)
# Report the best hyperparameters found
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
# Return the best metrics
return model, metrics