File size: 1,033 Bytes
a6df73d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import sys

import torch
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy

trainer = dict(
    accelerator="gpu",
    devices=-1,
    gradient_clip_val=0.5,
    log_every_n_steps=10,
    val_check_interval=5000,
    check_val_every_n_epoch=None,
    max_steps=300000,
    # Warning: If you are training the model with fs2 (and see nan), you should either use bf16 or fp32
    precision=16,
    callbacks=[
        ModelCheckpoint(
            filename="{epoch}-{step}-{valid_loss:.2f}",
            every_n_train_steps=5000,
            save_top_k=-1,
        ),
        LearningRateMonitor(logging_interval="step"),
    ],
)

# Use DDP for multi-gpu training
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    # Use gloo for windows
    process_group_backend = "nccl" if sys.platform != "win32" else "gloo"

    trainer["strategy"] = DDPStrategy(
        find_unused_parameters=True, process_group_backend=process_group_backend
    )