File size: 1,411 Bytes
ed920f9 5f9a93d 60ee22e 5f9a93d ed920f9 5f9a93d ed920f9 5f9a93d 60ee22e 5f9a93d ed920f9 5f9a93d ed920f9 5f9a93d ed920f9 5f9a93d ed920f9 5f9a93d 60ee22e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import pytorch_lightning as L
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint
import config
from data_loader import get_dataloaders
from esm_utils import load_esm2_model
from diffusion import Diffusion
import wandb
import sys
# Get dataloaders
train_loader, val_loader, _ = get_dataloaders(config)
# Initialize ESM tokenizer and model
tokenizer, _, _ = load_esm2_model(config.MODEL_NAME)
# Initialize diffusion model
latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer)
print(latent_diffusion_model)
sys.stdout.flush()
# Define checkpoints to save best model by minimum validation loss
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
save_top_k=1,
mode='min',
dirpath="/workspace/a03-sgoel/MDpLM/",
filename="best_model_epoch{epoch:02d}"
)
# Initialize trainer
trainer = L.Trainer(
max_epochs=config.Training.NUM_EPOCHS,
precision=config.Training.PRECISION,
devices=1,
accelerator='gpu',
strategy=DDPStrategy(find_unused_parameters=False),
accumulate_grad_batches=config.Training.ACCUMULATE_GRAD_BATCHES,
default_root_dir=config.Training.SAVE_DIR,
callbacks=[checkpoint_callback]
)
print(trainer)
print("Training model...")
sys.stdout.flush()
# Train the model
trainer.fit(latent_diffusion_model, train_loader, val_loader)
wandb.finish()
|