|
import pytorch_lightning as L |
|
from pytorch_lightning.strategies import DDPStrategy |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
import config |
|
from data_loader import get_dataloaders |
|
from esm_utils import load_esm2_model |
|
from diffusion import Diffusion |
|
import wandb |
|
import sys |
|
|
|
|
|
train_loader, val_loader, _ = get_dataloaders(config) |
|
|
|
|
|
tokenizer, _, _ = load_esm2_model(config.MODEL_NAME) |
|
|
|
|
|
latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer) |
|
print(latent_diffusion_model) |
|
sys.stdout.flush() |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor='val_loss', |
|
save_top_k=1, |
|
mode='min', |
|
dirpath="/workspace/a03-sgoel/MDpLM/", |
|
filename="best_model_epoch{epoch:02d}" |
|
) |
|
|
|
|
|
trainer = L.Trainer( |
|
max_epochs=config.Training.NUM_EPOCHS, |
|
precision=config.Training.PRECISION, |
|
devices=1, |
|
accelerator='gpu', |
|
strategy=DDPStrategy(find_unused_parameters=False), |
|
accumulate_grad_batches=config.Training.ACCUMULATE_GRAD_BATCHES, |
|
default_root_dir=config.Training.SAVE_DIR, |
|
callbacks=[checkpoint_callback] |
|
) |
|
|
|
print(trainer) |
|
print("Training model...") |
|
sys.stdout.flush() |
|
|
|
|
|
trainer.fit(latent_diffusion_model, train_loader, val_loader) |
|
|
|
wandb.finish() |
|
|