eduardosoares99's picture
upload smi_ssed
6360d19 verified
# Deep learning
import torch
from smi_ssed.load import Smi_ssed
from trainer import TrainerDirectDecoder
# Parallel
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group
# Data
from utils import MoleculeModule
from torch.utils.data import DataLoader
# Standard library
import os
import args
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def load_train_objs(config):
train_loader = MoleculeModule(
config.max_len,
config.train_load,
config.data_root
)
train_loader.setup()
train_dataset = train_loader.pubchem
loader = DataLoader(
train_dataset,
batch_size=config.n_batch,
pin_memory=True,
shuffle=False,
collate_fn=train_loader.text_encoder.process,
sampler=DistributedSampler(train_dataset),
num_workers=config.n_workers
)
# load model
model = Smi_ssed(config, train_loader.get_vocab())
model.apply(model._init_weights)
# load optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_decoder, betas=(0.9, 0.99))
return loader, model, optimizer
def main(
config,
save_every: int,
total_epochs: int,
save_checkpoint_path: str,
load_checkpoint_path: str
):
ddp_setup()
train_data, model, optimizers = load_train_objs(config)
trainer = TrainerDirectDecoder(
model,
train_data,
optimizers,
save_every,
save_checkpoint_path,
load_checkpoint_path,
config
)
trainer.train(total_epochs)
destroy_process_group()
if __name__ == '__main__':
parser = args.get_parser()
args = parser.parse_args()
main(
args,
args.checkpoint_every,
args.max_epochs,
save_checkpoint_path=args.save_checkpoint_path,
load_checkpoint_path=args.load_checkpoint_path,
)