Commit
·
b09510c
1
Parent(s):
5c27a23
Added script file for hparam CV training
Browse files
notebooks/protac_degradation_predictor.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/protac_degradation_predictor.py
ADDED
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %% [markdown]
|
2 |
+
# # PROTAC-Degradation-Predictor
|
3 |
+
|
4 |
+
# %%
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
8 |
+
protac_df.head()
|
9 |
+
|
10 |
+
# %%
|
11 |
+
# Get the unique Article IDs of the entries with NaN values in the Active column
|
12 |
+
nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique()
|
13 |
+
nan_active
|
14 |
+
|
15 |
+
# %%
|
16 |
+
# Map E3 Ligase Iap to IAP
|
17 |
+
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
18 |
+
|
19 |
+
# %%
|
20 |
+
protac_df.columns
|
21 |
+
|
22 |
+
# %%
|
23 |
+
cells = sorted(protac_df['Cell Type'].dropna().unique().tolist())
|
24 |
+
print(f'Number of non-cleaned cell lines: {len(cells)}')
|
25 |
+
|
26 |
+
# %%
|
27 |
+
cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist())
|
28 |
+
print(f'Number of cleaned cell lines: {len(cells)}')
|
29 |
+
|
30 |
+
# %%
|
31 |
+
unlabeled_df = protac_df[protac_df['Active'].isna()]
|
32 |
+
print(f'Number of compounds in test set: {len(unlabeled_df)}')
|
33 |
+
|
34 |
+
# %% [markdown]
|
35 |
+
# ## Load Protein Embeddings
|
36 |
+
|
37 |
+
# %% [markdown]
|
38 |
+
# Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
|
39 |
+
#
|
40 |
+
# Please note that running the following cell the first time might take a while.
|
41 |
+
|
42 |
+
# %%
|
43 |
+
import os
|
44 |
+
import urllib.request
|
45 |
+
|
46 |
+
download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
|
47 |
+
embeddings_path = "../data/uniprot2embedding.h5"
|
48 |
+
if not os.path.exists(embeddings_path):
|
49 |
+
# Download the file
|
50 |
+
print(f'Downloading embeddings from {download_link}')
|
51 |
+
urllib.request.urlretrieve(download_link, embeddings_path)
|
52 |
+
|
53 |
+
# %%
|
54 |
+
import h5py
|
55 |
+
import numpy as np
|
56 |
+
from tqdm.auto import tqdm
|
57 |
+
|
58 |
+
protein_embeddings = {}
|
59 |
+
with h5py.File("../data/uniprot2embedding.h5", "r") as file:
|
60 |
+
print(f"number of entries: {len(file.items()):,}")
|
61 |
+
uniprots = protac_df['Uniprot'].unique().tolist()
|
62 |
+
uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist()
|
63 |
+
for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'):
|
64 |
+
try:
|
65 |
+
embedding = file[sequence_id][:]
|
66 |
+
protein_embeddings[sequence_id] = np.array(embedding)
|
67 |
+
if i < 10:
|
68 |
+
print(
|
69 |
+
f"\tid: {sequence_id}, "
|
70 |
+
f"\tembeddings shape: {embedding.shape}, "
|
71 |
+
f"\tembeddings mean: {np.array(embedding).mean()}"
|
72 |
+
)
|
73 |
+
except KeyError:
|
74 |
+
print(f'KeyError for {sequence_id}')
|
75 |
+
protein_embeddings[sequence_id] = np.zeros((1024,))
|
76 |
+
|
77 |
+
# %% [markdown]
|
78 |
+
# ## Load Cell Embeddings
|
79 |
+
|
80 |
+
# %%
|
81 |
+
import pickle
|
82 |
+
|
83 |
+
cell2embedding_filepath = '../data/cell2embedding.pkl'
|
84 |
+
with open(cell2embedding_filepath, 'rb') as f:
|
85 |
+
cell2embedding = pickle.load(f)
|
86 |
+
print(f'Loaded {len(cell2embedding)} cell lines')
|
87 |
+
|
88 |
+
# %%
|
89 |
+
emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape
|
90 |
+
# Assign all-zero vectors to cell lines that are not in the embedding file
|
91 |
+
for cell_line in protac_df['Cell Line Identifier'].unique():
|
92 |
+
if cell_line not in cell2embedding:
|
93 |
+
cell2embedding[cell_line] = np.zeros(emb_shape)
|
94 |
+
|
95 |
+
# %% [markdown]
|
96 |
+
# ## Precompute Molecular Fingerprints
|
97 |
+
|
98 |
+
# %%
|
99 |
+
from rdkit import Chem
|
100 |
+
from rdkit.Chem import AllChem
|
101 |
+
from rdkit.Chem import Draw
|
102 |
+
|
103 |
+
morgan_radius = 15
|
104 |
+
n_bits = 1024
|
105 |
+
|
106 |
+
# fpgen = AllChem.GetAtomPairGenerator()
|
107 |
+
rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512)
|
108 |
+
morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits)
|
109 |
+
|
110 |
+
smiles2fp = {}
|
111 |
+
for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
|
112 |
+
# Get the fingerprint as a bit vector
|
113 |
+
morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
|
114 |
+
# rdkit_fp = rdkit_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
|
115 |
+
# fp = np.concatenate([morgan_fp, rdkit_fp])
|
116 |
+
smiles2fp[smiles] = morgan_fp
|
117 |
+
|
118 |
+
# Count the number of unique SMILES and the number of unique Morgan fingerprints
|
119 |
+
print(f'Number of unique SMILES: {len(smiles2fp)}')
|
120 |
+
print(f'Number of unique fingerprints: {len(set([tuple(fp) for fp in smiles2fp.values()]))}')
|
121 |
+
# Get the list of SMILES with overlapping fingerprints
|
122 |
+
overlapping_smiles = []
|
123 |
+
unique_fps = set()
|
124 |
+
for smiles, fp in smiles2fp.items():
|
125 |
+
if tuple(fp) in unique_fps:
|
126 |
+
overlapping_smiles.append(smiles)
|
127 |
+
else:
|
128 |
+
unique_fps.add(tuple(fp))
|
129 |
+
print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')
|
130 |
+
print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
|
131 |
+
|
132 |
+
# %%
|
133 |
+
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
134 |
+
from rdkit import DataStructs
|
135 |
+
from collections import defaultdict
|
136 |
+
|
137 |
+
tanimoto_matrix = defaultdict(list)
|
138 |
+
for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
|
139 |
+
fp1 = smiles2fp[smiles1]
|
140 |
+
# TODO: Use BulkTanimotoSimilarity
|
141 |
+
for j, smiles2 in enumerate(protac_df['Smiles'].unique()):
|
142 |
+
if j < i:
|
143 |
+
continue
|
144 |
+
fp2 = smiles2fp[smiles2]
|
145 |
+
tanimoto_dist = DataStructs.TanimotoSimilarity(fp1, fp2)
|
146 |
+
tanimoto_matrix[smiles1].append(tanimoto_dist)
|
147 |
+
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
148 |
+
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
149 |
+
|
150 |
+
# %%
|
151 |
+
# # Plot the distribution of the average Tanimoto similarity
|
152 |
+
# import seaborn as sns
|
153 |
+
# import matplotlib.pyplot as plt
|
154 |
+
|
155 |
+
# sns.histplot(protac_df['Avg Tanimoto'], bins=50)
|
156 |
+
# plt.xlabel('Average Tanimoto similarity')
|
157 |
+
# plt.ylabel('Count')
|
158 |
+
# plt.title('Distribution of average Tanimoto similarity')
|
159 |
+
# plt.grid(axis='y', alpha=0.5)
|
160 |
+
# plt.show()
|
161 |
+
|
162 |
+
# %%
|
163 |
+
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
164 |
+
|
165 |
+
# %% [markdown]
|
166 |
+
# ## Set the Column to Predict
|
167 |
+
|
168 |
+
# %%
|
169 |
+
# active_col = 'Active'
|
170 |
+
active_col = 'Active - OR'
|
171 |
+
|
172 |
+
|
173 |
+
from sklearn.preprocessing import StandardScaler
|
174 |
+
|
175 |
+
# %% [markdown]
|
176 |
+
# ## Define Torch Dataset
|
177 |
+
|
178 |
+
# %%
|
179 |
+
from imblearn.over_sampling import SMOTE, ADASYN
|
180 |
+
from sklearn.preprocessing import LabelEncoder
|
181 |
+
import pandas as pd
|
182 |
+
import numpy as np
|
183 |
+
|
184 |
+
# %%
|
185 |
+
from torch.utils.data import Dataset, DataLoader
|
186 |
+
|
187 |
+
|
188 |
+
class PROTAC_Dataset(Dataset):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
protac_df,
|
192 |
+
protein_embeddings=protein_embeddings,
|
193 |
+
cell2embedding=cell2embedding,
|
194 |
+
smiles2fp=smiles2fp,
|
195 |
+
use_smote=False,
|
196 |
+
oversampler=None,
|
197 |
+
use_ored_activity=False,
|
198 |
+
):
|
199 |
+
""" Initialize the PROTAC dataset
|
200 |
+
|
201 |
+
Args:
|
202 |
+
protac_df (pd.DataFrame): The PROTAC dataframe
|
203 |
+
protein_embeddings (dict): Dictionary of protein embeddings
|
204 |
+
cell2embedding (dict): Dictionary of cell line embeddings
|
205 |
+
smiles2fp (dict): Dictionary of SMILES to fingerprint
|
206 |
+
use_smote (bool): Whether to use SMOTE for oversampling
|
207 |
+
use_ored_activity (bool): Whether to use the 'Active - OR' column
|
208 |
+
"""
|
209 |
+
# Filter out examples with NaN in 'Active' column
|
210 |
+
self.data = protac_df # [~protac_df['Active'].isna()]
|
211 |
+
self.protein_embeddings = protein_embeddings
|
212 |
+
self.cell2embedding = cell2embedding
|
213 |
+
self.smiles2fp = smiles2fp
|
214 |
+
|
215 |
+
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
|
216 |
+
self.protein_emb_dim = protein_embeddings[list(
|
217 |
+
protein_embeddings.keys())[0]].shape[0]
|
218 |
+
self.cell_emb_dim = cell2embedding[list(
|
219 |
+
cell2embedding.keys())[0]].shape[0]
|
220 |
+
|
221 |
+
self.active_label = 'Active - OR' if use_ored_activity else 'Active'
|
222 |
+
|
223 |
+
self.use_smote = use_smote
|
224 |
+
self.oversampler = oversampler
|
225 |
+
# Apply SMOTE
|
226 |
+
if self.use_smote:
|
227 |
+
self.apply_smote()
|
228 |
+
|
229 |
+
def apply_smote(self):
|
230 |
+
# Prepare the dataset for SMOTE
|
231 |
+
features = []
|
232 |
+
labels = []
|
233 |
+
for _, row in self.data.iterrows():
|
234 |
+
smiles_emb = smiles2fp[row['Smiles']]
|
235 |
+
poi_emb = protein_embeddings[row['Uniprot']]
|
236 |
+
e3_emb = protein_embeddings[row['E3 Ligase Uniprot']]
|
237 |
+
cell_emb = cell2embedding[row['Cell Line Identifier']]
|
238 |
+
features.append(np.hstack([
|
239 |
+
smiles_emb.astype(np.float32),
|
240 |
+
poi_emb.astype(np.float32),
|
241 |
+
e3_emb.astype(np.float32),
|
242 |
+
cell_emb.astype(np.float32),
|
243 |
+
]))
|
244 |
+
labels.append(row[self.active_label])
|
245 |
+
|
246 |
+
# Convert to numpy array
|
247 |
+
features = np.array(features).astype(np.float32)
|
248 |
+
labels = np.array(labels).astype(np.float32)
|
249 |
+
|
250 |
+
# Initialize SMOTE and fit
|
251 |
+
if self.oversampler is None:
|
252 |
+
oversampler = SMOTE(random_state=42)
|
253 |
+
else:
|
254 |
+
oversampler = self.oversampler
|
255 |
+
features_smote, labels_smote = oversampler.fit_resample(features, labels)
|
256 |
+
|
257 |
+
# Separate the features back into their respective embeddings
|
258 |
+
smiles_embs = features_smote[:, :self.smiles_emb_dim]
|
259 |
+
poi_embs = features_smote[:,
|
260 |
+
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
|
261 |
+
e3_embs = features_smote[:, self.smiles_emb_dim +
|
262 |
+
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
|
263 |
+
cell_embs = features_smote[:, -self.cell_emb_dim:]
|
264 |
+
|
265 |
+
# Reconstruct the dataframe with oversampled data
|
266 |
+
df_smote = pd.DataFrame({
|
267 |
+
'Smiles': list(smiles_embs),
|
268 |
+
'Uniprot': list(poi_embs),
|
269 |
+
'E3 Ligase Uniprot': list(e3_embs),
|
270 |
+
'Cell Line Identifier': list(cell_embs),
|
271 |
+
self.active_label: labels_smote
|
272 |
+
})
|
273 |
+
self.data = df_smote
|
274 |
+
|
275 |
+
def __len__(self):
|
276 |
+
return len(self.data)
|
277 |
+
|
278 |
+
def __getitem__(self, idx):
|
279 |
+
if self.use_smote:
|
280 |
+
# NOTE: We do not need to look up the embeddings anymore
|
281 |
+
elem = {
|
282 |
+
'smiles_emb': self.data['Smiles'].iloc[idx],
|
283 |
+
'poi_emb': self.data['Uniprot'].iloc[idx],
|
284 |
+
'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
|
285 |
+
'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
|
286 |
+
'active': self.data[self.active_label].iloc[idx],
|
287 |
+
}
|
288 |
+
else:
|
289 |
+
elem = {
|
290 |
+
'smiles_emb': self.smiles2fp[self.data['Smiles'].iloc[idx]].astype(np.float32),
|
291 |
+
'poi_emb': self.protein_embeddings[self.data['Uniprot'].iloc[idx]].astype(np.float32),
|
292 |
+
'e3_emb': self.protein_embeddings[self.data['E3 Ligase Uniprot'].iloc[idx]].astype(np.float32),
|
293 |
+
'cell_emb': self.cell2embedding[self.data['Cell Line Identifier'].iloc[idx]].astype(np.float32),
|
294 |
+
'active': 1. if self.data[self.active_label].iloc[idx] else 0.,
|
295 |
+
}
|
296 |
+
return elem
|
297 |
+
|
298 |
+
# %%
|
299 |
+
import warnings
|
300 |
+
import torch
|
301 |
+
import torch.nn as nn
|
302 |
+
import torch.nn.functional as F
|
303 |
+
import torch.optim as optim
|
304 |
+
import pytorch_lightning as pl
|
305 |
+
from torchmetrics import (
|
306 |
+
Accuracy,
|
307 |
+
AUROC,
|
308 |
+
Precision,
|
309 |
+
Recall,
|
310 |
+
F1Score,
|
311 |
+
)
|
312 |
+
from torchmetrics import MetricCollection
|
313 |
+
|
314 |
+
# Ignore UserWarning from PyTorch Lightning
|
315 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
316 |
+
|
317 |
+
class PROTAC_Model(pl.LightningModule):
|
318 |
+
|
319 |
+
def __init__(
|
320 |
+
self,
|
321 |
+
hidden_dim,
|
322 |
+
smiles_emb_dim=1024,
|
323 |
+
poi_emb_dim=1024,
|
324 |
+
e3_emb_dim=1024,
|
325 |
+
cell_emb_dim=768,
|
326 |
+
batch_size=32,
|
327 |
+
learning_rate=1e-3,
|
328 |
+
dropout=0.2,
|
329 |
+
train_dataset=None,
|
330 |
+
val_dataset=None,
|
331 |
+
test_dataset=None,
|
332 |
+
disabled_embeddings=[],
|
333 |
+
):
|
334 |
+
super().__init__()
|
335 |
+
self.poi_emb_dim = poi_emb_dim
|
336 |
+
self.e3_emb_dim = e3_emb_dim
|
337 |
+
self.cell_emb_dim = cell_emb_dim
|
338 |
+
self.smiles_emb_dim = smiles_emb_dim
|
339 |
+
self.hidden_dim = hidden_dim
|
340 |
+
self.batch_size = batch_size
|
341 |
+
self.learning_rate = learning_rate
|
342 |
+
self.train_dataset = train_dataset
|
343 |
+
self.val_dataset = val_dataset
|
344 |
+
self.test_dataset = test_dataset
|
345 |
+
self.disabled_embeddings = disabled_embeddings
|
346 |
+
# Set our init args as class attributes
|
347 |
+
self.__dict__.update(locals()) # Add arguments as attributes
|
348 |
+
# Save the arguments passed to init
|
349 |
+
ignore_args_as_hyperparams = [
|
350 |
+
'train_dataset',
|
351 |
+
'test_dataset',
|
352 |
+
'val_dataset',
|
353 |
+
]
|
354 |
+
self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
|
355 |
+
|
356 |
+
if 'poi' not in self.disabled_embeddings:
|
357 |
+
self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
|
358 |
+
# # Set the POI surrogate model as a Sequential model
|
359 |
+
# self.poi_emb = nn.Sequential(
|
360 |
+
# nn.Linear(poi_emb_dim, hidden_dim),
|
361 |
+
# nn.GELU(),
|
362 |
+
# nn.Dropout(p=dropout),
|
363 |
+
# nn.Linear(hidden_dim, hidden_dim),
|
364 |
+
# # nn.ReLU(),
|
365 |
+
# # nn.Dropout(p=dropout),
|
366 |
+
# )
|
367 |
+
if 'e3' not in self.disabled_embeddings:
|
368 |
+
self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
|
369 |
+
# self.e3_emb = nn.Sequential(
|
370 |
+
# nn.Linear(e3_emb_dim, hidden_dim),
|
371 |
+
# # nn.ReLU(),
|
372 |
+
# nn.Dropout(p=dropout),
|
373 |
+
# # nn.Linear(hidden_dim, hidden_dim),
|
374 |
+
# # nn.ReLU(),
|
375 |
+
# # nn.Dropout(p=dropout),
|
376 |
+
# )
|
377 |
+
if 'cell' not in self.disabled_embeddings:
|
378 |
+
self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
|
379 |
+
# self.cell_emb = nn.Sequential(
|
380 |
+
# nn.Linear(cell_emb_dim, hidden_dim),
|
381 |
+
# # nn.ReLU(),
|
382 |
+
# nn.Dropout(p=dropout),
|
383 |
+
# # nn.Linear(hidden_dim, hidden_dim),
|
384 |
+
# # nn.ReLU(),
|
385 |
+
# # nn.Dropout(p=dropout),
|
386 |
+
# )
|
387 |
+
if 'smiles' not in self.disabled_embeddings:
|
388 |
+
self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
|
389 |
+
# self.smiles_emb = nn.Sequential(
|
390 |
+
# nn.Linear(smiles_emb_dim, hidden_dim),
|
391 |
+
# # nn.ReLU(),
|
392 |
+
# nn.Dropout(p=dropout),
|
393 |
+
# # nn.Linear(hidden_dim, hidden_dim),
|
394 |
+
# # nn.ReLU(),
|
395 |
+
# # nn.Dropout(p=dropout),
|
396 |
+
# )
|
397 |
+
|
398 |
+
self.fc1 = nn.Linear(
|
399 |
+
hidden_dim * (4 - len(self.disabled_embeddings)), hidden_dim)
|
400 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
401 |
+
self.fc3 = nn.Linear(hidden_dim, 1)
|
402 |
+
|
403 |
+
self.dropout = nn.Dropout(p=dropout)
|
404 |
+
|
405 |
+
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
406 |
+
self.metrics = nn.ModuleDict({s: MetricCollection({
|
407 |
+
'acc': Accuracy(task='binary'),
|
408 |
+
'roc_auc': AUROC(task='binary'),
|
409 |
+
'precision': Precision(task='binary'),
|
410 |
+
'recall': Recall(task='binary'),
|
411 |
+
'f1_score': F1Score(task='binary'),
|
412 |
+
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
413 |
+
'hp_metric': Accuracy(task='binary'),
|
414 |
+
}, prefix=s.replace('metrics', '')) for s in stages})
|
415 |
+
|
416 |
+
# Misc settings
|
417 |
+
self.missing_dataset_error = \
|
418 |
+
'''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually:
|
419 |
+
|
420 |
+
model = {1}.load_from_checkpoint('checkpoint.ckpt')
|
421 |
+
model.{0} = my_{0}
|
422 |
+
'''
|
423 |
+
|
424 |
+
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
425 |
+
embeddings = []
|
426 |
+
if 'poi' not in self.disabled_embeddings:
|
427 |
+
embeddings.append(self.poi_emb(poi_emb))
|
428 |
+
if 'e3' not in self.disabled_embeddings:
|
429 |
+
embeddings.append(self.e3_emb(e3_emb))
|
430 |
+
if 'cell' not in self.disabled_embeddings:
|
431 |
+
embeddings.append(self.cell_emb(cell_emb))
|
432 |
+
if 'smiles' not in self.disabled_embeddings:
|
433 |
+
embeddings.append(self.smiles_emb(smiles_emb))
|
434 |
+
x = torch.cat(embeddings, dim=1)
|
435 |
+
x = self.dropout(F.gelu(self.fc1(x)))
|
436 |
+
x = self.dropout(F.gelu(self.fc2(x)))
|
437 |
+
x = self.fc3(x)
|
438 |
+
return x
|
439 |
+
|
440 |
+
def step(self, batch, batch_idx, stage):
|
441 |
+
poi_emb = batch['poi_emb']
|
442 |
+
e3_emb = batch['e3_emb']
|
443 |
+
cell_emb = batch['cell_emb']
|
444 |
+
smiles_emb = batch['smiles_emb']
|
445 |
+
y = batch['active'].float().unsqueeze(1)
|
446 |
+
|
447 |
+
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
|
448 |
+
loss = F.binary_cross_entropy_with_logits(y_hat, y)
|
449 |
+
|
450 |
+
self.metrics[f'{stage}_metrics'].update(y_hat, y)
|
451 |
+
self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True)
|
452 |
+
self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True)
|
453 |
+
|
454 |
+
return loss
|
455 |
+
|
456 |
+
def training_step(self, batch, batch_idx):
|
457 |
+
return self.step(batch, batch_idx, 'train')
|
458 |
+
|
459 |
+
def validation_step(self, batch, batch_idx):
|
460 |
+
return self.step(batch, batch_idx, 'val')
|
461 |
+
|
462 |
+
def test_step(self, batch, batch_idx):
|
463 |
+
return self.step(batch, batch_idx, 'test')
|
464 |
+
|
465 |
+
def configure_optimizers(self):
|
466 |
+
return optim.Adam(self.parameters(), lr=self.learning_rate)
|
467 |
+
|
468 |
+
def predict_step(self, batch, batch_idx):
|
469 |
+
poi_emb = batch['poi_emb']
|
470 |
+
e3_emb = batch['e3_emb']
|
471 |
+
cell_emb = batch['cell_emb']
|
472 |
+
smiles_emb = batch['smiles_emb']
|
473 |
+
|
474 |
+
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
|
475 |
+
return torch.sigmoid(y_hat)
|
476 |
+
|
477 |
+
def train_dataloader(self):
|
478 |
+
if self.train_dataset is None:
|
479 |
+
format = 'train_dataset', self.__class__.__name__
|
480 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
481 |
+
return DataLoader(
|
482 |
+
self.train_dataset,
|
483 |
+
batch_size=self.batch_size,
|
484 |
+
shuffle=True,
|
485 |
+
# drop_last=True,
|
486 |
+
)
|
487 |
+
|
488 |
+
def val_dataloader(self):
|
489 |
+
if self.val_dataset is None:
|
490 |
+
format = 'val_dataset', self.__class__.__name__
|
491 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
492 |
+
return DataLoader(
|
493 |
+
self.val_dataset,
|
494 |
+
batch_size=self.batch_size,
|
495 |
+
shuffle=False,
|
496 |
+
)
|
497 |
+
|
498 |
+
def test_dataloader(self):
|
499 |
+
if self.test_dataset is None:
|
500 |
+
format = 'test_dataset', self.__class__.__name__
|
501 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
502 |
+
return DataLoader(
|
503 |
+
self.test_dataset,
|
504 |
+
batch_size=self.batch_size,
|
505 |
+
shuffle=False,
|
506 |
+
)
|
507 |
+
|
508 |
+
# %% [markdown]
|
509 |
+
# ## Test Sets
|
510 |
+
|
511 |
+
# %% [markdown]
|
512 |
+
# We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
|
513 |
+
# * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
|
514 |
+
# * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
|
515 |
+
# * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
|
516 |
+
|
517 |
+
# %%
|
518 |
+
test_indeces = {}
|
519 |
+
|
520 |
+
# %% [markdown]
|
521 |
+
# Isolating the unique SMILES and Uniprots:
|
522 |
+
|
523 |
+
# %%
|
524 |
+
active_df = protac_df[protac_df[active_col].notna()].copy()
|
525 |
+
|
526 |
+
# Get the unique SMILES and Uniprot
|
527 |
+
unique_smiles = active_df['Smiles'].value_counts() == 1
|
528 |
+
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
529 |
+
print(f'Number of unique SMILES: {unique_smiles.sum()}')
|
530 |
+
print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
|
531 |
+
# Sample 1% of the len(active_df) from unique SMILES and Uniprot and get the
|
532 |
+
# indices for a test set
|
533 |
+
n = int(0.05 * len(active_df)) // 2
|
534 |
+
unique_smiles = unique_smiles[unique_smiles].sample(n=n, random_state=42)
|
535 |
+
# unique_uniprot = unique_uniprot[unique_uniprot].sample(n=, random_state=42)
|
536 |
+
unique_indices = active_df[
|
537 |
+
active_df['Smiles'].isin(unique_smiles.index) &
|
538 |
+
active_df['Uniprot'].isin(unique_uniprot.index)
|
539 |
+
].index
|
540 |
+
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
|
541 |
+
|
542 |
+
test_indeces['random'] = unique_indices
|
543 |
+
|
544 |
+
# # Get the test set
|
545 |
+
# test_df = active_df.loc[unique_indices]
|
546 |
+
# # Bar plot of the test Active distribution as percentage
|
547 |
+
# test_df['Active'].value_counts(normalize=True).plot(kind='bar')
|
548 |
+
# plt.title('Test set Active distribution')
|
549 |
+
# plt.show()
|
550 |
+
# # Bar plot of the test Active - OR distribution as percentage
|
551 |
+
# test_df['Active - OR'].value_counts(normalize=True).plot(kind='bar')
|
552 |
+
# plt.title('Test set Active - OR distribution')
|
553 |
+
# plt.show()
|
554 |
+
|
555 |
+
# %% [markdown]
|
556 |
+
# Isolating the unique Uniprots:
|
557 |
+
|
558 |
+
# %%
|
559 |
+
active_df = protac_df[protac_df[active_col].notna()].copy()
|
560 |
+
|
561 |
+
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
562 |
+
print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
|
563 |
+
|
564 |
+
# NOTE: Since they are very few, all unique Uniprot will be used as test set.
|
565 |
+
# Get the indices for a test set
|
566 |
+
unique_indices = active_df[active_df['Uniprot'].isin(unique_uniprot.index)].index
|
567 |
+
|
568 |
+
|
569 |
+
test_indeces['uniprot'] = unique_indices
|
570 |
+
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
|
571 |
+
|
572 |
+
# %% [markdown]
|
573 |
+
# DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
|
574 |
+
# * its SMILES is unique
|
575 |
+
# * its POI is unique
|
576 |
+
# * its (SMILES, POI) pair is unique
|
577 |
+
|
578 |
+
# %%
|
579 |
+
active_df = protac_df[protac_df[active_col].notna()]
|
580 |
+
|
581 |
+
# Find the samples that:
|
582 |
+
# * have their SMILES appearing only once in the dataframe
|
583 |
+
# * have their Uniprot appearing only once in the dataframe
|
584 |
+
# * have their (Smiles, Uniprot) pair appearing only once in the dataframe
|
585 |
+
unique_smiles = active_df['Smiles'].value_counts() == 1
|
586 |
+
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
587 |
+
unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1
|
588 |
+
|
589 |
+
# Get the indices of the unique samples
|
590 |
+
unique_smiles_idx = active_df['Smiles'].map(unique_smiles)
|
591 |
+
unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot)
|
592 |
+
unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot)
|
593 |
+
|
594 |
+
# Cross the indices to get the unique samples
|
595 |
+
# unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx & unique_smiles_uniprot_idx].index
|
596 |
+
unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx].index
|
597 |
+
test_df = active_df.loc[unique_samples]
|
598 |
+
|
599 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
600 |
+
|
601 |
+
# %% [markdown]
|
602 |
+
# ## Cross-Validation Training
|
603 |
+
|
604 |
+
# %% [markdown]
|
605 |
+
# Cross validation training with 5 splits. The split operation is done in three different ways:
|
606 |
+
#
|
607 |
+
# * Random split
|
608 |
+
# * POI-wise: some POIs never in both splits
|
609 |
+
# * Least Tanimoto similarity PROTAC-wise
|
610 |
+
|
611 |
+
# %% [markdown]
|
612 |
+
# ### Plotting CV Folds
|
613 |
+
|
614 |
+
# %%
|
615 |
+
from sklearn.model_selection import (
|
616 |
+
StratifiedKFold,
|
617 |
+
StratifiedGroupKFold,
|
618 |
+
)
|
619 |
+
from sklearn.preprocessing import OrdinalEncoder
|
620 |
+
|
621 |
+
# NOTE: When set to 60, it will result in 29 groups, with nice distributions of
|
622 |
+
# the number of unique groups in the train and validation sets, together with
|
623 |
+
# the number of active and inactive PROTACs.
|
624 |
+
n_bins_tanimoto = 60 if active_col == 'Active' else 400
|
625 |
+
n_splits = 5
|
626 |
+
# The train and validation sets will be created from the active PROTACs only,
|
627 |
+
# i.e., the ones with 'Active' column not NaN, and that are NOT in the test set
|
628 |
+
active_df = protac_df[protac_df[active_col].notna()]
|
629 |
+
train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
|
630 |
+
|
631 |
+
# Make three groups for CV:
|
632 |
+
# * Random split
|
633 |
+
# * Split by Uniprot (POI)
|
634 |
+
# * Split by least tanimoto similarity PROTAC-wise
|
635 |
+
groups = [
|
636 |
+
'random',
|
637 |
+
'uniprot',
|
638 |
+
'tanimoto',
|
639 |
+
]
|
640 |
+
for group_type in groups:
|
641 |
+
if group_type == 'random':
|
642 |
+
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
643 |
+
groups = None
|
644 |
+
elif group_type == 'uniprot':
|
645 |
+
# Split by Uniprot
|
646 |
+
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
647 |
+
encoder = OrdinalEncoder()
|
648 |
+
groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
|
649 |
+
print(f'Number of unique groups: {len(encoder.categories_[0])}')
|
650 |
+
elif group_type == 'tanimoto':
|
651 |
+
# Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
|
652 |
+
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
653 |
+
tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
654 |
+
encoder = OrdinalEncoder()
|
655 |
+
groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
|
656 |
+
print(f'Number of unique groups: {len(encoder.categories_[0])}')
|
657 |
+
|
658 |
+
|
659 |
+
X = train_val_df.drop(columns=active_col)
|
660 |
+
y = train_val_df[active_col].tolist()
|
661 |
+
|
662 |
+
# print(f'Group: {group_type}')
|
663 |
+
# fig, ax = plt.subplots(figsize=(6, 3))
|
664 |
+
# plot_cv_indices(kf, X=X, y=y, group=groups, ax=ax, n_splits=n_splits)
|
665 |
+
# plt.tight_layout()
|
666 |
+
# plt.show()
|
667 |
+
|
668 |
+
stats = []
|
669 |
+
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
670 |
+
train_df = train_val_df.iloc[train_index]
|
671 |
+
val_df = train_val_df.iloc[val_index]
|
672 |
+
stat = {
|
673 |
+
'fold': k,
|
674 |
+
'train_len': len(train_df),
|
675 |
+
'val_len': len(val_df),
|
676 |
+
'train_perc': len(train_df) / len(train_val_df),
|
677 |
+
'val_perc': len(val_df) / len(train_val_df),
|
678 |
+
'train_active (%)': train_df[active_col].sum() / len(train_df) * 100,
|
679 |
+
'train_inactive (%)': (len(train_df) - train_df[active_col].sum()) / len(train_df) * 100,
|
680 |
+
'val_active (%)': val_df[active_col].sum() / len(val_df) * 100,
|
681 |
+
'val_inactive (%)': (len(val_df) - val_df[active_col].sum()) / len(val_df) * 100,
|
682 |
+
'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
|
683 |
+
'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
|
684 |
+
}
|
685 |
+
if group_type != 'random':
|
686 |
+
stat['train_unique_groups'] = len(np.unique(groups[train_index]))
|
687 |
+
stat['val_unique_groups'] = len(np.unique(groups[val_index]))
|
688 |
+
stats.append(stat)
|
689 |
+
print('-' * 120)
|
690 |
+
|
691 |
+
# %% [markdown]
|
692 |
+
# ### Run CV
|
693 |
+
|
694 |
+
# %%
|
695 |
+
import warnings
|
696 |
+
|
697 |
+
# Seed everything in pytorch lightning
|
698 |
+
pl.seed_everything(42)
|
699 |
+
|
700 |
+
|
701 |
+
def train_model(
|
702 |
+
train_df,
|
703 |
+
val_df,
|
704 |
+
test_df=None,
|
705 |
+
hidden_dim=768,
|
706 |
+
batch_size=8,
|
707 |
+
learning_rate=2e-5,
|
708 |
+
max_epochs=50,
|
709 |
+
smiles_emb_dim=1024,
|
710 |
+
smote_n_neighbors=5,
|
711 |
+
use_ored_activity=False if active_col == 'Active' else True,
|
712 |
+
fast_dev_run=False,
|
713 |
+
disabled_embeddings=[],
|
714 |
+
) -> tuple:
|
715 |
+
""" Train a PROTAC model using the given datasets and hyperparameters.
|
716 |
+
|
717 |
+
Args:
|
718 |
+
train_df (pd.DataFrame): The training set.
|
719 |
+
val_df (pd.DataFrame): The validation set.
|
720 |
+
test_df (pd.DataFrame): The test set.
|
721 |
+
hidden_dim (int): The hidden dimension of the model.
|
722 |
+
batch_size (int): The batch size.
|
723 |
+
learning_rate (float): The learning rate.
|
724 |
+
max_epochs (int): The maximum number of epochs.
|
725 |
+
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
726 |
+
smote_n_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
727 |
+
use_ored_activity (bool): Whether to use the ORED activity column.
|
728 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
729 |
+
disabled_embeddings (list): The list of disabled embeddings.
|
730 |
+
|
731 |
+
Returns:
|
732 |
+
tuple: The trained model, the trainer, and the metrics.
|
733 |
+
"""
|
734 |
+
oversampler = SMOTE(k_neighbors=smote_n_neighbors, random_state=42)
|
735 |
+
train_ds = PROTAC_Dataset(
|
736 |
+
train_df,
|
737 |
+
protein_embeddings,
|
738 |
+
cell2embedding,
|
739 |
+
smiles2fp,
|
740 |
+
use_smote=True,
|
741 |
+
oversampler=oversampler,
|
742 |
+
use_ored_activity=use_ored_activity,
|
743 |
+
)
|
744 |
+
val_ds = PROTAC_Dataset(
|
745 |
+
val_df,
|
746 |
+
protein_embeddings,
|
747 |
+
cell2embedding,
|
748 |
+
smiles2fp,
|
749 |
+
use_ored_activity=use_ored_activity,
|
750 |
+
)
|
751 |
+
if test_df is not None:
|
752 |
+
test_ds = PROTAC_Dataset(
|
753 |
+
test_df,
|
754 |
+
protein_embeddings,
|
755 |
+
cell2embedding,
|
756 |
+
smiles2fp,
|
757 |
+
use_ored_activity=use_ored_activity,
|
758 |
+
)
|
759 |
+
logger = pl.loggers.TensorBoardLogger(
|
760 |
+
save_dir='../logs',
|
761 |
+
name='protac',
|
762 |
+
)
|
763 |
+
callbacks = [
|
764 |
+
pl.callbacks.EarlyStopping(
|
765 |
+
monitor='train_loss',
|
766 |
+
patience=10,
|
767 |
+
mode='max',
|
768 |
+
verbose=True,
|
769 |
+
),
|
770 |
+
# pl.callbacks.ModelCheckpoint(
|
771 |
+
# monitor='val_acc',
|
772 |
+
# mode='max',
|
773 |
+
# verbose=True,
|
774 |
+
# filename='{epoch}-{val_metrics_opt_score:.4f}',
|
775 |
+
# ),
|
776 |
+
]
|
777 |
+
# Define Trainer
|
778 |
+
trainer = pl.Trainer(
|
779 |
+
logger=logger,
|
780 |
+
callbacks=callbacks,
|
781 |
+
max_epochs=max_epochs,
|
782 |
+
fast_dev_run=fast_dev_run,
|
783 |
+
enable_model_summary=False,
|
784 |
+
enable_checkpointing=False,
|
785 |
+
)
|
786 |
+
model = PROTAC_Model(
|
787 |
+
hidden_dim=hidden_dim,
|
788 |
+
smiles_emb_dim=smiles_emb_dim,
|
789 |
+
poi_emb_dim=1024,
|
790 |
+
e3_emb_dim=1024,
|
791 |
+
cell_emb_dim=768,
|
792 |
+
batch_size=batch_size,
|
793 |
+
learning_rate=learning_rate,
|
794 |
+
train_dataset=train_ds,
|
795 |
+
val_dataset=val_ds,
|
796 |
+
test_dataset=test_ds if test_df is not None else None,
|
797 |
+
disabled_embeddings=disabled_embeddings,
|
798 |
+
)
|
799 |
+
with warnings.catch_warnings():
|
800 |
+
warnings.simplefilter("ignore")
|
801 |
+
trainer.fit(model)
|
802 |
+
metrics = trainer.validate(model, verbose=False)[0]
|
803 |
+
if test_df is not None:
|
804 |
+
test_metrics = trainer.test(model, verbose=False)[0]
|
805 |
+
metrics.update(test_metrics)
|
806 |
+
return model, trainer, metrics
|
807 |
+
|
808 |
+
# %% [markdown]
|
809 |
+
# Setup hyperparameter optimization:
|
810 |
+
|
811 |
+
# %%
|
812 |
+
import optuna
|
813 |
+
import pandas as pd
|
814 |
+
|
815 |
+
|
816 |
+
def objective(
|
817 |
+
trial,
|
818 |
+
train_df,
|
819 |
+
val_df,
|
820 |
+
hidden_dim_options,
|
821 |
+
batch_size_options,
|
822 |
+
learning_rate_options,
|
823 |
+
max_epochs_options,
|
824 |
+
fast_dev_run=False,
|
825 |
+
) -> float:
|
826 |
+
# Generate the hyperparameters
|
827 |
+
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
828 |
+
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
829 |
+
learning_rate = trial.suggest_loguniform('learning_rate', *learning_rate_options)
|
830 |
+
max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options)
|
831 |
+
|
832 |
+
# Train the model with the current set of hyperparameters
|
833 |
+
_, _, metrics = train_model(
|
834 |
+
train_df,
|
835 |
+
val_df,
|
836 |
+
hidden_dim=hidden_dim,
|
837 |
+
batch_size=batch_size,
|
838 |
+
learning_rate=learning_rate,
|
839 |
+
max_epochs=max_epochs,
|
840 |
+
fast_dev_run=fast_dev_run,
|
841 |
+
)
|
842 |
+
|
843 |
+
# Metrics is a dictionary containing at least the validation loss
|
844 |
+
val_loss = metrics['val_loss']
|
845 |
+
val_acc = metrics['val_acc']
|
846 |
+
val_roc_auc = metrics['val_roc_auc']
|
847 |
+
|
848 |
+
# Optuna aims to minimize the objective
|
849 |
+
return val_loss - val_acc - val_roc_auc
|
850 |
+
|
851 |
+
|
852 |
+
def hyperparameter_tuning_and_training(
|
853 |
+
train_df,
|
854 |
+
val_df,
|
855 |
+
test_df,
|
856 |
+
fast_dev_run=False,
|
857 |
+
n_trials=20,
|
858 |
+
) -> tuple:
|
859 |
+
""" Hyperparameter tuning and training of a PROTAC model.
|
860 |
+
|
861 |
+
Args:
|
862 |
+
train_df (pd.DataFrame): The training set.
|
863 |
+
val_df (pd.DataFrame): The validation set.
|
864 |
+
test_df (pd.DataFrame): The test set.
|
865 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
866 |
+
|
867 |
+
Returns:
|
868 |
+
tuple: The trained model, the trainer, and the best metrics.
|
869 |
+
"""
|
870 |
+
# Define the search space
|
871 |
+
hidden_dim_options = [256, 512, 768]
|
872 |
+
batch_size_options = [8, 16, 32]
|
873 |
+
learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
|
874 |
+
max_epochs_options = [10, 20, 50]
|
875 |
+
|
876 |
+
# Create an Optuna study object
|
877 |
+
study = optuna.create_study(direction='minimize')
|
878 |
+
study.optimize(lambda trial: objective(
|
879 |
+
trial,
|
880 |
+
train_df,
|
881 |
+
val_df,
|
882 |
+
hidden_dim_options,
|
883 |
+
batch_size_options,
|
884 |
+
learning_rate_options,
|
885 |
+
max_epochs_options,
|
886 |
+
fast_dev_run=fast_dev_run,),
|
887 |
+
n_trials=n_trials,
|
888 |
+
)
|
889 |
+
|
890 |
+
# Retrieve the best hyperparameters
|
891 |
+
best_params = study.best_params
|
892 |
+
best_hidden_dim = best_params['hidden_dim']
|
893 |
+
best_batch_size = best_params['batch_size']
|
894 |
+
best_learning_rate = best_params['learning_rate']
|
895 |
+
best_max_epochs = best_params['max_epochs']
|
896 |
+
|
897 |
+
# Retrain the model with the best hyperparameters
|
898 |
+
model, trainer, metrics = train_model(
|
899 |
+
train_df,
|
900 |
+
val_df,
|
901 |
+
test_df,
|
902 |
+
hidden_dim=best_hidden_dim,
|
903 |
+
batch_size=best_batch_size,
|
904 |
+
learning_rate=best_learning_rate,
|
905 |
+
max_epochs=best_max_epochs,
|
906 |
+
fast_dev_run=fast_dev_run,
|
907 |
+
)
|
908 |
+
|
909 |
+
# Return the best metrics
|
910 |
+
return model, trainer, metrics
|
911 |
+
|
912 |
+
# Example usage
|
913 |
+
# train_df, val_df, test_df = load_your_data() # You need to load your datasets here
|
914 |
+
# model, trainer, best_metrics = hyperparameter_tuning_and_training(train_df, val_df, test_df)
|
915 |
+
|
916 |
+
# %% [markdown]
|
917 |
+
# Loop over the different splits and train the model:
|
918 |
+
|
919 |
+
# %%
|
920 |
+
n_splits = 5
|
921 |
+
report = []
|
922 |
+
active_df = protac_df[protac_df[active_col].notna()]
|
923 |
+
train_val_df = active_df[~active_df.index.isin(unique_samples)]
|
924 |
+
|
925 |
+
# Make directory ../reports if it does not exist
|
926 |
+
if not os.path.exists('../reports'):
|
927 |
+
os.makedirs('../reports')
|
928 |
+
|
929 |
+
for group_type in ['random', 'uniprot', 'tanimoto']:
|
930 |
+
print(f'Starting CV for group type: {group_type}')
|
931 |
+
# Setup CV iterator and groups
|
932 |
+
if group_type == 'random':
|
933 |
+
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
934 |
+
groups = None
|
935 |
+
elif group_type == 'uniprot':
|
936 |
+
# Split by Uniprot
|
937 |
+
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
938 |
+
encoder = OrdinalEncoder()
|
939 |
+
groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
|
940 |
+
elif group_type == 'tanimoto':
|
941 |
+
# Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
|
942 |
+
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
943 |
+
tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
944 |
+
encoder = OrdinalEncoder()
|
945 |
+
groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
|
946 |
+
# Start the CV over the folds
|
947 |
+
X = train_val_df.drop(columns=active_col)
|
948 |
+
y = train_val_df[active_col].tolist()
|
949 |
+
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
950 |
+
train_df = train_val_df.iloc[train_index]
|
951 |
+
val_df = train_val_df.iloc[val_index]
|
952 |
+
stats = {
|
953 |
+
'fold': k,
|
954 |
+
'group_type': group_type,
|
955 |
+
'train_len': len(train_df),
|
956 |
+
'val_len': len(val_df),
|
957 |
+
'train_perc': len(train_df) / len(train_val_df),
|
958 |
+
'val_perc': len(val_df) / len(train_val_df),
|
959 |
+
'train_active_perc': train_df[active_col].sum() / len(train_df),
|
960 |
+
'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
|
961 |
+
'val_active_perc': val_df[active_col].sum() / len(val_df),
|
962 |
+
'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
|
963 |
+
'test_active_perc': test_df[active_col].sum() / len(test_df),
|
964 |
+
'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
|
965 |
+
'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
|
966 |
+
'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
|
967 |
+
}
|
968 |
+
if group_type != 'random':
|
969 |
+
stats['train_unique_groups'] = len(np.unique(groups[train_index]))
|
970 |
+
stats['val_unique_groups'] = len(np.unique(groups[val_index]))
|
971 |
+
# Train and evaluate the model
|
972 |
+
# model, trainer, metrics = train_model(train_df, val_df, test_df)
|
973 |
+
model, trainer, metrics = hyperparameter_tuning_and_training(
|
974 |
+
train_df,
|
975 |
+
val_df,
|
976 |
+
test_df,
|
977 |
+
fast_dev_run=False,
|
978 |
+
n_trials=50,
|
979 |
+
)
|
980 |
+
stats.update(metrics)
|
981 |
+
del model
|
982 |
+
del trainer
|
983 |
+
report.append(stats)
|
984 |
+
report = pd.DataFrame(report)
|
985 |
+
report.to_csv(
|
986 |
+
f'../reports/cv_report_hparam_search_{n_splits}-splits.csv', index=False,
|
987 |
+
)
|
988 |
+
|