Spaces:
Sleeping
Sleeping
saicharan2804
commited on
Commit
·
59500aa
1
Parent(s):
0f253ff
Fixed error
Browse files- app.py +0 -50
- molgenevalmetric.py +338 -1
app.py
CHANGED
@@ -1,56 +1,6 @@
|
|
1 |
import evaluate
|
2 |
from evaluate.utils import launch_gradio_widget
|
3 |
import gradio as gr
|
4 |
-
# from pathlib import Path
|
5 |
-
# import sys
|
6 |
-
# import os
|
7 |
-
|
8 |
-
# from .logging import get_logger
|
9 |
-
# logger = get_logger(__name__)
|
10 |
-
|
11 |
-
# ###
|
12 |
-
# def launch_gradio_widget(metric):
|
13 |
-
# """Launches `metric` widget with Gradio."""
|
14 |
-
|
15 |
-
# try:
|
16 |
-
# import gradio as gr
|
17 |
-
# except ImportError as error:
|
18 |
-
# logger.error("To create a metric widget with Gradio make sure gradio is installed.")
|
19 |
-
# raise error
|
20 |
-
|
21 |
-
# local_path = Path(sys.path[0])
|
22 |
-
# # if there are several input types, use first as default.
|
23 |
-
# if isinstance(metric.features, list):
|
24 |
-
# (feature_names, feature_types) = zip(*metric.features[0].items())
|
25 |
-
# else:
|
26 |
-
# (feature_names, feature_types) = zip(*metric.features.items())
|
27 |
-
# gradio_input_types = infer_gradio_input_types(feature_types)
|
28 |
-
|
29 |
-
# def compute(data):
|
30 |
-
# return metric.compute(**parse_gradio_data(data, gradio_input_types))
|
31 |
-
|
32 |
-
# iface = gr.Interface(
|
33 |
-
# fn=compute,
|
34 |
-
# inputs=gr.Dataframe(
|
35 |
-
# headers=feature_names,
|
36 |
-
# col_count=len(feature_names),
|
37 |
-
# row_count=1,
|
38 |
-
# datatype=json_to_string_type(gradio_input_types),
|
39 |
-
# ),
|
40 |
-
# outputs=gr.Textbox(label=metric.name),
|
41 |
-
# description=(
|
42 |
-
# metric.info.description + "\nIf this is a text-based metric, make sure to wrap you input in double quotes."
|
43 |
-
# " Alternatively you can use a JSON-formatted list as input."
|
44 |
-
# ),
|
45 |
-
# title=f"Metric: {metric.name}",
|
46 |
-
# article=parse_readme(local_path / "README.md"),
|
47 |
-
# # TODO: load test cases and use them to populate examples
|
48 |
-
# # examples=[parse_test_cases(test_cases, feature_names, gradio_input_types)]
|
49 |
-
# )
|
50 |
-
|
51 |
-
# iface.launch()
|
52 |
-
# ###
|
53 |
-
|
54 |
|
55 |
module = evaluate.load("saicharan2804/molgenevalmetric")
|
56 |
# launch_gradio_widget(module)
|
|
|
1 |
import evaluate
|
2 |
from evaluate.utils import launch_gradio_widget
|
3 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
module = evaluate.load("saicharan2804/molgenevalmetric")
|
6 |
# launch_gradio_widget(module)
|
molgenevalmetric.py
CHANGED
@@ -5,7 +5,344 @@ import datasets
|
|
5 |
import pandas as pd
|
6 |
# from tdc import Evaluator
|
7 |
# from tdc import Oracle
|
8 |
-
from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
_DESCRIPTION = """
|
|
|
5 |
import pandas as pd
|
6 |
# from tdc import Evaluator
|
7 |
# from tdc import Oracle
|
8 |
+
# from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
|
9 |
+
|
10 |
+
import os
|
11 |
+
from collections import Counter
|
12 |
+
from functools import partial
|
13 |
+
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
import scipy.sparse
|
16 |
+
import torch
|
17 |
+
from rdkit import Chem
|
18 |
+
from rdkit.Chem import AllChem
|
19 |
+
from rdkit.Chem import MACCSkeys
|
20 |
+
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
|
21 |
+
from rdkit.Chem.QED import qed
|
22 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
23 |
+
from rdkit.Chem import Descriptors
|
24 |
+
|
25 |
+
import random
|
26 |
+
from multiprocessing import Pool
|
27 |
+
from collections import UserList, defaultdict
|
28 |
+
import numpy as np
|
29 |
+
import pandas as pd
|
30 |
+
from rdkit import rdBase
|
31 |
+
import sys
|
32 |
+
|
33 |
+
from rdkit.Chem import RDConfig
|
34 |
+
import os
|
35 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
36 |
+
import sascorer
|
37 |
+
import pandas as pd
|
38 |
+
from fcd_torch import FCD
|
39 |
+
from syba.syba import SybaClassifier
|
40 |
+
|
41 |
+
from tdc import Evaluator
|
42 |
+
from tdc import Oracle
|
43 |
+
|
44 |
+
|
45 |
+
def get_mol(smiles_or_mol):
|
46 |
+
'''
|
47 |
+
Loads SMILES/molecule into RDKit's object
|
48 |
+
'''
|
49 |
+
if isinstance(smiles_or_mol, str):
|
50 |
+
if len(smiles_or_mol) == 0:
|
51 |
+
return None
|
52 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
53 |
+
if mol is None:
|
54 |
+
return None
|
55 |
+
try:
|
56 |
+
Chem.SanitizeMol(mol)
|
57 |
+
except ValueError:
|
58 |
+
return None
|
59 |
+
return mol
|
60 |
+
return smiles_or_mol
|
61 |
+
|
62 |
+
def mapper(n_jobs):
|
63 |
+
'''
|
64 |
+
Returns function for map call.
|
65 |
+
If n_jobs == 1, will use standard map
|
66 |
+
If n_jobs > 1, will use multiprocessing pool
|
67 |
+
If n_jobs is a pool object, will return its map function
|
68 |
+
'''
|
69 |
+
if n_jobs == 1:
|
70 |
+
def _mapper(*args, **kwargs):
|
71 |
+
return list(map(*args, **kwargs))
|
72 |
+
|
73 |
+
return _mapper
|
74 |
+
if isinstance(n_jobs, int):
|
75 |
+
pool = Pool(n_jobs)
|
76 |
+
|
77 |
+
def _mapper(*args, **kwargs):
|
78 |
+
try:
|
79 |
+
result = pool.map(*args, **kwargs)
|
80 |
+
finally:
|
81 |
+
pool.terminate()
|
82 |
+
return result
|
83 |
+
|
84 |
+
return _mapper
|
85 |
+
return n_jobs.map
|
86 |
+
|
87 |
+
def fraction_valid(gen, n_jobs=1):
|
88 |
+
"""
|
89 |
+
Computes a number of valid molecules
|
90 |
+
Parameters:
|
91 |
+
gen: list of SMILES
|
92 |
+
n_jobs: number of threads for calculation
|
93 |
+
"""
|
94 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
95 |
+
return 1 - gen.count(None) / len(gen)
|
96 |
+
|
97 |
+
|
98 |
+
def canonic_smiles(smiles_or_mol):
|
99 |
+
mol = get_mol(smiles_or_mol)
|
100 |
+
if mol is None:
|
101 |
+
return None
|
102 |
+
return Chem.MolToSmiles(mol)
|
103 |
+
|
104 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
105 |
+
"""
|
106 |
+
Computes a number of unique molecules
|
107 |
+
Parameters:
|
108 |
+
gen: list of SMILES
|
109 |
+
k: compute unique@k
|
110 |
+
n_jobs: number of threads for calculation
|
111 |
+
check_validity: raises ValueError if invalid molecules are present
|
112 |
+
"""
|
113 |
+
if k is not None:
|
114 |
+
if len(gen) < k:
|
115 |
+
warnings.warn(
|
116 |
+
"Can't compute unique@{}.".format(k) +
|
117 |
+
"gen contains only {} molecules".format(len(gen))
|
118 |
+
)
|
119 |
+
gen = gen[:k]
|
120 |
+
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
121 |
+
|
122 |
+
if None in canonic and check_validity:
|
123 |
+
raise ValueError("Invalid molecule passed to unique@k")
|
124 |
+
return len(canonic) / len(gen)
|
125 |
+
|
126 |
+
def novelty(gen, train, n_jobs=1):
|
127 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
128 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
129 |
+
train_set = set(train)
|
130 |
+
return len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
131 |
+
|
132 |
+
|
133 |
+
def SAscore(gen):
|
134 |
+
"""
|
135 |
+
Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings.
|
136 |
+
|
137 |
+
Parameters:
|
138 |
+
- smiles_list (list of str): A list containing the SMILES representations of the molecules.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
- float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found.
|
142 |
+
"""
|
143 |
+
scores = []
|
144 |
+
for smiles in gen:
|
145 |
+
mol = Chem.MolFromSmiles(smiles)
|
146 |
+
if mol: # Ensures the molecule could be parsed from the SMILES string
|
147 |
+
score = sascorer.calculateScore(mol)
|
148 |
+
scores.append(score)
|
149 |
+
|
150 |
+
if scores: # Checks if there are any scores calculated
|
151 |
+
return np.mean(scores)
|
152 |
+
else:
|
153 |
+
return None
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
158 |
+
batch_size=5000, agg='max',
|
159 |
+
device='cpu', p=1):
|
160 |
+
"""
|
161 |
+
For each molecule in gen_vecs finds closest molecule in stock_vecs.
|
162 |
+
Returns average tanimoto score for between these molecules
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
stock_vecs: numpy array <n_vectors x dim>
|
166 |
+
gen_vecs: numpy array <n_vectors' x dim>
|
167 |
+
agg: max or mean
|
168 |
+
p: power for averaging: (mean x^p)^(1/p)
|
169 |
+
"""
|
170 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
171 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
172 |
+
total = np.zeros(len(gen_vecs))
|
173 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
174 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
175 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
176 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
177 |
+
y_gen = y_gen.transpose(0, 1)
|
178 |
+
tp = torch.mm(x_stock, y_gen)
|
179 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) +
|
180 |
+
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
181 |
+
jac[np.isnan(jac)] = 1
|
182 |
+
if p != 1:
|
183 |
+
jac = jac**p
|
184 |
+
if agg == 'max':
|
185 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
186 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
187 |
+
elif agg == 'mean':
|
188 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
189 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
190 |
+
if agg == 'mean':
|
191 |
+
agg_tanimoto /= total
|
192 |
+
if p != 1:
|
193 |
+
agg_tanimoto = (agg_tanimoto)**(1/p)
|
194 |
+
return np.mean(agg_tanimoto)
|
195 |
+
|
196 |
+
def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
|
197 |
+
morgan__n=1024, *args, **kwargs):
|
198 |
+
"""
|
199 |
+
Generates fingerprint for SMILES
|
200 |
+
If smiles is invalid, returns None
|
201 |
+
Returns numpy array of fingerprint bits
|
202 |
+
|
203 |
+
Parameters:
|
204 |
+
smiles: SMILES string
|
205 |
+
type: type of fingerprint: [MACCS|morgan]
|
206 |
+
dtype: if not None, specifies the dtype of returned array
|
207 |
+
"""
|
208 |
+
fp_type = fp_type.lower()
|
209 |
+
molecule = get_mol(smiles_or_mol, *args, **kwargs)
|
210 |
+
if molecule is None:
|
211 |
+
return None
|
212 |
+
if fp_type == 'maccs':
|
213 |
+
keys = MACCSkeys.GenMACCSKeys(molecule)
|
214 |
+
keys = np.array(keys.GetOnBits())
|
215 |
+
fingerprint = np.zeros(166, dtype='uint8')
|
216 |
+
if len(keys) != 0:
|
217 |
+
fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
|
218 |
+
elif fp_type == 'morgan':
|
219 |
+
fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
|
220 |
+
dtype='uint8')
|
221 |
+
else:
|
222 |
+
raise ValueError("Unknown fingerprint type {}".format(fp_type))
|
223 |
+
if dtype is not None:
|
224 |
+
fingerprint = fingerprint.astype(dtype)
|
225 |
+
return fingerprint
|
226 |
+
|
227 |
+
|
228 |
+
def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
|
229 |
+
**kwargs):
|
230 |
+
'''
|
231 |
+
Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
|
232 |
+
e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
|
233 |
+
Inserts np.NaN to rows corresponding to incorrect smiles.
|
234 |
+
IMPORTANT: if there is at least one np.NaN, the dtype would be float
|
235 |
+
Parameters:
|
236 |
+
smiles_mols_array: list/array/pd.Series of smiles or already computed
|
237 |
+
RDKit molecules
|
238 |
+
n_jobs: number of parralel workers to execute
|
239 |
+
already_unique: flag for performance reasons, if smiles array is big
|
240 |
+
and already unique. Its value is set to True if smiles_mols_array
|
241 |
+
contain RDKit molecules already.
|
242 |
+
'''
|
243 |
+
if isinstance(smiles_mols_array, pd.Series):
|
244 |
+
smiles_mols_array = smiles_mols_array.values
|
245 |
+
else:
|
246 |
+
smiles_mols_array = np.asarray(smiles_mols_array)
|
247 |
+
if not isinstance(smiles_mols_array[0], str):
|
248 |
+
already_unique = True
|
249 |
+
|
250 |
+
if not already_unique:
|
251 |
+
smiles_mols_array, inv_index = np.unique(smiles_mols_array,
|
252 |
+
return_inverse=True)
|
253 |
+
|
254 |
+
fps = mapper(n_jobs)(
|
255 |
+
partial(fingerprint, *args, **kwargs), smiles_mols_array
|
256 |
+
)
|
257 |
+
|
258 |
+
length = 1
|
259 |
+
for fp in fps:
|
260 |
+
if fp is not None:
|
261 |
+
length = fp.shape[-1]
|
262 |
+
first_fp = fp
|
263 |
+
break
|
264 |
+
fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
|
265 |
+
for fp in fps]
|
266 |
+
if scipy.sparse.issparse(first_fp):
|
267 |
+
fps = scipy.sparse.vstack(fps).tocsr()
|
268 |
+
else:
|
269 |
+
fps = np.vstack(fps)
|
270 |
+
if not already_unique:
|
271 |
+
return fps[inv_index]
|
272 |
+
return fps
|
273 |
+
|
274 |
+
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
|
275 |
+
gen_fps=None, p=1):
|
276 |
+
"""
|
277 |
+
Computes internal diversity as:
|
278 |
+
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
|
279 |
+
"""
|
280 |
+
if gen_fps is None:
|
281 |
+
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
|
282 |
+
return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
|
283 |
+
agg='mean', device=device, p=p)).mean()
|
284 |
+
|
285 |
+
|
286 |
+
def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
|
287 |
+
fcd = FCD(device=device, n_jobs= n_jobs)
|
288 |
+
return fcd(gen, train)
|
289 |
+
|
290 |
+
def SYBAscore(gen):
|
291 |
+
"""
|
292 |
+
Compute the average SYBA score for a list of SMILES strings.
|
293 |
+
|
294 |
+
Parameters:
|
295 |
+
- smiles_list (list of str): A list of SMILES strings representing molecules.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
- float: The average SYBA score for the list of molecules.
|
299 |
+
"""
|
300 |
+
syba = SybaClassifier()
|
301 |
+
syba.fitDefaultScore()
|
302 |
+
scores = []
|
303 |
+
|
304 |
+
for smiles in gen:
|
305 |
+
try:
|
306 |
+
score = syba.predict(smi=smiles)
|
307 |
+
scores.append(score)
|
308 |
+
except Exception as e:
|
309 |
+
print(f"Error processing SMILES '{smiles}': {e}")
|
310 |
+
continue
|
311 |
+
|
312 |
+
if scores:
|
313 |
+
return sum(scores) / len(scores)
|
314 |
+
else:
|
315 |
+
return None # Or handle empty list or all failed predictions as needed
|
316 |
+
|
317 |
+
def oracles(gen, train):
|
318 |
+
Result = {}
|
319 |
+
# evaluator = Evaluator(name = 'KL_Divergence')
|
320 |
+
# KL_Divergence = evaluator(gen, train)
|
321 |
+
|
322 |
+
# Result["KL_Divergence"]: KL_Divergence
|
323 |
+
|
324 |
+
|
325 |
+
oracle_list = [
|
326 |
+
'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
|
327 |
+
'DRD2', 'LogP', 'Rediscovery', 'Similarity',
|
328 |
+
'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
|
329 |
+
]
|
330 |
+
|
331 |
+
for oracle_name in oracle_list:
|
332 |
+
oracle = Oracle(name=oracle_name)
|
333 |
+
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
|
334 |
+
score = oracle(gen)
|
335 |
+
if isinstance(score, dict):
|
336 |
+
score = {key: sum(values)/len(values) for key, values in score.items()}
|
337 |
+
else:
|
338 |
+
score = oracle(gen)
|
339 |
+
if isinstance(score, list):
|
340 |
+
score = sum(score) / len(score)
|
341 |
+
|
342 |
+
Result[f"{oracle_name}"] = score
|
343 |
+
|
344 |
+
return Result
|
345 |
+
|
346 |
|
347 |
|
348 |
_DESCRIPTION = """
|