File size: 3,111 Bytes
8646273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

import torch

from data_preparation import augment, collation_fn, my_split_by_node
from model import Onset_picker, Updated_onset_picker

import webdataset as wds

from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.strategies import DDPStrategy
from lightning import seed_everything
import lightning as pl

seed_everything(42, workers=False)
torch.set_float32_matmul_precision('medium')

batch_size = 256
num_workers = 16 #int(os.cpu_count())
n_iters_in_epoch = 5000

train_dataset = (
      wds.WebDataset("data/sample/shard-00{0000..0001}.tar", 
                     # splitter=my_split_by_worker, 
                     nodesplitter=my_split_by_node)
      .decode()
      .map(augment)
      .shuffle(5000)
      .batched(batchsize=batch_size,
               collation_fn=collation_fn,
               partial=False
              )
).with_epoch(n_iters_in_epoch//num_workers)


val_dataset = (
      wds.WebDataset("data/sample/shard-00{0000..0000}.tar", 
                     # splitter=my_split_by_worker, 
                     nodesplitter=my_split_by_node)
      .decode()
      .map(augment)
      .repeat()
      .batched(batchsize=batch_size,
               collation_fn=collation_fn,
               partial=False
              )
).with_epoch(100)


train_loader = wds.WebLoader(train_dataset, 
                             num_workers=num_workers, 
                             shuffle=False,
                             pin_memory=True, 
                             batch_size=None)

val_loader = wds.WebLoader(val_dataset, 
                           num_workers=0,
                           shuffle=False,
                           pin_memory=True, 
                           batch_size=None) 



# model
model = Onset_picker(picker=Updated_onset_picker(), 
                     learning_rate=3e-4)
# model = torch.compile(model, mode="reduce-overhead")

logger = TensorBoardLogger("tensorboard_logdir", name="FAST")

checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}")
lr_callback = LearningRateMonitor(logging_interval='epoch')
# swa_callback = StochasticWeightAveraging(swa_lrs=0.05)
    
# # train model
trainer = pl.Trainer(
            precision='16-mixed',
            
            callbacks=[checkpoint_callback, lr_callback],
    
            devices='auto', 
            accelerator='auto', 
    
            strategy=DDPStrategy(find_unused_parameters=False,
                                 static_graph=True,
                                 gradient_as_bucket_view=True),
            benchmark=True,
            
            gradient_clip_val=0.5,
            # ckpt_path='path/to/saved/checkpoints/chkp.ckpt',

            # fast_dev_run=True,

            logger=logger,
            log_every_n_steps=50,
            enable_progress_bar=True,
    
            max_epochs=300,
        )

trainer.fit(model=model, 
            train_dataloaders=train_loader, 
            val_dataloaders=val_loader,
            )