Spaces:
Sleeping
Sleeping
File size: 10,971 Bytes
0f253ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 |
import os
from collections import Counter
from functools import partial
import numpy as np
import pandas as pd
import scipy.sparse
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import MACCSkeys
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
from rdkit.Chem.QED import qed
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import Descriptors
import random
from multiprocessing import Pool
from collections import UserList, defaultdict
import numpy as np
import pandas as pd
from rdkit import rdBase
import sys
from rdkit.Chem import RDConfig
import os
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
import pandas as pd
from fcd_torch import FCD
from syba.syba import SybaClassifier
from tdc import Evaluator
from tdc import Oracle
def get_mol(smiles_or_mol):
'''
Loads SMILES/molecule into RDKit's object
'''
if isinstance(smiles_or_mol, str):
if len(smiles_or_mol) == 0:
return None
mol = Chem.MolFromSmiles(smiles_or_mol)
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return mol
return smiles_or_mol
def mapper(n_jobs):
'''
Returns function for map call.
If n_jobs == 1, will use standard map
If n_jobs > 1, will use multiprocessing pool
If n_jobs is a pool object, will return its map function
'''
if n_jobs == 1:
def _mapper(*args, **kwargs):
return list(map(*args, **kwargs))
return _mapper
if isinstance(n_jobs, int):
pool = Pool(n_jobs)
def _mapper(*args, **kwargs):
try:
result = pool.map(*args, **kwargs)
finally:
pool.terminate()
return result
return _mapper
return n_jobs.map
def fraction_valid(gen, n_jobs=1):
"""
Computes a number of valid molecules
Parameters:
gen: list of SMILES
n_jobs: number of threads for calculation
"""
gen = mapper(n_jobs)(get_mol, gen)
return 1 - gen.count(None) / len(gen)
def canonic_smiles(smiles_or_mol):
mol = get_mol(smiles_or_mol)
if mol is None:
return None
return Chem.MolToSmiles(mol)
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
"""
Computes a number of unique molecules
Parameters:
gen: list of SMILES
k: compute unique@k
n_jobs: number of threads for calculation
check_validity: raises ValueError if invalid molecules are present
"""
if k is not None:
if len(gen) < k:
warnings.warn(
"Can't compute unique@{}.".format(k) +
"gen contains only {} molecules".format(len(gen))
)
gen = gen[:k]
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
if None in canonic and check_validity:
raise ValueError("Invalid molecule passed to unique@k")
return len(canonic) / len(gen)
def novelty(gen, train, n_jobs=1):
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
gen_smiles_set = set(gen_smiles) - {None}
train_set = set(train)
return len(gen_smiles_set - train_set) / len(gen_smiles_set)
def SAscore(gen):
"""
Calculate the average Synthetic Accessibility Score (SAscore) for a list of molecules represented by their SMILES strings.
Parameters:
- smiles_list (list of str): A list containing the SMILES representations of the molecules.
Returns:
- float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found.
"""
scores = []
for smiles in gen:
mol = Chem.MolFromSmiles(smiles)
if mol: # Ensures the molecule could be parsed from the SMILES string
score = sascorer.calculateScore(mol)
scores.append(score)
if scores: # Checks if there are any scores calculated
return np.mean(scores)
else:
return None
def average_agg_tanimoto(stock_vecs, gen_vecs,
batch_size=5000, agg='max',
device='cpu', p=1):
"""
For each molecule in gen_vecs finds closest molecule in stock_vecs.
Returns average tanimoto score for between these molecules
Parameters:
stock_vecs: numpy array <n_vectors x dim>
gen_vecs: numpy array <n_vectors' x dim>
agg: max or mean
p: power for averaging: (mean x^p)^(1/p)
"""
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
agg_tanimoto = np.zeros(len(gen_vecs))
total = np.zeros(len(gen_vecs))
for j in range(0, stock_vecs.shape[0], batch_size):
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
for i in range(0, gen_vecs.shape[0], batch_size):
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
y_gen = y_gen.transpose(0, 1)
tp = torch.mm(x_stock, y_gen)
jac = (tp / (x_stock.sum(1, keepdim=True) +
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
jac[np.isnan(jac)] = 1
if p != 1:
jac = jac**p
if agg == 'max':
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
elif agg == 'mean':
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
total[i:i + y_gen.shape[1]] += jac.shape[0]
if agg == 'mean':
agg_tanimoto /= total
if p != 1:
agg_tanimoto = (agg_tanimoto)**(1/p)
return np.mean(agg_tanimoto)
def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
morgan__n=1024, *args, **kwargs):
"""
Generates fingerprint for SMILES
If smiles is invalid, returns None
Returns numpy array of fingerprint bits
Parameters:
smiles: SMILES string
type: type of fingerprint: [MACCS|morgan]
dtype: if not None, specifies the dtype of returned array
"""
fp_type = fp_type.lower()
molecule = get_mol(smiles_or_mol, *args, **kwargs)
if molecule is None:
return None
if fp_type == 'maccs':
keys = MACCSkeys.GenMACCSKeys(molecule)
keys = np.array(keys.GetOnBits())
fingerprint = np.zeros(166, dtype='uint8')
if len(keys) != 0:
fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
elif fp_type == 'morgan':
fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
dtype='uint8')
else:
raise ValueError("Unknown fingerprint type {}".format(fp_type))
if dtype is not None:
fingerprint = fingerprint.astype(dtype)
return fingerprint
def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
**kwargs):
'''
Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
Inserts np.NaN to rows corresponding to incorrect smiles.
IMPORTANT: if there is at least one np.NaN, the dtype would be float
Parameters:
smiles_mols_array: list/array/pd.Series of smiles or already computed
RDKit molecules
n_jobs: number of parralel workers to execute
already_unique: flag for performance reasons, if smiles array is big
and already unique. Its value is set to True if smiles_mols_array
contain RDKit molecules already.
'''
if isinstance(smiles_mols_array, pd.Series):
smiles_mols_array = smiles_mols_array.values
else:
smiles_mols_array = np.asarray(smiles_mols_array)
if not isinstance(smiles_mols_array[0], str):
already_unique = True
if not already_unique:
smiles_mols_array, inv_index = np.unique(smiles_mols_array,
return_inverse=True)
fps = mapper(n_jobs)(
partial(fingerprint, *args, **kwargs), smiles_mols_array
)
length = 1
for fp in fps:
if fp is not None:
length = fp.shape[-1]
first_fp = fp
break
fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
for fp in fps]
if scipy.sparse.issparse(first_fp):
fps = scipy.sparse.vstack(fps).tocsr()
else:
fps = np.vstack(fps)
if not already_unique:
return fps[inv_index]
return fps
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
gen_fps=None, p=1):
"""
Computes internal diversity as:
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
"""
if gen_fps is None:
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
agg='mean', device=device, p=p)).mean()
def fcd_metric(gen, train, n_jobs = 8, device = 'cuda:0'):
fcd = FCD(device=device, n_jobs= n_jobs)
return fcd(gen, train)
def SYBAscore(gen):
"""
Compute the average SYBA score for a list of SMILES strings.
Parameters:
- smiles_list (list of str): A list of SMILES strings representing molecules.
Returns:
- float: The average SYBA score for the list of molecules.
"""
syba = SybaClassifier()
syba.fitDefaultScore()
scores = []
for smiles in gen:
try:
score = syba.predict(smi=smiles)
scores.append(score)
except Exception as e:
print(f"Error processing SMILES '{smiles}': {e}")
continue
if scores:
return sum(scores) / len(scores)
else:
return None # Or handle empty list or all failed predictions as needed
def oracles(gen, train):
Result = {}
# evaluator = Evaluator(name = 'KL_Divergence')
# KL_Divergence = evaluator(gen, train)
# Result["KL_Divergence"]: KL_Divergence
oracle_list = [
'QED', 'SA', 'MPO', 'GSK3B', 'JNK3',
'DRD2', 'LogP', 'Rediscovery', 'Similarity',
'Median', 'Isomers', 'Valsartan_SMARTS', 'Hop'
]
for oracle_name in oracle_list:
oracle = Oracle(name=oracle_name)
if oracle_name in ['Rediscovery', 'MPO', 'Similarity', 'Median', 'Isomers', 'Hop']:
score = oracle(gen)
if isinstance(score, dict):
score = {key: sum(values)/len(values) for key, values in score.items()}
else:
score = oracle(gen)
if isinstance(score, list):
score = sum(score) / len(score)
Result[f"{oracle_name}"] = score
return Result
|