saicharan2804 commited on
Commit
59500aa
·
1 Parent(s): 0f253ff

Fixed error

Browse files
Files changed (2) hide show
  1. app.py +0 -50
  2. molgenevalmetric.py +338 -1
app.py CHANGED
@@ -1,56 +1,6 @@
1
  import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
  import gradio as gr
4
- # from pathlib import Path
5
- # import sys
6
- # import os
7
-
8
- # from .logging import get_logger
9
- # logger = get_logger(__name__)
10
-
11
- # ###
12
- # def launch_gradio_widget(metric):
13
- # """Launches `metric` widget with Gradio."""
14
-
15
- # try:
16
- # import gradio as gr
17
- # except ImportError as error:
18
- # logger.error("To create a metric widget with Gradio make sure gradio is installed.")
19
- # raise error
20
-
21
- # local_path = Path(sys.path[0])
22
- # # if there are several input types, use first as default.
23
- # if isinstance(metric.features, list):
24
- # (feature_names, feature_types) = zip(*metric.features[0].items())
25
- # else:
26
- # (feature_names, feature_types) = zip(*metric.features.items())
27
- # gradio_input_types = infer_gradio_input_types(feature_types)
28
-
29
- # def compute(data):
30
- # return metric.compute(**parse_gradio_data(data, gradio_input_types))
31
-
32
- # iface = gr.Interface(
33
- # fn=compute,
34
- # inputs=gr.Dataframe(
35
- # headers=feature_names,
36
- # col_count=len(feature_names),
37
- # row_count=1,
38
- # datatype=json_to_string_type(gradio_input_types),
39
- # ),
40
- # outputs=gr.Textbox(label=metric.name),
41
- # description=(
42
- # metric.info.description + "\nIf this is a text-based metric, make sure to wrap you input in double quotes."
43
- # " Alternatively you can use a JSON-formatted list as input."
44
- # ),
45
- # title=f"Metric: {metric.name}",
46
- # article=parse_readme(local_path / "README.md"),
47
- # # TODO: load test cases and use them to populate examples
48
- # # examples=[parse_test_cases(test_cases, feature_names, gradio_input_types)]
49
- # )
50
-
51
- # iface.launch()
52
- # ###
53
-
54
 
55
  module = evaluate.load("saicharan2804/molgenevalmetric")
56
  # launch_gradio_widget(module)
 
1
  import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  module = evaluate.load("saicharan2804/molgenevalmetric")
6
  # launch_gradio_widget(module)
molgenevalmetric.py CHANGED
@@ -5,7 +5,344 @@ import datasets
5
  import pandas as pd
6
  # from tdc import Evaluator
7
  # from tdc import Oracle
8
- from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  _DESCRIPTION = """
 
5
  import pandas as pd
6
  # from tdc import Evaluator
7
  # from tdc import Oracle
8
+ # from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
9
+
10
+ import os
11
+ from collections import Counter
12
+ from functools import partial
13
+ import numpy as np
14
+ import pandas as pd
15
+ import scipy.sparse
16
+ import torch
17
+ from rdkit import Chem
18
+ from rdkit.Chem import AllChem
19
+ from rdkit.Chem import MACCSkeys
20
+ from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
21
+ from rdkit.Chem.QED import qed
22
+ from rdkit.Chem.Scaffolds import MurckoScaffold
23
+ from rdkit.Chem import Descriptors
24
+
25
+ import random
26
+ from multiprocessing import Pool
27
+ from collections import UserList, defaultdict
28
+ import numpy as np
29
+ import pandas as pd
30
+ from rdkit import rdBase
31
+ import sys
32
+
33
+ from rdkit.Chem import RDConfig
34
+ import os
35
+ sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
36
+ import sascorer
37
+ import pandas as pd
38
+ from fcd_torch import FCD
39
+ from syba.syba import SybaClassifier
40
+
41
+ from tdc import Evaluator
42
+ from tdc import Oracle
43
+
44
+
45
+ def get_mol(smiles_or_mol):
46
+ '''
47
+ Loads SMILES/molecule into RDKit's object
48
+ '''
49
+ if isinstance(smiles_or_mol, str):
50
+ if len(smiles_or_mol) == 0:
51
+ return None
52
+ mol = Chem.MolFromSmiles(smiles_or_mol)
53
+ if mol is None:
54
+ return None
55
+ try:
56
+ Chem.SanitizeMol(mol)
57
+ except ValueError:
58
+ return None
59
+ return mol
60
+ return smiles_or_mol
61
+
62
+ def mapper(n_jobs):
63
+ '''
64
+ Returns function for map call.
65
+ If n_jobs == 1, will use standard map
66
+ If n_jobs > 1, will use multiprocessing pool
67
+ If n_jobs is a pool object, will return its map function
68
+ '''
69
+ if n_jobs == 1:
70
+ def _mapper(*args, **kwargs):
71
+ return list(map(*args, **kwargs))
72
+
73
+ return _mapper
74
+ if isinstance(n_jobs, int):
75
+ pool = Pool(n_jobs)
76
+
77
+ def _mapper(*args, **kwargs):
78
+ try:
79
+ result = pool.map(*args, **kwargs)
80
+ finally:
81
+ pool.terminate()
82
+ return result
83
+
84
+ return _mapper
85
+ return n_jobs.map
86
+
87
+ def fraction_valid(gen, n_jobs=1):
88
+ """
89
+ Computes a number of valid molecules
90
+ Parameters:
91
+ gen: list of SMILES
92
+ n_jobs: number of threads for calculation
93
+ """
94
+ gen = mapper(n_jobs)(get_mol, gen)
95
+ return 1 - gen.count(None) / len(gen)
96
+
97
+
98
+ def canonic_smiles(smiles_or_mol):
99
+ mol = get_mol(smiles_or_mol)
100
+ if mol is None:
101
+ return None
102
+ return Chem.MolToSmiles(mol)
103
+
104
+ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
105
+ """
106
+ Computes a number of unique molecules
107
+ Parameters:
108
+ gen: list of SMILES
109
+ k: compute unique@k
110
+ n_jobs: number of threads for calculation
111
+ check_validity: raises ValueError if invalid molecules are present
112
+ """
113
+ if k is not None:
114
+ if len(gen) < k:
115
+ warnings.warn(
116
+ "Can't compute unique@{}.".format(k) +
117
+ "gen contains only {} molecules".format(len(gen))
118
+ )
119
+ gen = gen[:k]
120
+ canonic = set(mapper(n_jobs)(canonic_smiles, gen))
121
+
122
+ if None in canonic and check_validity:
123
+ raise ValueError("Invalid molecule passed to unique@k")
124
+ return len(canonic) / len(gen)
125
+
126
+ def novelty(gen, train, n_jobs=1):
127
+ gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
128
+ gen_smiles_set = set(gen_smiles) - {None}
129
+ train_set = set(train)
130
+ return len(gen_smiles_set - train_set) / len(gen_smiles_set)
131
+
132
+
133
+ def SAscore(gen):
134
+ """
135
+ Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings.
136
+
137
+ Parameters:
138
+ - smiles_list (list of str): A list containing the SMILES representations of the molecules.
139
+
140
+ Returns:
141
+ - float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found.
142
+ """
143
+ scores = []
144
+ for smiles in gen:
145
+ mol = Chem.MolFromSmiles(smiles)
146
+ if mol: # Ensures the molecule could be parsed from the SMILES string
147
+ score = sascorer.calculateScore(mol)
148
+ scores.append(score)
149
+
150
+ if scores: # Checks if there are any scores calculated
151
+ return np.mean(scores)
152
+ else:
153
+ return None
154
+
155
+
156
+
157
+ def average_agg_tanimoto(stock_vecs, gen_vecs,
158
+ batch_size=5000, agg='max',
159
+ device='cpu', p=1):
160
+ """
161
+ For each molecule in gen_vecs finds closest molecule in stock_vecs.
162
+ Returns average tanimoto score for between these molecules
163
+
164
+ Parameters:
165
+ stock_vecs: numpy array <n_vectors x dim>
166
+ gen_vecs: numpy array <n_vectors' x dim>
167
+ agg: max or mean
168
+ p: power for averaging: (mean x^p)^(1/p)
169
+ """
170
+ assert agg in ['max', 'mean'], "Can aggregate only max or mean"
171
+ agg_tanimoto = np.zeros(len(gen_vecs))
172
+ total = np.zeros(len(gen_vecs))
173
+ for j in range(0, stock_vecs.shape[0], batch_size):
174
+ x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
175
+ for i in range(0, gen_vecs.shape[0], batch_size):
176
+ y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
177
+ y_gen = y_gen.transpose(0, 1)
178
+ tp = torch.mm(x_stock, y_gen)
179
+ jac = (tp / (x_stock.sum(1, keepdim=True) +
180
+ y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
181
+ jac[np.isnan(jac)] = 1
182
+ if p != 1:
183
+ jac = jac**p
184
+ if agg == 'max':
185
+ agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
186
+ agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
187
+ elif agg == 'mean':
188
+ agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
189
+ total[i:i + y_gen.shape[1]] += jac.shape[0]
190
+ if agg == 'mean':
191
+ agg_tanimoto /= total
192
+ if p != 1:
193
+ agg_tanimoto = (agg_tanimoto)**(1/p)
194
+ return np.mean(agg_tanimoto)
195
+
196
+ def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
197
+ morgan__n=1024, *args, **kwargs):
198
+ """
199
+ Generates fingerprint for SMILES
200
+ If smiles is invalid, returns None
201
+ Returns numpy array of fingerprint bits
202
+
203
+ Parameters:
204
+ smiles: SMILES string
205
+ type: type of fingerprint: [MACCS|morgan]
206
+ dtype: if not None, specifies the dtype of returned array
207
+ """
208
+ fp_type = fp_type.lower()
209
+ molecule = get_mol(smiles_or_mol, *args, **kwargs)
210
+ if molecule is None:
211
+ return None
212
+ if fp_type == 'maccs':
213
+ keys = MACCSkeys.GenMACCSKeys(molecule)
214
+ keys = np.array(keys.GetOnBits())
215
+ fingerprint = np.zeros(166, dtype='uint8')
216
+ if len(keys) != 0:
217
+ fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
218
+ elif fp_type == 'morgan':
219
+ fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
220
+ dtype='uint8')
221
+ else:
222
+ raise ValueError("Unknown fingerprint type {}".format(fp_type))
223
+ if dtype is not None:
224
+ fingerprint = fingerprint.astype(dtype)
225
+ return fingerprint
226
+
227
+
228
+ def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
229
+ **kwargs):
230
+ '''
231
+ Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
232
+ e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
233
+ Inserts np.NaN to rows corresponding to incorrect smiles.
234
+ IMPORTANT: if there is at least one np.NaN, the dtype would be float
235
+ Parameters:
236
+ smiles_mols_array: list/array/pd.Series of smiles or already computed
237
+ RDKit molecules
238
+ n_jobs: number of parralel workers to execute
239
+ already_unique: flag for performance reasons, if smiles array is big
240
+ and already unique. Its value is set to True if smiles_mols_array
241
+ contain RDKit molecules already.
242
+ '''
243
+ if isinstance(smiles_mols_array, pd.Series):
244
+ smiles_mols_array = smiles_mols_array.values
245
+ else:
246
+ smiles_mols_array = np.asarray(smiles_mols_array)
247
+ if not isinstance(smiles_mols_array[0], str):
248
+ already_unique = True
249
+
250
+ if not already_unique:
251
+ smiles_mols_array, inv_index = np.unique(smiles_mols_array,
252
+ return_inverse=True)
253
+
254
+ fps = mapper(n_jobs)(
255
+ partial(fingerprint, *args, **kwargs), smiles_mols_array
256
+ )
257
+
258
+ length = 1
259
+ for fp in fps:
260
+ if fp is not None:
261
+ length = fp.shape[-1]
262
+ first_fp = fp
263
+ break
264
+ fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
265
+ for fp in fps]
266
+ if scipy.sparse.issparse(first_fp):
267
+ fps = scipy.sparse.vstack(fps).tocsr()
268
+ else:
269
+ fps = np.vstack(fps)
270
+ if not already_unique:
271
+ return fps[inv_index]
272
+ return fps
273
+
274
+ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
275
+ gen_fps=None, p=1):
276
+ """
277
+ Computes internal diversity as:
278
+ 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
279
+ """
280
+ if gen_fps is None:
281
+ gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
282
+ return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
283
+ agg='mean', device=device, p=p)).mean()
284
+
285
+
286
+ def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
287
+ fcd = FCD(device=device, n_jobs= n_jobs)
288
+ return fcd(gen, train)
289
+
290
+ def SYBAscore(gen):
291
+ """
292
+ Compute the average SYBA score for a list of SMILES strings.
293
+
294
+ Parameters:
295
+ - smiles_list (list of str): A list of SMILES strings representing molecules.
296
+
297
+ Returns:
298
+ - float: The average SYBA score for the list of molecules.
299
+ """
300
+ syba = SybaClassifier()
301
+ syba.fitDefaultScore()
302
+ scores = []
303
+
304
+ for smiles in gen:
305
+ try:
306
+ score = syba.predict(smi=smiles)
307
+ scores.append(score)
308
+ except Exception as e:
309
+ print(f"Error processing SMILES '{smiles}': {e}")
310
+ continue
311
+
312
+ if scores:
313
+ return sum(scores) / len(scores)
314
+ else:
315
+ return None # Or handle empty list or all failed predictions as needed
316
+
317
+ def oracles(gen, train):
318
+ Result = {}
319
+ # evaluator = Evaluator(name = 'KL_Divergence')
320
+ # KL_Divergence = evaluator(gen, train)
321
+
322
+ # Result["KL_Divergence"]: KL_Divergence
323
+
324
+
325
+ oracle_list = [
326
+ 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
327
+ 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
328
+ 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
329
+ ]
330
+
331
+ for oracle_name in oracle_list:
332
+ oracle = Oracle(name=oracle_name)
333
+ if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
334
+ score = oracle(gen)
335
+ if isinstance(score, dict):
336
+ score = {key: sum(values)/len(values) for key, values in score.items()}
337
+ else:
338
+ score = oracle(gen)
339
+ if isinstance(score, list):
340
+ score = sum(score) / len(score)
341
+
342
+ Result[f"{oracle_name}"] = score
343
+
344
+ return Result
345
+
346
 
347
 
348
  _DESCRIPTION = """