Spaces:
Sleeping
Sleeping
from statistics import mean | |
import os | |
import math | |
import time | |
import datetime | |
from rdkit import DataStructs | |
from rdkit import Chem | |
from rdkit import RDLogger | |
from rdkit.Chem import AllChem | |
from rdkit.Chem import Draw | |
from rdkit.Chem.Scaffolds import MurckoScaffold | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.lines import Line2D | |
import torch | |
import wandb | |
RDLogger.DisableLog('rdApp.*') | |
import warnings | |
from multiprocessing import Pool | |
class Metrics(object): | |
def valid(x): | |
return x is not None and Chem.MolToSmiles(x) != '' | |
def tanimoto_sim_1v2(data1, data2): | |
min_len = data1.size if data1.size > data2.size else data2 | |
sims = [] | |
for i in range(min_len): | |
sim = DataStructs.FingerprintSimilarity(data1[i], data2[i]) | |
sims.append(sim) | |
mean_sim = mean(sim) | |
return mean_sim | |
def mol_length(x): | |
if x is not None: | |
return len([char for char in max(x.split(sep =".")).upper() if char.isalpha()]) | |
else: | |
return 0 | |
def max_component(data, max_len): | |
return ((np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()) | |
def mean_atom_type(data): | |
atom_types_used = [] | |
for i in data: | |
atom_types_used.append(len(i.unique().tolist())) | |
av_type = np.mean(atom_types_used) - 1 | |
return av_type | |
def sim_reward(mol_gen, fps_r): | |
gen_scaf = [] | |
for x in mol_gen: | |
if x is not None: | |
try: | |
gen_scaf.append(MurckoScaffold.GetScaffoldForMol(x)) | |
except: | |
pass | |
if len(gen_scaf) == 0: | |
rew = 1 | |
else: | |
fps = [Chem.RDKFingerprint(x) for x in gen_scaf] | |
fps = np.array(fps) | |
fps_r = np.array(fps_r) | |
rew = average_agg_tanimoto(fps_r,fps) | |
if math.isnan(rew): | |
rew = 1 | |
return rew ## change this to penalty | |
########################################## | |
########################################## | |
########################################## | |
def mols2grid_image(mols,path): | |
mols = [e if e is not None else Chem.RWMol() for e in mols] | |
for i in range(len(mols)): | |
if Metrics.valid(mols[i]): | |
AllChem.Compute2DCoords(mols[i]) | |
Draw.MolToFile(mols[i], os.path.join(path,"{}.png".format(i+1)), size=(1200,1200)) | |
#wandb.save(os.path.join(path,"{}.png".format(i+1))) | |
else: | |
continue | |
def save_smiles_matrices(mols,edges_hard, nodes_hard, path, data_source = None): | |
mols = [e if e is not None else Chem.RWMol() for e in mols] | |
for i in range(len(mols)): | |
if Metrics.valid(mols[i]): | |
save_path = os.path.join(path,"{}.txt".format(i+1)) | |
with open(save_path, "a") as f: | |
np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n",fmt='%1.2f') | |
f.write("\n") | |
np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:",fmt='%1.2f') | |
f.write("\n") | |
#f.write(m0) | |
f.write("\n") | |
print(Chem.MolToSmiles(mols[i]), file=open(save_path,"a")) | |
#wandb.save(save_path) | |
else: | |
continue | |
########################################## | |
########################################## | |
########################################## | |
def dense_to_sparse_with_attr(adj): | |
assert adj.dim() >= 2 and adj.dim() <= 3 | |
assert adj.size(-1) == adj.size(-2) | |
index = adj.nonzero(as_tuple=True) | |
edge_attr = adj[index] | |
if len(index) == 3: | |
batch = index[0] * adj.size(-1) | |
index = (batch + index[1], batch + index[2]) | |
#index = torch.stack(index, dim=0) | |
return index, edge_attr | |
def label2onehot(labels, dim, device): | |
"""Convert label indices to one-hot vectors.""" | |
out = torch.zeros(list(labels.size())+[dim]).to(device) | |
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.) | |
return out.float() | |
def mol_sample(sample_directory, edges, nodes, idx, i,matrices2mol, dataset_name): | |
sample_path = os.path.join(sample_directory,"{}_{}-epoch_iteration".format(idx+1, i+1)) | |
g_edges_hat_sample = torch.max(edges, -1)[1] | |
g_nodes_hat_sample = torch.max(nodes , -1)[1] | |
mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name) | |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)] | |
if not os.path.exists(sample_path): | |
os.makedirs(sample_path) | |
mols2grid_image(mol,sample_path) | |
save_smiles_matrices(mol,g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path) | |
if len(os.listdir(sample_path)) == 0: | |
os.rmdir(sample_path) | |
print("Valid molecules are saved.") | |
print("Valid matrices and smiles are saved") | |
def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node, | |
matrices2mol, dataset_name, real_adj, real_annot, drug_vecs): | |
g_edges_hat_sample = torch.max(edge, -1)[1] | |
g_nodes_hat_sample = torch.max(node , -1)[1] | |
a_tensor_sample = torch.max(real_adj, -1)[1].float() | |
x_tensor_sample = torch.max(real_annot, -1)[1].float() | |
mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name) | |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)] | |
real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name) | |
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)] | |
atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample) | |
real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None] | |
gen_smiles = [] | |
uniq_smiles = [] | |
for line in mols: | |
if line is not None: | |
gen_smiles.append(Chem.MolToSmiles(line)) | |
uniq_smiles.append(Chem.MolToSmiles(line)) | |
elif line is None: | |
gen_smiles.append(None) | |
gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles] | |
uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles] | |
sample_save_dir = os.path.join(save_path, "samples.txt") | |
with open(sample_save_dir, "a") as f: | |
for idxs in range(len(gen_smiles_saves)): | |
if gen_smiles_saves[idxs] is not None: | |
f.write(gen_smiles_saves[idxs]) | |
f.write("\n") | |
k = len(set(uniq_smiles_saves) - {None}) | |
et = time.time() - start_time | |
et = str(datetime.timedelta(seconds=et))[:-7] | |
log = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i+1) | |
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None] | |
chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None] | |
# Log update | |
#m0 = get_all_metrics(gen = gen_smiles, train = train_smiles, batch_size=batch_size, k = valid_mol_num, device=self.device) | |
valid = fraction_valid(gen_smiles_saves) | |
unique = fraction_unique(uniq_smiles_saves, k, check_validity=False) | |
novel_starting_mol = novelty(gen_smiles_saves, real_smiles) | |
novel_akt = novelty(gen_smiles_saves, drug_smiles) | |
if (len(uniq_smiles_saves) == 0): | |
snn_chembl = 0 | |
snn_akt = 0 | |
maxlen = 0 | |
else: | |
snn_chembl = average_agg_tanimoto(np.array(chembl_vecs),np.array(gen_vecs)) | |
snn_akt = average_agg_tanimoto(np.array(drug_vecs),np.array(gen_vecs)) | |
maxlen = Metrics.max_component(uniq_smiles_saves, 45) | |
loss.update({'Validity': valid}) | |
loss.update({'Uniqueness': unique}) | |
loss.update({'Novelty': novel_starting_mol}) | |
loss.update({'Novelty_akt': novel_akt}) | |
loss.update({'SNN_chembl': snn_chembl}) | |
loss.update({'SNN_akt': snn_akt}) | |
loss.update({'MaxLen': maxlen}) | |
loss.update({'Atom_types': atom_types_average}) | |
wandb.log({"Validity": valid, "Uniqueness": unique, "Novelty": novel_starting_mol, | |
"Novelty_akt": novel_akt, "SNN_chembl": snn_chembl, "SNN_akt": snn_akt, | |
"MaxLen": maxlen, "Atom_types": atom_types_average}) | |
for tag, value in loss.items(): | |
log += ", {}: {:.4f}".format(tag, value) | |
with open(log_path, "a") as f: | |
f.write(log) | |
f.write("\n") | |
print(log) | |
print("\n") | |
def plot_grad_flow(named_parameters, model, itera, epoch,grad_flow_directory): | |
# Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10 | |
'''Plots the gradients flowing through different layers in the net during training. | |
Can be used for checking for possible gradient vanishing / exploding problems. | |
Usage: Plug this function in Trainer class after loss.backwards() as | |
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' | |
ave_grads = [] | |
max_grads= [] | |
layers = [] | |
for n, p in named_parameters: | |
if(p.requires_grad) and ("bias" not in n): | |
#print(p.grad,n) | |
layers.append(n) | |
ave_grads.append(p.grad.abs().mean().cpu()) | |
max_grads.append(p.grad.abs().max().cpu()) | |
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") | |
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") | |
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" ) | |
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") | |
plt.xlim(left=0, right=len(ave_grads)) | |
plt.ylim(bottom = -0.001, top=1) # zoom in on the lower gradient regions | |
plt.xlabel("Layers") | |
plt.ylabel("average gradient") | |
plt.title("Gradient flow") | |
plt.grid(True) | |
plt.legend([Line2D([0], [0], color="c", lw=4), | |
Line2D([0], [0], color="b", lw=4), | |
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) | |
pltsavedir = grad_flow_directory | |
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight') | |
def get_mol(smiles_or_mol): | |
''' | |
Loads SMILES/molecule into RDKit's object | |
''' | |
if isinstance(smiles_or_mol, str): | |
if len(smiles_or_mol) == 0: | |
return None | |
mol = Chem.MolFromSmiles(smiles_or_mol) | |
if mol is None: | |
return None | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return mol | |
return smiles_or_mol | |
def mapper(n_jobs): | |
''' | |
Returns function for map call. | |
If n_jobs == 1, will use standard map | |
If n_jobs > 1, will use multiprocessing pool | |
If n_jobs is a pool object, will return its map function | |
''' | |
if n_jobs == 1: | |
def _mapper(*args, **kwargs): | |
return list(map(*args, **kwargs)) | |
return _mapper | |
if isinstance(n_jobs, int): | |
pool = Pool(n_jobs) | |
def _mapper(*args, **kwargs): | |
try: | |
result = pool.map(*args, **kwargs) | |
finally: | |
pool.terminate() | |
return result | |
return _mapper | |
return n_jobs.map | |
def remove_invalid(gen, canonize=True, n_jobs=1): | |
""" | |
Removes invalid molecules from the dataset | |
""" | |
if not canonize: | |
mols = mapper(n_jobs)(get_mol, gen) | |
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None] | |
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if | |
x is not None] | |
def fraction_valid(gen, n_jobs=1): | |
""" | |
Computes a number of valid molecules | |
Parameters: | |
gen: list of SMILES | |
n_jobs: number of threads for calculation | |
""" | |
gen = mapper(n_jobs)(get_mol, gen) | |
return 1 - gen.count(None) / len(gen) | |
def canonic_smiles(smiles_or_mol): | |
mol = get_mol(smiles_or_mol) | |
if mol is None: | |
return None | |
return Chem.MolToSmiles(mol) | |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=False): | |
""" | |
Computes a number of unique molecules | |
Parameters: | |
gen: list of SMILES | |
k: compute unique@k | |
n_jobs: number of threads for calculation | |
check_validity: raises ValueError if invalid molecules are present | |
""" | |
if k is not None: | |
if len(gen) < k: | |
warnings.warn( | |
"Can't compute unique@{}.".format(k) + | |
"gen contains only {} molecules".format(len(gen)) | |
) | |
gen = gen[:k] | |
canonic = set(mapper(n_jobs)(canonic_smiles, gen)) | |
if None in canonic and check_validity: | |
#canonic = [i for i in canonic if i is not None] | |
raise ValueError("Invalid molecule passed to unique@k") | |
return 0 if len(gen) == 0 else len(canonic) / len(gen) | |
def novelty(gen, train, n_jobs=1): | |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen) | |
gen_smiles_set = set(gen_smiles) - {None} | |
train_set = set(train) | |
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set) | |
def average_agg_tanimoto(stock_vecs, gen_vecs, | |
batch_size=5000, agg='max', | |
device='cpu', p=1): | |
""" | |
For each molecule in gen_vecs finds closest molecule in stock_vecs. | |
Returns average tanimoto score for between these molecules | |
Parameters: | |
stock_vecs: numpy array <n_vectors x dim> | |
gen_vecs: numpy array <n_vectors' x dim> | |
agg: max or mean | |
p: power for averaging: (mean x^p)^(1/p) | |
""" | |
assert agg in ['max', 'mean'], "Can aggregate only max or mean" | |
agg_tanimoto = np.zeros(len(gen_vecs)) | |
total = np.zeros(len(gen_vecs)) | |
for j in range(0, stock_vecs.shape[0], batch_size): | |
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() | |
for i in range(0, gen_vecs.shape[0], batch_size): | |
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() | |
y_gen = y_gen.transpose(0, 1) | |
tp = torch.mm(x_stock, y_gen) | |
jac = (tp / (x_stock.sum(1, keepdim=True) + | |
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() | |
jac[np.isnan(jac)] = 1 | |
if p != 1: | |
jac = jac**p | |
if agg == 'max': | |
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( | |
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) | |
elif agg == 'mean': | |
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) | |
total[i:i + y_gen.shape[1]] += jac.shape[0] | |
if agg == 'mean': | |
agg_tanimoto /= total | |
if p != 1: | |
agg_tanimoto = (agg_tanimoto)**(1/p) | |
return np.mean(agg_tanimoto) | |
def str2bool(v): | |
return v.lower() in ('true') |