|
import albumentations as A |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
import pycocotools.mask as maskUtils |
|
from pycocotools.coco import COCO |
|
import random |
|
import os.path as osp |
|
import cv2 |
|
import numpy as np |
|
from scipy.ndimage import distance_transform_bf, distance_transform_edt, distance_transform_cdt |
|
|
|
|
|
def is_grey(img: np.ndarray): |
|
if len(img.shape) == 3 and img.shape[2] == 3: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value = (0, 0, 0)): |
|
h, w = img.shape[:2] |
|
pad_h, pad_w = 0, 0 |
|
|
|
|
|
if w < h: |
|
pad_w = h - w |
|
w += pad_w |
|
elif h < w: |
|
pad_h = w - h |
|
h += pad_h |
|
|
|
pad_size = tgt_size - h |
|
if pad_size > 0: |
|
pad_h += pad_size |
|
pad_w += pad_size |
|
|
|
if pad_h > 0 or pad_w > 0: |
|
c = 1 |
|
if is_grey(img): |
|
if isinstance(pad_value, tuple): |
|
pad_value = pad_value[0] |
|
else: |
|
if isinstance(pad_value, int): |
|
pad_value = (pad_value, pad_value, pad_value) |
|
|
|
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value) |
|
|
|
resize_ratio = tgt_size / img.shape[0] |
|
if resize_ratio < 1: |
|
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA) |
|
elif resize_ratio > 1: |
|
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR) |
|
|
|
return img, resize_ratio, pad_h, pad_w |
|
|
|
|
|
class MaskRefineDataset(Dataset): |
|
|
|
def __init__(self, |
|
refine_ann_path: str, |
|
data_root: str, |
|
load_instance_mask: bool = True, |
|
aug_ins_prob: float = 0., |
|
ins_rect_prob: float = 0., |
|
output_size: int = 720, |
|
augmentation: bool = False, |
|
with_distance: bool = False): |
|
self.load_instance_mask = load_instance_mask |
|
self.ann_util = COCO(refine_ann_path) |
|
self.img_ids = self.ann_util.getImgIds() |
|
self.set_load_method(load_instance_mask) |
|
self.data_root = data_root |
|
|
|
self.ins_rect_prob = ins_rect_prob |
|
self.aug_ins_prob = aug_ins_prob |
|
self.augmentation = augmentation |
|
if augmentation: |
|
transform = [ |
|
A.OpticalDistortion(), |
|
A.HorizontalFlip(), |
|
A.CLAHE(), |
|
A.Posterize(), |
|
A.CropAndPad(percent=0.1, p=0.3, pad_mode=cv2.BORDER_CONSTANT, pad_cval=0, pad_cval_mask=0, keep_size=True), |
|
A.RandomContrast(), |
|
A.Rotate(30, p=0.3, mask_value=0, border_mode=cv2.BORDER_CONSTANT) |
|
] |
|
self._aug_transform = A.Compose(transform) |
|
else: |
|
self._aug_transform = None |
|
|
|
self.output_size = output_size |
|
self.with_distance = with_distance |
|
|
|
def set_output_size(self, size: int): |
|
self.output_size = size |
|
|
|
def set_load_method(self, load_instance_mask: bool): |
|
if load_instance_mask: |
|
self._load_mask = self._load_with_instance |
|
else: |
|
self._load_mask = self._load_without_instance |
|
|
|
def __getitem__(self, idx: int): |
|
img_id = self.img_ids[idx] |
|
img_meta = self.ann_util.imgs[img_id] |
|
img_path = osp.join(self.data_root, img_meta['file_name']) |
|
img = cv2.imread(img_path) |
|
|
|
annids = self.ann_util.getAnnIds([img_id]) |
|
if len(annids) > 0: |
|
ann = random.choice(annids) |
|
ann = self.ann_util.anns[ann] |
|
assert ann['image_id'] == img_id |
|
else: |
|
ann = None |
|
|
|
return self._load_mask(img, ann) |
|
|
|
def transform(self, img: np.ndarray, mask: np.ndarray, ins_seg: np.ndarray = None) -> dict: |
|
if ins_seg is not None: |
|
use_seg = True |
|
else: |
|
use_seg = False |
|
|
|
if self.augmentation: |
|
masks = [mask] |
|
if use_seg: |
|
masks.append(ins_seg) |
|
data = self._aug_transform(image=img, masks=masks) |
|
img = data['image'] |
|
masks = data['masks'] |
|
mask = masks[0] |
|
if use_seg: |
|
ins_seg = masks[1] |
|
|
|
img = square_pad_resize(img, self.output_size, random.randint(0, 255))[0] |
|
mask = square_pad_resize(mask, self.output_size, 0)[0] |
|
if ins_seg is not None: |
|
ins_seg = square_pad_resize(ins_seg, self.output_size, 0)[0] |
|
|
|
img = (img.astype(np.float32) / 255.).transpose((2, 0, 1)) |
|
mask = mask[None, ...] |
|
|
|
|
|
if use_seg: |
|
ins_seg = ins_seg[None, ...] |
|
img = np.concatenate((img, ins_seg), axis=0) |
|
|
|
data = {'img': img, 'mask': mask} |
|
if self.with_distance: |
|
dist = distance_transform_edt(mask[0]) |
|
dist_max = dist.max() |
|
if dist_max != 0: |
|
dist = 1 - dist / dist_max |
|
|
|
|
|
dist = dist + 0.2 |
|
dist = dist.size / (dist.sum() + 1) * dist |
|
dist = np.clip(dist, 0, 20) |
|
else: |
|
dist = np.ones_like(dist) |
|
|
|
data['dist_weight'] = dist[None, ...] |
|
return data |
|
|
|
def _load_with_instance(self, img: np.ndarray, ann: dict): |
|
if ann is None: |
|
mask = np.zeros(img.shape[:2], dtype=np.float32) |
|
ins_seg = mask |
|
else: |
|
mask = maskUtils.decode(ann['segmentation']).astype(np.float32) |
|
if self.augmentation and random.random() < self.ins_rect_prob: |
|
ins_seg = np.zeros_like(mask) |
|
bbox = [int(b) for b in ann['bbox']] |
|
ins_seg[bbox[1]: bbox[1] + bbox[3], bbox[0]: bbox[0] + bbox[2]] = 1 |
|
elif len(ann['pred_segmentations']) > 0: |
|
ins_seg = random.choice(ann['pred_segmentations']) |
|
ins_seg = maskUtils.decode(ins_seg).astype(np.float32) |
|
else: |
|
ins_seg = mask |
|
if self.augmentation and random.random() < self.aug_ins_prob: |
|
ksize = random.choice([1, 3, 5, 7]) |
|
ksize = ksize * 2 + 1 |
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(ksize, ksize)) |
|
if random.random() < 0.5: |
|
ins_seg = cv2.dilate(ins_seg, kernel) |
|
else: |
|
ins_seg = cv2.erode(ins_seg, kernel) |
|
|
|
return self.transform(img, mask, ins_seg) |
|
|
|
def _load_without_instance(self, img: np.ndarray, ann: dict): |
|
if ann is None: |
|
mask = np.zeros(img.shape[:2], dtype=np.float32) |
|
else: |
|
mask = maskUtils.decode(ann['segmentation']).astype(np.float32) |
|
return self.transform(img, mask) |
|
|
|
def __len__(self): |
|
return len(self.img_ids) |
|
|
|
|
|
if __name__ == '__main__': |
|
ann_path = r'workspace/test_syndata/annotations/refine_train.json' |
|
data_root = r'workspace/test_syndata/train' |
|
|
|
ann_path = r'workspace/test_syndata/annotations/refine_train.json' |
|
data_root = r'workspace/test_syndata/train' |
|
aug_ins_prob = 0.5 |
|
load_instance_mask = True |
|
ins_rect_prob = 0.25 |
|
output_size = 640 |
|
augmentation = True |
|
|
|
random.seed(0) |
|
|
|
md = MaskRefineDataset(ann_path, data_root, load_instance_mask, aug_ins_prob, ins_rect_prob, output_size, augmentation, with_distance=True) |
|
|
|
dl = DataLoader(md, batch_size=1, shuffle=False, persistent_workers=True, |
|
num_workers=1, pin_memory=True) |
|
for data in dl: |
|
img = data['img'].cpu().numpy() |
|
img = (img[0, :3].transpose((1, 2, 0)) * 255).astype(np.uint8) |
|
mask = (data['mask'].cpu().numpy()[0][0] * 255).astype(np.uint8) |
|
if load_instance_mask: |
|
ins = (data['img'].cpu().numpy()[0][3] * 255).astype(np.uint8) |
|
cv2.imshow('ins', ins) |
|
dist = data['dist_weight'].cpu().numpy()[0][0] |
|
dist = (dist / dist.max() * 255).astype(np.uint8) |
|
cv2.imshow('img', img) |
|
cv2.imshow('mask', mask) |
|
cv2.imshow('dist_weight', dist) |
|
cv2.waitKey(0) |
|
|
|
|