Lang2mol-Diff / train.py
ndhieunguyen's picture
Add application file
7dd9869
import os
import argparse
from transformers import set_seed
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion import dist_util
from src.improved_diffusion.transformer_model import TransformerNetModel
from src.improved_diffusion.resample import create_named_schedule_sampler
from src.improved_diffusion.script_util import model_and_diffusion_defaults
from src.improved_diffusion.script_util import add_dict_to_argparser
from src.improved_diffusion.train_util import TrainLoop
import torch.distributed as dist
import wandb
from src.scripts.mydatasets import get_dataloader, Lang2molDataset_train
import warnings
import torch.multiprocessing as mp
def main_worker(rank, world_size):
args = create_argparser().parse_args()
set_seed(42)
wandb.login(key=args.wandb_token)
wandb.init(
project="ACL_Lang2Mol",
config=args.__dict__,
)
dist_util.setup_dist(rank, world_size)
tokenizer = Tokenizer()
model = TransformerNetModel(
in_channels=args.model_in_channels,
model_channels=args.model_model_channels,
dropout=args.model_dropout,
vocab_size=len(tokenizer),
hidden_size=args.model_hidden_size,
num_attention_heads=args.model_num_attention_heads,
num_hidden_layers=args.model_num_hidden_layers,
)
if args.model_path != "":
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.train()
print("Total params:", sum(p.numel() for p in model.parameters()))
print(
"Total trainable params:",
sum(p.numel() for p in model.parameters() if p.requires_grad),
)
print("Tokenizer vocab length:", len(tokenizer))
diffusion = SpacedDiffusion(
use_timesteps=[i for i in range(args.diffusion_steps)],
betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps),
model_mean_type=(gd.ModelMeanType.START_X),
model_var_type=((gd.ModelVarType.FIXED_LARGE)),
loss_type=gd.LossType.E2E_MSE,
rescale_timesteps=True,
model_arch="transformer",
training_mode="e2e",
)
schedule_sampler = create_named_schedule_sampler("uniform", diffusion)
print("Loading data...")
train_dataset = Lang2molDataset_train(
dir=args.dataset_path,
tokenizer=tokenizer,
split="train",
corrupt_prob=0.0,
token_max_length=512,
dataset_name=args.dataset_name,
)
dataloader = get_dataloader(train_dataset, args.batch_size, rank, world_size)
print("Finish loading data")
TrainLoop(
model=model,
diffusion=diffusion,
data=dataloader,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
checkpoint_path=args.checkpoint_path,
gradient_clipping=args.gradient_clipping,
eval_data=None,
eval_interval=args.eval_interval,
).run_loop()
dist.destroy_process_group()
def create_argparser():
defaults = dict()
text_defaults = dict(
wandb_token="",
batch_size=16,
cache_mode="no",
checkpoint_path="checkpoints",
class_cond=False,
config="ll",
config_name="QizhiPei/biot5-base-text2mol",
dataset_path="dataset",
diffusion_steps=2000,
dropout=0.01,
e2e_train="",
ema_rate="0.9999",
emb_scale_factor=1.0,
eval_interval=2000,
experiment="random",
experiment_mode="lm",
fp16_scale_growth=0.001,
gradient_clipping=2.4,
image_size=8,
in_channel=16,
learn_sigma=False,
log_interval=1000,
logits_mode=1,
lr=0.00005,
lr_anneal_steps=500000,
microbatch=-1,
modality="e2e-tgt",
model_arch="transformer",
noise_level=0.0,
noise_schedule="sqrt",
num_channels=128,
num_heads=4,
num_heads_upsample=-1,
num_res_blocks=2,
out_channel=16,
padding_mode="pad",
predict_xstart=True,
preprocessing_num_workers=1,
rescale_learned_sigmas=True,
rescale_timesteps=True,
resume_checkpoint="",
save_interval=50000,
schedule_sampler="uniform",
seed=42,
timestep_respacing="",
training_mode="e2e",
use_bert_tokenizer="no",
use_checkpoint=False,
use_fp16=False,
use_kl=False,
use_scale_shift_norm=True,
weight_decay=0.0,
model_in_channels=32,
model_model_channels=128,
model_dropout=0.01,
model_hidden_size=1024,
model_num_attention_heads=16,
model_num_hidden_layers=12,
dataset_name="",
model_path="",
)
defaults.update(model_and_diffusion_defaults())
defaults.update(text_defaults)
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
world_size = 1
mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True)