|
import numpy as np |
|
from typing import List, Union, Tuple, Dict |
|
import random |
|
from PIL import Image |
|
import cv2 |
|
import imageio, os |
|
import os.path as osp |
|
from tqdm import tqdm |
|
from panopticapi.utils import rgb2id |
|
import traceback |
|
|
|
from utils.io_utils import mask2rle, dict2json, fgbg_hist_matching |
|
from utils.logger import LOGGER |
|
from utils.constants import CATEGORIES, IMAGE_ID_ZFILL |
|
from .transforms import get_fg_transforms, get_bg_transforms, quantize_image, resize2height, rotate_image |
|
from .sampler import random_load_valid_bg, random_load_valid_fg, NameSampler, NormalSampler, PossionSampler, PersonBBoxSampler |
|
from .paste_methods import regular_paste, partition_paste |
|
|
|
|
|
def syn_animecoco_dataset( |
|
bg_list: List, fg_info_list: List[Dict], dataset_save_dir: str, policy: str='train', |
|
tgt_size=640, syn_num_multiplier=2.5, regular_paste_prob=0.4, person_paste_prob=0.4, |
|
max_syn_num=-1, image_id_start=0, obj_id_start=0, hist_match_prob=0.2, quantize_prob=0.25): |
|
|
|
LOGGER.info(f'syn data policy: {policy}') |
|
LOGGER.info(f'background: {len(bg_list)} foreground: {len(fg_info_list)}') |
|
|
|
numfg_sampler = PossionSampler(min_val=1, max_val=9, lam=2.5) |
|
numfg_regpaste_sampler = PossionSampler(min_val=2, max_val=9, lam=3.5) |
|
regpaste_size_sampler = NormalSampler(scalar=tgt_size, to_int=True, max_scale=0.75) |
|
color_correction_sampler = NameSampler({'hist_match': hist_match_prob, 'quantize': quantize_prob}, ) |
|
paste_method_sampler = NameSampler({'regular': regular_paste_prob, 'personbbox': person_paste_prob, |
|
'partition': 1-regular_paste_prob-person_paste_prob}) |
|
|
|
fg_transform = get_fg_transforms(tgt_size, transform_variant=policy) |
|
fg_distort_transform = get_fg_transforms(tgt_size, transform_variant='distort_only') |
|
bg_transform = get_bg_transforms('train', tgt_size) |
|
|
|
image_id = image_id_start + 1 |
|
obj_id = obj_id_start + 1 |
|
|
|
det_annotations, image_meta = [], [] |
|
|
|
syn_num = int(syn_num_multiplier * len(fg_info_list)) |
|
if max_syn_num > 0: |
|
syn_num = max_syn_num |
|
|
|
ann_save_dir = osp.join(dataset_save_dir, 'annotations') |
|
image_save_dir = osp.join(dataset_save_dir, policy) |
|
|
|
if not osp.exists(image_save_dir): |
|
os.makedirs(image_save_dir) |
|
if not osp.exists(ann_save_dir): |
|
os.makedirs(ann_save_dir) |
|
|
|
is_train = policy == 'train' |
|
if is_train: |
|
jpg_save_quality = [75, 85, 95] |
|
else: |
|
jpg_save_quality = [95] |
|
|
|
if isinstance(fg_info_list[0], str): |
|
for ii, fgp in enumerate(fg_info_list): |
|
if isinstance(fgp, str): |
|
fg_info_list[ii] = {'file_path': fgp, 'tag_string': [], 'danbooru': False, 'category_id': 0} |
|
|
|
if person_paste_prob > 0: |
|
personbbox_sampler = PersonBBoxSampler( |
|
'data/cocoperson_bbox_samples.json', fg_info_list, |
|
fg_transform=fg_distort_transform if is_train else None, is_train=is_train) |
|
|
|
total = tqdm(range(syn_num)) |
|
for fin in total: |
|
try: |
|
paste_method = paste_method_sampler.sample() |
|
|
|
fgs = [] |
|
if paste_method == 'regular': |
|
num_fg = numfg_regpaste_sampler.sample() |
|
size = regpaste_size_sampler.sample() |
|
while len(fgs) < num_fg: |
|
tgt_height = int(random.uniform(0.7, 1.2) * size) |
|
fg, fginfo = random_load_valid_fg(fg_info_list) |
|
fg = resize2height(fg, tgt_height) |
|
if is_train: |
|
fg = fg_distort_transform(image=fg)['image'] |
|
rotate_deg = random.randint(-40, 40) |
|
else: |
|
rotate_deg = random.randint(-30, 30) |
|
if random.random() < 0.3: |
|
fg = rotate_image(fg, rotate_deg, alpha_crop=True) |
|
fgs.append({'image': fg, 'fginfo': fginfo}) |
|
while len(fgs) < num_fg and random.random() < 0.15: |
|
fgs.append({'image': fg, 'fginfo': fginfo}) |
|
elif paste_method == 'personbbox': |
|
fgs = personbbox_sampler.sample_matchfg(tgt_size) |
|
else: |
|
num_fg = numfg_sampler.sample() |
|
fgs = [] |
|
for ii in range(num_fg): |
|
fg, fginfo = random_load_valid_fg(fg_info_list) |
|
fg = fg_transform(image=fg)['image'] |
|
h, w = fg.shape[:2] |
|
if num_fg > 6: |
|
downscale = min(tgt_size / 2.5 / w, tgt_size / 2.5 / h) |
|
if downscale < 1: |
|
fg = cv2.resize(fg, (int(w * downscale), int(h * downscale)), interpolation=cv2.INTER_AREA) |
|
fgs.append({'image': fg, 'fginfo': fginfo}) |
|
|
|
bg = random_load_valid_bg(bg_list) |
|
bg = bg_transform(image=bg)['image'] |
|
|
|
color_correct = color_correction_sampler.sample() |
|
|
|
if color_correct == 'hist_match': |
|
fgbg_hist_matching(fgs, bg) |
|
|
|
bg: Image = Image.fromarray(bg) |
|
|
|
if paste_method == 'regular': |
|
segments_info, segments = regular_paste(fgs, bg, regen_bboxes=True) |
|
elif paste_method == 'personbbox': |
|
segments_info, segments = regular_paste(fgs, bg, regen_bboxes=False) |
|
elif paste_method == 'partition': |
|
segments_info, segments = partition_paste(fgs, bg, ) |
|
else: |
|
print(f'invalid paste method: {paste_method}') |
|
raise NotImplementedError |
|
|
|
image = np.array(bg) |
|
if color_correct == 'quantize': |
|
mask = cv2.inRange(segments, np.array([0,0,0]), np.array([0,0,0])) |
|
|
|
image = quantize_image(image, random.choice([12, 16, 32]), 'kmeans', mask=mask)[0] |
|
|
|
|
|
for ii, segi in enumerate(segments_info): |
|
if segi['area'] == 0: |
|
continue |
|
x, y, w, h = segi['bbox'] |
|
x2, y2 = x+w, y+h |
|
c = segments[y: y2, x: x2] |
|
pan_png = rgb2id(c) |
|
cmask = (pan_png == segi['id']) |
|
area = cmask.sum() |
|
|
|
if paste_method != 'partition' and \ |
|
area / (fgs[ii]['image'][..., 3] > 30).sum() < 0.25: |
|
|
|
|
|
|
|
|
|
cmask_ids = np.where(cmask) |
|
segments[y: y2, x: x2][cmask_ids] = 0 |
|
image[y: y2, x: x2][cmask_ids] = (127, 127, 127) |
|
continue |
|
|
|
cmask = cmask.astype(np.uint8) * 255 |
|
dx, dy, w, h = cv2.boundingRect(cv2.findNonZero(cmask)) |
|
_bbox = [dx + x, dy + y, w, h] |
|
|
|
seg = cv2.copyMakeBorder(cmask, y, tgt_size-y2, x, tgt_size-x2, cv2.BORDER_CONSTANT) > 0 |
|
assert seg.shape[0] == tgt_size and seg.shape[1] == tgt_size |
|
segmentation = mask2rle(seg) |
|
|
|
det_annotations.append({ |
|
'id': obj_id, |
|
'category_id': fgs[ii]['fginfo']['category_id'], |
|
'iscrowd': 0, |
|
'segmentation': segmentation, |
|
'image_id': image_id, |
|
'area': area, |
|
'tag_string': fgs[ii]['fginfo']['tag_string'], |
|
'tag_string_character': fgs[ii]['fginfo']['tag_string_character'], |
|
'bbox': [float(c) for c in _bbox] |
|
}) |
|
|
|
obj_id += 1 |
|
|
|
|
|
|
|
|
|
image_id_str = str(image_id).zfill(IMAGE_ID_ZFILL) |
|
image_file_name = image_id_str + '.jpg' |
|
image_meta.append({ |
|
"id": image_id,"height": tgt_size,"width": tgt_size, "file_name": image_file_name, "id": image_id |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
imageio.imwrite(osp.join(image_save_dir, image_file_name), image, quality=random.choice(jpg_save_quality)) |
|
image_id += 1 |
|
|
|
except: |
|
LOGGER.error(traceback.format_exc()) |
|
continue |
|
|
|
det_meta = { |
|
"info": {}, |
|
"licenses": [], |
|
"images": image_meta, |
|
"annotations": det_annotations, |
|
"categories": CATEGORIES |
|
} |
|
|
|
detp = osp.join(ann_save_dir, f'det_{policy}.json') |
|
dict2json(det_meta, detp) |
|
LOGGER.info(f'annotations saved to {detp}') |
|
|
|
return image_id, obj_id |