M3L / datasets /citysunrgbd.py
harshm121's picture
Working demo
d4ebf73
raw
history blame
2.85 kB
import torch
import numpy as np
from datasets.base_dataset import BaseDataset
class CityScapesSunRGBD(BaseDataset):
def __init__(self, dataset_settings, mode, unsupervised, preprocess, sliding = False, stride_rate = None):
super(CityScapesSunRGBD, self).__init__(dataset_settings, mode, unsupervised)
self.preprocess = preprocess
self.sliding = sliding
self.stride_rate = stride_rate
if self.sliding and self._mode == 'train':
print("Ensure correct preprocessing is being done!")
def __getitem__(self, index):
# if self._file_length is not None:
# names = self._construct_new_file_names(self._file_length)[index]
# else:
# names = self._file_names[index]
names = self._file_names[index]
rgb_path = self._rgb_path+names['rgb']
depth_path = self._rgb_path+names['depth']
if not self.unsupervised:
gt_path = self._gt_path+names['gt']
item_name = names['rgb'].split("/")[-1].split(".")[0]
rgb = self._open_rgb(rgb_path)
depth = self._open_depth(depth_path)
gt = None
if not self.unsupervised:
gt = self._open_gt(gt_path)
if not self.sliding:
if self.preprocess is not None:
rgb, depth, gt = self.preprocess(rgb, depth, gt)
if self._mode in ['train', 'val']:
rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
if gt is not None:
gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
else:
raise Exception(f"{self._mode} not supported in CityScapesSunRGB")
# output_dict = dict(rgb=rgb, fn=str(item_name),
# n=len(self._file_names))
output_dict = dict(data=[rgb, depth], name = item_name)
if gt is not None:
output_dict['gt'] = gt
return output_dict
else:
sliding_ouptut = self.slide_over_image(rgb, self.model_input_shape, self.stride_rate)
output_dict = {}
if self._mode in ['train', 'val']:
if gt is not None:
gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
output_dict['gt'] = gt
output_dict['sliding_output'] = []
for img_sub, pos, margin in sliding_ouptut:
if self.preprocess is not None:
img_sub, _ = self.preprocess(img_sub, None)
img_sub = torch.from_numpy(np.ascontiguousarray(img_sub)).float()
output_dict['sliding_output'].append(([img_sub], pos, margin))
return output_dict