import json import glob import random from torch.utils.data import Dataset from PIL import Image from torchvision.transforms import transforms import torch import numpy as np try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC # modify for transformation for vit # modfify wider crop-person images ###### Base data loader ###### class DataSet(Dataset): def __init__( self, ann_files, augs, img_size, dataset, ): self.dataset = dataset self.ann_files = ann_files self.augment = self.augs_function(augs, img_size) self.transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])] # In this paper, we normalize the image data to [0, 1] # You can also use the so called 'ImageNet' Normalization method ) self.anns = [] self.load_anns() print(self.augment) # in wider dataset we use vit models # so transformation has been changed if self.dataset == "wider": self.transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) def augs_function(self, augs, img_size): t = [] if "randomflip" in augs: t.append(transforms.RandomHorizontalFlip()) if "ColorJitter" in augs: t.append( transforms.ColorJitter( brightness=0.5, contrast=0.5, saturation=0.5, hue=0 ) ) if "resizedcrop" in augs: t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0))) if "RandAugment" in augs: t.append(RandAugment()) t.append(transforms.Resize((img_size, img_size))) return transforms.Compose(t) def load_anns(self): self.anns = [] for ann_file in self.ann_files: json_data = json.load(open(ann_file, "r")) self.anns += json_data def __len__(self): return len(self.anns) def __getitem__(self, idx): idx = idx % len(self) ann = self.anns[idx] img = Image.open(ann["img_path"]).convert("RGB") if self.dataset == "wider": x, y, w, h = ann["bbox"] img_area = img.crop([x, y, x + w, y + h]) img_area = self.augment(img_area) img_area = self.transform(img_area) message = { "img_path": ann["img_path"], "target": torch.Tensor(ann["target"]), "img": img_area, } else: # voc and coco img = self.augment(img) img = self.transform(img) message = { "img_path": ann["img_path"], "target": torch.Tensor(ann["target"]), "img": img, } return message # finally, if we use dataloader to get the data, we will get # { # "img_path": list, # length = batch_size # "target": Tensor, # shape: batch_size * num_classes # "img": Tensor, # shape: batch_size * 3 * 224 * 224 # } def preprocess_scribble(img, img_size): transform = transforms.Compose( [ transforms.Resize(img_size, BICUBIC), transforms.CenterCrop(img_size), #_convert_image_to_rgb, transforms.ToTensor(), ] ) return transform(img) class DataSetMaskSup(Dataset): """ Data loader with scribbles. """ def __init__( self, ann_files, augs, img_size, dataset, ): self.dataset = dataset self.ann_files = ann_files self.img_size = img_size self.augment = self.augs_function(augs, img_size) self.transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])] # In this paper, we normalize the image data to [0, 1] # You can also use the so called 'ImageNet' Normalization method ) self.anns = [] self.load_anns() print(self.augment) # scribbles self._scribbles_folder = "./datasets/SCRIBBLES" # Type of masks to use, this is hardcoded since we find that high masks # work better in MSL. See paper for details. # for low masks # self._scribbles = sorted(glob.glob(self._scribbles_folder + "/*.png"))[ # :1000 # ] # for high masks self._scribbles = sorted(glob.glob(self._scribbles_folder + "/*.png"))[::-1][ :1000 ] # in wider dataset we use vit models # so transformation has been changed if self.dataset == "wider": self.transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) def augs_function(self, augs, img_size): t = [] if "randomflip" in augs: t.append(transforms.RandomHorizontalFlip()) if "ColorJitter" in augs: t.append( transforms.ColorJitter( brightness=0.5, contrast=0.5, saturation=0.5, hue=0 ) ) if "resizedcrop" in augs: t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0))) if "RandAugment" in augs: t.append(RandAugment()) t.append(transforms.Resize((img_size, img_size))) return transforms.Compose(t) def load_anns(self): self.anns = [] for ann_file in self.ann_files: json_data = json.load(open(ann_file, "r")) self.anns += json_data def __len__(self): return len(self.anns) def __getitem__(self, idx): idx = idx % len(self) ann = self.anns[idx] img = Image.open(ann["img_path"]).convert("RGB") # get scribble scribble_path = self._scribbles[ random.randint(0, 950) ] scribble = Image.open(scribble_path).convert('P') scribble = preprocess_scribble(scribble, self.img_size) scribble_t = (scribble > 0).float() # threshold to [0,1] inv_scribble = (torch.max(scribble_t) - scribble_t) # inverted scribble if self.dataset == "wider": x, y, w, h = ann["bbox"] img_area = img.crop([x, y, x + w, y + h]) img_area = self.augment(img_area) img_area = self.transform(img_area) # masked image masked_image = img_area * inv_scribble message = { "img_path": ann["img_path"], "target": torch.Tensor(ann["target"]), "img": img_area, "masked_img": masked_image, #"scribble": inv_scribble, } else: # voc and coco img = self.augment(img) img = self.transform(img) # masked image masked_image = img * inv_scribble message = { "img_path": ann["img_path"], "target": torch.Tensor(ann["target"]), "img": img, "masked_img": masked_image, #"scribble": inv_scribble, } return message # finally, if we use dataloader to get the data, we will get # { # "img_path": list, # length = batch_size # "target": Tensor, # shape: batch_size * num_classes # "img": Tensor, # shape: batch_size * 3 * 224 * 224 # "masked_img": Tensor, # shape: batch_size * 3 * 224 * 224 # }