File size: 5,515 Bytes
d4ebf73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.utils.data as data
import numpy as np
import cv2
from PIL import Image
from utils.img_utils import pad_image_to_shape

class BaseDataset(data.Dataset):

    def __init__(self, dataset_settings, mode, unsupervised):
        self._mode = mode
        self.unsupervised = unsupervised
        self._rgb_path = dataset_settings['rgb_root']
        self._depth_path = dataset_settings['depth_root']
        self._gt_path = dataset_settings['gt_root']
        self._train_source = dataset_settings['train_source']
        self._eval_source = dataset_settings['eval_source']
        self.modalities = dataset_settings['modalities']
        # self._file_length = dataset_settings['max_samples'] 
        self._required_length = dataset_settings['required_length'] 
        self._file_names = self._get_file_names(mode)
        self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
        
    def __len__(self):
        if self._required_length is not None:
            return self._required_length
        return len(self._file_names) # when model == "val"

    def _get_file_names(self, mode):
        assert mode in ['train', 'val']
        source = self._train_source
        if mode == "val":
            source = self._eval_source

        file_names = []
        with open(source) as f:
            files = f.readlines()

        for item in files:
            names = self._process_item_names(item)
            file_names.append(names)
        
        if mode == "val":
            return file_names
        elif self._required_length <= len(file_names):
            return file_names[:self._required_length]
        else:
            return self._construct_new_file_names(file_names, self._required_length)

    def _construct_new_file_names(self, file_names, length):
        assert isinstance(length, int)
        files_len = len(file_names)

        new_file_names = file_names * (length // files_len) #length % files_len items remaining

        rand_indices = torch.randperm(files_len).tolist()
        new_indices = rand_indices[:length % files_len]

        new_file_names += [file_names[i] for i in new_indices]

        return new_file_names

    def _process_item_names(self, item):
        item = item.strip()
        item = item.split('\t')
        num_modalities = len(self.modalities)
        num_items = len(item)
        names = {}
        if not self.unsupervised:
            assert num_modalities + 1 == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}" + item[0]
            for i, modality in enumerate(self.modalities):
                names[modality] = item[i]
            names['gt'] = item[-1]
        else:
            assert num_modalities == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}"
            for i, modality in enumerate(self.modalities):
                names[modality] = item[i]
            names['gt'] = None

        return names

    def _open_rgb(self, rgb_path, dtype = None):
        bgr = cv2.imread(rgb_path, cv2.IMREAD_COLOR) #cv2 reads in BGR format, HxWxC
        rgb = np.array(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), dtype=dtype) #Pretrained PyTorch model accepts image in RGB
        return rgb

    def _open_depth(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels
        img_arr = np.array(Image.open(depth_path))
        if len(img_arr.shape) == 2:  # grayscale
            img_arr = np.array(np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0), dtype = dtype)
        img_arr = (img_arr - img_arr.min()) * 255.0 / (img_arr.max() - img_arr.min())
        return img_arr
    
    def _open_depth_tf_nyu(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels
        img_arr = np.array(Image.open(depth_path))
        if len(img_arr.shape) == 2:  # grayscale
            img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)
        return img_arr

    def _open_gt(self, gt_path, dtype = None):
        return np.array(cv2.imread(gt_path,  cv2.IMREAD_GRAYSCALE), dtype=dtype)

    def slide_over_image(self, img, crop_size, stride_rate):
        H, W, C = img.shape
        long_size = H if H > W else W
        output = []
        if long_size <= min(crop_size[0], crop_size[1]):
            raise Exception("Crop size is greater than the image size itself. Not handeled right now")
        
        else:
            stride_0 = int(np.ceil(crop_size[0] * stride_rate))
            stride_1 = int(np.ceil(crop_size[1] * stride_rate))
            r_grid = int(np.ceil((H - crop_size[0]) / stride_0)) + 1
            c_grid = int(np.ceil((W - crop_size[1]) / stride_1)) + 1

            for grid_yidx in range(r_grid):
                for grid_xidx in range(c_grid):
                    s_x = grid_xidx * stride_1
                    s_y = grid_yidx * stride_0
                    e_x = min(s_x + crop_size[1], W)
                    e_y = min(s_y + crop_size[0], H)
                    s_x = e_x - crop_size[1]
                    s_y = e_y - crop_size[0]
                    img_sub = img[s_y:e_y, s_x: e_x, :]
                    img_sub, margin = pad_image_to_shape(img_sub, crop_size, cv2.BORDER_CONSTANT, value=0)
                    output.append((img_sub, np.array([s_y, e_y, s_x, e_x]), margin))
        
        return output