Spaces:
Sleeping
Sleeping
import evaluate | |
import datasets | |
import pandas as pd | |
import numpy as np | |
import scipy.sparse | |
from scipy.spatial.distance import cosine as cos_distance | |
from scipy.stats import wasserstein_distance | |
import torch | |
import warnings | |
from multiprocessing import Pool | |
from functools import partial | |
from fcd_torch import FCD | |
from collections import Counter | |
from tdc import Oracle | |
from rdkit.Chem.Crippen import MolLogP | |
from rdkit import Chem | |
from rdkit.Chem import MACCSkeys | |
from rdkit.Chem import AllChem | |
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan | |
from rdkit.Chem.QED import qed | |
from rdkit.Contrib.SA_Score import sascorer | |
from rdkit.Chem.Scaffolds import MurckoScaffold | |
from syba.syba import SybaClassifier | |
from myscscore.SCScore import SCScorer | |
def get_mol(smiles_or_mol): | |
""" | |
Converts a SMILES string or RDKit molecule object to an RDKit molecule object. | |
If the input is already an RDKit molecule object, it returns it directly. | |
For a SMILES string, it attempts to create an RDKit molecule object. | |
Parameters: | |
- smiles_or_mol (str or Mol): The SMILES string of the molecule or an RDKit molecule object. | |
Returns: | |
- Mol or None: The RDKit molecule object or None if conversion fails. | |
""" | |
if isinstance(smiles_or_mol, str): | |
if len(smiles_or_mol) == 0: | |
return None | |
mol = Chem.MolFromSmiles(smiles_or_mol) | |
if mol is None: | |
return None | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return mol | |
return smiles_or_mol | |
def mapper(n_jobs): | |
""" | |
Returns a mapping function suitable for parallel or sequential execution | |
based on the value of n_jobs. | |
Parameters: | |
- n_jobs (int or Pool): Number of jobs for parallel execution or a multiprocessing Pool object. | |
Returns: | |
- Function: A mapping function that can be used for applying a function over a sequence. | |
""" | |
if n_jobs == 1: | |
def _mapper(*args, **kwargs): | |
return list(map(*args, **kwargs)) | |
return _mapper | |
if isinstance(n_jobs, int): | |
pool = Pool(n_jobs) | |
def _mapper(*args, **kwargs): | |
try: | |
result = pool.map(*args, **kwargs) | |
finally: | |
pool.terminate() | |
return result | |
return _mapper | |
return n_jobs.map | |
def fraction_valid(gen, n_jobs=1): | |
""" | |
Calculates the fraction of valid molecules in a list of SMILES strings. | |
Parameters: | |
- gen (list of str): List of SMILES strings. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
Returns: | |
- float: Fraction of valid molecules. | |
""" | |
gen = mapper(n_jobs)(get_mol, gen) | |
return 1 - gen.count(None) / len(gen) | |
def canonic_smiles(smiles_or_mol): | |
""" | |
Converts a molecule into its canonical SMILES representation. | |
Parameters: | |
- smiles_or_mol (str or Mol): SMILES string or RDKit molecule object. | |
Returns: | |
- str or None: Canonical SMILES string, or None if conversion fails. | |
""" | |
mol = get_mol(smiles_or_mol) | |
if mol is None: | |
return None | |
return Chem.MolToSmiles(mol) | |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=False): | |
""" | |
Calculates the fraction of unique molecules in a list of SMILES strings. | |
Parameters: | |
- gen (list of str): List of SMILES strings. | |
- k (int, optional): Number of top molecules to consider for uniqueness. If None, considers all. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
- check_validity (bool): If True, checks for the validity of molecules. | |
Returns: | |
- float: Fraction of unique molecules. | |
""" | |
if k is not None: | |
if len(gen) < k: | |
warnings.warn( | |
"Can't compute unique@{}.".format(k) + | |
"gen contains only {} molecules".format(len(gen)) | |
) | |
gen = gen[:k] | |
canonic = set(mapper(n_jobs)(canonic_smiles, gen)) | |
if None in canonic and check_validity: | |
raise ValueError("Invalid molecule passed to unique@k") | |
return len(canonic) / len(gen) | |
def novelty(gen, train, n_jobs=1): | |
""" | |
Computes the novelty of generated molecules compared to a training set. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- train (List[str]): List of SMILES strings from the training set. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
Returns: | |
- float: Novelty score. | |
""" | |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen) | |
gen_smiles_set = set(gen_smiles) - {None} | |
train_set = set(train) | |
return len(gen_smiles_set - train_set) / len(gen_smiles_set) | |
def synthetic_complexity_score(gen): | |
""" | |
Calculate the average Synthetic Complexity Score (SCScore) for a list of molecules represented by their SMILES strings. | |
The SCScore model rates the synthetic complexity of molecules on a scale from 1 to 5. | |
Based on the premise that on average, the products of published chemical reactions should be more synthetically complex than their corresponding reactants | |
Parameters: | |
- gen (list of str): A list containing the SMILES representations of the molecules. | |
Returns: | |
- float: The average Synthetic Accessibility Score for the valid molecules in the list. Returns None if no valid molecules are found. | |
""" | |
model = SCScorer() | |
model.restore() | |
average_score = model.get_avg_score(gen) | |
return average_score | |
def calculate_sa_score(smiles): | |
""" | |
Calculates the SA score for a single SMILES string. | |
Evaluates the ease of synthesizing drug-like molecules in virtual screening. | |
Ranges from 1 (easy to synthesize) to 10 (hard to synthesize) | |
This score reflects the presence of common fragments in a molecule and structural complexities. | |
Parameters: | |
- smiles (str): SMILES string of the molecule. | |
Returns: | |
- float: SA score of the molecule, or None if the molecule couldn't be created. | |
""" | |
mol = get_mol(smiles) | |
if mol: | |
return sascorer.calculateScore(mol) | |
else: | |
return None | |
def average_sascore(gen, n_jobs=1): | |
""" | |
Computes the average synthetic accessibility score for a list of molecules. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- n_jobs (int): Number of parallel jobs to use for computation. | |
Returns: | |
- float: Average SA score, or None if no scores could be computed. | |
""" | |
scores = mapper(n_jobs)(calculate_sa_score, gen) | |
# Filter out None values which indicate failed molecule creation | |
valid_scores = [score for score in scores if score is not None] | |
if valid_scores: | |
return sum(valid_scores) / len(valid_scores) | |
else: | |
return None | |
def average_agg_tanimoto(stock_vecs, gen_vecs, | |
batch_size=5000, agg='max', | |
device='cpu', p=1): | |
""" | |
Calculates the average aggregate Tanimoto similarity between two sets of molecule fingerprints. | |
Parameters: | |
- stock_vecs (numpy array): Fingerprint vectors for the reference molecule set. | |
- gen_vecs (numpy array): Fingerprint vectors for the generated molecule set. | |
- batch_size (int): The size of batches to process similarities (reduces memory usage). | |
- agg (str): Aggregation method, either 'max' or 'mean'. | |
- device (str): The computation device ('cpu' or 'cuda:0', etc.). | |
- p (float): The power for averaging, used in generalized mean calculation. | |
Returns: | |
- float: Average aggregate Tanimoto similarity score. | |
""" | |
assert agg in ['max', 'mean'], "Can aggregate only max or mean" | |
agg_tanimoto = np.zeros(len(gen_vecs)) | |
total = np.zeros(len(gen_vecs)) | |
for j in range(0, stock_vecs.shape[0], batch_size): | |
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() | |
for i in range(0, gen_vecs.shape[0], batch_size): | |
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() | |
y_gen = y_gen.transpose(0, 1) | |
tp = torch.mm(x_stock, y_gen) | |
jac = (tp / (x_stock.sum(1, keepdim=True) + | |
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() | |
jac[np.isnan(jac)] = 1 | |
if p != 1: | |
jac = jac**p | |
if agg == 'max': | |
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( | |
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) | |
elif agg == 'mean': | |
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) | |
total[i:i + y_gen.shape[1]] += jac.shape[0] | |
if agg == 'mean': | |
agg_tanimoto /= total | |
if p != 1: | |
agg_tanimoto = (agg_tanimoto)**(1/p) | |
return np.mean(agg_tanimoto) | |
def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2, | |
morgan__n=1024, *args, **kwargs): | |
""" | |
Generates fingerprint for SMILES | |
If smiles is invalid, returns None | |
Returns numpy array of fingerprint bits | |
Parameters: | |
smiles: SMILES string | |
type: type of fingerprint: [MACCS|morgan] | |
dtype: if not None, specifies the dtype of returned array | |
""" | |
fp_type = fp_type.lower() | |
molecule = get_mol(smiles_or_mol, *args, **kwargs) | |
if molecule is None: | |
return None | |
if fp_type == 'maccs': | |
keys = MACCSkeys.GenMACCSKeys(molecule) | |
keys = np.array(keys.GetOnBits()) | |
fingerprint = np.zeros(166, dtype='uint8') | |
if len(keys) != 0: | |
fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero | |
elif fp_type == 'morgan': | |
fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n), | |
dtype='uint8') | |
else: | |
raise ValueError("Unknown fingerprint type {}".format(fp_type)) | |
if dtype is not None: | |
fingerprint = fingerprint.astype(dtype) | |
return fingerprint | |
def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args, | |
**kwargs): | |
''' | |
Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers | |
e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10) | |
Inserts np.NaN to rows corresponding to incorrect smiles. | |
IMPORTANT: if there is at least one np.NaN, the dtype would be float | |
Parameters: | |
smiles_mols_array: list/array/pd.Series of smiles or already computed | |
RDKit molecules | |
n_jobs: number of parralel workers to execute | |
already_unique: flag for performance reasons, if smiles array is big | |
and already unique. Its value is set to True if smiles_mols_array | |
contain RDKit molecules already. | |
''' | |
if isinstance(smiles_mols_array, pd.Series): | |
smiles_mols_array = smiles_mols_array.values | |
else: | |
smiles_mols_array = np.asarray(smiles_mols_array) | |
if not isinstance(smiles_mols_array[0], str): | |
already_unique = True | |
if not already_unique: | |
smiles_mols_array, inv_index = np.unique(smiles_mols_array, | |
return_inverse=True) | |
fps = mapper(n_jobs)( | |
partial(fingerprint, *args, **kwargs), smiles_mols_array | |
) | |
length = 1 | |
for fp in fps: | |
if fp is not None: | |
length = fp.shape[-1] | |
first_fp = fp | |
break | |
fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :] | |
for fp in fps] | |
if scipy.sparse.issparse(first_fp): | |
fps = scipy.sparse.vstack(fps).tocsr() | |
else: | |
fps = np.vstack(fps) | |
if not already_unique: | |
return fps[inv_index] | |
return fps | |
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan', | |
gen_fps=None, p=1): | |
""" | |
Computes internal diversity as: | |
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- n_jobs (int): Number of parallel jobs for fingerprint computation. | |
- device (str): Computation device ('cpu' or 'cuda:0', etc.). | |
- fp_type (str): Type of fingerprint to use ('morgan', etc.). | |
- gen_fps (Optional[np.ndarray]): Precomputed fingerprints of generated molecules. If None, will be computed. | |
Returns: | |
- float: Internal diversity score. | |
""" | |
if gen_fps is None: | |
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs) | |
return 1 - (average_agg_tanimoto(gen_fps, gen_fps, | |
agg='mean', device=device, p=p)).mean() | |
def fcd_metric(gen, train, n_jobs = 1, device = None): | |
""" | |
Computes the Fréchet ChemNet Distance (FCD) between two sets of molecules. | |
FCD is calculated using the Fréchet Distance between feature vectors of generated and real molecules obtained from ChemNet. | |
A lower FCD score indicates higher chemical realism and diversity in the molecules generated by a model. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
- train (List[str]): List of training set SMILES strings. | |
- n_jobs (int): Number of parallel jobs for computation. | |
- device (str): Computation device for the FCD calculation. | |
Returns: | |
- float: FCD score. | |
""" | |
# Determine the device dynamically based on CUDA availability | |
if device is None: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
else: | |
device = torch.device(device if torch.cuda.is_available() and 'cuda' in device else 'cpu') | |
fcd = FCD(device=device, n_jobs= n_jobs) | |
return fcd(gen, train) | |
def SYBAscore(gen): | |
""" | |
Compute the average SYBA (SYnthetic Bayesian Accessibility) score for a list of SMILES strings. | |
It is a fragment-based method for the rapid classification of organic compounds as easy- (ES) or hard-to-synthesize (HS). | |
Based on a Bernoulli naïve Bayes classifier that is used to assign SYBA score contributions to individual fragments based on their frequencies in the database of ES and HS molecules. | |
Trained on ES molecules available in the ZINC15 database and on HS molecules generated by the Nonpher methodology | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
Returns: | |
- float: The average SYBA score for the list of molecules. | |
""" | |
syba = SybaClassifier() | |
syba.fitDefaultScore() | |
scores = [] | |
for smiles in gen: | |
try: | |
score = syba.predict(smi=smiles) | |
scores.append(score) | |
except Exception as e: | |
print(f"Error processing SMILES '{smiles}': {e}") | |
continue | |
if scores: | |
return sum(scores) / len(scores) | |
else: | |
return None # Or handle empty list or all failed predictions as needed | |
def qed_metric(gen): | |
""" | |
Computes RDKit's QED score. | |
A [0,1] value estimating how likely a molecule is a viable candidate for a drug. | |
QED is meant to capture certain desirable traits that successful drug molecules tend to possess | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
Returns: | |
- float: The average QED score for the list of molecules. | |
""" | |
if not gen: | |
return 0.0 # Return 0 or suitable value for empty list | |
# Convert SMILES strings to RDKit molecule objects and calculate QED scores | |
qed_scores = [] | |
for smiles in gen: | |
try: | |
mol = get_mol(smiles) | |
if mol: # Ensure molecule is valid | |
qed_scores.append(qed(mol)) | |
except Exception as e: | |
print(f"Error processing molecule {smiles}: {str(e)}") | |
# Calculate the average QED score | |
if qed_scores: | |
return sum(qed_scores) / len(qed_scores) | |
else: | |
return 0.0 # Return 0 or suitable value if no valid molecules are processed | |
def logP_metric(gen): | |
""" | |
Computes the average RDKit's logP value for a list of SMILES strings. | |
LogP is the log of the partition coefficient of a solute between octanol and water, at near infinite dilution. | |
It is stated that LogP should be between 0 and 5 for a small molecule drug to be a candidate for oral administration. | |
Computed with RDKit's Crippen (Wildman and Crippen, 1999) estimation. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
Returns: | |
- float: Average logP value for the list of molecules. | |
""" | |
# Check if the input list is empty | |
if not gen: | |
return 0.0 # Return 0 or suitable value for empty list | |
# Convert SMILES strings to RDKit molecule objects and calculate logP values | |
logP_values = [] | |
for smiles in gen: | |
try: | |
mol = get_mol(smiles) | |
if mol: # Ensure molecule is valid | |
logP_values.append(MolLogP(mol)) | |
except Exception as e: | |
print(f"Error processing molecule {smiles}: {str(e)}") | |
# Calculate the average logP value | |
if logP_values: | |
return sum(logP_values) / len(logP_values) | |
else: | |
return 0.0 # Return 0 or suitable value if no valid molecules are processed | |
def penalized_logp(gen): | |
""" | |
Computes the average PyTDC's penalized logP value for a list of SMILES strings. | |
Captures LogP, SA and penalty for number of rings. | |
Parameters: | |
- gen (List[str]): List of generated SMILES strings. | |
Returns: | |
- float: Average penalized logP value for the list of molecules. | |
""" | |
oracle = Oracle('LogP') | |
score = oracle(gen) | |
if isinstance(score, list): | |
score = sum(score) / len(score) | |
return score | |
_DESCRIPTION = """ | |
Comprehensive suite of metrics designed to assess the performance of molecular generation models, for understanding how well a model can produce novel, chemically valid molecules that are relevant to specific research objectives. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
generated_smiles (`list` of `string`): A collection of SMILES (Simplified Molecular Input Line Entry System) strings generated by the model, ideally encompassing more than 30,000 samples. | |
train_smiles (`list` of `string`): The dataset of SMILES strings used to train the model, serving as a reference to evaluate the novelty and diversity of the generated molecules. | |
Returns: | |
Dectionary item containing various metrics to evaluate model performance | |
""" | |
_CITATION = """ | |
""" | |
class molgenevalmetric(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"gensmi": datasets.Sequence(datasets.Value("string")), | |
"trainsmi": datasets.Sequence(datasets.Value("string")), | |
} | |
if self.config_name == "multilabel" | |
else { | |
"gensmi": datasets.Value("string"), | |
"trainsmi": datasets.Value("string"), | |
} | |
), | |
reference_urls=["https://github.com/molecularsets/moses", "https://tdcommons.ai/functions/oracles/", "https://github.com/lich-uct/syba", "https://github.com/connorcoley/scscore"], | |
) | |
def _compute(self, gensmi, trainsmi): | |
metrics = {} | |
metrics['Novelty'] = novelty(gen = gensmi, train = trainsmi) | |
metrics['Valid'] = fraction_valid(gen=gensmi) | |
metrics['Unique'] = fraction_unique(gen=gensmi) | |
metrics['IntDiv'] = internal_diversity(gen=gensmi) | |
metrics['FCD'] = fcd_metric(gen = gensmi, train = trainsmi) | |
metrics['QED'] = qed_metric(gen=gensmi) | |
metrics['LogP'] = logP_metric(gen=gensmi) | |
metrics['Penalized LogP'] = penalized_logp(gen=gensmi) | |
metrics['SA'] = average_sascore(gen=gensmi) | |
metrics['SCScore'] = synthetic_complexity_score(gen=gensmi) | |
metrics['SYBA'] = SYBAscore(gen=gensmi) | |
# metrics['Oracles'] = oracles(gen = gensmi, train = trainsmi) | |
return metrics | |
# def get_n_rings(mol): | |
# """ | |
# Computes the number of rings in a molecule | |
# """ | |
# return mol.GetRingInfo().NumRings() | |
# def fragmenter(mol): | |
# """ | |
# fragment mol using BRICS and return smiles list | |
# """ | |
# fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol)) | |
# fgs_smi = Chem.MolToSmiles(fgs).split(".") | |
# return fgs_smi | |
# def compute_fragments(mol_list, n_jobs=1): | |
# """ | |
# fragment list of mols using BRICS and return smiles list | |
# """ | |
# fragments = Counter() | |
# for mol_frag in mapper(n_jobs)(fragmenter, mol_list): | |
# fragments.update(mol_frag) | |
# return fragments | |
# def compute_scaffolds(mol_list, n_jobs=1, min_rings=2): | |
# """ | |
# Extracts a scafold from a molecule in a form of a canonic SMILES | |
# """ | |
# scaffolds = Counter() | |
# map_ = mapper(n_jobs) | |
# scaffolds = Counter( | |
# map_(partial(compute_scaffold, min_rings=min_rings), mol_list)) | |
# if None in scaffolds: | |
# scaffolds.pop(None) | |
# return scaffolds | |
# def compute_scaffold(mol, min_rings=2): | |
# mol = get_mol(mol) | |
# try: | |
# scaffold = MurckoScaffold.GetScaffoldForMol(mol) | |
# except (ValueError, RuntimeError): | |
# return None | |
# n_rings = get_n_rings(scaffold) | |
# scaffold_smiles = Chem.MolToSmiles(scaffold) | |
# if scaffold_smiles == '' or n_rings < min_rings: | |
# return None | |
# return scaffold_smiles | |
# class Metric: | |
# def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs): | |
# self.n_jobs = n_jobs | |
# self.device = device | |
# self.batch_size = batch_size | |
# for k, v in kwargs.values(): | |
# setattr(self, k, v) | |
# def __call__(self, ref=None, gen=None, pref=None, pgen=None): | |
# assert (ref is None) != (pref is None), "specify ref xor pref" | |
# assert (gen is None) != (pgen is None), "specify gen xor pgen" | |
# if pref is None: | |
# pref = self.precalc(ref) | |
# if pgen is None: | |
# pgen = self.precalc(gen) | |
# return self.metric(pref, pgen) | |
# def precalc(self, moleclues): | |
# raise NotImplementedError | |
# def metric(self, pref, pgen): | |
# raise NotImplementedError | |
# class SNNMetric(Metric): | |
# """ | |
# Computes average max similarities of gen SMILES to ref SMILES | |
# """ | |
# def __init__(self, fp_type='morgan', **kwargs): | |
# self.fp_type = fp_type | |
# super().__init__(**kwargs) | |
# def precalc(self, mols): | |
# return {'fps': fingerprints(mols, n_jobs=self.n_jobs, | |
# fp_type=self.fp_type)} | |
# def metric(self, pref, pgen): | |
# return average_agg_tanimoto(pref['fps'], pgen['fps'], | |
# device=self.device) | |
# def cos_similarity(ref_counts, gen_counts): | |
# """ | |
# Computes cosine similarity between | |
# dictionaries of form {name: count}. Non-present | |
# elements are considered zero: | |
# sim = <r, g> / ||r|| / ||g|| | |
# """ | |
# if len(ref_counts) == 0 or len(gen_counts) == 0: | |
# return np.nan | |
# keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys())) | |
# ref_vec = np.array([ref_counts.get(k, 0) for k in keys]) | |
# gen_vec = np.array([gen_counts.get(k, 0) for k in keys]) | |
# return 1 - cos_distance(ref_vec, gen_vec) | |
# class FragMetric(Metric): | |
# def precalc(self, mols): | |
# return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)} | |
# def metric(self, pref, pgen): | |
# return cos_similarity(pref['frag'], pgen['frag']) | |
# class ScafMetric(Metric): | |
# def precalc(self, mols): | |
# return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)} | |
# def metric(self, pref, pgen): | |
# return cos_similarity(pref['scaf'], pgen['scaf']) | |
# class WassersteinMetric(Metric): | |
# def __init__(self, func=None, **kwargs): | |
# self.func = func | |
# super().__init__(**kwargs) | |
# def precalc(self, mols): | |
# if self.func is not None: | |
# values = mapper(self.n_jobs)(self.func, mols) | |
# else: | |
# values = mols | |
# return {'values': values} | |
# def metric(self, pref, pgen): | |
# return wasserstein_distance( | |
# pref['values'], pgen['values'] | |
# ) | |
# def get_frag(gen): | |
# mols = mapper(pool)(get_mol, gen) | |
# kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size} | |