# Significant contribution from Ben Fry import copy import os.path import pickle import random from multiprocessing import Pool import numpy as np import pandas as pd import torch from rdkit import Chem from rdkit.Chem import AllChem, MolFromSmiles from scipy.spatial.distance import pdist, squareform from torch_geometric.data import Dataset, HeteroData from torch_geometric.utils import subgraph from tqdm import tqdm from datasets.constants import aa_to_cg_indices, amino_acid_smiles, cg_rdkit_indices from datasets.parse_chi import aa_long2short, atom_order from datasets.process_mols import new_extract_receptor_structure, get_lig_graph, generate_conformer from utils.torsion import get_transformation_mask def read_strings_from_txt(path): # every line will be one element of the returned list with open(path) as file: lines = file.readlines() return [line.rstrip() for line in lines] def compute_num_ca_neighbors(coords, cg_coords, idx, is_valid_bb_node, max_dist=5, buffer_residue_num=7): """ Counts number of residues with heavy atoms within max_dist (Angstroms) of this sidechain that are not residues within +/- buffer_residue_num in primary sequence. From Ben's code Note: Gabriele removed the chain_index """ # Extract coordinates of all residues in the protein. bb_coords = coords # Compute the indices that we should not consider interactions. excluded_neighbors = [idx - x for x in reversed(range(0, buffer_residue_num+1)) if (idx - x) >= 0] excluded_neighbors.extend([idx + x for x in range(1, buffer_residue_num+1)]) # Create indices of an N x M distance matrix where N is num BB nodes and M is num CG nodes. e_idx = torch.stack([ torch.arange(bb_coords.shape[0]).unsqueeze(-1).expand((-1, cg_coords.shape[0])).flatten(), torch.arange(cg_coords.shape[0]).unsqueeze(0).expand((bb_coords.shape[0], -1)).flatten() ]) # Expand bb_coords and cg_coords into the same dimensionality. bb_coords_exp = bb_coords[e_idx[0]] cg_coords_exp = cg_coords[e_idx[1]].unsqueeze(1) # Every row is distance of chemical group to each atom in backbone coordinate frame. bb_exp_idces, _ = (torch.cdist(bb_coords_exp, cg_coords_exp).squeeze(-1) < max_dist).nonzero(as_tuple=True) bb_idces_within_thresh = torch.unique(e_idx[0][bb_exp_idces]) # Only count residues that are not adjacent or origin in primary sequence and are valid backbone residues (fully resolved coordinate frame). bb_idces_within_thresh = bb_idces_within_thresh[~torch.isin(bb_idces_within_thresh, torch.tensor(excluded_neighbors)) & is_valid_bb_node[bb_idces_within_thresh]] return len(bb_idces_within_thresh) def identify_valid_vandermers(args): """ Constructs a tensor containing all the number of contacts for each residue that can be sampled from for chemical groups. By using every sidechain as a chemical group, we will load the actual chemical groups at training time. These can be used to sample as probabilities once divided by the sum. """ complex_graph, max_dist, buffer_residue_num = args # Constructs a mask tracking whether index is a valid coordinate frame / residue label to train over. #is_in_residue_vocabulary = torch.tensor([x in aa_short2long for x in data['seq']]).bool() coords, seq = complex_graph.coords, complex_graph.seq is_valid_bb_node = (coords[:, :4].isnan().sum(dim=(1,2)) == 0).bool() #* is_in_residue_vocabulary valid_cg_idces = [] for idx, aa in enumerate(seq): if aa not in aa_to_cg_indices: valid_cg_idces.append(0) else: indices = aa_to_cg_indices[aa] cg_coordinates = coords[idx][indices] # remove chemical group residues that aren't fully resolved. if torch.any(cg_coordinates.isnan()).item(): valid_cg_idces.append(0) continue nbr_count = compute_num_ca_neighbors(coords, cg_coordinates, idx, is_valid_bb_node, max_dist=max_dist, buffer_residue_num=buffer_residue_num) valid_cg_idces.append(nbr_count) return complex_graph.name, torch.tensor(valid_cg_idces) def fast_identify_valid_vandermers(coords, seq, max_dist=5, buffer_residue_num=7): offset = 10000 + max_dist R = coords.shape[0] coords = coords.numpy().reshape(-1, 3) pdist_mat = squareform(pdist(coords)) pdist_mat = pdist_mat.reshape((R, 14, R, 14)) pdist_mat = np.nan_to_num(pdist_mat, nan=offset) pdist_mat = np.min(pdist_mat, axis=(1, 3)) # compute pairwise distances pdist_mat = pdist_mat + np.diag(np.ones(len(seq)) * offset) for i in range(1, buffer_residue_num+1): pdist_mat += np.diag(np.ones(len(seq)-i) * offset, k=i) + np.diag(np.ones(len(seq)-i) * offset, k=-i) # get number of residues that are within max_dist of each other nbr_count = np.sum(pdist_mat < max_dist, axis=1) return torch.tensor(nbr_count) def compute_cg_features(aa, aa_smile): """ Given an amino acid and a smiles string returns the stacked tensor of chemical group atom encodings. The order of the output tensor rows corresponds to the index the atoms appear in aa_to_cg_indices from constants. """ # Handle any residues that we don't have chemical groups for (ex: GLY if not using bb_cnh and bb_cco) aa_short = aa_long2short[aa] if aa_short not in aa_to_cg_indices: return None # Create rdkit molecule from smiles string. mol = Chem.MolFromSmiles(aa_smile) complex_graph = HeteroData() get_lig_graph(mol, complex_graph) atoms_to_keep = torch.tensor([i for i, _ in cg_rdkit_indices[aa].items()]).long() complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr = \ subgraph(atoms_to_keep, complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr, relabel_nodes=True) complex_graph['ligand'].x = complex_graph['ligand'].x[atoms_to_keep] edge_mask, mask_rotate = get_transformation_mask(complex_graph) complex_graph['ligand'].edge_mask = torch.tensor(edge_mask) complex_graph['ligand'].mask_rotate = mask_rotate return complex_graph class PDBSidechain(Dataset): def __init__(self, root, transform=None, cache_path='data/cache', split='train', limit_complexes=0, receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, remove_hs=True, all_atoms=False, atom_radius=5, atom_max_neighbors=None, sequences_to_embeddings=None, knn_only_graph=True, multiplicity=1, vandermers_max_dist=5, vandermers_buffer_residue_num=7, vandermers_min_contacts=5, remove_second_segment=False, merge_clusters=1, vandermers_extraction=True, add_random_ligand=False): super(PDBSidechain, self).__init__(root, transform) assert remove_hs == True, "not implemented yet" self.root = root self.split = split self.limit_complexes = limit_complexes self.receptor_radius = receptor_radius self.knn_only_graph = knn_only_graph self.multiplicity = multiplicity self.c_alpha_max_neighbors = c_alpha_max_neighbors self.num_workers = num_workers self.sequences_to_embeddings = sequences_to_embeddings self.remove_second_segment = remove_second_segment self.merge_clusters = merge_clusters self.vandermers_extraction = vandermers_extraction self.add_random_ligand = add_random_ligand self.all_atoms = all_atoms self.atom_radius = atom_radius self.atom_max_neighbors = atom_max_neighbors if vandermers_extraction: self.cg_node_feature_lookup_dict = {aa_long2short[aa]: compute_cg_features(aa, aa_smile) for aa, aa_smile in amino_acid_smiles.items()} self.cache_path = os.path.join(cache_path, f'PDB3_limit{self.limit_complexes}_INDEX{self.split}' f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}' + (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}') + ('' if not self.knn_only_graph else '_knnOnly')) self.read_split() if not self.check_all_proteins(): os.makedirs(self.cache_path, exist_ok=True) self.preprocess() self.vandermers_max_dist = vandermers_max_dist self.vandermers_buffer_residue_num = vandermers_buffer_residue_num self.vandermers_min_contacts = vandermers_min_contacts self.collect_proteins() filtered_proteins = [] if vandermers_extraction: for complex_graph in tqdm(self.protein_graphs): if complex_graph.name in self.vandermers and torch.any(self.vandermers[complex_graph.name] >= 10): filtered_proteins.append(complex_graph) print(f"Computed vandermers and kept {len(filtered_proteins)} proteins out of {len(self.protein_graphs)}") else: filtered_proteins = self.protein_graphs second_filter = [] for complex_graph in tqdm(filtered_proteins): if sequences_to_embeddings is None or complex_graph.orig_seq in sequences_to_embeddings: second_filter.append(complex_graph) print(f"Checked embeddings available and kept {len(second_filter)} proteins out of {len(filtered_proteins)}") self.protein_graphs = second_filter # filter clusters that have no protein graphs self.split_clusters = list(set([g.cluster for g in self.protein_graphs])) self.cluster_to_complexes = {c: [] for c in self.split_clusters} for p in self.protein_graphs: self.cluster_to_complexes[p['cluster']].append(p) self.split_clusters = [c for c in self.split_clusters if len(self.cluster_to_complexes[c]) > 0] print("Total elements in set", len(self.split_clusters) * self.multiplicity // self.merge_clusters) self.name_to_complex = {p.name: p for p in self.protein_graphs} self.define_probabilities() if self.add_random_ligand: # read csv with all smiles with open('data/smiles_list.csv', 'r') as f: self.smiles_list = f.readlines() self.smiles_list = [s.split(',')[0] for s in self.smiles_list] def define_probabilities(self): if not self.vandermers_extraction: return if self.vandermers_min_contacts is not None: self.probabilities = torch.arange(1000) - self.vandermers_min_contacts + 1 self.probabilities[:self.vandermers_min_contacts] = 0 else: with open('data/pdbbind_counts.pkl', 'rb') as f: pdbbind_counts = pickle.load(f) pdb_counts = torch.ones(1000) for contacts in self.vandermers.values(): pdb_counts.index_add_(0, contacts, torch.ones(contacts.shape)) print(pdbbind_counts[:30]) print(pdb_counts[:30]) self.probabilities = pdbbind_counts / pdb_counts self.probabilities[:7] = 0 def len(self): return len(self.split_clusters) * self.multiplicity // self.merge_clusters def get(self, idx=None, protein=None, smiles=None): assert idx is not None or (protein is not None and smiles is not None), "provide idx or protein or smile" if protein is None or smiles is None: idx = idx % len(self.split_clusters) if self.merge_clusters > 1: idx = idx * self.merge_clusters idx = idx + random.randint(0, self.merge_clusters - 1) idx = min(idx, len(self.split_clusters) - 1) cluster = self.split_clusters[idx] protein_graph = copy.deepcopy(random.choice(self.cluster_to_complexes[cluster])) else: protein_graph = copy.deepcopy(self.name_to_complex[protein]) if self.sequences_to_embeddings is not None: #print(self.sequences_to_embeddings[protein_graph.orig_seq].shape, len(protein_graph.orig_seq), protein_graph.to_keep.shape) if len(protein_graph.orig_seq) != len(self.sequences_to_embeddings[protein_graph.orig_seq]): print('problem with ESM embeddings') return self.get(random.randint(0, self.len())) lm_embeddings = self.sequences_to_embeddings[protein_graph.orig_seq][protein_graph.to_keep] protein_graph['receptor'].x = torch.cat([protein_graph['receptor'].x, lm_embeddings], dim=1) if self.vandermers_extraction: # select sidechain to remove vandermers_contacts = self.vandermers[protein_graph.name] vandermers_probs = self.probabilities[vandermers_contacts].numpy() if not np.any(vandermers_contacts.numpy() >= 10): print('no vandarmers >= 10 retrying with new one') return self.get(random.randint(0, self.len())) sidechain_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs)) # remove part of the sequence residues_to_keep = np.ones(len(protein_graph.seq), dtype=bool) residues_to_keep[max(0, sidechain_idx - self.vandermers_buffer_residue_num): min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False if self.remove_second_segment: pos_idx = protein_graph['receptor'].pos[sidechain_idx] limit_closeness = 10 far_enough = torch.sum((protein_graph['receptor'].pos - pos_idx[None, :]) ** 2, dim=-1) > limit_closeness ** 2 vandermers_probs = vandermers_probs * far_enough.float().numpy() vandermers_probs[max(0, sidechain_idx - self.vandermers_buffer_residue_num): min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = 0 if np.all(vandermers_probs<=0): print('no second vandermer available retrying with new one') return self.get(random.randint(0, self.len())) sc2_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs)) residues_to_keep[max(0, sc2_idx - self.vandermers_buffer_residue_num): min(sc2_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False residues_to_keep = torch.from_numpy(residues_to_keep) protein_graph['receptor'].pos = protein_graph['receptor'].pos[residues_to_keep] protein_graph['receptor'].x = protein_graph['receptor'].x[residues_to_keep] protein_graph['receptor'].side_chain_vecs = protein_graph['receptor'].side_chain_vecs[residues_to_keep] protein_graph['receptor', 'rec_contact', 'receptor'].edge_index = \ subgraph(residues_to_keep, protein_graph['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0] # create the sidechain ligand sidechain_aa = protein_graph.seq[sidechain_idx] ligand_graph = self.cg_node_feature_lookup_dict[sidechain_aa] ligand_graph['ligand'].pos = protein_graph.coords[sidechain_idx][protein_graph.mask[sidechain_idx]] for type in ligand_graph.node_types + ligand_graph.edge_types: for key, value in ligand_graph[type].items(): protein_graph[type][key] = value protein_graph['ligand'].orig_pos = protein_graph['ligand'].pos.numpy() protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True) protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center protein_graph['ligand'].pos = protein_graph['ligand'].pos - protein_center protein_graph.original_center = protein_center protein_graph['receptor_name'] = protein_graph.name else: protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True) protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center protein_graph.original_center = protein_center protein_graph['receptor_name'] = protein_graph.name if self.add_random_ligand: if smiles is not None: mol = MolFromSmiles(smiles) try: generate_conformer(mol) except Exception as e: print("failed to generate the given ligand returning None", e) return None else: success = False while not success: smiles = random.choice(self.smiles_list) mol = MolFromSmiles(smiles) try: success = not generate_conformer(mol) except Exception as e: print(e, "changing ligand") lig_graph = HeteroData() get_lig_graph(mol, lig_graph) edge_mask, mask_rotate = get_transformation_mask(lig_graph) lig_graph['ligand'].edge_mask = torch.tensor(edge_mask) lig_graph['ligand'].mask_rotate = mask_rotate lig_graph['ligand'].smiles = smiles lig_graph['ligand'].pos = lig_graph['ligand'].pos - torch.mean(lig_graph['ligand'].pos, dim=0, keepdim=True) for type in lig_graph.node_types + lig_graph.edge_types: for key, value in lig_graph[type].items(): protein_graph[type][key] = value for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']: if hasattr(protein_graph, a): delattr(protein_graph, a) if hasattr(protein_graph['receptor'], a): delattr(protein_graph['receptor'], a) return protein_graph def read_split(self): # read CSV file df = pd.read_csv(self.root + "/list.csv") print("Loaded list CSV file") # get clusters and filter by split if self.split == "train": val_clusters = set(read_strings_from_txt(self.root + "/valid_clusters.txt")) test_clusters = set(read_strings_from_txt(self.root + "/test_clusters.txt")) clusters = df["CLUSTER"].unique() clusters = [int(c) for c in clusters if c not in val_clusters and c not in test_clusters] elif self.split == "val": clusters = [int(s) for s in read_strings_from_txt(self.root + "/valid_clusters.txt")] elif self.split == "test": clusters = [int(s) for s in read_strings_from_txt(self.root + "/test_clusters.txt")] else: raise ValueError("Split must be train, val or test") print(self.split, "clusters", len(clusters)) clusters = set(clusters) self.chains_in_cluster = [] complexes_in_cluster = set() for chain, cluster in zip(df["CHAINID"], df["CLUSTER"]): if cluster not in clusters: continue # limit to one chain per complex if chain[:4] not in complexes_in_cluster: self.chains_in_cluster.append((chain, cluster)) complexes_in_cluster.add(chain[:4]) print("Filtered chains in cluster", len(self.chains_in_cluster)) if self.limit_complexes > 0: self.chains_in_cluster = self.chains_in_cluster[:self.limit_complexes] def check_all_proteins(self): for i in range(len(self.chains_in_cluster)//10000+1): if not os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")): return False return True def collect_proteins(self): self.protein_graphs = [] self.vandermers = {} total_recovered = 0 print(f'Loading {len(self.chains_in_cluster)} protein graphs.') list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1)) random.shuffle(list_indices) for i in list_indices: with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'rb') as f: print(i) l = pickle.load(f) total_recovered += len(l) self.protein_graphs.extend(l) if not self.vandermers_extraction: continue if os.path.exists(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl')): with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'rb') as f: vandermers = pickle.load(f) self.vandermers.update(vandermers) continue vandermers = {} if self.num_workers > 1: p = Pool(self.num_workers, maxtasksperchild=1) p.__enter__() with tqdm(total=len(l), desc=f'computing vandermers {i}') as pbar: map_fn = p.imap_unordered if self.num_workers > 1 else map arguments = zip(l, [self.vandermers_max_dist] * len(l), [self.vandermers_buffer_residue_num] * len(l)) for t in map_fn(identify_valid_vandermers, arguments): if t is not None: vandermers[t[0]] = t[1] pbar.update() if self.num_workers > 1: p.__exit__(None, None, None) with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'wb') as f: pickle.dump(vandermers, f) self.vandermers.update(vandermers) print(f"Kept {len(self.protein_graphs)} proteins out of {len(self.chains_in_cluster)} total") return def preprocess(self): # running preprocessing in parallel on multiple workers and saving the progress every 10000 proteins list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1)) random.shuffle(list_indices) for i in list_indices: if os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")): continue chains_names = self.chains_in_cluster[10000 * i:10000 * (i + 1)] protein_graphs = [] if self.num_workers > 1: p = Pool(self.num_workers, maxtasksperchild=1) p.__enter__() with tqdm(total=len(chains_names), desc=f'loading protein batch {i}/{len(self.chains_in_cluster) // 10000 + 1}') as pbar: map_fn = p.imap_unordered if self.num_workers > 1 else map for t in map_fn(self.load_chain, chains_names): if t is not None: protein_graphs.append(t) pbar.update() if self.num_workers > 1: p.__exit__(None, None, None) with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'wb') as f: pickle.dump(protein_graphs, f) print("Finished preprocessing and saving protein graphs") def load_chain(self, c): chain, cluster = c if not os.path.exists(self.root + f"/pdb/{chain[1:3]}/{chain}.pt"): print("File not found", chain) return None data = torch.load(self.root + f"/pdb/{chain[1:3]}/{chain}.pt") complex_graph = HeteroData() complex_graph['name'] = chain orig_seq = data["seq"] coords = data["xyz"] mask = data["mask"].bool() # remove residues with NaN backbone coordinates to_keep = torch.logical_not(torch.any(torch.isnan(coords[:, :4, 0]), dim=1)) coords = coords[to_keep] seq = ''.join(np.asarray(list(orig_seq))[to_keep.numpy()].tolist()) mask = mask[to_keep] if len(coords) == 0: print("All coords were NaN", chain) return None try: new_extract_receptor_structure(seq, coords.numpy(), complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius, max_neighbors=self.c_alpha_max_neighbors, knn_only_graph=self.knn_only_graph, all_atoms=self.all_atoms, atom_cutoff=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors) except Exception as e: print("Error in extracting receptor", chain) print(e) return None if torch.any(torch.isnan(complex_graph['receptor'].pos)): print("NaN in pos receptor", chain) return None complex_graph.coords = coords complex_graph.seq = seq complex_graph.mask = mask complex_graph.cluster = cluster complex_graph.orig_seq = orig_seq complex_graph.to_keep = to_keep return complex_graph if __name__ == "__main__": dataset = PDBSidechain(root="data/pdb_2021aug02_sample", split="train", multiplicity=1, limit_complexes=150) print(len(dataset)) print(dataset[0]) for p in dataset: print(p) pass