sgoel30 commited on
Commit
5f9a93d
·
verified ·
1 Parent(s): 9875b25

Upload train.py

Browse files
Files changed (1) hide show
  1. scripts/train.py +35 -11
scripts/train.py CHANGED
@@ -1,24 +1,48 @@
1
  import pytorch_lightning as L
2
  from pytorch_lightning.strategies import DDPStrategy
3
- from configs.config import Config
4
- from utils.data_loader import get_dataloaders
5
- from models.diffusion import Diffusion
 
 
 
6
 
7
  # Get dataloaders
8
- train_loader, val_loader, _ = get_dataloaders(Config)
9
 
10
- # Initialize model
11
- latent_diffusion_model = Diffusion(Config, latent_dim=Config.latent_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Initialize trainer
14
  trainer = L.Trainer(
15
- max_epochs=Config.training["epochs"],
16
- gpus=Config.training["gpus"],
17
- precision=Config.training["precision"],
 
18
  strategy=DDPStrategy(find_unused_parameters=False),
19
- accumulate_grad_batches=Config.training["accumulate_grad_batches"],
20
- default_root_dir=Config.training["save_dir"]
 
21
  )
22
 
 
 
 
 
23
  # Train the model
24
  trainer.fit(latent_diffusion_model, train_loader, val_loader)
 
 
1
  import pytorch_lightning as L
2
  from pytorch_lightning.strategies import DDPStrategy
3
+ from pytorch_lightning.callbacks import ModelCheckpoint
4
+ import config
5
+ from data_loader import get_dataloaders
6
+ from esm_utils import load_esm2_model
7
+ from diffusion import Diffusion
8
+ import sys
9
 
10
  # Get dataloaders
11
+ train_loader, val_loader, _ = get_dataloaders(config)
12
 
13
+ # Initialize ESM tokenizer and model
14
+ tokenizer, model = load_esm2_model(config.MODEL_NAME)
15
+
16
+ # Initialize diffusion model
17
+ latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer)
18
+ print(latent_diffusion_model)
19
+ sys.stdout.flush()
20
+
21
+ # Define checkpoints to save best model by minimum validation loss
22
+ checkpoint_callback = ModelCheckpoint(
23
+ monitor='val_loss',
24
+ save_top_k=1,
25
+ mode='min',
26
+ dirpath="/workspace/a03-sgoel/MDpLM/",
27
+ filename="best_model_epoch{epoch:02d}"
28
+ )
29
 
30
  # Initialize trainer
31
  trainer = L.Trainer(
32
+ max_epochs=config.Training.NUM_EPOCHS,
33
+ precision=config.Training.PRECISION,
34
+ devices=1,
35
+ accelerator='gpu',
36
  strategy=DDPStrategy(find_unused_parameters=False),
37
+ accumulate_grad_batches=config.Training.ACCUMULATE_GRAD_BATCHES,
38
+ default_root_dir=config.Training.SAVE_DIR,
39
+ callbacks=[checkpoint_callback]
40
  )
41
 
42
+ print(trainer)
43
+ print("Training model...")
44
+ sys.stdout.flush()
45
+
46
  # Train the model
47
  trainer.fit(latent_diffusion_model, train_loader, val_loader)
48
+