|
from typing import Literal, List, Tuple, Optional, Dict |
|
|
|
from torch.utils.data import Dataset |
|
import numpy as np |
|
from imblearn.over_sampling import SMOTE, ADASYN |
|
import pandas as pd |
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
class PROTAC_Dataset(Dataset): |
|
def __init__( |
|
self, |
|
protac_df: pd.DataFrame, |
|
protein2embedding: Dict, |
|
cell2embedding: Dict, |
|
smiles2fp: Dict, |
|
use_smote: bool = False, |
|
oversampler: Optional[SMOTE | ADASYN] = None, |
|
active_label: str = 'Active', |
|
): |
|
""" Initialize the PROTAC dataset |
|
|
|
Args: |
|
protac_df (pd.DataFrame): The PROTAC dataframe |
|
protein2embedding (dict): Dictionary of protein embeddings |
|
cell2embedding (dict): Dictionary of cell line embeddings |
|
smiles2fp (dict): Dictionary of SMILES to fingerprint |
|
use_smote (bool): Whether to use SMOTE for oversampling |
|
use_ored_activity (bool): Whether to use the 'Active - OR' column |
|
""" |
|
|
|
self.data = protac_df |
|
self.protein2embedding = protein2embedding |
|
self.cell2embedding = cell2embedding |
|
self.smiles2fp = smiles2fp |
|
self.active_label = active_label |
|
self.use_single_scaler = None |
|
|
|
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0] |
|
self.protein_emb_dim = protein2embedding[list( |
|
protein2embedding.keys())[0]].shape[0] |
|
self.cell_emb_dim = cell2embedding[list( |
|
cell2embedding.keys())[0]].shape[0] |
|
|
|
|
|
self.data = pd.DataFrame({ |
|
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(), |
|
'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(), |
|
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(), |
|
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(), |
|
self.active_label: self.data[self.active_label].astype(np.float32).tolist(), |
|
}) |
|
|
|
|
|
self.use_smote = use_smote |
|
self.oversampler = oversampler |
|
if self.use_smote: |
|
self.apply_smote() |
|
|
|
def apply_smote(self): |
|
|
|
features = [] |
|
labels = [] |
|
for _, row in self.data.iterrows(): |
|
features.append(np.hstack([ |
|
row['Smiles'], |
|
row['Uniprot'], |
|
row['E3 Ligase Uniprot'], |
|
row['Cell Line Identifier'], |
|
])) |
|
labels.append(row[self.active_label]) |
|
|
|
|
|
features = np.array(features).astype(np.float32) |
|
labels = np.array(labels).astype(np.float32) |
|
|
|
|
|
if self.oversampler is None: |
|
oversampler = SMOTE(random_state=42) |
|
else: |
|
oversampler = self.oversampler |
|
features_smote, labels_smote = oversampler.fit_resample(features, labels) |
|
|
|
|
|
smiles_embs = features_smote[:, :self.smiles_emb_dim] |
|
poi_embs = features_smote[:, |
|
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim] |
|
e3_embs = features_smote[:, self.smiles_emb_dim + |
|
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim] |
|
cell_embs = features_smote[:, -self.cell_emb_dim:] |
|
|
|
|
|
df_smote = pd.DataFrame({ |
|
'Smiles': list(smiles_embs), |
|
'Uniprot': list(poi_embs), |
|
'E3 Ligase Uniprot': list(e3_embs), |
|
'Cell Line Identifier': list(cell_embs), |
|
self.active_label: labels_smote |
|
}) |
|
self.data = df_smote |
|
|
|
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict: |
|
""" Fit the scalers for the data. |
|
|
|
Args: |
|
use_single_scaler (bool): Whether to use a single scaler for all features. |
|
scaler_kwargs: Keyword arguments for the StandardScaler. |
|
|
|
Returns: |
|
dict: The fitted scalers. |
|
""" |
|
if use_single_scaler: |
|
self.use_single_scaler = True |
|
scaler = StandardScaler(**scaler_kwargs) |
|
embeddings = np.hstack([ |
|
np.array(self.data['Smiles'].tolist()), |
|
np.array(self.data['Uniprot'].tolist()), |
|
np.array(self.data['E3 Ligase Uniprot'].tolist()), |
|
np.array(self.data['Cell Line Identifier'].tolist()), |
|
]) |
|
scaler.fit(embeddings) |
|
return scaler |
|
else: |
|
self.use_single_scaler = False |
|
scalers = {} |
|
scalers['Smiles'] = StandardScaler(**scaler_kwargs) |
|
scalers['Uniprot'] = StandardScaler(**scaler_kwargs) |
|
scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs) |
|
scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs) |
|
|
|
scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy())) |
|
scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy())) |
|
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy())) |
|
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy())) |
|
|
|
return scalers |
|
|
|
def apply_scaling(self, scalers: dict, use_single_scaler: bool = False): |
|
""" Apply scaling to the data. |
|
|
|
Args: |
|
scalers (dict): The scalers for each feature. |
|
use_single_scaler (bool): Whether to use a single scaler for all features. |
|
""" |
|
if self.use_single_scaler is None: |
|
raise ValueError( |
|
"The fit_scaling method must be called before apply_scaling.") |
|
if use_single_scaler != self.use_single_scaler: |
|
raise ValueError( |
|
f"The use_single_scaler parameter must be the same as the one used in the fit_scaling method. Got {use_single_scaler}, previously {self.use_single_scaler}.") |
|
if use_single_scaler: |
|
embeddings = np.hstack([ |
|
np.array(self.data['Smiles'].tolist()), |
|
np.array(self.data['Uniprot'].tolist()), |
|
np.array(self.data['E3 Ligase Uniprot'].tolist()), |
|
np.array(self.data['Cell Line Identifier'].tolist()), |
|
]) |
|
scaled_embeddings = scalers.transform(embeddings) |
|
self.data = pd.DataFrame({ |
|
'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]), |
|
'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]), |
|
'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]), |
|
'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]), |
|
self.active_label: self.data[self.active_label] |
|
}) |
|
else: |
|
self.data['Smiles'] = self.data['Smiles'].apply(lambda x: scalers['Smiles'].transform(x[np.newaxis, :])[0]) |
|
self.data['Uniprot'] = self.data['Uniprot'].apply(lambda x: scalers['Uniprot'].transform(x[np.newaxis, :])[0]) |
|
self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0]) |
|
self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0]) |
|
|
|
def get_numpy_arrays(self): |
|
X = np.hstack([ |
|
np.array(self.data['Smiles'].tolist()), |
|
np.array(self.data['Uniprot'].tolist()), |
|
np.array(self.data['E3 Ligase Uniprot'].tolist()), |
|
np.array(self.data['Cell Line Identifier'].tolist()), |
|
]).copy() |
|
y = self.data[self.active_label].values.copy() |
|
return X, y |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
elem = { |
|
'smiles_emb': self.data['Smiles'].iloc[idx], |
|
'poi_emb': self.data['Uniprot'].iloc[idx], |
|
'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx], |
|
'cell_emb': self.data['Cell Line Identifier'].iloc[idx], |
|
'active': self.data[self.active_label].iloc[idx], |
|
} |
|
return elem |