File size: 7,843 Bytes
c17cba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch_geometric.data import Dataset

from datasets.dataloader import DataLoader, DataListLoader
from datasets.moad import MOAD
from datasets.pdb import PDBSidechain
from datasets.pdbbind import NoiseTransform, PDBBind
from utils.utils import read_strings_from_txt


class CombineDatasets(Dataset):
    def __init__(self, dataset1, dataset2):
        super(CombineDatasets, self).__init__()
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def len(self):
        return len(self.dataset1) + len(self.dataset2)

    def get(self, idx):
        if idx < len(self.dataset1):
            return self.dataset1[idx]
        else:
            return self.dataset2[idx - len(self.dataset1)]

    def add_complexes(self, new_complex_list):
        self.dataset1.add_complexes(new_complex_list)


def construct_loader(args, t_to_sigma, device):
    val_dataset2 = None
    transform = NoiseTransform(t_to_sigma=t_to_sigma, no_torsion=args.no_torsion,
                               all_atom=args.all_atoms, alpha=args.sampling_alpha, beta=args.sampling_beta,
                               include_miscellaneous_atoms=False if not hasattr(args, 'include_miscellaneous_atoms') else args.include_miscellaneous_atoms,
                               crop_beyond_cutoff=args.crop_beyond)
    if args.triple_training: assert args.combined_training

    sequences_to_embeddings = None
    if args.dataset == 'pdbsidechain' or args.triple_training:
        if args.pdbsidechain_esm_embeddings_path is not None:
            print('Loading ESM embeddings')
            id_to_embeddings = torch.load(args.pdbsidechain_esm_embeddings_path)
            sequences_list = read_strings_from_txt(args.pdbsidechain_esm_embeddings_sequences_path)
            sequences_to_embeddings = {}
            for i, seq in enumerate(sequences_list):
                if str(i) in id_to_embeddings:
                    sequences_to_embeddings[seq] = id_to_embeddings[str(i)]

    if args.dataset == 'pdbsidechain' or args.triple_training:

        common_args = {'root': args.pdbsidechain_dir, 'transform': transform, 'limit_complexes': args.limit_complexes,
                       'receptor_radius': args.receptor_radius,
                       'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
                       'remove_hs': args.remove_hs, 'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
                       'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
                       'knn_only_graph': not args.not_knn_only_graph, 'sequences_to_embeddings': sequences_to_embeddings,
                       'vandermers_max_dist': args.vandermers_max_dist,
                       'vandermers_buffer_residue_num': args.vandermers_buffer_residue_num,
                       'vandermers_min_contacts': args.vandermers_min_contacts,
                       'remove_second_segment': args.remove_second_segment,
                       'merge_clusters': args.merge_clusters}
        train_dataset3 = PDBSidechain(cache_path=args.cache_path, split='train', multiplicity=args.train_multiplicity, **common_args)

        if args.dataset == 'pdbsidechain':
            train_dataset = train_dataset3
            val_dataset = PDBSidechain(cache_path=args.cache_path, split='val', multiplicity=args.val_multiplicity, **common_args)
        loader_class = DataListLoader if torch.cuda.is_available() else DataLoader

    if args.dataset in ['pdbbind', 'moad', 'generalisation', 'distillation']:
        common_args = {'transform': transform, 'limit_complexes': args.limit_complexes,
                       'chain_cutoff': args.chain_cutoff, 'receptor_radius': args.receptor_radius,
                       'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
                       'remove_hs': args.remove_hs, 'max_lig_size': args.max_lig_size,
                       'matching': not args.no_torsion, 'popsize': args.matching_popsize, 'maxiter': args.matching_maxiter,
                       'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
                       'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
                       'knn_only_graph': False if not hasattr(args, 'not_knn_only_graph') else not args.not_knn_only_graph,
                       'include_miscellaneous_atoms': False if not hasattr(args, 'include_miscellaneous_atoms') else args.include_miscellaneous_atoms,
                       'matching_tries': args.matching_tries}

        if args.dataset == 'pdbbind' or args.dataset == 'generalisation' or args.combined_training:
            train_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_train, keep_original=True,
                                    num_conformers=args.num_conformers, root=args.pdbbind_dir,
                                    esm_embeddings_path=args.pdbbind_esm_embeddings_path,
                                    protein_file=args.protein_file, **common_args)

        if args.dataset == 'moad' or args.combined_training:
            train_dataset2 = MOAD(cache_path=args.cache_path, split='train', keep_original=True,
                                  num_conformers=args.num_conformers, max_receptor_size=args.max_receptor_size,
                                  remove_promiscuous_targets=args.remove_promiscuous_targets, min_ligand_size=args.min_ligand_size,
                                  multiplicity= args.train_multiplicity, unroll_clusters=args.unroll_clusters,
                                  esm_embeddings_sequences_path=args.moad_esm_embeddings_sequences_path,
                                  root=args.moad_dir, esm_embeddings_path=args.moad_esm_embeddings_path,
                                  enforce_timesplit=args.enforce_timesplit, **common_args)

            if args.combined_training:
                train_dataset = CombineDatasets(train_dataset2, train_dataset)
                if args.triple_training:
                    train_dataset = CombineDatasets(train_dataset, train_dataset3)
            else:
                train_dataset = train_dataset2

        if args.dataset == 'pdbbind' or args.double_val:
            val_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_val, keep_original=True,
                                  esm_embeddings_path=args.pdbbind_esm_embeddings_path, root=args.pdbbind_dir,
                                  protein_file=args.protein_file, require_ligand=True, **common_args)
            if args.double_val:
                val_dataset2 = val_dataset

        if args.dataset == 'moad' or args.dataset == 'generalisation':
            val_dataset = MOAD(cache_path=args.cache_path, split='val', keep_original=True,
                               multiplicity= args.val_multiplicity, max_receptor_size=args.max_receptor_size,
                               remove_promiscuous_targets=args.remove_promiscuous_targets, min_ligand_size=args.min_ligand_size,
                               esm_embeddings_sequences_path=args.moad_esm_embeddings_sequences_path,
                               unroll_clusters=args.unroll_clusters, root=args.moad_dir,
                               esm_embeddings_path=args.moad_esm_embeddings_path, require_ligand=True, **common_args)

        loader_class = DataListLoader if torch.cuda.is_available() else DataLoader

    train_loader = loader_class(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory, drop_last=args.dataloader_drop_last)
    val_loader = loader_class(dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=False, pin_memory=args.pin_memory, drop_last=args.dataloader_drop_last)
    return train_loader, val_loader, val_dataset2