saicharan2804 commited on
Commit
e57187a
·
1 Parent(s): d3fd51f

documentation

Browse files
Files changed (1) hide show
  1. molgenevalmetric.py +132 -47
molgenevalmetric.py CHANGED
@@ -43,9 +43,18 @@ from fcd_torch import FCD
43
 
44
 
45
  def get_mol(smiles_or_mol):
46
- '''
47
- Loads SMILES/molecule into RDKit's object
48
- '''
 
 
 
 
 
 
 
 
 
49
  if isinstance(smiles_or_mol, str):
50
  if len(smiles_or_mol) == 0:
51
  return None
@@ -60,12 +69,17 @@ def get_mol(smiles_or_mol):
60
  return smiles_or_mol
61
 
62
  def mapper(n_jobs):
63
- '''
64
- Returns function for map call.
65
- If n_jobs == 1, will use standard map
66
- If n_jobs > 1, will use multiprocessing pool
67
- If n_jobs is a pool object, will return its map function
68
- '''
 
 
 
 
 
69
  if n_jobs == 1:
70
  def _mapper(*args, **kwargs):
71
  return list(map(*args, **kwargs))
@@ -86,16 +100,30 @@ def mapper(n_jobs):
86
 
87
  def fraction_valid(gen, n_jobs=1):
88
  """
89
- Computes a number of valid molecules
 
90
  Parameters:
91
- gen: list of SMILES
92
- n_jobs: number of threads for calculation
 
 
 
93
  """
94
  gen = mapper(n_jobs)(get_mol, gen)
95
  return 1 - gen.count(None) / len(gen)
96
 
97
 
98
  def canonic_smiles(smiles_or_mol):
 
 
 
 
 
 
 
 
 
 
99
  mol = get_mol(smiles_or_mol)
100
  if mol is None:
101
  return None
@@ -103,12 +131,16 @@ def canonic_smiles(smiles_or_mol):
103
 
104
  def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
105
  """
106
- Computes a number of unique molecules
 
107
  Parameters:
108
- gen: list of SMILES
109
- k: compute unique@k
110
- n_jobs: number of threads for calculation
111
- check_validity: raises ValueError if invalid molecules are present
 
 
 
112
  """
113
  if k is not None:
114
  if len(gen) < k:
@@ -124,6 +156,18 @@ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
124
  return len(canonic) / len(gen)
125
 
126
  def novelty(gen, train, n_jobs=1):
 
 
 
 
 
 
 
 
 
 
 
 
127
  gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
128
  gen_smiles_set = set(gen_smiles) - {None}
129
  train_set = set(train)
@@ -158,15 +202,20 @@ def average_agg_tanimoto(stock_vecs, gen_vecs,
158
  batch_size=5000, agg='max',
159
  device='cpu', p=1):
160
  """
161
- For each molecule in gen_vecs finds closest molecule in stock_vecs.
162
- Returns average tanimoto score for between these molecules
163
 
164
  Parameters:
165
- stock_vecs: numpy array <n_vectors x dim>
166
- gen_vecs: numpy array <n_vectors' x dim>
167
- agg: max or mean
168
- p: power for averaging: (mean x^p)^(1/p)
 
 
 
 
 
169
  """
 
170
  assert agg in ['max', 'mean'], "Can aggregate only max or mean"
171
  agg_tanimoto = np.zeros(len(gen_vecs))
172
  total = np.zeros(len(gen_vecs))
@@ -276,6 +325,17 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
276
  """
277
  Computes internal diversity as:
278
  1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
 
 
 
 
 
 
 
 
 
 
 
279
  """
280
  if gen_fps is None:
281
  gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
@@ -284,6 +344,19 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
284
 
285
 
286
  def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  fcd = FCD(device=device, n_jobs= n_jobs)
288
  return fcd(gen, train)
289
 
@@ -315,33 +388,45 @@ def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
315
  # return None # Or handle empty list or all failed predictions as needed
316
 
317
  def oracles(gen, train):
318
- Result = {}
319
- # evaluator = Evaluator(name = 'KL_Divergence')
320
- # KL_Divergence = evaluator(gen, train)
321
-
322
- # Result["KL_Divergence"]: KL_Divergence
323
 
 
 
324
 
325
- oracle_list = [
326
- 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
327
- 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
328
- 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
329
- ]
330
-
331
- for oracle_name in oracle_list:
332
- oracle = Oracle(name=oracle_name)
333
- if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
334
- score = oracle(gen)
335
- if isinstance(score, dict):
336
- score = {key: sum(values)/len(values) for key, values in score.items()}
337
- else:
338
- score = oracle(gen)
339
- if isinstance(score, list):
340
- score = sum(score) / len(score)
341
 
342
- Result[f"{oracle_name}"] = score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- return Result
 
 
345
 
346
 
347
 
@@ -424,7 +509,7 @@ class molgenevalmetric(evaluate.Metric):
424
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
425
  metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
426
  metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
427
-
428
  # metrics['SA'] = SAscore(gen=gensmi)
429
  # metrics['SCS'] = SAscore(gen=trainsmi)
430
 
 
43
 
44
 
45
  def get_mol(smiles_or_mol):
46
+ """
47
+ Converts a SMILES string or RDKit molecule object to an RDKit molecule object.
48
+ If the input is already an RDKit molecule object, it returns it directly.
49
+ For a SMILES string, it attempts to create an RDKit molecule object.
50
+
51
+ Parameters:
52
+ - smiles_or_mol (str or Mol): The SMILES string of the molecule or an RDKit molecule object.
53
+
54
+ Returns:
55
+ - Mol or None: The RDKit molecule object or None if conversion fails.
56
+ """
57
+
58
  if isinstance(smiles_or_mol, str):
59
  if len(smiles_or_mol) == 0:
60
  return None
 
69
  return smiles_or_mol
70
 
71
  def mapper(n_jobs):
72
+ """
73
+ Returns a mapping function suitable for parallel or sequential execution
74
+ based on the value of n_jobs.
75
+
76
+ Parameters:
77
+ - n_jobs (int or Pool): Number of jobs for parallel execution or a multiprocessing Pool object.
78
+
79
+ Returns:
80
+ - Function: A mapping function that can be used for applying a function over a sequence.
81
+ """
82
+
83
  if n_jobs == 1:
84
  def _mapper(*args, **kwargs):
85
  return list(map(*args, **kwargs))
 
100
 
101
  def fraction_valid(gen, n_jobs=1):
102
  """
103
+ Calculates the fraction of valid molecules in a list of SMILES strings.
104
+
105
  Parameters:
106
+ - gen (list of str): List of SMILES strings.
107
+ - n_jobs (int): Number of parallel jobs to use for computation.
108
+
109
+ Returns:
110
+ - float: Fraction of valid molecules.
111
  """
112
  gen = mapper(n_jobs)(get_mol, gen)
113
  return 1 - gen.count(None) / len(gen)
114
 
115
 
116
  def canonic_smiles(smiles_or_mol):
117
+ """
118
+ Converts a molecule into its canonical SMILES representation.
119
+
120
+ Parameters:
121
+ - smiles_or_mol (str or Mol): SMILES string or RDKit molecule object.
122
+
123
+ Returns:
124
+ - str or None: Canonical SMILES string, or None if conversion fails.
125
+ """
126
+
127
  mol = get_mol(smiles_or_mol)
128
  if mol is None:
129
  return None
 
131
 
132
  def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
133
  """
134
+ Calculates the fraction of unique molecules in a list of SMILES strings.
135
+
136
  Parameters:
137
+ - gen (list of str): List of SMILES strings.
138
+ - k (int, optional): Number of top molecules to consider for uniqueness. If None, considers all.
139
+ - n_jobs (int): Number of parallel jobs to use for computation.
140
+ - check_validity (bool): If True, checks for the validity of molecules.
141
+
142
+ Returns:
143
+ - float: Fraction of unique molecules.
144
  """
145
  if k is not None:
146
  if len(gen) < k:
 
156
  return len(canonic) / len(gen)
157
 
158
  def novelty(gen, train, n_jobs=1):
159
+ """
160
+ Computes the novelty of generated molecules compared to a training set.
161
+
162
+ Parameters:
163
+ - gen (List[str]): List of generated SMILES strings.
164
+ - train (List[str]): List of SMILES strings from the training set.
165
+ - n_jobs (int): Number of parallel jobs to use for computation.
166
+
167
+ Returns:
168
+ - float: Novelty score.
169
+ """
170
+
171
  gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
172
  gen_smiles_set = set(gen_smiles) - {None}
173
  train_set = set(train)
 
202
  batch_size=5000, agg='max',
203
  device='cpu', p=1):
204
  """
205
+ Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints.
 
206
 
207
  Parameters:
208
+ - stock_vecs (numpy array): Fingerprint vectors for the reference molecule set.
209
+ - gen_vecs (numpy array): Fingerprint vectors for the generated molecule set.
210
+ - batch_size (int): The size of batches to process similarities (reduces memory usage).
211
+ - agg (str): Aggregation method, either 'max' or 'mean'.
212
+ - device (str): The computation device ('cpu' or 'cuda:0', etc.).
213
+ - p (float): The power for averaging, used in generalized mean calculation.
214
+
215
+ Returns:
216
+ - float: Average aggregate Tanimoto similarity score.
217
  """
218
+
219
  assert agg in ['max', 'mean'], "Can aggregate only max or mean"
220
  agg_tanimoto = np.zeros(len(gen_vecs))
221
  total = np.zeros(len(gen_vecs))
 
325
  """
326
  Computes internal diversity as:
327
  1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
328
+
329
+ Parameters:
330
+ - gen (List[str]): List of generated SMILES strings.
331
+ - n_jobs (int): Number of parallel jobs for fingerprint computation.
332
+ - device (str): Computation device ('cpu' or 'cuda:0', etc.).
333
+ - fp_type (str): Type of fingerprint to use ('morgan', etc.).
334
+ - gen_fps (Optional[np.ndarray]): Precomputed fingerprints of generated molecules. If None, will be computed.
335
+
336
+ Returns:
337
+ - float: Internal diversity score.
338
+
339
  """
340
  if gen_fps is None:
341
  gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
 
344
 
345
 
346
  def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
347
+ """
348
+ Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules.
349
+
350
+ Parameters:
351
+ - gen (List[str]): List of generated SMILES strings.
352
+ - train (List[str]): List of training set SMILES strings.
353
+ - n_jobs (int): Number of parallel jobs for computation.
354
+ - device (str): Computation device for the FCD calculation.
355
+
356
+ Returns:
357
+ - float: FCD score.
358
+ """
359
+
360
  fcd = FCD(device=device, n_jobs= n_jobs)
361
  return fcd(gen, train)
362
 
 
388
  # return None # Or handle empty list or all failed predictions as needed
389
 
390
  def oracles(gen, train):
 
 
 
 
 
391
 
392
+ """
393
+ Computes scores from various oracles for a list of generated molecules.
394
 
395
+ Parameters:
396
+ - gen (List[str]): List of generated SMILES strings.
397
+ - train (List[str]): List of training set SMILES strings.
398
+
399
+ Returns:
400
+ - Dict[str, Any]: A dictionary with oracle names as keys and their corresponding scores as values.
401
+ """
402
+
403
+ Result = {}
404
+ # evaluator = Evaluator(name = 'KL_Divergence')
405
+ # KL_Divergence = evaluator(gen, train)
 
 
 
 
 
406
 
407
+ # Result["KL_Divergence"]: KL_Divergence
408
+
409
+
410
+ oracle_list = [
411
+ 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
412
+ 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
413
+ 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
414
+ ]
415
+
416
+ for oracle_name in oracle_list:
417
+ oracle = Oracle(name=oracle_name)
418
+ if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
419
+ score = oracle(gen)
420
+ if isinstance(score, dict):
421
+ score = {key: sum(values)/len(values) for key, values in score.items()}
422
+ else:
423
+ score = oracle(gen)
424
+ if isinstance(score, list):
425
+ score = sum(score) / len(score)
426
 
427
+ Result[f"{oracle_name}"] = score
428
+
429
+ return Result
430
 
431
 
432
 
 
509
  metrics['IntDiv'] = internal_diversity(gen=gensmi)
510
  metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
511
  metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi)
512
+
513
  # metrics['SA'] = SAscore(gen=gensmi)
514
  # metrics['SCS'] = SAscore(gen=trainsmi)
515