WSSS_ResNet50 / core /datasets.py
kittendev's picture
Upload 176 files
c20a1af verified
import os
import cv2
import glob
import torch
import math
import imageio
import numpy as np
from PIL import Image
from core.aff_utils import *
from tools.ai.augment_utils import *
from tools.ai.torch_utils import one_hot_embedding
from tools.general.xml_utils import read_xml
from tools.general.json_utils import read_json
from tools.dataset.voc_utils import get_color_map_dic
class Iterator:
def __init__(self, loader):
self.loader = loader
self.init()
def init(self):
self.iterator = iter(self.loader)
def get(self):
try:
data = next(self.iterator)
except StopIteration:
self.init()
data = next(self.iterator)
return data
class VOC_Dataset(torch.utils.data.Dataset):
def __init__(self, root_dir, domain, with_id=False, with_tags=False, with_mask=False):
self.root_dir = root_dir
self.image_dir = self.root_dir + 'JPEGImages/'
self.xml_dir = self.root_dir + 'Annotations/'
self.mask_dir = self.root_dir + 'SegmentationClass/'
self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt'%domain).readlines()]
self.with_id = with_id
self.with_tags = with_tags
self.with_mask = with_mask
def __len__(self):
return len(self.image_id_list)
def get_image(self, image_id):
image = Image.open(self.image_dir + image_id + '.jpg').convert('RGB')
return image
def get_mask(self, image_id):
mask_path = self.mask_dir + image_id + '.png'
if os.path.isfile(mask_path):
mask = Image.open(mask_path)
else:
mask = None
return mask
def get_tags(self, image_id):
_, tags = read_xml(self.xml_dir + image_id + '.xml')
return tags
def __getitem__(self, index):
image_id = self.image_id_list[index]
data_list = [self.get_image(image_id)]
if self.with_id:
data_list.append(image_id)
if self.with_tags:
data_list.append(self.get_tags(image_id))
if self.with_mask:
data_list.append(self.get_mask(image_id))
return data_list
class VOC_Dataset_For_Classification(VOC_Dataset):
def __init__(self, root_dir, domain, transform=None):
super().__init__(root_dir, domain, with_tags=True)
self.transform = transform
data = read_json('./data/VOC_2012.json')
self.class_dic = data['class_dic']
self.classes = data['classes']
def __getitem__(self, index):
image, tags = super().__getitem__(index)
if self.transform is not None:
image = self.transform(image)
label = one_hot_embedding([self.class_dic[tag] for tag in tags], self.classes)
return image, label
class VOC_Dataset_For_Segmentation(VOC_Dataset):
def __init__(self, root_dir, domain, transform=None):
super().__init__(root_dir, domain, with_mask=True)
self.transform = transform
cmap_dic, _, class_names = get_color_map_dic()
self.colors = np.asarray([cmap_dic[class_name] for class_name in class_names])
def __getitem__(self, index):
image, mask = super().__getitem__(index)
if self.transform is not None:
input_dic = {'image':image, 'mask':mask}
output_dic = self.transform(input_dic)
image = output_dic['image']
mask = output_dic['mask']
return image, mask
class VOC_Dataset_For_Evaluation(VOC_Dataset):
def __init__(self, root_dir, domain, transform=None):
super().__init__(root_dir, domain, with_id=True, with_mask=True)
self.transform = transform
cmap_dic, _, class_names = get_color_map_dic()
self.colors = np.asarray([cmap_dic[class_name] for class_name in class_names])
def __getitem__(self, index):
image, image_id, mask = super().__getitem__(index)
if self.transform is not None:
input_dic = {'image':image, 'mask':mask}
output_dic = self.transform(input_dic)
image = output_dic['image']
mask = output_dic['mask']
return image, image_id, mask
class VOC_Dataset_For_WSSS(VOC_Dataset):
def __init__(self, root_dir, domain, pred_dir, transform=None):
super().__init__(root_dir, domain, with_id=True)
self.pred_dir = pred_dir
self.transform = transform
cmap_dic, _, class_names = get_color_map_dic()
self.colors = np.asarray([cmap_dic[class_name] for class_name in class_names])
def __getitem__(self, index):
image, image_id = super().__getitem__(index)
mask = Image.open(self.pred_dir + image_id + '.png')
if self.transform is not None:
input_dic = {'image':image, 'mask':mask}
output_dic = self.transform(input_dic)
image = output_dic['image']
mask = output_dic['mask']
return image, mask
class VOC_Dataset_For_Testing_CAM(VOC_Dataset):
def __init__(self, root_dir, domain, transform=None):
super().__init__(root_dir, domain, with_tags=True, with_mask=True)
self.transform = transform
cmap_dic, _, class_names = get_color_map_dic()
self.colors = np.asarray([cmap_dic[class_name] for class_name in class_names])
data = read_json('./data/VOC_2012.json')
self.class_dic = data['class_dic']
self.classes = data['classes']
def __getitem__(self, index):
image, tags, mask = super().__getitem__(index)
if self.transform is not None:
input_dic = {'image':image, 'mask':mask}
output_dic = self.transform(input_dic)
image = output_dic['image']
mask = output_dic['mask']
label = one_hot_embedding([self.class_dic[tag] for tag in tags], self.classes)
return image, label, mask
class VOC_Dataset_For_Making_CAM(VOC_Dataset):
def __init__(self, root_dir, domain):
super().__init__(root_dir, domain, with_id=True, with_tags=True, with_mask=True)
cmap_dic, _, class_names = get_color_map_dic()
self.colors = np.asarray([cmap_dic[class_name] for class_name in class_names])
data = read_json('./data/VOC_2012.json')
self.class_names = np.asarray(class_names[1:21])
self.class_dic = data['class_dic']
self.classes = data['classes']
def __getitem__(self, index):
image, image_id, tags, mask = super().__getitem__(index)
label = one_hot_embedding([self.class_dic[tag] for tag in tags], self.classes)
return image, image_id, label, mask
class VOC_Dataset_For_Affinity(VOC_Dataset):
def __init__(self, root_dir, domain, path_index, label_dir, transform=None):
super().__init__(root_dir, domain, with_id=True)
data = read_json('./data/VOC_2012.json')
self.class_dic = data['class_dic']
self.classes = data['classes']
self.transform = transform
self.label_dir = label_dir
self.path_index = path_index
self.extract_aff_lab_func = GetAffinityLabelFromIndices(self.path_index.src_indices, self.path_index.dst_indices)
def __getitem__(self, idx):
image, image_id = super().__getitem__(idx)
label = imageio.imread(self.label_dir + image_id + '.png')
label = Image.fromarray(label)
output_dic = self.transform({'image':image, 'mask':label})
image, label = output_dic['image'], output_dic['mask']
return image, self.extract_aff_lab_func(label)