File size: 3,473 Bytes
a220803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import sys
sys.path.append(".")
import torch
import random
import numpy as np
from opensora.models.ae.videobase import (
    CausalVAEModel,
)
from torch.utils.data import DataLoader
from opensora.models.ae.videobase.dataset_videobase import VideoDataset
import argparse
from transformers import HfArgumentParser
from dataclasses import dataclass, field, asdict
import torch.distributed as dist
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

@dataclass
class TrainingArguments:
    exp_name: str = field(default="causalvae")
    batch_size: int = field(default=1)
    precision: str = field(default="bf16")
    max_steps: int = field(default=100000)
    save_steps: int = field(default=2000)
    output_dir: str = field(default="results/causalvae")
    video_path: str = field(default="/remote-home1/dataset/data_split_tt")
    video_num_frames: int = field(default=17)
    sample_rate: int = field(default=1)
    dynamic_sample: bool = field(default=False)
    model_config: str = field(default="scripts/causalvae/288.yaml")
    n_nodes: int = field(default=1)
    devices: int = field(default=8)
    resolution: int = field(default=64)
    num_workers: int = field(default=8)
    resume_from_checkpoint: str = field(default=None)
    
def set_seed(seed=1006):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def load_callbacks_and_logger(args):
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.output_dir,
        filename="model-{epoch:02d}-{step}",
        every_n_train_steps=args.save_steps,
        save_top_k=-1,
        save_on_train_epoch_end=False,
    )
    lr_monitor = LearningRateMonitor(logging_interval="step")
    logger = WandbLogger(name=args.exp_name, log_model=False)
    return [checkpoint_callback, lr_monitor], logger

def train(args):
    set_seed()
    # Load Config
    model = CausalVAEModel()
    if args.resume_from_checkpoint is not None:
        model = CausalVAEModel.from_pretrained(args.resume_from_checkpoint)
    else:
        model = CausalVAEModel.from_config(args.model_config)
        
    if (dist.is_initialized() and dist.get_rank() == 0) or not dist.is_initialized():
        print(model)
        
    # Load Dataset
    dataset = VideoDataset(args.video_path, sequence_length=args.video_num_frames, resolution=args.resolution, sample_rate=args.sample_rate, dynamic_sample=args.dynamic_sample)
    train_loader = DataLoader(
        dataset,
        shuffle=True,
        num_workers=args.num_workers,
        batch_size=args.batch_size,
        pin_memory=True,
    )
    # Load Callbacks and Logger
    callbacks, logger = load_callbacks_and_logger(args)
    # Load Trainer
    trainer = pl.Trainer(
        accelerator="cuda",
        devices=args.devices,
        num_nodes=args.n_nodes,
        callbacks=callbacks,
        logger=logger,
        log_every_n_steps=5,
        precision=args.precision,
        max_steps=args.max_steps,
        strategy="ddp_find_unused_parameters_true"
    )
    trainer_kwargs = {}
    if args.resume_from_checkpoint:
        trainer_kwargs['ckpt_path'] = args.resume_from_checkpoint
        
    trainer.fit(
        model,
        train_loader,
        **trainer_kwargs
    )

if __name__ == "__main__":
    parser = HfArgumentParser(TrainingArguments)
    args = parser.parse_args_into_dataclasses()
    train(args[0])