Spaces:
Running
Running
import os | |
import argparse | |
from transformers import set_seed | |
from src.scripts.mytokenizers import Tokenizer | |
from src.improved_diffusion import gaussian_diffusion as gd | |
from src.improved_diffusion.respace import SpacedDiffusion | |
from src.improved_diffusion import dist_util | |
from src.improved_diffusion.transformer_model import TransformerNetModel | |
from src.improved_diffusion.resample import create_named_schedule_sampler | |
from src.improved_diffusion.script_util import model_and_diffusion_defaults | |
from src.improved_diffusion.script_util import add_dict_to_argparser | |
from src.improved_diffusion.train_util import TrainLoop | |
import torch.distributed as dist | |
import wandb | |
from src.scripts.mydatasets import get_dataloader, Lang2molDataset_train | |
import warnings | |
import torch.multiprocessing as mp | |
def main_worker(rank, world_size): | |
args = create_argparser().parse_args() | |
set_seed(42) | |
wandb.login(key=args.wandb_token) | |
wandb.init( | |
project="ACL_Lang2Mol", | |
config=args.__dict__, | |
) | |
dist_util.setup_dist(rank, world_size) | |
tokenizer = Tokenizer() | |
model = TransformerNetModel( | |
in_channels=args.model_in_channels, | |
model_channels=args.model_model_channels, | |
dropout=args.model_dropout, | |
vocab_size=len(tokenizer), | |
hidden_size=args.model_hidden_size, | |
num_attention_heads=args.model_num_attention_heads, | |
num_hidden_layers=args.model_num_hidden_layers, | |
) | |
if args.model_path != "": | |
model.load_state_dict( | |
dist_util.load_state_dict(args.model_path, map_location="cpu") | |
) | |
model.train() | |
print("Total params:", sum(p.numel() for p in model.parameters())) | |
print( | |
"Total trainable params:", | |
sum(p.numel() for p in model.parameters() if p.requires_grad), | |
) | |
print("Tokenizer vocab length:", len(tokenizer)) | |
diffusion = SpacedDiffusion( | |
use_timesteps=[i for i in range(args.diffusion_steps)], | |
betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps), | |
model_mean_type=(gd.ModelMeanType.START_X), | |
model_var_type=((gd.ModelVarType.FIXED_LARGE)), | |
loss_type=gd.LossType.E2E_MSE, | |
rescale_timesteps=True, | |
model_arch="transformer", | |
training_mode="e2e", | |
) | |
schedule_sampler = create_named_schedule_sampler("uniform", diffusion) | |
print("Loading data...") | |
train_dataset = Lang2molDataset_train( | |
dir=args.dataset_path, | |
tokenizer=tokenizer, | |
split="train", | |
corrupt_prob=0.0, | |
token_max_length=512, | |
dataset_name=args.dataset_name, | |
) | |
dataloader = get_dataloader(train_dataset, args.batch_size, rank, world_size) | |
print("Finish loading data") | |
TrainLoop( | |
model=model, | |
diffusion=diffusion, | |
data=dataloader, | |
batch_size=args.batch_size, | |
microbatch=args.microbatch, | |
lr=args.lr, | |
ema_rate=args.ema_rate, | |
log_interval=args.log_interval, | |
save_interval=args.save_interval, | |
resume_checkpoint=args.resume_checkpoint, | |
use_fp16=args.use_fp16, | |
fp16_scale_growth=args.fp16_scale_growth, | |
schedule_sampler=schedule_sampler, | |
weight_decay=args.weight_decay, | |
lr_anneal_steps=args.lr_anneal_steps, | |
checkpoint_path=args.checkpoint_path, | |
gradient_clipping=args.gradient_clipping, | |
eval_data=None, | |
eval_interval=args.eval_interval, | |
).run_loop() | |
dist.destroy_process_group() | |
def create_argparser(): | |
defaults = dict() | |
text_defaults = dict( | |
wandb_token="", | |
batch_size=16, | |
cache_mode="no", | |
checkpoint_path="checkpoints", | |
class_cond=False, | |
config="ll", | |
config_name="QizhiPei/biot5-base-text2mol", | |
dataset_path="dataset", | |
diffusion_steps=2000, | |
dropout=0.01, | |
e2e_train="", | |
ema_rate="0.9999", | |
emb_scale_factor=1.0, | |
eval_interval=2000, | |
experiment="random", | |
experiment_mode="lm", | |
fp16_scale_growth=0.001, | |
gradient_clipping=2.4, | |
image_size=8, | |
in_channel=16, | |
learn_sigma=False, | |
log_interval=1000, | |
logits_mode=1, | |
lr=0.00005, | |
lr_anneal_steps=500000, | |
microbatch=-1, | |
modality="e2e-tgt", | |
model_arch="transformer", | |
noise_level=0.0, | |
noise_schedule="sqrt", | |
num_channels=128, | |
num_heads=4, | |
num_heads_upsample=-1, | |
num_res_blocks=2, | |
out_channel=16, | |
padding_mode="pad", | |
predict_xstart=True, | |
preprocessing_num_workers=1, | |
rescale_learned_sigmas=True, | |
rescale_timesteps=True, | |
resume_checkpoint="", | |
save_interval=50000, | |
schedule_sampler="uniform", | |
seed=42, | |
timestep_respacing="", | |
training_mode="e2e", | |
use_bert_tokenizer="no", | |
use_checkpoint=False, | |
use_fp16=False, | |
use_kl=False, | |
use_scale_shift_norm=True, | |
weight_decay=0.0, | |
model_in_channels=32, | |
model_model_channels=128, | |
model_dropout=0.01, | |
model_hidden_size=1024, | |
model_num_attention_heads=16, | |
model_num_hidden_layers=12, | |
dataset_name="", | |
model_path="", | |
) | |
defaults.update(model_and_diffusion_defaults()) | |
defaults.update(text_defaults) | |
parser = argparse.ArgumentParser() | |
add_dict_to_argparser(parser, defaults) | |
return parser | |
if __name__ == "__main__": | |
world_size = 1 | |
mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True) | |