prajwalsahu5's picture
Upload 141 files
b3c2eb7
raw
history blame
6.02 kB
from rdkit import Chem
import codecs
import numpy as np
import pandas as pd
from rdkit.Chem import Descriptors
from subword_nmt.apply_bpe import BPE
import deepchem as dc
import deepchem.molnet as dcm
from rdkit import RDLogger
from torch.utils.data import DataLoader, Dataset, Subset
from pytorch_lightning import LightningDataModule
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
def smiles2index(s1, words2idx, bpe):
t1 = bpe.process_line(s1).split()
i1 = [words2idx[i] for i in t1]
return i1
def index2multi_hot(i1, idx2word):
v1 = np.zeros(len(idx2word))
v1[i1] = 1
return v1
def index2multi_hot_fg(molecule, fgroups_list):
v1 = np.zeros(len(fgroups_list))
for idx in range(len(fgroups_list)):
if molecule.HasSubstructMatch(fgroups_list[idx]):
v1[idx] = 1
return v1
def smiles2vector_fgr(s1, words2idx, bpe, idx2word, fgroups_list):
i1 = smiles2index(s1, words2idx, bpe)
mfg = index2multi_hot(i1, idx2word)
molecule = Chem.MolFromSmiles(s1)
fg = index2multi_hot_fg(molecule, fgroups_list)
return fg, mfg
class FsrFgDataset(Dataset):
def __init__(self, data, idx2word, fgroups_list, words2idx, bpe, descriptor_funcs):
self.mols = data.X
self.y = data.y
self.smiles = data.ids
self.idx2word = idx2word
self.fgroups_list = fgroups_list
self.words2idx = words2idx
self.bpe = bpe
self.descriptor_funcs = descriptor_funcs
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
mol = self.mols[idx]
smile = self.smiles[idx]
target = self.y[idx]
fg, mfg = smiles2vector_fgr(smile, self.words2idx, self.bpe, self.idx2word, self.fgroups_list)
num_features = np.asarray([self.descriptor_funcs[key](mol) for key in self.descriptor_funcs.keys()])
return np.float32(fg), np.float32(mfg), np.float32(num_features), int(target)
class FsrFgDataModule(LightningDataModule):
def __init__(self, root: str = 'Data', task_name: str = 'ecoli', batch_size: int = 8, num_workers: int = 8,
pin_memory: bool = True, split_type: str = 'scaffold', num_folds: int = 5, fold_index: int = 0):
super().__init__()
self.root = root
self.task_name = task_name
self.num_workers = num_workers
self.pin_memory = pin_memory
self.batch_size = batch_size
self.split_type = split_type
self.num_folds = num_folds
self.fold_index = fold_index
self.test_fold = None
self.val_fold = None
self.train_fold = None
self.fgroups_list = None
self.descriptor_funcs = None
self.bpe = None
self.idx2word = None
self.words2idx = None
self.val_data = None
self.test_data = None
self.train_data = None
self.splits = None
def prepare_data(self) -> None:
df = pd.read_csv('Data/ecoli_inhibition.csv')
ids = df['smiles']
labels = df['Encoded_Activity']
X = [Chem.MolFromSmiles(smiles) for smiles in ids]
data = dc.data.DiskDataset.from_numpy(X=X, y=labels, ids=ids)
fgroups = pd.read_csv(self.root + '/Functional_groups_filtered.csv')
fgroups_list = list(map(lambda x: Chem.MolFromSmarts(x), fgroups['SMARTS'].tolist()))
self.fgroups_list = [i for i in fgroups_list if i]
self.descriptor_funcs = {name: func for name, func in Descriptors.descList}
vocab_path = self.root + '/codes_drug_chembl_1500.txt'
bpe_codes_fin = codecs.open(vocab_path)
self.bpe = BPE(bpe_codes_fin, merges=-1, separator='')
vocab_map = pd.read_csv(self.root + '/subword_units_map_drug_chembl_1500.csv')
self.idx2word = vocab_map['index'].values
self.words2idx = dict(zip(self.idx2word, range(0, len(self.idx2word))))
self.weights = data.w
if self.split_type == 'scaffold':
splitter = dc.splits.ScaffoldSplitter()
self.splits = []
for i in range(self.num_folds):
new_data = data.complete_shuffle()
self.splits.append((new_data, splitter.split(dataset=new_data, seed=i)))
self.dataset = FsrFgDataset(self.splits[self.fold_index][0], self.idx2word, self.fgroups_list, self.words2idx, self.bpe,
self.descriptor_funcs)
self.train_ind, self.val_ind, self.test_ind = self.splits[self.fold_index][1]
else:
splitter = dc.splits.RandomStratifiedSplitter()
self.splits = [splitter.split(dataset=data, seed=fold_num) for fold_num in range(self.num_folds)]
self.dataset = FsrFgDataset(data, self.idx2word, self.fgroups_list, self.words2idx, self.bpe,
self.descriptor_funcs)
self.train_ind, self.val_ind, self.test_ind = self.splits[self.fold_index]
def setup(self, stage=None):
"""define train, test and validation datasets """
self.train_fold = Subset(self.dataset, self.train_ind)
self.val_fold = Subset(self.dataset, self.val_ind)
self.test_fold = Subset(self.dataset, self.test_ind)
def train_dataloader(self):
"""returns train dataloader"""
loader = DataLoader(self.train_fold, batch_size=self.batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=True)
return loader
def val_dataloader(self):
"""returns val dataloader"""
loader = DataLoader(self.val_fold, batch_size=1, shuffle=False,
num_workers=self.num_workers, pin_memory=self.pin_memory)
return loader
def test_dataloader(self):
"""returns test dataloader"""
loader = DataLoader(self.test_fold, batch_size=1, shuffle=False,
num_workers=self.num_workers, pin_memory=self.pin_memory)
return loader