saicharan2804 commited on
Commit
0f253ff
·
1 Parent(s): 389f170

Full implementations

Browse files
Files changed (5) hide show
  1. metrics.py +335 -0
  2. molgenevalmetric.py +42 -31
  3. requirements.txt +6 -2
  4. scscore +1 -0
  5. syba_test.py +34 -0
metrics.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import Counter
3
+ from functools import partial
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy.sparse
7
+ import torch
8
+ from rdkit import Chem
9
+ from rdkit.Chem import AllChem
10
+ from rdkit.Chem import MACCSkeys
11
+ from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
12
+ from rdkit.Chem.QED import qed
13
+ from rdkit.Chem.Scaffolds import MurckoScaffold
14
+ from rdkit.Chem import Descriptors
15
+
16
+ import random
17
+ from multiprocessing import Pool
18
+ from collections import UserList, defaultdict
19
+ import numpy as np
20
+ import pandas as pd
21
+ from rdkit import rdBase
22
+ import sys
23
+
24
+ from rdkit.Chem import RDConfig
25
+ import os
26
+ sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
27
+ import sascorer
28
+ import pandas as pd
29
+ from fcd_torch import FCD
30
+ from syba.syba import SybaClassifier
31
+
32
+ from tdc import Evaluator
33
+ from tdc import Oracle
34
+
35
+
36
+ def get_mol(smiles_or_mol):
37
+ '''
38
+ Loads SMILES/molecule into RDKit's object
39
+ '''
40
+ if isinstance(smiles_or_mol, str):
41
+ if len(smiles_or_mol) == 0:
42
+ return None
43
+ mol = Chem.MolFromSmiles(smiles_or_mol)
44
+ if mol is None:
45
+ return None
46
+ try:
47
+ Chem.SanitizeMol(mol)
48
+ except ValueError:
49
+ return None
50
+ return mol
51
+ return smiles_or_mol
52
+
53
+ def mapper(n_jobs):
54
+ '''
55
+ Returns function for map call.
56
+ If n_jobs == 1, will use standard map
57
+ If n_jobs > 1, will use multiprocessing pool
58
+ If n_jobs is a pool object, will return its map function
59
+ '''
60
+ if n_jobs == 1:
61
+ def _mapper(*args, **kwargs):
62
+ return list(map(*args, **kwargs))
63
+
64
+ return _mapper
65
+ if isinstance(n_jobs, int):
66
+ pool = Pool(n_jobs)
67
+
68
+ def _mapper(*args, **kwargs):
69
+ try:
70
+ result = pool.map(*args, **kwargs)
71
+ finally:
72
+ pool.terminate()
73
+ return result
74
+
75
+ return _mapper
76
+ return n_jobs.map
77
+
78
+ def fraction_valid(gen, n_jobs=1):
79
+ """
80
+ Computes a number of valid molecules
81
+ Parameters:
82
+ gen: list of SMILES
83
+ n_jobs: number of threads for calculation
84
+ """
85
+ gen = mapper(n_jobs)(get_mol, gen)
86
+ return 1 - gen.count(None) / len(gen)
87
+
88
+
89
+ def canonic_smiles(smiles_or_mol):
90
+ mol = get_mol(smiles_or_mol)
91
+ if mol is None:
92
+ return None
93
+ return Chem.MolToSmiles(mol)
94
+
95
+ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
96
+ """
97
+ Computes a number of unique molecules
98
+ Parameters:
99
+ gen: list of SMILES
100
+ k: compute unique@k
101
+ n_jobs: number of threads for calculation
102
+ check_validity: raises ValueError if invalid molecules are present
103
+ """
104
+ if k is not None:
105
+ if len(gen) < k:
106
+ warnings.warn(
107
+ "Can't compute unique@{}.".format(k) +
108
+ "gen contains only {} molecules".format(len(gen))
109
+ )
110
+ gen = gen[:k]
111
+ canonic = set(mapper(n_jobs)(canonic_smiles, gen))
112
+
113
+ if None in canonic and check_validity:
114
+ raise ValueError("Invalid molecule passed to unique@k")
115
+ return len(canonic) / len(gen)
116
+
117
+ def novelty(gen, train, n_jobs=1):
118
+ gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
119
+ gen_smiles_set = set(gen_smiles) - {None}
120
+ train_set = set(train)
121
+ return len(gen_smiles_set - train_set) / len(gen_smiles_set)
122
+
123
+
124
+ def SAscore(gen):
125
+ """
126
+ Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings.
127
+
128
+ Parameters:
129
+ - smiles_list (list of str): A list containing the SMILES representations of the molecules.
130
+
131
+ Returns:
132
+ - float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found.
133
+ """
134
+ scores = []
135
+ for smiles in gen:
136
+ mol = Chem.MolFromSmiles(smiles)
137
+ if mol: # Ensures the molecule could be parsed from the SMILES string
138
+ score = sascorer.calculateScore(mol)
139
+ scores.append(score)
140
+
141
+ if scores: # Checks if there are any scores calculated
142
+ return np.mean(scores)
143
+ else:
144
+ return None
145
+
146
+
147
+
148
+ def average_agg_tanimoto(stock_vecs, gen_vecs,
149
+ batch_size=5000, agg='max',
150
+ device='cpu', p=1):
151
+ """
152
+ For each molecule in gen_vecs finds closest molecule in stock_vecs.
153
+ Returns average tanimoto score for between these molecules
154
+
155
+ Parameters:
156
+ stock_vecs: numpy array <n_vectors x dim>
157
+ gen_vecs: numpy array <n_vectors' x dim>
158
+ agg: max or mean
159
+ p: power for averaging: (mean x^p)^(1/p)
160
+ """
161
+ assert agg in ['max', 'mean'], "Can aggregate only max or mean"
162
+ agg_tanimoto = np.zeros(len(gen_vecs))
163
+ total = np.zeros(len(gen_vecs))
164
+ for j in range(0, stock_vecs.shape[0], batch_size):
165
+ x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
166
+ for i in range(0, gen_vecs.shape[0], batch_size):
167
+ y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
168
+ y_gen = y_gen.transpose(0, 1)
169
+ tp = torch.mm(x_stock, y_gen)
170
+ jac = (tp / (x_stock.sum(1, keepdim=True) +
171
+ y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
172
+ jac[np.isnan(jac)] = 1
173
+ if p != 1:
174
+ jac = jac**p
175
+ if agg == 'max':
176
+ agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
177
+ agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
178
+ elif agg == 'mean':
179
+ agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
180
+ total[i:i + y_gen.shape[1]] += jac.shape[0]
181
+ if agg == 'mean':
182
+ agg_tanimoto /= total
183
+ if p != 1:
184
+ agg_tanimoto = (agg_tanimoto)**(1/p)
185
+ return np.mean(agg_tanimoto)
186
+
187
+ def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
188
+ morgan__n=1024, *args, **kwargs):
189
+ """
190
+ Generates fingerprint for SMILES
191
+ If smiles is invalid, returns None
192
+ Returns numpy array of fingerprint bits
193
+
194
+ Parameters:
195
+ smiles: SMILES string
196
+ type: type of fingerprint: [MACCS|morgan]
197
+ dtype: if not None, specifies the dtype of returned array
198
+ """
199
+ fp_type = fp_type.lower()
200
+ molecule = get_mol(smiles_or_mol, *args, **kwargs)
201
+ if molecule is None:
202
+ return None
203
+ if fp_type == 'maccs':
204
+ keys = MACCSkeys.GenMACCSKeys(molecule)
205
+ keys = np.array(keys.GetOnBits())
206
+ fingerprint = np.zeros(166, dtype='uint8')
207
+ if len(keys) != 0:
208
+ fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
209
+ elif fp_type == 'morgan':
210
+ fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
211
+ dtype='uint8')
212
+ else:
213
+ raise ValueError("Unknown fingerprint type {}".format(fp_type))
214
+ if dtype is not None:
215
+ fingerprint = fingerprint.astype(dtype)
216
+ return fingerprint
217
+
218
+
219
+ def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
220
+ **kwargs):
221
+ '''
222
+ Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
223
+ e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
224
+ Inserts np.NaN to rows corresponding to incorrect smiles.
225
+ IMPORTANT: if there is at least one np.NaN, the dtype would be float
226
+ Parameters:
227
+ smiles_mols_array: list/array/pd.Series of smiles or already computed
228
+ RDKit molecules
229
+ n_jobs: number of parralel workers to execute
230
+ already_unique: flag for performance reasons, if smiles array is big
231
+ and already unique. Its value is set to True if smiles_mols_array
232
+ contain RDKit molecules already.
233
+ '''
234
+ if isinstance(smiles_mols_array, pd.Series):
235
+ smiles_mols_array = smiles_mols_array.values
236
+ else:
237
+ smiles_mols_array = np.asarray(smiles_mols_array)
238
+ if not isinstance(smiles_mols_array[0], str):
239
+ already_unique = True
240
+
241
+ if not already_unique:
242
+ smiles_mols_array, inv_index = np.unique(smiles_mols_array,
243
+ return_inverse=True)
244
+
245
+ fps = mapper(n_jobs)(
246
+ partial(fingerprint, *args, **kwargs), smiles_mols_array
247
+ )
248
+
249
+ length = 1
250
+ for fp in fps:
251
+ if fp is not None:
252
+ length = fp.shape[-1]
253
+ first_fp = fp
254
+ break
255
+ fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
256
+ for fp in fps]
257
+ if scipy.sparse.issparse(first_fp):
258
+ fps = scipy.sparse.vstack(fps).tocsr()
259
+ else:
260
+ fps = np.vstack(fps)
261
+ if not already_unique:
262
+ return fps[inv_index]
263
+ return fps
264
+
265
+ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
266
+ gen_fps=None, p=1):
267
+ """
268
+ Computes internal diversity as:
269
+ 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
270
+ """
271
+ if gen_fps is None:
272
+ gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
273
+ return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
274
+ agg='mean', device=device, p=p)).mean()
275
+
276
+
277
+ def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
278
+ fcd = FCD(device=device, n_jobs= n_jobs)
279
+ return fcd(gen, train)
280
+
281
+ def SYBAscore(gen):
282
+ """
283
+ Compute the average SYBA score for a list of SMILES strings.
284
+
285
+ Parameters:
286
+ - smiles_list (list of str): A list of SMILES strings representing molecules.
287
+
288
+ Returns:
289
+ - float: The average SYBA score for the list of molecules.
290
+ """
291
+ syba = SybaClassifier()
292
+ syba.fitDefaultScore()
293
+ scores = []
294
+
295
+ for smiles in gen:
296
+ try:
297
+ score = syba.predict(smi=smiles)
298
+ scores.append(score)
299
+ except Exception as e:
300
+ print(f"Error processing SMILES '{smiles}': {e}")
301
+ continue
302
+
303
+ if scores:
304
+ return sum(scores) / len(scores)
305
+ else:
306
+ return None # Or handle empty list or all failed predictions as needed
307
+
308
+ def oracles(gen, train):
309
+ Result = {}
310
+ # evaluator = Evaluator(name = 'KL_Divergence')
311
+ # KL_Divergence = evaluator(gen, train)
312
+
313
+ # Result["KL_Divergence"]: KL_Divergence
314
+
315
+
316
+ oracle_list = [
317
+ 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
318
+ 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
319
+ 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
320
+ ]
321
+
322
+ for oracle_name in oracle_list:
323
+ oracle = Oracle(name=oracle_name)
324
+ if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
325
+ score = oracle(gen)
326
+ if isinstance(score, dict):
327
+ score = {key: sum(values)/len(values) for key, values in score.items()}
328
+ else:
329
+ score = oracle(gen)
330
+ if isinstance(score, list):
331
+ score = sum(score) / len(score)
332
+
333
+ Result[f"{oracle_name}"] = score
334
+
335
+ return Result
molgenevalmetric.py CHANGED
@@ -1,10 +1,11 @@
1
  import evaluate
2
  import datasets
3
  # import moses
4
- from moses import metrics
5
  import pandas as pd
6
- from tdc import Evaluator
7
- from tdc import Oracle
 
8
 
9
 
10
  _DESCRIPTION = """
@@ -77,42 +78,52 @@ class molgenevalmetric(evaluate.Metric):
77
  reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"],
78
  )
79
 
80
- def _compute(self, generated_smiles, train_smiles = None):
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
83
 
84
- generated_smiles = [s for s in generated_smiles if s != '']
85
 
86
- evaluator = Evaluator(name = 'KL_Divergence')
87
- KL_Divergence = evaluator(generated_smiles, train_smiles)
88
 
89
- Results.update({
90
- "KL_Divergence": KL_Divergence,
91
- })
92
 
93
 
94
- oracle_list = [
95
- 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
96
- 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
97
- 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
98
- ]
99
 
100
- for oracle_name in oracle_list:
101
- oracle = Oracle(name=oracle_name)
102
- if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
103
- score = oracle(generated_smiles)
104
- if isinstance(score, dict):
105
- score = {key: sum(values)/len(values) for key, values in score.items()}
106
- else:
107
- score = oracle(generated_smiles)
108
- if isinstance(score, list):
109
- score = sum(score) / len(score)
110
 
111
- Results.update({f"{oracle_name}": score})
112
 
113
- keys_to_remove = ["FCD/TestSF", "SNN/TestSF", "Frag/TestSF", "Scaf/TestSF"]
114
- for key in keys_to_remove:
115
- Results.pop(key, None)
116
 
117
- return {"results": Results}
118
 
 
1
  import evaluate
2
  import datasets
3
  # import moses
4
+ # from moses import metrics
5
  import pandas as pd
6
+ # from tdc import Evaluator
7
+ # from tdc import Oracle
8
+ from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
9
 
10
 
11
  _DESCRIPTION = """
 
78
  reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"],
79
  )
80
 
81
+ def _compute(self, gensmi, trainsmi):
82
+
83
+ metrics = {}
84
+ metrics['novelty'] = novelty(gen = gensmi, train = trainsmi)
85
+ metrics['valid'] = fraction_valid(gen=gensmi)
86
+ metrics['unique'] = fraction_unique(gen=gensmi)
87
+ metrics['IntDiv'] = internal_diversity(gen=gensmi)
88
+ metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
89
+ metrics['SA'] = SAscore(gen=gensmi)
90
+ metrics['SCS'] = SAscore(gen=trainsmi)
91
+
92
+ return metrics
93
 
 
94
 
95
+ # generated_smiles = [s for s in generated_smiles if s != '']
96
 
97
+ # evaluator = Evaluator(name = 'KL_Divergence')
98
+ # KL_Divergence = evaluator(generated_smiles, train_smiles)
99
 
100
+ # Results.update({
101
+ # "KL_Divergence": KL_Divergence,
102
+ # })
103
 
104
 
105
+ # oracle_list = [
106
+ # 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
107
+ # 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
108
+ # 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
109
+ # ]
110
 
111
+ # for oracle_name in oracle_list:
112
+ # oracle = Oracle(name=oracle_name)
113
+ # if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
114
+ # score = oracle(generated_smiles)
115
+ # if isinstance(score, dict):
116
+ # score = {key: sum(values)/len(values) for key, values in score.items()}
117
+ # else:
118
+ # score = oracle(generated_smiles)
119
+ # if isinstance(score, list):
120
+ # score = sum(score) / len(score)
121
 
122
+ # Results.update({f"{oracle_name}": score})
123
 
124
+ # # keys_to_remove = ["FCD/TestSF", "SNN/TestSF", "Frag/TestSF", "Scaf/TestSF"]
125
+ # # for key in keys_to_remove:
126
+ # # Results.pop(key, None)
127
 
128
+ # return {"results": Results}
129
 
requirements.txt CHANGED
@@ -1,5 +1,9 @@
1
  git+https://github.com/huggingface/evaluate@main
2
  git+https://github.com/molecularsets/moses.git
 
 
 
 
3
  rdkit
4
- pandas==1.5.3
5
- PyTDC
 
1
  git+https://github.com/huggingface/evaluate@main
2
  git+https://github.com/molecularsets/moses.git
3
+ numpy
4
+ pandas
5
+ scipy
6
+ torch
7
  rdkit
8
+ pyarrow
9
+ fcd-torch
scscore ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit ba64e5a524ccc7d9b2000b34b50d8372178ca5d6
syba_test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from syba.syba import SybaClassifier
2
+
3
+ def SYBAscore(smiles_list):
4
+ """
5
+ Compute the average SYBA score for a list of SMILES strings.
6
+
7
+ Parameters:
8
+ - smiles_list (list of str): A list of SMILES strings representing molecules.
9
+
10
+ Returns:
11
+ - float: The average SYBA score for the list of molecules.
12
+ """
13
+ syba = SybaClassifier()
14
+ syba.fitDefaultScore()
15
+ scores = []
16
+
17
+ for smiles in smiles_list:
18
+ try:
19
+ score = syba.predict(smi=smiles)
20
+ scores.append(score)
21
+ except Exception as e:
22
+ print(f"Error processing SMILES '{smiles}': {e}")
23
+ continue
24
+
25
+ if scores:
26
+ return sum(scores) / len(scores)
27
+ else:
28
+ return None # Or handle empty list or all failed predictions as needed
29
+
30
+
31
+ syba = SybaClassifier()
32
+ syba.fitDefaultScore()
33
+ smi = "O=C(C)Oc1ccccc1C(=O)O"
34
+ print(syba.predict(smi))