MeMDLM / scripts /train.py
sgoel30's picture
Upload 2 files
60ee22e verified
raw
history blame
1.41 kB
import pytorch_lightning as L
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint
import config
from data_loader import get_dataloaders
from esm_utils import load_esm2_model
from diffusion import Diffusion
import wandb
import sys
# Get dataloaders
train_loader, val_loader, _ = get_dataloaders(config)
# Initialize ESM tokenizer and model
tokenizer, _, _ = load_esm2_model(config.MODEL_NAME)
# Initialize diffusion model
latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer)
print(latent_diffusion_model)
sys.stdout.flush()
# Define checkpoints to save best model by minimum validation loss
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
save_top_k=1,
mode='min',
dirpath="/workspace/a03-sgoel/MDpLM/",
filename="best_model_epoch{epoch:02d}"
)
# Initialize trainer
trainer = L.Trainer(
max_epochs=config.Training.NUM_EPOCHS,
precision=config.Training.PRECISION,
devices=1,
accelerator='gpu',
strategy=DDPStrategy(find_unused_parameters=False),
accumulate_grad_batches=config.Training.ACCUMULATE_GRAD_BATCHES,
default_root_dir=config.Training.SAVE_DIR,
callbacks=[checkpoint_callback]
)
print(trainer)
print("Training model...")
sys.stdout.flush()
# Train the model
trainer.fit(latent_diffusion_model, train_loader, val_loader)
wandb.finish()