File size: 3,144 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import os
import numpy as np
from .ab_dataset import ABDataset


class _SplitDataset(torch.utils.data.Dataset):
    """Used by split_dataset"""

    def __init__(self, underlying_dataset, keys):
        super(_SplitDataset, self).__init__()
        self.underlying_dataset = underlying_dataset
        self.keys = keys

    def __getitem__(self, key):
        return self.underlying_dataset[self.keys[key]]

    def __len__(self):
        return len(self.keys)


def split_dataset(dataset, n, seed=0, transform=None):
    
    if isinstance(dataset, ABDataset):
        if dataset.task_type == 'Object Detection':
            return split_dataset_det(dataset, n, seed)
        if dataset.task_type == 'MM Object Detection':
            return split_dataset_det_mm(dataset, n, seed, transform=transform)
    
    """
    Return a pair of datasets corresponding to a random split of the given
    dataset, with n datapoints in the first dataset and the rest in the last,
    using the given random seed
    """
    assert n <= len(dataset), f'{n}_{len(dataset)}'

    cache_p = f'{n}_{seed}_{len(dataset)}'
    cache_p = os.path.join(os.path.expanduser(
        '~'), '.domain_benchmark_split_dataset_cache_' + str(cache_p))
    if os.path.exists(cache_p):
        keys_1, keys_2 = torch.load(cache_p)
    else:
        keys = list(range(len(dataset)))
        np.random.RandomState(seed).shuffle(keys)
        keys_1 = keys[:n]
        keys_2 = keys[n:]
        torch.save((keys_1, keys_2), cache_p)
    
    return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2)


def train_val_split(dataset, split):
    assert split in ['train', 'val']
    if split == 'train':
        return split_dataset(dataset, int(len(dataset) * 0.8))[0]
    else:
        return split_dataset(dataset, int(len(dataset) * 0.8))[1]

    
def train_val_test_split(dataset, split):
    assert split in ['train', 'val', 'test']

    train_set, test_set = split_dataset(dataset, int(len(dataset) * 0.8))
    train_set, val_set = split_dataset(train_set, int(len(train_set) * 0.8))
    
    return {'train': train_set, 'val': val_set, 'test': test_set}[split]


def split_dataset_det(dataset: ABDataset, n, seed=0):
    coco_ann_json_path = dataset.ann_json_file_path_for_split
    from .object_detection.yolox_data_util.api import coco_split, get_default_yolox_coco_dataset
    split_coco_ann_json_path = coco_split(coco_ann_json_path, ratio=n / len(dataset))[0]
    # print(n, len(dataset))
    return get_default_yolox_coco_dataset(dataset.root_dir, split_coco_ann_json_path, train=dataset.split == 'train'), None

def split_dataset_det_mm(dataset: ABDataset, n, seed=0, transform=None):
    coco_ann_json_path = dataset.ann_json_file_path_for_split
    from .object_detection.yolox_data_util.api import coco_split, get_yolox_coco_dataset_with_caption
    split_coco_ann_json_path = coco_split(coco_ann_json_path, ratio=n / len(dataset))[0]
    # print(n, len(dataset))
    return get_yolox_coco_dataset_with_caption(dataset.root_dir, split_coco_ann_json_path, transform=transform, train=dataset.split == 'train', classes=dataset.classes), None