|
import os.path as osp |
|
import numpy as np |
|
from typing import List, Optional, Sequence, Tuple, Union |
|
import copy |
|
from time import time |
|
import mmcv |
|
from mmcv.transforms import to_tensor |
|
from mmdet.datasets.transforms import LoadAnnotations, RandomCrop, PackDetInputs, Mosaic, CachedMosaic, CachedMixUp, FilterAnnotations |
|
from mmdet.structures.mask import BitmapMasks, PolygonMasks |
|
from mmdet.datasets import CocoDataset |
|
from mmdet.registry import DATASETS, TRANSFORMS |
|
from numpy import random |
|
from mmdet.structures.bbox import autocast_box_type, BaseBoxes |
|
from mmengine.structures import InstanceData, PixelData |
|
from mmdet.structures import DetDataSample |
|
from utils.io_utils import bbox_overlap_xy |
|
from utils.logger import LOGGER |
|
|
|
@DATASETS.register_module() |
|
class AnimeMangaMixedDataset(CocoDataset): |
|
|
|
def __init__(self, animeins_root: str = None, animeins_annfile: str = None, manga109_annfile: str = None, manga109_root: str = None, *args, **kwargs) -> None: |
|
self.animeins_annfile = animeins_annfile |
|
self.animeins_root = animeins_root |
|
self.manga109_annfile = manga109_annfile |
|
self.manga109_root = manga109_root |
|
self.cat_ids = [] |
|
self.cat_img_map = {} |
|
super().__init__(*args, **kwargs) |
|
LOGGER.info(f'total num data: {len(self.data_list)}') |
|
|
|
|
|
def parse_data_info(self, raw_data_info: dict, data_prefix: str) -> Union[dict, List[dict]]: |
|
"""Parse raw annotation to target format. |
|
|
|
Args: |
|
raw_data_info (dict): Raw data information load from ``ann_file`` |
|
|
|
Returns: |
|
Union[dict, List[dict]]: Parsed annotation. |
|
""" |
|
img_info = raw_data_info['raw_img_info'] |
|
ann_info = raw_data_info['raw_ann_info'] |
|
|
|
data_info = {} |
|
|
|
|
|
img_path = osp.join(data_prefix, img_info['file_name']) |
|
if self.data_prefix.get('seg', None): |
|
seg_map_path = osp.join( |
|
self.data_prefix['seg'], |
|
img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix) |
|
else: |
|
seg_map_path = None |
|
data_info['img_path'] = img_path |
|
data_info['img_id'] = img_info['img_id'] |
|
data_info['seg_map_path'] = seg_map_path |
|
data_info['height'] = img_info['height'] |
|
data_info['width'] = img_info['width'] |
|
|
|
instances = [] |
|
for i, ann in enumerate(ann_info): |
|
instance = {} |
|
|
|
if ann.get('ignore', False): |
|
continue |
|
x1, y1, w, h = ann['bbox'] |
|
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) |
|
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) |
|
if inter_w * inter_h == 0: |
|
continue |
|
if ann['area'] <= 0 or w < 1 or h < 1: |
|
continue |
|
if ann['category_id'] not in self.cat_ids: |
|
continue |
|
bbox = [x1, y1, x1 + w, y1 + h] |
|
|
|
if ann.get('iscrowd', False): |
|
instance['ignore_flag'] = 1 |
|
else: |
|
instance['ignore_flag'] = 0 |
|
instance['bbox'] = bbox |
|
instance['bbox_label'] = self.cat2label[ann['category_id']] |
|
|
|
if ann.get('segmentation', None): |
|
instance['mask'] = ann['segmentation'] |
|
|
|
instances.append(instance) |
|
data_info['instances'] = instances |
|
return data_info |
|
|
|
|
|
def load_data_list(self) -> List[dict]: |
|
data_lst = [] |
|
if self.manga109_root is not None: |
|
data_lst += self._data_list(self.manga109_annfile, osp.join(self.manga109_root, 'images')) |
|
|
|
|
|
LOGGER.info(f'num data from manga109: {len(data_lst)}') |
|
if self.animeins_root is not None: |
|
animeins_annfile = osp.join(self.animeins_root, self.animeins_annfile) |
|
data_prefix = osp.join(self.animeins_root, self.data_prefix['img']) |
|
anime_lst = self._data_list(animeins_annfile, data_prefix) |
|
|
|
|
|
data_lst += anime_lst |
|
LOGGER.info(f'num data from animeins: {len(data_lst)}') |
|
return data_lst |
|
|
|
def _data_list(self, annfile: str, data_prefix: str) -> List[dict]: |
|
"""Load annotations from an annotation file named as ``ann_file`` |
|
|
|
Returns: |
|
List[dict]: A list of annotation. |
|
""" |
|
with self.file_client.get_local_path(annfile) as local_path: |
|
self.coco = self.COCOAPI(local_path) |
|
|
|
|
|
self.cat_ids = self.coco.get_cat_ids( |
|
cat_names=self.metainfo['classes']) |
|
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} |
|
cat_img_map = copy.deepcopy(self.coco.cat_img_map) |
|
for key, val in cat_img_map.items(): |
|
if key in self.cat_img_map: |
|
self.cat_img_map[key] += val |
|
else: |
|
self.cat_img_map[key] = val |
|
|
|
img_ids = self.coco.get_img_ids() |
|
data_list = [] |
|
total_ann_ids = [] |
|
for img_id in img_ids: |
|
raw_img_info = self.coco.load_imgs([img_id])[0] |
|
raw_img_info['img_id'] = img_id |
|
|
|
ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) |
|
raw_ann_info = self.coco.load_anns(ann_ids) |
|
total_ann_ids.extend(ann_ids) |
|
|
|
parsed_data_info = self.parse_data_info({ |
|
'raw_ann_info': |
|
raw_ann_info, |
|
'raw_img_info': |
|
raw_img_info |
|
}, data_prefix) |
|
data_list.append(parsed_data_info) |
|
if self.ANN_ID_UNIQUE: |
|
assert len(set(total_ann_ids)) == len( |
|
total_ann_ids |
|
), f"Annotation ids in '{annfile}' are not unique!" |
|
|
|
del self.coco |
|
|
|
return data_list |
|
|
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class LoadAnnotationsNoSegs(LoadAnnotations): |
|
|
|
def _process_masks(self, results: dict) -> list: |
|
"""Process gt_masks and filter invalid polygons. |
|
|
|
Args: |
|
results (dict): Result dict from :obj:``mmengine.BaseDataset``. |
|
|
|
Returns: |
|
list: Processed gt_masks. |
|
""" |
|
gt_masks = [] |
|
gt_ignore_flags = [] |
|
gt_ignore_mask_flags = [] |
|
for instance in results.get('instances', []): |
|
gt_mask = instance['mask'] |
|
ignore_mask = False |
|
|
|
|
|
if isinstance(gt_mask, list): |
|
gt_mask = [ |
|
np.array(polygon) for polygon in gt_mask |
|
if len(polygon) % 2 == 0 and len(polygon) >= 6 |
|
] |
|
if len(gt_mask) == 0: |
|
|
|
instance['ignore_flag'] = 1 |
|
gt_mask = [np.zeros(6)] |
|
elif not self.poly2mask: |
|
|
|
|
|
instance['ignore_flag'] = 1 |
|
gt_mask = [np.zeros(6)] |
|
elif isinstance(gt_mask, dict) and \ |
|
not (gt_mask.get('counts') is not None and |
|
gt_mask.get('size') is not None and |
|
isinstance(gt_mask['counts'], (list, str))): |
|
|
|
|
|
|
|
ignore_mask = True |
|
gt_mask = [np.zeros(6)] |
|
gt_masks.append(gt_mask) |
|
|
|
gt_ignore_flags.append(instance['ignore_flag']) |
|
gt_ignore_mask_flags.append(ignore_mask) |
|
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool) |
|
results['gt_ignore_mask_flags'] = np.array(gt_ignore_mask_flags, dtype=bool) |
|
return gt_masks |
|
|
|
def _load_masks(self, results: dict) -> None: |
|
"""Private function to load mask annotations. |
|
|
|
Args: |
|
results (dict): Result dict from :obj:``mmengine.BaseDataset``. |
|
""" |
|
h, w = results['ori_shape'] |
|
gt_masks = self._process_masks(results) |
|
if self.poly2mask: |
|
p2masks = [] |
|
if len(gt_masks) > 0: |
|
for ins, mask, ignore_mask in zip(results['instances'], gt_masks, results['gt_ignore_mask_flags']): |
|
bbox = [int(c) for c in ins['bbox']] |
|
if ignore_mask: |
|
m = np.zeros((h, w), dtype=np.uint8) |
|
m[bbox[1]:bbox[3], bbox[0]: bbox[2]] = 255 |
|
|
|
p2masks.append(m) |
|
else: |
|
p2masks.append(self._poly2mask(mask, h, w)) |
|
|
|
|
|
|
|
|
|
|
|
gt_masks = BitmapMasks(p2masks, h, w) |
|
else: |
|
|
|
gt_masks = PolygonMasks([mask for mask in gt_masks], h, w) |
|
results['gt_masks'] = gt_masks |
|
|
|
def transform(self, results: dict) -> dict: |
|
"""Function to load multiple types annotations. |
|
|
|
Args: |
|
results (dict): Result dict from :obj:``mmengine.BaseDataset``. |
|
|
|
Returns: |
|
dict: The dict contains loaded bounding box, label and |
|
semantic segmentation. |
|
""" |
|
|
|
if self.with_bbox: |
|
self._load_bboxes(results) |
|
if self.with_label: |
|
self._load_labels(results) |
|
if self.with_mask: |
|
self._load_masks(results) |
|
if self.with_seg: |
|
self._load_seg_map(results) |
|
|
|
return results |
|
|
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class PackDetIputsNoSeg(PackDetInputs): |
|
|
|
mapping_table = { |
|
'gt_bboxes': 'bboxes', |
|
'gt_bboxes_labels': 'labels', |
|
'gt_ignore_mask_flags': 'ignore_mask', |
|
'gt_masks': 'masks' |
|
} |
|
|
|
def transform(self, results: dict) -> dict: |
|
"""Method to pack the input data. |
|
|
|
Args: |
|
results (dict): Result dict from the data pipeline. |
|
|
|
Returns: |
|
dict: |
|
|
|
- 'inputs' (obj:`torch.Tensor`): The forward data of models. |
|
- 'data_sample' (obj:`DetDataSample`): The annotation info of the |
|
sample. |
|
""" |
|
packed_results = dict() |
|
if 'img' in results: |
|
img = results['img'] |
|
if len(img.shape) < 3: |
|
img = np.expand_dims(img, -1) |
|
img = np.ascontiguousarray(img.transpose(2, 0, 1)) |
|
packed_results['inputs'] = to_tensor(img) |
|
|
|
if 'gt_ignore_flags' in results: |
|
valid_idx = np.where(results['gt_ignore_flags'] == 0)[0] |
|
ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0] |
|
|
|
data_sample = DetDataSample() |
|
instance_data = InstanceData() |
|
ignore_instance_data = InstanceData() |
|
|
|
for key in self.mapping_table.keys(): |
|
if key not in results: |
|
continue |
|
if key == 'gt_masks' or isinstance(results[key], BaseBoxes): |
|
if 'gt_ignore_flags' in results: |
|
instance_data[ |
|
self.mapping_table[key]] = results[key][valid_idx] |
|
ignore_instance_data[ |
|
self.mapping_table[key]] = results[key][ignore_idx] |
|
else: |
|
instance_data[self.mapping_table[key]] = results[key] |
|
else: |
|
if 'gt_ignore_flags' in results: |
|
instance_data[self.mapping_table[key]] = to_tensor( |
|
results[key][valid_idx]) |
|
ignore_instance_data[self.mapping_table[key]] = to_tensor( |
|
results[key][ignore_idx]) |
|
else: |
|
instance_data[self.mapping_table[key]] = to_tensor( |
|
results[key]) |
|
data_sample.gt_instances = instance_data |
|
data_sample.ignored_instances = ignore_instance_data |
|
|
|
if 'proposals' in results: |
|
proposals = InstanceData( |
|
bboxes=to_tensor(results['proposals']), |
|
scores=to_tensor(results['proposals_scores'])) |
|
data_sample.proposals = proposals |
|
|
|
if 'gt_seg_map' in results: |
|
gt_sem_seg_data = dict( |
|
sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy())) |
|
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) |
|
|
|
img_meta = {} |
|
for key in self.meta_keys: |
|
assert key in results, f'`{key}` is not found in `results`, ' \ |
|
f'the valid keys are {list(results)}.' |
|
img_meta[key] = results[key] |
|
|
|
data_sample.set_metainfo(img_meta) |
|
packed_results['data_samples'] = data_sample |
|
|
|
return packed_results |
|
|
|
|
|
|
|
def translate_bitmapmask(bitmap_masks: BitmapMasks, |
|
out_shape, |
|
offset_x, |
|
offset_y,): |
|
|
|
if len(bitmap_masks.masks) == 0: |
|
translated_masks = np.empty((0, *out_shape), dtype=np.uint8) |
|
else: |
|
masks = bitmap_masks.masks |
|
out_h, out_w = out_shape |
|
mask_h, mask_w = masks.shape[1:] |
|
|
|
translated_masks = np.zeros((masks.shape[0], *out_shape), |
|
dtype=masks.dtype) |
|
|
|
ix, iy = bbox_overlap_xy([0, 0, out_w, out_h], [offset_x, offset_y, mask_w, mask_h]) |
|
if ix > 2 and iy > 2: |
|
if offset_x > 0: |
|
mx1 = 0 |
|
tx1 = offset_x |
|
else: |
|
mx1 = -offset_x |
|
tx1 = 0 |
|
mx2 = min(out_w - offset_x, mask_w) |
|
tx2 = tx1 + mx2 - mx1 |
|
|
|
if offset_y > 0: |
|
my1 = 0 |
|
ty1 = offset_y |
|
else: |
|
my1 = -offset_y |
|
ty1 = 0 |
|
my2 = min(out_h - offset_y, mask_h) |
|
ty2 = ty1 + my2 - my1 |
|
|
|
translated_masks[:, ty1: ty2, tx1: tx2] = \ |
|
masks[:, my1: my2, mx1: mx2] |
|
|
|
return BitmapMasks(translated_masks, *out_shape) |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class CachedMosaicNoSeg(CachedMosaic): |
|
|
|
@autocast_box_type() |
|
def transform(self, results: dict) -> dict: |
|
|
|
"""Mosaic transform function. |
|
|
|
Args: |
|
results (dict): Result dict. |
|
|
|
Returns: |
|
dict: Updated result dict. |
|
""" |
|
|
|
self.results_cache.append(copy.deepcopy(results)) |
|
if len(self.results_cache) > self.max_cached_images: |
|
if self.random_pop: |
|
index = random.randint(0, len(self.results_cache) - 1) |
|
else: |
|
index = 0 |
|
self.results_cache.pop(index) |
|
|
|
if len(self.results_cache) <= 4: |
|
return results |
|
|
|
if random.uniform(0, 1) > self.prob: |
|
return results |
|
indices = self.get_indexes(self.results_cache) |
|
mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices] |
|
|
|
|
|
mosaic_bboxes = [] |
|
mosaic_bboxes_labels = [] |
|
mosaic_ignore_flags = [] |
|
mosaic_masks = [] |
|
mosaic_ignore_mask_flags = [] |
|
with_mask = True if 'gt_masks' in results else False |
|
|
|
if len(results['img'].shape) == 3: |
|
mosaic_img = np.full( |
|
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3), |
|
self.pad_val, |
|
dtype=results['img'].dtype) |
|
else: |
|
mosaic_img = np.full( |
|
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)), |
|
self.pad_val, |
|
dtype=results['img'].dtype) |
|
|
|
|
|
center_x = int( |
|
random.uniform(*self.center_ratio_range) * self.img_scale[0]) |
|
center_y = int( |
|
random.uniform(*self.center_ratio_range) * self.img_scale[1]) |
|
center_position = (center_x, center_y) |
|
|
|
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') |
|
|
|
n_manga = 0 |
|
for i, loc in enumerate(loc_strs): |
|
if loc == 'top_left': |
|
results_patch = copy.deepcopy(results) |
|
else: |
|
results_patch = copy.deepcopy(mix_results[i - 1]) |
|
|
|
is_manga = results_patch['img_id'] > 900000000 |
|
if is_manga: |
|
n_manga += 1 |
|
if n_manga > 3: |
|
continue |
|
im_h, im_w = results_patch['img'].shape[:2] |
|
if im_w > im_h and random.random() < 0.75: |
|
results_patch = hcrop(results_patch, (im_h, im_w // 2), True) |
|
|
|
img_i = results_patch['img'] |
|
h_i, w_i = img_i.shape[:2] |
|
|
|
scale_ratio_i = min(self.img_scale[1] / h_i, |
|
self.img_scale[0] / w_i) |
|
img_i = mmcv.imresize( |
|
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) |
|
|
|
|
|
paste_coord, crop_coord = self._mosaic_combine( |
|
loc, center_position, img_i.shape[:2][::-1]) |
|
x1_p, y1_p, x2_p, y2_p = paste_coord |
|
x1_c, y1_c, x2_c, y2_c = crop_coord |
|
|
|
|
|
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] |
|
|
|
|
|
gt_bboxes_i = results_patch['gt_bboxes'] |
|
gt_bboxes_labels_i = results_patch['gt_bboxes_labels'] |
|
gt_ignore_flags_i = results_patch['gt_ignore_flags'] |
|
gt_ignore_mask_i = results_patch['gt_ignore_mask_flags'] |
|
|
|
padw = x1_p - x1_c |
|
padh = y1_p - y1_c |
|
gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i]) |
|
gt_bboxes_i.translate_([padw, padh]) |
|
mosaic_bboxes.append(gt_bboxes_i) |
|
mosaic_bboxes_labels.append(gt_bboxes_labels_i) |
|
mosaic_ignore_flags.append(gt_ignore_flags_i) |
|
mosaic_ignore_mask_flags.append(gt_ignore_mask_i) |
|
if with_mask and results_patch.get('gt_masks', None) is not None: |
|
|
|
gt_masks_i = results_patch['gt_masks'] |
|
gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i)) |
|
|
|
gt_masks_i = translate_bitmapmask(gt_masks_i, |
|
out_shape=(int(self.img_scale[0] * 2), |
|
int(self.img_scale[1] * 2)), |
|
offset_x=padw, offset_y=padh) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mosaic_masks.append(gt_masks_i) |
|
|
|
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0) |
|
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0) |
|
mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0) |
|
mosaic_ignore_mask_flags = np.concatenate(mosaic_ignore_mask_flags, 0) |
|
|
|
if self.bbox_clip_border: |
|
mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]]) |
|
|
|
inside_inds = mosaic_bboxes.is_inside( |
|
[2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy() |
|
|
|
mosaic_bboxes = mosaic_bboxes[inside_inds] |
|
mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds] |
|
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds] |
|
mosaic_ignore_mask_flags = mosaic_ignore_mask_flags[inside_inds] |
|
|
|
results['img'] = mosaic_img |
|
results['img_shape'] = mosaic_img.shape |
|
results['gt_bboxes'] = mosaic_bboxes |
|
results['gt_bboxes_labels'] = mosaic_bboxes_labels |
|
results['gt_ignore_flags'] = mosaic_ignore_flags |
|
results['gt_ignore_mask_flags'] = mosaic_ignore_mask_flags |
|
|
|
|
|
if with_mask: |
|
total_instances = len(inside_inds) |
|
assert total_instances == np.array([m.masks.shape[0] for m in mosaic_masks]).sum() |
|
if total_instances > 10: |
|
masks = np.empty((inside_inds.sum(), mosaic_masks[0].height, mosaic_masks[0].width), dtype=np.uint8) |
|
msk_idx = 0 |
|
mmsk_idx = 0 |
|
for m in mosaic_masks: |
|
for ii in range(m.masks.shape[0]): |
|
if inside_inds[msk_idx]: |
|
masks[mmsk_idx] = m.masks[ii] |
|
mmsk_idx += 1 |
|
msk_idx += 1 |
|
results['gt_masks'] = BitmapMasks(masks, mosaic_masks[0].height, mosaic_masks[0].width) |
|
else: |
|
mosaic_masks = mosaic_masks[0].cat(mosaic_masks) |
|
results['gt_masks'] = mosaic_masks[inside_inds] |
|
|
|
|
|
|
|
return results |
|
|
|
@TRANSFORMS.register_module() |
|
class FilterAnnotationsNoSeg(FilterAnnotations): |
|
|
|
def __init__(self, |
|
min_gt_bbox_wh: Tuple[int, int] = (1, 1), |
|
min_gt_mask_area: int = 1, |
|
by_box: bool = True, |
|
by_mask: bool = False, |
|
keep_empty: bool = True) -> None: |
|
|
|
assert by_box or by_mask |
|
self.min_gt_bbox_wh = min_gt_bbox_wh |
|
self.min_gt_mask_area = min_gt_mask_area |
|
self.by_box = by_box |
|
self.by_mask = by_mask |
|
self.keep_empty = keep_empty |
|
|
|
@autocast_box_type() |
|
def transform(self, results: dict) -> Union[dict, None]: |
|
"""Transform function to filter annotations. |
|
|
|
Args: |
|
results (dict): Result dict. |
|
|
|
Returns: |
|
dict: Updated result dict. |
|
""" |
|
assert 'gt_bboxes' in results |
|
gt_bboxes = results['gt_bboxes'] |
|
if gt_bboxes.shape[0] == 0: |
|
return results |
|
|
|
tests = [] |
|
if self.by_box: |
|
tests.append( |
|
((gt_bboxes.widths > self.min_gt_bbox_wh[0]) & |
|
(gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy()) |
|
|
|
if self.by_mask: |
|
assert 'gt_masks' in results |
|
gt_masks = results['gt_masks'] |
|
tests.append(gt_masks.areas >= self.min_gt_mask_area) |
|
|
|
keep = tests[0] |
|
for t in tests[1:]: |
|
keep = keep & t |
|
|
|
|
|
|
|
|
|
|
|
assert len(results['gt_ignore_flags']) == len(results['gt_ignore_mask_flags']) |
|
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags', 'gt_ignore_mask_flags') |
|
for key in keys: |
|
if key in results: |
|
try: |
|
results[key] = results[key][keep] |
|
except Exception as e: |
|
raise e |
|
|
|
return results |
|
|
|
|
|
def hcrop(results: dict, crop_size: Tuple[int, int], |
|
allow_negative_crop: bool) -> Union[dict, None]: |
|
|
|
assert crop_size[0] > 0 and crop_size[1] > 0 |
|
img = results['img'] |
|
offset_h, offset_w = 0, random.choice([0, crop_size[1]]) |
|
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] |
|
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] |
|
|
|
|
|
homography_matrix = np.array( |
|
[[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]], |
|
dtype=np.float32) |
|
if results.get('homography_matrix', None) is None: |
|
results['homography_matrix'] = homography_matrix |
|
else: |
|
results['homography_matrix'] = homography_matrix @ results[ |
|
'homography_matrix'] |
|
|
|
|
|
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] |
|
img_shape = img.shape |
|
results['img'] = img |
|
results['img_shape'] = img_shape |
|
|
|
|
|
if results.get('gt_bboxes', None) is not None: |
|
bboxes = results['gt_bboxes'] |
|
bboxes.translate_([-offset_w, -offset_h]) |
|
bboxes.clip_(img_shape[:2]) |
|
valid_inds = bboxes.is_inside(img_shape[:2]).numpy() |
|
|
|
|
|
if (not valid_inds.any() and not allow_negative_crop): |
|
return None |
|
|
|
results['gt_bboxes'] = bboxes[valid_inds] |
|
|
|
if results.get('gt_ignore_flags', None) is not None: |
|
results['gt_ignore_flags'] = \ |
|
results['gt_ignore_flags'][valid_inds] |
|
|
|
if results.get('gt_ignore_mask_flags', None) is not None: |
|
results['gt_ignore_mask_flags'] = \ |
|
results['gt_ignore_mask_flags'][valid_inds] |
|
|
|
if results.get('gt_bboxes_labels', None) is not None: |
|
results['gt_bboxes_labels'] = \ |
|
results['gt_bboxes_labels'][valid_inds] |
|
|
|
if results.get('gt_masks', None) is not None: |
|
results['gt_masks'] = results['gt_masks'][ |
|
valid_inds.nonzero()[0]].crop( |
|
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2])) |
|
results['gt_bboxes'] = results['gt_masks'].get_bboxes( |
|
type(results['gt_bboxes'])) |
|
|
|
|
|
if results.get('gt_seg_map', None) is not None: |
|
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2, |
|
crop_x1:crop_x2] |
|
|
|
return results |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class RandomCropNoSeg(RandomCrop): |
|
|
|
def _crop_data(self, results: dict, crop_size: Tuple[int, int], |
|
allow_negative_crop: bool) -> Union[dict, None]: |
|
|
|
assert crop_size[0] > 0 and crop_size[1] > 0 |
|
img = results['img'] |
|
margin_h = max(img.shape[0] - crop_size[0], 0) |
|
margin_w = max(img.shape[1] - crop_size[1], 0) |
|
offset_h, offset_w = self._rand_offset((margin_h, margin_w)) |
|
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] |
|
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] |
|
|
|
|
|
homography_matrix = np.array( |
|
[[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]], |
|
dtype=np.float32) |
|
if results.get('homography_matrix', None) is None: |
|
results['homography_matrix'] = homography_matrix |
|
else: |
|
results['homography_matrix'] = homography_matrix @ results[ |
|
'homography_matrix'] |
|
|
|
|
|
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] |
|
img_shape = img.shape |
|
results['img'] = img |
|
results['img_shape'] = img_shape |
|
|
|
|
|
if results.get('gt_bboxes', None) is not None: |
|
bboxes = results['gt_bboxes'] |
|
bboxes.translate_([-offset_w, -offset_h]) |
|
if self.bbox_clip_border: |
|
bboxes.clip_(img_shape[:2]) |
|
valid_inds = bboxes.is_inside(img_shape[:2]).numpy() |
|
|
|
|
|
if (not valid_inds.any() and not allow_negative_crop): |
|
return None |
|
|
|
results['gt_bboxes'] = bboxes[valid_inds] |
|
|
|
if results.get('gt_ignore_flags', None) is not None: |
|
results['gt_ignore_flags'] = \ |
|
results['gt_ignore_flags'][valid_inds] |
|
|
|
if results.get('gt_ignore_mask_flags', None) is not None: |
|
results['gt_ignore_mask_flags'] = \ |
|
results['gt_ignore_mask_flags'][valid_inds] |
|
|
|
if results.get('gt_bboxes_labels', None) is not None: |
|
results['gt_bboxes_labels'] = \ |
|
results['gt_bboxes_labels'][valid_inds] |
|
|
|
if results.get('gt_masks', None) is not None: |
|
results['gt_masks'] = results['gt_masks'][ |
|
valid_inds.nonzero()[0]].crop( |
|
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2])) |
|
if self.recompute_bbox: |
|
results['gt_bboxes'] = results['gt_masks'].get_bboxes( |
|
type(results['gt_bboxes'])) |
|
|
|
|
|
if results.get('gt_seg_map', None) is not None: |
|
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2, |
|
crop_x1:crop_x2] |
|
|
|
return results |
|
|
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class CachedMixUpNoSeg(CachedMixUp): |
|
|
|
@autocast_box_type() |
|
def transform(self, results: dict) -> dict: |
|
"""MixUp transform function. |
|
|
|
Args: |
|
results (dict): Result dict. |
|
|
|
Returns: |
|
dict: Updated result dict. |
|
""" |
|
|
|
self.results_cache.append(copy.deepcopy(results)) |
|
if len(self.results_cache) > self.max_cached_images: |
|
if self.random_pop: |
|
index = random.randint(0, len(self.results_cache) - 1) |
|
else: |
|
index = 0 |
|
self.results_cache.pop(index) |
|
|
|
if len(self.results_cache) <= 1: |
|
return results |
|
|
|
if random.uniform(0, 1) > self.prob: |
|
return results |
|
|
|
index = self.get_indexes(self.results_cache) |
|
retrieve_results = copy.deepcopy(self.results_cache[index]) |
|
|
|
|
|
if retrieve_results['gt_bboxes'].shape[0] == 0: |
|
|
|
return results |
|
|
|
retrieve_img = retrieve_results['img'] |
|
with_mask = True if 'gt_masks' in results else False |
|
|
|
jit_factor = random.uniform(*self.ratio_range) |
|
is_filp = random.uniform(0, 1) > self.flip_ratio |
|
|
|
if len(retrieve_img.shape) == 3: |
|
out_img = np.ones( |
|
(self.dynamic_scale[1], self.dynamic_scale[0], 3), |
|
dtype=retrieve_img.dtype) * self.pad_val |
|
else: |
|
out_img = np.ones( |
|
self.dynamic_scale[::-1], |
|
dtype=retrieve_img.dtype) * self.pad_val |
|
|
|
|
|
scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0], |
|
self.dynamic_scale[0] / retrieve_img.shape[1]) |
|
retrieve_img = mmcv.imresize( |
|
retrieve_img, (int(retrieve_img.shape[1] * scale_ratio), |
|
int(retrieve_img.shape[0] * scale_ratio))) |
|
|
|
|
|
out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img |
|
|
|
|
|
scale_ratio *= jit_factor |
|
out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor), |
|
int(out_img.shape[0] * jit_factor))) |
|
|
|
|
|
if is_filp: |
|
out_img = out_img[:, ::-1, :] |
|
|
|
|
|
ori_img = results['img'] |
|
origin_h, origin_w = out_img.shape[:2] |
|
target_h, target_w = ori_img.shape[:2] |
|
padded_img = np.ones((max(origin_h, target_h), max( |
|
origin_w, target_w), 3)) * self.pad_val |
|
padded_img = padded_img.astype(np.uint8) |
|
padded_img[:origin_h, :origin_w] = out_img |
|
|
|
x_offset, y_offset = 0, 0 |
|
if padded_img.shape[0] > target_h: |
|
y_offset = random.randint(0, padded_img.shape[0] - target_h) |
|
if padded_img.shape[1] > target_w: |
|
x_offset = random.randint(0, padded_img.shape[1] - target_w) |
|
padded_cropped_img = padded_img[y_offset:y_offset + target_h, |
|
x_offset:x_offset + target_w] |
|
|
|
|
|
retrieve_gt_bboxes = retrieve_results['gt_bboxes'] |
|
retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio]) |
|
if with_mask: |
|
retrieve_gt_masks = retrieve_results['gt_masks'].rescale( |
|
scale_ratio) |
|
|
|
if self.bbox_clip_border: |
|
retrieve_gt_bboxes.clip_([origin_h, origin_w]) |
|
|
|
if is_filp: |
|
retrieve_gt_bboxes.flip_([origin_h, origin_w], |
|
direction='horizontal') |
|
if with_mask: |
|
retrieve_gt_masks = retrieve_gt_masks.flip() |
|
|
|
|
|
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone() |
|
cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset]) |
|
if with_mask: |
|
|
|
retrieve_gt_masks = translate_bitmapmask(retrieve_gt_masks, |
|
out_shape=(target_h, target_w), |
|
offset_x=-x_offset, offset_y=-y_offset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.bbox_clip_border: |
|
cp_retrieve_gt_bboxes.clip_([target_h, target_w]) |
|
|
|
|
|
ori_img = ori_img.astype(np.float32) |
|
mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32) |
|
|
|
retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels'] |
|
retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags'] |
|
retrieve_gt_ignore_mask_flags = retrieve_results['gt_ignore_mask_flags'] |
|
|
|
mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat( |
|
(results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0) |
|
mixup_gt_bboxes_labels = np.concatenate( |
|
(results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0) |
|
mixup_gt_ignore_flags = np.concatenate( |
|
(results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0) |
|
mixup_gt_ignore_mask_flags = np.concatenate( |
|
(results['gt_ignore_mask_flags'], retrieve_gt_ignore_mask_flags), axis=0) |
|
|
|
if with_mask: |
|
mixup_gt_masks = retrieve_gt_masks.cat( |
|
[results['gt_masks'], retrieve_gt_masks]) |
|
|
|
|
|
inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy() |
|
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] |
|
mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds] |
|
mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds] |
|
mixup_gt_ignore_mask_flags = mixup_gt_ignore_mask_flags[inside_inds] |
|
if with_mask: |
|
mixup_gt_masks = mixup_gt_masks[inside_inds] |
|
|
|
results['img'] = mixup_img.astype(np.uint8) |
|
results['img_shape'] = mixup_img.shape |
|
results['gt_bboxes'] = mixup_gt_bboxes |
|
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels |
|
results['gt_ignore_flags'] = mixup_gt_ignore_flags |
|
results['gt_ignore_mask_flags'] = mixup_gt_ignore_mask_flags |
|
if with_mask: |
|
results['gt_masks'] = mixup_gt_masks |
|
return results |