Spaces:
No application file
No application file
File size: 1,033 Bytes
a6df73d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import sys
import torch
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
trainer = dict(
accelerator="gpu",
devices=-1,
gradient_clip_val=0.5,
log_every_n_steps=10,
val_check_interval=5000,
check_val_every_n_epoch=None,
max_steps=300000,
# Warning: If you are training the model with fs2 (and see nan), you should either use bf16 or fp32
precision=16,
callbacks=[
ModelCheckpoint(
filename="{epoch}-{step}-{valid_loss:.2f}",
every_n_train_steps=5000,
save_top_k=-1,
),
LearningRateMonitor(logging_interval="step"),
],
)
# Use DDP for multi-gpu training
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# Use gloo for windows
process_group_backend = "nccl" if sys.platform != "win32" else "gloo"
trainer["strategy"] = DDPStrategy(
find_unused_parameters=True, process_group_backend=process_group_backend
)
|