|
|
|
import torch
|
|
|
|
|
|
from pubchem_encoder import Encoder
|
|
from datasets import load_dataset
|
|
|
|
|
|
import os
|
|
import getpass
|
|
import glob
|
|
|
|
|
|
class MoleculeModule:
|
|
def __init__(self, max_len, dataset, data_path):
|
|
super().__init__()
|
|
self.dataset = dataset
|
|
self.data_path = data_path
|
|
self.text_encoder = Encoder(max_len)
|
|
|
|
def prepare_data(self):
|
|
pass
|
|
|
|
def get_vocab(self):
|
|
|
|
return self.text_encoder.char2id
|
|
|
|
def get_cache(self):
|
|
return self.cache_files
|
|
|
|
def setup(self, stage=None):
|
|
|
|
|
|
pubchem_path = {'train': self.data_path}
|
|
if 'canonical' in pubchem_path['train'].lower():
|
|
pubchem_script = './pubchem_canon_script.py'
|
|
else:
|
|
pubchem_script = './pubchem_script.py'
|
|
zinc_path = './data/ZINC'
|
|
global dataset_dict
|
|
if 'ZINC' in self.dataset or 'zinc' in self.dataset:
|
|
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))]
|
|
for zfile in zinc_files:
|
|
print(zfile)
|
|
self.dataset = {'train': zinc_files}
|
|
dataset_dict = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train')
|
|
|
|
elif 'pubchem' in self.dataset:
|
|
dataset_dict = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'), split='train')
|
|
elif 'both' in self.dataset or 'Both' in self.dataset or 'BOTH' in self.dataset:
|
|
dataset_dict_pubchem = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'),split='train')
|
|
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))]
|
|
for zfile in zinc_files:
|
|
print(zfile)
|
|
self.dataset = {'train': zinc_files}
|
|
dataset_dict_zinc = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train')
|
|
dataset_dict = concatenate_datasets([dataset_dict_zinc, dataset_dict_pubchem])
|
|
self.pubchem= dataset_dict
|
|
print(dataset_dict.cache_files)
|
|
self.cache_files = []
|
|
|
|
for cache in dataset_dict.cache_files:
|
|
tmp = '/'.join(cache['filename'].split('/')[:4])
|
|
self.cache_files.append(tmp)
|
|
|
|
|
|
def get_optim_groups(module):
|
|
|
|
decay = set()
|
|
no_decay = set()
|
|
whitelist_weight_modules = (torch.nn.Linear,)
|
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
for mn, m in module.named_modules():
|
|
for pn, p in m.named_parameters():
|
|
fpn = '%s.%s' % (mn, pn) if mn else pn
|
|
if pn.endswith('bias'):
|
|
|
|
no_decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
|
|
|
decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
|
|
|
no_decay.add(fpn)
|
|
|
|
|
|
param_dict = {pn: p for pn, p in module.named_parameters()}
|
|
|
|
|
|
optim_groups = [
|
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
|
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
]
|
|
|
|
return optim_groups |