GenFBDD / datasets /pdbbind.py
libokj's picture
Minify
c17cba8
import binascii
import glob
import os
import pickle
from collections import defaultdict
from multiprocessing import Pool
import random
import copy
import torch.nn.functional as F
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import MolFromSmiles, AddHs
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.transforms import BaseTransform
from tqdm import tqdm
from rdkit.Chem import RemoveAllHs
from datasets.process_mols import read_molecule, get_lig_graph_with_matching, generate_conformer, moad_extract_receptor_structure
from utils.diffusion_utils import modify_conformer, set_time
from utils.utils import read_strings_from_txt, crop_beyond
from utils import so3, torus
class NoiseTransform(BaseTransform):
def __init__(self, t_to_sigma, no_torsion, all_atom, alpha=1, beta=1,
include_miscellaneous_atoms=False, crop_beyond_cutoff=None, time_independent=False, rmsd_cutoff=0,
minimum_t=0, sampling_mixing_coeff=0):
self.t_to_sigma = t_to_sigma
self.no_torsion = no_torsion
self.all_atom = all_atom
self.include_miscellaneous_atoms = include_miscellaneous_atoms
self.minimum_t = minimum_t
self.mixing_coeff = sampling_mixing_coeff
self.alpha = alpha
self.beta = beta
self.crop_beyond_cutoff = crop_beyond_cutoff
self.rmsd_cutoff = rmsd_cutoff
self.time_independent = time_independent
def __call__(self, data):
t_tr, t_rot, t_tor, t = self.get_time()
return self.apply_noise(data, t_tr, t_rot, t_tor, t)
def get_time(self):
if self.time_independent:
t = np.random.beta(self.alpha, self.beta)
t_tr, t_rot, t_tor = t,t,t
else:
t = None
if self.mixing_coeff == 0:
t = np.random.beta(self.alpha, self.beta)
t = self.minimum_t + t * (1 - self.minimum_t)
else:
choice = np.random.binomial(1, self.mixing_coeff)
t1 = np.random.beta(self.alpha, self.beta)
t1 = t1 * self.minimum_t
t2 = np.random.beta(self.alpha, self.beta)
t2 = self.minimum_t + t2 * (1 - self.minimum_t)
t = choice * t1 + (1 - choice) * t2
t_tr, t_rot, t_tor = t,t,t
return t_tr, t_rot, t_tor, t
def apply_noise(self, data, t_tr, t_rot, t_tor, t, tr_update = None, rot_update=None, torsion_updates=None):
if not torch.is_tensor(data['ligand'].pos):
data['ligand'].pos = random.choice(data['ligand'].pos)
if self.time_independent:
orig_complex_graph = copy.deepcopy(data)
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor)
if self.time_independent:
set_time(data, 0, 0, 0, 0, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
else:
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update
rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update
torsion_updates = np.random.normal(loc=0.0, scale=tor_sigma, size=data['ligand'].edge_mask.sum()) if torsion_updates is None else torsion_updates
torsion_updates = None if self.no_torsion else torsion_updates
try:
modify_conformer(data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates)
except Exception as e:
print("failed modify conformer")
print(e)
if self.time_independent:
if self.no_torsion:
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy())
filterHs = torch.not_equal(data['ligand'].x[:, 0], 0).cpu().numpy()
if isinstance(orig_complex_graph['ligand'].orig_pos, list):
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
ligand_pos = data['ligand'].pos.cpu().numpy()[filterHs]
orig_ligand_pos = orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy()
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=1).mean(axis=0))
data.y = torch.tensor(rmsd < self.rmsd_cutoff).float().unsqueeze(0)
data.atom_y = data.y
return data
data.tr_score = -tr_update / tr_sigma ** 2
data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0)
data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float()
data.tor_sigma_edge = None if self.no_torsion else np.ones(data['ligand'].edge_mask.sum()) * tor_sigma
if data['ligand'].pos.shape[0] == 1:
# if the ligand is a single atom, the rotational score is always 0
data.rot_score = data.rot_score * 0
if self.crop_beyond_cutoff is not None:
crop_beyond(data, tr_sigma * 3 + self.crop_beyond_cutoff, self.all_atom)
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
return data
class PDBBind(Dataset):
def __init__(self, root, transform=None, cache_path='data/cache', split_path='data/', limit_complexes=0, chain_cutoff=10,
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15,
matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False,
atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, require_ligand=False,
include_miscellaneous_atoms=False,
protein_path_list=None, ligand_descriptions=None, keep_local_structures=False,
protein_file="protein_processed", ligand_file="ligand",
knn_only_graph=False, matching_tries=1, dataset='PDBBind'):
super(PDBBind, self).__init__(root, transform)
self.pdbbind_dir = root
self.include_miscellaneous_atoms = include_miscellaneous_atoms
self.max_lig_size = max_lig_size
self.split_path = split_path
self.limit_complexes = limit_complexes
self.chain_cutoff = chain_cutoff
self.receptor_radius = receptor_radius
self.num_workers = num_workers
self.c_alpha_max_neighbors = c_alpha_max_neighbors
self.remove_hs = remove_hs
self.esm_embeddings_path = esm_embeddings_path
self.use_old_wrong_embedding_order = False
self.require_ligand = require_ligand
self.protein_path_list = protein_path_list
self.ligand_descriptions = ligand_descriptions
self.keep_local_structures = keep_local_structures
self.protein_file = protein_file
self.fixed_knn_radius_graph = True
self.knn_only_graph = knn_only_graph
self.matching_tries = matching_tries
self.ligand_file = ligand_file
self.dataset = dataset
assert knn_only_graph or (not all_atoms)
self.all_atoms = all_atoms
if matching or protein_path_list is not None and ligand_descriptions is not None:
cache_path += '_torsion'
if all_atoms:
cache_path += '_allatoms'
self.full_cache_path = os.path.join(cache_path, f'{dataset}3_limit{self.limit_complexes}'
f'_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}'
f'_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}'
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
f'_chainCutoff{self.chain_cutoff if self.chain_cutoff is None else int(self.chain_cutoff)}'
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
+ (''if not matching or num_conformers == 1 else f'_confs{num_conformers}')
+ ('' if self.esm_embeddings_path is None else f'_esmEmbeddings')
+ '_full'
+ ('' if not keep_local_structures else f'_keptLocalStruct')
+ ('' if protein_path_list is None or ligand_descriptions is None else str(binascii.crc32(''.join(ligand_descriptions + protein_path_list).encode())))
+ ('' if protein_file == "protein_processed" else '_' + protein_file)
+ ('' if not self.fixed_knn_radius_graph else (f'_fixedKNN' if not self.knn_only_graph else '_fixedKNNonly'))
+ ('' if not self.include_miscellaneous_atoms else '_miscAtoms')
+ ('' if self.use_old_wrong_embedding_order else '_chainOrd')
+ ('' if self.matching_tries == 1 else f'_tries{matching_tries}'))
self.popsize, self.maxiter = popsize, maxiter
self.matching, self.keep_original = matching, keep_original
self.num_conformers = num_conformers
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
if not self.check_all_complexes():
os.makedirs(self.full_cache_path, exist_ok=True)
if protein_path_list is None or ligand_descriptions is None:
self.preprocessing()
else:
self.inference_preprocessing()
self.complex_graphs, self.rdkit_ligands = self.collect_all_complexes()
print_statistics(self.complex_graphs)
list_names = [complex['name'] for complex in self.complex_graphs]
with open(os.path.join(self.full_cache_path, f'pdbbind_{os.path.splitext(os.path.basename(self.split_path))[0][:3]}_names.txt'), 'w') as f:
f.write('\n'.join(list_names))
def len(self):
return len(self.complex_graphs)
def get(self, idx):
complex_graph = copy.deepcopy(self.complex_graphs[idx])
if self.require_ligand:
complex_graph.mol = RemoveAllHs(copy.deepcopy(self.rdkit_ligands[idx]))
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']:
if hasattr(complex_graph, a):
delattr(complex_graph, a)
if hasattr(complex_graph['receptor'], a):
delattr(complex_graph['receptor'], a)
return complex_graph
def preprocessing(self):
print(f'Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]')
complex_names_all = read_strings_from_txt(self.split_path)
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
print(f'Loading {len(complex_names_all)} complexes.')
if self.esm_embeddings_path is not None:
id_to_embeddings = torch.load(self.esm_embeddings_path)
chain_embeddings_dictlist = defaultdict(list)
chain_indices_dictlist = defaultdict(list)
for key, embedding in id_to_embeddings.items():
key_name = key.split('_chain_')[0]
if key_name in complex_names_all:
chain_embeddings_dictlist[key_name].append(embedding)
chain_indices_dictlist[key_name].append(int(key.split('_chain_')[1]))
lm_embeddings_chains_all = []
for name in complex_names_all:
complex_chains_embeddings = chain_embeddings_dictlist[name]
complex_chains_indices = chain_indices_dictlist[name]
chain_reorder_idx = np.argsort(complex_chains_indices)
reordered_chains = [complex_chains_embeddings[i] for i in chain_reorder_idx]
lm_embeddings_chains_all.append(reordered_chains)
else:
lm_embeddings_chains_all = [None] * len(complex_names_all)
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
list_indices = list(range(len(complex_names_all)//1000+1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
continue
complex_names = complex_names_all[1000*i:1000*(i+1)]
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
complex_graphs, rdkit_ligands = [], []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_complex, zip(complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names))):
complex_graphs.extend(t[0])
rdkit_ligands.extend(t[1])
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
pickle.dump((complex_graphs), f)
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands), f)
def inference_preprocessing(self):
ligands_list = []
print('Reading molecules and generating local structures with RDKit')
for ligand_description in tqdm(self.ligand_descriptions):
mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
if mol is not None:
mol = AddHs(mol)
generate_conformer(mol)
ligands_list.append(mol)
else:
mol = read_molecule(ligand_description, remove_hs=False, sanitize=True)
if not self.keep_local_structures:
mol.RemoveAllConformers()
mol = AddHs(mol)
generate_conformer(mol)
ligands_list.append(mol)
if self.esm_embeddings_path is not None:
print('Reading language model embeddings.')
lm_embeddings_chains_all = []
if not os.path.exists(self.esm_embeddings_path): raise Exception('ESM embeddings path does not exist: ',self.esm_embeddings_path)
for protein_path in self.protein_path_list:
embeddings_paths = sorted(glob.glob(os.path.join(self.esm_embeddings_path, os.path.basename(protein_path)) + '*'))
lm_embeddings_chains = []
for embeddings_path in embeddings_paths:
lm_embeddings_chains.append(torch.load(embeddings_path)['representations'][33])
lm_embeddings_chains_all.append(lm_embeddings_chains)
else:
lm_embeddings_chains_all = [None] * len(self.protein_path_list)
print('Generating graphs for ligands and proteins')
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
list_indices = list(range(len(self.protein_path_list)//1000+1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
continue
protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)]
ligand_description_chunk = self.ligand_descriptions[1000*i:1000*(i+1)]
ligands_chunk = ligands_list[1000 * i:1000 * (i + 1)]
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
complex_graphs, rdkit_ligands = [], []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(protein_paths_chunk), desc=f'loading complexes {i}/{len(protein_paths_chunk)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_complex, zip(protein_paths_chunk, lm_embeddings_chains, ligands_chunk,ligand_description_chunk)):
complex_graphs.extend(t[0])
rdkit_ligands.extend(t[1])
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
pickle.dump((complex_graphs), f)
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands), f)
def check_all_complexes(self):
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")):
return True
complex_names_all = read_strings_from_txt(self.split_path)
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
for i in range(len(complex_names_all) // 1000 + 1):
if not os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
return False
return True
def collect_all_complexes(self):
print('Collecting all complexes from cache', self.full_cache_path)
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")):
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'rb') as f:
complex_graphs = pickle.load(f)
if self.require_ligand:
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
rdkit_ligands = pickle.load(f)
else:
rdkit_ligands = None
return complex_graphs, rdkit_ligands
complex_names_all = read_strings_from_txt(self.split_path)
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
complex_graphs_all = []
for i in range(len(complex_names_all) // 1000 + 1):
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
print(i)
l = pickle.load(f)
complex_graphs_all.extend(l)
rdkit_ligands_all = []
for i in range(len(complex_names_all) // 1000 + 1):
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
l = pickle.load(f)
rdkit_ligands_all.extend(l)
return complex_graphs_all, rdkit_ligands_all
def get_complex(self, par):
name, lm_embedding_chains, ligand, ligand_description = par
if not os.path.exists(os.path.join(self.pdbbind_dir, name)) and ligand is None:
print("Folder not found", name)
return [], []
try:
lig = read_mol(self.pdbbind_dir, name, suffix=self.ligand_file, remove_hs=False)
if self.max_lig_size != None and lig.GetNumHeavyAtoms() > self.max_lig_size:
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
return [], []
complex_graph = HeteroData()
complex_graph['name'] = name
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
self.num_conformers, remove_hs=self.remove_hs, tries=self.matching_tries)
moad_extract_receptor_structure(path=os.path.join(self.pdbbind_dir, name, f'{name}_{self.protein_file}.pdb'),
complex_graph=complex_graph,
neighbor_cutoff=self.receptor_radius,
max_neighbors=self.c_alpha_max_neighbors,
lm_embeddings=lm_embedding_chains,
knn_only_graph=self.knn_only_graph,
all_atoms=self.all_atoms,
atom_cutoff=self.atom_radius,
atom_max_neighbors=self.atom_max_neighbors)
except Exception as e:
print(f'Skipping {name} because of the error:')
print(e)
return [], []
if self.dataset == 'posebusters':
other_positions = []
all_mol_file = os.path.join(self.pdbbind_dir, name, f'{name}_ligands.sdf')
supplier = Chem.SDMolSupplier(all_mol_file, sanitize=False, removeHs=False)
for mol in supplier:
Chem.SanitizeMol(mol)
all_mol = RemoveAllHs(mol)
for conf in all_mol.GetConformers():
other_positions.append(conf.GetPositions())
print(f'Found {len(other_positions)} alternative poses for {name}')
complex_graph['ligand'].orig_pos = np.asarray(other_positions)
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
complex_graph['receptor'].pos -= protein_center
if self.all_atoms:
complex_graph['atom'].pos -= protein_center
if (not self.matching) or self.num_conformers == 1:
complex_graph['ligand'].pos -= protein_center
else:
for p in complex_graph['ligand'].pos:
p -= protein_center
complex_graph.original_center = protein_center
complex_graph['receptor_name'] = name
return [complex_graph], [lig]
def print_statistics(complex_graphs):
statistics = ([], [], [], [], [], [])
receptor_sizes = []
for complex_graph in complex_graphs:
lig_pos = complex_graph['ligand'].pos if torch.is_tensor(complex_graph['ligand'].pos) else complex_graph['ligand'].pos[0]
receptor_sizes.append(complex_graph['receptor'].pos.shape[0])
radius_protein = torch.max(torch.linalg.vector_norm(complex_graph['receptor'].pos, dim=1))
molecule_center = torch.mean(lig_pos, dim=0)
radius_molecule = torch.max(
torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1))
distance_center = torch.linalg.vector_norm(molecule_center)
statistics[0].append(radius_protein)
statistics[1].append(radius_molecule)
statistics[2].append(distance_center)
if "rmsd_matching" in complex_graph:
statistics[3].append(complex_graph.rmsd_matching)
else:
statistics[3].append(0)
statistics[4].append(int(complex_graph.random_coords) if "random_coords" in complex_graph else -1)
if "random_coords" in complex_graph and complex_graph.random_coords and "rmsd_matching" in complex_graph:
statistics[5].append(complex_graph.rmsd_matching)
if len(statistics[5]) == 0:
statistics[5].append(-1)
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching', 'random coordinates', 'random rmsd matching']
print('Number of complexes: ', len(complex_graphs))
for i in range(len(name)):
array = np.asarray(statistics[i])
print(f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}")
return
def read_mol(pdbbind_dir, name, suffix='ligand', remove_hs=False):
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.sdf'), remove_hs=remove_hs, sanitize=True)
if lig is None: # read mol2 file if sdf file cannot be sanitized
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.mol2'), remove_hs=remove_hs, sanitize=True)
return lig
def read_mols(pdbbind_dir, name, remove_hs=False):
ligs = []
for file in os.listdir(os.path.join(pdbbind_dir, name)):
if file.endswith(".sdf") and 'rdkit' not in file:
lig = read_molecule(os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True)
if lig is None and os.path.exists(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2")): # read mol2 file if sdf file cannot be sanitized
print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')
lig = read_molecule(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"), remove_hs=remove_hs, sanitize=True)
if lig is not None:
ligs.append(lig)
return ligs