import sys sys.path.append(".") import torch import random import numpy as np from opensora.models.ae.videobase import ( CausalVAEModel, ) from torch.utils.data import DataLoader from opensora.models.ae.videobase.dataset_videobase import VideoDataset import argparse from transformers import HfArgumentParser from dataclasses import dataclass, field, asdict import torch.distributed as dist import os import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor @dataclass class TrainingArguments: exp_name: str = field(default="causalvae") batch_size: int = field(default=1) precision: str = field(default="bf16") max_steps: int = field(default=100000) save_steps: int = field(default=2000) output_dir: str = field(default="results/causalvae") video_path: str = field(default="/remote-home1/dataset/data_split_tt") video_num_frames: int = field(default=17) sample_rate: int = field(default=1) dynamic_sample: bool = field(default=False) model_config: str = field(default="scripts/causalvae/288.yaml") n_nodes: int = field(default=1) devices: int = field(default=8) resolution: int = field(default=64) num_workers: int = field(default=8) resume_from_checkpoint: str = field(default=None) def set_seed(seed=1006): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) def load_callbacks_and_logger(args): checkpoint_callback = ModelCheckpoint( dirpath=args.output_dir, filename="model-{epoch:02d}-{step}", every_n_train_steps=args.save_steps, save_top_k=-1, save_on_train_epoch_end=False, ) lr_monitor = LearningRateMonitor(logging_interval="step") logger = WandbLogger(name=args.exp_name, log_model=False) return [checkpoint_callback, lr_monitor], logger def train(args): set_seed() # Load Config model = CausalVAEModel() if args.resume_from_checkpoint is not None: model = CausalVAEModel.from_pretrained(args.resume_from_checkpoint) else: model = CausalVAEModel.from_config(args.model_config) if (dist.is_initialized() and dist.get_rank() == 0) or not dist.is_initialized(): print(model) # Load Dataset dataset = VideoDataset(args.video_path, sequence_length=args.video_num_frames, resolution=args.resolution, sample_rate=args.sample_rate, dynamic_sample=args.dynamic_sample) train_loader = DataLoader( dataset, shuffle=True, num_workers=args.num_workers, batch_size=args.batch_size, pin_memory=True, ) # Load Callbacks and Logger callbacks, logger = load_callbacks_and_logger(args) # Load Trainer trainer = pl.Trainer( accelerator="cuda", devices=args.devices, num_nodes=args.n_nodes, callbacks=callbacks, logger=logger, log_every_n_steps=5, precision=args.precision, max_steps=args.max_steps, strategy="ddp_find_unused_parameters_true" ) trainer_kwargs = {} if args.resume_from_checkpoint: trainer_kwargs['ckpt_path'] = args.resume_from_checkpoint trainer.fit( model, train_loader, **trainer_kwargs ) if __name__ == "__main__": parser = HfArgumentParser(TrainingArguments) args = parser.parse_args_into_dataclasses() train(args[0])