Spaces:
Sleeping
Sleeping
saicharan2804
commited on
Commit
·
0f253ff
1
Parent(s):
389f170
Full implementations
Browse files- metrics.py +335 -0
- molgenevalmetric.py +42 -31
- requirements.txt +6 -2
- scscore +1 -0
- syba_test.py +34 -0
metrics.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import Counter
|
3 |
+
from functools import partial
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import scipy.sparse
|
7 |
+
import torch
|
8 |
+
from rdkit import Chem
|
9 |
+
from rdkit.Chem import AllChem
|
10 |
+
from rdkit.Chem import MACCSkeys
|
11 |
+
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
|
12 |
+
from rdkit.Chem.QED import qed
|
13 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
14 |
+
from rdkit.Chem import Descriptors
|
15 |
+
|
16 |
+
import random
|
17 |
+
from multiprocessing import Pool
|
18 |
+
from collections import UserList, defaultdict
|
19 |
+
import numpy as np
|
20 |
+
import pandas as pd
|
21 |
+
from rdkit import rdBase
|
22 |
+
import sys
|
23 |
+
|
24 |
+
from rdkit.Chem import RDConfig
|
25 |
+
import os
|
26 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
27 |
+
import sascorer
|
28 |
+
import pandas as pd
|
29 |
+
from fcd_torch import FCD
|
30 |
+
from syba.syba import SybaClassifier
|
31 |
+
|
32 |
+
from tdc import Evaluator
|
33 |
+
from tdc import Oracle
|
34 |
+
|
35 |
+
|
36 |
+
def get_mol(smiles_or_mol):
|
37 |
+
'''
|
38 |
+
Loads SMILES/molecule into RDKit's object
|
39 |
+
'''
|
40 |
+
if isinstance(smiles_or_mol, str):
|
41 |
+
if len(smiles_or_mol) == 0:
|
42 |
+
return None
|
43 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
44 |
+
if mol is None:
|
45 |
+
return None
|
46 |
+
try:
|
47 |
+
Chem.SanitizeMol(mol)
|
48 |
+
except ValueError:
|
49 |
+
return None
|
50 |
+
return mol
|
51 |
+
return smiles_or_mol
|
52 |
+
|
53 |
+
def mapper(n_jobs):
|
54 |
+
'''
|
55 |
+
Returns function for map call.
|
56 |
+
If n_jobs == 1, will use standard map
|
57 |
+
If n_jobs > 1, will use multiprocessing pool
|
58 |
+
If n_jobs is a pool object, will return its map function
|
59 |
+
'''
|
60 |
+
if n_jobs == 1:
|
61 |
+
def _mapper(*args, **kwargs):
|
62 |
+
return list(map(*args, **kwargs))
|
63 |
+
|
64 |
+
return _mapper
|
65 |
+
if isinstance(n_jobs, int):
|
66 |
+
pool = Pool(n_jobs)
|
67 |
+
|
68 |
+
def _mapper(*args, **kwargs):
|
69 |
+
try:
|
70 |
+
result = pool.map(*args, **kwargs)
|
71 |
+
finally:
|
72 |
+
pool.terminate()
|
73 |
+
return result
|
74 |
+
|
75 |
+
return _mapper
|
76 |
+
return n_jobs.map
|
77 |
+
|
78 |
+
def fraction_valid(gen, n_jobs=1):
|
79 |
+
"""
|
80 |
+
Computes a number of valid molecules
|
81 |
+
Parameters:
|
82 |
+
gen: list of SMILES
|
83 |
+
n_jobs: number of threads for calculation
|
84 |
+
"""
|
85 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
86 |
+
return 1 - gen.count(None) / len(gen)
|
87 |
+
|
88 |
+
|
89 |
+
def canonic_smiles(smiles_or_mol):
|
90 |
+
mol = get_mol(smiles_or_mol)
|
91 |
+
if mol is None:
|
92 |
+
return None
|
93 |
+
return Chem.MolToSmiles(mol)
|
94 |
+
|
95 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
96 |
+
"""
|
97 |
+
Computes a number of unique molecules
|
98 |
+
Parameters:
|
99 |
+
gen: list of SMILES
|
100 |
+
k: compute unique@k
|
101 |
+
n_jobs: number of threads for calculation
|
102 |
+
check_validity: raises ValueError if invalid molecules are present
|
103 |
+
"""
|
104 |
+
if k is not None:
|
105 |
+
if len(gen) < k:
|
106 |
+
warnings.warn(
|
107 |
+
"Can't compute unique@{}.".format(k) +
|
108 |
+
"gen contains only {} molecules".format(len(gen))
|
109 |
+
)
|
110 |
+
gen = gen[:k]
|
111 |
+
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
112 |
+
|
113 |
+
if None in canonic and check_validity:
|
114 |
+
raise ValueError("Invalid molecule passed to unique@k")
|
115 |
+
return len(canonic) / len(gen)
|
116 |
+
|
117 |
+
def novelty(gen, train, n_jobs=1):
|
118 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
119 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
120 |
+
train_set = set(train)
|
121 |
+
return len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
122 |
+
|
123 |
+
|
124 |
+
def SAscore(gen):
|
125 |
+
"""
|
126 |
+
Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings.
|
127 |
+
|
128 |
+
Parameters:
|
129 |
+
- smiles_list (list of str): A list containing the SMILES representations of the molecules.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
- float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found.
|
133 |
+
"""
|
134 |
+
scores = []
|
135 |
+
for smiles in gen:
|
136 |
+
mol = Chem.MolFromSmiles(smiles)
|
137 |
+
if mol: # Ensures the molecule could be parsed from the SMILES string
|
138 |
+
score = sascorer.calculateScore(mol)
|
139 |
+
scores.append(score)
|
140 |
+
|
141 |
+
if scores: # Checks if there are any scores calculated
|
142 |
+
return np.mean(scores)
|
143 |
+
else:
|
144 |
+
return None
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
149 |
+
batch_size=5000, agg='max',
|
150 |
+
device='cpu', p=1):
|
151 |
+
"""
|
152 |
+
For each molecule in gen_vecs finds closest molecule in stock_vecs.
|
153 |
+
Returns average tanimoto score for between these molecules
|
154 |
+
|
155 |
+
Parameters:
|
156 |
+
stock_vecs: numpy array <n_vectors x dim>
|
157 |
+
gen_vecs: numpy array <n_vectors' x dim>
|
158 |
+
agg: max or mean
|
159 |
+
p: power for averaging: (mean x^p)^(1/p)
|
160 |
+
"""
|
161 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
162 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
163 |
+
total = np.zeros(len(gen_vecs))
|
164 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
165 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
166 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
167 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
168 |
+
y_gen = y_gen.transpose(0, 1)
|
169 |
+
tp = torch.mm(x_stock, y_gen)
|
170 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) +
|
171 |
+
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
172 |
+
jac[np.isnan(jac)] = 1
|
173 |
+
if p != 1:
|
174 |
+
jac = jac**p
|
175 |
+
if agg == 'max':
|
176 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
177 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
178 |
+
elif agg == 'mean':
|
179 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
180 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
181 |
+
if agg == 'mean':
|
182 |
+
agg_tanimoto /= total
|
183 |
+
if p != 1:
|
184 |
+
agg_tanimoto = (agg_tanimoto)**(1/p)
|
185 |
+
return np.mean(agg_tanimoto)
|
186 |
+
|
187 |
+
def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
|
188 |
+
morgan__n=1024, *args, **kwargs):
|
189 |
+
"""
|
190 |
+
Generates fingerprint for SMILES
|
191 |
+
If smiles is invalid, returns None
|
192 |
+
Returns numpy array of fingerprint bits
|
193 |
+
|
194 |
+
Parameters:
|
195 |
+
smiles: SMILES string
|
196 |
+
type: type of fingerprint: [MACCS|morgan]
|
197 |
+
dtype: if not None, specifies the dtype of returned array
|
198 |
+
"""
|
199 |
+
fp_type = fp_type.lower()
|
200 |
+
molecule = get_mol(smiles_or_mol, *args, **kwargs)
|
201 |
+
if molecule is None:
|
202 |
+
return None
|
203 |
+
if fp_type == 'maccs':
|
204 |
+
keys = MACCSkeys.GenMACCSKeys(molecule)
|
205 |
+
keys = np.array(keys.GetOnBits())
|
206 |
+
fingerprint = np.zeros(166, dtype='uint8')
|
207 |
+
if len(keys) != 0:
|
208 |
+
fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
|
209 |
+
elif fp_type == 'morgan':
|
210 |
+
fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
|
211 |
+
dtype='uint8')
|
212 |
+
else:
|
213 |
+
raise ValueError("Unknown fingerprint type {}".format(fp_type))
|
214 |
+
if dtype is not None:
|
215 |
+
fingerprint = fingerprint.astype(dtype)
|
216 |
+
return fingerprint
|
217 |
+
|
218 |
+
|
219 |
+
def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
|
220 |
+
**kwargs):
|
221 |
+
'''
|
222 |
+
Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
|
223 |
+
e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
|
224 |
+
Inserts np.NaN to rows corresponding to incorrect smiles.
|
225 |
+
IMPORTANT: if there is at least one np.NaN, the dtype would be float
|
226 |
+
Parameters:
|
227 |
+
smiles_mols_array: list/array/pd.Series of smiles or already computed
|
228 |
+
RDKit molecules
|
229 |
+
n_jobs: number of parralel workers to execute
|
230 |
+
already_unique: flag for performance reasons, if smiles array is big
|
231 |
+
and already unique. Its value is set to True if smiles_mols_array
|
232 |
+
contain RDKit molecules already.
|
233 |
+
'''
|
234 |
+
if isinstance(smiles_mols_array, pd.Series):
|
235 |
+
smiles_mols_array = smiles_mols_array.values
|
236 |
+
else:
|
237 |
+
smiles_mols_array = np.asarray(smiles_mols_array)
|
238 |
+
if not isinstance(smiles_mols_array[0], str):
|
239 |
+
already_unique = True
|
240 |
+
|
241 |
+
if not already_unique:
|
242 |
+
smiles_mols_array, inv_index = np.unique(smiles_mols_array,
|
243 |
+
return_inverse=True)
|
244 |
+
|
245 |
+
fps = mapper(n_jobs)(
|
246 |
+
partial(fingerprint, *args, **kwargs), smiles_mols_array
|
247 |
+
)
|
248 |
+
|
249 |
+
length = 1
|
250 |
+
for fp in fps:
|
251 |
+
if fp is not None:
|
252 |
+
length = fp.shape[-1]
|
253 |
+
first_fp = fp
|
254 |
+
break
|
255 |
+
fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
|
256 |
+
for fp in fps]
|
257 |
+
if scipy.sparse.issparse(first_fp):
|
258 |
+
fps = scipy.sparse.vstack(fps).tocsr()
|
259 |
+
else:
|
260 |
+
fps = np.vstack(fps)
|
261 |
+
if not already_unique:
|
262 |
+
return fps[inv_index]
|
263 |
+
return fps
|
264 |
+
|
265 |
+
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
|
266 |
+
gen_fps=None, p=1):
|
267 |
+
"""
|
268 |
+
Computes internal diversity as:
|
269 |
+
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
|
270 |
+
"""
|
271 |
+
if gen_fps is None:
|
272 |
+
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
|
273 |
+
return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
|
274 |
+
agg='mean', device=device, p=p)).mean()
|
275 |
+
|
276 |
+
|
277 |
+
def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
|
278 |
+
fcd = FCD(device=device, n_jobs= n_jobs)
|
279 |
+
return fcd(gen, train)
|
280 |
+
|
281 |
+
def SYBAscore(gen):
|
282 |
+
"""
|
283 |
+
Compute the average SYBA score for a list of SMILES strings.
|
284 |
+
|
285 |
+
Parameters:
|
286 |
+
- smiles_list (list of str): A list of SMILES strings representing molecules.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
- float: The average SYBA score for the list of molecules.
|
290 |
+
"""
|
291 |
+
syba = SybaClassifier()
|
292 |
+
syba.fitDefaultScore()
|
293 |
+
scores = []
|
294 |
+
|
295 |
+
for smiles in gen:
|
296 |
+
try:
|
297 |
+
score = syba.predict(smi=smiles)
|
298 |
+
scores.append(score)
|
299 |
+
except Exception as e:
|
300 |
+
print(f"Error processing SMILES '{smiles}': {e}")
|
301 |
+
continue
|
302 |
+
|
303 |
+
if scores:
|
304 |
+
return sum(scores) / len(scores)
|
305 |
+
else:
|
306 |
+
return None # Or handle empty list or all failed predictions as needed
|
307 |
+
|
308 |
+
def oracles(gen, train):
|
309 |
+
Result = {}
|
310 |
+
# evaluator = Evaluator(name = 'KL_Divergence')
|
311 |
+
# KL_Divergence = evaluator(gen, train)
|
312 |
+
|
313 |
+
# Result["KL_Divergence"]: KL_Divergence
|
314 |
+
|
315 |
+
|
316 |
+
oracle_list = [
|
317 |
+
'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
|
318 |
+
'DRD2', 'LogP', 'Rediscovery', 'Similarity',
|
319 |
+
'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
|
320 |
+
]
|
321 |
+
|
322 |
+
for oracle_name in oracle_list:
|
323 |
+
oracle = Oracle(name=oracle_name)
|
324 |
+
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
|
325 |
+
score = oracle(gen)
|
326 |
+
if isinstance(score, dict):
|
327 |
+
score = {key: sum(values)/len(values) for key, values in score.items()}
|
328 |
+
else:
|
329 |
+
score = oracle(gen)
|
330 |
+
if isinstance(score, list):
|
331 |
+
score = sum(score) / len(score)
|
332 |
+
|
333 |
+
Result[f"{oracle_name}"] = score
|
334 |
+
|
335 |
+
return Result
|
molgenevalmetric.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import evaluate
|
2 |
import datasets
|
3 |
# import moses
|
4 |
-
from moses import metrics
|
5 |
import pandas as pd
|
6 |
-
from tdc import Evaluator
|
7 |
-
from tdc import Oracle
|
|
|
8 |
|
9 |
|
10 |
_DESCRIPTION = """
|
@@ -77,42 +78,52 @@ class molgenevalmetric(evaluate.Metric):
|
|
77 |
reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"],
|
78 |
)
|
79 |
|
80 |
-
def _compute(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
|
83 |
|
84 |
-
generated_smiles = [s for s in generated_smiles if s != '']
|
85 |
|
86 |
-
evaluator = Evaluator(name = 'KL_Divergence')
|
87 |
-
KL_Divergence = evaluator(generated_smiles, train_smiles)
|
88 |
|
89 |
-
Results.update({
|
90 |
-
|
91 |
-
})
|
92 |
|
93 |
|
94 |
-
oracle_list = [
|
95 |
-
'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
|
96 |
-
'DRD2', 'LogP', 'Rediscovery', 'Similarity',
|
97 |
-
'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
|
98 |
-
]
|
99 |
|
100 |
-
for oracle_name in oracle_list:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
|
111 |
-
|
112 |
|
113 |
-
keys_to_remove = ["FCD/TestSF", "SNN/TestSF", "Frag/TestSF", "Scaf/TestSF"]
|
114 |
-
for key in keys_to_remove:
|
115 |
-
|
116 |
|
117 |
-
return {"results": Results}
|
118 |
|
|
|
1 |
import evaluate
|
2 |
import datasets
|
3 |
# import moses
|
4 |
+
# from moses import metrics
|
5 |
import pandas as pd
|
6 |
+
# from tdc import Evaluator
|
7 |
+
# from tdc import Oracle
|
8 |
+
from metrics import novelty, fraction_valid, fraction_unique, SAscore, internal_diversity,fcd_metric, SYBAscore, oracles
|
9 |
|
10 |
|
11 |
_DESCRIPTION = """
|
|
|
78 |
reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/"],
|
79 |
)
|
80 |
|
81 |
+
def _compute(self, gensmi, trainsmi):
|
82 |
+
|
83 |
+
metrics = {}
|
84 |
+
metrics['novelty'] = novelty(gen = gensmi, train = trainsmi)
|
85 |
+
metrics['valid'] = fraction_valid(gen=gensmi)
|
86 |
+
metrics['unique'] = fraction_unique(gen=gensmi)
|
87 |
+
metrics['IntDiv'] = internal_diversity(gen=gensmi)
|
88 |
+
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi)
|
89 |
+
metrics['SA'] = SAscore(gen=gensmi)
|
90 |
+
metrics['SCS'] = SAscore(gen=trainsmi)
|
91 |
+
|
92 |
+
return metrics
|
93 |
|
|
|
94 |
|
95 |
+
# generated_smiles = [s for s in generated_smiles if s != '']
|
96 |
|
97 |
+
# evaluator = Evaluator(name = 'KL_Divergence')
|
98 |
+
# KL_Divergence = evaluator(generated_smiles, train_smiles)
|
99 |
|
100 |
+
# Results.update({
|
101 |
+
# "KL_Divergence": KL_Divergence,
|
102 |
+
# })
|
103 |
|
104 |
|
105 |
+
# oracle_list = [
|
106 |
+
# 'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
|
107 |
+
# 'DRD2', 'LogP', 'Rediscovery', 'Similarity',
|
108 |
+
# 'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
|
109 |
+
# ]
|
110 |
|
111 |
+
# for oracle_name in oracle_list:
|
112 |
+
# oracle = Oracle(name=oracle_name)
|
113 |
+
# if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
|
114 |
+
# score = oracle(generated_smiles)
|
115 |
+
# if isinstance(score, dict):
|
116 |
+
# score = {key: sum(values)/len(values) for key, values in score.items()}
|
117 |
+
# else:
|
118 |
+
# score = oracle(generated_smiles)
|
119 |
+
# if isinstance(score, list):
|
120 |
+
# score = sum(score) / len(score)
|
121 |
|
122 |
+
# Results.update({f"{oracle_name}": score})
|
123 |
|
124 |
+
# # keys_to_remove = ["FCD/TestSF", "SNN/TestSF", "Frag/TestSF", "Scaf/TestSF"]
|
125 |
+
# # for key in keys_to_remove:
|
126 |
+
# # Results.pop(key, None)
|
127 |
|
128 |
+
# return {"results": Results}
|
129 |
|
requirements.txt
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
git+https://github.com/huggingface/evaluate@main
|
2 |
git+https://github.com/molecularsets/moses.git
|
|
|
|
|
|
|
|
|
3 |
rdkit
|
4 |
-
|
5 |
-
|
|
|
1 |
git+https://github.com/huggingface/evaluate@main
|
2 |
git+https://github.com/molecularsets/moses.git
|
3 |
+
numpy
|
4 |
+
pandas
|
5 |
+
scipy
|
6 |
+
torch
|
7 |
rdkit
|
8 |
+
pyarrow
|
9 |
+
fcd-torch
|
scscore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit ba64e5a524ccc7d9b2000b34b50d8372178ca5d6
|
syba_test.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from syba.syba import SybaClassifier
|
2 |
+
|
3 |
+
def SYBAscore(smiles_list):
|
4 |
+
"""
|
5 |
+
Compute the average SYBA score for a list of SMILES strings.
|
6 |
+
|
7 |
+
Parameters:
|
8 |
+
- smiles_list (list of str): A list of SMILES strings representing molecules.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
- float: The average SYBA score for the list of molecules.
|
12 |
+
"""
|
13 |
+
syba = SybaClassifier()
|
14 |
+
syba.fitDefaultScore()
|
15 |
+
scores = []
|
16 |
+
|
17 |
+
for smiles in smiles_list:
|
18 |
+
try:
|
19 |
+
score = syba.predict(smi=smiles)
|
20 |
+
scores.append(score)
|
21 |
+
except Exception as e:
|
22 |
+
print(f"Error processing SMILES '{smiles}': {e}")
|
23 |
+
continue
|
24 |
+
|
25 |
+
if scores:
|
26 |
+
return sum(scores) / len(scores)
|
27 |
+
else:
|
28 |
+
return None # Or handle empty list or all failed predictions as needed
|
29 |
+
|
30 |
+
|
31 |
+
syba = SybaClassifier()
|
32 |
+
syba.fitDefaultScore()
|
33 |
+
smi = "O=C(C)Oc1ccccc1C(=O)O"
|
34 |
+
print(syba.predict(smi))
|