|
from typing import Literal, List, Tuple, Optional, Dict |
|
|
|
from .protac_dataset import PROTAC_Dataset |
|
|
|
import pandas as pd |
|
from sklearn.base import ClassifierMixin |
|
from sklearn.ensemble import ( |
|
RandomForestClassifier, |
|
GradientBoostingClassifier, |
|
) |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.svm import SVC |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torchmetrics import ( |
|
Accuracy, |
|
AUROC, |
|
Precision, |
|
Recall, |
|
F1Score, |
|
MetricCollection, |
|
) |
|
import optuna |
|
|
|
|
|
def train_sklearn_model( |
|
clf: ClassifierMixin, |
|
protein2embedding: Dict, |
|
cell2embedding: Dict, |
|
smiles2fp: Dict, |
|
train_df: pd.DataFrame, |
|
val_df: pd.DataFrame, |
|
test_df: Optional[pd.DataFrame] = None, |
|
active_label: str = 'Active', |
|
use_single_scaler: bool = True, |
|
) -> Tuple[ClassifierMixin, Dict]: |
|
""" Train a classifier model on train and val sets and evaluate it on a test set. |
|
|
|
Args: |
|
clf: The classifier model to train and evaluate. |
|
train_df (pd.DataFrame): The training set. |
|
val_df (pd.DataFrame): The validation set. |
|
test_df (Optional[pd.DataFrame]): The test set. |
|
|
|
Returns: |
|
Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics. |
|
""" |
|
|
|
train_ds = PROTAC_Dataset( |
|
train_df, |
|
protein2embedding, |
|
cell2embedding, |
|
smiles2fp, |
|
active_label=active_label, |
|
use_smote=False, |
|
) |
|
scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler) |
|
train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler) |
|
val_ds = PROTAC_Dataset( |
|
val_df, |
|
protein2embedding, |
|
cell2embedding, |
|
smiles2fp, |
|
active_label=active_label, |
|
use_smote=False, |
|
) |
|
val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler) |
|
if test_df is not None: |
|
test_ds = PROTAC_Dataset( |
|
test_df, |
|
protein2embedding, |
|
cell2embedding, |
|
smiles2fp, |
|
active_label=active_label, |
|
use_smote=False, |
|
) |
|
test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler) |
|
|
|
|
|
X_train, y_train = train_ds.get_numpy_arrays() |
|
X_val, y_val = val_ds.get_numpy_arrays() |
|
if test_df is not None: |
|
X_test, y_test = test_ds.get_numpy_arrays() |
|
|
|
|
|
clf.fit(X_train, y_train) |
|
|
|
stages = ['train_metrics', 'val_metrics', 'test_metrics'] |
|
metrics = nn.ModuleDict({s: MetricCollection({ |
|
'acc': Accuracy(task='binary'), |
|
'roc_auc': AUROC(task='binary'), |
|
'precision': Precision(task='binary'), |
|
'recall': Recall(task='binary'), |
|
'f1_score': F1Score(task='binary'), |
|
'opt_score': Accuracy(task='binary') + F1Score(task='binary'), |
|
'hp_metric': Accuracy(task='binary'), |
|
}, prefix=s.replace('metrics', '')) for s in stages}) |
|
|
|
|
|
metrics_out = {} |
|
|
|
y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1]) |
|
y_true = torch.tensor(y_train) |
|
metrics['train_metrics'].update(y_pred, y_true) |
|
metrics_out.update(metrics['train_metrics'].compute()) |
|
|
|
y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1]) |
|
y_true = torch.tensor(y_val) |
|
metrics['val_metrics'].update(y_pred, y_true) |
|
metrics_out.update(metrics['val_metrics'].compute()) |
|
|
|
if test_df is not None: |
|
y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1]) |
|
y_true = torch.tensor(y_test) |
|
metrics['test_metrics'].update(y_pred, y_true) |
|
metrics_out.update(metrics['test_metrics'].compute()) |
|
|
|
return clf, metrics_out |
|
|
|
|
|
def suggest_random_forest( |
|
trial: optuna.Trial, |
|
) -> ClassifierMixin: |
|
""" Suggest hyperparameters for a Random Forest classifier. |
|
|
|
Args: |
|
trial (optuna.Trial): The Optuna trial object. |
|
|
|
Returns: |
|
ClassifierMixin: The Random Forest classifier with the suggested hyperparameters. |
|
""" |
|
n_estimators = trial.suggest_int('model_n_estimators', 10, 1000) |
|
max_depth = trial.suggest_int('model_max_depth', 2, 100) |
|
min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10) |
|
min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10) |
|
max_features = trial.suggest_categorical('model_max_features', [None, 'sqrt', 'log2']) |
|
criterion = trial.suggest_categorical('model_criterion', ['gini', 'entropy']) |
|
|
|
clf = RandomForestClassifier( |
|
n_estimators=n_estimators, |
|
max_depth=max_depth, |
|
min_samples_split=min_samples_split, |
|
min_samples_leaf=min_samples_leaf, |
|
max_features=max_features, |
|
criterion=criterion, |
|
random_state=42, |
|
) |
|
|
|
return clf |
|
|
|
|
|
def suggest_logistic_regression( |
|
trial: optuna.Trial, |
|
) -> ClassifierMixin: |
|
""" Suggest hyperparameters for a Logistic Regression classifier. |
|
|
|
Args: |
|
trial (optuna.Trial): The Optuna trial object. |
|
|
|
Returns: |
|
ClassifierMixin: The Logistic Regression classifier with the suggested hyperparameters. |
|
""" |
|
|
|
C = trial.suggest_loguniform('model_C', 1e-4, 1e2) |
|
penalty = trial.suggest_categorical('model_penalty', ['l1', 'l2', 'elasticnet', None]) |
|
solver = trial.suggest_categorical('model_solver', ['newton-cholesky', 'lbfgs', 'liblinear', 'sag', 'saga']) |
|
|
|
|
|
if penalty == 'l1' and solver not in ['liblinear', 'saga']: |
|
raise optuna.exceptions.TrialPruned() |
|
if penalty == None and solver not in ['newton-cholesky', 'lbfgs', 'sag']: |
|
raise optuna.exceptions.TrialPruned() |
|
|
|
|
|
clf = LogisticRegression( |
|
C=C, |
|
penalty=penalty, |
|
solver=solver, |
|
max_iter=1000, |
|
random_state=42, |
|
) |
|
|
|
return clf |
|
|
|
|
|
def suggest_svc( |
|
trial: optuna.Trial, |
|
) -> ClassifierMixin: |
|
""" Suggest hyperparameters for an SVC classifier. |
|
|
|
Args: |
|
trial (optuna.Trial): The Optuna trial object. |
|
|
|
Returns: |
|
ClassifierMixin: The SVC classifier with the suggested hyperparameters. |
|
""" |
|
C = trial.suggest_loguniform('model_C', 1e-4, 1e2) |
|
kernel = trial.suggest_categorical('model_kernel', ['linear', 'poly', 'rbf', 'sigmoid']) |
|
gamma = trial.suggest_categorical('model_gamma', ['scale', 'auto']) |
|
degree = trial.suggest_int('model_degree', 2, 5) if kernel == 'poly' else 3 |
|
|
|
clf = SVC( |
|
C=C, |
|
kernel=kernel, |
|
gamma=gamma, |
|
degree=degree, |
|
probability=True, |
|
random_state=42, |
|
) |
|
|
|
return clf |
|
|
|
|
|
def suggest_gradient_boosting( |
|
trial: optuna.Trial, |
|
) -> ClassifierMixin: |
|
""" Suggest hyperparameters for a Gradient Boosting classifier. |
|
|
|
Args: |
|
trial (optuna.Trial): The Optuna trial object. |
|
|
|
Returns: |
|
ClassifierMixin: The Gradient Boosting classifier with the suggested hyperparameters. |
|
""" |
|
n_estimators = trial.suggest_int('model_n_estimators', 50, 500) |
|
learning_rate = trial.suggest_loguniform('model_learning_rate', 0.01, 1) |
|
max_depth = trial.suggest_int('model_max_depth', 3, 10) |
|
min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10) |
|
min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10) |
|
max_features = trial.suggest_categorical('model_max_features', ['sqrt', 'log2', None]) |
|
|
|
clf = GradientBoostingClassifier( |
|
n_estimators=n_estimators, |
|
learning_rate=learning_rate, |
|
max_depth=max_depth, |
|
min_samples_split=min_samples_split, |
|
min_samples_leaf=min_samples_leaf, |
|
max_features=max_features, |
|
random_state=42, |
|
) |
|
|
|
return clf |