File size: 6,282 Bytes
f239efc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset


def get_media_type(dataset_config):
    if len(dataset_config) == 3 and dataset_config[2] == "video":
        return "video"
    elif dataset_config[-1] == "only_video":
        return "only_video"
    else:
        return "image"


def create_dataset(dataset_type, config):
    if "clip" in config.model.get("vit_model", 'vit'):
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
    else:
        vision_enc_name = config.model.vision_encoder.name
        if "swin" in vision_enc_name or "vit" in vision_enc_name:
            mean = (0.485, 0.456, 0.406)
            std = (0.229, 0.224, 0.225)
        elif "beit" in vision_enc_name:
            mean = (0.5, 0.5, 0.5)  # for all beit model except IN1K finetuning
            std = (0.5, 0.5, 0.5)
        elif "clip" in vision_enc_name:
            mean = (0.48145466, 0.4578275, 0.40821073)
            std = (0.26862954, 0.26130258, 0.27577711)
        else:
            raise ValueError

    normalize = transforms.Normalize(mean, std)

    # loaded images and videos are torch.Tensor of torch.uint8 format,
    # ordered as (T, 1 or 3, H, W) where T=1 for image
    type_transform = transforms.Lambda(lambda x: x.float().div(255.0))

    if config.inputs.video_input.random_aug:
        aug_transform = transforms.RandAugment()
    else:
        aug_transform = transforms.Lambda(lambda x: x)

    train_transform = transforms.Compose(
        [
            aug_transform,
            transforms.RandomResizedCrop(
                config.inputs.image_res,
                scale=(0.5, 1.0),
                interpolation=InterpolationMode.BICUBIC,
            ),
            transforms.RandomHorizontalFlip(),
            type_transform,
            normalize,
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.Resize(
                (config.inputs.image_res, config.inputs.image_res),
                interpolation=InterpolationMode.BICUBIC,
            ),
            type_transform,
            normalize,
        ]
    )

    video_reader_type = config.inputs.video_input.get("video_reader_type", "decord")
    video_only_dataset_kwargs_train = dict(
        video_reader_type=video_reader_type,
        sample_type=config.inputs.video_input.sample_type,
        num_frames=config.inputs.video_input.num_frames,
        num_tries=3,  # false tolerance
    )

    if dataset_type == "pt_train":
        raise ValueError("NOT PRETRAINING YET")
    elif dataset_type in ["it_train"]:
        # convert to list of lists
        train_files = (
            [config.train_file] if isinstance(config.train_file[0], str) else config.train_file
        )
        train_media_types = sorted(list({get_media_type(e) for e in train_files}))

        train_datasets = []
        for m in train_media_types:
            dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset
            # dataset of the same media_type will be mixed in a single Dataset object
            _train_files = [e for e in train_files if get_media_type(e) == m]

            datasets = []
            for train_file in _train_files:
                dataset_kwargs = dict(
                    ann_file=train_file,
                    transform=train_transform,
                    mm_alone=config.preprocess.get("mm_alone", True),
                    add_second_msg=config.preprocess.get("add_second_msg", True),
                    skip_short_sample=config.preprocess.get("skip_short_sample", False),
                    clip_transform=config.preprocess.get("clip_transform", False),
                    random_shuffle=config.preprocess.get("random_shuffle", True),
                    system=config.preprocess.get("system", ""),
                    role=config.preprocess.get('roles', ("Human", "Assistant")),
                    end_signal=config.preprocess.get('end_signal', "###"),
                    begin_signal=config.preprocess.get('begin_signal', ""),
                )
                if m == "video":
                    video_only_dataset_kwargs_train.update({
                        "start_token": config.model.get("start_token", "<Video>"),
                        "end_token": config.model.get("end_token", "</Video>"),
                    })
                    dataset_kwargs.update(video_only_dataset_kwargs_train)
                    if "tgif" in train_file[1]:
                        video_only_dataset_kwargs_train.update({
                            "video_reader_type": "gif"
                        })
                        dataset_kwargs.update(video_only_dataset_kwargs_train)
                    elif "webvid" in train_file[1]:
                        video_only_dataset_kwargs_train.update({
                            "video_reader_type": "hdfs"
                        })
                    else:
                        video_only_dataset_kwargs_train.update({
                            "video_reader_type": "decord"
                        })
                    dataset_kwargs.update(video_only_dataset_kwargs_train)
                datasets.append(dataset_cls(**dataset_kwargs))
            dataset = ConcatDataset(datasets)
            train_datasets.append(dataset)
        return train_datasets


def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
    loaders = []
    for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
        datasets, samplers, batch_size, num_workers, is_trains, collate_fns
    ):
        if is_train:
            shuffle = sampler is None
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=False,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
            persistent_workers=True if n_worker > 0 else False,
        )
        loaders.append(loader)
    return loaders