ribesstefano's picture
Started working on packaging the repository
5e01175
raw
history blame
8.75 kB
from typing import Literal, List, Tuple, Optional, Dict
from torch.utils.data import Dataset
import numpy as np
from imblearn.over_sampling import SMOTE, ADASYN
import pandas as pd
from sklearn.preprocessing import StandardScaler
class PROTAC_Dataset(Dataset):
def __init__(
self,
protac_df: pd.DataFrame,
protein2embedding: Dict,
cell2embedding: Dict,
smiles2fp: Dict,
use_smote: bool = False,
oversampler: Optional[SMOTE | ADASYN] = None,
active_label: str = 'Active',
):
""" Initialize the PROTAC dataset
Args:
protac_df (pd.DataFrame): The PROTAC dataframe
protein2embedding (dict): Dictionary of protein embeddings
cell2embedding (dict): Dictionary of cell line embeddings
smiles2fp (dict): Dictionary of SMILES to fingerprint
use_smote (bool): Whether to use SMOTE for oversampling
use_ored_activity (bool): Whether to use the 'Active - OR' column
"""
# Filter out examples with NaN in active_col column
self.data = protac_df # [~protac_df[active_col].isna()]
self.protein2embedding = protein2embedding
self.cell2embedding = cell2embedding
self.smiles2fp = smiles2fp
self.active_label = active_label
self.use_single_scaler = None
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
self.protein_emb_dim = protein2embedding[list(
protein2embedding.keys())[0]].shape[0]
self.cell_emb_dim = cell2embedding[list(
cell2embedding.keys())[0]].shape[0]
# Look up the embeddings
self.data = pd.DataFrame({
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(),
'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(),
self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
})
# Apply SMOTE
self.use_smote = use_smote
self.oversampler = oversampler
if self.use_smote:
self.apply_smote()
def apply_smote(self):
# Prepare the dataset for SMOTE
features = []
labels = []
for _, row in self.data.iterrows():
features.append(np.hstack([
row['Smiles'],
row['Uniprot'],
row['E3 Ligase Uniprot'],
row['Cell Line Identifier'],
]))
labels.append(row[self.active_label])
# Convert to numpy array
features = np.array(features).astype(np.float32)
labels = np.array(labels).astype(np.float32)
# Initialize SMOTE and fit
if self.oversampler is None:
oversampler = SMOTE(random_state=42)
else:
oversampler = self.oversampler
features_smote, labels_smote = oversampler.fit_resample(features, labels)
# Separate the features back into their respective embeddings
smiles_embs = features_smote[:, :self.smiles_emb_dim]
poi_embs = features_smote[:,
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
e3_embs = features_smote[:, self.smiles_emb_dim +
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
cell_embs = features_smote[:, -self.cell_emb_dim:]
# Reconstruct the dataframe with oversampled data
df_smote = pd.DataFrame({
'Smiles': list(smiles_embs),
'Uniprot': list(poi_embs),
'E3 Ligase Uniprot': list(e3_embs),
'Cell Line Identifier': list(cell_embs),
self.active_label: labels_smote
})
self.data = df_smote
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
""" Fit the scalers for the data.
Args:
use_single_scaler (bool): Whether to use a single scaler for all features.
scaler_kwargs: Keyword arguments for the StandardScaler.
Returns:
dict: The fitted scalers.
"""
if use_single_scaler:
self.use_single_scaler = True
scaler = StandardScaler(**scaler_kwargs)
embeddings = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
])
scaler.fit(embeddings)
return scaler
else:
self.use_single_scaler = False
scalers = {}
scalers['Smiles'] = StandardScaler(**scaler_kwargs)
scalers['Uniprot'] = StandardScaler(**scaler_kwargs)
scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs)
scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs)
scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy()))
scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy()))
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
return scalers
def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
""" Apply scaling to the data.
Args:
scalers (dict): The scalers for each feature.
use_single_scaler (bool): Whether to use a single scaler for all features.
"""
if self.use_single_scaler is None:
raise ValueError(
"The fit_scaling method must be called before apply_scaling.")
if use_single_scaler != self.use_single_scaler:
raise ValueError(
f"The use_single_scaler parameter must be the same as the one used in the fit_scaling method. Got {use_single_scaler}, previously {self.use_single_scaler}.")
if use_single_scaler:
embeddings = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
])
scaled_embeddings = scalers.transform(embeddings)
self.data = pd.DataFrame({
'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]),
'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]),
'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]),
'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]),
self.active_label: self.data[self.active_label]
})
else:
self.data['Smiles'] = self.data['Smiles'].apply(lambda x: scalers['Smiles'].transform(x[np.newaxis, :])[0])
self.data['Uniprot'] = self.data['Uniprot'].apply(lambda x: scalers['Uniprot'].transform(x[np.newaxis, :])[0])
self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
def get_numpy_arrays(self):
X = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
]).copy()
y = self.data[self.active_label].values.copy()
return X, y
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
elem = {
'smiles_emb': self.data['Smiles'].iloc[idx],
'poi_emb': self.data['Uniprot'].iloc[idx],
'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
'active': self.data[self.active_label].iloc[idx],
}
return elem