Spaces:
Sleeping
Sleeping
saicharan2804
commited on
Commit
·
e57187a
1
Parent(s):
d3fd51f
documentation
Browse files- molgenevalmetric.py +132 -47
molgenevalmetric.py
CHANGED
@@ -43,9 +43,18 @@ from fcd_torch import FCD
|
|
43 |
|
44 |
|
45 |
def get_mol(smiles_or_mol):
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
if isinstance(smiles_or_mol, str):
|
50 |
if len(smiles_or_mol) == 0:
|
51 |
return None
|
@@ -60,12 +69,17 @@ def get_mol(smiles_or_mol):
|
|
60 |
return smiles_or_mol
|
61 |
|
62 |
def mapper(n_jobs):
|
63 |
-
|
64 |
-
Returns function for
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
69 |
if n_jobs == 1:
|
70 |
def _mapper(*args, **kwargs):
|
71 |
return list(map(*args, **kwargs))
|
@@ -86,16 +100,30 @@ def mapper(n_jobs):
|
|
86 |
|
87 |
def fraction_valid(gen, n_jobs=1):
|
88 |
"""
|
89 |
-
|
|
|
90 |
Parameters:
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
93 |
"""
|
94 |
gen = mapper(n_jobs)(get_mol, gen)
|
95 |
return 1 - gen.count(None) / len(gen)
|
96 |
|
97 |
|
98 |
def canonic_smiles(smiles_or_mol):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
mol = get_mol(smiles_or_mol)
|
100 |
if mol is None:
|
101 |
return None
|
@@ -103,12 +131,16 @@ def canonic_smiles(smiles_or_mol):
|
|
103 |
|
104 |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
105 |
"""
|
106 |
-
|
|
|
107 |
Parameters:
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
112 |
"""
|
113 |
if k is not None:
|
114 |
if len(gen) < k:
|
@@ -124,6 +156,18 @@ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
|
124 |
return len(canonic) / len(gen)
|
125 |
|
126 |
def novelty(gen, train, n_jobs=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
128 |
gen_smiles_set = set(gen_smiles) - {None}
|
129 |
train_set = set(train)
|
@@ -158,15 +202,20 @@ def average_agg_tanimoto(stock_vecs, gen_vecs,
|
|
158 |
batch_size=5000, agg='max',
|
159 |
device='cpu', p=1):
|
160 |
"""
|
161 |
-
|
162 |
-
Returns average tanimoto score for between these molecules
|
163 |
|
164 |
Parameters:
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
169 |
"""
|
|
|
170 |
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
171 |
agg_tanimoto = np.zeros(len(gen_vecs))
|
172 |
total = np.zeros(len(gen_vecs))
|
@@ -276,6 +325,17 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
|
|
276 |
"""
|
277 |
Computes internal diversity as:
|
278 |
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
"""
|
280 |
if gen_fps is None:
|
281 |
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
|
@@ -284,6 +344,19 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
|
|
284 |
|
285 |
|
286 |
def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
fcd = FCD(device=device, n_jobs= n_jobs)
|
288 |
return fcd(gen, train)
|
289 |
|
@@ -315,33 +388,45 @@ def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
|
|
315 |
# return None # Or handle empty list or all failed predictions as needed
|
316 |
|
317 |
def oracles(gen, train):
|
318 |
-
Result = {}
|
319 |
-
# evaluator = Evaluator(name = 'KL_Divergence')
|
320 |
-
# KL_Divergence = evaluator(gen, train)
|
321 |
-
|
322 |
-
# Result["KL_Divergence"]: KL_Divergence
|
323 |
|
|
|
|
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
score = {key: sum(values)/len(values) for key, values in score.items()}
|
337 |
-
else:
|
338 |
-
score = oracle(gen)
|
339 |
-
if isinstance(score, list):
|
340 |
-
score = sum(score) / len(score)
|
341 |
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
-
|
|
|
|
|
345 |
|
346 |
|
347 |
|
@@ -424,7 +509,7 @@ class molgenevalmetric(evaluate.Metric):
|
|
424 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
425 |
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
426 |
metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
|
427 |
-
|
428 |
# metrics['SA'] = SAscore(gen=gensmi)
|
429 |
# metrics['SCS'] = SAscore(gen=trainsmi)
|
430 |
|
|
|
43 |
|
44 |
|
45 |
def get_mol(smiles_or_mol):
|
46 |
+
"""
|
47 |
+
Converts a SMILES string or RDKit molecule object to an RDKit molecule object.
|
48 |
+
If the input is already an RDKit molecule object, it returns it directly.
|
49 |
+
For a SMILES string, it attempts to create an RDKit molecule object.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
- smiles_or_mol (str or Mol): The SMILES string of the molecule or an RDKit molecule object.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
- Mol or None: The RDKit molecule object or None if conversion fails.
|
56 |
+
"""
|
57 |
+
|
58 |
if isinstance(smiles_or_mol, str):
|
59 |
if len(smiles_or_mol) == 0:
|
60 |
return None
|
|
|
69 |
return smiles_or_mol
|
70 |
|
71 |
def mapper(n_jobs):
|
72 |
+
"""
|
73 |
+
Returns a mapping function suitable for parallel or sequential execution
|
74 |
+
based on the value of n_jobs.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
- n_jobs (int or Pool): Number of jobs for parallel execution or a multiprocessing Pool object.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
- Function: A mapping function that can be used for applying a function over a sequence.
|
81 |
+
"""
|
82 |
+
|
83 |
if n_jobs == 1:
|
84 |
def _mapper(*args, **kwargs):
|
85 |
return list(map(*args, **kwargs))
|
|
|
100 |
|
101 |
def fraction_valid(gen, n_jobs=1):
|
102 |
"""
|
103 |
+
Calculates the fraction of valid molecules in a list of SMILES strings.
|
104 |
+
|
105 |
Parameters:
|
106 |
+
- gen (list of str): List of SMILES strings.
|
107 |
+
- n_jobs (int): Number of parallel jobs to use for computation.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
- float: Fraction of valid molecules.
|
111 |
"""
|
112 |
gen = mapper(n_jobs)(get_mol, gen)
|
113 |
return 1 - gen.count(None) / len(gen)
|
114 |
|
115 |
|
116 |
def canonic_smiles(smiles_or_mol):
|
117 |
+
"""
|
118 |
+
Converts a molecule into its canonical SMILES representation.
|
119 |
+
|
120 |
+
Parameters:
|
121 |
+
- smiles_or_mol (str or Mol): SMILES string or RDKit molecule object.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
- str or None: Canonical SMILES string, or None if conversion fails.
|
125 |
+
"""
|
126 |
+
|
127 |
mol = get_mol(smiles_or_mol)
|
128 |
if mol is None:
|
129 |
return None
|
|
|
131 |
|
132 |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
133 |
"""
|
134 |
+
Calculates the fraction of unique molecules in a list of SMILES strings.
|
135 |
+
|
136 |
Parameters:
|
137 |
+
- gen (list of str): List of SMILES strings.
|
138 |
+
- k (int, optional): Number of top molecules to consider for uniqueness. If None, considers all.
|
139 |
+
- n_jobs (int): Number of parallel jobs to use for computation.
|
140 |
+
- check_validity (bool): If True, checks for the validity of molecules.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
- float: Fraction of unique molecules.
|
144 |
"""
|
145 |
if k is not None:
|
146 |
if len(gen) < k:
|
|
|
156 |
return len(canonic) / len(gen)
|
157 |
|
158 |
def novelty(gen, train, n_jobs=1):
|
159 |
+
"""
|
160 |
+
Computes the novelty of generated molecules compared to a training set.
|
161 |
+
|
162 |
+
Parameters:
|
163 |
+
- gen (List[str]): List of generated SMILES strings.
|
164 |
+
- train (List[str]): List of SMILES strings from the training set.
|
165 |
+
- n_jobs (int): Number of parallel jobs to use for computation.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
- float: Novelty score.
|
169 |
+
"""
|
170 |
+
|
171 |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
172 |
gen_smiles_set = set(gen_smiles) - {None}
|
173 |
train_set = set(train)
|
|
|
202 |
batch_size=5000, agg='max',
|
203 |
device='cpu', p=1):
|
204 |
"""
|
205 |
+
Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints.
|
|
|
206 |
|
207 |
Parameters:
|
208 |
+
- stock_vecs (numpy array): Fingerprint vectors for the reference molecule set.
|
209 |
+
- gen_vecs (numpy array): Fingerprint vectors for the generated molecule set.
|
210 |
+
- batch_size (int): The size of batches to process similarities (reduces memory usage).
|
211 |
+
- agg (str): Aggregation method, either 'max' or 'mean'.
|
212 |
+
- device (str): The computation device ('cpu' or 'cuda:0', etc.).
|
213 |
+
- p (float): The power for averaging, used in generalized mean calculation.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
- float: Average aggregate Tanimoto similarity score.
|
217 |
"""
|
218 |
+
|
219 |
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
220 |
agg_tanimoto = np.zeros(len(gen_vecs))
|
221 |
total = np.zeros(len(gen_vecs))
|
|
|
325 |
"""
|
326 |
Computes internal diversity as:
|
327 |
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
|
328 |
+
|
329 |
+
Parameters:
|
330 |
+
- gen (List[str]): List of generated SMILES strings.
|
331 |
+
- n_jobs (int): Number of parallel jobs for fingerprint computation.
|
332 |
+
- device (str): Computation device ('cpu' or 'cuda:0', etc.).
|
333 |
+
- fp_type (str): Type of fingerprint to use ('morgan', etc.).
|
334 |
+
- gen_fps (Optional[np.ndarray]): Precomputed fingerprints of generated molecules. If None, will be computed.
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
- float: Internal diversity score.
|
338 |
+
|
339 |
"""
|
340 |
if gen_fps is None:
|
341 |
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
|
|
|
344 |
|
345 |
|
346 |
def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
|
347 |
+
"""
|
348 |
+
Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
|
349 |
+
|
350 |
+
Parameters:
|
351 |
+
- gen (List[str]): List of generated SMILES strings.
|
352 |
+
- train (List[str]): List of training set SMILES strings.
|
353 |
+
- n_jobs (int): Number of parallel jobs for computation.
|
354 |
+
- device (str): Computation device for the FCD calculation.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
- float: FCD score.
|
358 |
+
"""
|
359 |
+
|
360 |
fcd = FCD(device=device, n_jobs= n_jobs)
|
361 |
return fcd(gen, train)
|
362 |
|
|
|
388 |
# return None # Or handle empty list or all failed predictions as needed
|
389 |
|
390 |
def oracles(gen, train):
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
+
"""
|
393 |
+
Computes scores from various oracles for a list of generated molecules.
|
394 |
|
395 |
+
Parameters:
|
396 |
+
- gen (List[str]): List of generated SMILES strings.
|
397 |
+
- train (List[str]): List of training set SMILES strings.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
- Dict[str, Any]: A dictionary with oracle names as keys and their corresponding scores as values.
|
401 |
+
"""
|
402 |
+
|
403 |
+
Result = {}
|
404 |
+
# evaluator = Evaluator(name = 'KL_Divergence')
|
405 |
+
# KL_Divergence = evaluator(gen, train)
|
|
|
|
|
|
|
|
|
|
|
406 |
|
407 |
+
# Result["KL_Divergence"]: KL_Divergence
|
408 |
+
|
409 |
+
|
410 |
+
oracle_list = [
|
411 |
+
'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
|
412 |
+
'DRD2', 'LogP', 'Rediscovery', 'Similarity',
|
413 |
+
'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
|
414 |
+
]
|
415 |
+
|
416 |
+
for oracle_name in oracle_list:
|
417 |
+
oracle = Oracle(name=oracle_name)
|
418 |
+
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
|
419 |
+
score = oracle(gen)
|
420 |
+
if isinstance(score, dict):
|
421 |
+
score = {key: sum(values)/len(values) for key, values in score.items()}
|
422 |
+
else:
|
423 |
+
score = oracle(gen)
|
424 |
+
if isinstance(score, list):
|
425 |
+
score = sum(score) / len(score)
|
426 |
|
427 |
+
Result[f"{oracle_name}"] = score
|
428 |
+
|
429 |
+
return Result
|
430 |
|
431 |
|
432 |
|
|
|
509 |
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
510 |
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
511 |
metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
|
512 |
+
|
513 |
# metrics['SA'] = SAscore(gen=gensmi)
|
514 |
# metrics['SCS'] = SAscore(gen=trainsmi)
|
515 |
|