Spaces:
Runtime error
Runtime error
File size: 6,282 Bytes
f239efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset
def get_media_type(dataset_config):
if len(dataset_config) == 3 and dataset_config[2] == "video":
return "video"
elif dataset_config[-1] == "only_video":
return "only_video"
else:
return "image"
def create_dataset(dataset_type, config):
if "clip" in config.model.get("vit_model", 'vit'):
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
else:
vision_enc_name = config.model.vision_encoder.name
if "swin" in vision_enc_name or "vit" in vision_enc_name:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
elif "beit" in vision_enc_name:
mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning
std = (0.5, 0.5, 0.5)
elif "clip" in vision_enc_name:
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
else:
raise ValueError
normalize = transforms.Normalize(mean, std)
# loaded images and videos are torch.Tensor of torch.uint8 format,
# ordered as (T, 1 or 3, H, W) where T=1 for image
type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
if config.inputs.video_input.random_aug:
aug_transform = transforms.RandAugment()
else:
aug_transform = transforms.Lambda(lambda x: x)
train_transform = transforms.Compose(
[
aug_transform,
transforms.RandomResizedCrop(
config.inputs.image_res,
scale=(0.5, 1.0),
interpolation=InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(),
type_transform,
normalize,
]
)
test_transform = transforms.Compose(
[
transforms.Resize(
(config.inputs.image_res, config.inputs.image_res),
interpolation=InterpolationMode.BICUBIC,
),
type_transform,
normalize,
]
)
video_reader_type = config.inputs.video_input.get("video_reader_type", "decord")
video_only_dataset_kwargs_train = dict(
video_reader_type=video_reader_type,
sample_type=config.inputs.video_input.sample_type,
num_frames=config.inputs.video_input.num_frames,
num_tries=3, # false tolerance
)
if dataset_type == "pt_train":
raise ValueError("NOT PRETRAINING YET")
elif dataset_type in ["it_train"]:
# convert to list of lists
train_files = (
[config.train_file] if isinstance(config.train_file[0], str) else config.train_file
)
train_media_types = sorted(list({get_media_type(e) for e in train_files}))
train_datasets = []
for m in train_media_types:
dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset
# dataset of the same media_type will be mixed in a single Dataset object
_train_files = [e for e in train_files if get_media_type(e) == m]
datasets = []
for train_file in _train_files:
dataset_kwargs = dict(
ann_file=train_file,
transform=train_transform,
mm_alone=config.preprocess.get("mm_alone", True),
add_second_msg=config.preprocess.get("add_second_msg", True),
skip_short_sample=config.preprocess.get("skip_short_sample", False),
clip_transform=config.preprocess.get("clip_transform", False),
random_shuffle=config.preprocess.get("random_shuffle", True),
system=config.preprocess.get("system", ""),
role=config.preprocess.get('roles', ("Human", "Assistant")),
end_signal=config.preprocess.get('end_signal', "###"),
begin_signal=config.preprocess.get('begin_signal', ""),
)
if m == "video":
video_only_dataset_kwargs_train.update({
"start_token": config.model.get("start_token", "<Video>"),
"end_token": config.model.get("end_token", "</Video>"),
})
dataset_kwargs.update(video_only_dataset_kwargs_train)
if "tgif" in train_file[1]:
video_only_dataset_kwargs_train.update({
"video_reader_type": "gif"
})
dataset_kwargs.update(video_only_dataset_kwargs_train)
elif "webvid" in train_file[1]:
video_only_dataset_kwargs_train.update({
"video_reader_type": "hdfs"
})
else:
video_only_dataset_kwargs_train.update({
"video_reader_type": "decord"
})
dataset_kwargs.update(video_only_dataset_kwargs_train)
datasets.append(dataset_cls(**dataset_kwargs))
dataset = ConcatDataset(datasets)
train_datasets.append(dataset)
return train_datasets
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=False,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
persistent_workers=True if n_worker > 0 else False,
)
loaders.append(loader)
return loaders
|