Spaces:
Build error
Build error
File size: 2,415 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from typing import Dict, Optional
import torch
import numpy as np
import pytorch_lightning as pl
from yacs.config import CfgNode
from ..configs import to_lower
from .dataset import Dataset
class HAMERDataModule(pl.LightningDataModule):
def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None:
"""
Initialize LightningDataModule for HAMER training
Args:
cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
dataset_cfg (CfgNode): Dataset configuration file
"""
super().__init__()
self.cfg = cfg
self.dataset_cfg = dataset_cfg
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.mocap_dataset = None
def setup(self, stage: Optional[str] = None) -> None:
"""
Load datasets necessary for training
Args:
cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
"""
if self.train_dataset == None:
self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000)
self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000)
self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP]))
def train_dataloader(self) -> Dict:
"""
Setup training data loader.
Returns:
Dict: Dictionary containing image and mocap data dataloaders
"""
train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR)
mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1)
return {'img': train_dataloader, 'mocap': mocap_dataloader}
def val_dataloader(self) -> torch.utils.data.DataLoader:
"""
Setup val data loader.
Returns:
torch.utils.data.DataLoader: Validation dataloader
"""
val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS)
return val_dataloader
|