saicharan2804 commited on
Commit
e32ec06
1 Parent(s): 56d9bde

improved documentation

Browse files
Files changed (1) hide show
  1. molgenevalmetric.py +179 -11
molgenevalmetric.py CHANGED
@@ -4,19 +4,25 @@ import datasets
4
  import pandas as pd
5
  import numpy as np
6
  import scipy.sparse
 
 
7
  import torch
8
  import warnings
9
  from multiprocessing import Pool
10
  from functools import partial
11
  from fcd_torch import FCD
 
 
12
 
13
  from tdc import Oracle
14
  from rdkit.Chem.Crippen import MolLogP
15
  from rdkit import Chem
16
  from rdkit.Chem import MACCSkeys
 
17
  from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
18
  from rdkit.Chem.QED import qed
19
  from rdkit.Contrib.SA_Score import sascorer
 
20
 
21
  from syba.syba import SybaClassifier
22
  from myscscore.SCScore import SCScorer
@@ -157,6 +163,9 @@ def novelty(gen, train, n_jobs=1):
157
  def synthetic_complexity_score(gen):
158
  """
159
  Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings.
 
 
 
160
 
161
  Parameters:
162
  - gen (list of str): A list containing the SMILES representations of the molecules.
@@ -174,7 +183,10 @@ def synthetic_complexity_score(gen):
174
  def calculate_sa_score(smiles):
175
  """
176
  Calculates the SA score for a single SMILES string.
177
-
 
 
 
178
  Parameters:
179
  - smiles (str): SMILES string of the molecule.
180
 
@@ -189,11 +201,10 @@ def calculate_sa_score(smiles):
189
 
190
  def average_sascore(gen, n_jobs=1):
191
  """
192
- Computes the average synthetic accessibility score for a list of molecules
193
- using parallel or sequential execution based on the n_jobs parameter.
194
 
195
  Parameters:
196
- - molecules (List[str]): List of generated SMILES strings.
197
  - n_jobs (int): Number of parallel jobs to use for computation.
198
 
199
  Returns:
@@ -358,6 +369,8 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
358
  def fcd_metric(gen, train, n_jobs = 1, device = None):
359
  """
360
  Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
 
 
361
 
362
  Parameters:
363
  - gen (List[str]): List of generated SMILES strings.
@@ -380,10 +393,13 @@ def fcd_metric(gen, train, n_jobs = 1, device = None):
380
 
381
  def SYBAscore(gen):
382
  """
383
- Compute the average SYBA score for a list of SMILES strings.
 
 
 
384
 
385
  Parameters:
386
- - smiles_list (list of str): A list of SMILES strings representing molecules.
387
 
388
  Returns:
389
  - float: The average SYBA score for the list of molecules.
@@ -407,7 +423,18 @@ def SYBAscore(gen):
407
 
408
  def qed_metric(gen):
409
  """
410
- Computes RDKit's QED score
 
 
 
 
 
 
 
 
 
 
 
411
  """
412
  if not gen:
413
  return 0.0 # Return 0 or suitable value for empty list
@@ -433,7 +460,7 @@ def logP_metric(gen):
433
  Computes the average RDKit's logP value for a list of SMILES strings.
434
 
435
  Parameters:
436
- - mols (List[str]): List of SMILES strings representing the molecules.
437
 
438
  Returns:
439
  - float: Average logP value for the list of molecules.
@@ -463,7 +490,7 @@ def penalized_logp(gen):
463
  Computes the average PyTDC's penalized logP value for a list of SMILES strings.
464
 
465
  Parameters:
466
- - mols (List[str]): List of SMILES strings representing the molecules.
467
 
468
  Returns:
469
  - float: Average penalized logP value for the list of molecules.
@@ -477,8 +504,6 @@ def penalized_logp(gen):
477
  return score
478
 
479
 
480
-
481
-
482
  _DESCRIPTION = """
483
 
484
  Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives.
@@ -541,3 +566,146 @@ class molgenevalmetric(evaluate.Metric):
541
 
542
  return metrics
543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pandas as pd
5
  import numpy as np
6
  import scipy.sparse
7
+ from scipy.spatial.distance import cosine as cos_distance
8
+ from scipy.stats import wasserstein_distance
9
  import torch
10
  import warnings
11
  from multiprocessing import Pool
12
  from functools import partial
13
  from fcd_torch import FCD
14
+ from collections import Counter
15
+
16
 
17
  from tdc import Oracle
18
  from rdkit.Chem.Crippen import MolLogP
19
  from rdkit import Chem
20
  from rdkit.Chem import MACCSkeys
21
+ from rdkit.Chem import AllChem
22
  from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
23
  from rdkit.Chem.QED import qed
24
  from rdkit.Contrib.SA_Score import sascorer
25
+ from rdkit.Chem.Scaffolds import MurckoScaffold
26
 
27
  from syba.syba import SybaClassifier
28
  from myscscore.SCScore import SCScorer
 
163
  def synthetic_complexity_score(gen):
164
  """
165
  Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings.
166
+ The SCScore model rates the synthetic complexity of molecules on a scale from 1 to 5.
167
+ Based on the premise that on average, the products of published chemical reactions should be more synthetically complex than their corresponding reactants
168
+
169
 
170
  Parameters:
171
  - gen (list of str): A list containing the SMILES representations of the molecules.
 
183
  def calculate_sa_score(smiles):
184
  """
185
  Calculates the SA score for a single SMILES string.
186
+ Evaluates the ease of synthesizing drug-like molecules in virtual screening.
187
+ Ranges from 1 (easy to synthesize) to 10 (hard to synthesize)
188
+ This score reflects the presence of common fragments in a molecule and structural complexities.
189
+
190
  Parameters:
191
  - smiles (str): SMILES string of the molecule.
192
 
 
201
 
202
  def average_sascore(gen, n_jobs=1):
203
  """
204
+ Computes the average synthetic accessibility score for a list of molecules.
 
205
 
206
  Parameters:
207
+ - gen (List[str]): List of generated SMILES strings.
208
  - n_jobs (int): Number of parallel jobs to use for computation.
209
 
210
  Returns:
 
369
  def fcd_metric(gen, train, n_jobs = 1, device = None):
370
  """
371
  Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
372
+ FCD is calculated using the Fréchet Distance between feature vectors of generated and real molecules obtained from ChemNet.
373
+ A lower FCD score indicates higher chemical realism and diversity in the molecules generated by a model.
374
 
375
  Parameters:
376
  - gen (List[str]): List of generated SMILES strings.
 
393
 
394
  def SYBAscore(gen):
395
  """
396
+ Compute the average SYBA (SYnthetic Bayesian Accessibility) score for a list of SMILES strings.
397
+ It is a fragment-based method for the rapid classification of organic compounds as easy- (ES) or hard-to-synthesize (HS).
398
+ Based on a Bernoulli naïve Bayes classifier that is used to assign SYBA score contributions to individual fragments based on their frequencies in the database of ES and HS molecules.
399
+ Trained on ES molecules available in the ZINC15 database and on HS molecules generated by the Nonpher methodology
400
 
401
  Parameters:
402
+ - gen (List[str]): List of generated SMILES strings.
403
 
404
  Returns:
405
  - float: The average SYBA score for the list of molecules.
 
423
 
424
  def qed_metric(gen):
425
  """
426
+ Computes RDKit's QED score.
427
+ A [0,1] value estimating how likely a molecule is a viable candidate for a drug.
428
+ QED is meant to capture certain desirable traits that successful drug molecules tend to possess
429
+
430
+
431
+
432
+ Parameters:
433
+ - gen (List[str]): List of generated SMILES strings.
434
+
435
+ Returns:
436
+ - float: The average QED score for the list of molecules.
437
+
438
  """
439
  if not gen:
440
  return 0.0 # Return 0 or suitable value for empty list
 
460
  Computes the average RDKit's logP value for a list of SMILES strings.
461
 
462
  Parameters:
463
+ - gen (List[str]): List of generated SMILES strings.
464
 
465
  Returns:
466
  - float: Average logP value for the list of molecules.
 
490
  Computes the average PyTDC's penalized logP value for a list of SMILES strings.
491
 
492
  Parameters:
493
+ - gen (List[str]): List of generated SMILES strings.
494
 
495
  Returns:
496
  - float: Average penalized logP value for the list of molecules.
 
504
  return score
505
 
506
 
 
 
507
  _DESCRIPTION = """
508
 
509
  Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives.
 
566
 
567
  return metrics
568
 
569
+
570
+ # def get_n_rings(mol):
571
+ # """
572
+ # Computes the number of rings in a molecule
573
+ # """
574
+ # return mol.GetRingInfo().NumRings()
575
+
576
+ # def fragmenter(mol):
577
+ # """
578
+ # fragment mol using BRICS and return smiles list
579
+ # """
580
+ # fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
581
+ # fgs_smi = Chem.MolToSmiles(fgs).split(".")
582
+ # return fgs_smi
583
+
584
+ # def compute_fragments(mol_list, n_jobs=1):
585
+ # """
586
+ # fragment list of mols using BRICS and return smiles list
587
+ # """
588
+ # fragments = Counter()
589
+ # for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
590
+ # fragments.update(mol_frag)
591
+ # return fragments
592
+
593
+ # def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
594
+ # """
595
+ # Extracts a scafold from a molecule in a form of a canonic SMILES
596
+ # """
597
+ # scaffolds = Counter()
598
+ # map_ = mapper(n_jobs)
599
+ # scaffolds = Counter(
600
+ # map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
601
+ # if None in scaffolds:
602
+ # scaffolds.pop(None)
603
+ # return scaffolds
604
+
605
+ # def compute_scaffold(mol, min_rings=2):
606
+ # mol = get_mol(mol)
607
+ # try:
608
+ # scaffold = MurckoScaffold.GetScaffoldForMol(mol)
609
+ # except (ValueError, RuntimeError):
610
+ # return None
611
+ # n_rings = get_n_rings(scaffold)
612
+ # scaffold_smiles = Chem.MolToSmiles(scaffold)
613
+ # if scaffold_smiles == '' or n_rings < min_rings:
614
+ # return None
615
+ # return scaffold_smiles
616
+
617
+ # class Metric:
618
+ # def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
619
+ # self.n_jobs = n_jobs
620
+ # self.device = device
621
+ # self.batch_size = batch_size
622
+ # for k, v in kwargs.values():
623
+ # setattr(self, k, v)
624
+
625
+ # def __call__(self, ref=None, gen=None, pref=None, pgen=None):
626
+ # assert (ref is None) != (pref is None), "specify ref xor pref"
627
+ # assert (gen is None) != (pgen is None), "specify gen xor pgen"
628
+ # if pref is None:
629
+ # pref = self.precalc(ref)
630
+ # if pgen is None:
631
+ # pgen = self.precalc(gen)
632
+ # return self.metric(pref, pgen)
633
+
634
+ # def precalc(self, moleclues):
635
+ # raise NotImplementedError
636
+
637
+ # def metric(self, pref, pgen):
638
+ # raise NotImplementedError
639
+
640
+
641
+ # class SNNMetric(Metric):
642
+ # """
643
+ # Computes average max similarities of gen SMILES to ref SMILES
644
+ # """
645
+
646
+ # def __init__(self, fp_type='morgan', **kwargs):
647
+ # self.fp_type = fp_type
648
+ # super().__init__(**kwargs)
649
+
650
+ # def precalc(self, mols):
651
+ # return {'fps': fingerprints(mols, n_jobs=self.n_jobs,
652
+ # fp_type=self.fp_type)}
653
+
654
+ # def metric(self, pref, pgen):
655
+ # return average_agg_tanimoto(pref['fps'], pgen['fps'],
656
+ # device=self.device)
657
+
658
+
659
+ # def cos_similarity(ref_counts, gen_counts):
660
+ # """
661
+ # Computes cosine similarity between
662
+ # dictionaries of form {name: count}. Non-present
663
+ # elements are considered zero:
664
+
665
+ # sim = <r, g> / ||r|| / ||g||
666
+ # """
667
+ # if len(ref_counts) == 0 or len(gen_counts) == 0:
668
+ # return np.nan
669
+ # keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
670
+ # ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
671
+ # gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
672
+ # return 1 - cos_distance(ref_vec, gen_vec)
673
+
674
+
675
+ # class FragMetric(Metric):
676
+ # def precalc(self, mols):
677
+ # return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
678
+
679
+ # def metric(self, pref, pgen):
680
+ # return cos_similarity(pref['frag'], pgen['frag'])
681
+
682
+
683
+ # class ScafMetric(Metric):
684
+ # def precalc(self, mols):
685
+ # return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
686
+
687
+ # def metric(self, pref, pgen):
688
+ # return cos_similarity(pref['scaf'], pgen['scaf'])
689
+
690
+
691
+ # class WassersteinMetric(Metric):
692
+ # def __init__(self, func=None, **kwargs):
693
+ # self.func = func
694
+ # super().__init__(**kwargs)
695
+
696
+ # def precalc(self, mols):
697
+ # if self.func is not None:
698
+ # values = mapper(self.n_jobs)(self.func, mols)
699
+ # else:
700
+ # values = mols
701
+ # return {'values': values}
702
+
703
+ # def metric(self, pref, pgen):
704
+ # return wasserstein_distance(
705
+ # pref['values'], pgen['values']
706
+ # )
707
+
708
+
709
+ # def get_frag(gen):
710
+ # mols = mapper(pool)(get_mol, gen)
711
+ # kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}