|
|
|
import torch
|
|
from smi_ssed.load import Smi_ssed
|
|
from trainer import TrainerDirectDecoder
|
|
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.distributed import init_process_group, destroy_process_group
|
|
|
|
|
|
from utils import MoleculeModule
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
import os
|
|
import args
|
|
|
|
|
|
def ddp_setup():
|
|
init_process_group(backend="nccl")
|
|
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
|
|
|
|
|
def load_train_objs(config):
|
|
train_loader = MoleculeModule(
|
|
config.max_len,
|
|
config.train_load,
|
|
config.data_root
|
|
)
|
|
train_loader.setup()
|
|
train_dataset = train_loader.pubchem
|
|
loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=config.n_batch,
|
|
pin_memory=True,
|
|
shuffle=False,
|
|
collate_fn=train_loader.text_encoder.process,
|
|
sampler=DistributedSampler(train_dataset),
|
|
num_workers=config.n_workers
|
|
)
|
|
|
|
|
|
model = Smi_ssed(config, train_loader.get_vocab())
|
|
model.apply(model._init_weights)
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_decoder, betas=(0.9, 0.99))
|
|
|
|
return loader, model, optimizer
|
|
|
|
|
|
def main(
|
|
config,
|
|
save_every: int,
|
|
total_epochs: int,
|
|
save_checkpoint_path: str,
|
|
load_checkpoint_path: str
|
|
):
|
|
ddp_setup()
|
|
train_data, model, optimizers = load_train_objs(config)
|
|
trainer = TrainerDirectDecoder(
|
|
model,
|
|
train_data,
|
|
optimizers,
|
|
save_every,
|
|
save_checkpoint_path,
|
|
load_checkpoint_path,
|
|
config
|
|
)
|
|
trainer.train(total_epochs)
|
|
destroy_process_group()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = args.get_parser()
|
|
args = parser.parse_args()
|
|
main(
|
|
args,
|
|
args.checkpoint_every,
|
|
args.max_epochs,
|
|
save_checkpoint_path=args.save_checkpoint_path,
|
|
load_checkpoint_path=args.load_checkpoint_path,
|
|
) |