File size: 7,601 Bytes
d4ebf73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from torch.utils.data import DataLoader
from datasets.citysundepth import CityScapesSunDepth
from datasets.citysunrgb import CityScapesSunRGB
from datasets.citysunrgbd import CityScapesSunRGBD
from datasets.preprocessors import DepthTrainPre, DepthValPre, NYURGBDTrainPre, NYURGBDValPre, RGBDTrainPre, RGBDValPre, RGBTrainPre, RGBValPre
from datasets.tfnyu import TFNYU
from utils.constants import Constants as C

def get_dataset(args):
    datasetClass = None
    if args.data == "nyudv2":
        return TFNYU
    if args.data == "city" or args.data == "sunrgbd" or args.data == 'stanford_indoor':
        if len(args.modalities) == 1 and args.modalities[0] == 'rgb':
            datasetClass = CityScapesSunRGB
        elif len(args.modalities) == 1 and args.modalities[0] == 'depth':
            datasetClass = CityScapesSunDepth
        elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
            datasetClass = CityScapesSunRGBD
        else:
            raise Exception(f"{args.modalities} not configured in get_dataset function.")
    else:
        raise Exception(f"{args.data} not configured in get_dataset function.")
    return datasetClass

def get_preprocessors(args, dataset_settings, mode):
    if args.data == "nyudv2" and len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
        if mode == 'train':
            return NYURGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
        elif mode == 'val':
            return NYURGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)

    if len(args.modalities) == 1 and args.modalities[0] == 'rgb':
        if mode == 'train':
            return RGBTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
        elif mode == 'val':
            return RGBValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
        else:
            return Exception("%s mode not defined" % mode)
    elif len(args.modalities) == 1 and args.modalities[0] == 'depth':
        if mode == 'train':
            return DepthTrainPre(dataset_settings)
        elif mode == 'val':
            return DepthValPre(dataset_settings)
        else:
            return Exception("%s mode not defined" % mode)
    elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
        if mode == 'train':
            return RGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
        elif mode == 'val':
            return RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
        else:
            return Exception("%s mode not defined" % mode)
    else:
        raise Exception("%s not configured for preprocessing" % args.modalities)

def get_train_loader(datasetClass, args, train_source, unsupervised = False):
    dataset_settings = {'rgb_root': args.rgb_root,
                        'gt_root': args.gt_root,
                        'depth_root': args.depth_root,
                        'train_source': train_source,
                        'eval_source': args.eval_source, 
                        'required_length': args.total_train_imgs, #Every dataloader will have  Total Train Images / batch size iterations to be consistent
                        # 'max_samples': args.max_samples, #Every dataloader will have  Total Train Images / batch size iterations to be consistent
                        'train_scale_array': args.train_scale_array,
                        'image_height': args.image_height,
                        'image_width': args.image_width,
                        'modalities': args.modalities}

    preprocessing = get_preprocessors(args, dataset_settings, "train")
    train_dataset = datasetClass(dataset_settings, "train", unsupervised, preprocessing)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas = args.world_size, rank = args.rank)
    if unsupervised and "unsup_batch_size" in args:
        batch_size = args.unsup_batch_size
    else:
        batch_size = args.batch_size
    train_loader = DataLoader(train_dataset, 
                                batch_size = args.batch_size // args.world_size,
                                num_workers = args.num_workers, 
                                drop_last = True, 
                                shuffle = False,
                                sampler = train_sampler)
    return train_loader
    
def get_val_loader(datasetClass, args):
    dataset_settings = {'rgb_root': args.rgb_root,
                        'gt_root': args.gt_root,
                        'depth_root': args.depth_root,
                        'train_source': None,
                        'eval_source': args.eval_source, 
                        'required_length': None,
                        'max_samples': None,
                        'train_scale_array': args.train_scale_array,
                        'image_height': args.image_height,
                        'image_width': args.image_width,
                        'modalities': args.modalities}
    if args.data == 'sunrgbd':
        eval_sources = []
        for shape in ['427_561', '441_591', '530_730', '531_681']:
            eval_sources.append(dataset_settings['eval_source'].split('.')[0] + '_' +  shape + '.txt')
    else:
        eval_sources = [args.eval_source]

    preprocessing = get_preprocessors(args, dataset_settings, "val")
    if args.sliding_eval:
        collate_fn = _sliding_collate_fn
    else:
        collate_fn = None

    val_loaders = []
    for eval_source in eval_sources:
        dataset_settings['eval_source'] = eval_source
        val_dataset = datasetClass(dataset_settings, "val", False, preprocessing, args.sliding_eval, args.stride_rate)
        if args.rank is not None: #DDP Evaluation
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas = args.world_size, rank = args.rank)
            batch_size = args.val_batch_size // args.world_size
        else: #DP Evaluation
            val_sampler = None
            batch_size = args.val_batch_size

        val_loader = DataLoader(val_dataset,
                            batch_size = batch_size,
                            num_workers = 4, 
                            drop_last = False, 
                            shuffle = False, 
                            collate_fn = collate_fn, 
                            sampler = val_sampler)
        val_loaders.append(val_loader)
    return val_loaders
    

def _sliding_collate_fn(batch):
        gt = torch.stack([b['gt'] for b in batch])
        sliding_output = []
        num_modalities = len(batch[0]['sliding_output'][0][0])
        for i in range(len(batch[0]['sliding_output'])): #i iterates over positions
            imgs = [torch.stack([b['sliding_output'][i][0][m] for b in batch]) for m in range(num_modalities)]
            pos = batch[0]['sliding_output'][i][1]
            pos_compare = [(b['sliding_output'][i][1] == pos).all() for b in batch]
            assert all(pos_compare), f"Position not same for all points in the batch: {pos_compare}, {[b['sliding_output'][i][1] for b in batch]}"
            margin = batch[0]['sliding_output'][i][2]
            margin_compare = [(b['sliding_output'][i][2] == margin).all() for b in batch]
            assert all(margin_compare), f"Margin not same for all points in the batch: {margin_compare}, {[b['sliding_output'][i][2] for b in batch]}"
            sliding_output.append((imgs, pos, margin))
        return {"gt": gt, "sliding_output": sliding_output}