import pytorch_lightning as L from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.callbacks import ModelCheckpoint import config from data_loader import get_dataloaders from esm_utils import load_esm2_model from diffusion import Diffusion import wandb import sys # Get dataloaders train_loader, val_loader, _ = get_dataloaders(config) # Initialize ESM tokenizer and model tokenizer, _, _ = load_esm2_model(config.MODEL_NAME) # Initialize diffusion model latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer) print(latent_diffusion_model) sys.stdout.flush() # Define checkpoints to save best model by minimum validation loss checkpoint_callback = ModelCheckpoint( monitor='val_loss', save_top_k=1, mode='min', dirpath="/workspace/a03-sgoel/MDpLM/", filename="best_model_epoch{epoch:02d}" ) # Initialize trainer trainer = L.Trainer( max_epochs=config.Training.NUM_EPOCHS, precision=config.Training.PRECISION, devices=1, accelerator='gpu', strategy=DDPStrategy(find_unused_parameters=False), accumulate_grad_batches=config.Training.ACCUMULATE_GRAD_BATCHES, default_root_dir=config.Training.SAVE_DIR, callbacks=[checkpoint_callback] ) print(trainer) print("Training model...") sys.stdout.flush() # Train the model trainer.fit(latent_diffusion_model, train_loader, val_loader) wandb.finish()