File size: 2,849 Bytes
d4ebf73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np

from datasets.base_dataset import BaseDataset


class CityScapesSunRGBD(BaseDataset):

    def __init__(self, dataset_settings, mode, unsupervised, preprocess, sliding = False, stride_rate = None):
        super(CityScapesSunRGBD, self).__init__(dataset_settings, mode, unsupervised)
        self.preprocess = preprocess
        self.sliding = sliding
        self.stride_rate = stride_rate
        if self.sliding and self._mode == 'train':
            print("Ensure correct preprocessing is being done!")

    def __getitem__(self, index):
        # if self._file_length is not None:
        #     names = self._construct_new_file_names(self._file_length)[index]
        # else:
        #     names = self._file_names[index]
        names = self._file_names[index]
        rgb_path = self._rgb_path+names['rgb'] 
        depth_path = self._rgb_path+names['depth'] 
        if not self.unsupervised:
            gt_path = self._gt_path+names['gt']
        item_name = names['rgb'].split("/")[-1].split(".")[0]

        rgb = self._open_rgb(rgb_path)
        depth = self._open_depth(depth_path)
        gt = None
        if not self.unsupervised:
            gt = self._open_gt(gt_path)

        if not self.sliding:
            if self.preprocess is not None:
                rgb, depth, gt = self.preprocess(rgb, depth, gt)

            if self._mode in ['train', 'val']:
                rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
                depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
                if gt is not None:
                    gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
            else:
                raise Exception(f"{self._mode} not supported in CityScapesSunRGB")
                
            # output_dict = dict(rgb=rgb, fn=str(item_name),
            #                    n=len(self._file_names))
            output_dict = dict(data=[rgb, depth], name = item_name)
            if gt is not None:
                output_dict['gt'] = gt
            return output_dict
    
        else:
            sliding_ouptut = self.slide_over_image(rgb, self.model_input_shape, self.stride_rate)
            output_dict = {}
            if self._mode in ['train', 'val']:
                if gt is not None:
                    gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
                output_dict['gt'] = gt
                output_dict['sliding_output'] = []
                for img_sub, pos, margin in sliding_ouptut:
                    if self.preprocess is not None:
                        img_sub, _ = self.preprocess(img_sub, None)
                    img_sub = torch.from_numpy(np.ascontiguousarray(img_sub)).float()
                    output_dict['sliding_output'].append(([img_sub], pos, margin))
            return output_dict