svjack commited on
Commit
10f36ff
·
verified ·
1 Parent(s): 295f20c

Upload 46 files

Browse files
Files changed (47) hide show
  1. .gitattributes +8 -0
  2. Genshin_Impact_Images/400146081-d2353d35-0066-469e-8e4c-748aefa2b73c.png +3 -0
  3. Genshin_Impact_Images/400146081-d2353d35-0066-469e-8e4c-748aefa2b73c1.png +3 -0
  4. Genshin_Impact_Images/400146161-071aa7c1-54e1-4d6d-94df-1d86ffc28607.png +3 -0
  5. Genshin_Impact_Images/400146171-361d6a62-3e06-47a8-94a5-86004d301f51.png +3 -0
  6. Genshin_Impact_Images/400146192-4f42f25a-24b6-40f9-8a33-e78ff93eb1b7.png +3 -0
  7. Genshin_Impact_Images/400146194-36b0827d-d0bb-420a-ac2c-8a25d6bfa249.png +3 -0
  8. Genshin_Impact_Images/400198228-d57aaf90-7fdd-432f-ac65-ac4f9ec14f7f.png +3 -0
  9. Genshin_Impact_Images/400203947-c37bfa21-ca4b-4671-b2da-7a61fce737e7.png +3 -0
  10. README.md +23 -5
  11. animeins_app.py +81 -0
  12. animeinsseg/__init__.py +708 -0
  13. animeinsseg/__pycache__/__init__.cpython-311.pyc +0 -0
  14. animeinsseg/anime_instances.py +301 -0
  15. animeinsseg/data/__init__.py +2 -0
  16. animeinsseg/data/dataset.py +929 -0
  17. animeinsseg/data/maskrefine_dataset.py +235 -0
  18. animeinsseg/data/metrics.py +348 -0
  19. animeinsseg/data/paste_methods.py +327 -0
  20. animeinsseg/data/sampler.py +226 -0
  21. animeinsseg/data/syndataset.py +213 -0
  22. animeinsseg/data/transforms.py +299 -0
  23. animeinsseg/inpainting/__init__.py +0 -0
  24. animeinsseg/inpainting/ldm_inpaint.py +353 -0
  25. animeinsseg/inpainting/patch_match.py +203 -0
  26. animeinsseg/models/__init__.py +7 -0
  27. animeinsseg/models/animeseg_refine/__init__.py +189 -0
  28. animeinsseg/models/animeseg_refine/encoders.py +51 -0
  29. animeinsseg/models/animeseg_refine/isnet.py +645 -0
  30. animeinsseg/models/animeseg_refine/models.py +0 -0
  31. animeinsseg/models/animeseg_refine/modnet.py +667 -0
  32. animeinsseg/models/animeseg_refine/u2net.py +228 -0
  33. animeinsseg/models/rtmdet_inshead_custom.py +370 -0
  34. ccip.py +238 -0
  35. palette_app.py +134 -0
  36. requirements.txt +34 -0
  37. text_app.py +119 -0
  38. utils/__init__.py +0 -0
  39. utils/booru_tagger.py +116 -0
  40. utils/constants.py +82 -0
  41. utils/cupy_utils.py +122 -0
  42. utils/effects.py +182 -0
  43. utils/env_utils.py +65 -0
  44. utils/helper_math.h +1449 -0
  45. utils/io_utils.py +473 -0
  46. utils/logger.py +20 -0
  47. utils/mmdet_custom_hooks.py +223 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Genshin_Impact_Images/400146081-d2353d35-0066-469e-8e4c-748aefa2b73c.png filter=lfs diff=lfs merge=lfs -text
37
+ Genshin_Impact_Images/400146081-d2353d35-0066-469e-8e4c-748aefa2b73c1.png filter=lfs diff=lfs merge=lfs -text
38
+ Genshin_Impact_Images/400146161-071aa7c1-54e1-4d6d-94df-1d86ffc28607.png filter=lfs diff=lfs merge=lfs -text
39
+ Genshin_Impact_Images/400146171-361d6a62-3e06-47a8-94a5-86004d301f51.png filter=lfs diff=lfs merge=lfs -text
40
+ Genshin_Impact_Images/400146192-4f42f25a-24b6-40f9-8a33-e78ff93eb1b7.png filter=lfs diff=lfs merge=lfs -text
41
+ Genshin_Impact_Images/400146194-36b0827d-d0bb-420a-ac2c-8a25d6bfa249.png filter=lfs diff=lfs merge=lfs -text
42
+ Genshin_Impact_Images/400198228-d57aaf90-7fdd-432f-ac65-ac4f9ec14f7f.png filter=lfs diff=lfs merge=lfs -text
43
+ Genshin_Impact_Images/400203947-c37bfa21-ca4b-4671-b2da-7a61fce737e7.png filter=lfs diff=lfs merge=lfs -text
Genshin_Impact_Images/400146081-d2353d35-0066-469e-8e4c-748aefa2b73c.png ADDED

Git LFS Details

  • SHA256: ace0097bbcf042cd32ceb5b9786c07b949f8eb7671e5e4265153df285834e692
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
Genshin_Impact_Images/400146081-d2353d35-0066-469e-8e4c-748aefa2b73c1.png ADDED

Git LFS Details

  • SHA256: cfc5787e01c42dd1bdb75c91ae457158d3c5fee06b730acee759d87cfafe1596
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
Genshin_Impact_Images/400146161-071aa7c1-54e1-4d6d-94df-1d86ffc28607.png ADDED

Git LFS Details

  • SHA256: 5ba68d14ab938fa368fdd7f4b668879fc4f1e46c274b3155459029b829e8f4af
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
Genshin_Impact_Images/400146171-361d6a62-3e06-47a8-94a5-86004d301f51.png ADDED

Git LFS Details

  • SHA256: 1100ff7c13a20e80e7453694ee4e0e82b469a43c3ef201601bc7ca883ee2ec23
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
Genshin_Impact_Images/400146192-4f42f25a-24b6-40f9-8a33-e78ff93eb1b7.png ADDED

Git LFS Details

  • SHA256: ec3495b118d9aaa9b4a1f45465e466ee7abd77bec0be522be060199495799c7d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
Genshin_Impact_Images/400146194-36b0827d-d0bb-420a-ac2c-8a25d6bfa249.png ADDED

Git LFS Details

  • SHA256: c04eb66827fb904889c866baeb1420364d867ba8888456cde5b6dda7f3255ef6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
Genshin_Impact_Images/400198228-d57aaf90-7fdd-432f-ac65-ac4f9ec14f7f.png ADDED

Git LFS Details

  • SHA256: 1a91689aaeb92b21b414631a054f0c6e93fe406d7f624954f63198f942bf1c06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
Genshin_Impact_Images/400203947-c37bfa21-ca4b-4671-b2da-7a61fce737e7.png ADDED

Git LFS Details

  • SHA256: 68454a564c61f850fec7481038b9ac104aad214beef42cb344f5a09dd1cbca0d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
README.md CHANGED
@@ -1,12 +1,30 @@
1
  ---
2
- title: Genshin Impact Animeins Ccip
3
- emoji: 📉
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.10.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: AnimeIns CPU
3
+ emoji: 💻
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ The CPU-based implementation of the subject segmentation model of [Instance-guided Cartoon Editing with a Large-scale Dataset
14
+ ](https://arxiv.org/abs/2312.01943).
15
+
16
+ Cite this work:
17
+
18
+ @article{animeins,
19
+ Author = {Jian Lin and Chengze Li and Xueting Liu and Zhongping Ge},
20
+ Title = {Instance-guided Cartoon Editing with a Large-scale Dataset},
21
+ Eprint = {2312.01943v1},
22
+ ArchivePrefix = {arXiv},
23
+ PrimaryClass = {cs.CV},
24
+ Year = {2023},
25
+ Month = {Dec},
26
+ Url = {http://arxiv.org/abs/2312.01943v1},
27
+ File = {2312.01943v1.pdf}
28
+ }
29
+
30
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
animeins_app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ os.system("mim install mmengine")
5
+ os.system('mim install mmcv==2.1.0')
6
+ os.system("mim install mmdet==3.2.0")
7
+
8
+ import cv2
9
+ from PIL import Image
10
+ import numpy as np
11
+
12
+ from animeinsseg import AnimeInsSeg, AnimeInstances
13
+ from animeinsseg.anime_instances import get_color
14
+
15
+ if not os.path.exists("models"):
16
+ os.mkdir("models")
17
+
18
+ os.system("huggingface-cli lfs-enable-largefiles .")
19
+ os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")
20
+
21
+ ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
22
+
23
+ mask_thres = 0.3
24
+ instance_thres = 0.3
25
+ refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
26
+ # refine_kwargs = None
27
+
28
+ net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
29
+
30
+ def fn(image):
31
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
32
+ instances: AnimeInstances = net.infer(
33
+ img,
34
+ output_type='numpy',
35
+ pred_score_thr=instance_thres
36
+ )
37
+
38
+ drawed = img.copy()
39
+ im_h, im_w = img.shape[:2]
40
+
41
+ # instances.bboxes, instances.masks will be None, None if no obj is detected
42
+ if instances.bboxes is None:
43
+ return Image.fromarray(drawed[..., ::-1])
44
+
45
+ for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
46
+ color = get_color(ii)
47
+
48
+ mask_alpha = 0.5
49
+ linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
50
+
51
+ # draw bbox
52
+ p1, p2 = (int(xywh[0]), int(xywh[1])), (int(xywh[2] + xywh[0]), int(xywh[3] + xywh[1]))
53
+ cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
54
+
55
+ # draw mask
56
+ p = mask.astype(np.float32)
57
+ blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
58
+ alpha_msk = (mask_alpha * p)[..., None]
59
+ alpha_ori = 1 - alpha_msk
60
+ drawed = drawed * alpha_ori + alpha_msk * blend_mask
61
+
62
+ drawed = drawed.astype(np.uint8)
63
+
64
+ return Image.fromarray(drawed[..., ::-1])
65
+
66
+ import pathlib
67
+ genshin_impact_exps = list(map(str ,pathlib.Path("Genshin_Impact_Images").rglob("*.png")))
68
+
69
+ iface = gr.Interface(
70
+ # design titles and text descriptions
71
+ title="Anime Subject Instance Segmentation",
72
+ description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).",
73
+ fn=fn,
74
+ inputs=gr.Image(type="numpy"),
75
+ outputs=gr.Image(type="pil"),
76
+ examples = genshin_impact_exps
77
+ #examples=["1562990.jpg", "612989.jpg", "sample_3.jpg"]
78
+ )
79
+
80
+ iface.launch(share = True)
81
+
animeinsseg/__init__.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv, torch
2
+ from tqdm import tqdm
3
+ from einops import rearrange
4
+ import os
5
+ import os.path as osp
6
+ import cv2
7
+ import gc
8
+ import math
9
+
10
+ from .anime_instances import AnimeInstances
11
+ import numpy as np
12
+ from typing import List, Tuple, Union, Optional, Callable
13
+ from mmengine import Config
14
+ from mmengine.model.utils import revert_sync_batchnorm
15
+ from mmdet.utils import register_all_modules, get_test_pipeline_cfg
16
+ from mmdet.apis import init_detector
17
+ from mmdet.registry import MODELS
18
+ from mmdet.structures import DetDataSample, SampleList
19
+ from mmdet.structures.bbox.transforms import scale_boxes, get_box_wh
20
+ from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead
21
+ from pycocotools.coco import COCO
22
+ from mmcv.transforms import Compose
23
+ from mmdet.models.detectors.single_stage import SingleStageDetector
24
+
25
+ from utils.logger import LOGGER
26
+ from utils.io_utils import square_pad_resize, find_all_imgs, imglist2grid, mask2rle, dict2json, scaledown_maxsize, resize_pad
27
+ from utils.constants import DEFAULT_DEVICE, CATEGORIES
28
+ from utils.booru_tagger import Tagger
29
+
30
+ from .models.animeseg_refine import AnimeSegmentation, load_refinenet, get_mask
31
+ from .models.rtmdet_inshead_custom import RTMDetInsSepBNHeadCustom
32
+
33
+ from torchvision.ops.boxes import box_iou
34
+ import torch.nn.functional as F
35
+
36
+
37
+ def prepare_refine_batch(segmentations: np.ndarray, img: np.ndarray, max_batch_size: int = 4, device: str = 'cpu', input_size: int = 720):
38
+
39
+ img, (pt, pb, pl, pr) = resize_pad(img, input_size, pad_value=(0, 0, 0))
40
+
41
+ img = img.transpose((2, 0, 1)).astype(np.float32) / 255.
42
+
43
+ batch = []
44
+ num_seg = len(segmentations)
45
+
46
+ for ii, seg in enumerate(segmentations):
47
+ seg, _ = resize_pad(seg, input_size, 0)
48
+ seg = seg[None, ...]
49
+ batch.append(np.concatenate((img, seg)))
50
+
51
+ if ii == num_seg - 1:
52
+ yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
53
+ elif len(batch) >= max_batch_size:
54
+ yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
55
+ batch = []
56
+
57
+
58
+ VALID_REFINEMETHODS = {'animeseg', 'none'}
59
+
60
+ register_all_modules()
61
+
62
+
63
+ def single_image_preprocess(img: Union[str, np.ndarray], pipeline: Compose):
64
+ if isinstance(img, str):
65
+ img = mmcv.imread(img)
66
+ elif not isinstance(img, np.ndarray):
67
+ raise NotImplementedError
68
+
69
+ # img = square_pad_resize(img, 1024)[0]
70
+
71
+ data_ = dict(img=img, img_id=0)
72
+ data_ = pipeline(data_)
73
+ data_['inputs'] = [data_['inputs']]
74
+ data_['data_samples'] = [data_['data_samples']]
75
+
76
+ return data_, img
77
+
78
+ def animeseg_refine(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
79
+
80
+ num_pred = len(det_pred.pred_instances)
81
+ if num_pred < 1:
82
+ return
83
+
84
+ with torch.no_grad():
85
+ if to_rgb:
86
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
87
+ seg_thr = 0.5
88
+ mask = get_mask(net, img, s=input_size)[..., 0]
89
+ mask = (mask > seg_thr)
90
+
91
+ ins_masks = det_pred.pred_instances.masks
92
+
93
+ if isinstance(ins_masks, torch.Tensor):
94
+ tensor_device = ins_masks.device
95
+ tensor_dtype = ins_masks.dtype
96
+ to_tensor = True
97
+ ins_masks = ins_masks.cpu().numpy()
98
+
99
+ area_original = np.sum(ins_masks, axis=(1, 2))
100
+ masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
101
+ area_refined = np.sum(masks_refined, axis=(1, 2))
102
+
103
+ for ii in range(num_pred):
104
+ if area_refined[ii] / area_original[ii] > 0.3:
105
+ ins_masks[ii] = masks_refined[ii]
106
+ ins_masks = np.ascontiguousarray(ins_masks)
107
+
108
+ # for ii, insm in enumerate(ins_masks):
109
+ # cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
110
+
111
+ if to_tensor:
112
+ ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
113
+
114
+ det_pred.pred_instances.masks = ins_masks
115
+ # rst = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
116
+ # cv2.imwrite('rst.png', rst)
117
+
118
+
119
+ # def refinenet_forward(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
120
+
121
+ # num_pred = len(det_pred.pred_instances)
122
+ # if num_pred < 1:
123
+ # return
124
+
125
+ # with torch.no_grad():
126
+ # if to_rgb:
127
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
128
+ # seg_thr = 0.5
129
+
130
+ # h0, w0 = h, w = img.shape[0], img.shape[1]
131
+ # if h > w:
132
+ # h, w = input_size, int(input_size * w / h)
133
+ # else:
134
+ # h, w = int(input_size * h / w), input_size
135
+ # ph, pw = input_size - h, input_size - w
136
+ # tmpImg = np.zeros([s, s, 3], dtype=np.float32)
137
+ # tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
138
+ # tmpImg = tmpImg.transpose((2, 0, 1))
139
+ # tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
140
+ # with torch.no_grad():
141
+ # if use_amp:
142
+ # with amp.autocast():
143
+ # pred = model(tmpImg)
144
+ # pred = pred.to(dtype=torch.float32)
145
+ # else:
146
+ # pred = model(tmpImg)
147
+ # pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
148
+ # pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis]
149
+ # return pred
150
+
151
+ # mask = (mask > seg_thr)
152
+
153
+ # ins_masks = det_pred.pred_instances.masks
154
+
155
+ # if isinstance(ins_masks, torch.Tensor):
156
+ # tensor_device = ins_masks.device
157
+ # tensor_dtype = ins_masks.dtype
158
+ # to_tensor = True
159
+ # ins_masks = ins_masks.cpu().numpy()
160
+
161
+ # area_original = np.sum(ins_masks, axis=(1, 2))
162
+ # masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
163
+ # area_refined = np.sum(masks_refined, axis=(1, 2))
164
+
165
+ # for ii in range(num_pred):
166
+ # if area_refined[ii] / area_original[ii] > 0.3:
167
+ # ins_masks[ii] = masks_refined[ii]
168
+ # ins_masks = np.ascontiguousarray(ins_masks)
169
+
170
+ # # for ii, insm in enumerate(ins_masks):
171
+ # # cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
172
+
173
+ # if to_tensor:
174
+ # ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
175
+
176
+ # det_pred.pred_instances.masks = ins_masks
177
+
178
+
179
+ def read_imglst_from_txt(filep) -> List[str]:
180
+ with open(filep, 'r', encoding='utf8') as f:
181
+ lines = f.read().splitlines()
182
+ return lines
183
+
184
+
185
+ class AnimeInsSeg:
186
+
187
+ def __init__(self, ckpt: str, default_det_size: int = 640, device: str = None,
188
+ refine_kwargs: dict = {'refine_method': 'refinenet_isnet'},
189
+ tagger_path: str = 'models/wd-v1-4-swinv2-tagger-v2/model.onnx', mask_thr=0.3) -> None:
190
+ self.ckpt = ckpt
191
+ self.default_det_size = default_det_size
192
+ self.device = DEFAULT_DEVICE if device is None else device
193
+
194
+ # init detector in mmdet's way
195
+
196
+ ckpt = torch.load(ckpt, map_location='cpu')
197
+ cfg = Config.fromstring(ckpt['meta']['cfg'].replace('file_client_args', 'backend_args'), file_format='.py')
198
+ cfg.visualizer = []
199
+ cfg.vis_backends = {}
200
+ cfg.default_hooks.pop('visualization')
201
+
202
+
203
+ # self.model: SingleStageDetector = init_detector(cfg, checkpoint=None, device='cpu')
204
+ model = MODELS.build(cfg.model)
205
+ model = revert_sync_batchnorm(model)
206
+
207
+ self.model = model.to(self.device).eval()
208
+ self.model.load_state_dict(ckpt['state_dict'], strict=False)
209
+ self.model = self.model.to(self.device).eval()
210
+ self.cfg = cfg.copy()
211
+
212
+ test_pipeline = get_test_pipeline_cfg(self.cfg.copy())
213
+ test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
214
+ test_pipeline = Compose(test_pipeline)
215
+ self.default_data_pipeline = test_pipeline
216
+
217
+ self.refinenet = None
218
+ self.refinenet_animeseg: AnimeSegmentation = None
219
+ self.postprocess_refine: Callable = None
220
+
221
+ if refine_kwargs is not None:
222
+ self.set_refine_method(**refine_kwargs)
223
+
224
+ self.tagger = None
225
+ self.tagger_path = tagger_path
226
+
227
+ self.mask_thr = mask_thr
228
+
229
+ def init_tagger(self, tagger_path: str = None):
230
+ tagger_path = self.tagger_path if tagger_path is None else tagger_path
231
+ self.tagger = Tagger(self.tagger_path)
232
+
233
+ def infer_tags(self, instances: AnimeInstances, img: np.ndarray, infer_grey: bool = False):
234
+ if self.tagger is None:
235
+ self.init_tagger()
236
+
237
+ if infer_grey:
238
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., None][..., [0, 0, 0]]
239
+
240
+ num_ins = len(instances)
241
+ for ii in range(num_ins):
242
+ bbox = instances.bboxes[ii]
243
+ mask = instances.masks[ii]
244
+ if isinstance(bbox, torch.Tensor):
245
+ bbox = bbox.cpu().numpy()
246
+ mask = mask.cpu().numpy()
247
+ bbox = bbox.astype(np.int32)
248
+
249
+ crop = img[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]].copy()
250
+ mask = mask[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]]
251
+ crop[mask == 0] = 255
252
+ tags, character_tags = self.tagger.label_cv2_bgr(crop)
253
+ exclude_tags = ['simple_background', 'white_background']
254
+ valid_tags = []
255
+ for tag in tags:
256
+ if tag in exclude_tags:
257
+ continue
258
+ valid_tags.append(tag)
259
+ instances.tags[ii] = ' '.join(valid_tags)
260
+ instances.character_tags[ii] = character_tags
261
+
262
+ @torch.no_grad()
263
+ def infer_embeddings(self, imgs, det_size = None):
264
+
265
+ def hijack_bbox_mask_post_process(
266
+ self,
267
+ results,
268
+ mask_feat,
269
+ cfg,
270
+ rescale: bool = False,
271
+ with_nms: bool = True,
272
+ img_meta: Optional[dict] = None):
273
+
274
+ stride = self.prior_generator.strides[0][0]
275
+ if rescale:
276
+ assert img_meta.get('scale_factor') is not None
277
+ scale_factor = [1 / s for s in img_meta['scale_factor']]
278
+ results.bboxes = scale_boxes(results.bboxes, scale_factor)
279
+
280
+ if hasattr(results, 'score_factors'):
281
+ # TODO: Add sqrt operation in order to be consistent with
282
+ # the paper.
283
+ score_factors = results.pop('score_factors')
284
+ results.scores = results.scores * score_factors
285
+
286
+ # filter small size bboxes
287
+ if cfg.get('min_bbox_size', -1) >= 0:
288
+ w, h = get_box_wh(results.bboxes)
289
+ valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
290
+ if not valid_mask.all():
291
+ results = results[valid_mask]
292
+
293
+ # results.mask_feat = mask_feat
294
+ return results, mask_feat
295
+
296
+ def hijack_detector_predict(self: SingleStageDetector,
297
+ batch_inputs: torch.Tensor,
298
+ batch_data_samples: SampleList,
299
+ rescale: bool = True) -> SampleList:
300
+ x = self.extract_feat(batch_inputs)
301
+
302
+ bbox_head: RTMDetInsSepBNHeadCustom = self.bbox_head
303
+ old_postprocess = RTMDetInsSepBNHeadCustom._bbox_mask_post_process
304
+ RTMDetInsSepBNHeadCustom._bbox_mask_post_process = hijack_bbox_mask_post_process
305
+ # results_list = bbox_head.predict(
306
+ # x, batch_data_samples, rescale=rescale)
307
+
308
+ batch_img_metas = [
309
+ data_samples.metainfo for data_samples in batch_data_samples
310
+ ]
311
+
312
+ outs = bbox_head(x)
313
+
314
+ results_list = bbox_head.predict_by_feat(
315
+ *outs, batch_img_metas=batch_img_metas, rescale=rescale)
316
+
317
+ # batch_data_samples = self.add_pred_to_datasample(
318
+ # batch_data_samples, results_list)
319
+
320
+ RTMDetInsSepBNHeadCustom._bbox_mask_post_process = old_postprocess
321
+ return results_list
322
+
323
+ old_predict = SingleStageDetector.predict
324
+ SingleStageDetector.predict = hijack_detector_predict
325
+ test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
326
+
327
+ if len(imgs) > 1:
328
+ imgs = tqdm(imgs)
329
+ model = self.model
330
+ img = imgs[0]
331
+ data_, img = test_pipeline(img)
332
+ data = model.data_preprocessor(data_, False)
333
+ instance_data, mask_feat = model(**data, mode='predict')[0]
334
+ SingleStageDetector.predict = old_predict
335
+
336
+ # print((instance_data.scores > 0.9).sum())
337
+ return img, instance_data, mask_feat
338
+
339
+ def segment_with_bboxes(self, img, bboxes: torch.Tensor, instance_data, mask_feat: torch.Tensor):
340
+ # instance_data.bboxes: x1, y1, x2, y2
341
+ maxidx = torch.argmax(instance_data.scores)
342
+ bbox = instance_data.bboxes[maxidx].cpu().numpy()
343
+ p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
344
+ tgt_bboxes = instance_data.bboxes
345
+
346
+ im_h, im_w = img.shape[:2]
347
+ long_side = max(im_h, im_w)
348
+ bbox_head: RTMDetInsSepBNHeadCustom = self.model.bbox_head
349
+ priors, kernels = instance_data.priors, instance_data.kernels
350
+ stride = bbox_head.prior_generator.strides[0][0]
351
+
352
+ ins_bboxes, ins_segs, scores = [], [], []
353
+ for bbox in bboxes:
354
+ bbox = torch.from_numpy(np.array([bbox])).to(tgt_bboxes.dtype).to(tgt_bboxes.device)
355
+ ioulst = box_iou(bbox, tgt_bboxes).squeeze()
356
+ matched_idx = torch.argmax(ioulst)
357
+
358
+ mask_logits = bbox_head._mask_predict_by_feat_single(
359
+ mask_feat, kernels[matched_idx][None, ...], priors[matched_idx][None, ...])
360
+
361
+ mask_logits = F.interpolate(
362
+ mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
363
+
364
+ mask_logits = F.interpolate(
365
+ mask_logits,
366
+ size=[long_side, long_side],
367
+ mode='bilinear',
368
+ align_corners=False)[..., :im_h, :im_w]
369
+ mask = mask_logits.sigmoid().squeeze()
370
+ mask = mask > 0.5
371
+ mask = mask.cpu().numpy()
372
+ ins_segs.append(mask)
373
+
374
+ matched_iou_score = ioulst[matched_idx]
375
+ matched_score = instance_data.scores[matched_idx]
376
+ scores.append(matched_score.cpu().item())
377
+ matched_bbox = tgt_bboxes[matched_idx]
378
+
379
+ ins_bboxes.append(matched_bbox.cpu().numpy())
380
+ # p1, p2 = (int(matched_bbox[0]), int(matched_bbox[1])), (int(matched_bbox[2]), int(matched_bbox[3]))
381
+
382
+ if len(ins_bboxes) > 0:
383
+ ins_bboxes = np.array(ins_bboxes).astype(np.int32)
384
+ ins_bboxes[:, 2:] -= ins_bboxes[:, :2]
385
+ ins_segs = np.array(ins_segs)
386
+ instances = AnimeInstances(ins_segs, ins_bboxes, scores)
387
+
388
+ self._postprocess_refine(instances, img)
389
+ drawed = instances.draw_instances(img)
390
+ # cv2.imshow('drawed', drawed)
391
+ # cv2.waitKey(0)
392
+
393
+ return instances
394
+
395
+ def set_detect_size(self, det_size: Union[int, Tuple]):
396
+ if isinstance(det_size, int):
397
+ det_size = (det_size, det_size)
398
+ self.default_data_pipeline.transforms[1].scale = det_size
399
+ self.default_data_pipeline.transforms[2].size = det_size
400
+
401
+ @torch.no_grad()
402
+ def infer(self, imgs: Union[List, str, np.ndarray],
403
+ pred_score_thr: float = 0.3,
404
+ refine_kwargs: dict = None,
405
+ output_type: str="tensor",
406
+ det_size: int = None,
407
+ save_dir: str = '',
408
+ save_visualization: bool = False,
409
+ save_annotation: str = '',
410
+ infer_tags: bool = False,
411
+ obj_id_start: int = -1,
412
+ img_id_start: int = -1,
413
+ verbose: bool = False,
414
+ infer_grey: bool = False,
415
+ save_mask_only: bool = False,
416
+ val_dir=None,
417
+ max_instances: int = 100,
418
+ **kwargs) -> Union[List[AnimeInstances], AnimeInstances, None]:
419
+
420
+ """
421
+ Args:
422
+ imgs (str, ndarray, Sequence[str/ndarray]):
423
+ Either image files or loaded images.
424
+
425
+ Returns:
426
+ :obj:`AnimeInstances` or list[:obj:`AnimeInstances`]:
427
+ If save_annotation or save_annotation, return None.
428
+ """
429
+
430
+ if det_size is not None:
431
+ self.set_detect_size(det_size)
432
+ if refine_kwargs is not None:
433
+ self.set_refine_method(**refine_kwargs)
434
+
435
+ self.set_max_instance(max_instances)
436
+
437
+ if isinstance(imgs, str):
438
+ if imgs.endswith('.txt'):
439
+ imgs = read_imglst_from_txt(imgs)
440
+
441
+ if save_annotation or save_visualization:
442
+ return self._infer_save_annotations(imgs, pred_score_thr, det_size, save_dir, save_visualization, \
443
+ save_annotation, infer_tags, obj_id_start, img_id_start, val_dir=val_dir)
444
+ else:
445
+ return self._infer_simple(imgs, pred_score_thr, det_size, output_type, infer_tags, verbose=verbose, infer_grey=infer_grey)
446
+
447
+ def _det_forward(self, img, test_pipeline, pred_score_thr: float = 0.3) -> Tuple[AnimeInstances, np.ndarray]:
448
+ data_, img = test_pipeline(img)
449
+ with torch.no_grad():
450
+ results: DetDataSample = self.model.test_step(data_)[0]
451
+ pred_instances = results.pred_instances
452
+ pred_instances = pred_instances[pred_instances.scores > pred_score_thr]
453
+ if len(pred_instances) < 1:
454
+ return AnimeInstances(), img
455
+
456
+ del data_
457
+
458
+ bboxes = pred_instances.bboxes.to(torch.int32)
459
+ bboxes[:, 2:] -= bboxes[:, :2]
460
+ masks = pred_instances.masks
461
+ scores = pred_instances.scores
462
+ return AnimeInstances(masks, bboxes, scores), img
463
+
464
+ def _infer_simple(self, imgs: Union[List, str, np.ndarray],
465
+ pred_score_thr: float = 0.3,
466
+ det_size: int = None,
467
+ output_type: str = "tensor",
468
+ infer_tags: bool = False,
469
+ infer_grey: bool = False,
470
+ verbose: bool = False) -> Union[DetDataSample, List[DetDataSample]]:
471
+
472
+ if isinstance(imgs, List):
473
+ return_list = True
474
+ else:
475
+ return_list = False
476
+
477
+ assert output_type in {'tensor', 'numpy'}
478
+
479
+ test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
480
+ predictions = []
481
+
482
+ if len(imgs) > 1:
483
+ imgs = tqdm(imgs)
484
+
485
+ for img in imgs:
486
+ instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
487
+ # drawed = instances.draw_instances(img)
488
+ # cv2.imwrite('drawed.jpg', drawed)
489
+ self.postprocess_results(instances, img)
490
+ # drawed = instances.draw_instances(img)
491
+ # cv2.imwrite('drawed_post.jpg', drawed)
492
+
493
+ if infer_tags:
494
+ self.infer_tags(instances, img, infer_grey)
495
+
496
+ if output_type == 'numpy':
497
+ instances.to_numpy()
498
+
499
+ predictions.append(instances)
500
+
501
+ if return_list:
502
+ return predictions
503
+ else:
504
+ return predictions[0]
505
+
506
+ def _infer_save_annotations(self, imgs: Union[List, str, np.ndarray],
507
+ pred_score_thr: float = 0.3,
508
+ det_size: int = None,
509
+ save_dir: str = '',
510
+ save_visualization: bool = False,
511
+ save_annotation: str = '',
512
+ infer_tags: bool = False,
513
+ obj_id_start: int = 100000000000,
514
+ img_id_start: int = 100000000000,
515
+ save_mask_only: bool = False,
516
+ val_dir = None,
517
+ **kwargs) -> None:
518
+
519
+ coco_api = None
520
+ if isinstance(imgs, str) and imgs.endswith('.json'):
521
+ coco_api = COCO(imgs)
522
+
523
+ if val_dir is None:
524
+ val_dir = osp.join(osp.dirname(osp.dirname(imgs)), 'val')
525
+ imgs = coco_api.getImgIds()
526
+ imgp2ids = {}
527
+ imgps, coco_imgmetas = [], []
528
+ for imgid in imgs:
529
+ imeta = coco_api.loadImgs(imgid)[0]
530
+ imgname = imeta['file_name']
531
+ imgp = osp.join(val_dir, imgname)
532
+ imgp2ids[imgp] = imgid
533
+ imgps.append(imgp)
534
+ coco_imgmetas.append(imeta)
535
+ imgs = imgps
536
+
537
+ test_pipeline, imgs, target_dir = self.prepare_data_pipeline(imgs, det_size)
538
+ if save_dir == '':
539
+ save_dir = osp.join(target_dir, \
540
+ osp.basename(self.ckpt).replace('.ckpt', '').replace('.pth', '').replace('.pt', ''))
541
+
542
+ if not osp.exists(save_dir):
543
+ os.makedirs(save_dir)
544
+
545
+ det_annotations = []
546
+ image_meta = []
547
+ obj_id = obj_id_start + 1
548
+ image_id = img_id_start + 1
549
+
550
+ for ii, img in enumerate(tqdm(imgs)):
551
+ # prepare data
552
+ if isinstance(img, str):
553
+ img_name = osp.basename(img)
554
+ else:
555
+ img_name = f'{ii}'.zfill(12) + '.jpg'
556
+
557
+ if coco_api is not None:
558
+ image_id = imgp2ids[img]
559
+
560
+ try:
561
+ instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
562
+ except Exception as e:
563
+ raise e
564
+ if isinstance(e, torch.cuda.OutOfMemoryError):
565
+ gc.collect()
566
+ torch.cuda.empty_cache()
567
+ torch.cuda.ipc_collect()
568
+ try:
569
+ instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
570
+ except:
571
+ LOGGER.warning(f'cuda out of memory: {img_name}')
572
+ if isinstance(img, str):
573
+ img = cv2.imread(img)
574
+ instances = None
575
+
576
+ if instances is not None:
577
+ self.postprocess_results(instances, img)
578
+
579
+ if infer_tags:
580
+ self.infer_tags(instances, img)
581
+
582
+ if save_visualization:
583
+ out_file = osp.join(save_dir, img_name)
584
+ self.save_visualization(out_file, img, instances)
585
+
586
+ if save_annotation:
587
+ im_h, im_w = img.shape[:2]
588
+ image_meta.append({
589
+ "id": image_id,"height": im_h,"width": im_w,
590
+ "file_name": img_name, "id": image_id
591
+ })
592
+ if instances is not None:
593
+ for ii in range(len(instances)):
594
+ segmentation = instances.masks[ii].squeeze().cpu().numpy().astype(np.uint8)
595
+ area = segmentation.sum()
596
+ segmentation *= 255
597
+ if save_mask_only:
598
+ cv2.imwrite(osp.join(save_dir, 'mask_' + str(ii).zfill(3) + '_' +img_name+'.png'), segmentation)
599
+ else:
600
+ score = instances.scores[ii]
601
+ if isinstance(score, torch.Tensor):
602
+ score = score.item()
603
+ score = float(score)
604
+ bbox = instances.bboxes[ii].cpu().numpy()
605
+ bbox = bbox.astype(np.float32).tolist()
606
+ segmentation = mask2rle(segmentation)
607
+ tag_string = instances.tags[ii]
608
+ tag_string_character = instances.character_tags[ii]
609
+ det_annotations.append({'id': obj_id, 'category_id': 0, 'iscrowd': 0, 'score': score,
610
+ 'segmentation': segmentation, 'image_id': image_id, 'area': area,
611
+ 'tag_string': tag_string, 'tag_string_character': tag_string_character, 'bbox': bbox
612
+ })
613
+ obj_id += 1
614
+ image_id += 1
615
+
616
+ if save_annotation != '' and not save_mask_only:
617
+ det_meta = {"info": {},"licenses": [], "images": image_meta,
618
+ "annotations": det_annotations, "categories": CATEGORIES}
619
+ detp = save_annotation
620
+ dict2json(det_meta, detp)
621
+ LOGGER.info(f'annotations saved to {detp}')
622
+
623
+ def set_refine_method(self, refine_method: str = 'none', refine_size: int = 720):
624
+ if refine_method == 'none':
625
+ self.postprocess_refine = None
626
+ elif refine_method == 'animeseg':
627
+ if self.refinenet_animeseg is None:
628
+ self.refinenet_animeseg = load_refinenet(refine_method)
629
+ self.postprocess_refine = lambda det_pred, img: \
630
+ animeseg_refine(det_pred, img, self.refinenet_animeseg, True, refine_size)
631
+ elif refine_method == 'refinenet_isnet':
632
+ if self.refinenet is None:
633
+ self.refinenet = load_refinenet(refine_method)
634
+ self.postprocess_refine = self._postprocess_refine
635
+ else:
636
+ raise NotImplementedError(f'Invalid refine method: {refine_method}')
637
+
638
+ def _postprocess_refine(self, instances: AnimeInstances, img: np.ndarray, refine_size: int = 720, max_refine_batch: int = 4, **kwargs):
639
+
640
+ if instances.is_empty:
641
+ return
642
+
643
+ segs = instances.masks
644
+ is_tensor = instances.is_tensor
645
+ if is_tensor:
646
+ segs = segs.cpu().numpy()
647
+ segs = segs.astype(np.float32)
648
+ im_h, im_w = img.shape[:2]
649
+
650
+ masks = []
651
+ with torch.no_grad():
652
+ for batch, (pt, pb, pl, pr) in prepare_refine_batch(segs, img, max_refine_batch, self.device, refine_size):
653
+ preds = self.refinenet(batch)[0][0].sigmoid()
654
+ if pb == 0:
655
+ pb = -im_h
656
+ if pr == 0:
657
+ pr = -im_w
658
+ preds = preds[..., pt: -pb, pl: -pr]
659
+ preds = torch.nn.functional.interpolate(preds, (im_h, im_w), mode='bilinear', align_corners=True)
660
+ masks.append(preds.cpu()[:, 0])
661
+
662
+ masks = (torch.concat(masks, dim=0) > self.mask_thr).to(self.device)
663
+ if not is_tensor:
664
+ masks = masks.cpu().numpy()
665
+ instances.masks = masks
666
+
667
+
668
+ def prepare_data_pipeline(self, imgs: Union[str, np.ndarray, List], det_size: int) -> Tuple[Compose, List, str]:
669
+
670
+ if det_size is None:
671
+ det_size = self.default_det_size
672
+
673
+ target_dir = './workspace/output'
674
+ # cast imgs to a list of np.ndarray or image_file_path if necessary
675
+ if isinstance(imgs, str):
676
+ if osp.isdir(imgs):
677
+ target_dir = imgs
678
+ imgs = find_all_imgs(imgs, abs_path=True)
679
+ elif osp.isfile(imgs):
680
+ target_dir = osp.dirname(imgs)
681
+ imgs = [imgs]
682
+ elif isinstance(imgs, np.ndarray) or isinstance(imgs, str):
683
+ imgs = [imgs]
684
+ elif isinstance(imgs, List):
685
+ if len(imgs) > 0:
686
+ if isinstance(imgs[0], np.ndarray) or isinstance(imgs[0], str):
687
+ pass
688
+ else:
689
+ raise NotImplementedError
690
+ else:
691
+ raise NotImplementedError
692
+
693
+ test_pipeline = lambda img: single_image_preprocess(img, pipeline=self.default_data_pipeline)
694
+ return test_pipeline, imgs, target_dir
695
+
696
+ def save_visualization(self, out_file: str, img: np.ndarray, instances: AnimeInstances):
697
+ drawed = instances.draw_instances(img)
698
+ mmcv.imwrite(drawed, out_file)
699
+
700
+ def postprocess_results(self, results: DetDataSample, img: np.ndarray) -> None:
701
+ if self.postprocess_refine is not None:
702
+ self.postprocess_refine(results, img)
703
+
704
+ def set_mask_threshold(self, mask_thr: float):
705
+ self.model.bbox_head.test_cfg['mask_thr_binary'] = mask_thr
706
+
707
+ def set_max_instance(self, num_ins):
708
+ self.model.bbox_head.test_cfg['max_per_img'] = num_ins
animeinsseg/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (35.9 kB). View file
 
animeinsseg/anime_instances.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from typing import List, Union, Tuple
4
+ import torch
5
+ from utils.constants import COLOR_PALETTE
6
+ from utils.constants import get_color
7
+ import cv2
8
+
9
+ def tags2multilines(tags: Union[str, List], lw, tf, max_width):
10
+ if isinstance(tags, str):
11
+ taglist = tags.split(' ')
12
+ else:
13
+ taglist = tags
14
+
15
+ sz = cv2.getTextSize(' ', 0, lw / 3, tf)
16
+ line_height = sz[0][1]
17
+ line_width = 0
18
+ if len(taglist) > 0:
19
+ lines = [taglist[0]]
20
+ if len(taglist) > 1:
21
+ for t in taglist[1:]:
22
+ textl = len(t) * line_height
23
+ if line_width + line_height + textl > max_width:
24
+ lines.append(t)
25
+ line_width = 0
26
+ else:
27
+ line_width = line_width + line_height + textl
28
+ lines[-1] = lines[-1] + ' ' + t
29
+ return lines, line_height
30
+
31
+ class AnimeInstances:
32
+
33
+ def __init__(self,
34
+ masks: Union[np.ndarray, torch.Tensor ]= None,
35
+ bboxes: Union[np.ndarray, torch.Tensor ] = None,
36
+ scores: Union[np.ndarray, torch.Tensor ] = None,
37
+ tags: List[str] = None, character_tags: List[str] = None) -> None:
38
+ self.masks = masks
39
+ self.tags = tags
40
+ self.bboxes = bboxes
41
+
42
+
43
+ if scores is None:
44
+ scores = [1.] * len(self)
45
+ if self.is_numpy:
46
+ scores = np.array(scores)
47
+ elif self.is_tensor:
48
+ scores = torch.tensor(scores)
49
+
50
+ self.scores = scores
51
+
52
+ if tags is None:
53
+ self.tags = [''] * len(self)
54
+ self.character_tags = [''] * len(self)
55
+ else:
56
+ self.tags = tags
57
+ self.character_tags = character_tags
58
+
59
+ @property
60
+ def is_cuda(self):
61
+ if isinstance(self.masks, torch.Tensor) and self.masks.is_cuda:
62
+ return True
63
+ else:
64
+ return False
65
+
66
+ @property
67
+ def is_tensor(self):
68
+ if self.is_empty:
69
+ return False
70
+ else:
71
+ return isinstance(self.masks, torch.Tensor)
72
+
73
+ @property
74
+ def is_numpy(self):
75
+ if self.is_empty:
76
+ return True
77
+ else:
78
+ return isinstance(self.masks, np.ndarray)
79
+
80
+ @property
81
+ def is_empty(self):
82
+ return self.masks is None or len(self.masks) == 0\
83
+
84
+ def remove_duplicated(self):
85
+
86
+ num_masks = len(self)
87
+ if num_masks < 2:
88
+ return
89
+
90
+ need_cvt = False
91
+ if self.is_numpy:
92
+ need_cvt = True
93
+ self.to_tensor()
94
+
95
+ mask_areas = torch.Tensor([mask.sum() for mask in self.masks])
96
+ sids = torch.argsort(mask_areas, descending=True)
97
+ sids = sids.cpu().numpy().tolist()
98
+ mask_areas = mask_areas[sids]
99
+ masks = self.masks[sids]
100
+ bboxes = self.bboxes[sids]
101
+ tags = [self.tags[sid] for sid in sids]
102
+ scores = self.scores[sids]
103
+
104
+ canvas = masks[0]
105
+
106
+ valid_ids: List = np.arange(num_masks).tolist()
107
+ for ii, mask in enumerate(masks[1:]):
108
+
109
+ mask_id = ii + 1
110
+ canvas_and = torch.bitwise_and(canvas, mask)
111
+
112
+ and_area = canvas_and.sum()
113
+ mask_area = mask_areas[mask_id]
114
+
115
+ if and_area / mask_area > 0.8:
116
+ valid_ids.remove(mask_id)
117
+ elif mask_id != num_masks - 1:
118
+ canvas = torch.bitwise_or(canvas, mask)
119
+
120
+ sids = valid_ids
121
+ self.masks = masks[sids]
122
+ self.bboxes = bboxes[sids]
123
+ self.tags = [tags[sid] for sid in sids]
124
+ self.scores = scores[sids]
125
+
126
+ if need_cvt:
127
+ self.to_numpy()
128
+
129
+ # sids =
130
+
131
+ def draw_instances(self,
132
+ img: np.ndarray,
133
+ draw_bbox: bool = True,
134
+ draw_ins_mask: bool = True,
135
+ draw_ins_contour: bool = True,
136
+ draw_tags: bool = False,
137
+ draw_indices: List = None,
138
+ mask_alpha: float = 0.4):
139
+
140
+ mask_alpha = 0.75
141
+
142
+
143
+ drawed = img.copy()
144
+
145
+ if self.is_empty:
146
+ return drawed
147
+
148
+ im_h, im_w = img.shape[:2]
149
+
150
+ mask_shape = self.masks[0].shape
151
+ if mask_shape[0] != im_h or mask_shape[1] != im_w:
152
+ drawed = cv2.resize(drawed, (mask_shape[1], mask_shape[0]), interpolation=cv2.INTER_AREA)
153
+ im_h, im_w = mask_shape[0], mask_shape[1]
154
+
155
+ if draw_indices is None:
156
+ draw_indices = list(range(len(self)))
157
+ ins_dict = {'mask': [], 'tags': [], 'score': [], 'bbox': [], 'character_tags': []}
158
+ colors = []
159
+ for idx in draw_indices:
160
+ ins = self.get_instance(idx, out_type='numpy')
161
+ for key, data in ins.items():
162
+ ins_dict[key].append(data)
163
+ colors.append(get_color(idx))
164
+
165
+ if draw_bbox:
166
+ lw = max(round(sum(drawed.shape) / 2 * 0.003), 2)
167
+ for color, bbox in zip(colors, ins_dict['bbox']):
168
+ p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2] + bbox[0]), int(bbox[3] + bbox[1]))
169
+ cv2.rectangle(drawed, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
170
+
171
+ if draw_ins_mask:
172
+ drawed = drawed.astype(np.float32)
173
+ for color, mask in zip(colors, ins_dict['mask']):
174
+ p = mask.astype(np.float32)
175
+ blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
176
+ alpha_msk = (mask_alpha * p)[..., None]
177
+ alpha_ori = 1 - alpha_msk
178
+ drawed = drawed * alpha_ori + alpha_msk * blend_mask
179
+ drawed = drawed.astype(np.uint8)
180
+
181
+ if draw_tags:
182
+ lw = max(round(sum(drawed.shape) / 2 * 0.002), 2)
183
+ tf = max(lw - 1, 1)
184
+ for color, tags, bbox in zip(colors, ins_dict['tags'], ins_dict['bbox']):
185
+ if not tags:
186
+ continue
187
+ lines, line_height = tags2multilines(tags, lw, tf, bbox[2])
188
+ for ii, l in enumerate(lines):
189
+ xy = (bbox[0], bbox[1] + line_height + int(line_height * 1.2 * ii))
190
+ cv2.putText(drawed, l, xy, 0, lw / 3, color, thickness=tf, lineType=cv2.LINE_AA)
191
+
192
+ # cv2.imshow('canvas', drawed)
193
+ # cv2.waitKey(0)
194
+ return drawed
195
+
196
+
197
+ def cuda(self):
198
+ if self.is_empty:
199
+ return self
200
+ self.to_tensor(device='cuda')
201
+ return self
202
+
203
+ def cpu(self):
204
+ if not self.is_tensor or not self.is_cuda:
205
+ return self
206
+ self.masks = self.masks.cpu()
207
+ self.scores = self.scores.cpu()
208
+ self.bboxes = self.bboxes.cpu()
209
+ return self
210
+
211
+ def to_tensor(self, device: str = 'cpu'):
212
+ if self.is_empty:
213
+ return self
214
+ elif self.is_tensor and self.masks.device == device:
215
+ return self
216
+ self.masks = torch.from_numpy(self.masks).to(device)
217
+ self.bboxes = torch.from_numpy(self.bboxes).to(device)
218
+ self.scores = torch.from_numpy(self.scores ).to(device)
219
+ return self
220
+
221
+ def to_numpy(self):
222
+ if self.is_numpy:
223
+ return self
224
+ if self.is_cuda:
225
+ self.masks = self.masks.cpu().numpy()
226
+ self.scores = self.scores.cpu().numpy()
227
+ self.bboxes = self.bboxes.cpu().numpy()
228
+ else:
229
+ self.masks = self.masks.numpy()
230
+ self.scores = self.scores.numpy()
231
+ self.bboxes = self.bboxes.numpy()
232
+ return self
233
+
234
+ def get_instance(self, ins_idx: int, out_type: str = None, device: str = None):
235
+ mask = self.masks[ins_idx]
236
+ tags = self.tags[ins_idx]
237
+ character_tags = self.character_tags[ins_idx]
238
+ bbox = self.bboxes[ins_idx]
239
+ score = self.scores[ins_idx]
240
+ if out_type is not None:
241
+ if out_type == 'numpy' and not self.is_numpy:
242
+ mask = mask.cpu().numpy()
243
+ bbox = bbox.cpu().numpy()
244
+ score = score.cpu().numpy()
245
+ if out_type == 'tensor' and not self.is_tensor:
246
+ mask = torch.from_numpy(mask)
247
+ bbox = torch.from_numpy(bbox)
248
+ score = torch.from_numpy(score)
249
+ if isinstance(mask, torch.Tensor) and device is not None and mask.device != device:
250
+ mask = mask.to(device)
251
+ bbox = bbox.to(device)
252
+ score = score.to(device)
253
+
254
+ return {
255
+ 'mask': mask,
256
+ 'tags': tags,
257
+ 'character_tags': character_tags,
258
+ 'bbox': bbox,
259
+ 'score': score
260
+ }
261
+
262
+ def __len__(self):
263
+ if self.is_empty:
264
+ return 0
265
+ else:
266
+ return len(self.masks)
267
+
268
+ def resize(self, h, w, mode = 'area'):
269
+ if self.is_empty:
270
+ return
271
+ if self.is_tensor:
272
+ masks = self.masks.to(torch.float).unsqueeze(1)
273
+ oh, ow = masks.shape[2], masks.shape[3]
274
+ hs, ws = h / oh, w / ow
275
+ bboxes = self.bboxes.float()
276
+ bboxes[:, ::2] *= hs
277
+ bboxes[:, 1::2] *= ws
278
+ self.bboxes = torch.round(bboxes).int()
279
+ masks = torch.nn.functional.interpolate(masks, (h, w), mode=mode)
280
+ self.masks = masks.squeeze(1) > 0.3
281
+
282
+ def compose_masks(self, output_type=None):
283
+ if self.is_empty:
284
+ return None
285
+ else:
286
+ mask = self.masks[0]
287
+ if len(self.masks) > 1:
288
+ for m in self.masks[1:]:
289
+ if self.is_numpy:
290
+ mask = np.logical_or(mask, m)
291
+ else:
292
+ mask = torch.logical_or(mask, m)
293
+ if output_type is not None:
294
+ if output_type == 'numpy' and not self.is_numpy:
295
+ mask = mask.cpu().numpy()
296
+ if output_type == 'tensor' and not self.is_tensor:
297
+ mask = torch.from_numpy(mask)
298
+ return mask
299
+
300
+
301
+
animeinsseg/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from .dataset import *
2
+ # from .syndataset import *
animeinsseg/data/dataset.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ from typing import List, Optional, Sequence, Tuple, Union
4
+ import copy
5
+ from time import time
6
+ import mmcv
7
+ from mmcv.transforms import to_tensor
8
+ from mmdet.datasets.transforms import LoadAnnotations, RandomCrop, PackDetInputs, Mosaic, CachedMosaic, CachedMixUp, FilterAnnotations
9
+ from mmdet.structures.mask import BitmapMasks, PolygonMasks
10
+ from mmdet.datasets import CocoDataset
11
+ from mmdet.registry import DATASETS, TRANSFORMS
12
+ from numpy import random
13
+ from mmdet.structures.bbox import autocast_box_type, BaseBoxes
14
+ from mmengine.structures import InstanceData, PixelData
15
+ from mmdet.structures import DetDataSample
16
+ from utils.io_utils import bbox_overlap_xy
17
+ from utils.logger import LOGGER
18
+
19
+ @DATASETS.register_module()
20
+ class AnimeMangaMixedDataset(CocoDataset):
21
+
22
+ def __init__(self, animeins_root: str = None, animeins_annfile: str = None, manga109_annfile: str = None, manga109_root: str = None, *args, **kwargs) -> None:
23
+ self.animeins_annfile = animeins_annfile
24
+ self.animeins_root = animeins_root
25
+ self.manga109_annfile = manga109_annfile
26
+ self.manga109_root = manga109_root
27
+ self.cat_ids = []
28
+ self.cat_img_map = {}
29
+ super().__init__(*args, **kwargs)
30
+ LOGGER.info(f'total num data: {len(self.data_list)}')
31
+
32
+
33
+ def parse_data_info(self, raw_data_info: dict, data_prefix: str) -> Union[dict, List[dict]]:
34
+ """Parse raw annotation to target format.
35
+
36
+ Args:
37
+ raw_data_info (dict): Raw data information load from ``ann_file``
38
+
39
+ Returns:
40
+ Union[dict, List[dict]]: Parsed annotation.
41
+ """
42
+ img_info = raw_data_info['raw_img_info']
43
+ ann_info = raw_data_info['raw_ann_info']
44
+
45
+ data_info = {}
46
+
47
+ # TODO: need to change data_prefix['img'] to data_prefix['img_path']
48
+ img_path = osp.join(data_prefix, img_info['file_name'])
49
+ if self.data_prefix.get('seg', None):
50
+ seg_map_path = osp.join(
51
+ self.data_prefix['seg'],
52
+ img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
53
+ else:
54
+ seg_map_path = None
55
+ data_info['img_path'] = img_path
56
+ data_info['img_id'] = img_info['img_id']
57
+ data_info['seg_map_path'] = seg_map_path
58
+ data_info['height'] = img_info['height']
59
+ data_info['width'] = img_info['width']
60
+
61
+ instances = []
62
+ for i, ann in enumerate(ann_info):
63
+ instance = {}
64
+
65
+ if ann.get('ignore', False):
66
+ continue
67
+ x1, y1, w, h = ann['bbox']
68
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
69
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
70
+ if inter_w * inter_h == 0:
71
+ continue
72
+ if ann['area'] <= 0 or w < 1 or h < 1:
73
+ continue
74
+ if ann['category_id'] not in self.cat_ids:
75
+ continue
76
+ bbox = [x1, y1, x1 + w, y1 + h]
77
+
78
+ if ann.get('iscrowd', False):
79
+ instance['ignore_flag'] = 1
80
+ else:
81
+ instance['ignore_flag'] = 0
82
+ instance['bbox'] = bbox
83
+ instance['bbox_label'] = self.cat2label[ann['category_id']]
84
+
85
+ if ann.get('segmentation', None):
86
+ instance['mask'] = ann['segmentation']
87
+
88
+ instances.append(instance)
89
+ data_info['instances'] = instances
90
+ return data_info
91
+
92
+
93
+ def load_data_list(self) -> List[dict]:
94
+ data_lst = []
95
+ if self.manga109_root is not None:
96
+ data_lst += self._data_list(self.manga109_annfile, osp.join(self.manga109_root, 'images'))
97
+ # if len(data_lst) > 8000:
98
+ # data_lst = data_lst[:500]
99
+ LOGGER.info(f'num data from manga109: {len(data_lst)}')
100
+ if self.animeins_root is not None:
101
+ animeins_annfile = osp.join(self.animeins_root, self.animeins_annfile)
102
+ data_prefix = osp.join(self.animeins_root, self.data_prefix['img'])
103
+ anime_lst = self._data_list(animeins_annfile, data_prefix)
104
+ # if len(anime_lst) > 8000:
105
+ # anime_lst = anime_lst[:500]
106
+ data_lst += anime_lst
107
+ LOGGER.info(f'num data from animeins: {len(data_lst)}')
108
+ return data_lst
109
+
110
+ def _data_list(self, annfile: str, data_prefix: str) -> List[dict]:
111
+ """Load annotations from an annotation file named as ``ann_file``
112
+
113
+ Returns:
114
+ List[dict]: A list of annotation.
115
+ """ # noqa: E501
116
+ with self.file_client.get_local_path(annfile) as local_path:
117
+ self.coco = self.COCOAPI(local_path)
118
+ # The order of returned `cat_ids` will not
119
+ # change with the order of the `classes`
120
+ self.cat_ids = self.coco.get_cat_ids(
121
+ cat_names=self.metainfo['classes'])
122
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
123
+ cat_img_map = copy.deepcopy(self.coco.cat_img_map)
124
+ for key, val in cat_img_map.items():
125
+ if key in self.cat_img_map:
126
+ self.cat_img_map[key] += val
127
+ else:
128
+ self.cat_img_map[key] = val
129
+
130
+ img_ids = self.coco.get_img_ids()
131
+ data_list = []
132
+ total_ann_ids = []
133
+ for img_id in img_ids:
134
+ raw_img_info = self.coco.load_imgs([img_id])[0]
135
+ raw_img_info['img_id'] = img_id
136
+
137
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
138
+ raw_ann_info = self.coco.load_anns(ann_ids)
139
+ total_ann_ids.extend(ann_ids)
140
+
141
+ parsed_data_info = self.parse_data_info({
142
+ 'raw_ann_info':
143
+ raw_ann_info,
144
+ 'raw_img_info':
145
+ raw_img_info
146
+ }, data_prefix)
147
+ data_list.append(parsed_data_info)
148
+ if self.ANN_ID_UNIQUE:
149
+ assert len(set(total_ann_ids)) == len(
150
+ total_ann_ids
151
+ ), f"Annotation ids in '{annfile}' are not unique!"
152
+
153
+ del self.coco
154
+
155
+ return data_list
156
+
157
+
158
+
159
+ @TRANSFORMS.register_module()
160
+ class LoadAnnotationsNoSegs(LoadAnnotations):
161
+
162
+ def _process_masks(self, results: dict) -> list:
163
+ """Process gt_masks and filter invalid polygons.
164
+
165
+ Args:
166
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
167
+
168
+ Returns:
169
+ list: Processed gt_masks.
170
+ """
171
+ gt_masks = []
172
+ gt_ignore_flags = []
173
+ gt_ignore_mask_flags = []
174
+ for instance in results.get('instances', []):
175
+ gt_mask = instance['mask']
176
+ ignore_mask = False
177
+ # If the annotation of segmentation mask is invalid,
178
+ # ignore the whole instance.
179
+ if isinstance(gt_mask, list):
180
+ gt_mask = [
181
+ np.array(polygon) for polygon in gt_mask
182
+ if len(polygon) % 2 == 0 and len(polygon) >= 6
183
+ ]
184
+ if len(gt_mask) == 0:
185
+ # ignore this instance and set gt_mask to a fake mask
186
+ instance['ignore_flag'] = 1
187
+ gt_mask = [np.zeros(6)]
188
+ elif not self.poly2mask:
189
+ # `PolygonMasks` requires a ploygon of format List[np.array],
190
+ # other formats are invalid.
191
+ instance['ignore_flag'] = 1
192
+ gt_mask = [np.zeros(6)]
193
+ elif isinstance(gt_mask, dict) and \
194
+ not (gt_mask.get('counts') is not None and
195
+ gt_mask.get('size') is not None and
196
+ isinstance(gt_mask['counts'], (list, str))):
197
+ # if gt_mask is a dict, it should include `counts` and `size`,
198
+ # so that `BitmapMasks` can uncompressed RLE
199
+ # instance['ignore_flag'] = 1
200
+ ignore_mask = True
201
+ gt_mask = [np.zeros(6)]
202
+ gt_masks.append(gt_mask)
203
+ # re-process gt_ignore_flags
204
+ gt_ignore_flags.append(instance['ignore_flag'])
205
+ gt_ignore_mask_flags.append(ignore_mask)
206
+ results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
207
+ results['gt_ignore_mask_flags'] = np.array(gt_ignore_mask_flags, dtype=bool)
208
+ return gt_masks
209
+
210
+ def _load_masks(self, results: dict) -> None:
211
+ """Private function to load mask annotations.
212
+
213
+ Args:
214
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
215
+ """
216
+ h, w = results['ori_shape']
217
+ gt_masks = self._process_masks(results)
218
+ if self.poly2mask:
219
+ p2masks = []
220
+ if len(gt_masks) > 0:
221
+ for ins, mask, ignore_mask in zip(results['instances'], gt_masks, results['gt_ignore_mask_flags']):
222
+ bbox = [int(c) for c in ins['bbox']]
223
+ if ignore_mask:
224
+ m = np.zeros((h, w), dtype=np.uint8)
225
+ m[bbox[1]:bbox[3], bbox[0]: bbox[2]] = 255
226
+ # m[bbox[1]:bbox[3], bbox[0]: bbox[2]]
227
+ p2masks.append(m)
228
+ else:
229
+ p2masks.append(self._poly2mask(mask, h, w))
230
+ # import cv2
231
+ # # cv2.imwrite('tmp_mask.png', p2masks[-1] * 255)
232
+ # cv2.imwrite('tmp_img.png', results['img'])
233
+ # cv2.imwrite('tmp_bbox.png', m * 225)
234
+ # print(p2masks[-1].shape, p2masks[-1].dtype)
235
+ gt_masks = BitmapMasks(p2masks, h, w)
236
+ else:
237
+ # fake polygon masks will be ignored in `PackDetInputs`
238
+ gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
239
+ results['gt_masks'] = gt_masks
240
+
241
+ def transform(self, results: dict) -> dict:
242
+ """Function to load multiple types annotations.
243
+
244
+ Args:
245
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
246
+
247
+ Returns:
248
+ dict: The dict contains loaded bounding box, label and
249
+ semantic segmentation.
250
+ """
251
+
252
+ if self.with_bbox:
253
+ self._load_bboxes(results)
254
+ if self.with_label:
255
+ self._load_labels(results)
256
+ if self.with_mask:
257
+ self._load_masks(results)
258
+ if self.with_seg:
259
+ self._load_seg_map(results)
260
+
261
+ return results
262
+
263
+
264
+
265
+ @TRANSFORMS.register_module()
266
+ class PackDetIputsNoSeg(PackDetInputs):
267
+
268
+ mapping_table = {
269
+ 'gt_bboxes': 'bboxes',
270
+ 'gt_bboxes_labels': 'labels',
271
+ 'gt_ignore_mask_flags': 'ignore_mask',
272
+ 'gt_masks': 'masks'
273
+ }
274
+
275
+ def transform(self, results: dict) -> dict:
276
+ """Method to pack the input data.
277
+
278
+ Args:
279
+ results (dict): Result dict from the data pipeline.
280
+
281
+ Returns:
282
+ dict:
283
+
284
+ - 'inputs' (obj:`torch.Tensor`): The forward data of models.
285
+ - 'data_sample' (obj:`DetDataSample`): The annotation info of the
286
+ sample.
287
+ """
288
+ packed_results = dict()
289
+ if 'img' in results:
290
+ img = results['img']
291
+ if len(img.shape) < 3:
292
+ img = np.expand_dims(img, -1)
293
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
294
+ packed_results['inputs'] = to_tensor(img)
295
+
296
+ if 'gt_ignore_flags' in results:
297
+ valid_idx = np.where(results['gt_ignore_flags'] == 0)[0]
298
+ ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0]
299
+
300
+ data_sample = DetDataSample()
301
+ instance_data = InstanceData()
302
+ ignore_instance_data = InstanceData()
303
+
304
+ for key in self.mapping_table.keys():
305
+ if key not in results:
306
+ continue
307
+ if key == 'gt_masks' or isinstance(results[key], BaseBoxes):
308
+ if 'gt_ignore_flags' in results:
309
+ instance_data[
310
+ self.mapping_table[key]] = results[key][valid_idx]
311
+ ignore_instance_data[
312
+ self.mapping_table[key]] = results[key][ignore_idx]
313
+ else:
314
+ instance_data[self.mapping_table[key]] = results[key]
315
+ else:
316
+ if 'gt_ignore_flags' in results:
317
+ instance_data[self.mapping_table[key]] = to_tensor(
318
+ results[key][valid_idx])
319
+ ignore_instance_data[self.mapping_table[key]] = to_tensor(
320
+ results[key][ignore_idx])
321
+ else:
322
+ instance_data[self.mapping_table[key]] = to_tensor(
323
+ results[key])
324
+ data_sample.gt_instances = instance_data
325
+ data_sample.ignored_instances = ignore_instance_data
326
+
327
+ if 'proposals' in results:
328
+ proposals = InstanceData(
329
+ bboxes=to_tensor(results['proposals']),
330
+ scores=to_tensor(results['proposals_scores']))
331
+ data_sample.proposals = proposals
332
+
333
+ if 'gt_seg_map' in results:
334
+ gt_sem_seg_data = dict(
335
+ sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy()))
336
+ data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
337
+
338
+ img_meta = {}
339
+ for key in self.meta_keys:
340
+ assert key in results, f'`{key}` is not found in `results`, ' \
341
+ f'the valid keys are {list(results)}.'
342
+ img_meta[key] = results[key]
343
+
344
+ data_sample.set_metainfo(img_meta)
345
+ packed_results['data_samples'] = data_sample
346
+
347
+ return packed_results
348
+
349
+
350
+
351
+ def translate_bitmapmask(bitmap_masks: BitmapMasks,
352
+ out_shape,
353
+ offset_x,
354
+ offset_y,):
355
+
356
+ if len(bitmap_masks.masks) == 0:
357
+ translated_masks = np.empty((0, *out_shape), dtype=np.uint8)
358
+ else:
359
+ masks = bitmap_masks.masks
360
+ out_h, out_w = out_shape
361
+ mask_h, mask_w = masks.shape[1:]
362
+
363
+ translated_masks = np.zeros((masks.shape[0], *out_shape),
364
+ dtype=masks.dtype)
365
+
366
+ ix, iy = bbox_overlap_xy([0, 0, out_w, out_h], [offset_x, offset_y, mask_w, mask_h])
367
+ if ix > 2 and iy > 2:
368
+ if offset_x > 0:
369
+ mx1 = 0
370
+ tx1 = offset_x
371
+ else:
372
+ mx1 = -offset_x
373
+ tx1 = 0
374
+ mx2 = min(out_w - offset_x, mask_w)
375
+ tx2 = tx1 + mx2 - mx1
376
+
377
+ if offset_y > 0:
378
+ my1 = 0
379
+ ty1 = offset_y
380
+ else:
381
+ my1 = -offset_y
382
+ ty1 = 0
383
+ my2 = min(out_h - offset_y, mask_h)
384
+ ty2 = ty1 + my2 - my1
385
+
386
+ translated_masks[:, ty1: ty2, tx1: tx2] = \
387
+ masks[:, my1: my2, mx1: mx2]
388
+
389
+ return BitmapMasks(translated_masks, *out_shape)
390
+
391
+
392
+ @TRANSFORMS.register_module()
393
+ class CachedMosaicNoSeg(CachedMosaic):
394
+
395
+ @autocast_box_type()
396
+ def transform(self, results: dict) -> dict:
397
+
398
+ """Mosaic transform function.
399
+
400
+ Args:
401
+ results (dict): Result dict.
402
+
403
+ Returns:
404
+ dict: Updated result dict.
405
+ """
406
+ # cache and pop images
407
+ self.results_cache.append(copy.deepcopy(results))
408
+ if len(self.results_cache) > self.max_cached_images:
409
+ if self.random_pop:
410
+ index = random.randint(0, len(self.results_cache) - 1)
411
+ else:
412
+ index = 0
413
+ self.results_cache.pop(index)
414
+
415
+ if len(self.results_cache) <= 4:
416
+ return results
417
+
418
+ if random.uniform(0, 1) > self.prob:
419
+ return results
420
+ indices = self.get_indexes(self.results_cache)
421
+ mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices]
422
+
423
+ # TODO: refactor mosaic to reuse these code.
424
+ mosaic_bboxes = []
425
+ mosaic_bboxes_labels = []
426
+ mosaic_ignore_flags = []
427
+ mosaic_masks = []
428
+ mosaic_ignore_mask_flags = []
429
+ with_mask = True if 'gt_masks' in results else False
430
+
431
+ if len(results['img'].shape) == 3:
432
+ mosaic_img = np.full(
433
+ (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3),
434
+ self.pad_val,
435
+ dtype=results['img'].dtype)
436
+ else:
437
+ mosaic_img = np.full(
438
+ (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)),
439
+ self.pad_val,
440
+ dtype=results['img'].dtype)
441
+
442
+ # mosaic center x, y
443
+ center_x = int(
444
+ random.uniform(*self.center_ratio_range) * self.img_scale[0])
445
+ center_y = int(
446
+ random.uniform(*self.center_ratio_range) * self.img_scale[1])
447
+ center_position = (center_x, center_y)
448
+
449
+ loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
450
+
451
+ n_manga = 0
452
+ for i, loc in enumerate(loc_strs):
453
+ if loc == 'top_left':
454
+ results_patch = copy.deepcopy(results)
455
+ else:
456
+ results_patch = copy.deepcopy(mix_results[i - 1])
457
+
458
+ is_manga = results_patch['img_id'] > 900000000
459
+ if is_manga:
460
+ n_manga += 1
461
+ if n_manga > 3:
462
+ continue
463
+ im_h, im_w = results_patch['img'].shape[:2]
464
+ if im_w > im_h and random.random() < 0.75:
465
+ results_patch = hcrop(results_patch, (im_h, im_w // 2), True)
466
+
467
+ img_i = results_patch['img']
468
+ h_i, w_i = img_i.shape[:2]
469
+ # keep_ratio resize
470
+ scale_ratio_i = min(self.img_scale[1] / h_i,
471
+ self.img_scale[0] / w_i)
472
+ img_i = mmcv.imresize(
473
+ img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
474
+
475
+ # compute the combine parameters
476
+ paste_coord, crop_coord = self._mosaic_combine(
477
+ loc, center_position, img_i.shape[:2][::-1])
478
+ x1_p, y1_p, x2_p, y2_p = paste_coord
479
+ x1_c, y1_c, x2_c, y2_c = crop_coord
480
+
481
+ # crop and paste image
482
+ mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
483
+
484
+ # adjust coordinate
485
+ gt_bboxes_i = results_patch['gt_bboxes']
486
+ gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
487
+ gt_ignore_flags_i = results_patch['gt_ignore_flags']
488
+ gt_ignore_mask_i = results_patch['gt_ignore_mask_flags']
489
+
490
+ padw = x1_p - x1_c
491
+ padh = y1_p - y1_c
492
+ gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
493
+ gt_bboxes_i.translate_([padw, padh])
494
+ mosaic_bboxes.append(gt_bboxes_i)
495
+ mosaic_bboxes_labels.append(gt_bboxes_labels_i)
496
+ mosaic_ignore_flags.append(gt_ignore_flags_i)
497
+ mosaic_ignore_mask_flags.append(gt_ignore_mask_i)
498
+ if with_mask and results_patch.get('gt_masks', None) is not None:
499
+
500
+ gt_masks_i = results_patch['gt_masks']
501
+ gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
502
+
503
+ gt_masks_i = translate_bitmapmask(gt_masks_i,
504
+ out_shape=(int(self.img_scale[0] * 2),
505
+ int(self.img_scale[1] * 2)),
506
+ offset_x=padw, offset_y=padh)
507
+
508
+ # gt_masks_i = gt_masks_i.translate(
509
+ # out_shape=(int(self.img_scale[0] * 2),
510
+ # int(self.img_scale[1] * 2)),
511
+ # offset=padw,
512
+ # direction='horizontal')
513
+ # gt_masks_i = gt_masks_i.translate(
514
+ # out_shape=(int(self.img_scale[0] * 2),
515
+ # int(self.img_scale[1] * 2)),
516
+ # offset=padh,
517
+ # direction='vertical')
518
+ mosaic_masks.append(gt_masks_i)
519
+
520
+ mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
521
+ mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
522
+ mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
523
+ mosaic_ignore_mask_flags = np.concatenate(mosaic_ignore_mask_flags, 0)
524
+
525
+ if self.bbox_clip_border:
526
+ mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]])
527
+ # remove outside bboxes
528
+ inside_inds = mosaic_bboxes.is_inside(
529
+ [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy()
530
+
531
+ mosaic_bboxes = mosaic_bboxes[inside_inds]
532
+ mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
533
+ mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
534
+ mosaic_ignore_mask_flags = mosaic_ignore_mask_flags[inside_inds]
535
+
536
+ results['img'] = mosaic_img
537
+ results['img_shape'] = mosaic_img.shape
538
+ results['gt_bboxes'] = mosaic_bboxes
539
+ results['gt_bboxes_labels'] = mosaic_bboxes_labels
540
+ results['gt_ignore_flags'] = mosaic_ignore_flags
541
+ results['gt_ignore_mask_flags'] = mosaic_ignore_mask_flags
542
+
543
+
544
+ if with_mask:
545
+ total_instances = len(inside_inds)
546
+ assert total_instances == np.array([m.masks.shape[0] for m in mosaic_masks]).sum()
547
+ if total_instances > 10:
548
+ masks = np.empty((inside_inds.sum(), mosaic_masks[0].height, mosaic_masks[0].width), dtype=np.uint8)
549
+ msk_idx = 0
550
+ mmsk_idx = 0
551
+ for m in mosaic_masks:
552
+ for ii in range(m.masks.shape[0]):
553
+ if inside_inds[msk_idx]:
554
+ masks[mmsk_idx] = m.masks[ii]
555
+ mmsk_idx += 1
556
+ msk_idx += 1
557
+ results['gt_masks'] = BitmapMasks(masks, mosaic_masks[0].height, mosaic_masks[0].width)
558
+ else:
559
+ mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
560
+ results['gt_masks'] = mosaic_masks[inside_inds]
561
+ # assert np.all(results['gt_masks'].masks == masks) and results['gt_masks'].masks.shape == masks.shape
562
+
563
+ # assert inside_inds.sum() == results['gt_masks'].masks.shape[0]
564
+ return results
565
+
566
+ @TRANSFORMS.register_module()
567
+ class FilterAnnotationsNoSeg(FilterAnnotations):
568
+
569
+ def __init__(self,
570
+ min_gt_bbox_wh: Tuple[int, int] = (1, 1),
571
+ min_gt_mask_area: int = 1,
572
+ by_box: bool = True,
573
+ by_mask: bool = False,
574
+ keep_empty: bool = True) -> None:
575
+ # TODO: add more filter options
576
+ assert by_box or by_mask
577
+ self.min_gt_bbox_wh = min_gt_bbox_wh
578
+ self.min_gt_mask_area = min_gt_mask_area
579
+ self.by_box = by_box
580
+ self.by_mask = by_mask
581
+ self.keep_empty = keep_empty
582
+
583
+ @autocast_box_type()
584
+ def transform(self, results: dict) -> Union[dict, None]:
585
+ """Transform function to filter annotations.
586
+
587
+ Args:
588
+ results (dict): Result dict.
589
+
590
+ Returns:
591
+ dict: Updated result dict.
592
+ """
593
+ assert 'gt_bboxes' in results
594
+ gt_bboxes = results['gt_bboxes']
595
+ if gt_bboxes.shape[0] == 0:
596
+ return results
597
+
598
+ tests = []
599
+ if self.by_box:
600
+ tests.append(
601
+ ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
602
+ (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
603
+
604
+ if self.by_mask:
605
+ assert 'gt_masks' in results
606
+ gt_masks = results['gt_masks']
607
+ tests.append(gt_masks.areas >= self.min_gt_mask_area)
608
+
609
+ keep = tests[0]
610
+ for t in tests[1:]:
611
+ keep = keep & t
612
+
613
+ # if not keep.any():
614
+ # if self.keep_empty:
615
+ # return None
616
+
617
+ assert len(results['gt_ignore_flags']) == len(results['gt_ignore_mask_flags'])
618
+ keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags', 'gt_ignore_mask_flags')
619
+ for key in keys:
620
+ if key in results:
621
+ try:
622
+ results[key] = results[key][keep]
623
+ except Exception as e:
624
+ raise e
625
+
626
+ return results
627
+
628
+
629
+ def hcrop(results: dict, crop_size: Tuple[int, int],
630
+ allow_negative_crop: bool) -> Union[dict, None]:
631
+
632
+ assert crop_size[0] > 0 and crop_size[1] > 0
633
+ img = results['img']
634
+ offset_h, offset_w = 0, random.choice([0, crop_size[1]])
635
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
636
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
637
+
638
+ # Record the homography matrix for the RandomCrop
639
+ homography_matrix = np.array(
640
+ [[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
641
+ dtype=np.float32)
642
+ if results.get('homography_matrix', None) is None:
643
+ results['homography_matrix'] = homography_matrix
644
+ else:
645
+ results['homography_matrix'] = homography_matrix @ results[
646
+ 'homography_matrix']
647
+
648
+ # crop the image
649
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
650
+ img_shape = img.shape
651
+ results['img'] = img
652
+ results['img_shape'] = img_shape
653
+
654
+ # crop bboxes accordingly and clip to the image boundary
655
+ if results.get('gt_bboxes', None) is not None:
656
+ bboxes = results['gt_bboxes']
657
+ bboxes.translate_([-offset_w, -offset_h])
658
+ bboxes.clip_(img_shape[:2])
659
+ valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
660
+ # If the crop does not contain any gt-bbox area and
661
+ # allow_negative_crop is False, skip this image.
662
+ if (not valid_inds.any() and not allow_negative_crop):
663
+ return None
664
+
665
+ results['gt_bboxes'] = bboxes[valid_inds]
666
+
667
+ if results.get('gt_ignore_flags', None) is not None:
668
+ results['gt_ignore_flags'] = \
669
+ results['gt_ignore_flags'][valid_inds]
670
+
671
+ if results.get('gt_ignore_mask_flags', None) is not None:
672
+ results['gt_ignore_mask_flags'] = \
673
+ results['gt_ignore_mask_flags'][valid_inds]
674
+
675
+ if results.get('gt_bboxes_labels', None) is not None:
676
+ results['gt_bboxes_labels'] = \
677
+ results['gt_bboxes_labels'][valid_inds]
678
+
679
+ if results.get('gt_masks', None) is not None:
680
+ results['gt_masks'] = results['gt_masks'][
681
+ valid_inds.nonzero()[0]].crop(
682
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
683
+ results['gt_bboxes'] = results['gt_masks'].get_bboxes(
684
+ type(results['gt_bboxes']))
685
+
686
+ # crop semantic seg
687
+ if results.get('gt_seg_map', None) is not None:
688
+ results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
689
+ crop_x1:crop_x2]
690
+
691
+ return results
692
+
693
+
694
+ @TRANSFORMS.register_module()
695
+ class RandomCropNoSeg(RandomCrop):
696
+
697
+ def _crop_data(self, results: dict, crop_size: Tuple[int, int],
698
+ allow_negative_crop: bool) -> Union[dict, None]:
699
+
700
+ assert crop_size[0] > 0 and crop_size[1] > 0
701
+ img = results['img']
702
+ margin_h = max(img.shape[0] - crop_size[0], 0)
703
+ margin_w = max(img.shape[1] - crop_size[1], 0)
704
+ offset_h, offset_w = self._rand_offset((margin_h, margin_w))
705
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
706
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
707
+
708
+ # Record the homography matrix for the RandomCrop
709
+ homography_matrix = np.array(
710
+ [[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
711
+ dtype=np.float32)
712
+ if results.get('homography_matrix', None) is None:
713
+ results['homography_matrix'] = homography_matrix
714
+ else:
715
+ results['homography_matrix'] = homography_matrix @ results[
716
+ 'homography_matrix']
717
+
718
+ # crop the image
719
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
720
+ img_shape = img.shape
721
+ results['img'] = img
722
+ results['img_shape'] = img_shape
723
+
724
+ # crop bboxes accordingly and clip to the image boundary
725
+ if results.get('gt_bboxes', None) is not None:
726
+ bboxes = results['gt_bboxes']
727
+ bboxes.translate_([-offset_w, -offset_h])
728
+ if self.bbox_clip_border:
729
+ bboxes.clip_(img_shape[:2])
730
+ valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
731
+ # If the crop does not contain any gt-bbox area and
732
+ # allow_negative_crop is False, skip this image.
733
+ if (not valid_inds.any() and not allow_negative_crop):
734
+ return None
735
+
736
+ results['gt_bboxes'] = bboxes[valid_inds]
737
+
738
+ if results.get('gt_ignore_flags', None) is not None:
739
+ results['gt_ignore_flags'] = \
740
+ results['gt_ignore_flags'][valid_inds]
741
+
742
+ if results.get('gt_ignore_mask_flags', None) is not None:
743
+ results['gt_ignore_mask_flags'] = \
744
+ results['gt_ignore_mask_flags'][valid_inds]
745
+
746
+ if results.get('gt_bboxes_labels', None) is not None:
747
+ results['gt_bboxes_labels'] = \
748
+ results['gt_bboxes_labels'][valid_inds]
749
+
750
+ if results.get('gt_masks', None) is not None:
751
+ results['gt_masks'] = results['gt_masks'][
752
+ valid_inds.nonzero()[0]].crop(
753
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
754
+ if self.recompute_bbox:
755
+ results['gt_bboxes'] = results['gt_masks'].get_bboxes(
756
+ type(results['gt_bboxes']))
757
+
758
+ # crop semantic seg
759
+ if results.get('gt_seg_map', None) is not None:
760
+ results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
761
+ crop_x1:crop_x2]
762
+
763
+ return results
764
+
765
+
766
+
767
+ @TRANSFORMS.register_module()
768
+ class CachedMixUpNoSeg(CachedMixUp):
769
+
770
+ @autocast_box_type()
771
+ def transform(self, results: dict) -> dict:
772
+ """MixUp transform function.
773
+
774
+ Args:
775
+ results (dict): Result dict.
776
+
777
+ Returns:
778
+ dict: Updated result dict.
779
+ """
780
+ # cache and pop images
781
+ self.results_cache.append(copy.deepcopy(results))
782
+ if len(self.results_cache) > self.max_cached_images:
783
+ if self.random_pop:
784
+ index = random.randint(0, len(self.results_cache) - 1)
785
+ else:
786
+ index = 0
787
+ self.results_cache.pop(index)
788
+
789
+ if len(self.results_cache) <= 1:
790
+ return results
791
+
792
+ if random.uniform(0, 1) > self.prob:
793
+ return results
794
+
795
+ index = self.get_indexes(self.results_cache)
796
+ retrieve_results = copy.deepcopy(self.results_cache[index])
797
+
798
+ # TODO: refactor mixup to reuse these code.
799
+ if retrieve_results['gt_bboxes'].shape[0] == 0:
800
+ # empty bbox
801
+ return results
802
+
803
+ retrieve_img = retrieve_results['img']
804
+ with_mask = True if 'gt_masks' in results else False
805
+
806
+ jit_factor = random.uniform(*self.ratio_range)
807
+ is_filp = random.uniform(0, 1) > self.flip_ratio
808
+
809
+ if len(retrieve_img.shape) == 3:
810
+ out_img = np.ones(
811
+ (self.dynamic_scale[1], self.dynamic_scale[0], 3),
812
+ dtype=retrieve_img.dtype) * self.pad_val
813
+ else:
814
+ out_img = np.ones(
815
+ self.dynamic_scale[::-1],
816
+ dtype=retrieve_img.dtype) * self.pad_val
817
+
818
+ # 1. keep_ratio resize
819
+ scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0],
820
+ self.dynamic_scale[0] / retrieve_img.shape[1])
821
+ retrieve_img = mmcv.imresize(
822
+ retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
823
+ int(retrieve_img.shape[0] * scale_ratio)))
824
+
825
+ # 2. paste
826
+ out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
827
+
828
+ # 3. scale jit
829
+ scale_ratio *= jit_factor
830
+ out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
831
+ int(out_img.shape[0] * jit_factor)))
832
+
833
+ # 4. flip
834
+ if is_filp:
835
+ out_img = out_img[:, ::-1, :]
836
+
837
+ # 5. random crop
838
+ ori_img = results['img']
839
+ origin_h, origin_w = out_img.shape[:2]
840
+ target_h, target_w = ori_img.shape[:2]
841
+ padded_img = np.ones((max(origin_h, target_h), max(
842
+ origin_w, target_w), 3)) * self.pad_val
843
+ padded_img = padded_img.astype(np.uint8)
844
+ padded_img[:origin_h, :origin_w] = out_img
845
+
846
+ x_offset, y_offset = 0, 0
847
+ if padded_img.shape[0] > target_h:
848
+ y_offset = random.randint(0, padded_img.shape[0] - target_h)
849
+ if padded_img.shape[1] > target_w:
850
+ x_offset = random.randint(0, padded_img.shape[1] - target_w)
851
+ padded_cropped_img = padded_img[y_offset:y_offset + target_h,
852
+ x_offset:x_offset + target_w]
853
+
854
+ # 6. adjust bbox
855
+ retrieve_gt_bboxes = retrieve_results['gt_bboxes']
856
+ retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
857
+ if with_mask:
858
+ retrieve_gt_masks = retrieve_results['gt_masks'].rescale(
859
+ scale_ratio)
860
+
861
+ if self.bbox_clip_border:
862
+ retrieve_gt_bboxes.clip_([origin_h, origin_w])
863
+
864
+ if is_filp:
865
+ retrieve_gt_bboxes.flip_([origin_h, origin_w],
866
+ direction='horizontal')
867
+ if with_mask:
868
+ retrieve_gt_masks = retrieve_gt_masks.flip()
869
+
870
+ # 7. filter
871
+ cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
872
+ cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
873
+ if with_mask:
874
+
875
+ retrieve_gt_masks = translate_bitmapmask(retrieve_gt_masks,
876
+ out_shape=(target_h, target_w),
877
+ offset_x=-x_offset, offset_y=-y_offset)
878
+
879
+ # retrieve_gt_masks = retrieve_gt_masks.translate(
880
+ # out_shape=(target_h, target_w),
881
+ # offset=-x_offset,
882
+ # direction='horizontal')
883
+ # retrieve_gt_masks = retrieve_gt_masks.translate(
884
+ # out_shape=(target_h, target_w),
885
+ # offset=-y_offset,
886
+ # direction='vertical')
887
+
888
+ if self.bbox_clip_border:
889
+ cp_retrieve_gt_bboxes.clip_([target_h, target_w])
890
+
891
+ # 8. mix up
892
+ ori_img = ori_img.astype(np.float32)
893
+ mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
894
+
895
+ retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
896
+ retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
897
+ retrieve_gt_ignore_mask_flags = retrieve_results['gt_ignore_mask_flags']
898
+
899
+ mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
900
+ (results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
901
+ mixup_gt_bboxes_labels = np.concatenate(
902
+ (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
903
+ mixup_gt_ignore_flags = np.concatenate(
904
+ (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
905
+ mixup_gt_ignore_mask_flags = np.concatenate(
906
+ (results['gt_ignore_mask_flags'], retrieve_gt_ignore_mask_flags), axis=0)
907
+
908
+ if with_mask:
909
+ mixup_gt_masks = retrieve_gt_masks.cat(
910
+ [results['gt_masks'], retrieve_gt_masks])
911
+
912
+ # remove outside bbox
913
+ inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
914
+ mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
915
+ mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
916
+ mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
917
+ mixup_gt_ignore_mask_flags = mixup_gt_ignore_mask_flags[inside_inds]
918
+ if with_mask:
919
+ mixup_gt_masks = mixup_gt_masks[inside_inds]
920
+
921
+ results['img'] = mixup_img.astype(np.uint8)
922
+ results['img_shape'] = mixup_img.shape
923
+ results['gt_bboxes'] = mixup_gt_bboxes
924
+ results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
925
+ results['gt_ignore_flags'] = mixup_gt_ignore_flags
926
+ results['gt_ignore_mask_flags'] = mixup_gt_ignore_mask_flags
927
+ if with_mask:
928
+ results['gt_masks'] = mixup_gt_masks
929
+ return results
animeinsseg/data/maskrefine_dataset.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+
3
+ from torch.utils.data import Dataset, DataLoader
4
+ import pycocotools.mask as maskUtils
5
+ from pycocotools.coco import COCO
6
+ import random
7
+ import os.path as osp
8
+ import cv2
9
+ import numpy as np
10
+ from scipy.ndimage import distance_transform_bf, distance_transform_edt, distance_transform_cdt
11
+
12
+
13
+ def is_grey(img: np.ndarray):
14
+ if len(img.shape) == 3 and img.shape[2] == 3:
15
+ return False
16
+ else:
17
+ return True
18
+
19
+
20
+ def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value = (0, 0, 0)):
21
+ h, w = img.shape[:2]
22
+ pad_h, pad_w = 0, 0
23
+
24
+ # make square image
25
+ if w < h:
26
+ pad_w = h - w
27
+ w += pad_w
28
+ elif h < w:
29
+ pad_h = w - h
30
+ h += pad_h
31
+
32
+ pad_size = tgt_size - h
33
+ if pad_size > 0:
34
+ pad_h += pad_size
35
+ pad_w += pad_size
36
+
37
+ if pad_h > 0 or pad_w > 0:
38
+ c = 1
39
+ if is_grey(img):
40
+ if isinstance(pad_value, tuple):
41
+ pad_value = pad_value[0]
42
+ else:
43
+ if isinstance(pad_value, int):
44
+ pad_value = (pad_value, pad_value, pad_value)
45
+
46
+ img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value)
47
+
48
+ resize_ratio = tgt_size / img.shape[0]
49
+ if resize_ratio < 1:
50
+ img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
51
+ elif resize_ratio > 1:
52
+ img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR)
53
+
54
+ return img, resize_ratio, pad_h, pad_w
55
+
56
+
57
+ class MaskRefineDataset(Dataset):
58
+
59
+ def __init__(self,
60
+ refine_ann_path: str,
61
+ data_root: str,
62
+ load_instance_mask: bool = True,
63
+ aug_ins_prob: float = 0.,
64
+ ins_rect_prob: float = 0.,
65
+ output_size: int = 720,
66
+ augmentation: bool = False,
67
+ with_distance: bool = False):
68
+ self.load_instance_mask = load_instance_mask
69
+ self.ann_util = COCO(refine_ann_path)
70
+ self.img_ids = self.ann_util.getImgIds()
71
+ self.set_load_method(load_instance_mask)
72
+ self.data_root = data_root
73
+
74
+ self.ins_rect_prob = ins_rect_prob
75
+ self.aug_ins_prob = aug_ins_prob
76
+ self.augmentation = augmentation
77
+ if augmentation:
78
+ transform = [
79
+ A.OpticalDistortion(),
80
+ A.HorizontalFlip(),
81
+ A.CLAHE(),
82
+ A.Posterize(),
83
+ A.CropAndPad(percent=0.1, p=0.3, pad_mode=cv2.BORDER_CONSTANT, pad_cval=0, pad_cval_mask=0, keep_size=True),
84
+ A.RandomContrast(),
85
+ A.Rotate(30, p=0.3, mask_value=0, border_mode=cv2.BORDER_CONSTANT)
86
+ ]
87
+ self._aug_transform = A.Compose(transform)
88
+ else:
89
+ self._aug_transform = None
90
+
91
+ self.output_size = output_size
92
+ self.with_distance = with_distance
93
+
94
+ def set_output_size(self, size: int):
95
+ self.output_size = size
96
+
97
+ def set_load_method(self, load_instance_mask: bool):
98
+ if load_instance_mask:
99
+ self._load_mask = self._load_with_instance
100
+ else:
101
+ self._load_mask = self._load_without_instance
102
+
103
+ def __getitem__(self, idx: int):
104
+ img_id = self.img_ids[idx]
105
+ img_meta = self.ann_util.imgs[img_id]
106
+ img_path = osp.join(self.data_root, img_meta['file_name'])
107
+ img = cv2.imread(img_path)
108
+
109
+ annids = self.ann_util.getAnnIds([img_id])
110
+ if len(annids) > 0:
111
+ ann = random.choice(annids)
112
+ ann = self.ann_util.anns[ann]
113
+ assert ann['image_id'] == img_id
114
+ else:
115
+ ann = None
116
+
117
+ return self._load_mask(img, ann)
118
+
119
+ def transform(self, img: np.ndarray, mask: np.ndarray, ins_seg: np.ndarray = None) -> dict:
120
+ if ins_seg is not None:
121
+ use_seg = True
122
+ else:
123
+ use_seg = False
124
+
125
+ if self.augmentation:
126
+ masks = [mask]
127
+ if use_seg:
128
+ masks.append(ins_seg)
129
+ data = self._aug_transform(image=img, masks=masks)
130
+ img = data['image']
131
+ masks = data['masks']
132
+ mask = masks[0]
133
+ if use_seg:
134
+ ins_seg = masks[1]
135
+
136
+ img = square_pad_resize(img, self.output_size, random.randint(0, 255))[0]
137
+ mask = square_pad_resize(mask, self.output_size, 0)[0]
138
+ if ins_seg is not None:
139
+ ins_seg = square_pad_resize(ins_seg, self.output_size, 0)[0]
140
+
141
+ img = (img.astype(np.float32) / 255.).transpose((2, 0, 1))
142
+ mask = mask[None, ...]
143
+
144
+
145
+ if use_seg:
146
+ ins_seg = ins_seg[None, ...]
147
+ img = np.concatenate((img, ins_seg), axis=0)
148
+
149
+ data = {'img': img, 'mask': mask}
150
+ if self.with_distance:
151
+ dist = distance_transform_edt(mask[0])
152
+ dist_max = dist.max()
153
+ if dist_max != 0:
154
+ dist = 1 - dist / dist_max
155
+ # diff_mat = cv2.bitwise_xor(mask[0], ins_seg[0])
156
+ # dist = dist + diff_mat + 0.2
157
+ dist = dist + 0.2
158
+ dist = dist.size / (dist.sum() + 1) * dist
159
+ dist = np.clip(dist, 0, 20)
160
+ else:
161
+ dist = np.ones_like(dist)
162
+ # print(dist.max(), dist.min())
163
+ data['dist_weight'] = dist[None, ...]
164
+ return data
165
+
166
+ def _load_with_instance(self, img: np.ndarray, ann: dict):
167
+ if ann is None:
168
+ mask = np.zeros(img.shape[:2], dtype=np.float32)
169
+ ins_seg = mask
170
+ else:
171
+ mask = maskUtils.decode(ann['segmentation']).astype(np.float32)
172
+ if self.augmentation and random.random() < self.ins_rect_prob:
173
+ ins_seg = np.zeros_like(mask)
174
+ bbox = [int(b) for b in ann['bbox']]
175
+ ins_seg[bbox[1]: bbox[1] + bbox[3], bbox[0]: bbox[0] + bbox[2]] = 1
176
+ elif len(ann['pred_segmentations']) > 0:
177
+ ins_seg = random.choice(ann['pred_segmentations'])
178
+ ins_seg = maskUtils.decode(ins_seg).astype(np.float32)
179
+ else:
180
+ ins_seg = mask
181
+ if self.augmentation and random.random() < self.aug_ins_prob:
182
+ ksize = random.choice([1, 3, 5, 7])
183
+ ksize = ksize * 2 + 1
184
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(ksize, ksize))
185
+ if random.random() < 0.5:
186
+ ins_seg = cv2.dilate(ins_seg, kernel)
187
+ else:
188
+ ins_seg = cv2.erode(ins_seg, kernel)
189
+
190
+ return self.transform(img, mask, ins_seg)
191
+
192
+ def _load_without_instance(self, img: np.ndarray, ann: dict):
193
+ if ann is None:
194
+ mask = np.zeros(img.shape[:2], dtype=np.float32)
195
+ else:
196
+ mask = maskUtils.decode(ann['segmentation']).astype(np.float32)
197
+ return self.transform(img, mask)
198
+
199
+ def __len__(self):
200
+ return len(self.img_ids)
201
+
202
+
203
+ if __name__ == '__main__':
204
+ ann_path = r'workspace/test_syndata/annotations/refine_train.json'
205
+ data_root = r'workspace/test_syndata/train'
206
+
207
+ ann_path = r'workspace/test_syndata/annotations/refine_train.json'
208
+ data_root = r'workspace/test_syndata/train'
209
+ aug_ins_prob = 0.5
210
+ load_instance_mask = True
211
+ ins_rect_prob = 0.25
212
+ output_size = 640
213
+ augmentation = True
214
+
215
+ random.seed(0)
216
+
217
+ md = MaskRefineDataset(ann_path, data_root, load_instance_mask, aug_ins_prob, ins_rect_prob, output_size, augmentation, with_distance=True)
218
+
219
+ dl = DataLoader(md, batch_size=1, shuffle=False, persistent_workers=True,
220
+ num_workers=1, pin_memory=True)
221
+ for data in dl:
222
+ img = data['img'].cpu().numpy()
223
+ img = (img[0, :3].transpose((1, 2, 0)) * 255).astype(np.uint8)
224
+ mask = (data['mask'].cpu().numpy()[0][0] * 255).astype(np.uint8)
225
+ if load_instance_mask:
226
+ ins = (data['img'].cpu().numpy()[0][3] * 255).astype(np.uint8)
227
+ cv2.imshow('ins', ins)
228
+ dist = data['dist_weight'].cpu().numpy()[0][0]
229
+ dist = (dist / dist.max() * 255).astype(np.uint8)
230
+ cv2.imshow('img', img)
231
+ cv2.imshow('mask', mask)
232
+ cv2.imshow('dist_weight', dist)
233
+ cv2.waitKey(0)
234
+
235
+ # cv2.imwrite('')
animeinsseg/data/metrics.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import datetime
3
+ import itertools
4
+ import os.path as osp
5
+ import tempfile
6
+ from collections import OrderedDict
7
+ from typing import Dict, List, Optional, Sequence, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from mmengine.evaluator import BaseMetric
12
+ from mmengine.fileio import FileClient, dump, load
13
+ from mmengine.logging import MMLogger
14
+ from terminaltables import AsciiTable
15
+
16
+ from mmdet.datasets.api_wrappers import COCO, COCOeval
17
+ from mmdet.registry import METRICS
18
+ from mmdet.structures.mask import encode_mask_results
19
+ # from ..functional import eval_recalls
20
+ from mmdet.evaluation.metrics import CocoMetric
21
+
22
+
23
+ @METRICS.register_module()
24
+ class AnimeMangaMetric(CocoMetric):
25
+
26
+ def __init__(self,
27
+ manga109_annfile=None,
28
+ animeins_annfile=None,
29
+ ann_file: Optional[str] = None,
30
+ metric: Union[str, List[str]] = 'bbox',
31
+ classwise: bool = False,
32
+ proposal_nums: Sequence[int] = (100, 300, 1000),
33
+ iou_thrs: Optional[Union[float, Sequence[float]]] = None,
34
+ metric_items: Optional[Sequence[str]] = None,
35
+ format_only: bool = False,
36
+ outfile_prefix: Optional[str] = None,
37
+ file_client_args: dict = dict(backend='disk'),
38
+ collect_device: str = 'cpu',
39
+ prefix: Optional[str] = None,
40
+ sort_categories: bool = False) -> None:
41
+
42
+ super().__init__(ann_file, metric, classwise, proposal_nums, iou_thrs, metric_items, format_only, outfile_prefix, file_client_args, collect_device, prefix, sort_categories)
43
+
44
+ self.manga109_img_ids = set()
45
+ if manga109_annfile is not None:
46
+ with self.file_client.get_local_path(manga109_annfile) as local_path:
47
+ self._manga109_coco_api = COCO(local_path)
48
+ if sort_categories:
49
+ # 'categories' list in objects365_train.json and
50
+ # objects365_val.json is inconsistent, need sort
51
+ # list(or dict) before get cat_ids.
52
+ cats = self._manga109_coco_api.cats
53
+ sorted_cats = {i: cats[i] for i in sorted(cats)}
54
+ self._manga109_coco_api.cats = sorted_cats
55
+ categories = self._manga109_coco_api.dataset['categories']
56
+ sorted_categories = sorted(
57
+ categories, key=lambda i: i['id'])
58
+ self._manga109_coco_api.dataset['categories'] = sorted_categories
59
+ self.manga109_img_ids = set(self._manga109_coco_api.get_img_ids())
60
+ else:
61
+ self._manga109_coco_api = None
62
+
63
+ self.animeins_img_ids = set()
64
+ if animeins_annfile is not None:
65
+ with self.file_client.get_local_path(animeins_annfile) as local_path:
66
+ self._animeins_coco_api = COCO(local_path)
67
+ if sort_categories:
68
+ # 'categories' list in objects365_train.json and
69
+ # objects365_val.json is inconsistent, need sort
70
+ # list(or dict) before get cat_ids.
71
+ cats = self._animeins_coco_api.cats
72
+ sorted_cats = {i: cats[i] for i in sorted(cats)}
73
+ self._animeins_coco_api.cats = sorted_cats
74
+ categories = self._animeins_coco_api.dataset['categories']
75
+ sorted_categories = sorted(
76
+ categories, key=lambda i: i['id'])
77
+ self._animeins_coco_api.dataset['categories'] = sorted_categories
78
+ self.animeins_img_ids = set(self._animeins_coco_api.get_img_ids())
79
+ else:
80
+ self._animeins_coco_api = None
81
+
82
+ if self._animeins_coco_api is not None:
83
+ self._coco_api = self._animeins_coco_api
84
+ else:
85
+ self._coco_api = self._manga109_coco_api
86
+
87
+
88
+ def compute_metrics(self, results: list) -> Dict[str, float]:
89
+
90
+ # split gt and prediction list
91
+ gts, preds = zip(*results)
92
+
93
+ manga109_gts, animeins_gts = [], []
94
+ manga109_preds, animeins_preds = [], []
95
+ for gt, pred in zip(gts, preds):
96
+ if gt['img_id'] in self.manga109_img_ids:
97
+ manga109_gts.append(gt)
98
+ manga109_preds.append(pred)
99
+ else:
100
+ animeins_gts.append(gt)
101
+ animeins_preds.append(pred)
102
+
103
+ tmp_dir = None
104
+ if self.outfile_prefix is None:
105
+ tmp_dir = tempfile.TemporaryDirectory()
106
+ outfile_prefix = osp.join(tmp_dir.name, 'results')
107
+ else:
108
+ outfile_prefix = self.outfile_prefix
109
+
110
+ eval_results = OrderedDict()
111
+
112
+ if len(manga109_gts) > 0:
113
+ metrics = []
114
+ for m in self.metrics:
115
+ if m != 'segm':
116
+ metrics.append(m)
117
+
118
+ self.cat_ids = self._manga109_coco_api.get_cat_ids(cat_names=self.dataset_meta['classes'])
119
+ self.img_ids = self._manga109_coco_api.get_img_ids()
120
+ rst = self._compute_metrics(metrics, self._manga109_coco_api, manga109_preds, outfile_prefix, tmp_dir)
121
+ for key, item in rst.items():
122
+ eval_results['manga109_'+key] = item
123
+
124
+ if len(animeins_gts) > 0:
125
+ self.cat_ids = self._animeins_coco_api.get_cat_ids(cat_names=self.dataset_meta['classes'])
126
+ self.img_ids = self._animeins_coco_api.get_img_ids()
127
+ rst = self._compute_metrics(self.metrics, self._animeins_coco_api, animeins_preds, outfile_prefix, tmp_dir)
128
+ for key, item in rst.items():
129
+ eval_results['animeins_'+key] = item
130
+
131
+ return eval_results
132
+
133
+ def results2json(self, results: Sequence[dict],
134
+ outfile_prefix: str) -> dict:
135
+ """Dump the detection results to a COCO style json file.
136
+
137
+ There are 3 types of results: proposals, bbox predictions, mask
138
+ predictions, and they have different data types. This method will
139
+ automatically recognize the type, and dump them to json files.
140
+
141
+ Args:
142
+ results (Sequence[dict]): Testing results of the
143
+ dataset.
144
+ outfile_prefix (str): The filename prefix of the json files. If the
145
+ prefix is "somepath/xxx", the json files will be named
146
+ "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
147
+ "somepath/xxx.proposal.json".
148
+
149
+ Returns:
150
+ dict: Possible keys are "bbox", "segm", "proposal", and
151
+ values are corresponding filenames.
152
+ """
153
+ bbox_json_results = []
154
+ segm_json_results = [] if 'masks' in results[0] else None
155
+ for idx, result in enumerate(results):
156
+ image_id = result.get('img_id', idx)
157
+ labels = result['labels']
158
+ bboxes = result['bboxes']
159
+ scores = result['scores']
160
+ # bbox results
161
+ for i, label in enumerate(labels):
162
+ data = dict()
163
+ data['image_id'] = image_id
164
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
165
+ data['score'] = float(scores[i])
166
+ data['category_id'] = self.cat_ids[label]
167
+ bbox_json_results.append(data)
168
+
169
+ if segm_json_results is None:
170
+ continue
171
+
172
+ # segm results
173
+ masks = result['masks']
174
+ mask_scores = result.get('mask_scores', scores)
175
+ for i, label in enumerate(labels):
176
+ data = dict()
177
+ data['image_id'] = image_id
178
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
179
+ data['score'] = float(mask_scores[i])
180
+ data['category_id'] = self.cat_ids[label]
181
+ if isinstance(masks[i]['counts'], bytes):
182
+ masks[i]['counts'] = masks[i]['counts'].decode()
183
+ data['segmentation'] = masks[i]
184
+ segm_json_results.append(data)
185
+
186
+ logger: MMLogger = MMLogger.get_current_instance()
187
+ logger.info('dumping predictions ... ')
188
+ result_files = dict()
189
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
190
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
191
+ dump(bbox_json_results, result_files['bbox'])
192
+
193
+ if segm_json_results is not None:
194
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
195
+ dump(segm_json_results, result_files['segm'])
196
+
197
+ return result_files
198
+
199
+ def _compute_metrics(self, metrics, tgt_api, preds, outfile_prefix, tmp_dir):
200
+ logger: MMLogger = MMLogger.get_current_instance()
201
+
202
+ result_files = self.results2json(preds, outfile_prefix)
203
+
204
+ eval_results = OrderedDict()
205
+ if self.format_only:
206
+ logger.info('results are saved in '
207
+ f'{osp.dirname(outfile_prefix)}')
208
+ return eval_results
209
+
210
+ for metric in metrics:
211
+ logger.info(f'Evaluating {metric}...')
212
+
213
+ # TODO: May refactor fast_eval_recall to an independent metric?
214
+ # fast eval recall
215
+ if metric == 'proposal_fast':
216
+ ar = self.fast_eval_recall(
217
+ preds, self.proposal_nums, self.iou_thrs, logger=logger)
218
+ log_msg = []
219
+ for i, num in enumerate(self.proposal_nums):
220
+ eval_results[f'AR@{num}'] = ar[i]
221
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
222
+ log_msg = ''.join(log_msg)
223
+ logger.info(log_msg)
224
+ continue
225
+
226
+ # evaluate proposal, bbox and segm
227
+ iou_type = 'bbox' if metric == 'proposal' else metric
228
+ if metric not in result_files:
229
+ raise KeyError(f'{metric} is not in results')
230
+ try:
231
+ predictions = load(result_files[metric])
232
+ if iou_type == 'segm':
233
+ # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
234
+ # When evaluating mask AP, if the results contain bbox,
235
+ # cocoapi will use the box area instead of the mask area
236
+ # for calculating the instance area. Though the overall AP
237
+ # is not affected, this leads to different
238
+ # small/medium/large mask AP results.
239
+ for x in predictions:
240
+ x.pop('bbox')
241
+ coco_dt = tgt_api.loadRes(predictions)
242
+
243
+ except IndexError:
244
+ logger.error(
245
+ 'The testing results of the whole dataset is empty.')
246
+ break
247
+
248
+ coco_eval = COCOeval(tgt_api, coco_dt, iou_type)
249
+
250
+ coco_eval.params.catIds = self.cat_ids
251
+ coco_eval.params.imgIds = self.img_ids
252
+ coco_eval.params.maxDets = list(self.proposal_nums)
253
+ coco_eval.params.iouThrs = self.iou_thrs
254
+
255
+ # mapping of cocoEval.stats
256
+ coco_metric_names = {
257
+ 'mAP': 0,
258
+ 'mAP_50': 1,
259
+ 'mAP_75': 2,
260
+ 'mAP_s': 3,
261
+ 'mAP_m': 4,
262
+ 'mAP_l': 5,
263
+ 'AR@100': 6,
264
+ 'AR@300': 7,
265
+ 'AR@1000': 8,
266
+ 'AR_s@1000': 9,
267
+ 'AR_m@1000': 10,
268
+ 'AR_l@1000': 11
269
+ }
270
+ metric_items = self.metric_items
271
+ if metric_items is not None:
272
+ for metric_item in metric_items:
273
+ if metric_item not in coco_metric_names:
274
+ raise KeyError(
275
+ f'metric item "{metric_item}" is not supported')
276
+
277
+ if metric == 'proposal':
278
+ coco_eval.params.useCats = 0
279
+ coco_eval.evaluate()
280
+ coco_eval.accumulate()
281
+ coco_eval.summarize()
282
+ if metric_items is None:
283
+ metric_items = [
284
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
285
+ 'AR_m@1000', 'AR_l@1000'
286
+ ]
287
+
288
+ for item in metric_items:
289
+ val = float(
290
+ f'{coco_eval.stats[coco_metric_names[item]]:.3f}')
291
+ eval_results[item] = val
292
+ else:
293
+ coco_eval.evaluate()
294
+ coco_eval.accumulate()
295
+ coco_eval.summarize()
296
+ if self.classwise: # Compute per-category AP
297
+ # Compute per-category AP
298
+ # from https://github.com/facebookresearch/detectron2/
299
+ precisions = coco_eval.eval['precision']
300
+ # precision: (iou, recall, cls, area range, max dets)
301
+ assert len(self.cat_ids) == precisions.shape[2]
302
+
303
+ results_per_category = []
304
+ for idx, cat_id in enumerate(self.cat_ids):
305
+ # area range index 0: all area ranges
306
+ # max dets index -1: typically 100 per image
307
+ nm = tgt_api.loadCats(cat_id)[0]
308
+ precision = precisions[:, :, idx, 0, -1]
309
+ precision = precision[precision > -1]
310
+ if precision.size:
311
+ ap = np.mean(precision)
312
+ else:
313
+ ap = float('nan')
314
+ results_per_category.append(
315
+ (f'{nm["name"]}', f'{round(ap, 3)}'))
316
+ eval_results[f'{nm["name"]}_precision'] = round(ap, 3)
317
+
318
+ num_columns = min(6, len(results_per_category) * 2)
319
+ results_flatten = list(
320
+ itertools.chain(*results_per_category))
321
+ headers = ['category', 'AP'] * (num_columns // 2)
322
+ results_2d = itertools.zip_longest(*[
323
+ results_flatten[i::num_columns]
324
+ for i in range(num_columns)
325
+ ])
326
+ table_data = [headers]
327
+ table_data += [result for result in results_2d]
328
+ table = AsciiTable(table_data)
329
+ logger.info('\n' + table.table)
330
+
331
+ if metric_items is None:
332
+ metric_items = [
333
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
334
+ ]
335
+
336
+ for metric_item in metric_items:
337
+ key = f'{metric}_{metric_item}'
338
+ val = coco_eval.stats[coco_metric_names[metric_item]]
339
+ eval_results[key] = float(f'{round(val, 3)}')
340
+
341
+ ap = coco_eval.stats[:6]
342
+ logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} '
343
+ f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
344
+ f'{ap[4]:.3f} {ap[5]:.3f}')
345
+
346
+ if tmp_dir is not None:
347
+ tmp_dir.cleanup()
348
+ return eval_results
animeinsseg/data/paste_methods.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Union, Tuple, Dict
3
+ import random
4
+ from PIL import Image
5
+ import cv2
6
+ import os.path as osp
7
+ from tqdm import tqdm
8
+ from panopticapi.utils import rgb2id, id2rgb
9
+ from time import time
10
+ import traceback
11
+
12
+ from utils.io_utils import bbox_overlap_area
13
+ from utils.logger import LOGGER
14
+ from utils.constants import COLOR_PALETTE
15
+
16
+
17
+
18
+ class PartitionTree:
19
+
20
+ def __init__(self, bleft: int, btop: int, bright: int, bbottom: int, parent = None) -> None:
21
+ self.left: PartitionTree = None
22
+ self.right: PartitionTree = None
23
+ self.top: PartitionTree = None
24
+ self.bottom: PartitionTree = None
25
+
26
+ if bright < bleft:
27
+ bright = bleft
28
+ if bbottom < btop:
29
+ bbottom = btop
30
+
31
+ self.bleft = bleft
32
+ self.bright = bright
33
+ self.btop = btop
34
+ self.bbottom = bbottom
35
+ self.parent: PartitionTree = parent
36
+
37
+ def is_leaf(self):
38
+ return self.left is None
39
+
40
+ def new_partition(self, new_rect: List):
41
+ self.left = PartitionTree(self.bleft, self.btop, new_rect[0], self.bbottom, self)
42
+ self.top = PartitionTree(self.bleft, self.btop, self.bright, new_rect[1], self)
43
+ self.right = PartitionTree(new_rect[2], self.btop, self.bright, self.bbottom, self)
44
+ self.bottom = PartitionTree(self.bleft, new_rect[3], self.bright, self.bbottom, self)
45
+ if self.parent is not None:
46
+ self.root_update_rect(new_rect)
47
+
48
+ def root_update_rect(self, rect):
49
+ root = self.get_root()
50
+ root.update_child_rect(rect)
51
+
52
+ def update_child_rect(self, rect: List):
53
+ if self.is_leaf():
54
+ self.update_from_rect(rect)
55
+ else:
56
+ self.left.update_child_rect(rect)
57
+ self.right.update_child_rect(rect)
58
+ self.top.update_child_rect(rect)
59
+ self.bottom.update_child_rect(rect)
60
+
61
+ def get_root(self):
62
+ if self.parent is not None:
63
+ return self.parent.get_root()
64
+ else:
65
+ return self
66
+
67
+
68
+ def update_from_rect(self, rect: List):
69
+ if not self.is_leaf():
70
+ return
71
+ ix = min(self.bright, rect[2]) - max(self.bleft, rect[0])
72
+ iy = min(self.bbottom, rect[3]) - max(self.btop, rect[1])
73
+ if not (ix > 0 and iy > 0):
74
+ return
75
+
76
+ new_ltrb0 = np.array([self.bleft, self.btop, self.bright, self.bbottom])
77
+ new_ltrb1 = new_ltrb0.copy()
78
+
79
+ if rect[0] > self.bleft and rect[0] < self.bright:
80
+ new_ltrb0[2] = rect[0]
81
+ else:
82
+ new_ltrb0[0] = rect[2]
83
+
84
+ if rect[1] > self.btop and rect[1] < self.bbottom:
85
+ new_ltrb1[3]= rect[1]
86
+ else:
87
+ new_ltrb1[1] = rect[3]
88
+
89
+ if (new_ltrb0[2:] - new_ltrb0[:2]).prod() > (new_ltrb1[2:] - new_ltrb1[:2]).prod():
90
+ self.bleft, self.btop, self.bright, self.bbottom = new_ltrb0
91
+ else:
92
+ self.bleft, self.btop, self.bright, self.bbottom = new_ltrb1
93
+
94
+ @property
95
+ def width(self) -> int:
96
+ return self.bright - self.bleft
97
+
98
+ @property
99
+ def height(self) -> int:
100
+ return self.bbottom - self.btop
101
+
102
+ def prefer_partition(self, tgt_h: int, tgt_w: int):
103
+ if self.is_leaf():
104
+ return self, min(self.width / tgt_w, 1.2) * min(self.height / tgt_h, 1.2)
105
+ else:
106
+ lp, ls = self.left.prefer_partition(tgt_h, tgt_w)
107
+ rp, rs = self.right.prefer_partition(tgt_h, tgt_w)
108
+ tp, ts = self.top.prefer_partition(tgt_h, tgt_w)
109
+ bp, bs = self.bottom.prefer_partition(tgt_h, tgt_w)
110
+ preferp = [(p, s) for s, p in sorted(zip([ls, rs, ts, bs],[lp, rp, tp, bp]), key=lambda pair: pair[0], reverse=True)][0]
111
+ return preferp
112
+
113
+ def new_random_pos(self, fg_h: int, fg_w: int, im_h: int, im_w: int, random_sample: bool = False):
114
+ extx, exty = int(fg_w / 3), int(fg_h / 3)
115
+ extxb, extyb = int(fg_w / 10), int(fg_h / 10)
116
+ region_w, region_h = self.width + extx, self.height + exty
117
+ downscale_ratio = max(min(region_w / fg_w, region_h / fg_h), 0.8)
118
+ if downscale_ratio < 1:
119
+ fg_h = int(downscale_ratio * fg_h)
120
+ fg_w = int(downscale_ratio * fg_w)
121
+
122
+ max_x, max_y = self.bright + extx - fg_w, self.bbottom + exty - fg_h
123
+ max_x = min(im_w+extxb-fg_w, max_x)
124
+ max_y = min(im_h+extyb-fg_h, max_y)
125
+ min_x = max(min(self.bright + extx - fg_w, self.bleft - extx), -extx)
126
+ min_x = max(-extxb, min_x)
127
+ min_y = max(min(self.bbottom + exty - fg_h, self.btop - exty), -exty)
128
+ min_y = max(-extyb, min_y)
129
+ px, py = min_x, min_y
130
+ if min_x < max_x:
131
+ if random_sample:
132
+ px = random.randint(min_x, max_x)
133
+ else:
134
+ px = int((min_x + max_x) / 2)
135
+ if min_y < max_y:
136
+ if random_sample:
137
+ py = random.randint(min_y, max_y)
138
+ else:
139
+ py = int((min_y + max_y) / 2)
140
+ return px, py, downscale_ratio
141
+
142
+ def drawpartition(self, image: np.ndarray, color = None):
143
+ if color is None:
144
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
145
+ if not self.is_leaf():
146
+ cv2.rectangle(image, (self.bleft, self.btop), (self.bright, self.bbottom), color, 2)
147
+ if not self.is_leaf():
148
+ c = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
149
+ self.left.drawpartition(image, c)
150
+ self.right.drawpartition(image, c)
151
+ self.top.drawpartition(image, c)
152
+ self.bottom.drawpartition(image, c)
153
+
154
+
155
+ def paste_one_fg(fg_pil: Image, bg: Image, segments: np.ndarray, px: int, py: int, seg_color: Tuple, cal_area=True):
156
+
157
+ fg_h, fg_w = fg_pil.height, fg_pil.width
158
+ im_h, im_w = bg.height, bg.width
159
+
160
+ bg.paste(fg_pil, (px, py), mask=fg_pil)
161
+
162
+
163
+ bgx1, bgx2, bgy1, bgy2 = px, px+fg_w, py, py+fg_h
164
+ fgx1, fgx2, fgy1, fgy2 = 0, fg_w, 0, fg_h
165
+ if bgx1 < 0:
166
+ fgx1 = -bgx1
167
+ bgx1 = 0
168
+ if bgy1 < 0:
169
+ fgy1 = -bgy1
170
+ bgy1 = 0
171
+ if bgx2 > im_w:
172
+ fgx2 = im_w - bgx2
173
+ bgx2 = im_w
174
+ if bgy2 > im_h:
175
+ fgy2 = im_h - bgy2
176
+ bgy2 = im_h
177
+
178
+ fg_mask = np.array(fg_pil)[fgy1: fgy2, fgx1: fgx2, 3] > 30
179
+ segments[bgy1: bgy2, bgx1: bgx2][np.where(fg_mask)] = seg_color
180
+
181
+ if cal_area:
182
+ area = fg_mask.sum()
183
+ else:
184
+ area = 1
185
+ bbox = [bgx1, bgy1, bgx2-bgx1, bgy2-bgy1]
186
+ return area, bbox, [bgx1, bgy1, bgx2, bgy2]
187
+
188
+
189
+ def partition_paste(fg_list, bg: Image):
190
+ segments_info = []
191
+
192
+ fg_list.sort(key = lambda x: x['image'].shape[0] * x['image'].shape[1], reverse=True)
193
+ pnode: PartitionTree = None
194
+ im_h, im_w = bg.height, bg.width
195
+
196
+ ptree = PartitionTree(0, 0, bg.width, bg.height)
197
+
198
+ segments = np.zeros((im_h, im_w, 3), np.uint8)
199
+ for ii, fg_dict in enumerate(fg_list):
200
+ fg = fg_dict['image']
201
+ fg_h, fg_w = fg.shape[:2]
202
+ pnode, _ = ptree.prefer_partition(fg_h, fg_w)
203
+ px, py, downscale_ratio = pnode.new_random_pos(fg_h, fg_w, im_h, im_w, True)
204
+
205
+ fg_pil = Image.fromarray(fg)
206
+ if downscale_ratio < 1:
207
+ fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS)
208
+ # fg_h, fg_w = fg_pil.height, fg_pil.width
209
+
210
+ seg_color = COLOR_PALETTE[ii]
211
+ area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=False)
212
+ pnode.new_partition(xyxy)
213
+
214
+ segments_info.append({
215
+ 'id': rgb2id(seg_color),
216
+ 'bbox': bbox,
217
+ 'area': area
218
+ })
219
+
220
+ return segments_info, segments
221
+ # if downscale_ratio < 1:
222
+ # fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS)
223
+ # fg_h, fg_w = fg_pil.height, fg_pil.width
224
+
225
+
226
+ def gen_fg_regbboxes(fg_list: List[Dict], tgt_size: int, min_overlap=0.15, max_overlap=0.8):
227
+
228
+ def _sample_y(h):
229
+ y = (tgt_size - h) // 2
230
+ if y > 0:
231
+ yrange = min(y, h // 4)
232
+ y += random.randint(-yrange, yrange)
233
+ return y
234
+ else:
235
+ return 0
236
+
237
+ shape_list = []
238
+ depth_list = []
239
+
240
+
241
+ for fg_dict in fg_list:
242
+ shape_list.append(fg_dict['image'].shape[:2])
243
+
244
+ shape_list = np.array(shape_list)
245
+ depth_list = np.random.random(len(fg_list))
246
+ depth_list[shape_list[..., 1] > 0.6 * tgt_size] += 1
247
+
248
+ # num_fg = len(fg_list)
249
+ # grid_sample = random.random() < 0.4 or num_fg > 6
250
+ # grid_sample = grid_sample and num_fg < 9 and num_fg > 3
251
+ # grid_sample = False
252
+ # if grid_sample:
253
+ # grid_pos = np.arange(9)
254
+ # np.random.shuffle(grid_pos)
255
+ # grid_pos = grid_pos[: num_fg]
256
+ # grid_x = grid_pos % 3
257
+ # grid_y = grid_pos // 3
258
+
259
+ # else:
260
+ pos_list = [[0, _sample_y(shape_list[0][0])]]
261
+ pre_overlap = 0
262
+ for ii, ((h, w), d) in enumerate(zip(shape_list[1:], depth_list[1:])):
263
+ (preh, prew), predepth, (prex, prey) = shape_list[ii], depth_list[ii], pos_list[ii]
264
+
265
+ isfg = d < predepth
266
+ y = _sample_y(h)
267
+ x = prex+prew
268
+ if isfg:
269
+ min_x = max_x = x
270
+ if pre_overlap < max_overlap:
271
+ min_x -= (max_overlap - pre_overlap) * prew
272
+ min_x = int(min_x)
273
+ if pre_overlap < min_overlap:
274
+ max_x -= (min_overlap - pre_overlap) * prew
275
+ max_x = int(max_x)
276
+ x = random.randint(min_x, max_x)
277
+ pre_overlap = 0
278
+ else:
279
+ overlap = random.uniform(min_overlap, max_overlap)
280
+ x -= int(overlap * w)
281
+ area = h * w
282
+ overlap_area = bbox_overlap_area([x, y, w, h], [prex, prey, prew, preh])
283
+ pre_overlap = overlap_area / area
284
+
285
+ pos_list.append([x, y])
286
+
287
+ pos_list = np.array(pos_list)
288
+ last_x2 = pos_list[-1][0] + shape_list[-1][1]
289
+ valid_shiftx = tgt_size - last_x2
290
+ if valid_shiftx > 0:
291
+ shiftx = random.randint(0, valid_shiftx)
292
+ pos_list[:, 0] += shiftx
293
+ else:
294
+ pos_list[:, 0] += valid_shiftx // 2
295
+
296
+ for pos, fg_dict, depth in zip(pos_list, fg_list, depth_list):
297
+ fg_dict['pos'] = pos
298
+ fg_dict['depth'] = depth
299
+ fg_list.sort(key=lambda x: x['depth'], reverse=True)
300
+
301
+
302
+
303
+ def regular_paste(fg_list, bg: Image, regen_bboxes=False):
304
+ segments_info = []
305
+ im_h, im_w = bg.height, bg.width
306
+
307
+ if regen_bboxes:
308
+ random.shuffle(fg_list)
309
+ gen_fg_regbboxes(fg_list, im_h)
310
+
311
+ segments = np.zeros((im_h, im_w, 3), np.uint8)
312
+ for ii, fg_dict in enumerate(fg_list):
313
+ fg = fg_dict['image']
314
+
315
+ px, py = fg_dict.pop('pos')
316
+ fg_pil = Image.fromarray(fg)
317
+
318
+ seg_color = COLOR_PALETTE[ii]
319
+ area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=True)
320
+
321
+ segments_info.append({
322
+ 'id': rgb2id(seg_color),
323
+ 'bbox': bbox,
324
+ 'area': area
325
+ })
326
+
327
+ return segments_info, segments
animeinsseg/data/sampler.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from random import choice as rchoice
3
+ from random import randint
4
+ import random
5
+ import cv2, traceback, imageio
6
+ import os.path as osp
7
+
8
+ from typing import Optional, List, Union, Tuple, Dict
9
+ from utils.io_utils import imread_nogrey_rgb, json2dict
10
+ from .transforms import rotate_image
11
+ from utils.logger import LOGGER
12
+
13
+
14
+ class NameSampler:
15
+
16
+ def __init__(self, name_prob_dict, sample_num=2048) -> None:
17
+ self.name_prob_dict = name_prob_dict
18
+ self._id2name = list(name_prob_dict.keys())
19
+ self.sample_ids = []
20
+
21
+ total_prob = 0.
22
+ for ii, (_, prob) in enumerate(name_prob_dict.items()):
23
+ tgt_num = int(prob * sample_num)
24
+ total_prob += prob
25
+ if tgt_num > 0:
26
+ self.sample_ids += [ii] * tgt_num
27
+
28
+ nsamples = len(self.sample_ids)
29
+ assert prob <= 1
30
+ if prob < 1 and nsamples < sample_num:
31
+ self.sample_ids += [len(self._id2name)] * (sample_num - nsamples)
32
+ self._id2name.append('_')
33
+
34
+ def sample(self) -> str:
35
+ return self._id2name[rchoice(self.sample_ids)]
36
+
37
+
38
+ class PossionSampler:
39
+ def __init__(self, lam=3, min_val=1, max_val=8) -> None:
40
+ self._distr = np.random.poisson(lam, 1024)
41
+ invalid = np.where(np.logical_or(self._distr<min_val, self._distr > max_val))
42
+ self._distr[invalid] = np.random.randint(min_val, max_val, len(invalid[0]))
43
+
44
+ def sample(self) -> int:
45
+ return rchoice(self._distr)
46
+
47
+
48
+ class NormalSampler:
49
+ def __init__(self, loc=0.33, std=0.2, min_scale=0.15, max_scale=0.85, scalar=1, to_int = True):
50
+ s = np.random.normal(loc, std, 4096)
51
+ valid = np.where(np.logical_and(s>min_scale, s<max_scale))
52
+ self._distr = s[valid] * scalar
53
+ if to_int:
54
+ self._distr = self._distr.astype(np.int32)
55
+
56
+ def sample(self) -> int:
57
+ return rchoice(self._distr)
58
+
59
+
60
+ class PersonBBoxSampler:
61
+
62
+ def __init__(self, sample_path: Union[str, List]='data/cocoperson_bbox_samples.json', fg_info_list: List = None, fg_transform=None, is_train=True) -> None:
63
+ if isinstance(sample_path, str):
64
+ sample_path = [sample_path]
65
+ self.bbox_list = []
66
+ for sp in sample_path:
67
+ bboxlist = json2dict(sp)
68
+ for bboxes in bboxlist:
69
+ if isinstance(bboxes, dict):
70
+ bboxes = bboxes['bboxes']
71
+ bboxes = np.array(bboxes)
72
+ bboxes[:, [0, 1]] -= bboxes[:, [0, 1]].min(axis=0)
73
+ self.bbox_list.append(bboxes)
74
+
75
+ self.fg_info_list = fg_info_list
76
+ self.fg_transform = fg_transform
77
+ self.is_train = is_train
78
+
79
+ def sample(self, tgt_size: int, scale_range=(1, 1), size_thres=(0.02, 0.85)) -> List[np.ndarray]:
80
+ bboxes_normalized = rchoice(self.bbox_list)
81
+ if scale_range[0] != 1 or scale_range[1] != 1:
82
+ bbox_scale = random.uniform(scale_range[0], scale_range[1])
83
+ else:
84
+ bbox_scale = 1
85
+ bboxes = (bboxes_normalized * tgt_size * bbox_scale).astype(np.int32)
86
+
87
+ xyxy_array = np.copy(bboxes)
88
+ xyxy_array[:, [2, 3]] += xyxy_array[:, [0, 1]]
89
+ x_max, y_max = xyxy_array[:, 2].max(), xyxy_array[:, 3].max()
90
+
91
+ x_shift = tgt_size - x_max
92
+ x_shift = randint(0, x_shift) if x_shift > 0 else 0
93
+ y_shift = tgt_size - y_max
94
+ y_shift = randint(0, y_shift) if y_shift > 0 else 0
95
+
96
+ bboxes[:, [0, 1]] += [x_shift, y_shift]
97
+ valid_bboxes = []
98
+ max_size = size_thres[1] * tgt_size
99
+ min_size = size_thres[0] * tgt_size
100
+ for bbox in bboxes:
101
+ w = min(bbox[2], tgt_size - bbox[0])
102
+ h = min(bbox[3], tgt_size - bbox[1])
103
+ if max(h, w) < max_size and min(h, w) > min_size:
104
+ valid_bboxes.append(bbox)
105
+ return valid_bboxes
106
+
107
+ def sample_matchfg(self, tgt_size: int):
108
+ while True:
109
+ bboxes = self.sample(tgt_size, (1.1, 1.8))
110
+ if len(bboxes) > 0:
111
+ break
112
+ MIN_FG_SIZE = 20
113
+ num_fg = len(bboxes)
114
+ rotate = 20 if self.is_train else 15
115
+ fgs = random_load_nfg(num_fg, self.fg_info_list, random_rotate_prob=0.33, random_rotate=rotate)
116
+ assert len(fgs) == num_fg
117
+
118
+ bboxes.sort(key=lambda x: x[2] / x[3])
119
+ fgs.sort(key=lambda x: x['asp_ratio'])
120
+
121
+ for fg, bbox in zip(fgs, bboxes):
122
+ x, y, w, h = bbox
123
+ img = fg['image']
124
+ im_h, im_w = img.shape[:2]
125
+ if im_h < h and im_w < w:
126
+ scale = min(h / im_h, w / im_w)
127
+ new_h, new_w = int(scale * im_h), int(scale * im_w)
128
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
129
+ else:
130
+ scale_h, scale_w = min(1, h / im_h), min(1, w / im_w)
131
+ scale = (scale_h + scale_w) / 2
132
+ if scale < 1:
133
+ new_h, new_w = max(int(scale * im_h), MIN_FG_SIZE), max(int(scale * im_w), MIN_FG_SIZE)
134
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
135
+
136
+ if self.fg_transform is not None:
137
+ img = self.fg_transform(image=img)['image']
138
+
139
+ im_h, im_w = img.shape[:2]
140
+ fg['image'] = img
141
+ px = int(x + w / 2 - im_w / 2)
142
+ py = int(y + h / 2 - im_h / 2)
143
+ fg['pos'] = (px, py)
144
+
145
+ random.shuffle(fgs)
146
+
147
+ slist, llist = [], []
148
+ large_size = int(tgt_size * 0.55)
149
+ for fg in fgs:
150
+ if max(fg['image'].shape[:2]) > large_size:
151
+ llist.append(fg)
152
+ else:
153
+ slist.append(fg)
154
+ return llist + slist
155
+
156
+
157
+ def random_load_nfg(num_fg: int, fg_info_list: List[Union[Dict, str]], random_rotate=0, random_rotate_prob=0.):
158
+ fgs = []
159
+ while len(fgs) < num_fg:
160
+ fg, fginfo = random_load_valid_fg(fg_info_list)
161
+ if random.random() < random_rotate_prob:
162
+ rotate_deg = randint(-random_rotate, random_rotate)
163
+ fg = rotate_image(fg, rotate_deg, alpha_crop=True)
164
+
165
+ asp_ratio = fg.shape[1] / fg.shape[0]
166
+ fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
167
+ while len(fgs) < num_fg and random.random() < 0.12:
168
+ fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
169
+
170
+ return fgs
171
+
172
+
173
+ def random_load_valid_fg(fg_info_list: List[Union[Dict, str]]) -> Tuple[np.ndarray, Dict]:
174
+ while True:
175
+ item = fginfo = rchoice(fg_info_list)
176
+
177
+ file_path = fginfo['file_path']
178
+ if 'root_dir' in fginfo and fginfo['root_dir']:
179
+ file_path = osp.join(fginfo['root_dir'], file_path)
180
+
181
+ try:
182
+ fg = imageio.imread(file_path)
183
+ except:
184
+ LOGGER.error(traceback.format_exc())
185
+ LOGGER.error(f'invalid fg: {file_path}')
186
+ fg_info_list.remove(item)
187
+ continue
188
+
189
+ c = 1
190
+ if len(fg.shape) == 3:
191
+ c = fg.shape[-1]
192
+ if c != 4:
193
+ LOGGER.warning(f'fg {file_path} doesnt have alpha channel')
194
+ fg_info_list.remove(item)
195
+ else:
196
+ if 'xyxy' in fginfo:
197
+ x1, y1, x2, y2 = fginfo['xyxy']
198
+ else:
199
+ oh, ow = fg.shape[:2]
200
+ ksize = 5
201
+ mask = cv2.blur(fg[..., 3], (ksize,ksize))
202
+ _, mask = cv2.threshold(mask, 20, 255, cv2.THRESH_BINARY)
203
+
204
+ x1, y1, w, h = cv2.boundingRect(cv2.findNonZero(mask))
205
+ x2, y2 = x1 + w, y1 + h
206
+ if oh - h > 15 or ow - w > 15:
207
+ crop = True
208
+ else:
209
+ x1 = y1 = 0
210
+ x2, y2 = ow, oh
211
+
212
+ fginfo['xyxy'] = [x1, y1, x2, y2]
213
+ fg = fg[y1: y2, x1: x2]
214
+ return fg, fginfo
215
+
216
+
217
+ def random_load_valid_bg(bg_list: List[str]) -> np.ndarray:
218
+ while True:
219
+ try:
220
+ bgp = rchoice(bg_list)
221
+ return imread_nogrey_rgb(bgp)
222
+ except:
223
+ LOGGER.error(traceback.format_exc())
224
+ LOGGER.error(f'invalid bg: {bgp}')
225
+ bg_list.remove(bgp)
226
+ continue
animeinsseg/data/syndataset.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Union, Tuple, Dict
3
+ import random
4
+ from PIL import Image
5
+ import cv2
6
+ import imageio, os
7
+ import os.path as osp
8
+ from tqdm import tqdm
9
+ from panopticapi.utils import rgb2id
10
+ import traceback
11
+
12
+ from utils.io_utils import mask2rle, dict2json, fgbg_hist_matching
13
+ from utils.logger import LOGGER
14
+ from utils.constants import CATEGORIES, IMAGE_ID_ZFILL
15
+ from .transforms import get_fg_transforms, get_bg_transforms, quantize_image, resize2height, rotate_image
16
+ from .sampler import random_load_valid_bg, random_load_valid_fg, NameSampler, NormalSampler, PossionSampler, PersonBBoxSampler
17
+ from .paste_methods import regular_paste, partition_paste
18
+
19
+
20
+ def syn_animecoco_dataset(
21
+ bg_list: List, fg_info_list: List[Dict], dataset_save_dir: str, policy: str='train',
22
+ tgt_size=640, syn_num_multiplier=2.5, regular_paste_prob=0.4, person_paste_prob=0.4,
23
+ max_syn_num=-1, image_id_start=0, obj_id_start=0, hist_match_prob=0.2, quantize_prob=0.25):
24
+
25
+ LOGGER.info(f'syn data policy: {policy}')
26
+ LOGGER.info(f'background: {len(bg_list)} foreground: {len(fg_info_list)}')
27
+
28
+ numfg_sampler = PossionSampler(min_val=1, max_val=9, lam=2.5)
29
+ numfg_regpaste_sampler = PossionSampler(min_val=2, max_val=9, lam=3.5)
30
+ regpaste_size_sampler = NormalSampler(scalar=tgt_size, to_int=True, max_scale=0.75)
31
+ color_correction_sampler = NameSampler({'hist_match': hist_match_prob, 'quantize': quantize_prob}, )
32
+ paste_method_sampler = NameSampler({'regular': regular_paste_prob, 'personbbox': person_paste_prob,
33
+ 'partition': 1-regular_paste_prob-person_paste_prob})
34
+
35
+ fg_transform = get_fg_transforms(tgt_size, transform_variant=policy)
36
+ fg_distort_transform = get_fg_transforms(tgt_size, transform_variant='distort_only')
37
+ bg_transform = get_bg_transforms('train', tgt_size)
38
+
39
+ image_id = image_id_start + 1
40
+ obj_id = obj_id_start + 1
41
+
42
+ det_annotations, image_meta = [], []
43
+
44
+ syn_num = int(syn_num_multiplier * len(fg_info_list))
45
+ if max_syn_num > 0:
46
+ syn_num = max_syn_num
47
+
48
+ ann_save_dir = osp.join(dataset_save_dir, 'annotations')
49
+ image_save_dir = osp.join(dataset_save_dir, policy)
50
+
51
+ if not osp.exists(image_save_dir):
52
+ os.makedirs(image_save_dir)
53
+ if not osp.exists(ann_save_dir):
54
+ os.makedirs(ann_save_dir)
55
+
56
+ is_train = policy == 'train'
57
+ if is_train:
58
+ jpg_save_quality = [75, 85, 95]
59
+ else:
60
+ jpg_save_quality = [95]
61
+
62
+ if isinstance(fg_info_list[0], str):
63
+ for ii, fgp in enumerate(fg_info_list):
64
+ if isinstance(fgp, str):
65
+ fg_info_list[ii] = {'file_path': fgp, 'tag_string': [], 'danbooru': False, 'category_id': 0}
66
+
67
+ if person_paste_prob > 0:
68
+ personbbox_sampler = PersonBBoxSampler(
69
+ 'data/cocoperson_bbox_samples.json', fg_info_list,
70
+ fg_transform=fg_distort_transform if is_train else None, is_train=is_train)
71
+
72
+ total = tqdm(range(syn_num))
73
+ for fin in total:
74
+ try:
75
+ paste_method = paste_method_sampler.sample()
76
+
77
+ fgs = []
78
+ if paste_method == 'regular':
79
+ num_fg = numfg_regpaste_sampler.sample()
80
+ size = regpaste_size_sampler.sample()
81
+ while len(fgs) < num_fg:
82
+ tgt_height = int(random.uniform(0.7, 1.2) * size)
83
+ fg, fginfo = random_load_valid_fg(fg_info_list)
84
+ fg = resize2height(fg, tgt_height)
85
+ if is_train:
86
+ fg = fg_distort_transform(image=fg)['image']
87
+ rotate_deg = random.randint(-40, 40)
88
+ else:
89
+ rotate_deg = random.randint(-30, 30)
90
+ if random.random() < 0.3:
91
+ fg = rotate_image(fg, rotate_deg, alpha_crop=True)
92
+ fgs.append({'image': fg, 'fginfo': fginfo})
93
+ while len(fgs) < num_fg and random.random() < 0.15:
94
+ fgs.append({'image': fg, 'fginfo': fginfo})
95
+ elif paste_method == 'personbbox':
96
+ fgs = personbbox_sampler.sample_matchfg(tgt_size)
97
+ else:
98
+ num_fg = numfg_sampler.sample()
99
+ fgs = []
100
+ for ii in range(num_fg):
101
+ fg, fginfo = random_load_valid_fg(fg_info_list)
102
+ fg = fg_transform(image=fg)['image']
103
+ h, w = fg.shape[:2]
104
+ if num_fg > 6:
105
+ downscale = min(tgt_size / 2.5 / w, tgt_size / 2.5 / h)
106
+ if downscale < 1:
107
+ fg = cv2.resize(fg, (int(w * downscale), int(h * downscale)), interpolation=cv2.INTER_AREA)
108
+ fgs.append({'image': fg, 'fginfo': fginfo})
109
+
110
+ bg = random_load_valid_bg(bg_list)
111
+ bg = bg_transform(image=bg)['image']
112
+
113
+ color_correct = color_correction_sampler.sample()
114
+
115
+ if color_correct == 'hist_match':
116
+ fgbg_hist_matching(fgs, bg)
117
+
118
+ bg: Image = Image.fromarray(bg)
119
+
120
+ if paste_method == 'regular':
121
+ segments_info, segments = regular_paste(fgs, bg, regen_bboxes=True)
122
+ elif paste_method == 'personbbox':
123
+ segments_info, segments = regular_paste(fgs, bg, regen_bboxes=False)
124
+ elif paste_method == 'partition':
125
+ segments_info, segments = partition_paste(fgs, bg, )
126
+ else:
127
+ print(f'invalid paste method: {paste_method}')
128
+ raise NotImplementedError
129
+
130
+ image = np.array(bg)
131
+ if color_correct == 'quantize':
132
+ mask = cv2.inRange(segments, np.array([0,0,0]), np.array([0,0,0]))
133
+ # cv2.imshow("mask", mask)
134
+ image = quantize_image(image, random.choice([12, 16, 32]), 'kmeans', mask=mask)[0]
135
+
136
+ # postprocess & check if instance is valid
137
+ for ii, segi in enumerate(segments_info):
138
+ if segi['area'] == 0:
139
+ continue
140
+ x, y, w, h = segi['bbox']
141
+ x2, y2 = x+w, y+h
142
+ c = segments[y: y2, x: x2]
143
+ pan_png = rgb2id(c)
144
+ cmask = (pan_png == segi['id'])
145
+ area = cmask.sum()
146
+
147
+ if paste_method != 'partition' and \
148
+ area / (fgs[ii]['image'][..., 3] > 30).sum() < 0.25:
149
+ # cv2.imshow('im', fgs[ii]['image'])
150
+ # cv2.imshow('mask', fgs[ii]['image'][..., 3])
151
+ # cv2.imshow('seg', segments)
152
+ # cv2.waitKey(0)
153
+ cmask_ids = np.where(cmask)
154
+ segments[y: y2, x: x2][cmask_ids] = 0
155
+ image[y: y2, x: x2][cmask_ids] = (127, 127, 127)
156
+ continue
157
+
158
+ cmask = cmask.astype(np.uint8) * 255
159
+ dx, dy, w, h = cv2.boundingRect(cv2.findNonZero(cmask))
160
+ _bbox = [dx + x, dy + y, w, h]
161
+
162
+ seg = cv2.copyMakeBorder(cmask, y, tgt_size-y2, x, tgt_size-x2, cv2.BORDER_CONSTANT) > 0
163
+ assert seg.shape[0] == tgt_size and seg.shape[1] == tgt_size
164
+ segmentation = mask2rle(seg)
165
+
166
+ det_annotations.append({
167
+ 'id': obj_id,
168
+ 'category_id': fgs[ii]['fginfo']['category_id'],
169
+ 'iscrowd': 0,
170
+ 'segmentation': segmentation,
171
+ 'image_id': image_id,
172
+ 'area': area,
173
+ 'tag_string': fgs[ii]['fginfo']['tag_string'],
174
+ 'tag_string_character': fgs[ii]['fginfo']['tag_string_character'],
175
+ 'bbox': [float(c) for c in _bbox]
176
+ })
177
+
178
+ obj_id += 1
179
+ # cv2.imshow('c', cv2.cvtColor(c, cv2.COLOR_RGB2BGR))
180
+ # cv2.imshow('cmask', cmask)
181
+ # cv2.waitKey(0)
182
+
183
+ image_id_str = str(image_id).zfill(IMAGE_ID_ZFILL)
184
+ image_file_name = image_id_str + '.jpg'
185
+ image_meta.append({
186
+ "id": image_id,"height": tgt_size,"width": tgt_size, "file_name": image_file_name, "id": image_id
187
+ })
188
+
189
+ # LOGGER.info(f'paste method: {paste_method} color correct: {color_correct}')
190
+ # cv2.imshow('image', cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
191
+ # cv2.imshow('segments', cv2.cvtColor(segments, cv2.COLOR_RGB2BGR))
192
+ # cv2.waitKey(0)
193
+
194
+ imageio.imwrite(osp.join(image_save_dir, image_file_name), image, quality=random.choice(jpg_save_quality))
195
+ image_id += 1
196
+
197
+ except:
198
+ LOGGER.error(traceback.format_exc())
199
+ continue
200
+
201
+ det_meta = {
202
+ "info": {},
203
+ "licenses": [],
204
+ "images": image_meta,
205
+ "annotations": det_annotations,
206
+ "categories": CATEGORIES
207
+ }
208
+
209
+ detp = osp.join(ann_save_dir, f'det_{policy}.json')
210
+ dict2json(det_meta, detp)
211
+ LOGGER.info(f'annotations saved to {detp}')
212
+
213
+ return image_id, obj_id
animeinsseg/data/transforms.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations import DualIAATransform, to_tuple
3
+ import imgaug.augmenters as iaa
4
+ import cv2
5
+ from tqdm import tqdm
6
+ from sklearn.cluster import KMeans
7
+ from sklearn.metrics import pairwise_distances_argmin
8
+ from sklearn.utils import shuffle
9
+ import numpy as np
10
+
11
+ class IAAAffine2(DualIAATransform):
12
+ """Place a regular grid of points on the input and randomly move the neighbourhood of these point around
13
+ via affine transformations.
14
+ Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
15
+ Args:
16
+ p (float): probability of applying the transform. Default: 0.5.
17
+ Targets:
18
+ image, mask
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ scale=(0.7, 1.3),
24
+ translate_percent=None,
25
+ translate_px=None,
26
+ rotate=0.0,
27
+ shear=(-0.1, 0.1),
28
+ order=1,
29
+ cval=0,
30
+ mode="reflect",
31
+ always_apply=False,
32
+ p=0.5,
33
+ ):
34
+ super(IAAAffine2, self).__init__(always_apply, p)
35
+ self.scale = dict(x=scale, y=scale)
36
+ self.translate_percent = to_tuple(translate_percent, 0)
37
+ self.translate_px = to_tuple(translate_px, 0)
38
+ self.rotate = to_tuple(rotate)
39
+ self.shear = dict(x=shear, y=shear)
40
+ self.order = order
41
+ self.cval = cval
42
+ self.mode = mode
43
+
44
+ @property
45
+ def processor(self):
46
+ return iaa.Affine(
47
+ self.scale,
48
+ self.translate_percent,
49
+ self.translate_px,
50
+ self.rotate,
51
+ self.shear,
52
+ self.order,
53
+ self.cval,
54
+ self.mode,
55
+ )
56
+
57
+ def get_transform_init_args_names(self):
58
+ return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")
59
+
60
+
61
+ class IAAPerspective2(DualIAATransform):
62
+ """Perform a random four point perspective transform of the input.
63
+ Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
64
+ Args:
65
+ scale ((float, float): standard deviation of the normal distributions. These are used to sample
66
+ the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1).
67
+ p (float): probability of applying the transform. Default: 0.5.
68
+ Targets:
69
+ image, mask
70
+ """
71
+
72
+ def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5,
73
+ order=1, cval=0, mode="replicate"):
74
+ super(IAAPerspective2, self).__init__(always_apply, p)
75
+ self.scale = to_tuple(scale, 1.0)
76
+ self.keep_size = keep_size
77
+ self.cval = cval
78
+ self.mode = mode
79
+
80
+ @property
81
+ def processor(self):
82
+ return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval)
83
+
84
+ def get_transform_init_args_names(self):
85
+ return ("scale", "keep_size")
86
+
87
+
88
+ def get_bg_transforms(transform_variant, out_size):
89
+ max_size = int(out_size * 1.2)
90
+ if transform_variant == 'train':
91
+ transform = [
92
+ A.SmallestMaxSize(max_size, always_apply=True, interpolation=cv2.INTER_AREA),
93
+ A.RandomResizedCrop(out_size, out_size, scale=(0.9, 1.5), p=1, ratio=(0.9, 1.1)),
94
+ ]
95
+ else:
96
+ transform = [
97
+ A.SmallestMaxSize(out_size, always_apply=True),
98
+ A.RandomCrop(out_size, out_size, True),
99
+ ]
100
+ return A.Compose(transform)
101
+
102
+
103
+ def get_fg_transforms(out_size, scale_limit=(-0.85, -0.3), transform_variant='train'):
104
+ if transform_variant == 'train':
105
+ transform = [
106
+ A.LongestMaxSize(out_size),
107
+ A.RandomScale(scale_limit=scale_limit, always_apply=True, interpolation=cv2.INTER_AREA),
108
+ IAAAffine2(scale=(1, 1),
109
+ rotate=(-15, 15),
110
+ shear=(-0.1, 0.1), p=0.3, mode='constant'),
111
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
112
+ A.HorizontalFlip(),
113
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT, p=0.3),
114
+ A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, p=0.3)
115
+ ]
116
+ elif transform_variant == 'distort_only':
117
+ transform = [
118
+ IAAAffine2(scale=(1, 1),
119
+ shear=(-0.1, 0.1), p=0.3, mode='constant'),
120
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
121
+ A.HorizontalFlip(),
122
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT, p=0.3),
123
+ A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, p=0.3)
124
+ ]
125
+ else:
126
+ transform = [
127
+ A.LongestMaxSize(out_size),
128
+ A.RandomScale(scale_limit=scale_limit, always_apply=True, interpolation=cv2.INTER_LINEAR)
129
+ ]
130
+ return A.Compose(transform)
131
+
132
+
133
+ def get_transforms(transform_variant, out_size, to_float=True):
134
+ if transform_variant == 'distortions':
135
+ transform = [
136
+ IAAAffine2(scale=(1, 1.3),
137
+ rotate=(-20, 20),
138
+ shear=(-0.1, 0.1), p=1, mode='constant'),
139
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
140
+ A.OpticalDistortion(),
141
+ A.HorizontalFlip(),
142
+ A.Sharpen(p=0.3),
143
+ A.CLAHE(),
144
+ A.GaussNoise(p=0.3),
145
+ A.Posterize(),
146
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT),
147
+ ]
148
+ elif transform_variant == 'default':
149
+ transform = [
150
+ A.HorizontalFlip(),
151
+ A.Rotate(20, p=0.3)
152
+ ]
153
+ elif transform_variant == 'identity':
154
+ transform = []
155
+ else:
156
+ raise ValueError(f'Unexpected transform_variant {transform_variant}')
157
+ if to_float:
158
+ transform.append(A.ToFloat())
159
+ return A.Compose(transform)
160
+
161
+
162
+ def get_template_transforms(transform_variant, out_size, to_float=True):
163
+ if transform_variant == 'distortions':
164
+ transform = [
165
+ A.Cutout(p=0.3, max_w_size=30, max_h_size=30, num_holes=1),
166
+ IAAAffine2(scale=(1, 1.3),
167
+ rotate=(-20, 20),
168
+ shear=(-0.1, 0.1), p=1, mode='constant'),
169
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
170
+ A.OpticalDistortion(),
171
+ A.HorizontalFlip(),
172
+ A.Sharpen(p=0.3),
173
+ A.CLAHE(),
174
+ A.GaussNoise(p=0.3),
175
+ A.Posterize(),
176
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT),
177
+ ]
178
+ elif transform_variant == 'identity':
179
+ transform = []
180
+ else:
181
+ raise ValueError(f'Unexpected transform_variant {transform_variant}')
182
+ if to_float:
183
+ transform.append(A.ToFloat())
184
+ return A.Compose(transform)
185
+
186
+
187
+ def rotate_image(mat: np.ndarray, angle: float, alpha_crop: bool = False) -> np.ndarray:
188
+ """
189
+ Rotates an image (angle in degrees) and expands image to avoid cropping
190
+ # https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
191
+ """
192
+
193
+ height, width = mat.shape[:2] # image shape has 3 dimensions
194
+ image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
195
+
196
+ rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)
197
+
198
+ # rotation calculates the cos and sin, taking absolutes of those.
199
+ abs_cos = abs(rotation_mat[0,0])
200
+ abs_sin = abs(rotation_mat[0,1])
201
+
202
+ # find the new width and height bounds
203
+ bound_w = int(height * abs_sin + width * abs_cos)
204
+ bound_h = int(height * abs_cos + width * abs_sin)
205
+
206
+ # subtract old image center (bringing image back to origo) and adding the new image center coordinates
207
+ rotation_mat[0, 2] += bound_w/2 - image_center[0]
208
+ rotation_mat[1, 2] += bound_h/2 - image_center[1]
209
+
210
+ # rotate image with the new bounds and translated rotation matrix
211
+ rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
212
+
213
+ if alpha_crop and len(rotated_mat.shape) == 3 and rotated_mat.shape[-1] == 4:
214
+ x, y, w, h = cv2.boundingRect(rotated_mat[..., -1])
215
+ rotated_mat = rotated_mat[y: y+h, x: x+w]
216
+
217
+ return rotated_mat
218
+
219
+
220
+ def recreate_image(codebook, labels, w, h):
221
+ """Recreate the (compressed) image from the code book & labels"""
222
+ return (codebook[labels].reshape(w, h, -1) * 255).astype(np.uint8)
223
+
224
+ def quantize_image(image: np.ndarray, n_colors: int, method='kmeans', mask=None):
225
+ # https://scikit-learn.org/stable/auto_examples/cluster/plot_color_quantization.html
226
+ image = np.array(image, dtype=np.float64) / 255
227
+
228
+ if len(image.shape) == 3:
229
+ w, h, d = tuple(image.shape)
230
+ else:
231
+ w, h = image.shape
232
+ d = 1
233
+
234
+ # assert d == 3
235
+ image_array = image.reshape(-1, d)
236
+
237
+ if method == 'kmeans':
238
+
239
+ image_array_sample = None
240
+ if mask is not None:
241
+ ids = np.where(mask)
242
+ if len(ids[0]) > 10:
243
+ bg = image[ids][::2]
244
+ fg = image[np.where(mask == 0)]
245
+ max_bg_num = int(fg.shape[0] * 1.5)
246
+ if bg.shape[0] > max_bg_num:
247
+ bg = shuffle(bg, random_state=0, n_samples=max_bg_num)
248
+ image_array_sample = np.concatenate((fg, bg), axis=0)
249
+ if image_array_sample.shape[0] > 2048:
250
+ image_array_sample = shuffle(image_array_sample, random_state=0, n_samples=2048)
251
+ else:
252
+ image_array_sample = None
253
+
254
+ if image_array_sample is None:
255
+ image_array_sample = shuffle(image_array, random_state=0, n_samples=2048)
256
+
257
+ kmeans = KMeans(n_clusters=n_colors, n_init=10, random_state=0).fit(
258
+ image_array_sample
259
+ )
260
+
261
+ labels = kmeans.predict(image_array)
262
+ quantized = recreate_image(kmeans.cluster_centers_, labels, w, h)
263
+ return quantized, kmeans.cluster_centers_, labels
264
+
265
+ else:
266
+
267
+ codebook_random = shuffle(image_array, random_state=0, n_samples=n_colors)
268
+ labels_random = pairwise_distances_argmin(codebook_random, image_array, axis=0)
269
+
270
+ return [recreate_image(codebook_random, labels_random, w, h)]
271
+
272
+
273
+ def resize2height(img: np.ndarray, height: int):
274
+ im_h, im_w = img.shape[:2]
275
+ if im_h > height:
276
+ interpolation = cv2.INTER_AREA
277
+ else:
278
+ interpolation = cv2.INTER_LINEAR
279
+ if im_h != height:
280
+ img = cv2.resize(img, (int(height / im_h * im_w), height), interpolation=interpolation)
281
+ return img
282
+
283
+ if __name__ == '__main__':
284
+ import os.path as osp
285
+
286
+ img_path = r'tmp\megumin.png'
287
+ save_dir = r'tmp'
288
+ sample_num = 24
289
+
290
+ tv = 'distortions'
291
+ out_size = 224
292
+ transforms = get_transforms(tv, out_size ,to_float=False)
293
+ img = cv2.imread(img_path)
294
+ for idx in tqdm(range(sample_num)):
295
+ transformed = transforms(image=img)['image']
296
+ print(transformed.shape)
297
+ cv2.imwrite(osp.join(save_dir, str(idx)+'-transform.jpg'), transformed)
298
+ # cv2.waitKey(0)
299
+ pass
animeinsseg/inpainting/__init__.py ADDED
File without changes
animeinsseg/inpainting/ldm_inpaint.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from omegaconf import OmegaConf
5
+ import safetensors
6
+ import os
7
+ import einops
8
+ import cv2
9
+ from PIL import Image, ImageFilter, ImageOps
10
+ from utils.io_utils import resize_pad2divisior
11
+ import os
12
+ from utils.io_utils import submit_request, img2b64
13
+ import json
14
+ # Debug by Francis
15
+ # from ldm.util import instantiate_from_config
16
+ # from ldm.models.diffusion.ddpm import LatentDiffusion
17
+ # from ldm.models.diffusion.ddim import DDIMSampler
18
+ # from ldm.modules.diffusionmodules.util import noise_like
19
+ import io
20
+ import base64
21
+ from requests.auth import HTTPBasicAuth
22
+
23
+ # Debug by Francis
24
+ # def create_model(config_path):
25
+ # config = OmegaConf.load(config_path)
26
+ # model = instantiate_from_config(config.model).cpu()
27
+ # return model
28
+ #
29
+ # def get_state_dict(d):
30
+ # return d.get('state_dict', d)
31
+ #
32
+ # def load_state_dict(ckpt_path, location='cpu'):
33
+ # _, extension = os.path.splitext(ckpt_path)
34
+ # if extension.lower() == ".safetensors":
35
+ # import safetensors.torch
36
+ # state_dict = safetensors.torch.load_file(ckpt_path, device=location)
37
+ # else:
38
+ # state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
39
+ # state_dict = get_state_dict(state_dict)
40
+ # return state_dict
41
+ #
42
+ #
43
+ # def load_ldm_sd(model, path) :
44
+ # if path.endswith('.safetensor') :
45
+ # sd = safetensors.torch.load_file(path)
46
+ # else :
47
+ # sd = load_state_dict(path)
48
+ # model.load_state_dict(sd, strict = False)
49
+ #
50
+ # def fill_mask_input(image, mask):
51
+ # """fills masked regions with colors from image using blur. Not extremely effective."""
52
+ #
53
+ # image_mod = Image.new('RGBA', (image.width, image.height))
54
+ #
55
+ # image_masked = Image.new('RGBa', (image.width, image.height))
56
+ # image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
57
+ #
58
+ # image_masked = image_masked.convert('RGBa')
59
+ #
60
+ # for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
61
+ # blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
62
+ # for _ in range(repeats):
63
+ # image_mod.alpha_composite(blurred)
64
+ #
65
+ # return image_mod.convert("RGB")
66
+ #
67
+ #
68
+ # def get_inpainting_image_condition(model, image, mask) :
69
+ # conditioning_mask = np.array(mask.convert("L"))
70
+ # conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
71
+ # conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
72
+ # conditioning_mask = torch.round(conditioning_mask)
73
+ # conditioning_mask = conditioning_mask.to(device=image.device, dtype=image.dtype)
74
+ # conditioning_image = torch.lerp(
75
+ # image,
76
+ # image * (1.0 - conditioning_mask),
77
+ # 1
78
+ # )
79
+ # conditioning_image = model.get_first_stage_encoding(model.encode_first_stage(conditioning_image))
80
+ # conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=conditioning_image.shape[-2:])
81
+ # conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
82
+ # image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
83
+ # return image_conditioning
84
+ #
85
+ #
86
+ # class GuidedLDM(LatentDiffusion):
87
+ # def __init__(self, *args, **kwargs):
88
+ # super().__init__(*args, **kwargs)
89
+ #
90
+ # @torch.no_grad()
91
+ # def img2img_inpaint(
92
+ # self,
93
+ # image: Image.Image,
94
+ # c_text: str,
95
+ # uc_text: str,
96
+ # mask: Image.Image,
97
+ # ddim_steps = 50,
98
+ # mask_blur: int = 0,
99
+ # use_cuda: bool = True,
100
+ # **kwargs) -> Image.Image :
101
+ # ddim_sampler = GuidedDDIMSample(self)
102
+ # if use_cuda :
103
+ # self.cond_stage_model.cuda()
104
+ # self.first_stage_model.cuda()
105
+ # c_text = self.get_learned_conditioning([c_text])
106
+ # uc_text = self.get_learned_conditioning([uc_text])
107
+ # cond = {"c_crossattn": [c_text]}
108
+ # uc_cond = {"c_crossattn": [uc_text]}
109
+ #
110
+ # if use_cuda :
111
+ # device = torch.device('cuda:0')
112
+ # else :
113
+ # device = torch.device('cpu')
114
+ #
115
+ # image_mask = mask
116
+ # image_mask = image_mask.convert('L')
117
+ # image_mask = image_mask.filter(ImageFilter.GaussianBlur(mask_blur))
118
+ # latent_mask = image_mask
119
+ # # image = fill_mask_input(image, latent_mask)
120
+ # # image.save('image_fill.png')
121
+ # image = np.array(image).astype(np.float32) / 127.5 - 1.0
122
+ # image = np.moveaxis(image, 2, 0)
123
+ # image = torch.from_numpy(image).to(device)[None]
124
+ # init_latent = self.get_first_stage_encoding(self.encode_first_stage(image))
125
+ # init_mask = latent_mask
126
+ # latmask = init_mask.convert('RGB').resize((init_latent.shape[3], init_latent.shape[2]))
127
+ # latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
128
+ # latmask = latmask[0]
129
+ # latmask = np.around(latmask)
130
+ # latmask = np.tile(latmask[None], (4, 1, 1))
131
+ # nmask = torch.asarray(latmask).to(init_latent.device).float()
132
+ # init_latent = (1 - nmask) * init_latent + nmask * torch.randn_like(init_latent)
133
+ #
134
+ # denoising_strength = 1
135
+ # if self.model.conditioning_key == 'hybrid' :
136
+ # image_cdt = get_inpainting_image_condition(self, image, image_mask)
137
+ # cond["c_concat"] = [image_cdt]
138
+ # uc_cond["c_concat"] = [image_cdt]
139
+ #
140
+ # steps = ddim_steps
141
+ # t_enc = int(min(denoising_strength, 0.999) * steps)
142
+ # eta = 0
143
+ #
144
+ # noise = torch.randn_like(init_latent)
145
+ # ddim_sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, ddim_discretize="uniform", verbose=False)
146
+ # x1 = ddim_sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * int(init_latent.shape[0])).to(device), noise=noise)
147
+ #
148
+ # if use_cuda :
149
+ # self.cond_stage_model.cpu()
150
+ # self.first_stage_model.cpu()
151
+ #
152
+ # if use_cuda :
153
+ # self.model.cuda()
154
+ # decoded = ddim_sampler.decode(x1, cond,t_enc,init_latent=init_latent,nmask=nmask,unconditional_guidance_scale=7,unconditional_conditioning=uc_cond)
155
+ # if use_cuda :
156
+ # self.model.cpu()
157
+ #
158
+ # if mask is not None :
159
+ # decoded = init_latent * (1 - nmask) + decoded * nmask
160
+ #
161
+ # if use_cuda :
162
+ # self.first_stage_model.cuda()
163
+ # with torch.cuda.amp.autocast(enabled=False):
164
+ # x_samples = self.decode_first_stage(decoded.to(torch.float32))
165
+ # if use_cuda :
166
+ # self.first_stage_model.cpu()
167
+ # return torch.clip(x_samples, -1, 1)
168
+ #
169
+ #
170
+ #
171
+ # class GuidedDDIMSample(DDIMSampler) :
172
+ # def __init__(self, *args, **kwargs):
173
+ # super().__init__(*args, **kwargs)
174
+ #
175
+ # @torch.no_grad()
176
+ # def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
177
+ # temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
178
+ # unconditional_guidance_scale=1., unconditional_conditioning=None,
179
+ # dynamic_threshold=None):
180
+ # b, *_, device = *x.shape, x.device
181
+ #
182
+ # if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
183
+ # model_output = self.model.apply_model(x, t, c)
184
+ # else:
185
+ # x_in = torch.cat([x] * 2)
186
+ # t_in = torch.cat([t] * 2)
187
+ # if isinstance(c, dict):
188
+ # assert isinstance(unconditional_conditioning, dict)
189
+ # c_in = dict()
190
+ # for k in c:
191
+ # if isinstance(c[k], list):
192
+ # c_in[k] = [torch.cat([
193
+ # unconditional_conditioning[k][i],
194
+ # c[k][i]]) for i in range(len(c[k]))]
195
+ # else:
196
+ # c_in[k] = torch.cat([
197
+ # unconditional_conditioning[k],
198
+ # c[k]])
199
+ # elif isinstance(c, list):
200
+ # c_in = list()
201
+ # assert isinstance(unconditional_conditioning, list)
202
+ # for i in range(len(c)):
203
+ # c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
204
+ # else:
205
+ # c_in = torch.cat([unconditional_conditioning, c])
206
+ # model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
207
+ # model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
208
+ #
209
+ # e_t = model_output
210
+ #
211
+ # alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
212
+ # alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
213
+ # sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
214
+ # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
215
+ # # select parameters corresponding to the currently considered timestep
216
+ # a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
217
+ # a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
218
+ # sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
219
+ # sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
220
+ #
221
+ # # current prediction for x_0
222
+ # pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
223
+ #
224
+ # # direction pointing to x_t
225
+ # dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
226
+ # noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
227
+ # if noise_dropout > 0.:
228
+ # noise = torch.nn.functional.dropout(noise, p=noise_dropout)
229
+ # x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
230
+ # return x_prev, pred_x0
231
+ #
232
+ # @torch.no_grad()
233
+ # def decode(self, x_latent, cond, t_start, init_latent=None, nmask=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
234
+ # use_original_steps=False, callback=None):
235
+ #
236
+ # timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
237
+ # total_steps = len(timesteps)
238
+ # timesteps = timesteps[:t_start]
239
+ #
240
+ # time_range = np.flip(timesteps)
241
+ # total_steps = timesteps.shape[0]
242
+ # print(f"Running Guided DDIM Sampling with {len(timesteps)} timesteps, t_start={t_start}")
243
+ # iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
244
+ # x_dec = x_latent
245
+ # for i, step in enumerate(iterator):
246
+ # p = (i + (total_steps - t_start) + 1) / (total_steps)
247
+ # index = total_steps - i - 1
248
+ # ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
249
+ # if nmask is not None :
250
+ # noised_input = self.model.q_sample(init_latent.to(x_latent.device), ts.to(x_latent.device))
251
+ # x_dec = (1 - nmask) * noised_input + nmask * x_dec
252
+ # x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
253
+ # unconditional_guidance_scale=unconditional_guidance_scale,
254
+ # unconditional_conditioning=unconditional_conditioning)
255
+ # if callback: callback(i)
256
+ # return x_dec
257
+ #
258
+ #
259
+ # def ldm_inpaint(model, img, mask, inpaint_size=720, pos_prompt='', neg_prompt = '', use_cuda=True):
260
+ # img_original = np.copy(img)
261
+ # im_h, im_w = img.shape[:2]
262
+ # img_resized, (pad_h, pad_w) = resize_pad2divisior(img, inpaint_size)
263
+ #
264
+ # mask_original = np.copy(mask)
265
+ # mask_original[mask_original < 127] = 0
266
+ # mask_original[mask_original >= 127] = 1
267
+ # mask_original = mask_original[:, :, None]
268
+ # mask, _ = resize_pad2divisior(mask, inpaint_size)
269
+ #
270
+ # # cv2.imwrite('img_resized.png', img_resized)
271
+ # # cv2.imwrite('mask_resized.png', mask)
272
+ #
273
+ #
274
+ # if use_cuda :
275
+ # with torch.autocast(enabled = True, device_type = 'cuda') :
276
+ # img = model.img2img_inpaint(
277
+ # image = Image.fromarray(img_resized),
278
+ # c_text = pos_prompt,
279
+ # uc_text = neg_prompt,
280
+ # mask = Image.fromarray(mask),
281
+ # use_cuda = True
282
+ # )
283
+ # else :
284
+ # img = model.img2img_inpaint(
285
+ # image = Image.fromarray(img_resized),
286
+ # c_text = pos_prompt,
287
+ # uc_text = neg_prompt,
288
+ # mask = Image.fromarray(mask),
289
+ # use_cuda = False
290
+ # )
291
+ #
292
+ # img_inpainted = (einops.rearrange(img, '1 c h w -> h w c').cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
293
+ # if pad_h != 0:
294
+ # img_inpainted = img_inpainted[:-pad_h]
295
+ # if pad_w != 0:
296
+ # img_inpainted = img_inpainted[:, :-pad_w]
297
+ #
298
+ #
299
+ # if img_inpainted.shape[0] != im_h or img_inpainted.shape[1] != im_w:
300
+ # img_inpainted = cv2.resize(img_inpainted, (im_w, im_h), interpolation = cv2.INTER_LINEAR)
301
+ # ans = img_inpainted * mask_original + img_original * (1 - mask_original)
302
+ # ans = img_inpainted
303
+ # return ans
304
+
305
+
306
+
307
+
308
+ import requests
309
+ from PIL import Image
310
+ def ldm_inpaint_webui(
311
+ img, mask, resolution: int, url: str, prompt: str = '', neg_prompt: str = '',
312
+ **inpaint_ldm_options):
313
+ if isinstance(img, np.ndarray):
314
+ img = Image.fromarray(img)
315
+
316
+ im_h, im_w = img.height, img.width
317
+
318
+ if img.height > img.width:
319
+ W = resolution
320
+ H = (img.height / img.width * resolution) // 32 * 32
321
+ H = int(H)
322
+ else:
323
+ H = resolution
324
+ W = (img.width / img.height * resolution) // 32 * 32
325
+ W = int(W)
326
+
327
+ auth = None
328
+ if 'username' in inpaint_ldm_options:
329
+ username = inpaint_ldm_options.pop('username')
330
+ password = inpaint_ldm_options.pop('password')
331
+ auth = HTTPBasicAuth(username, password)
332
+
333
+ img_b64 = img2b64(img)
334
+ mask_b64 = img2b64(mask)
335
+ data = {
336
+ "init_images": [img_b64],
337
+ "mask": mask_b64,
338
+ "prompt": prompt,
339
+ "negative_prompt": neg_prompt,
340
+ "width": W,
341
+ "height": H,
342
+ **inpaint_ldm_options,
343
+ }
344
+ data = json.dumps(data)
345
+
346
+ response = submit_request(url, data, auth=auth)
347
+
348
+ inpainted_b64 = response.json()['images'][0]
349
+ inpainted = Image.open(io.BytesIO(base64.b64decode(inpainted_b64)))
350
+ if inpainted.height != im_h or inpainted.width != im_w:
351
+ inpainted = inpainted.resize((im_w, im_h), resample=Image.Resampling.LANCZOS)
352
+ inpainted = np.array(inpainted)
353
+ return inpainted
animeinsseg/inpainting/patch_match.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : patch_match.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import ctypes, os
11
+ import os.path as osp
12
+ from typing import Optional, Union
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+ # try:
18
+ # # If the Jacinle library (https://github.com/vacancy/Jacinle) is present, use its auto_travis feature.
19
+ # from jacinle.jit.cext import auto_travis
20
+ # auto_travis(__file__, required_files=['*.so'])
21
+ # except ImportError as e:
22
+ # # Otherwise, fall back to the subprocess.
23
+ # import subprocess
24
+ # print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
25
+ # subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
26
+
27
+
28
+ __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
29
+
30
+
31
+ class CShapeT(ctypes.Structure):
32
+ _fields_ = [
33
+ ('width', ctypes.c_int),
34
+ ('height', ctypes.c_int),
35
+ ('channels', ctypes.c_int),
36
+ ]
37
+
38
+ class CMatT(ctypes.Structure):
39
+ _fields_ = [
40
+ ('data_ptr', ctypes.c_void_p),
41
+ ('shape', CShapeT),
42
+ ('dtype', ctypes.c_int)
43
+ ]
44
+
45
+ import sys
46
+ if sys.platform == 'linux':
47
+ PMLIB = ctypes.CDLL('data/libs/libpatchmatch_inpaint.so')
48
+ else:
49
+ PMLIB = ctypes.CDLL('data/libs/libpatchmatch.dll')
50
+
51
+ PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
52
+ PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
53
+ PMLIB.PM_free_pymat.argtypes = [CMatT]
54
+ PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
55
+ PMLIB.PM_inpaint.restype = CMatT
56
+ PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
57
+ PMLIB.PM_inpaint_regularity.restype = CMatT
58
+ PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
59
+ PMLIB.PM_inpaint2.restype = CMatT
60
+ PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
61
+ PMLIB.PM_inpaint2_regularity.restype = CMatT
62
+
63
+
64
+ def set_random_seed(seed: int):
65
+ PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
66
+
67
+
68
+ def set_verbose(verbose: bool):
69
+ PMLIB.PM_set_verbose(ctypes.c_int(verbose))
70
+
71
+
72
+ def inpaint(
73
+ image: Union[np.ndarray, Image.Image],
74
+ mask: Optional[Union[np.ndarray, Image.Image]] = None,
75
+ *,
76
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
77
+ patch_size: int = 15
78
+ ) -> np.ndarray:
79
+ """
80
+ PatchMatch based inpainting proposed in:
81
+
82
+ PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
83
+ C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
84
+ SIGGRAPH 2009
85
+
86
+ Args:
87
+ image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
88
+ mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
89
+ If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
90
+ global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
91
+ patch_size (int): the patch size for the inpainting algorithm.
92
+
93
+ Return:
94
+ result (np.ndarray): the repaired image, of the same size as the input image.
95
+ """
96
+
97
+ if isinstance(image, Image.Image):
98
+ image = np.array(image)
99
+ image = np.ascontiguousarray(image)
100
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
101
+
102
+ if mask is None:
103
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
104
+ mask = np.ascontiguousarray(mask)
105
+ else:
106
+ mask = _canonize_mask_array(mask)
107
+
108
+ if global_mask is None:
109
+ ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
110
+ else:
111
+ global_mask = _canonize_mask_array(global_mask)
112
+ ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
113
+
114
+ ret_npmat = pymat_to_np(ret_pymat)
115
+ PMLIB.PM_free_pymat(ret_pymat)
116
+
117
+ return ret_npmat
118
+
119
+
120
+ def inpaint_regularity(
121
+ image: Union[np.ndarray, Image.Image],
122
+ mask: Optional[Union[np.ndarray, Image.Image]],
123
+ ijmap: np.ndarray,
124
+ *,
125
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
126
+ patch_size: int = 15, guide_weight: float = 0.25
127
+ ) -> np.ndarray:
128
+ if isinstance(image, Image.Image):
129
+ image = np.array(image)
130
+ image = np.ascontiguousarray(image)
131
+
132
+ assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
133
+ ijmap = np.ascontiguousarray(ijmap)
134
+
135
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
136
+ if mask is None:
137
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
138
+ mask = np.ascontiguousarray(mask)
139
+ else:
140
+ mask = _canonize_mask_array(mask)
141
+
142
+
143
+ if global_mask is None:
144
+ ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
145
+ else:
146
+ global_mask = _canonize_mask_array(global_mask)
147
+ ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
148
+
149
+ ret_npmat = pymat_to_np(ret_pymat)
150
+ PMLIB.PM_free_pymat(ret_pymat)
151
+
152
+ return ret_npmat
153
+
154
+
155
+ def _canonize_mask_array(mask):
156
+ if isinstance(mask, Image.Image):
157
+ mask = np.array(mask)
158
+ if mask.ndim == 2 and mask.dtype == 'uint8':
159
+ mask = mask[..., np.newaxis]
160
+ assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
161
+ return np.ascontiguousarray(mask)
162
+
163
+
164
+ dtype_pymat_to_ctypes = [
165
+ ctypes.c_uint8,
166
+ ctypes.c_int8,
167
+ ctypes.c_uint16,
168
+ ctypes.c_int16,
169
+ ctypes.c_int32,
170
+ ctypes.c_float,
171
+ ctypes.c_double,
172
+ ]
173
+
174
+
175
+ dtype_np_to_pymat = {
176
+ 'uint8': 0,
177
+ 'int8': 1,
178
+ 'uint16': 2,
179
+ 'int16': 3,
180
+ 'int32': 4,
181
+ 'float32': 5,
182
+ 'float64': 6,
183
+ }
184
+
185
+
186
+ def np_to_pymat(npmat):
187
+ assert npmat.ndim == 3
188
+ return CMatT(
189
+ ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
190
+ CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
191
+ dtype_np_to_pymat[str(npmat.dtype)]
192
+ )
193
+
194
+
195
+ def pymat_to_np(pymat):
196
+ npmat = np.ctypeslib.as_array(
197
+ ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
198
+ (pymat.shape.height, pymat.shape.width, pymat.shape.channels)
199
+ )
200
+ ret = np.empty(npmat.shape, npmat.dtype)
201
+ ret[:] = npmat
202
+ return ret
203
+
animeinsseg/models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from typing import Union
5
+
6
+
7
+
animeinsseg/models/animeseg_refine/__init__.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/SkyTNT/anime-segmentation/blob/main/train.py
2
+ import os
3
+
4
+ import argparse
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import torch.optim as optim
11
+ import numpy as np
12
+ import cv2
13
+ from torch.cuda import amp
14
+
15
+ from utils.constants import DEFAULT_DEVICE
16
+ # from data_loader import create_training_datasets
17
+
18
+
19
+ import pytorch_lightning as pl
20
+ import warnings
21
+
22
+ from .isnet import ISNetDIS, ISNetGTEncoder
23
+ from .u2net import U2NET, U2NET_full, U2NET_full2, U2NET_lite2
24
+ from .modnet import MODNet
25
+
26
+ # warnings.filterwarnings("ignore")
27
+
28
+ def get_net(net_name):
29
+ if net_name == "isnet":
30
+ return ISNetDIS()
31
+ elif net_name == "isnet_is":
32
+ return ISNetDIS()
33
+ elif net_name == "isnet_gt":
34
+ return ISNetGTEncoder()
35
+ elif net_name == "u2net":
36
+ return U2NET_full2()
37
+ elif net_name == "u2netl":
38
+ return U2NET_lite2()
39
+ elif net_name == "modnet":
40
+ return MODNet()
41
+ raise NotImplemented
42
+
43
+
44
+ def f1_torch(pred, gt):
45
+ # micro F1-score
46
+ pred = pred.float().view(pred.shape[0], -1)
47
+ gt = gt.float().view(gt.shape[0], -1)
48
+ tp1 = torch.sum(pred * gt, dim=1)
49
+ tp_fp1 = torch.sum(pred, dim=1)
50
+ tp_fn1 = torch.sum(gt, dim=1)
51
+ pred = 1 - pred
52
+ gt = 1 - gt
53
+ tp2 = torch.sum(pred * gt, dim=1)
54
+ tp_fp2 = torch.sum(pred, dim=1)
55
+ tp_fn2 = torch.sum(gt, dim=1)
56
+ precision = (tp1 + tp2) / (tp_fp1 + tp_fp2 + 0.0001)
57
+ recall = (tp1 + tp2) / (tp_fn1 + tp_fn2 + 0.0001)
58
+ f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 0.0001)
59
+ return precision, recall, f1
60
+
61
+
62
+ class AnimeSegmentation(pl.LightningModule):
63
+
64
+ def __init__(self, net_name):
65
+ super().__init__()
66
+ assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
67
+ self.net = get_net(net_name)
68
+ if net_name == "isnet_is":
69
+ self.gt_encoder = get_net("isnet_gt")
70
+ self.gt_encoder.requires_grad_(False)
71
+ else:
72
+ self.gt_encoder = None
73
+
74
+ @classmethod
75
+ def try_load(cls, net_name, ckpt_path, map_location=None):
76
+ state_dict = torch.load(ckpt_path, map_location=map_location)
77
+ if "epoch" in state_dict:
78
+ return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
79
+ else:
80
+ model = cls(net_name)
81
+ if any([k.startswith("net.") for k, v in state_dict.items()]):
82
+ model.load_state_dict(state_dict)
83
+ else:
84
+ model.net.load_state_dict(state_dict)
85
+ return model
86
+
87
+ def configure_optimizers(self):
88
+ optimizer = optim.Adam(self.net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
89
+ return optimizer
90
+
91
+ def forward(self, x):
92
+ if isinstance(self.net, ISNetDIS):
93
+ return self.net(x)[0][0].sigmoid()
94
+ if isinstance(self.net, ISNetGTEncoder):
95
+ return self.net(x)[0][0].sigmoid()
96
+ elif isinstance(self.net, U2NET):
97
+ return self.net(x)[0].sigmoid()
98
+ elif isinstance(self.net, MODNet):
99
+ return self.net(x, True)[2]
100
+ raise NotImplemented
101
+
102
+ def training_step(self, batch, batch_idx):
103
+ images, labels = batch["image"], batch["label"]
104
+ if isinstance(self.net, ISNetDIS):
105
+ ds, dfs = self.net(images)
106
+ loss_args = [ds, dfs, labels]
107
+ elif isinstance(self.net, ISNetGTEncoder):
108
+ ds = self.net(labels)[0]
109
+ loss_args = [ds, labels]
110
+ elif isinstance(self.net, U2NET):
111
+ ds = self.net(images)
112
+ loss_args = [ds, labels]
113
+ elif isinstance(self.net, MODNet):
114
+ trimaps = batch["trimap"]
115
+ pred_semantic, pred_detail, pred_matte = self.net(images, False)
116
+ loss_args = [pred_semantic, pred_detail, pred_matte, images, trimaps, labels]
117
+ else:
118
+ raise NotImplemented
119
+ if self.gt_encoder is not None:
120
+ fs = self.gt_encoder(labels)[1]
121
+ loss_args.append(fs)
122
+
123
+ loss0, loss = self.net.compute_loss(loss_args)
124
+ self.log_dict({"train/loss": loss, "train/loss_tar": loss0})
125
+ return loss
126
+
127
+ def validation_step(self, batch, batch_idx):
128
+ images, labels = batch["image"], batch["label"]
129
+ if isinstance(self.net, ISNetGTEncoder):
130
+ preds = self.forward(labels)
131
+ else:
132
+ preds = self.forward(images)
133
+ pre, rec, f1, = f1_torch(preds.nan_to_num(nan=0, posinf=1, neginf=0), labels)
134
+ mae_m = F.l1_loss(preds, labels, reduction="mean")
135
+ pre_m = pre.mean()
136
+ rec_m = rec.mean()
137
+ f1_m = f1.mean()
138
+ self.log_dict({"val/precision": pre_m, "val/recall": rec_m, "val/f1": f1_m, "val/mae": mae_m}, sync_dist=True)
139
+
140
+
141
+ def get_gt_encoder(train_dataloader, val_dataloader, opt):
142
+ print("---start train ground truth encoder---")
143
+ gt_encoder = AnimeSegmentation("isnet_gt")
144
+ trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator,
145
+ devices=opt.devices, max_epochs=opt.gt_epoch,
146
+ benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step,
147
+ check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step,
148
+ strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
149
+ )
150
+ trainer.fit(gt_encoder, train_dataloader, val_dataloader)
151
+ return gt_encoder.net
152
+
153
+
154
+ def load_refinenet(refine_method = 'animeseg', device: str = None) -> AnimeSegmentation:
155
+ if device is None:
156
+ device = DEFAULT_DEVICE
157
+ if refine_method == 'animeseg':
158
+ model = AnimeSegmentation.try_load('isnet_is', 'models/anime-seg/isnetis.ckpt', device)
159
+ elif refine_method == 'refinenet_isnet':
160
+ model = ISNetDIS(in_ch=4)
161
+ sd = torch.load('models/AnimeInstanceSegmentation/refine_last.ckpt', map_location='cpu')
162
+ # sd = torch.load('models/AnimeInstanceSegmentation/refine_noweight_dist.ckpt', map_location='cpu')
163
+ # sd = torch.load('models/AnimeInstanceSegmentation/refine_f3loss.ckpt', map_location='cpu')
164
+ model.load_state_dict(sd)
165
+ else:
166
+ raise NotImplementedError
167
+ return model.eval().to(device)
168
+
169
+ def get_mask(model, input_img, use_amp=True, s=640):
170
+ h0, w0 = h, w = input_img.shape[0], input_img.shape[1]
171
+ if h > w:
172
+ h, w = s, int(s * w / h)
173
+ else:
174
+ h, w = int(s * h / w), s
175
+ ph, pw = s - h, s - w
176
+ tmpImg = np.zeros([s, s, 3], dtype=np.float32)
177
+ tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
178
+ tmpImg = tmpImg.transpose((2, 0, 1))
179
+ tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
180
+ with torch.no_grad():
181
+ if use_amp:
182
+ with amp.autocast():
183
+ pred = model(tmpImg)
184
+ pred = pred.to(dtype=torch.float32)
185
+ else:
186
+ pred = model(tmpImg)
187
+ pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
188
+ pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis]
189
+ return pred
animeinsseg/models/animeseg_refine/encoders.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
+
7
+
8
+ class AbstractEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def encode(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+
16
+ class IdentityEncoder(AbstractEncoder):
17
+
18
+ def encode(self, x):
19
+ return x
20
+
21
+
22
+ class ClassEmbedder(nn.Module):
23
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
24
+ super().__init__()
25
+ self.key = key
26
+ self.embedding = nn.Embedding(n_classes, embed_dim)
27
+ self.n_classes = n_classes
28
+ self.ucg_rate = ucg_rate
29
+
30
+ def forward(self, batch, key=None, disable_dropout=False):
31
+ if key is None:
32
+ key = self.key
33
+ # this is for use in crossattn
34
+ c = batch[key][:, None]
35
+ if self.ucg_rate > 0. and not disable_dropout:
36
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
37
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
38
+ c = c.long()
39
+ c = self.embedding(c)
40
+ return c
41
+
42
+ def get_unconditional_conditioning(self, bs, device="cuda"):
43
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
44
+ uc = torch.ones((bs,), device=device) * uc_class
45
+ uc = {self.key: uc}
46
+ return uc
47
+
48
+
49
+ class DanbooruEmbedder(AbstractEncoder):
50
+ def __init__(self):
51
+ super().__init__()
animeinsseg/models/animeseg_refine/isnet.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models
7
+ import torch.nn.functional as F
8
+
9
+ _bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
10
+ _bce_loss_none = nn.BCEWithLogitsLoss(reduction='none')
11
+
12
+ def bce_loss(p, t, weights=None):
13
+ if weights is None:
14
+ return _bce_loss(p, t)
15
+ else:
16
+ loss = _bce_loss_none(p, t)
17
+ loss = loss * weights
18
+ return loss.mean()
19
+
20
+
21
+ _fea_loss = nn.MSELoss(reduction="mean")
22
+ _fea_loss_none = nn.MSELoss(reduction="none")
23
+
24
+ def fea_loss(p, t, weights=None):
25
+ return _fea_loss(p, t)
26
+
27
+ kl_loss = nn.KLDivLoss(reduction="mean")
28
+ l1_loss = nn.L1Loss(reduction="mean")
29
+ smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
30
+
31
+
32
+ def structure_loss(pred, mask):
33
+ weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=15, stride=1, padding=7)-mask)
34
+ wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
35
+ wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
36
+
37
+ pred = torch.sigmoid(pred)
38
+ inter = ((pred*mask)*weit).sum(dim=(2,3))
39
+ union = ((pred+mask)*weit).sum(dim=(2,3))
40
+ wiou = 1-(inter+1)/(union-inter+1)
41
+ return (wbce+wiou).mean()
42
+
43
+
44
+ def muti_loss_fusion(preds, target, dist_weight=None, loss0_weight=1.0):
45
+ loss0 = 0.0
46
+ loss = 0.0
47
+
48
+ for i in range(0, len(preds)):
49
+ weight = dist_weight if i == 0 else None
50
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
51
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
52
+ loss = loss + structure_loss(preds[i], tmp_target)
53
+ else:
54
+ # loss = loss + bce_loss(preds[i], target, weight)
55
+ loss = loss + structure_loss(preds[i], target)
56
+ if i == 0:
57
+ loss *= loss0_weight
58
+ loss0 = loss
59
+ return loss0, loss
60
+
61
+
62
+
63
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE', dist_weight=None, loss0_weight=1.0):
64
+ loss0 = 0.0
65
+ loss = 0.0
66
+
67
+ for i in range(0, len(preds)):
68
+ weight = dist_weight if i == 0 else None
69
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
70
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
71
+ # loss = loss + bce_loss(preds[i], tmp_target, weight)
72
+ loss = loss + structure_loss(preds[i], tmp_target)
73
+ else:
74
+ # loss = loss + bce_loss(preds[i], target, weight)
75
+ loss = loss + structure_loss(preds[i], target)
76
+ if i == 0:
77
+ loss *= loss0_weight
78
+ loss0 = loss
79
+
80
+ for i in range(0, len(dfs)):
81
+ df = dfs[i]
82
+ fs_i = fs[i]
83
+ if mode == 'MSE':
84
+ loss = loss + fea_loss(df, fs_i, dist_weight) ### add the mse loss of features as additional constraints
85
+ elif mode == 'KL':
86
+ loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
87
+ elif mode == 'MAE':
88
+ loss = loss + l1_loss(df, fs_i)
89
+ elif mode == 'SmoothL1':
90
+ loss = loss + smooth_l1_loss(df, fs_i)
91
+
92
+ return loss0, loss
93
+
94
+
95
+ class REBNCONV(nn.Module):
96
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
97
+ super(REBNCONV, self).__init__()
98
+
99
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
100
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
101
+ self.relu_s1 = nn.ReLU(inplace=True)
102
+
103
+ def forward(self, x):
104
+ hx = x
105
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
106
+
107
+ return xout
108
+
109
+
110
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
111
+ def _upsample_like(src, tar):
112
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
113
+
114
+ return src
115
+
116
+
117
+ ### RSU-7 ###
118
+ class RSU7(nn.Module):
119
+
120
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
121
+ super(RSU7, self).__init__()
122
+
123
+ self.in_ch = in_ch
124
+ self.mid_ch = mid_ch
125
+ self.out_ch = out_ch
126
+
127
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
128
+
129
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
130
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
131
+
132
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
133
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
134
+
135
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
136
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
137
+
138
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
139
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
140
+
141
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
142
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
143
+
144
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
145
+
146
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
147
+
148
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
149
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
150
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
151
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
152
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
153
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
154
+
155
+ def forward(self, x):
156
+ b, c, h, w = x.shape
157
+
158
+ hx = x
159
+ hxin = self.rebnconvin(hx)
160
+
161
+ hx1 = self.rebnconv1(hxin)
162
+ hx = self.pool1(hx1)
163
+
164
+ hx2 = self.rebnconv2(hx)
165
+ hx = self.pool2(hx2)
166
+
167
+ hx3 = self.rebnconv3(hx)
168
+ hx = self.pool3(hx3)
169
+
170
+ hx4 = self.rebnconv4(hx)
171
+ hx = self.pool4(hx4)
172
+
173
+ hx5 = self.rebnconv5(hx)
174
+ hx = self.pool5(hx5)
175
+
176
+ hx6 = self.rebnconv6(hx)
177
+
178
+ hx7 = self.rebnconv7(hx6)
179
+
180
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
181
+ hx6dup = _upsample_like(hx6d, hx5)
182
+
183
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
184
+ hx5dup = _upsample_like(hx5d, hx4)
185
+
186
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
187
+ hx4dup = _upsample_like(hx4d, hx3)
188
+
189
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
190
+ hx3dup = _upsample_like(hx3d, hx2)
191
+
192
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
193
+ hx2dup = _upsample_like(hx2d, hx1)
194
+
195
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
196
+
197
+ return hx1d + hxin
198
+
199
+
200
+ ### RSU-6 ###
201
+ class RSU6(nn.Module):
202
+
203
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
204
+ super(RSU6, self).__init__()
205
+
206
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
207
+
208
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
209
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
210
+
211
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
212
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
213
+
214
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
215
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
216
+
217
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
218
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
219
+
220
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
221
+
222
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
223
+
224
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
225
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
226
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
227
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
228
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
229
+
230
+ def forward(self, x):
231
+ hx = x
232
+
233
+ hxin = self.rebnconvin(hx)
234
+
235
+ hx1 = self.rebnconv1(hxin)
236
+ hx = self.pool1(hx1)
237
+
238
+ hx2 = self.rebnconv2(hx)
239
+ hx = self.pool2(hx2)
240
+
241
+ hx3 = self.rebnconv3(hx)
242
+ hx = self.pool3(hx3)
243
+
244
+ hx4 = self.rebnconv4(hx)
245
+ hx = self.pool4(hx4)
246
+
247
+ hx5 = self.rebnconv5(hx)
248
+
249
+ hx6 = self.rebnconv6(hx5)
250
+
251
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
252
+ hx5dup = _upsample_like(hx5d, hx4)
253
+
254
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
255
+ hx4dup = _upsample_like(hx4d, hx3)
256
+
257
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
258
+ hx3dup = _upsample_like(hx3d, hx2)
259
+
260
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
261
+ hx2dup = _upsample_like(hx2d, hx1)
262
+
263
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
264
+
265
+ return hx1d + hxin
266
+
267
+
268
+ ### RSU-5 ###
269
+ class RSU5(nn.Module):
270
+
271
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
272
+ super(RSU5, self).__init__()
273
+
274
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
275
+
276
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
277
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
278
+
279
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
280
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
281
+
282
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
283
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
284
+
285
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
286
+
287
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
288
+
289
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
290
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
291
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
292
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
293
+
294
+ def forward(self, x):
295
+ hx = x
296
+
297
+ hxin = self.rebnconvin(hx)
298
+
299
+ hx1 = self.rebnconv1(hxin)
300
+ hx = self.pool1(hx1)
301
+
302
+ hx2 = self.rebnconv2(hx)
303
+ hx = self.pool2(hx2)
304
+
305
+ hx3 = self.rebnconv3(hx)
306
+ hx = self.pool3(hx3)
307
+
308
+ hx4 = self.rebnconv4(hx)
309
+
310
+ hx5 = self.rebnconv5(hx4)
311
+
312
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
313
+ hx4dup = _upsample_like(hx4d, hx3)
314
+
315
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
316
+ hx3dup = _upsample_like(hx3d, hx2)
317
+
318
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
319
+ hx2dup = _upsample_like(hx2d, hx1)
320
+
321
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
322
+
323
+ return hx1d + hxin
324
+
325
+
326
+ ### RSU-4 ###
327
+ class RSU4(nn.Module):
328
+
329
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
330
+ super(RSU4, self).__init__()
331
+
332
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
333
+
334
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
335
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
336
+
337
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
338
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
339
+
340
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
341
+
342
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
343
+
344
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
345
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
346
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
347
+
348
+ def forward(self, x):
349
+ hx = x
350
+
351
+ hxin = self.rebnconvin(hx)
352
+
353
+ hx1 = self.rebnconv1(hxin)
354
+ hx = self.pool1(hx1)
355
+
356
+ hx2 = self.rebnconv2(hx)
357
+ hx = self.pool2(hx2)
358
+
359
+ hx3 = self.rebnconv3(hx)
360
+
361
+ hx4 = self.rebnconv4(hx3)
362
+
363
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
364
+ hx3dup = _upsample_like(hx3d, hx2)
365
+
366
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
367
+ hx2dup = _upsample_like(hx2d, hx1)
368
+
369
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
370
+
371
+ return hx1d + hxin
372
+
373
+
374
+ ### RSU-4F ###
375
+ class RSU4F(nn.Module):
376
+
377
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
378
+ super(RSU4F, self).__init__()
379
+
380
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
381
+
382
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
383
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
384
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
385
+
386
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
387
+
388
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
389
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
390
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
391
+
392
+ def forward(self, x):
393
+ hx = x
394
+
395
+ hxin = self.rebnconvin(hx)
396
+
397
+ hx1 = self.rebnconv1(hxin)
398
+ hx2 = self.rebnconv2(hx1)
399
+ hx3 = self.rebnconv3(hx2)
400
+
401
+ hx4 = self.rebnconv4(hx3)
402
+
403
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
404
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
405
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
406
+
407
+ return hx1d + hxin
408
+
409
+
410
+ class myrebnconv(nn.Module):
411
+ def __init__(self, in_ch=3,
412
+ out_ch=1,
413
+ kernel_size=3,
414
+ stride=1,
415
+ padding=1,
416
+ dilation=1,
417
+ groups=1):
418
+ super(myrebnconv, self).__init__()
419
+
420
+ self.conv = nn.Conv2d(in_ch,
421
+ out_ch,
422
+ kernel_size=kernel_size,
423
+ stride=stride,
424
+ padding=padding,
425
+ dilation=dilation,
426
+ groups=groups)
427
+ self.bn = nn.BatchNorm2d(out_ch)
428
+ self.rl = nn.ReLU(inplace=True)
429
+
430
+ def forward(self, x):
431
+ return self.rl(self.bn(self.conv(x)))
432
+
433
+
434
+ class ISNetGTEncoder(nn.Module):
435
+
436
+ def __init__(self, in_ch=1, out_ch=1):
437
+ super(ISNetGTEncoder, self).__init__()
438
+
439
+ self.conv_in = myrebnconv(in_ch, 16, 3, stride=2, padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
440
+
441
+ self.stage1 = RSU7(16, 16, 64)
442
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
443
+
444
+ self.stage2 = RSU6(64, 16, 64)
445
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
446
+
447
+ self.stage3 = RSU5(64, 32, 128)
448
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
449
+
450
+ self.stage4 = RSU4(128, 32, 256)
451
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
452
+
453
+ self.stage5 = RSU4F(256, 64, 512)
454
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
455
+
456
+ self.stage6 = RSU4F(512, 64, 512)
457
+
458
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
459
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
460
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
461
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
462
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
463
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
464
+
465
+ @staticmethod
466
+ def compute_loss(args, dist_weight=None):
467
+ preds, targets = args
468
+ return muti_loss_fusion(preds, targets, dist_weight)
469
+
470
+ def forward(self, x):
471
+ hx = x
472
+
473
+ hxin = self.conv_in(hx)
474
+ # hx = self.pool_in(hxin)
475
+
476
+ # stage 1
477
+ hx1 = self.stage1(hxin)
478
+ hx = self.pool12(hx1)
479
+
480
+ # stage 2
481
+ hx2 = self.stage2(hx)
482
+ hx = self.pool23(hx2)
483
+
484
+ # stage 3
485
+ hx3 = self.stage3(hx)
486
+ hx = self.pool34(hx3)
487
+
488
+ # stage 4
489
+ hx4 = self.stage4(hx)
490
+ hx = self.pool45(hx4)
491
+
492
+ # stage 5
493
+ hx5 = self.stage5(hx)
494
+ hx = self.pool56(hx5)
495
+
496
+ # stage 6
497
+ hx6 = self.stage6(hx)
498
+
499
+ # side output
500
+ d1 = self.side1(hx1)
501
+ d1 = _upsample_like(d1, x)
502
+
503
+ d2 = self.side2(hx2)
504
+ d2 = _upsample_like(d2, x)
505
+
506
+ d3 = self.side3(hx3)
507
+ d3 = _upsample_like(d3, x)
508
+
509
+ d4 = self.side4(hx4)
510
+ d4 = _upsample_like(d4, x)
511
+
512
+ d5 = self.side5(hx5)
513
+ d5 = _upsample_like(d5, x)
514
+
515
+ d6 = self.side6(hx6)
516
+ d6 = _upsample_like(d6, x)
517
+
518
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
519
+
520
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
521
+ return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
522
+
523
+
524
+ class ISNetDIS(nn.Module):
525
+
526
+ def __init__(self, in_ch=3, out_ch=1):
527
+ super(ISNetDIS, self).__init__()
528
+
529
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
530
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
531
+
532
+ self.stage1 = RSU7(64, 32, 64)
533
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
534
+
535
+ self.stage2 = RSU6(64, 32, 128)
536
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
537
+
538
+ self.stage3 = RSU5(128, 64, 256)
539
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
540
+
541
+ self.stage4 = RSU4(256, 128, 512)
542
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
543
+
544
+ self.stage5 = RSU4F(512, 256, 512)
545
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
546
+
547
+ self.stage6 = RSU4F(512, 256, 512)
548
+
549
+ # decoder
550
+ self.stage5d = RSU4F(1024, 256, 512)
551
+ self.stage4d = RSU4(1024, 128, 256)
552
+ self.stage3d = RSU5(512, 64, 128)
553
+ self.stage2d = RSU6(256, 32, 64)
554
+ self.stage1d = RSU7(128, 16, 64)
555
+
556
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
557
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
558
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
559
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
560
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
561
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
562
+
563
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
564
+
565
+ @staticmethod
566
+ def compute_loss_kl(preds, targets, dfs, fs, mode='MSE'):
567
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode, loss0_weight=5.0)
568
+
569
+ @staticmethod
570
+ def compute_loss(args, dist_weight=None):
571
+ if len(args) == 3:
572
+ ds, dfs, labels = args
573
+ return muti_loss_fusion(ds, labels, dist_weight, loss0_weight=5.0)
574
+ else:
575
+ ds, dfs, labels, fs = args
576
+ return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE", dist_weight=dist_weight, loss0_weight=5.0)
577
+
578
+ def forward(self, x):
579
+ hx = x
580
+
581
+ hxin = self.conv_in(hx)
582
+ hx = self.pool_in(hxin)
583
+
584
+ # stage 1
585
+ hx1 = self.stage1(hxin)
586
+ hx = self.pool12(hx1)
587
+
588
+ # stage 2
589
+ hx2 = self.stage2(hx)
590
+ hx = self.pool23(hx2)
591
+
592
+ # stage 3
593
+ hx3 = self.stage3(hx)
594
+ hx = self.pool34(hx3)
595
+
596
+ # stage 4
597
+ hx4 = self.stage4(hx)
598
+ hx = self.pool45(hx4)
599
+
600
+ # stage 5
601
+ hx5 = self.stage5(hx)
602
+ hx = self.pool56(hx5)
603
+
604
+ # stage 6
605
+ hx6 = self.stage6(hx)
606
+ hx6up = _upsample_like(hx6, hx5)
607
+
608
+ # -------------------- decoder --------------------
609
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
610
+ hx5dup = _upsample_like(hx5d, hx4)
611
+
612
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
613
+ hx4dup = _upsample_like(hx4d, hx3)
614
+
615
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
616
+ hx3dup = _upsample_like(hx3d, hx2)
617
+
618
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
619
+ hx2dup = _upsample_like(hx2d, hx1)
620
+
621
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
622
+
623
+ # side output
624
+ d1 = self.side1(hx1d)
625
+ d1 = _upsample_like(d1, x)
626
+
627
+ d2 = self.side2(hx2d)
628
+ d2 = _upsample_like(d2, x)
629
+
630
+ d3 = self.side3(hx3d)
631
+ d3 = _upsample_like(d3, x)
632
+
633
+ d4 = self.side4(hx4d)
634
+ d4 = _upsample_like(d4, x)
635
+
636
+ d5 = self.side5(hx5d)
637
+ d5 = _upsample_like(d5, x)
638
+
639
+ d6 = self.side6(hx6)
640
+ d6 = _upsample_like(d6, x)
641
+
642
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
643
+
644
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
645
+ return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
animeinsseg/models/animeseg_refine/models.py ADDED
File without changes
animeinsseg/models/animeseg_refine/modnet.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/trainer.py
3
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/models/backbones/mobilenetv2.py
4
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/models/modnet.py
5
+
6
+ import numpy as np
7
+ import scipy
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import os
12
+ import math
13
+ import torch
14
+ from scipy.ndimage import gaussian_filter
15
+
16
+
17
+ # ----------------------------------------------------------------------------------
18
+ # Loss Functions
19
+ # ----------------------------------------------------------------------------------
20
+
21
+
22
+ class GaussianBlurLayer(nn.Module):
23
+ """ Add Gaussian Blur to a 4D tensors
24
+ This layer takes a 4D tensor of {N, C, H, W} as input.
25
+ The Gaussian blur will be performed in given channel number (C) splitly.
26
+ """
27
+
28
+ def __init__(self, channels, kernel_size):
29
+ """
30
+ Arguments:
31
+ channels (int): Channel for input tensor
32
+ kernel_size (int): Size of the kernel used in blurring
33
+ """
34
+
35
+ super(GaussianBlurLayer, self).__init__()
36
+ self.channels = channels
37
+ self.kernel_size = kernel_size
38
+ assert self.kernel_size % 2 != 0
39
+
40
+ self.op = nn.Sequential(
41
+ nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
42
+ nn.Conv2d(channels, channels, self.kernel_size,
43
+ stride=1, padding=0, bias=None, groups=channels)
44
+ )
45
+
46
+ self._init_kernel()
47
+
48
+ def forward(self, x):
49
+ """
50
+ Arguments:
51
+ x (torch.Tensor): input 4D tensor
52
+ Returns:
53
+ torch.Tensor: Blurred version of the input
54
+ """
55
+
56
+ if not len(list(x.shape)) == 4:
57
+ print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
58
+ exit()
59
+ elif not x.shape[1] == self.channels:
60
+ print('In \'GaussianBlurLayer\', the required channel ({0}) is'
61
+ 'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
62
+ exit()
63
+
64
+ return self.op(x)
65
+
66
+ def _init_kernel(self):
67
+ sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
68
+
69
+ n = np.zeros((self.kernel_size, self.kernel_size))
70
+ i = math.floor(self.kernel_size / 2)
71
+ n[i, i] = 1
72
+ kernel = gaussian_filter(n, sigma)
73
+
74
+ for name, param in self.named_parameters():
75
+ param.data.copy_(torch.from_numpy(kernel))
76
+ param.requires_grad = False
77
+
78
+
79
+ blurer = GaussianBlurLayer(1, 3)
80
+
81
+
82
+ def loss_func(pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte,
83
+ semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
84
+ """ loss of MODNet
85
+ Arguments:
86
+ blurer: GaussianBlurLayer
87
+ pred_semantic: model output
88
+ pred_detail: model output
89
+ pred_matte: model output
90
+ image : input RGB image ts pixel values should be normalized
91
+ trimap : trimap used to calculate the losses
92
+ its pixel values can be 0, 0.5, or 1
93
+ (foreground=1, background=0, unknown=0.5)
94
+ gt_matte: ground truth alpha matte its pixel values are between [0, 1]
95
+ semantic_scale (float): scale of the semantic loss
96
+ NOTE: please adjust according to your dataset
97
+ detail_scale (float): scale of the detail loss
98
+ NOTE: please adjust according to your dataset
99
+ matte_scale (float): scale of the matte loss
100
+ NOTE: please adjust according to your dataset
101
+
102
+ Returns:
103
+ semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
104
+ detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
105
+ matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
106
+ """
107
+
108
+ trimap = trimap.float()
109
+ # calculate the boundary mask from the trimap
110
+ boundaries = (trimap < 0.5) + (trimap > 0.5)
111
+
112
+ # calculate the semantic loss
113
+ gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
114
+ gt_semantic = blurer(gt_semantic)
115
+ semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
116
+ semantic_loss = semantic_scale * semantic_loss
117
+
118
+ # calculate the detail loss
119
+ pred_boundary_detail = torch.where(boundaries, trimap, pred_detail.float())
120
+ gt_detail = torch.where(boundaries, trimap, gt_matte.float())
121
+ detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail.float()))
122
+ detail_loss = detail_scale * detail_loss
123
+
124
+ # calculate the matte loss
125
+ pred_boundary_matte = torch.where(boundaries, trimap, pred_matte.float())
126
+ matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
127
+ matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
128
+ + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
129
+ matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
130
+ matte_loss = matte_scale * matte_loss
131
+
132
+ return semantic_loss, detail_loss, matte_loss
133
+
134
+
135
+ # ------------------------------------------------------------------------------
136
+ # Useful functions
137
+ # ------------------------------------------------------------------------------
138
+
139
+ def _make_divisible(v, divisor, min_value=None):
140
+ if min_value is None:
141
+ min_value = divisor
142
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
143
+ # Make sure that round down does not go down by more than 10%.
144
+ if new_v < 0.9 * v:
145
+ new_v += divisor
146
+ return new_v
147
+
148
+
149
+ def conv_bn(inp, oup, stride):
150
+ return nn.Sequential(
151
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
152
+ nn.BatchNorm2d(oup),
153
+ nn.ReLU6(inplace=True)
154
+ )
155
+
156
+
157
+ def conv_1x1_bn(inp, oup):
158
+ return nn.Sequential(
159
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
160
+ nn.BatchNorm2d(oup),
161
+ nn.ReLU6(inplace=True)
162
+ )
163
+
164
+
165
+ # ------------------------------------------------------------------------------
166
+ # Class of Inverted Residual block
167
+ # ------------------------------------------------------------------------------
168
+
169
+ class InvertedResidual(nn.Module):
170
+ def __init__(self, inp, oup, stride, expansion, dilation=1):
171
+ super(InvertedResidual, self).__init__()
172
+ self.stride = stride
173
+ assert stride in [1, 2]
174
+
175
+ hidden_dim = round(inp * expansion)
176
+ self.use_res_connect = self.stride == 1 and inp == oup
177
+
178
+ if expansion == 1:
179
+ self.conv = nn.Sequential(
180
+ # dw
181
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
182
+ nn.BatchNorm2d(hidden_dim),
183
+ nn.ReLU6(inplace=True),
184
+ # pw-linear
185
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
186
+ nn.BatchNorm2d(oup),
187
+ )
188
+ else:
189
+ self.conv = nn.Sequential(
190
+ # pw
191
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
192
+ nn.BatchNorm2d(hidden_dim),
193
+ nn.ReLU6(inplace=True),
194
+ # dw
195
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
196
+ nn.BatchNorm2d(hidden_dim),
197
+ nn.ReLU6(inplace=True),
198
+ # pw-linear
199
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
200
+ nn.BatchNorm2d(oup),
201
+ )
202
+
203
+ def forward(self, x):
204
+ if self.use_res_connect:
205
+ return x + self.conv(x)
206
+ else:
207
+ return self.conv(x)
208
+
209
+
210
+ # ------------------------------------------------------------------------------
211
+ # Class of MobileNetV2
212
+ # ------------------------------------------------------------------------------
213
+
214
+ class MobileNetV2(nn.Module):
215
+ def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
216
+ super(MobileNetV2, self).__init__()
217
+ self.in_channels = in_channels
218
+ self.num_classes = num_classes
219
+ input_channel = 32
220
+ last_channel = 1280
221
+ interverted_residual_setting = [
222
+ # t, c, n, s
223
+ [1, 16, 1, 1],
224
+ [expansion, 24, 2, 2],
225
+ [expansion, 32, 3, 2],
226
+ [expansion, 64, 4, 2],
227
+ [expansion, 96, 3, 1],
228
+ [expansion, 160, 3, 2],
229
+ [expansion, 320, 1, 1],
230
+ ]
231
+
232
+ # building first layer
233
+ input_channel = _make_divisible(input_channel * alpha, 8)
234
+ self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
235
+ self.features = [conv_bn(self.in_channels, input_channel, 2)]
236
+
237
+ # building inverted residual blocks
238
+ for t, c, n, s in interverted_residual_setting:
239
+ output_channel = _make_divisible(int(c * alpha), 8)
240
+ for i in range(n):
241
+ if i == 0:
242
+ self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
243
+ else:
244
+ self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
245
+ input_channel = output_channel
246
+
247
+ # building last several layers
248
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
249
+
250
+ # make it nn.Sequential
251
+ self.features = nn.Sequential(*self.features)
252
+
253
+ # building classifier
254
+ if self.num_classes is not None:
255
+ self.classifier = nn.Sequential(
256
+ nn.Dropout(0.2),
257
+ nn.Linear(self.last_channel, num_classes),
258
+ )
259
+
260
+ # Initialize weights
261
+ self._init_weights()
262
+
263
+ def forward(self, x):
264
+ # Stage1
265
+ x = self.features[0](x)
266
+ x = self.features[1](x)
267
+ # Stage2
268
+ x = self.features[2](x)
269
+ x = self.features[3](x)
270
+ # Stage3
271
+ x = self.features[4](x)
272
+ x = self.features[5](x)
273
+ x = self.features[6](x)
274
+ # Stage4
275
+ x = self.features[7](x)
276
+ x = self.features[8](x)
277
+ x = self.features[9](x)
278
+ x = self.features[10](x)
279
+ x = self.features[11](x)
280
+ x = self.features[12](x)
281
+ x = self.features[13](x)
282
+ # Stage5
283
+ x = self.features[14](x)
284
+ x = self.features[15](x)
285
+ x = self.features[16](x)
286
+ x = self.features[17](x)
287
+ x = self.features[18](x)
288
+
289
+ # Classification
290
+ if self.num_classes is not None:
291
+ x = x.mean(dim=(2, 3))
292
+ x = self.classifier(x)
293
+
294
+ # Output
295
+ return x
296
+
297
+ def _load_pretrained_model(self, pretrained_file):
298
+ pretrain_dict = torch.load(pretrained_file, map_location='cpu')
299
+ model_dict = {}
300
+ state_dict = self.state_dict()
301
+ print("[MobileNetV2] Loading pretrained model...")
302
+ for k, v in pretrain_dict.items():
303
+ if k in state_dict:
304
+ model_dict[k] = v
305
+ else:
306
+ print(k, "is ignored")
307
+ state_dict.update(model_dict)
308
+ self.load_state_dict(state_dict)
309
+
310
+ def _init_weights(self):
311
+ for m in self.modules():
312
+ if isinstance(m, nn.Conv2d):
313
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
314
+ m.weight.data.normal_(0, math.sqrt(2. / n))
315
+ if m.bias is not None:
316
+ m.bias.data.zero_()
317
+ elif isinstance(m, nn.BatchNorm2d):
318
+ m.weight.data.fill_(1)
319
+ m.bias.data.zero_()
320
+ elif isinstance(m, nn.Linear):
321
+ n = m.weight.size(1)
322
+ m.weight.data.normal_(0, 0.01)
323
+ m.bias.data.zero_()
324
+
325
+
326
+ class BaseBackbone(nn.Module):
327
+ """ Superclass of Replaceable Backbone Model for Semantic Estimation
328
+ """
329
+
330
+ def __init__(self, in_channels):
331
+ super(BaseBackbone, self).__init__()
332
+ self.in_channels = in_channels
333
+
334
+ self.model = None
335
+ self.enc_channels = []
336
+
337
+ def forward(self, x):
338
+ raise NotImplementedError
339
+
340
+ def load_pretrained_ckpt(self):
341
+ raise NotImplementedError
342
+
343
+
344
+ class MobileNetV2Backbone(BaseBackbone):
345
+ """ MobileNetV2 Backbone
346
+ """
347
+
348
+ def __init__(self, in_channels):
349
+ super(MobileNetV2Backbone, self).__init__(in_channels)
350
+
351
+ self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
352
+ self.enc_channels = [16, 24, 32, 96, 1280]
353
+
354
+ def forward(self, x):
355
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
356
+ x = self.model.features[0](x)
357
+ x = self.model.features[1](x)
358
+ enc2x = x
359
+
360
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
361
+ x = self.model.features[2](x)
362
+ x = self.model.features[3](x)
363
+ enc4x = x
364
+
365
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
366
+ x = self.model.features[4](x)
367
+ x = self.model.features[5](x)
368
+ x = self.model.features[6](x)
369
+ enc8x = x
370
+
371
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
372
+ x = self.model.features[7](x)
373
+ x = self.model.features[8](x)
374
+ x = self.model.features[9](x)
375
+ x = self.model.features[10](x)
376
+ x = self.model.features[11](x)
377
+ x = self.model.features[12](x)
378
+ x = self.model.features[13](x)
379
+ enc16x = x
380
+
381
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
382
+ x = self.model.features[14](x)
383
+ x = self.model.features[15](x)
384
+ x = self.model.features[16](x)
385
+ x = self.model.features[17](x)
386
+ x = self.model.features[18](x)
387
+ enc32x = x
388
+ return [enc2x, enc4x, enc8x, enc16x, enc32x]
389
+
390
+ def load_pretrained_ckpt(self):
391
+ # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
392
+ ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
393
+ if not os.path.exists(ckpt_path):
394
+ print('cannot find the pretrained mobilenetv2 backbone')
395
+ exit()
396
+
397
+ ckpt = torch.load(ckpt_path)
398
+ self.model.load_state_dict(ckpt)
399
+
400
+
401
+ SUPPORTED_BACKBONES = {
402
+ 'mobilenetv2': MobileNetV2Backbone,
403
+ }
404
+
405
+
406
+ # ------------------------------------------------------------------------------
407
+ # MODNet Basic Modules
408
+ # ------------------------------------------------------------------------------
409
+
410
+ class IBNorm(nn.Module):
411
+ """ Combine Instance Norm and Batch Norm into One Layer
412
+ """
413
+
414
+ def __init__(self, in_channels):
415
+ super(IBNorm, self).__init__()
416
+ in_channels = in_channels
417
+ self.bnorm_channels = int(in_channels / 2)
418
+ self.inorm_channels = in_channels - self.bnorm_channels
419
+
420
+ self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
421
+ self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
422
+
423
+ def forward(self, x):
424
+ bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
425
+ in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
426
+
427
+ return torch.cat((bn_x, in_x), 1)
428
+
429
+
430
+ class Conv2dIBNormRelu(nn.Module):
431
+ """ Convolution + IBNorm + ReLu
432
+ """
433
+
434
+ def __init__(self, in_channels, out_channels, kernel_size,
435
+ stride=1, padding=0, dilation=1, groups=1, bias=True,
436
+ with_ibn=True, with_relu=True):
437
+ super(Conv2dIBNormRelu, self).__init__()
438
+
439
+ layers = [
440
+ nn.Conv2d(in_channels, out_channels, kernel_size,
441
+ stride=stride, padding=padding, dilation=dilation,
442
+ groups=groups, bias=bias)
443
+ ]
444
+
445
+ if with_ibn:
446
+ layers.append(IBNorm(out_channels))
447
+ if with_relu:
448
+ layers.append(nn.ReLU(inplace=True))
449
+
450
+ self.layers = nn.Sequential(*layers)
451
+
452
+ def forward(self, x):
453
+ return self.layers(x)
454
+
455
+
456
+ class SEBlock(nn.Module):
457
+ """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
458
+ """
459
+
460
+ def __init__(self, in_channels, out_channels, reduction=1):
461
+ super(SEBlock, self).__init__()
462
+ self.pool = nn.AdaptiveAvgPool2d(1)
463
+ self.fc = nn.Sequential(
464
+ nn.Linear(in_channels, int(in_channels // reduction), bias=False),
465
+ nn.ReLU(inplace=True),
466
+ nn.Linear(int(in_channels // reduction), out_channels, bias=False),
467
+ nn.Sigmoid()
468
+ )
469
+
470
+ def forward(self, x):
471
+ b, c, _, _ = x.size()
472
+ w = self.pool(x).view(b, c)
473
+ w = self.fc(w).view(b, c, 1, 1)
474
+
475
+ return x * w.expand_as(x)
476
+
477
+
478
+ # ------------------------------------------------------------------------------
479
+ # MODNet Branches
480
+ # ------------------------------------------------------------------------------
481
+
482
+ class LRBranch(nn.Module):
483
+ """ Low Resolution Branch of MODNet
484
+ """
485
+
486
+ def __init__(self, backbone):
487
+ super(LRBranch, self).__init__()
488
+
489
+ enc_channels = backbone.enc_channels
490
+
491
+ self.backbone = backbone
492
+ self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
493
+ self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
494
+ self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
495
+ self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False,
496
+ with_relu=False)
497
+
498
+ def forward(self, img, inference):
499
+ enc_features = self.backbone.forward(img)
500
+ enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
501
+
502
+ enc32x = self.se_block(enc32x)
503
+ lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
504
+ lr16x = self.conv_lr16x(lr16x)
505
+ lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
506
+ lr8x = self.conv_lr8x(lr8x)
507
+
508
+ pred_semantic = None
509
+ if not inference:
510
+ lr = self.conv_lr(lr8x)
511
+ pred_semantic = torch.sigmoid(lr)
512
+
513
+ return pred_semantic, lr8x, [enc2x, enc4x]
514
+
515
+
516
+ class HRBranch(nn.Module):
517
+ """ High Resolution Branch of MODNet
518
+ """
519
+
520
+ def __init__(self, hr_channels, enc_channels):
521
+ super(HRBranch, self).__init__()
522
+
523
+ self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
524
+ self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
525
+
526
+ self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
527
+ self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
528
+
529
+ self.conv_hr4x = nn.Sequential(
530
+ Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
531
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
532
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
533
+ )
534
+
535
+ self.conv_hr2x = nn.Sequential(
536
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
537
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
538
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
539
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
540
+ )
541
+
542
+ self.conv_hr = nn.Sequential(
543
+ Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
544
+ Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
545
+ )
546
+
547
+ def forward(self, img, enc2x, enc4x, lr8x, inference):
548
+ img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
549
+ img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
550
+
551
+ enc2x = self.tohr_enc2x(enc2x)
552
+ hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
553
+
554
+ enc4x = self.tohr_enc4x(enc4x)
555
+ hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
556
+
557
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
558
+ hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
559
+
560
+ hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
561
+ hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
562
+
563
+ pred_detail = None
564
+ if not inference:
565
+ hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
566
+ hr = self.conv_hr(torch.cat((hr, img), dim=1))
567
+ pred_detail = torch.sigmoid(hr)
568
+
569
+ return pred_detail, hr2x
570
+
571
+
572
+ class FusionBranch(nn.Module):
573
+ """ Fusion Branch of MODNet
574
+ """
575
+
576
+ def __init__(self, hr_channels, enc_channels):
577
+ super(FusionBranch, self).__init__()
578
+ self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
579
+
580
+ self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
581
+ self.conv_f = nn.Sequential(
582
+ Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
583
+ Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
584
+ )
585
+
586
+ def forward(self, img, lr8x, hr2x):
587
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
588
+ lr4x = self.conv_lr4x(lr4x)
589
+ lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
590
+
591
+ f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
592
+ f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
593
+ f = self.conv_f(torch.cat((f, img), dim=1))
594
+ pred_matte = torch.sigmoid(f)
595
+
596
+ return pred_matte
597
+
598
+
599
+ # ------------------------------------------------------------------------------
600
+ # MODNet
601
+ # ------------------------------------------------------------------------------
602
+
603
+ class MODNet(nn.Module):
604
+ """ Architecture of MODNet
605
+ """
606
+
607
+ def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=False):
608
+ super(MODNet, self).__init__()
609
+
610
+ self.in_channels = in_channels
611
+ self.hr_channels = hr_channels
612
+ self.backbone_arch = backbone_arch
613
+ self.backbone_pretrained = backbone_pretrained
614
+
615
+ self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
616
+
617
+ self.lr_branch = LRBranch(self.backbone)
618
+ self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
619
+ self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
620
+
621
+ for m in self.modules():
622
+ if isinstance(m, nn.Conv2d):
623
+ self._init_conv(m)
624
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
625
+ self._init_norm(m)
626
+
627
+ if self.backbone_pretrained:
628
+ self.backbone.load_pretrained_ckpt()
629
+
630
+ def forward(self, img, inference):
631
+ pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
632
+ pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
633
+ pred_matte = self.f_branch(img, lr8x, hr2x)
634
+
635
+ return pred_semantic, pred_detail, pred_matte
636
+
637
+ @staticmethod
638
+ def compute_loss(args):
639
+ pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte = args
640
+ semantic_loss, detail_loss, matte_loss = loss_func(pred_semantic, pred_detail, pred_matte,
641
+ image, trimap, gt_matte)
642
+ loss = semantic_loss + detail_loss + matte_loss
643
+ return matte_loss, loss
644
+
645
+ def freeze_norm(self):
646
+ norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
647
+ for m in self.modules():
648
+ for n in norm_types:
649
+ if isinstance(m, n):
650
+ m.eval()
651
+ continue
652
+
653
+ def _init_conv(self, conv):
654
+ nn.init.kaiming_uniform_(
655
+ conv.weight, a=0, mode='fan_in', nonlinearity='relu')
656
+ if conv.bias is not None:
657
+ nn.init.constant_(conv.bias, 0)
658
+
659
+ def _init_norm(self, norm):
660
+ if norm.weight is not None:
661
+ nn.init.constant_(norm.weight, 1)
662
+ nn.init.constant_(norm.bias, 0)
663
+
664
+ def _apply(self, fn):
665
+ super(MODNet, self)._apply(fn)
666
+ blurer._apply(fn) # let blurer's device same as modnet
667
+ return self
animeinsseg/models/animeseg_refine/u2net.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/U-2-Net/blob/master/model/u2net_refactor.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ __all__ = ['U2NET_full', 'U2NET_full2', 'U2NET_lite', 'U2NET_lite2', "U2NET"]
10
+
11
+ bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
12
+
13
+
14
+ def _upsample_like(x, size):
15
+ return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
16
+
17
+
18
+ def _size_map(x, height):
19
+ # {height: size} for Upsample
20
+ size = list(x.shape[-2:])
21
+ sizes = {}
22
+ for h in range(1, height):
23
+ sizes[h] = size
24
+ size = [math.ceil(w / 2) for w in size]
25
+ return sizes
26
+
27
+
28
+ class REBNCONV(nn.Module):
29
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
30
+ super(REBNCONV, self).__init__()
31
+
32
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
33
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
34
+ self.relu_s1 = nn.ReLU(inplace=True)
35
+
36
+ def forward(self, x):
37
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
38
+
39
+
40
+ class RSU(nn.Module):
41
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
42
+ super(RSU, self).__init__()
43
+ self.name = name
44
+ self.height = height
45
+ self.dilated = dilated
46
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
47
+
48
+ def forward(self, x):
49
+ sizes = _size_map(x, self.height)
50
+ x = self.rebnconvin(x)
51
+
52
+ # U-Net like symmetric encoder-decoder structure
53
+ def unet(x, height=1):
54
+ if height < self.height:
55
+ x1 = getattr(self, f'rebnconv{height}')(x)
56
+ if not self.dilated and height < self.height - 1:
57
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
58
+ else:
59
+ x2 = unet(x1, height + 1)
60
+
61
+ x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
62
+ return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
63
+ else:
64
+ return getattr(self, f'rebnconv{height}')(x)
65
+
66
+ return x + unet(x)
67
+
68
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
69
+ self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
70
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
71
+
72
+ self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
73
+ self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
74
+
75
+ for i in range(2, height):
76
+ dilate = 1 if not dilated else 2 ** (i - 1)
77
+ self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
78
+ self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
79
+
80
+ dilate = 2 if not dilated else 2 ** (height - 1)
81
+ self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
82
+
83
+
84
+ class U2NET(nn.Module):
85
+ def __init__(self, cfgs, out_ch):
86
+ super(U2NET, self).__init__()
87
+ self.out_ch = out_ch
88
+ self._make_layers(cfgs)
89
+
90
+ def forward(self, x):
91
+ sizes = _size_map(x, self.height)
92
+ maps = [] # storage for maps
93
+
94
+ # side saliency map
95
+ def unet(x, height=1):
96
+ if height < 6:
97
+ x1 = getattr(self, f'stage{height}')(x)
98
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
99
+ x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
100
+ side(x, height)
101
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
102
+ else:
103
+ x = getattr(self, f'stage{height}')(x)
104
+ side(x, height)
105
+ return _upsample_like(x, sizes[height - 1])
106
+
107
+ def side(x, h):
108
+ # side output saliency map (before sigmoid)
109
+ x = getattr(self, f'side{h}')(x)
110
+ x = _upsample_like(x, sizes[1])
111
+ maps.append(x)
112
+
113
+ def fuse():
114
+ # fuse saliency probability maps
115
+ maps.reverse()
116
+ x = torch.cat(maps, 1)
117
+ x = getattr(self, 'outconv')(x)
118
+ maps.insert(0, x)
119
+ # return [torch.sigmoid(x) for x in maps]
120
+ return [x for x in maps]
121
+
122
+ unet(x)
123
+ maps = fuse()
124
+ return maps
125
+
126
+ @staticmethod
127
+ def compute_loss(args):
128
+ preds, labels_v = args
129
+ d0, d1, d2, d3, d4, d5, d6 = preds
130
+ loss0 = bce_loss(d0, labels_v)
131
+ loss1 = bce_loss(d1, labels_v)
132
+ loss2 = bce_loss(d2, labels_v)
133
+ loss3 = bce_loss(d3, labels_v)
134
+ loss4 = bce_loss(d4, labels_v)
135
+ loss5 = bce_loss(d5, labels_v)
136
+ loss6 = bce_loss(d6, labels_v)
137
+
138
+ loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
139
+
140
+ return loss0, loss
141
+
142
+ def _make_layers(self, cfgs):
143
+ self.height = int((len(cfgs) + 1) / 2)
144
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
145
+ for k, v in cfgs.items():
146
+ # build rsu block
147
+ self.add_module(k, RSU(v[0], *v[1]))
148
+ if v[2] > 0:
149
+ # build side layer
150
+ self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
151
+ # build fuse layer
152
+ self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
153
+
154
+
155
+ def U2NET_full():
156
+ full = {
157
+ # cfgs for building RSUs and sides
158
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
159
+ 'stage1': ['En_1', (7, 3, 32, 64), -1],
160
+ 'stage2': ['En_2', (6, 64, 32, 128), -1],
161
+ 'stage3': ['En_3', (5, 128, 64, 256), -1],
162
+ 'stage4': ['En_4', (4, 256, 128, 512), -1],
163
+ 'stage5': ['En_5', (4, 512, 256, 512, True), -1],
164
+ 'stage6': ['En_6', (4, 512, 256, 512, True), 512],
165
+ 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
166
+ 'stage4d': ['De_4', (4, 1024, 128, 256), 256],
167
+ 'stage3d': ['De_3', (5, 512, 64, 128), 128],
168
+ 'stage2d': ['De_2', (6, 256, 32, 64), 64],
169
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
170
+ }
171
+ return U2NET(cfgs=full, out_ch=1)
172
+
173
+
174
+ def U2NET_full2():
175
+ full = {
176
+ # cfgs for building RSUs and sides
177
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
178
+ 'stage1': ['En_1', (8, 3, 32, 64), -1],
179
+ 'stage2': ['En_2', (7, 64, 32, 128), -1],
180
+ 'stage3': ['En_3', (6, 128, 64, 256), -1],
181
+ 'stage4': ['En_4', (5, 256, 128, 512), -1],
182
+ 'stage5': ['En_5', (5, 512, 256, 512, True), -1],
183
+ 'stage6': ['En_6', (5, 512, 256, 512, True), 512],
184
+ 'stage5d': ['De_5', (5, 1024, 256, 512, True), 512],
185
+ 'stage4d': ['De_4', (5, 1024, 128, 256), 256],
186
+ 'stage3d': ['De_3', (6, 512, 64, 128), 128],
187
+ 'stage2d': ['De_2', (7, 256, 32, 64), 64],
188
+ 'stage1d': ['De_1', (8, 128, 16, 64), 64],
189
+ }
190
+ return U2NET(cfgs=full, out_ch=1)
191
+
192
+
193
+ def U2NET_lite():
194
+ lite = {
195
+ # cfgs for building RSUs and sides
196
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
197
+ 'stage1': ['En_1', (7, 3, 16, 64), -1],
198
+ 'stage2': ['En_2', (6, 64, 16, 64), -1],
199
+ 'stage3': ['En_3', (5, 64, 16, 64), -1],
200
+ 'stage4': ['En_4', (4, 64, 16, 64), -1],
201
+ 'stage5': ['En_5', (4, 64, 16, 64, True), -1],
202
+ 'stage6': ['En_6', (4, 64, 16, 64, True), 64],
203
+ 'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
204
+ 'stage4d': ['De_4', (4, 128, 16, 64), 64],
205
+ 'stage3d': ['De_3', (5, 128, 16, 64), 64],
206
+ 'stage2d': ['De_2', (6, 128, 16, 64), 64],
207
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
208
+ }
209
+ return U2NET(cfgs=lite, out_ch=1)
210
+
211
+
212
+ def U2NET_lite2():
213
+ lite = {
214
+ # cfgs for building RSUs and sides
215
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
216
+ 'stage1': ['En_1', (8, 3, 16, 64), -1],
217
+ 'stage2': ['En_2', (7, 64, 16, 64), -1],
218
+ 'stage3': ['En_3', (6, 64, 16, 64), -1],
219
+ 'stage4': ['En_4', (5, 64, 16, 64), -1],
220
+ 'stage5': ['En_5', (5, 64, 16, 64, True), -1],
221
+ 'stage6': ['En_6', (5, 64, 16, 64, True), 64],
222
+ 'stage5d': ['De_5', (5, 128, 16, 64, True), 64],
223
+ 'stage4d': ['De_4', (5, 128, 16, 64), 64],
224
+ 'stage3d': ['De_3', (6, 128, 16, 64), 64],
225
+ 'stage2d': ['De_2', (7, 128, 16, 64), 64],
226
+ 'stage1d': ['De_1', (8, 128, 16, 64), 64],
227
+ }
228
+ return U2NET(cfgs=lite, out_ch=1)
animeinsseg/models/rtmdet_inshead_custom.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import math
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from mmcv.cnn import ConvModule, is_norm
10
+ from mmcv.ops import batched_nms
11
+ from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
12
+ normal_init)
13
+ from mmengine.structures import InstanceData
14
+ from torch import Tensor
15
+
16
+ from mmdet.models.layers.transformer import inverse_sigmoid
17
+ from mmdet.models.utils import (filter_scores_and_topk, multi_apply,
18
+ select_single_mlvl, sigmoid_geometric_mean)
19
+ from mmdet.registry import MODELS
20
+ from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor,
21
+ get_box_wh, scale_boxes)
22
+ from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
23
+ from mmdet.models.dense_heads.rtmdet_head import RTMDetHead
24
+ from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead, MaskFeatModule
25
+
26
+ from mmdet.utils import AvoidCUDAOOM
27
+
28
+
29
+
30
+ def sthgoeswrong(logits):
31
+ return torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits))
32
+
33
+ from time import time
34
+
35
+ @MODELS.register_module(force=True)
36
+ class RTMDetInsHeadCustom(RTMDetInsHead):
37
+
38
+ def loss_by_feat(self,
39
+ cls_scores: List[Tensor],
40
+ bbox_preds: List[Tensor],
41
+ kernel_preds: List[Tensor],
42
+ mask_feat: Tensor,
43
+ batch_gt_instances: InstanceList,
44
+ batch_img_metas: List[dict],
45
+ batch_gt_instances_ignore: OptInstanceList = None):
46
+ """Compute losses of the head.
47
+
48
+ Args:
49
+ cls_scores (list[Tensor]): Box scores for each scale level
50
+ Has shape (N, num_anchors * num_classes, H, W)
51
+ bbox_preds (list[Tensor]): Decoded box for each scale
52
+ level with shape (N, num_anchors * 4, H, W) in
53
+ [tl_x, tl_y, br_x, br_y] format.
54
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
55
+ gt_instance. It usually includes ``bboxes`` and ``labels``
56
+ attributes.
57
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
58
+ image size, scaling factor, etc.
59
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
60
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
61
+ data that is ignored during training and testing.
62
+ Defaults to None.
63
+
64
+ Returns:
65
+ dict[str, Tensor]: A dictionary of loss components.
66
+ """
67
+ num_imgs = len(batch_img_metas)
68
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
69
+ assert len(featmap_sizes) == self.prior_generator.num_levels
70
+
71
+ device = cls_scores[0].device
72
+ anchor_list, valid_flag_list = self.get_anchors(
73
+ featmap_sizes, batch_img_metas, device=device)
74
+ flatten_cls_scores = torch.cat([
75
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
76
+ self.cls_out_channels)
77
+ for cls_score in cls_scores
78
+ ], 1)
79
+ flatten_kernels = torch.cat([
80
+ kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
81
+ self.num_gen_params)
82
+ for kernel_pred in kernel_preds
83
+ ], 1)
84
+ decoded_bboxes = []
85
+ for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
86
+ anchor = anchor.reshape(-1, 4)
87
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
88
+ bbox_pred = distance2bbox(anchor, bbox_pred)
89
+ decoded_bboxes.append(bbox_pred)
90
+
91
+ flatten_bboxes = torch.cat(decoded_bboxes, 1)
92
+ for gt_instances in batch_gt_instances:
93
+ gt_instances.masks = gt_instances.masks.to_tensor(
94
+ dtype=torch.bool, device=device)
95
+
96
+ cls_reg_targets = self.get_targets(
97
+ flatten_cls_scores,
98
+ flatten_bboxes,
99
+ anchor_list,
100
+ valid_flag_list,
101
+ batch_gt_instances,
102
+ batch_img_metas,
103
+ batch_gt_instances_ignore=batch_gt_instances_ignore)
104
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
105
+ assign_metrics_list, sampling_results_list) = cls_reg_targets
106
+
107
+ losses_cls, losses_bbox,\
108
+ cls_avg_factors, bbox_avg_factors = multi_apply(
109
+ self.loss_by_feat_single,
110
+ cls_scores,
111
+ decoded_bboxes,
112
+ labels_list,
113
+ label_weights_list,
114
+ bbox_targets_list,
115
+ assign_metrics_list,
116
+ self.prior_generator.strides)
117
+
118
+ cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
119
+ losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
120
+
121
+ bbox_avg_factor = reduce_mean(
122
+ sum(bbox_avg_factors)).clamp_(min=1).item()
123
+ losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
124
+
125
+ loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels,
126
+ sampling_results_list,
127
+ batch_gt_instances)
128
+ loss = dict(
129
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask)
130
+
131
+ return loss
132
+
133
+
134
+ def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
135
+ priors: Tensor) -> Tensor:
136
+
137
+ ori_maskfeat = mask_feat
138
+
139
+ num_inst = priors.shape[0]
140
+ h, w = mask_feat.size()[-2:]
141
+ if num_inst < 1:
142
+ return torch.empty(
143
+ size=(num_inst, h, w),
144
+ dtype=mask_feat.dtype,
145
+ device=mask_feat.device)
146
+ if len(mask_feat.shape) < 4:
147
+ mask_feat.unsqueeze(0)
148
+
149
+ coord = self.prior_generator.single_level_grid_priors(
150
+ (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
151
+ num_inst = priors.shape[0]
152
+ points = priors[:, :2].reshape(-1, 1, 2)
153
+ strides = priors[:, 2:].reshape(-1, 1, 2)
154
+ relative_coord = (points - coord).permute(0, 2, 1) / (
155
+ strides[..., 0].reshape(-1, 1, 1) * 8)
156
+ relative_coord = relative_coord.reshape(num_inst, 2, h, w)
157
+
158
+ mask_feat = torch.cat(
159
+ [relative_coord,
160
+ mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
161
+ weights, biases = self.parse_dynamic_params(kernels)
162
+
163
+ fp16_used = weights[0].dtype == torch.float16
164
+
165
+ n_layers = len(weights)
166
+ x = mask_feat.reshape(1, -1, h, w)
167
+ for i, (weight, bias) in enumerate(zip(weights, biases)):
168
+ with torch.cuda.amp.autocast(enabled=False):
169
+ if fp16_used:
170
+ weight = weight.to(torch.float32)
171
+ bias = bias.to(torch.float32)
172
+ x = F.conv2d(
173
+ x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
174
+ if i < n_layers - 1:
175
+ x = F.relu(x)
176
+
177
+ if fp16_used:
178
+ x = torch.clip(x, -8192, 8192)
179
+ if sthgoeswrong(x):
180
+ torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt')
181
+ raise Exception('Mask Head NaN')
182
+
183
+ x = x.reshape(num_inst, h, w)
184
+ return x
185
+
186
+ def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
187
+ sampling_results_list: list,
188
+ batch_gt_instances: InstanceList) -> Tensor:
189
+ batch_pos_mask_logits = []
190
+ pos_gt_masks = []
191
+ ignore_masks = []
192
+ for idx, (mask_feat, kernels, sampling_results,
193
+ gt_instances) in enumerate(
194
+ zip(mask_feats, flatten_kernels, sampling_results_list,
195
+ batch_gt_instances)):
196
+ pos_priors = sampling_results.pos_priors
197
+ pos_inds = sampling_results.pos_inds
198
+ pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
199
+ pos_mask_logits = self._mask_predict_by_feat_single(
200
+ mask_feat, pos_kernels, pos_priors)
201
+ if gt_instances.masks.numel() == 0:
202
+ gt_masks = torch.empty_like(gt_instances.masks)
203
+ if gt_masks.shape[0] > 0:
204
+ ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device)
205
+ ignore_masks.append(ignore)
206
+ else:
207
+ gt_masks = gt_instances.masks[
208
+ sampling_results.pos_assigned_gt_inds, :]
209
+ ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
210
+ batch_pos_mask_logits.append(pos_mask_logits)
211
+ pos_gt_masks.append(gt_masks)
212
+
213
+ pos_gt_masks = torch.cat(pos_gt_masks, 0)
214
+ batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
215
+ ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0))
216
+
217
+ pos_gt_masks = pos_gt_masks[ignore_masks]
218
+ batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks]
219
+
220
+
221
+ # avg_factor
222
+ num_pos = batch_pos_mask_logits.shape[0]
223
+ num_pos = reduce_mean(mask_feats.new_tensor([num_pos
224
+ ])).clamp_(min=1).item()
225
+
226
+ if batch_pos_mask_logits.shape[0] == 0:
227
+ return mask_feats.sum() * 0
228
+
229
+ scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
230
+ # upsample pred masks
231
+ batch_pos_mask_logits = F.interpolate(
232
+ batch_pos_mask_logits.unsqueeze(0),
233
+ scale_factor=scale,
234
+ mode='bilinear',
235
+ align_corners=False).squeeze(0)
236
+ # downsample gt masks
237
+ pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
238
+ 2::self.mask_loss_stride,
239
+ self.mask_loss_stride //
240
+ 2::self.mask_loss_stride]
241
+
242
+ loss_mask = self.loss_mask(
243
+ batch_pos_mask_logits,
244
+ pos_gt_masks,
245
+ weight=None,
246
+ avg_factor=num_pos)
247
+
248
+ return loss_mask
249
+
250
+
251
+ @MODELS.register_module()
252
+ class RTMDetInsSepBNHeadCustom(RTMDetInsSepBNHead):
253
+ def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
254
+ priors: Tensor) -> Tensor:
255
+
256
+ ori_maskfeat = mask_feat
257
+
258
+ num_inst = priors.shape[0]
259
+ h, w = mask_feat.size()[-2:]
260
+ if num_inst < 1:
261
+ return torch.empty(
262
+ size=(num_inst, h, w),
263
+ dtype=mask_feat.dtype,
264
+ device=mask_feat.device)
265
+ if len(mask_feat.shape) < 4:
266
+ mask_feat.unsqueeze(0)
267
+
268
+ coord = self.prior_generator.single_level_grid_priors(
269
+ (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
270
+ num_inst = priors.shape[0]
271
+ points = priors[:, :2].reshape(-1, 1, 2)
272
+ strides = priors[:, 2:].reshape(-1, 1, 2)
273
+ relative_coord = (points - coord).permute(0, 2, 1) / (
274
+ strides[..., 0].reshape(-1, 1, 1) * 8)
275
+ relative_coord = relative_coord.reshape(num_inst, 2, h, w)
276
+
277
+ mask_feat = torch.cat(
278
+ [relative_coord,
279
+ mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
280
+ weights, biases = self.parse_dynamic_params(kernels)
281
+
282
+ fp16_used = weights[0].dtype == torch.float16
283
+
284
+ n_layers = len(weights)
285
+ x = mask_feat.reshape(1, -1, h, w)
286
+ for i, (weight, bias) in enumerate(zip(weights, biases)):
287
+ with torch.cuda.amp.autocast(enabled=False):
288
+ if fp16_used:
289
+ weight = weight.to(torch.float32)
290
+ bias = bias.to(torch.float32)
291
+ x = F.conv2d(
292
+ x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
293
+ if i < n_layers - 1:
294
+ x = F.relu(x)
295
+
296
+ if fp16_used:
297
+ x = torch.clip(x, -8192, 8192)
298
+ if sthgoeswrong(x):
299
+ torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt')
300
+ raise Exception('Mask Head NaN')
301
+
302
+ x = x.reshape(num_inst, h, w)
303
+ return x
304
+
305
+ @AvoidCUDAOOM.retry_if_cuda_oom
306
+ def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
307
+ sampling_results_list: list,
308
+ batch_gt_instances: InstanceList) -> Tensor:
309
+ batch_pos_mask_logits = []
310
+ pos_gt_masks = []
311
+ ignore_masks = []
312
+ for idx, (mask_feat, kernels, sampling_results,
313
+ gt_instances) in enumerate(
314
+ zip(mask_feats, flatten_kernels, sampling_results_list,
315
+ batch_gt_instances)):
316
+ pos_priors = sampling_results.pos_priors
317
+ pos_inds = sampling_results.pos_inds
318
+ pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
319
+ pos_mask_logits = self._mask_predict_by_feat_single(
320
+ mask_feat, pos_kernels, pos_priors)
321
+ if gt_instances.masks.numel() == 0:
322
+ gt_masks = torch.empty_like(gt_instances.masks)
323
+ # if gt_masks.shape[0] > 0:
324
+ # ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device)
325
+ # ignore_masks.append(ignore)
326
+ else:
327
+ msk = torch.logical_not(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
328
+ gt_masks = gt_instances.masks[
329
+ sampling_results.pos_assigned_gt_inds, :][msk]
330
+ pos_mask_logits = pos_mask_logits[msk]
331
+ # ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
332
+ batch_pos_mask_logits.append(pos_mask_logits)
333
+ pos_gt_masks.append(gt_masks)
334
+
335
+ pos_gt_masks = torch.cat(pos_gt_masks, 0)
336
+ batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
337
+ # ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0))
338
+
339
+ # pos_gt_masks = pos_gt_masks[ignore_masks]
340
+ # batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks]
341
+
342
+
343
+ # avg_factor
344
+ num_pos = batch_pos_mask_logits.shape[0]
345
+ num_pos = reduce_mean(mask_feats.new_tensor([num_pos
346
+ ])).clamp_(min=1).item()
347
+
348
+ if batch_pos_mask_logits.shape[0] == 0:
349
+ return mask_feats.sum() * 0
350
+
351
+ scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
352
+ # upsample pred masks
353
+ batch_pos_mask_logits = F.interpolate(
354
+ batch_pos_mask_logits.unsqueeze(0),
355
+ scale_factor=scale,
356
+ mode='bilinear',
357
+ align_corners=False).squeeze(0)
358
+ # downsample gt masks
359
+ pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
360
+ 2::self.mask_loss_stride,
361
+ self.mask_loss_stride //
362
+ 2::self.mask_loss_stride]
363
+
364
+ loss_mask = self.loss_mask(
365
+ batch_pos_mask_logits,
366
+ pos_gt_masks,
367
+ weight=None,
368
+ avg_factor=num_pos)
369
+
370
+ return loss_mask
ccip.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ from functools import lru_cache
4
+ from typing import Union, List
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download, HfFileSystem
9
+
10
+ try:
11
+ from typing import Literal
12
+ except (ModuleNotFoundError, ImportError):
13
+ from typing_extensions import Literal
14
+
15
+ from imgutils.data import MultiImagesTyping, load_images, ImageTyping
16
+ from imgutils.utils import open_onnx_model
17
+
18
+ hf_fs = HfFileSystem()
19
+
20
+
21
+ def _normalize(data, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)):
22
+ mean, std = np.asarray(mean), np.asarray(std)
23
+ return (data - mean[:, None, None]) / std[:, None, None]
24
+
25
+
26
+ def _preprocess_image(image: Image.Image, size: int = 384):
27
+ image = image.resize((size, size), resample=Image.BILINEAR)
28
+ # noinspection PyTypeChecker
29
+ data = np.array(image).transpose(2, 0, 1).astype(np.float32) / 255.0
30
+ data = _normalize(data)
31
+
32
+ return data
33
+
34
+
35
+ @lru_cache()
36
+ def _open_feat_model(model):
37
+ return open_onnx_model(hf_hub_download(
38
+ f'deepghs/ccip_onnx',
39
+ f'{model}/model_feat.onnx',
40
+ ))
41
+
42
+
43
+ @lru_cache()
44
+ def _open_metric_model(model):
45
+ return open_onnx_model(hf_hub_download(
46
+ f'deepghs/ccip_onnx',
47
+ f'{model}/model_metrics.onnx',
48
+ ))
49
+
50
+
51
+ @lru_cache()
52
+ def _open_metrics(model):
53
+ with open(hf_hub_download(f'deepghs/ccip_onnx', f'{model}/metrics.json'), 'r') as f:
54
+ return json.load(f)
55
+
56
+
57
+ @lru_cache()
58
+ def _open_cluster_metrics(model):
59
+ with open(hf_hub_download(f'deepghs/ccip_onnx', f'{model}/cluster.json'), 'r') as f:
60
+ return json.load(f)
61
+
62
+
63
+ _VALID_MODEL_NAMES = [
64
+ os.path.basename(os.path.dirname(file)) for file in
65
+ hf_fs.glob('deepghs/ccip_onnx/*/model.ckpt')
66
+ ]
67
+ _DEFAULT_MODEL_NAMES = 'ccip-caformer-24-randaug-pruned'
68
+
69
+
70
+ def ccip_extract_feature(image: ImageTyping, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
71
+ """
72
+ Extracts the feature vector of the character from the given anime image.
73
+
74
+ :param image: The anime image containing a single character.
75
+ :type image: ImageTyping
76
+
77
+ :param size: The size of the input image to be used for feature extraction. (default: ``384``)
78
+ :type size: int
79
+
80
+ :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
81
+ The available model names are: ``ccip-caformer-24-randaug-pruned``,
82
+ ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
83
+ :type model: str
84
+
85
+ :return: The feature vector of the character.
86
+ :rtype: numpy.ndarray
87
+
88
+ Examples::
89
+ >>> from imgutils.metrics import ccip_extract_feature
90
+ >>>
91
+ >>> feat = ccip_extract_feature('ccip/1.jpg')
92
+ >>> feat.shape, feat.dtype
93
+ ((768,), dtype('float32'))
94
+ """
95
+ return ccip_batch_extract_features([image], size, model)[0]
96
+
97
+
98
+ def ccip_batch_extract_features(images: MultiImagesTyping, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
99
+ """
100
+ Extracts the feature vectors of multiple images using the specified model.
101
+
102
+ :param images: The input images from which to extract the feature vectors.
103
+ :type images: MultiImagesTyping
104
+
105
+ :param size: The size of the input image to be used for feature extraction. (default: ``384``)
106
+ :type size: int
107
+
108
+ :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
109
+ The available model names are: ``ccip-caformer-24-randaug-pruned``,
110
+ ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
111
+ :type model: str
112
+
113
+ :return: The feature vectors of the input images.
114
+ :rtype: numpy.ndarray
115
+
116
+ Examples::
117
+ >>> from imgutils.metrics import ccip_batch_extract_features
118
+ >>>
119
+ >>> feat = ccip_batch_extract_features(['ccip/1.jpg', 'ccip/2.jpg', 'ccip/6.jpg'])
120
+ >>> feat.shape, feat.dtype
121
+ ((3, 768), dtype('float32'))
122
+ """
123
+ images = load_images(images, mode='RGB')
124
+ data = np.stack([_preprocess_image(item, size=size) for item in images]).astype(np.float32)
125
+ output, = _open_feat_model(model).run(['output'], {'input': data})
126
+ return output
127
+
128
+
129
+ _FeatureOrImage = Union[ImageTyping, np.ndarray]
130
+
131
+
132
+ def _p_feature(x: _FeatureOrImage, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
133
+ if isinstance(x, np.ndarray): # if feature
134
+ return x
135
+ else: # is image or path
136
+ return ccip_extract_feature(x, size, model)
137
+
138
+
139
+ def ccip_default_threshold(model: str = _DEFAULT_MODEL_NAMES) -> float:
140
+ """
141
+ Retrieves the default threshold value obtained from model metrics in the Hugging Face model repository.
142
+
143
+ :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
144
+ The available model names are: ``ccip-caformer-24-randaug-pruned``,
145
+ ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
146
+ :type model: str
147
+
148
+ :return: The default threshold value obtained from model metrics.
149
+ :rtype: float
150
+
151
+ Examples::
152
+ >>> from imgutils.metrics import ccip_default_threshold
153
+ >>>
154
+ >>> ccip_default_threshold()
155
+ 0.17847511429108218
156
+ >>> ccip_default_threshold('ccip-caformer-6-randaug-pruned_fp32')
157
+ 0.1951224011983088
158
+ >>> ccip_default_threshold('ccip-caformer-5_fp32')
159
+ 0.18397327797685215
160
+ """
161
+ return _open_metrics(model)['threshold']
162
+
163
+
164
+ def ccip_difference(x: _FeatureOrImage, y: _FeatureOrImage,
165
+ size: int = 384, model: str = _DEFAULT_MODEL_NAMES) -> float:
166
+ """
167
+ Calculates the difference value between two anime characters based on their images or feature vectors.
168
+
169
+ :param x: The image or feature vector of the first anime character.
170
+ :type x: Union[ImageTyping, np.ndarray]
171
+
172
+ :param y: The image or feature vector of the second anime character.
173
+ :type y: Union[ImageTyping, np.ndarray]
174
+
175
+ :param size: The size of the input image to be used for feature extraction. (default: ``384``)
176
+ :type size: int
177
+
178
+ :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
179
+ The available model names are: ``ccip-caformer-24-randaug-pruned``,
180
+ ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
181
+ :type model: str
182
+
183
+ :return: The difference value between the two anime characters.
184
+ :rtype: float
185
+
186
+ Examples::
187
+ >>> from imgutils.metrics import ccip_difference
188
+ >>>
189
+ >>> ccip_difference('ccip/1.jpg', 'ccip/2.jpg') # same character
190
+ 0.16583099961280823
191
+ >>>
192
+ >>> # different characters
193
+ >>> ccip_difference('ccip/1.jpg', 'ccip/6.jpg')
194
+ 0.42947039008140564
195
+ >>> ccip_difference('ccip/1.jpg', 'ccip/7.jpg')
196
+ 0.4037521779537201
197
+ >>> ccip_difference('ccip/2.jpg', 'ccip/6.jpg')
198
+ 0.4371533691883087
199
+ >>> ccip_difference('ccip/2.jpg', 'ccip/7.jpg')
200
+ 0.40748104453086853
201
+ >>> ccip_difference('ccip/6.jpg', 'ccip/7.jpg')
202
+ 0.392294704914093
203
+ """
204
+ return ccip_batch_differences([x, y], size, model)[0, 1].item()
205
+
206
+
207
+ def ccip_batch_differences(images: List[_FeatureOrImage],
208
+ size: int = 384, model: str = _DEFAULT_MODEL_NAMES) -> np.ndarray:
209
+ """
210
+ Calculates the pairwise differences between a given list of images or feature vectors representing anime characters.
211
+
212
+ :param images: The list of images or feature vectors representing anime characters.
213
+ :type images: List[Union[ImageTyping, np.ndarray]]
214
+
215
+ :param size: The size of the input image to be used for feature extraction. (default: ``384``)
216
+ :type size: int
217
+
218
+ :param model: The name of the model to use for feature extraction. (default: ``ccip-caformer-24-randaug-pruned``)
219
+ The available model names are: ``ccip-caformer-24-randaug-pruned``,
220
+ ``ccip-caformer-6-randaug-pruned_fp32``, ``ccip-caformer-5_fp32``.
221
+ :type model: str
222
+
223
+ :return: The matrix of pairwise differences between the given images or feature vectors.
224
+ :rtype: np.ndarray
225
+
226
+ Examples::
227
+ >>> from imgutils.metrics import ccip_batch_differences
228
+ >>>
229
+ >>> ccip_batch_differences(['ccip/1.jpg', 'ccip/2.jpg', 'ccip/6.jpg', 'ccip/7.jpg'])
230
+ array([[6.5350548e-08, 1.6583106e-01, 4.2947042e-01, 4.0375218e-01],
231
+ [1.6583106e-01, 9.8025822e-08, 4.3715334e-01, 4.0748104e-01],
232
+ [4.2947042e-01, 4.3715334e-01, 3.2675274e-08, 3.9229470e-01],
233
+ [4.0375218e-01, 4.0748104e-01, 3.9229470e-01, 6.5350548e-08]],
234
+ dtype=float32)
235
+ """
236
+ input_ = np.stack([_p_feature(img, size, model) for img in images]).astype(np.float32)
237
+ output, = _open_metric_model(model).run(['output'], {'input': input_})
238
+ return output
palette_app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
+ from animeinsseg import AnimeInsSeg, AnimeInstances
7
+ from animeinsseg.anime_instances import get_color
8
+ from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
9
+ from datasets import load_dataset
10
+ import pathlib
11
+
12
+ # 安装必要的库
13
+ os.system("mim install mmengine")
14
+ os.system('mim install mmcv==2.1.0')
15
+ os.system("mim install mmdet==3.2.0")
16
+
17
+ # 加载模型
18
+ if not os.path.exists("models"):
19
+ os.mkdir("models")
20
+
21
+ os.system("huggingface-cli lfs-enable-largefiles .")
22
+ os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")
23
+
24
+ ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
25
+
26
+ mask_thres = 0.3
27
+ instance_thres = 0.3
28
+ refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
29
+ # refine_kwargs = None
30
+
31
+ net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
32
+
33
+ # 加载数据集
34
+ Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"]
35
+ ds_size = len(Genshin_Impact_Illustration_ds)
36
+ name_image_dict = {}
37
+ for i in range(ds_size):
38
+ row_dict = Genshin_Impact_Illustration_ds[i]
39
+ name_image_dict[row_dict["name"]] = row_dict["image"]
40
+
41
+ # 从数据集中选择一些图片作为示例
42
+ example_images = list(map(str, list(pathlib.Path(".").rglob("*.png"))))
43
+
44
+ def fn(image, model_name):
45
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
46
+ instances: AnimeInstances = net.infer(
47
+ img,
48
+ output_type='numpy',
49
+ pred_score_thr=instance_thres
50
+ )
51
+
52
+ drawed = img.copy()
53
+ im_h, im_w = img.shape[:2]
54
+
55
+ # instances.bboxes, instances.masks will be None, None if no obj is detected
56
+ if instances.bboxes is None:
57
+ return Image.fromarray(drawed[..., ::-1]), "No instances detected"
58
+
59
+ # 用于存储每个 bbox 的 top5 结果
60
+ top5_results = []
61
+
62
+ for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
63
+ color = get_color(ii)
64
+
65
+ mask_alpha = 0.5
66
+ linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
67
+
68
+ # 提取 bbox 区域
69
+ x1, y1, w, h = map(int, xywh)
70
+ x2, y2 = x1 + w, y1 + h
71
+ bbox_image = img[y1:y2, x1:x2]
72
+
73
+ # 计算相似度
74
+ threshold = ccip_default_threshold(model_name)
75
+ results = []
76
+
77
+ for name, imagey in name_image_dict.items():
78
+ # 将数据集中的图片调整为与 bbox 区域相同的大小
79
+ imagey_resized = cv2.resize(imagey, (w, h))
80
+ diff = ccip_difference(bbox_image, imagey_resized)
81
+ result = (diff, 'Same' if diff <= threshold else 'Not Same', name)
82
+ results.append(result)
83
+
84
+ # 按照 diff 值进行排序
85
+ results.sort(key=lambda x: x[0])
86
+ top5_results.append(results[:5]) # 取 top5 结果
87
+
88
+ # 绘制 bbox
89
+ p1, p2 = (x1, y1), (x2, y2)
90
+ cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
91
+
92
+ # 绘制 mask
93
+ p = mask.astype(np.float32)
94
+ blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
95
+ alpha_msk = (mask_alpha * p)[..., None]
96
+ alpha_ori = 1 - alpha_msk
97
+ drawed = drawed * alpha_ori + alpha_msk * blend_mask
98
+
99
+ drawed = drawed.astype(np.uint8)
100
+
101
+ # 创建调色盘图像
102
+ palette_height = 100
103
+ palette_width = im_w
104
+ palette = np.zeros((palette_height, palette_width, 3), dtype=np.uint8)
105
+
106
+ # 绘制每个 bbox 的 top5 结果
107
+ for idx, results in enumerate(top5_results):
108
+ color = get_color(idx)
109
+ x_start = idx * (palette_width // len(top5_results))
110
+ x_end = (idx + 1) * (palette_width // len(top5_results))
111
+
112
+ # 填充颜色
113
+ palette[:, x_start:x_end] = color
114
+
115
+ # 在调色盘上绘制 top5 结果
116
+ for i, (diff, pred, name) in enumerate(results):
117
+ text = f"{name}: {diff:.2f} ({pred})"
118
+ y_pos = 20 + i * 15
119
+ cv2.putText(palette, text, (x_start + 10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA)
120
+
121
+ return Image.fromarray(drawed[..., ::-1]), Image.fromarray(palette)
122
+
123
+ # 创建 Gradio 界面
124
+ iface = gr.Interface(
125
+ # design titles and text descriptions
126
+ title="Anime Subject Instance Segmentation with Similarity Comparison",
127
+ description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).",
128
+ fn=fn,
129
+ inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')],
130
+ outputs=[gr.Image(type="pil", label="Segmentation Result"), gr.Image(type="pil", label="Top5 Results Palette")],
131
+ examples=example_images
132
+ )
133
+
134
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ imageio
3
+ git+https://github.com/cocodataset/panopticapi.git
4
+ pytorch-lightning
5
+ albumentations
6
+ huggingface_hub
7
+ onnxruntime
8
+
9
+ #mmcv<2.2.0
10
+ mmcv==2.1.0
11
+ mmdet==3.2.0
12
+
13
+ # For Web UI
14
+ #gradio
15
+ torch
16
+ torchvision
17
+ openmim
18
+
19
+ gradio==5.10.0
20
+ numpy
21
+ pillow
22
+ #huggingface_hub
23
+ scikit-image
24
+ pandas
25
+ opencv-python>=4.6.0
26
+ hbutils>=0.9.0
27
+ dghs-imgutils[gpu]>=0.2.3
28
+ #### pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple
29
+ onnxruntime-gpu==1.17.0
30
+ #dghs-imgutils>=0.2.3
31
+ #onnxruntime
32
+ httpx
33
+
34
+ datasets
text_app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
+ from animeinsseg import AnimeInsSeg, AnimeInstances
7
+ from animeinsseg.anime_instances import get_color
8
+ from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
9
+ from datasets import load_dataset
10
+ import pathlib
11
+
12
+ # 安装必要的库
13
+ os.system("mim install mmengine")
14
+ os.system('mim install mmcv==2.1.0')
15
+ os.system("mim install mmdet==3.2.0")
16
+
17
+ # 加载模型
18
+ if not os.path.exists("models"):
19
+ os.mkdir("models")
20
+
21
+ os.system("huggingface-cli lfs-enable-largefiles .")
22
+ os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")
23
+
24
+ ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
25
+
26
+ mask_thres = 0.3
27
+ instance_thres = 0.3
28
+ refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
29
+ # refine_kwargs = None
30
+
31
+ net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
32
+
33
+ # 加载数据集
34
+ Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"]
35
+ ds_size = len(Genshin_Impact_Illustration_ds)
36
+ name_image_dict = {}
37
+ for i in range(ds_size):
38
+ row_dict = Genshin_Impact_Illustration_ds[i]
39
+ name_image_dict[row_dict["name"]] = row_dict["image"]
40
+
41
+ # 从数据集中选择一些图片作为示例
42
+ example_images = list(map(str, list(pathlib.Path(".").rglob("*.png"))))
43
+
44
+ def fn(image, model_name):
45
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
46
+ instances: AnimeInstances = net.infer(
47
+ img,
48
+ output_type='numpy',
49
+ pred_score_thr=instance_thres
50
+ )
51
+
52
+ drawed = img.copy()
53
+ im_h, im_w = img.shape[:2]
54
+
55
+ # instances.bboxes, instances.masks will be None, None if no obj is detected
56
+ if instances.bboxes is None:
57
+ return Image.fromarray(drawed[..., ::-1]), "No instances detected"
58
+
59
+ # 用于存储每个 bbox 的 top1 结果
60
+ top1_results = []
61
+
62
+ for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
63
+ color = get_color(ii)
64
+
65
+ mask_alpha = 0.5
66
+ linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
67
+
68
+ # 提取 bbox 区域
69
+ x1, y1, w, h = map(int, xywh)
70
+ x2, y2 = x1 + w, y1 + h
71
+ bbox_image = img[y1:y2, x1:x2]
72
+
73
+ # 计算相似度
74
+ threshold = ccip_default_threshold(model_name)
75
+ results = []
76
+
77
+ for name, imagey in name_image_dict.items():
78
+ # 将数据集中的图片调整为与 bbox 区域相同的大小
79
+ imagey_resized = cv2.resize(imagey, (w, h))
80
+ diff = ccip_difference(bbox_image, imagey_resized)
81
+ result = (diff, 'Same' if diff <= threshold else 'Not Same', name)
82
+ results.append(result)
83
+
84
+ # 按照 diff 值进行排序
85
+ results.sort(key=lambda x: x[0])
86
+ top1_result = results[0]
87
+ top1_results.append(top1_result)
88
+
89
+ # 绘制 bbox
90
+ p1, p2 = (x1, y1), (x2, y2)
91
+ cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
92
+
93
+ # 绘制 mask
94
+ p = mask.astype(np.float32)
95
+ blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
96
+ alpha_msk = (mask_alpha * p)[..., None]
97
+ alpha_ori = 1 - alpha_msk
98
+ drawed = drawed * alpha_ori + alpha_msk * blend_mask
99
+
100
+ drawed = drawed.astype(np.uint8)
101
+
102
+ # 在 bbox 旁边绘制 top1 结果
103
+ text = f"Diff: {top1_result[0]:.2f}, {top1_result[1]}, Name: {top1_result[2]}"
104
+ cv2.putText(drawed, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
105
+
106
+ return Image.fromarray(drawed[..., ::-1]), "\n".join([f"Bbox {i+1}: {res}" for i, res in enumerate(top1_results)])
107
+
108
+ # 创建 Gradio 界面
109
+ iface = gr.Interface(
110
+ # design titles and text descriptions
111
+ title="Anime Subject Instance Segmentation with Similarity Comparison",
112
+ description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).",
113
+ fn=fn,
114
+ inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')],
115
+ outputs=[gr.Image(type="pil"), gr.Textbox(label="Top1 Results for Each Bbox")],
116
+ examples=example_images
117
+ )
118
+
119
+ iface.launch(share=True)
utils/__init__.py ADDED
File without changes
utils/booru_tagger.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import pandas as pd
4
+ import numpy as np
5
+ from onnxruntime import InferenceSession
6
+ from typing import Tuple, List, Dict
7
+ from io import BytesIO
8
+ from PIL import Image
9
+
10
+ import cv2
11
+ from pathlib import Path
12
+
13
+ from tqdm import tqdm
14
+
15
+ def make_square(img, target_size):
16
+ old_size = img.shape[:2]
17
+ desired_size = max(old_size)
18
+ desired_size = max(desired_size, target_size)
19
+
20
+ delta_w = desired_size - old_size[1]
21
+ delta_h = desired_size - old_size[0]
22
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
23
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
24
+
25
+ color = [255, 255, 255]
26
+ new_im = cv2.copyMakeBorder(
27
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
28
+ )
29
+ return new_im
30
+
31
+
32
+ def smart_resize(img, size):
33
+ # Assumes the image has already gone through make_square
34
+ if img.shape[0] > size:
35
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
36
+ elif img.shape[0] < size:
37
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
38
+ return img
39
+
40
+ class Tagger :
41
+ def __init__(self, filename) -> None:
42
+ self.model = InferenceSession(filename, providers=['CUDAExecutionProvider'])
43
+ [root, _] = os.path.split(filename)
44
+ self.tags = pd.read_csv(os.path.join(root, 'selected_tags.csv') if root else 'selected_tags.csv')
45
+
46
+ _, self.height, _, _ = self.model.get_inputs()[0].shape
47
+
48
+ characters = self.tags.loc[self.tags['category'] == 4]
49
+ self.characters = set(characters['name'].values.tolist())
50
+
51
+ def label(self, image: Image) -> Dict[str, float] :
52
+ # alpha to white
53
+ image = image.convert('RGBA')
54
+ new_image = Image.new('RGBA', image.size, 'WHITE')
55
+ new_image.paste(image, mask=image)
56
+ image = new_image.convert('RGB')
57
+ image = np.asarray(image)
58
+
59
+ # PIL RGB to OpenCV BGR
60
+ image = image[:, :, ::-1]
61
+
62
+ image = make_square(image, self.height)
63
+ image = smart_resize(image, self.height)
64
+ image = image.astype(np.float32)
65
+ image = np.expand_dims(image, 0)
66
+
67
+ # evaluate model
68
+ input_name = self.model.get_inputs()[0].name
69
+ label_name = self.model.get_outputs()[0].name
70
+ confidents = self.model.run([label_name], {input_name: image})[0]
71
+
72
+ tags = self.tags[:][['name']]
73
+ tags['confidents'] = confidents[0]
74
+
75
+ # first 4 items are for rating (general, sensitive, questionable, explicit)
76
+ ratings = dict(tags[:4].values)
77
+
78
+ # rest are regular tags
79
+ tags = dict(tags[4:].values)
80
+
81
+ tags = {t: v for t, v in tags.items() if v > 0.5}
82
+ return tags
83
+
84
+ def label_cv2_bgr(self, image: np.ndarray) -> Dict[str, float] :
85
+ # image in BGR u8
86
+ image = make_square(image, self.height)
87
+ image = smart_resize(image, self.height)
88
+ image = image.astype(np.float32)
89
+ image = np.expand_dims(image, 0)
90
+
91
+ # evaluate model
92
+ input_name = self.model.get_inputs()[0].name
93
+ label_name = self.model.get_outputs()[0].name
94
+ confidents = self.model.run([label_name], {input_name: image})[0]
95
+
96
+ tags = self.tags[:][['name']]
97
+ cats = self.tags[:][['category']]
98
+ tags['confidents'] = confidents[0]
99
+
100
+ # first 4 items are for rating (general, sensitive, questionable, explicit)
101
+ ratings = dict(tags[:4].values)
102
+
103
+ # rest are regular tags
104
+ tags = dict(tags[4:].values)
105
+
106
+ tags = [t for t, v in tags.items() if v > 0.5]
107
+ character_str = []
108
+ for t in tags:
109
+ if t in self.characters:
110
+ character_str.append(t)
111
+ return tags, character_str
112
+
113
+
114
+ if __name__ == '__main__':
115
+ modelp = r'models/wd-v1-4-swinv2-tagger-v2/model.onnx'
116
+ tagger = Tagger(modelp)
utils/constants.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ CATEGORIES = [
4
+ {"id": 0, "name": "object", "isthing": 1}
5
+ ]
6
+
7
+ IMAGE_ID_ZFILL = 12
8
+
9
+ COLOR_PALETTE = [
10
+ (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
11
+ (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
12
+ (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
13
+ (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
14
+ (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
15
+ (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
16
+ (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
17
+ (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
18
+ (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
19
+ (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
20
+ (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
21
+ (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
22
+ (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
23
+ (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
24
+ (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
25
+ (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
26
+ (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
27
+ (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
28
+ (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
29
+ (246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203),
30
+ (150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100),
31
+ (92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255),
32
+ (124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0),
33
+ (193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176),
34
+ (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
35
+ (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
36
+ (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
37
+ (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
38
+ (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
39
+ (146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140),
40
+ (96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152),
41
+ (208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0),
42
+ (0, 114, 143), (102, 102, 156), (250, 141, 255)
43
+ ]
44
+
45
+ class Colors:
46
+ # Ultralytics color palette https://ultralytics.com/
47
+ def __init__(self):
48
+ # hex = matplotlib.colors.TABLEAU_COLORS.values()
49
+ hexs = ('FF1010', '10FF10', 'FFF010', '100FFF', '0018EC', 'FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
50
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
51
+ self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
52
+ self.n = len(self.palette)
53
+
54
+ def __call__(self, i, bgr=True):
55
+ c = self.palette[int(i) % self.n]
56
+ return (c[2], c[1], c[0]) if bgr else c
57
+
58
+ @staticmethod
59
+ def hex2rgb(h): # rgb order (PIL)
60
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
61
+
62
+ colors = Colors()
63
+ def get_color(idx):
64
+ if idx == -1:
65
+ return 255
66
+ else:
67
+ return colors(idx)
68
+
69
+
70
+ MULTIPLE_TAGS = {'2girls', '3girls', '4girls', '5girls', '6+girls', 'multiple_girls',
71
+ '2boys', '3boys', '4boys', '5boys', '6+boys', 'multiple_boys',
72
+ '2others', '3others', '4others', '5others', '6+others', 'multiple_others'}
73
+
74
+ if hasattr(torch, 'cuda'):
75
+ DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
76
+ else:
77
+ DEFAULT_DEVICE = 'cpu'
78
+
79
+ DEFAULT_DETECTOR_CKPT = 'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
80
+ DEFAULT_DEPTHREFINE_CKPT = 'models/AnimeInstanceSegmentation/kenburns_depth_refinenet.ckpt'
81
+ DEFAULT_INPAINTNET_CKPT = 'models/AnimeInstanceSegmentation/kenburns_inpaintnet.ckpt'
82
+ DEPTH_ZOE_CKPT = 'models/AnimeInstanceSegmentation/ZoeD_M12_N.pt'
utils/cupy_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import cupy
4
+ import os.path as osp
5
+ import torch
6
+
7
+ @cupy.memoize(for_each_device=True)
8
+ def launch_kernel(strFunction, strKernel):
9
+ if 'CUDA_HOME' not in os.environ:
10
+ os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
11
+ # end
12
+ # , options=tuple([ '-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include' ])
13
+ return cupy.RawKernel(strKernel, strFunction)
14
+
15
+
16
+ def preprocess_kernel(strKernel, objVariables):
17
+ path_to_math_helper = osp.join(osp.dirname(osp.abspath(__file__)), 'helper_math.h')
18
+ strKernel = '''
19
+ #include <{{HELPER_PATH}}>
20
+
21
+ __device__ __forceinline__ float atomicMin(const float* buffer, float dblValue) {
22
+ int intValue = __float_as_int(*buffer);
23
+
24
+ while (__int_as_float(intValue) > dblValue) {
25
+ intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
26
+ }
27
+
28
+ return __int_as_float(intValue);
29
+ }
30
+
31
+
32
+ __device__ __forceinline__ float atomicMax(const float* buffer, float dblValue) {
33
+ int intValue = __float_as_int(*buffer);
34
+
35
+ while (__int_as_float(intValue) < dblValue) {
36
+ intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
37
+ }
38
+
39
+ return __int_as_float(intValue);
40
+ }
41
+ '''.replace('{{HELPER_PATH}}', path_to_math_helper) + strKernel
42
+ # end
43
+
44
+ for strVariable in objVariables:
45
+ objValue = objVariables[strVariable]
46
+
47
+ if type(objValue) == int:
48
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
49
+
50
+ elif type(objValue) == float:
51
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
52
+
53
+ elif type(objValue) == str:
54
+ strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
55
+
56
+ # end
57
+ # end
58
+
59
+ while True:
60
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
61
+
62
+ if objMatch is None:
63
+ break
64
+ # end
65
+
66
+ intArg = int(objMatch.group(2))
67
+
68
+ strTensor = objMatch.group(4)
69
+ intSizes = objVariables[strTensor].size()
70
+
71
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
72
+ # end
73
+
74
+ while True:
75
+ objMatch = re.search('(STRIDE_)([0-4])(\()([^\)]*)(\))', strKernel)
76
+
77
+ if objMatch is None:
78
+ break
79
+ # end
80
+
81
+ intArg = int(objMatch.group(2))
82
+
83
+ strTensor = objMatch.group(4)
84
+ intStrides = objVariables[strTensor].stride()
85
+
86
+ strKernel = strKernel.replace(objMatch.group(), str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()))
87
+ # end
88
+
89
+ while True:
90
+ objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
91
+
92
+ if objMatch is None:
93
+ break
94
+ # end
95
+
96
+ intArgs = int(objMatch.group(2))
97
+ strArgs = objMatch.group(4).split(',')
98
+
99
+ strTensor = strArgs[0]
100
+ intStrides = objVariables[strTensor].stride()
101
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
102
+
103
+ strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
104
+ # end
105
+
106
+ while True:
107
+ objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
108
+
109
+ if objMatch is None:
110
+ break
111
+ # end
112
+
113
+ intArgs = int(objMatch.group(2))
114
+ strArgs = objMatch.group(4).split(',')
115
+
116
+ strTensor = strArgs[0]
117
+ intStrides = objVariables[strTensor].stride()
118
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
119
+
120
+ strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
121
+ # end
122
+ return strKernel
utils/effects.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numba import jit, njit
2
+ import numpy as np
3
+ import time
4
+ import cv2
5
+ import math
6
+ from pathlib import Path
7
+ import os.path as osp
8
+ import torch
9
+ from .cupy_utils import launch_kernel, preprocess_kernel
10
+ import cupy
11
+
12
+ def bokeh_filter_cupy(img, depth, dx, dy, im_h, im_w, num_samples=32):
13
+ blurred = img.clone()
14
+ n = im_h * im_w
15
+
16
+ str_kernel = '''
17
+ extern "C" __global__ void kernel_bokeh(
18
+ const int n,
19
+ const int h,
20
+ const int w,
21
+ const int nsamples,
22
+ const float dx,
23
+ const float dy,
24
+ const float* img,
25
+ const float* depth,
26
+ float* blurred
27
+ ) {
28
+
29
+ const int im_size = min(h, w);
30
+ const int sample_offset = nsamples / 2;
31
+ for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n * 3; intIndex += blockDim.x * gridDim.x) {
32
+
33
+ const int intSample = intIndex / 3;
34
+
35
+ const int c = intIndex % 3;
36
+ const int y = ( intSample / w) % h;
37
+ const int x = intSample % w;
38
+
39
+ const int flatten_xy = y * w + x;
40
+ const int fid = flatten_xy * 3 + c;
41
+ const float d = depth[flatten_xy];
42
+
43
+ const float _dx = dx * d;
44
+ const float _dy = dy * d;
45
+ float weight = 0;
46
+ float color = 0;
47
+ for (int s = 0; s < nsamples; s += 1) {
48
+
49
+ const int sp = (s - sample_offset) * im_size;
50
+ const int x_ = x + int(round(_dx * sp));
51
+ const int y_ = y + int(round(_dy * sp));
52
+
53
+ if ((x_ >= w) | (y_ >= h) | (x_ < 0) | (y_ < 0))
54
+ continue;
55
+
56
+ const int flatten_xy_ = y_ * w + x_;
57
+ const float w_ = depth[flatten_xy_];
58
+ weight += w_;
59
+ const int fid_ = flatten_xy_ * 3 + c;
60
+ color += img[fid_] * w_;
61
+ }
62
+
63
+ if (weight != 0) {
64
+ color /= weight;
65
+ }
66
+ else {
67
+ color = img[fid];
68
+ }
69
+
70
+ blurred[fid] = color;
71
+
72
+ }
73
+
74
+ }
75
+ '''
76
+ launch_kernel('kernel_bokeh', str_kernel)(
77
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
78
+ block=tuple([ 512, 1, 1 ]),
79
+ args=[ cupy.int32(n), cupy.int32(im_h), cupy.int32(im_w), \
80
+ cupy.int32(num_samples), cupy.float32(dx), cupy.float32(dy),
81
+ img.data_ptr(), depth.data_ptr(), blurred.data_ptr() ]
82
+ )
83
+
84
+ return blurred
85
+
86
+
87
+ def np2flatten_tensor(arr: np.ndarray, to_cuda: bool = True) -> torch.Tensor:
88
+ c = 1
89
+ if len(arr.shape) == 3:
90
+ c = arr.shape[2]
91
+ else:
92
+ arr = arr[..., None]
93
+ arr = arr.transpose((2, 0, 1))[None, ...]
94
+ t = torch.from_numpy(arr).view(1, c, -1)
95
+
96
+ if to_cuda:
97
+ t = t.cuda()
98
+ return t
99
+
100
+ def ftensor2img(t: torch.Tensor, im_h, im_w):
101
+ t = t.detach().cpu().numpy().squeeze()
102
+ c = t.shape[0]
103
+ t = t.transpose((1, 0)).reshape((im_h, im_w, c))
104
+ return t
105
+
106
+
107
+ @njit
108
+ def bokeh_filter(img, depth, dx, dy, num_samples=32):
109
+
110
+ sample_offset = num_samples // 2
111
+ # _scale = 0.0005
112
+ # depth = depth * _scale
113
+
114
+ im_h, im_w = img.shape[0], img.shape[1]
115
+ im_size = min(im_h, im_w)
116
+ blured = np.zeros_like(img)
117
+ for x in range(im_w):
118
+ for y in range(im_h):
119
+ d = depth[y, x]
120
+ _color = np.array([0, 0, 0], dtype=np.float32)
121
+ _dx = dx * d
122
+ _dy = dy * d
123
+ weight = 0
124
+ for s in range(num_samples):
125
+ s = (s - sample_offset) * im_size
126
+ x_ = x + int(round(_dx * s))
127
+ y_ = y + int(round(_dy * s))
128
+ if x_ >= im_w or y_ >= im_h or x_ < 0 or y_ < 0:
129
+ continue
130
+ _w = depth[y_, x_]
131
+ weight += _w
132
+ _color += img[y_, x_] * _w
133
+ if weight == 0:
134
+ blured[y, x] = img[y, x]
135
+ else:
136
+ blured[y, x] = _color / np.array([weight, weight, weight], dtype=np.float32)
137
+
138
+ return blured
139
+
140
+
141
+
142
+
143
+ def bokeh_blur(img, depth, num_samples=32, lightness_factor=10, depth_factor=2, use_cuda=False, focal_plane=None):
144
+ img = np.ascontiguousarray(img)
145
+
146
+ if depth is not None:
147
+ depth = depth.astype(np.float32)
148
+ if focal_plane is not None:
149
+ depth = depth.max() - np.abs(depth - focal_plane)
150
+ if depth_factor != 1:
151
+ depth = np.power(depth, depth_factor)
152
+ depth = depth - depth.min()
153
+ depth = depth.astype(np.float32) / depth.max()
154
+ depth = 1 - depth
155
+
156
+ img = img.astype(np.float32) / 255
157
+ img_hightlighted = np.power(img, lightness_factor)
158
+
159
+ # img =
160
+ im_h, im_w = img.shape[:2]
161
+ PI = math.pi
162
+
163
+ _scale = 0.0005
164
+ depth = depth * _scale
165
+
166
+ if use_cuda:
167
+ img_hightlighted = np2flatten_tensor(img_hightlighted, True)
168
+ depth = np2flatten_tensor(depth, True)
169
+ vertical_blured = bokeh_filter_cupy(img_hightlighted, depth, 0, 1, im_h, im_w, num_samples)
170
+ diag_blured = bokeh_filter_cupy(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), im_h, im_w, num_samples)
171
+ rhom_blur = bokeh_filter_cupy(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), im_h, im_w, num_samples)
172
+ blured = (diag_blured + rhom_blur) / 2
173
+ blured = ftensor2img(blured, im_h, im_w)
174
+ else:
175
+ vertical_blured = bokeh_filter(img_hightlighted, depth, 0, 1, num_samples)
176
+ diag_blured = bokeh_filter(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), num_samples)
177
+ rhom_blur = bokeh_filter(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), num_samples)
178
+ blured = (diag_blured + rhom_blur) / 2
179
+ blured = np.power(blured, 1 / lightness_factor)
180
+ blured = (blured * 255).astype(np.uint8)
181
+
182
+ return blured
utils/env_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import warnings
4
+
5
+ import torch.multiprocessing as mp
6
+
7
+
8
+ def set_multi_processing(
9
+ mp_start_method: str = "fork", opencv_num_threads: int = 0, distributed: bool = True
10
+ ) -> None:
11
+ """Set multi-processing related environment.
12
+
13
+ This function is refered from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/dl_utils/setup_env.py
14
+
15
+ Args:
16
+ mp_start_method (str): Set the method which should be used to start
17
+ child processes. Defaults to 'fork'.
18
+ opencv_num_threads (int): Number of threads for opencv.
19
+ Defaults to 0.
20
+ distributed (bool): True if distributed environment.
21
+ Defaults to False.
22
+ """ # noqa
23
+ # set multi-process start method as `fork` to speed up the training
24
+ if platform.system() != "Windows":
25
+ current_method = mp.get_start_method(allow_none=True)
26
+ if current_method is not None and current_method != mp_start_method:
27
+ warnings.warn(
28
+ f"Multi-processing start method `{mp_start_method}` is "
29
+ f"different from the previous setting `{current_method}`."
30
+ f"It will be force set to `{mp_start_method}`. You can "
31
+ "change this behavior by changing `mp_start_method` in "
32
+ "your config."
33
+ )
34
+ mp.set_start_method(mp_start_method, force=True)
35
+
36
+ try:
37
+ import cv2
38
+
39
+ # disable opencv multithreading to avoid system being overloaded
40
+ cv2.setNumThreads(opencv_num_threads)
41
+ except ImportError:
42
+ pass
43
+
44
+ # setup OMP threads
45
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
46
+ if "OMP_NUM_THREADS" not in os.environ and distributed:
47
+ omp_num_threads = 1
48
+ warnings.warn(
49
+ "Setting OMP_NUM_THREADS environment variable for each process"
50
+ f" to be {omp_num_threads} in default, to avoid your system "
51
+ "being overloaded, please further tune the variable for "
52
+ "optimal performance in your application as needed."
53
+ )
54
+ os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
55
+
56
+ # # setup MKL threads
57
+ if "MKL_NUM_THREADS" not in os.environ and distributed:
58
+ mkl_num_threads = 1
59
+ warnings.warn(
60
+ "Setting MKL_NUM_THREADS environment variable for each process"
61
+ f" to be {mkl_num_threads} in default, to avoid your system "
62
+ "being overloaded, please further tune the variable for "
63
+ "optimal performance in your application as needed."
64
+ )
65
+ os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
utils/helper_math.h ADDED
@@ -0,0 +1,1449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 1993-2012 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * Please refer to the NVIDIA end user license agreement (EULA) associated
5
+ * with this source code for terms and conditions that govern your use of
6
+ * this software. Any use, reproduction, disclosure, or distribution of
7
+ * this software and related documentation outside the terms of the EULA
8
+ * is strictly prohibited.
9
+ *
10
+ */
11
+
12
+ /*
13
+ * This file implements common mathematical operations on vector types
14
+ * (float3, float4 etc.) since these are not provided as standard by CUDA.
15
+ *
16
+ * The syntax is modeled on the Cg standard library.
17
+ *
18
+ * This is part of the Helper library includes
19
+ *
20
+ * Thanks to Linh Hah for additions and fixes.
21
+ */
22
+
23
+ #ifndef HELPER_MATH_H
24
+ #define HELPER_MATH_H
25
+
26
+ #include "cuda_runtime.h"
27
+
28
+ typedef unsigned int uint;
29
+ typedef unsigned short ushort;
30
+
31
+ #ifndef __CUDACC__
32
+ #include <math.h>
33
+
34
+ ////////////////////////////////////////////////////////////////////////////////
35
+ // host implementations of CUDA functions
36
+ ////////////////////////////////////////////////////////////////////////////////
37
+
38
+ inline float fminf(float a, float b)
39
+ {
40
+ return a < b ? a : b;
41
+ }
42
+
43
+ inline float fmaxf(float a, float b)
44
+ {
45
+ return a > b ? a : b;
46
+ }
47
+
48
+ inline int max(int a, int b)
49
+ {
50
+ return a > b ? a : b;
51
+ }
52
+
53
+ inline int min(int a, int b)
54
+ {
55
+ return a < b ? a : b;
56
+ }
57
+
58
+ inline float rsqrtf(float x)
59
+ {
60
+ return 1.0f / sqrtf(x);
61
+ }
62
+ #endif
63
+
64
+ ////////////////////////////////////////////////////////////////////////////////
65
+ // constructors
66
+ ////////////////////////////////////////////////////////////////////////////////
67
+
68
+ inline __host__ __device__ float2 make_float2(float s)
69
+ {
70
+ return make_float2(s, s);
71
+ }
72
+ inline __host__ __device__ float2 make_float2(float3 a)
73
+ {
74
+ return make_float2(a.x, a.y);
75
+ }
76
+ inline __host__ __device__ float2 make_float2(int2 a)
77
+ {
78
+ return make_float2(float(a.x), float(a.y));
79
+ }
80
+ inline __host__ __device__ float2 make_float2(uint2 a)
81
+ {
82
+ return make_float2(float(a.x), float(a.y));
83
+ }
84
+
85
+ inline __host__ __device__ int2 make_int2(int s)
86
+ {
87
+ return make_int2(s, s);
88
+ }
89
+ inline __host__ __device__ int2 make_int2(int3 a)
90
+ {
91
+ return make_int2(a.x, a.y);
92
+ }
93
+ inline __host__ __device__ int2 make_int2(uint2 a)
94
+ {
95
+ return make_int2(int(a.x), int(a.y));
96
+ }
97
+ inline __host__ __device__ int2 make_int2(float2 a)
98
+ {
99
+ return make_int2(int(a.x), int(a.y));
100
+ }
101
+
102
+ inline __host__ __device__ uint2 make_uint2(uint s)
103
+ {
104
+ return make_uint2(s, s);
105
+ }
106
+ inline __host__ __device__ uint2 make_uint2(uint3 a)
107
+ {
108
+ return make_uint2(a.x, a.y);
109
+ }
110
+ inline __host__ __device__ uint2 make_uint2(int2 a)
111
+ {
112
+ return make_uint2(uint(a.x), uint(a.y));
113
+ }
114
+
115
+ inline __host__ __device__ float3 make_float3(float s)
116
+ {
117
+ return make_float3(s, s, s);
118
+ }
119
+ inline __host__ __device__ float3 make_float3(float2 a)
120
+ {
121
+ return make_float3(a.x, a.y, 0.0f);
122
+ }
123
+ inline __host__ __device__ float3 make_float3(float2 a, float s)
124
+ {
125
+ return make_float3(a.x, a.y, s);
126
+ }
127
+ inline __host__ __device__ float3 make_float3(float4 a)
128
+ {
129
+ return make_float3(a.x, a.y, a.z);
130
+ }
131
+ inline __host__ __device__ float3 make_float3(int3 a)
132
+ {
133
+ return make_float3(float(a.x), float(a.y), float(a.z));
134
+ }
135
+ inline __host__ __device__ float3 make_float3(uint3 a)
136
+ {
137
+ return make_float3(float(a.x), float(a.y), float(a.z));
138
+ }
139
+
140
+ inline __host__ __device__ int3 make_int3(int s)
141
+ {
142
+ return make_int3(s, s, s);
143
+ }
144
+ inline __host__ __device__ int3 make_int3(int2 a)
145
+ {
146
+ return make_int3(a.x, a.y, 0);
147
+ }
148
+ inline __host__ __device__ int3 make_int3(int2 a, int s)
149
+ {
150
+ return make_int3(a.x, a.y, s);
151
+ }
152
+ inline __host__ __device__ int3 make_int3(uint3 a)
153
+ {
154
+ return make_int3(int(a.x), int(a.y), int(a.z));
155
+ }
156
+ inline __host__ __device__ int3 make_int3(float3 a)
157
+ {
158
+ return make_int3(int(a.x), int(a.y), int(a.z));
159
+ }
160
+
161
+ inline __host__ __device__ uint3 make_uint3(uint s)
162
+ {
163
+ return make_uint3(s, s, s);
164
+ }
165
+ inline __host__ __device__ uint3 make_uint3(uint2 a)
166
+ {
167
+ return make_uint3(a.x, a.y, 0);
168
+ }
169
+ inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
170
+ {
171
+ return make_uint3(a.x, a.y, s);
172
+ }
173
+ inline __host__ __device__ uint3 make_uint3(uint4 a)
174
+ {
175
+ return make_uint3(a.x, a.y, a.z);
176
+ }
177
+ inline __host__ __device__ uint3 make_uint3(int3 a)
178
+ {
179
+ return make_uint3(uint(a.x), uint(a.y), uint(a.z));
180
+ }
181
+
182
+ inline __host__ __device__ float4 make_float4(float s)
183
+ {
184
+ return make_float4(s, s, s, s);
185
+ }
186
+ inline __host__ __device__ float4 make_float4(float3 a)
187
+ {
188
+ return make_float4(a.x, a.y, a.z, 0.0f);
189
+ }
190
+ inline __host__ __device__ float4 make_float4(float3 a, float w)
191
+ {
192
+ return make_float4(a.x, a.y, a.z, w);
193
+ }
194
+ inline __host__ __device__ float4 make_float4(int4 a)
195
+ {
196
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
197
+ }
198
+ inline __host__ __device__ float4 make_float4(uint4 a)
199
+ {
200
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
201
+ }
202
+
203
+ inline __host__ __device__ int4 make_int4(int s)
204
+ {
205
+ return make_int4(s, s, s, s);
206
+ }
207
+ inline __host__ __device__ int4 make_int4(int3 a)
208
+ {
209
+ return make_int4(a.x, a.y, a.z, 0);
210
+ }
211
+ inline __host__ __device__ int4 make_int4(int3 a, int w)
212
+ {
213
+ return make_int4(a.x, a.y, a.z, w);
214
+ }
215
+ inline __host__ __device__ int4 make_int4(uint4 a)
216
+ {
217
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
218
+ }
219
+ inline __host__ __device__ int4 make_int4(float4 a)
220
+ {
221
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
222
+ }
223
+
224
+
225
+ inline __host__ __device__ uint4 make_uint4(uint s)
226
+ {
227
+ return make_uint4(s, s, s, s);
228
+ }
229
+ inline __host__ __device__ uint4 make_uint4(uint3 a)
230
+ {
231
+ return make_uint4(a.x, a.y, a.z, 0);
232
+ }
233
+ inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
234
+ {
235
+ return make_uint4(a.x, a.y, a.z, w);
236
+ }
237
+ inline __host__ __device__ uint4 make_uint4(int4 a)
238
+ {
239
+ return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
240
+ }
241
+
242
+ ////////////////////////////////////////////////////////////////////////////////
243
+ // negate
244
+ ////////////////////////////////////////////////////////////////////////////////
245
+
246
+ inline __host__ __device__ float2 operator-(float2 &a)
247
+ {
248
+ return make_float2(-a.x, -a.y);
249
+ }
250
+ inline __host__ __device__ int2 operator-(int2 &a)
251
+ {
252
+ return make_int2(-a.x, -a.y);
253
+ }
254
+ inline __host__ __device__ float3 operator-(float3 &a)
255
+ {
256
+ return make_float3(-a.x, -a.y, -a.z);
257
+ }
258
+ inline __host__ __device__ int3 operator-(int3 &a)
259
+ {
260
+ return make_int3(-a.x, -a.y, -a.z);
261
+ }
262
+ inline __host__ __device__ float4 operator-(float4 &a)
263
+ {
264
+ return make_float4(-a.x, -a.y, -a.z, -a.w);
265
+ }
266
+ inline __host__ __device__ int4 operator-(int4 &a)
267
+ {
268
+ return make_int4(-a.x, -a.y, -a.z, -a.w);
269
+ }
270
+
271
+ ////////////////////////////////////////////////////////////////////////////////
272
+ // addition
273
+ ////////////////////////////////////////////////////////////////////////////////
274
+
275
+ inline __host__ __device__ float2 operator+(float2 a, float2 b)
276
+ {
277
+ return make_float2(a.x + b.x, a.y + b.y);
278
+ }
279
+ inline __host__ __device__ void operator+=(float2 &a, float2 b)
280
+ {
281
+ a.x += b.x;
282
+ a.y += b.y;
283
+ }
284
+ inline __host__ __device__ float2 operator+(float2 a, float b)
285
+ {
286
+ return make_float2(a.x + b, a.y + b);
287
+ }
288
+ inline __host__ __device__ float2 operator+(float b, float2 a)
289
+ {
290
+ return make_float2(a.x + b, a.y + b);
291
+ }
292
+ inline __host__ __device__ void operator+=(float2 &a, float b)
293
+ {
294
+ a.x += b;
295
+ a.y += b;
296
+ }
297
+
298
+ inline __host__ __device__ int2 operator+(int2 a, int2 b)
299
+ {
300
+ return make_int2(a.x + b.x, a.y + b.y);
301
+ }
302
+ inline __host__ __device__ void operator+=(int2 &a, int2 b)
303
+ {
304
+ a.x += b.x;
305
+ a.y += b.y;
306
+ }
307
+ inline __host__ __device__ int2 operator+(int2 a, int b)
308
+ {
309
+ return make_int2(a.x + b, a.y + b);
310
+ }
311
+ inline __host__ __device__ int2 operator+(int b, int2 a)
312
+ {
313
+ return make_int2(a.x + b, a.y + b);
314
+ }
315
+ inline __host__ __device__ void operator+=(int2 &a, int b)
316
+ {
317
+ a.x += b;
318
+ a.y += b;
319
+ }
320
+
321
+ inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
322
+ {
323
+ return make_uint2(a.x + b.x, a.y + b.y);
324
+ }
325
+ inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
326
+ {
327
+ a.x += b.x;
328
+ a.y += b.y;
329
+ }
330
+ inline __host__ __device__ uint2 operator+(uint2 a, uint b)
331
+ {
332
+ return make_uint2(a.x + b, a.y + b);
333
+ }
334
+ inline __host__ __device__ uint2 operator+(uint b, uint2 a)
335
+ {
336
+ return make_uint2(a.x + b, a.y + b);
337
+ }
338
+ inline __host__ __device__ void operator+=(uint2 &a, uint b)
339
+ {
340
+ a.x += b;
341
+ a.y += b;
342
+ }
343
+
344
+
345
+ inline __host__ __device__ float3 operator+(float3 a, float3 b)
346
+ {
347
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
348
+ }
349
+ inline __host__ __device__ void operator+=(float3 &a, float3 b)
350
+ {
351
+ a.x += b.x;
352
+ a.y += b.y;
353
+ a.z += b.z;
354
+ }
355
+ inline __host__ __device__ float3 operator+(float3 a, float b)
356
+ {
357
+ return make_float3(a.x + b, a.y + b, a.z + b);
358
+ }
359
+ inline __host__ __device__ void operator+=(float3 &a, float b)
360
+ {
361
+ a.x += b;
362
+ a.y += b;
363
+ a.z += b;
364
+ }
365
+
366
+ inline __host__ __device__ int3 operator+(int3 a, int3 b)
367
+ {
368
+ return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
369
+ }
370
+ inline __host__ __device__ void operator+=(int3 &a, int3 b)
371
+ {
372
+ a.x += b.x;
373
+ a.y += b.y;
374
+ a.z += b.z;
375
+ }
376
+ inline __host__ __device__ int3 operator+(int3 a, int b)
377
+ {
378
+ return make_int3(a.x + b, a.y + b, a.z + b);
379
+ }
380
+ inline __host__ __device__ void operator+=(int3 &a, int b)
381
+ {
382
+ a.x += b;
383
+ a.y += b;
384
+ a.z += b;
385
+ }
386
+
387
+ inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
388
+ {
389
+ return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
390
+ }
391
+ inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
392
+ {
393
+ a.x += b.x;
394
+ a.y += b.y;
395
+ a.z += b.z;
396
+ }
397
+ inline __host__ __device__ uint3 operator+(uint3 a, uint b)
398
+ {
399
+ return make_uint3(a.x + b, a.y + b, a.z + b);
400
+ }
401
+ inline __host__ __device__ void operator+=(uint3 &a, uint b)
402
+ {
403
+ a.x += b;
404
+ a.y += b;
405
+ a.z += b;
406
+ }
407
+
408
+ inline __host__ __device__ int3 operator+(int b, int3 a)
409
+ {
410
+ return make_int3(a.x + b, a.y + b, a.z + b);
411
+ }
412
+ inline __host__ __device__ uint3 operator+(uint b, uint3 a)
413
+ {
414
+ return make_uint3(a.x + b, a.y + b, a.z + b);
415
+ }
416
+ inline __host__ __device__ float3 operator+(float b, float3 a)
417
+ {
418
+ return make_float3(a.x + b, a.y + b, a.z + b);
419
+ }
420
+
421
+ inline __host__ __device__ float4 operator+(float4 a, float4 b)
422
+ {
423
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
424
+ }
425
+ inline __host__ __device__ void operator+=(float4 &a, float4 b)
426
+ {
427
+ a.x += b.x;
428
+ a.y += b.y;
429
+ a.z += b.z;
430
+ a.w += b.w;
431
+ }
432
+ inline __host__ __device__ float4 operator+(float4 a, float b)
433
+ {
434
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
435
+ }
436
+ inline __host__ __device__ float4 operator+(float b, float4 a)
437
+ {
438
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
439
+ }
440
+ inline __host__ __device__ void operator+=(float4 &a, float b)
441
+ {
442
+ a.x += b;
443
+ a.y += b;
444
+ a.z += b;
445
+ a.w += b;
446
+ }
447
+
448
+ inline __host__ __device__ int4 operator+(int4 a, int4 b)
449
+ {
450
+ return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
451
+ }
452
+ inline __host__ __device__ void operator+=(int4 &a, int4 b)
453
+ {
454
+ a.x += b.x;
455
+ a.y += b.y;
456
+ a.z += b.z;
457
+ a.w += b.w;
458
+ }
459
+ inline __host__ __device__ int4 operator+(int4 a, int b)
460
+ {
461
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
462
+ }
463
+ inline __host__ __device__ int4 operator+(int b, int4 a)
464
+ {
465
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
466
+ }
467
+ inline __host__ __device__ void operator+=(int4 &a, int b)
468
+ {
469
+ a.x += b;
470
+ a.y += b;
471
+ a.z += b;
472
+ a.w += b;
473
+ }
474
+
475
+ inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
476
+ {
477
+ return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
478
+ }
479
+ inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
480
+ {
481
+ a.x += b.x;
482
+ a.y += b.y;
483
+ a.z += b.z;
484
+ a.w += b.w;
485
+ }
486
+ inline __host__ __device__ uint4 operator+(uint4 a, uint b)
487
+ {
488
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
489
+ }
490
+ inline __host__ __device__ uint4 operator+(uint b, uint4 a)
491
+ {
492
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
493
+ }
494
+ inline __host__ __device__ void operator+=(uint4 &a, uint b)
495
+ {
496
+ a.x += b;
497
+ a.y += b;
498
+ a.z += b;
499
+ a.w += b;
500
+ }
501
+
502
+ ////////////////////////////////////////////////////////////////////////////////
503
+ // subtract
504
+ ////////////////////////////////////////////////////////////////////////////////
505
+
506
+ inline __host__ __device__ float2 operator-(float2 a, float2 b)
507
+ {
508
+ return make_float2(a.x - b.x, a.y - b.y);
509
+ }
510
+ inline __host__ __device__ void operator-=(float2 &a, float2 b)
511
+ {
512
+ a.x -= b.x;
513
+ a.y -= b.y;
514
+ }
515
+ inline __host__ __device__ float2 operator-(float2 a, float b)
516
+ {
517
+ return make_float2(a.x - b, a.y - b);
518
+ }
519
+ inline __host__ __device__ float2 operator-(float b, float2 a)
520
+ {
521
+ return make_float2(b - a.x, b - a.y);
522
+ }
523
+ inline __host__ __device__ void operator-=(float2 &a, float b)
524
+ {
525
+ a.x -= b;
526
+ a.y -= b;
527
+ }
528
+
529
+ inline __host__ __device__ int2 operator-(int2 a, int2 b)
530
+ {
531
+ return make_int2(a.x - b.x, a.y - b.y);
532
+ }
533
+ inline __host__ __device__ void operator-=(int2 &a, int2 b)
534
+ {
535
+ a.x -= b.x;
536
+ a.y -= b.y;
537
+ }
538
+ inline __host__ __device__ int2 operator-(int2 a, int b)
539
+ {
540
+ return make_int2(a.x - b, a.y - b);
541
+ }
542
+ inline __host__ __device__ int2 operator-(int b, int2 a)
543
+ {
544
+ return make_int2(b - a.x, b - a.y);
545
+ }
546
+ inline __host__ __device__ void operator-=(int2 &a, int b)
547
+ {
548
+ a.x -= b;
549
+ a.y -= b;
550
+ }
551
+
552
+ inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
553
+ {
554
+ return make_uint2(a.x - b.x, a.y - b.y);
555
+ }
556
+ inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
557
+ {
558
+ a.x -= b.x;
559
+ a.y -= b.y;
560
+ }
561
+ inline __host__ __device__ uint2 operator-(uint2 a, uint b)
562
+ {
563
+ return make_uint2(a.x - b, a.y - b);
564
+ }
565
+ inline __host__ __device__ uint2 operator-(uint b, uint2 a)
566
+ {
567
+ return make_uint2(b - a.x, b - a.y);
568
+ }
569
+ inline __host__ __device__ void operator-=(uint2 &a, uint b)
570
+ {
571
+ a.x -= b;
572
+ a.y -= b;
573
+ }
574
+
575
+ inline __host__ __device__ float3 operator-(float3 a, float3 b)
576
+ {
577
+ return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
578
+ }
579
+ inline __host__ __device__ void operator-=(float3 &a, float3 b)
580
+ {
581
+ a.x -= b.x;
582
+ a.y -= b.y;
583
+ a.z -= b.z;
584
+ }
585
+ inline __host__ __device__ float3 operator-(float3 a, float b)
586
+ {
587
+ return make_float3(a.x - b, a.y - b, a.z - b);
588
+ }
589
+ inline __host__ __device__ float3 operator-(float b, float3 a)
590
+ {
591
+ return make_float3(b - a.x, b - a.y, b - a.z);
592
+ }
593
+ inline __host__ __device__ void operator-=(float3 &a, float b)
594
+ {
595
+ a.x -= b;
596
+ a.y -= b;
597
+ a.z -= b;
598
+ }
599
+
600
+ inline __host__ __device__ int3 operator-(int3 a, int3 b)
601
+ {
602
+ return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
603
+ }
604
+ inline __host__ __device__ void operator-=(int3 &a, int3 b)
605
+ {
606
+ a.x -= b.x;
607
+ a.y -= b.y;
608
+ a.z -= b.z;
609
+ }
610
+ inline __host__ __device__ int3 operator-(int3 a, int b)
611
+ {
612
+ return make_int3(a.x - b, a.y - b, a.z - b);
613
+ }
614
+ inline __host__ __device__ int3 operator-(int b, int3 a)
615
+ {
616
+ return make_int3(b - a.x, b - a.y, b - a.z);
617
+ }
618
+ inline __host__ __device__ void operator-=(int3 &a, int b)
619
+ {
620
+ a.x -= b;
621
+ a.y -= b;
622
+ a.z -= b;
623
+ }
624
+
625
+ inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
626
+ {
627
+ return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
628
+ }
629
+ inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
630
+ {
631
+ a.x -= b.x;
632
+ a.y -= b.y;
633
+ a.z -= b.z;
634
+ }
635
+ inline __host__ __device__ uint3 operator-(uint3 a, uint b)
636
+ {
637
+ return make_uint3(a.x - b, a.y - b, a.z - b);
638
+ }
639
+ inline __host__ __device__ uint3 operator-(uint b, uint3 a)
640
+ {
641
+ return make_uint3(b - a.x, b - a.y, b - a.z);
642
+ }
643
+ inline __host__ __device__ void operator-=(uint3 &a, uint b)
644
+ {
645
+ a.x -= b;
646
+ a.y -= b;
647
+ a.z -= b;
648
+ }
649
+
650
+ inline __host__ __device__ float4 operator-(float4 a, float4 b)
651
+ {
652
+ return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
653
+ }
654
+ inline __host__ __device__ void operator-=(float4 &a, float4 b)
655
+ {
656
+ a.x -= b.x;
657
+ a.y -= b.y;
658
+ a.z -= b.z;
659
+ a.w -= b.w;
660
+ }
661
+ inline __host__ __device__ float4 operator-(float4 a, float b)
662
+ {
663
+ return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
664
+ }
665
+ inline __host__ __device__ void operator-=(float4 &a, float b)
666
+ {
667
+ a.x -= b;
668
+ a.y -= b;
669
+ a.z -= b;
670
+ a.w -= b;
671
+ }
672
+
673
+ inline __host__ __device__ int4 operator-(int4 a, int4 b)
674
+ {
675
+ return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
676
+ }
677
+ inline __host__ __device__ void operator-=(int4 &a, int4 b)
678
+ {
679
+ a.x -= b.x;
680
+ a.y -= b.y;
681
+ a.z -= b.z;
682
+ a.w -= b.w;
683
+ }
684
+ inline __host__ __device__ int4 operator-(int4 a, int b)
685
+ {
686
+ return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
687
+ }
688
+ inline __host__ __device__ int4 operator-(int b, int4 a)
689
+ {
690
+ return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
691
+ }
692
+ inline __host__ __device__ void operator-=(int4 &a, int b)
693
+ {
694
+ a.x -= b;
695
+ a.y -= b;
696
+ a.z -= b;
697
+ a.w -= b;
698
+ }
699
+
700
+ inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
701
+ {
702
+ return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
703
+ }
704
+ inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
705
+ {
706
+ a.x -= b.x;
707
+ a.y -= b.y;
708
+ a.z -= b.z;
709
+ a.w -= b.w;
710
+ }
711
+ inline __host__ __device__ uint4 operator-(uint4 a, uint b)
712
+ {
713
+ return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
714
+ }
715
+ inline __host__ __device__ uint4 operator-(uint b, uint4 a)
716
+ {
717
+ return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
718
+ }
719
+ inline __host__ __device__ void operator-=(uint4 &a, uint b)
720
+ {
721
+ a.x -= b;
722
+ a.y -= b;
723
+ a.z -= b;
724
+ a.w -= b;
725
+ }
726
+
727
+ ////////////////////////////////////////////////////////////////////////////////
728
+ // multiply
729
+ ////////////////////////////////////////////////////////////////////////////////
730
+
731
+ inline __host__ __device__ float2 operator*(float2 a, float2 b)
732
+ {
733
+ return make_float2(a.x * b.x, a.y * b.y);
734
+ }
735
+ inline __host__ __device__ void operator*=(float2 &a, float2 b)
736
+ {
737
+ a.x *= b.x;
738
+ a.y *= b.y;
739
+ }
740
+ inline __host__ __device__ float2 operator*(float2 a, float b)
741
+ {
742
+ return make_float2(a.x * b, a.y * b);
743
+ }
744
+ inline __host__ __device__ float2 operator*(float b, float2 a)
745
+ {
746
+ return make_float2(b * a.x, b * a.y);
747
+ }
748
+ inline __host__ __device__ void operator*=(float2 &a, float b)
749
+ {
750
+ a.x *= b;
751
+ a.y *= b;
752
+ }
753
+
754
+ inline __host__ __device__ int2 operator*(int2 a, int2 b)
755
+ {
756
+ return make_int2(a.x * b.x, a.y * b.y);
757
+ }
758
+ inline __host__ __device__ void operator*=(int2 &a, int2 b)
759
+ {
760
+ a.x *= b.x;
761
+ a.y *= b.y;
762
+ }
763
+ inline __host__ __device__ int2 operator*(int2 a, int b)
764
+ {
765
+ return make_int2(a.x * b, a.y * b);
766
+ }
767
+ inline __host__ __device__ int2 operator*(int b, int2 a)
768
+ {
769
+ return make_int2(b * a.x, b * a.y);
770
+ }
771
+ inline __host__ __device__ void operator*=(int2 &a, int b)
772
+ {
773
+ a.x *= b;
774
+ a.y *= b;
775
+ }
776
+
777
+ inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
778
+ {
779
+ return make_uint2(a.x * b.x, a.y * b.y);
780
+ }
781
+ inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
782
+ {
783
+ a.x *= b.x;
784
+ a.y *= b.y;
785
+ }
786
+ inline __host__ __device__ uint2 operator*(uint2 a, uint b)
787
+ {
788
+ return make_uint2(a.x * b, a.y * b);
789
+ }
790
+ inline __host__ __device__ uint2 operator*(uint b, uint2 a)
791
+ {
792
+ return make_uint2(b * a.x, b * a.y);
793
+ }
794
+ inline __host__ __device__ void operator*=(uint2 &a, uint b)
795
+ {
796
+ a.x *= b;
797
+ a.y *= b;
798
+ }
799
+
800
+ inline __host__ __device__ float3 operator*(float3 a, float3 b)
801
+ {
802
+ return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
803
+ }
804
+ inline __host__ __device__ void operator*=(float3 &a, float3 b)
805
+ {
806
+ a.x *= b.x;
807
+ a.y *= b.y;
808
+ a.z *= b.z;
809
+ }
810
+ inline __host__ __device__ float3 operator*(float3 a, float b)
811
+ {
812
+ return make_float3(a.x * b, a.y * b, a.z * b);
813
+ }
814
+ inline __host__ __device__ float3 operator*(float b, float3 a)
815
+ {
816
+ return make_float3(b * a.x, b * a.y, b * a.z);
817
+ }
818
+ inline __host__ __device__ void operator*=(float3 &a, float b)
819
+ {
820
+ a.x *= b;
821
+ a.y *= b;
822
+ a.z *= b;
823
+ }
824
+
825
+ inline __host__ __device__ int3 operator*(int3 a, int3 b)
826
+ {
827
+ return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
828
+ }
829
+ inline __host__ __device__ void operator*=(int3 &a, int3 b)
830
+ {
831
+ a.x *= b.x;
832
+ a.y *= b.y;
833
+ a.z *= b.z;
834
+ }
835
+ inline __host__ __device__ int3 operator*(int3 a, int b)
836
+ {
837
+ return make_int3(a.x * b, a.y * b, a.z * b);
838
+ }
839
+ inline __host__ __device__ int3 operator*(int b, int3 a)
840
+ {
841
+ return make_int3(b * a.x, b * a.y, b * a.z);
842
+ }
843
+ inline __host__ __device__ void operator*=(int3 &a, int b)
844
+ {
845
+ a.x *= b;
846
+ a.y *= b;
847
+ a.z *= b;
848
+ }
849
+
850
+ inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
851
+ {
852
+ return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
853
+ }
854
+ inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
855
+ {
856
+ a.x *= b.x;
857
+ a.y *= b.y;
858
+ a.z *= b.z;
859
+ }
860
+ inline __host__ __device__ uint3 operator*(uint3 a, uint b)
861
+ {
862
+ return make_uint3(a.x * b, a.y * b, a.z * b);
863
+ }
864
+ inline __host__ __device__ uint3 operator*(uint b, uint3 a)
865
+ {
866
+ return make_uint3(b * a.x, b * a.y, b * a.z);
867
+ }
868
+ inline __host__ __device__ void operator*=(uint3 &a, uint b)
869
+ {
870
+ a.x *= b;
871
+ a.y *= b;
872
+ a.z *= b;
873
+ }
874
+
875
+ inline __host__ __device__ float4 operator*(float4 a, float4 b)
876
+ {
877
+ return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
878
+ }
879
+ inline __host__ __device__ void operator*=(float4 &a, float4 b)
880
+ {
881
+ a.x *= b.x;
882
+ a.y *= b.y;
883
+ a.z *= b.z;
884
+ a.w *= b.w;
885
+ }
886
+ inline __host__ __device__ float4 operator*(float4 a, float b)
887
+ {
888
+ return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
889
+ }
890
+ inline __host__ __device__ float4 operator*(float b, float4 a)
891
+ {
892
+ return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
893
+ }
894
+ inline __host__ __device__ void operator*=(float4 &a, float b)
895
+ {
896
+ a.x *= b;
897
+ a.y *= b;
898
+ a.z *= b;
899
+ a.w *= b;
900
+ }
901
+
902
+ inline __host__ __device__ int4 operator*(int4 a, int4 b)
903
+ {
904
+ return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
905
+ }
906
+ inline __host__ __device__ void operator*=(int4 &a, int4 b)
907
+ {
908
+ a.x *= b.x;
909
+ a.y *= b.y;
910
+ a.z *= b.z;
911
+ a.w *= b.w;
912
+ }
913
+ inline __host__ __device__ int4 operator*(int4 a, int b)
914
+ {
915
+ return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
916
+ }
917
+ inline __host__ __device__ int4 operator*(int b, int4 a)
918
+ {
919
+ return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
920
+ }
921
+ inline __host__ __device__ void operator*=(int4 &a, int b)
922
+ {
923
+ a.x *= b;
924
+ a.y *= b;
925
+ a.z *= b;
926
+ a.w *= b;
927
+ }
928
+
929
+ inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
930
+ {
931
+ return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
932
+ }
933
+ inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
934
+ {
935
+ a.x *= b.x;
936
+ a.y *= b.y;
937
+ a.z *= b.z;
938
+ a.w *= b.w;
939
+ }
940
+ inline __host__ __device__ uint4 operator*(uint4 a, uint b)
941
+ {
942
+ return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
943
+ }
944
+ inline __host__ __device__ uint4 operator*(uint b, uint4 a)
945
+ {
946
+ return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
947
+ }
948
+ inline __host__ __device__ void operator*=(uint4 &a, uint b)
949
+ {
950
+ a.x *= b;
951
+ a.y *= b;
952
+ a.z *= b;
953
+ a.w *= b;
954
+ }
955
+
956
+ ////////////////////////////////////////////////////////////////////////////////
957
+ // divide
958
+ ////////////////////////////////////////////////////////////////////////////////
959
+
960
+ inline __host__ __device__ float2 operator/(float2 a, float2 b)
961
+ {
962
+ return make_float2(a.x / b.x, a.y / b.y);
963
+ }
964
+ inline __host__ __device__ void operator/=(float2 &a, float2 b)
965
+ {
966
+ a.x /= b.x;
967
+ a.y /= b.y;
968
+ }
969
+ inline __host__ __device__ float2 operator/(float2 a, float b)
970
+ {
971
+ return make_float2(a.x / b, a.y / b);
972
+ }
973
+ inline __host__ __device__ void operator/=(float2 &a, float b)
974
+ {
975
+ a.x /= b;
976
+ a.y /= b;
977
+ }
978
+ inline __host__ __device__ float2 operator/(float b, float2 a)
979
+ {
980
+ return make_float2(b / a.x, b / a.y);
981
+ }
982
+
983
+ inline __host__ __device__ float3 operator/(float3 a, float3 b)
984
+ {
985
+ return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
986
+ }
987
+ inline __host__ __device__ void operator/=(float3 &a, float3 b)
988
+ {
989
+ a.x /= b.x;
990
+ a.y /= b.y;
991
+ a.z /= b.z;
992
+ }
993
+ inline __host__ __device__ float3 operator/(float3 a, float b)
994
+ {
995
+ return make_float3(a.x / b, a.y / b, a.z / b);
996
+ }
997
+ inline __host__ __device__ void operator/=(float3 &a, float b)
998
+ {
999
+ a.x /= b;
1000
+ a.y /= b;
1001
+ a.z /= b;
1002
+ }
1003
+ inline __host__ __device__ float3 operator/(float b, float3 a)
1004
+ {
1005
+ return make_float3(b / a.x, b / a.y, b / a.z);
1006
+ }
1007
+
1008
+ inline __host__ __device__ float4 operator/(float4 a, float4 b)
1009
+ {
1010
+ return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
1011
+ }
1012
+ inline __host__ __device__ void operator/=(float4 &a, float4 b)
1013
+ {
1014
+ a.x /= b.x;
1015
+ a.y /= b.y;
1016
+ a.z /= b.z;
1017
+ a.w /= b.w;
1018
+ }
1019
+ inline __host__ __device__ float4 operator/(float4 a, float b)
1020
+ {
1021
+ return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
1022
+ }
1023
+ inline __host__ __device__ void operator/=(float4 &a, float b)
1024
+ {
1025
+ a.x /= b;
1026
+ a.y /= b;
1027
+ a.z /= b;
1028
+ a.w /= b;
1029
+ }
1030
+ inline __host__ __device__ float4 operator/(float b, float4 a)
1031
+ {
1032
+ return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
1033
+ }
1034
+
1035
+ ////////////////////////////////////////////////////////////////////////////////
1036
+ // min
1037
+ ////////////////////////////////////////////////////////////////////////////////
1038
+
1039
+ inline __host__ __device__ float2 fminf(float2 a, float2 b)
1040
+ {
1041
+ return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
1042
+ }
1043
+ inline __host__ __device__ float3 fminf(float3 a, float3 b)
1044
+ {
1045
+ return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
1046
+ }
1047
+ inline __host__ __device__ float4 fminf(float4 a, float4 b)
1048
+ {
1049
+ return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
1050
+ }
1051
+
1052
+ inline __host__ __device__ int2 min(int2 a, int2 b)
1053
+ {
1054
+ return make_int2(min(a.x,b.x), min(a.y,b.y));
1055
+ }
1056
+ inline __host__ __device__ int3 min(int3 a, int3 b)
1057
+ {
1058
+ return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1059
+ }
1060
+ inline __host__ __device__ int4 min(int4 a, int4 b)
1061
+ {
1062
+ return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1063
+ }
1064
+
1065
+ inline __host__ __device__ uint2 min(uint2 a, uint2 b)
1066
+ {
1067
+ return make_uint2(min(a.x,b.x), min(a.y,b.y));
1068
+ }
1069
+ inline __host__ __device__ uint3 min(uint3 a, uint3 b)
1070
+ {
1071
+ return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1072
+ }
1073
+ inline __host__ __device__ uint4 min(uint4 a, uint4 b)
1074
+ {
1075
+ return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1076
+ }
1077
+
1078
+ ////////////////////////////////////////////////////////////////////////////////
1079
+ // max
1080
+ ////////////////////////////////////////////////////////////////////////////////
1081
+
1082
+ inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
1083
+ {
1084
+ return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
1085
+ }
1086
+ inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
1087
+ {
1088
+ return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
1089
+ }
1090
+ inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
1091
+ {
1092
+ return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
1093
+ }
1094
+
1095
+ inline __host__ __device__ int2 max(int2 a, int2 b)
1096
+ {
1097
+ return make_int2(max(a.x,b.x), max(a.y,b.y));
1098
+ }
1099
+ inline __host__ __device__ int3 max(int3 a, int3 b)
1100
+ {
1101
+ return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1102
+ }
1103
+ inline __host__ __device__ int4 max(int4 a, int4 b)
1104
+ {
1105
+ return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1106
+ }
1107
+
1108
+ inline __host__ __device__ uint2 max(uint2 a, uint2 b)
1109
+ {
1110
+ return make_uint2(max(a.x,b.x), max(a.y,b.y));
1111
+ }
1112
+ inline __host__ __device__ uint3 max(uint3 a, uint3 b)
1113
+ {
1114
+ return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1115
+ }
1116
+ inline __host__ __device__ uint4 max(uint4 a, uint4 b)
1117
+ {
1118
+ return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1119
+ }
1120
+
1121
+ ////////////////////////////////////////////////////////////////////////////////
1122
+ // lerp
1123
+ // - linear interpolation between a and b, based on value t in [0, 1] range
1124
+ ////////////////////////////////////////////////////////////////////////////////
1125
+
1126
+ inline __device__ __host__ float lerp(float a, float b, float t)
1127
+ {
1128
+ return a + t*(b-a);
1129
+ }
1130
+ inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
1131
+ {
1132
+ return a + t*(b-a);
1133
+ }
1134
+ inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
1135
+ {
1136
+ return a + t*(b-a);
1137
+ }
1138
+ inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
1139
+ {
1140
+ return a + t*(b-a);
1141
+ }
1142
+
1143
+ ////////////////////////////////////////////////////////////////////////////////
1144
+ // clamp
1145
+ // - clamp the value v to be in the range [a, b]
1146
+ ////////////////////////////////////////////////////////////////////////////////
1147
+
1148
+ inline __device__ __host__ float clamp(float f, float a, float b)
1149
+ {
1150
+ return fmaxf(a, fminf(f, b));
1151
+ }
1152
+ inline __device__ __host__ int clamp(int f, int a, int b)
1153
+ {
1154
+ return max(a, min(f, b));
1155
+ }
1156
+ inline __device__ __host__ uint clamp(uint f, uint a, uint b)
1157
+ {
1158
+ return max(a, min(f, b));
1159
+ }
1160
+
1161
+ inline __device__ __host__ float2 clamp(float2 v, float a, float b)
1162
+ {
1163
+ return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
1164
+ }
1165
+ inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
1166
+ {
1167
+ return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1168
+ }
1169
+ inline __device__ __host__ float3 clamp(float3 v, float a, float b)
1170
+ {
1171
+ return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1172
+ }
1173
+ inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
1174
+ {
1175
+ return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1176
+ }
1177
+ inline __device__ __host__ float4 clamp(float4 v, float a, float b)
1178
+ {
1179
+ return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1180
+ }
1181
+ inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
1182
+ {
1183
+ return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1184
+ }
1185
+
1186
+ inline __device__ __host__ int2 clamp(int2 v, int a, int b)
1187
+ {
1188
+ return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
1189
+ }
1190
+ inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
1191
+ {
1192
+ return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1193
+ }
1194
+ inline __device__ __host__ int3 clamp(int3 v, int a, int b)
1195
+ {
1196
+ return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1197
+ }
1198
+ inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
1199
+ {
1200
+ return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1201
+ }
1202
+ inline __device__ __host__ int4 clamp(int4 v, int a, int b)
1203
+ {
1204
+ return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1205
+ }
1206
+ inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
1207
+ {
1208
+ return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1209
+ }
1210
+
1211
+ inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
1212
+ {
1213
+ return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
1214
+ }
1215
+ inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
1216
+ {
1217
+ return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1218
+ }
1219
+ inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
1220
+ {
1221
+ return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1222
+ }
1223
+ inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
1224
+ {
1225
+ return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1226
+ }
1227
+ inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
1228
+ {
1229
+ return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1230
+ }
1231
+ inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
1232
+ {
1233
+ return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1234
+ }
1235
+
1236
+ ////////////////////////////////////////////////////////////////////////////////
1237
+ // dot product
1238
+ ////////////////////////////////////////////////////////////////////////////////
1239
+
1240
+ inline __host__ __device__ float dot(float2 a, float2 b)
1241
+ {
1242
+ return a.x * b.x + a.y * b.y;
1243
+ }
1244
+ inline __host__ __device__ float dot(float3 a, float3 b)
1245
+ {
1246
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1247
+ }
1248
+ inline __host__ __device__ float dot(float4 a, float4 b)
1249
+ {
1250
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1251
+ }
1252
+
1253
+ inline __host__ __device__ int dot(int2 a, int2 b)
1254
+ {
1255
+ return a.x * b.x + a.y * b.y;
1256
+ }
1257
+ inline __host__ __device__ int dot(int3 a, int3 b)
1258
+ {
1259
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1260
+ }
1261
+ inline __host__ __device__ int dot(int4 a, int4 b)
1262
+ {
1263
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1264
+ }
1265
+
1266
+ inline __host__ __device__ uint dot(uint2 a, uint2 b)
1267
+ {
1268
+ return a.x * b.x + a.y * b.y;
1269
+ }
1270
+ inline __host__ __device__ uint dot(uint3 a, uint3 b)
1271
+ {
1272
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1273
+ }
1274
+ inline __host__ __device__ uint dot(uint4 a, uint4 b)
1275
+ {
1276
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1277
+ }
1278
+
1279
+ ////////////////////////////////////////////////////////////////////////////////
1280
+ // length
1281
+ ////////////////////////////////////////////////////////////////////////////////
1282
+
1283
+ inline __host__ __device__ float length(float2 v)
1284
+ {
1285
+ return sqrtf(dot(v, v));
1286
+ }
1287
+ inline __host__ __device__ float length(float3 v)
1288
+ {
1289
+ return sqrtf(dot(v, v));
1290
+ }
1291
+ inline __host__ __device__ float length(float4 v)
1292
+ {
1293
+ return sqrtf(dot(v, v));
1294
+ }
1295
+
1296
+ ////////////////////////////////////////////////////////////////////////////////
1297
+ // normalize
1298
+ ////////////////////////////////////////////////////////////////////////////////
1299
+
1300
+ inline __host__ __device__ float2 normalize(float2 v)
1301
+ {
1302
+ float invLen = rsqrtf(dot(v, v));
1303
+ return v * invLen;
1304
+ }
1305
+ inline __host__ __device__ float3 normalize(float3 v)
1306
+ {
1307
+ float invLen = rsqrtf(dot(v, v));
1308
+ return v * invLen;
1309
+ }
1310
+ inline __host__ __device__ float4 normalize(float4 v)
1311
+ {
1312
+ float invLen = rsqrtf(dot(v, v));
1313
+ return v * invLen;
1314
+ }
1315
+
1316
+ ////////////////////////////////////////////////////////////////////////////////
1317
+ // floor
1318
+ ////////////////////////////////////////////////////////////////////////////////
1319
+
1320
+ inline __host__ __device__ float2 floorf(float2 v)
1321
+ {
1322
+ return make_float2(floorf(v.x), floorf(v.y));
1323
+ }
1324
+ inline __host__ __device__ float3 floorf(float3 v)
1325
+ {
1326
+ return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
1327
+ }
1328
+ inline __host__ __device__ float4 floorf(float4 v)
1329
+ {
1330
+ return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
1331
+ }
1332
+
1333
+ ////////////////////////////////////////////////////////////////////////////////
1334
+ // frac - returns the fractional portion of a scalar or each vector component
1335
+ ////////////////////////////////////////////////////////////////////////////////
1336
+
1337
+ inline __host__ __device__ float fracf(float v)
1338
+ {
1339
+ return v - floorf(v);
1340
+ }
1341
+ inline __host__ __device__ float2 fracf(float2 v)
1342
+ {
1343
+ return make_float2(fracf(v.x), fracf(v.y));
1344
+ }
1345
+ inline __host__ __device__ float3 fracf(float3 v)
1346
+ {
1347
+ return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
1348
+ }
1349
+ inline __host__ __device__ float4 fracf(float4 v)
1350
+ {
1351
+ return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
1352
+ }
1353
+
1354
+ ////////////////////////////////////////////////////////////////////////////////
1355
+ // fmod
1356
+ ////////////////////////////////////////////////////////////////////////////////
1357
+
1358
+ inline __host__ __device__ float2 fmodf(float2 a, float2 b)
1359
+ {
1360
+ return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
1361
+ }
1362
+ inline __host__ __device__ float3 fmodf(float3 a, float3 b)
1363
+ {
1364
+ return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
1365
+ }
1366
+ inline __host__ __device__ float4 fmodf(float4 a, float4 b)
1367
+ {
1368
+ return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
1369
+ }
1370
+
1371
+ ////////////////////////////////////////////////////////////////////////////////
1372
+ // absolute value
1373
+ ////////////////////////////////////////////////////////////////////////////////
1374
+
1375
+ inline __host__ __device__ float2 fabs(float2 v)
1376
+ {
1377
+ return make_float2(fabs(v.x), fabs(v.y));
1378
+ }
1379
+ inline __host__ __device__ float3 fabs(float3 v)
1380
+ {
1381
+ return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
1382
+ }
1383
+ inline __host__ __device__ float4 fabs(float4 v)
1384
+ {
1385
+ return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
1386
+ }
1387
+
1388
+ inline __host__ __device__ int2 abs(int2 v)
1389
+ {
1390
+ return make_int2(abs(v.x), abs(v.y));
1391
+ }
1392
+ inline __host__ __device__ int3 abs(int3 v)
1393
+ {
1394
+ return make_int3(abs(v.x), abs(v.y), abs(v.z));
1395
+ }
1396
+ inline __host__ __device__ int4 abs(int4 v)
1397
+ {
1398
+ return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
1399
+ }
1400
+
1401
+ ////////////////////////////////////////////////////////////////////////////////
1402
+ // reflect
1403
+ // - returns reflection of incident ray I around surface normal N
1404
+ // - N should be normalized, reflected vector's length is equal to length of I
1405
+ ////////////////////////////////////////////////////////////////////////////////
1406
+
1407
+ inline __host__ __device__ float3 reflect(float3 i, float3 n)
1408
+ {
1409
+ return i - 2.0f * n * dot(n,i);
1410
+ }
1411
+
1412
+ ////////////////////////////////////////////////////////////////////////////////
1413
+ // cross product
1414
+ ////////////////////////////////////////////////////////////////////////////////
1415
+
1416
+ inline __host__ __device__ float3 cross(float3 a, float3 b)
1417
+ {
1418
+ return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
1419
+ }
1420
+
1421
+ ////////////////////////////////////////////////////////////////////////////////
1422
+ // smoothstep
1423
+ // - returns 0 if x < a
1424
+ // - returns 1 if x > b
1425
+ // - otherwise returns smooth interpolation between 0 and 1 based on x
1426
+ ////////////////////////////////////////////////////////////////////////////////
1427
+
1428
+ inline __device__ __host__ float smoothstep(float a, float b, float x)
1429
+ {
1430
+ float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1431
+ return (y*y*(3.0f - (2.0f*y)));
1432
+ }
1433
+ inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
1434
+ {
1435
+ float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1436
+ return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
1437
+ }
1438
+ inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
1439
+ {
1440
+ float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1441
+ return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
1442
+ }
1443
+ inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
1444
+ {
1445
+ float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1446
+ return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
1447
+ }
1448
+
1449
+ #endif
utils/io_utils.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json, os, sys
3
+ import os.path as osp
4
+ from typing import List, Union, Tuple, Dict
5
+ from pathlib import Path
6
+ import cv2
7
+ import numpy as np
8
+ from imageio import imread, imwrite
9
+ import pickle
10
+ import pycocotools.mask as maskUtils
11
+ from einops import rearrange
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ import io
15
+ import requests
16
+ import traceback
17
+ import base64
18
+ import time
19
+
20
+
21
+ NP_BOOL_TYPES = (np.bool_, np.bool8)
22
+ NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
23
+ NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
24
+
25
+ class NumpyEncoder(json.JSONEncoder):
26
+ def default(self, obj):
27
+ if isinstance(obj, np.ndarray):
28
+ return obj.tolist()
29
+ elif isinstance(obj, np.ScalarType):
30
+ if isinstance(obj, NP_BOOL_TYPES):
31
+ return bool(obj)
32
+ elif isinstance(obj, NP_FLOAT_TYPES):
33
+ return float(obj)
34
+ elif isinstance(obj, NP_INT_TYPES):
35
+ return int(obj)
36
+ return json.JSONEncoder.default(self, obj)
37
+
38
+
39
+ def json2dict(json_path: str):
40
+ with open(json_path, 'r', encoding='utf8') as f:
41
+ metadata = json.loads(f.read())
42
+ return metadata
43
+
44
+
45
+ def dict2json(adict: dict, json_path: str):
46
+ with open(json_path, "w", encoding="utf-8") as f:
47
+ f.write(json.dumps(adict, ensure_ascii=False, cls=NumpyEncoder))
48
+
49
+
50
+ def dict2pickle(dumped_path: str, tgt_dict: dict):
51
+ with open(dumped_path, "wb") as f:
52
+ pickle.dump(tgt_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
53
+
54
+
55
+ def pickle2dict(pkl_path: str) -> Dict:
56
+ with open(pkl_path, "rb") as f:
57
+ dumped_data = pickle.load(f)
58
+ return dumped_data
59
+
60
+ def get_all_dirs(root_p: str) -> List[str]:
61
+ alldir = os.listdir(root_p)
62
+ dirlist = []
63
+ for dirp in alldir:
64
+ dirp = osp.join(root_p, dirp)
65
+ if osp.isdir(dirp):
66
+ dirlist.append(dirp)
67
+ return dirlist
68
+
69
+
70
+ def read_filelist(filelistp: str):
71
+ with open(filelistp, 'r', encoding='utf8') as f:
72
+ lines = f.readlines()
73
+ if len(lines) > 0 and lines[-1].strip() == '':
74
+ lines = lines[:-1]
75
+ return lines
76
+
77
+
78
+ VIDEO_EXTS = {'.flv', '.mp4', '.mkv', '.ts', '.mov', 'mpeg'}
79
+ def get_all_videos(video_dir: str, video_exts=VIDEO_EXTS, abs_path=False) -> List[str]:
80
+ filelist = os.listdir(video_dir)
81
+ vlist = []
82
+ for f in filelist:
83
+ if Path(f).suffix in video_exts:
84
+ if abs_path:
85
+ vlist.append(osp.join(video_dir, f))
86
+ else:
87
+ vlist.append(f)
88
+ return vlist
89
+
90
+
91
+ IMG_EXT = {'.bmp', '.jpg', '.png', '.jpeg'}
92
+ def find_all_imgs(img_dir, abs_path=False):
93
+ imglist = []
94
+ dir_list = os.listdir(img_dir)
95
+ for filename in dir_list:
96
+ file_suffix = Path(filename).suffix
97
+ if file_suffix.lower() not in IMG_EXT:
98
+ continue
99
+ if abs_path:
100
+ imglist.append(osp.join(img_dir, filename))
101
+ else:
102
+ imglist.append(filename)
103
+ return imglist
104
+
105
+
106
+ def find_all_files_recursive(tgt_dir: Union[List, str], ext, exclude_dirs={}):
107
+ if isinstance(tgt_dir, str):
108
+ tgt_dir = [tgt_dir]
109
+
110
+ filelst = []
111
+ for d in tgt_dir:
112
+ for root, _, files in os.walk(d):
113
+ if osp.basename(root) in exclude_dirs:
114
+ continue
115
+ for f in files:
116
+ if Path(f).suffix.lower() in ext:
117
+ filelst.append(osp.join(root, f))
118
+
119
+ return filelst
120
+
121
+
122
+ def danbooruid2relpath(id_str: str, file_ext='.jpg'):
123
+ if not isinstance(id_str, str):
124
+ id_str = str(id_str)
125
+ return id_str[-3:].zfill(4) + '/' + id_str + file_ext
126
+
127
+
128
+ def get_template_histvq(template: np.ndarray) -> Tuple[List[np.ndarray]]:
129
+ len_shape = len(template.shape)
130
+ num_c = 3
131
+ mask = None
132
+ if len_shape == 2:
133
+ num_c = 1
134
+ elif len_shape == 3 and template.shape[-1] == 4:
135
+ mask = np.where(template[..., -1])
136
+ template = template[..., :num_c][mask]
137
+
138
+ values, quantiles = [], []
139
+ for ii in range(num_c):
140
+ v, c = np.unique(template[..., ii].ravel(), return_counts=True)
141
+ q = np.cumsum(c).astype(np.float64)
142
+ if len(q) < 1:
143
+ return None, None
144
+ q /= q[-1]
145
+ values.append(v)
146
+ quantiles.append(q)
147
+ return values, quantiles
148
+
149
+
150
+ def inplace_hist_matching(img: np.ndarray, tv: List[np.ndarray], tq: List[np.ndarray]) -> None:
151
+ len_shape = len(img.shape)
152
+ num_c = 3
153
+ mask = None
154
+
155
+ tgtimg = img
156
+ if len_shape == 2:
157
+ num_c = 1
158
+ elif len_shape == 3 and img.shape[-1] == 4:
159
+ mask = np.where(img[..., -1])
160
+ tgtimg = img[..., :num_c][mask]
161
+
162
+ im_h, im_w = img.shape[:2]
163
+ oldtype = img.dtype
164
+ for ii in range(num_c):
165
+ _, bin_idx, s_counts = np.unique(tgtimg[..., ii].ravel(), return_inverse=True,
166
+ return_counts=True)
167
+ s_quantiles = np.cumsum(s_counts).astype(np.float64)
168
+ if len(s_quantiles) == 0:
169
+ return
170
+ s_quantiles /= s_quantiles[-1]
171
+ interp_t_values = np.interp(s_quantiles, tq[ii], tv[ii]).astype(oldtype)
172
+ if mask is not None:
173
+ img[..., ii][mask] = interp_t_values[bin_idx]
174
+ else:
175
+ img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
176
+ # try:
177
+ # img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
178
+ # except:
179
+ # LOGGER.error('##################### sth goes wrong')
180
+ # cv2.imshow('img', img)
181
+ # cv2.waitKey(0)
182
+
183
+
184
+ def fgbg_hist_matching(fg_list: List, bg: np.ndarray, min_tq_num=128):
185
+ btv, btq = get_template_histvq(bg)
186
+ ftv, ftq = get_template_histvq(fg_list[0]['image'])
187
+ num_fg = len(fg_list)
188
+ idx_matched = -1
189
+ if num_fg > 1:
190
+ _ftv, _ftq = get_template_histvq(fg_list[0]['image'])
191
+ if _ftq is not None and ftq is not None:
192
+ if len(_ftq[0]) > len(ftq[0]):
193
+ idx_matched = num_fg - 1
194
+ ftv, ftq = _ftv, _ftq
195
+ else:
196
+ idx_matched = 0
197
+
198
+ if btq is not None and ftq is not None:
199
+ if len(btq[0]) > len(ftq[0]):
200
+ tv, tq = btv, btq
201
+ idx_matched = -1
202
+ else:
203
+ tv, tq = ftv, ftq
204
+ if len(tq[0]) > min_tq_num:
205
+ inplace_hist_matching(bg, tv, tq)
206
+
207
+ if len(tq[0]) > min_tq_num:
208
+ for ii, fg_dict in enumerate(fg_list):
209
+ fg = fg_dict['image']
210
+ if ii != idx_matched and len(tq[0]) > min_tq_num:
211
+ inplace_hist_matching(fg, tv, tq)
212
+
213
+
214
+ def imread_nogrey_rgb(imp: str) -> np.ndarray:
215
+ img: np.ndarray = imread(imp)
216
+ c = 1
217
+ if len(img.shape) == 3:
218
+ c = img.shape[-1]
219
+ if c == 1:
220
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
221
+ if c == 4:
222
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
223
+ return img
224
+
225
+
226
+ def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value: Tuple = (114, 114, 114)):
227
+ h, w = img.shape[:2]
228
+ pad_h, pad_w = 0, 0
229
+
230
+ # make square image
231
+ if w < h:
232
+ pad_w = h - w
233
+ w += pad_w
234
+ elif h < w:
235
+ pad_h = w - h
236
+ h += pad_h
237
+
238
+ pad_size = tgt_size - h
239
+ if pad_size > 0:
240
+ pad_h += pad_size
241
+ pad_w += pad_size
242
+
243
+ if pad_h > 0 or pad_w > 0:
244
+ img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value)
245
+
246
+ down_scale_ratio = tgt_size / img.shape[0]
247
+ assert down_scale_ratio <= 1
248
+ if down_scale_ratio < 1:
249
+ img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
250
+
251
+ return img, down_scale_ratio, pad_h, pad_w
252
+
253
+
254
+ def scaledown_maxsize(img: np.ndarray, max_size: int, divisior: int = None):
255
+
256
+ im_h, im_w = img.shape[:2]
257
+ ori_h, ori_w = img.shape[:2]
258
+ resize_ratio = max_size / max(im_h, im_w)
259
+ if resize_ratio < 1:
260
+ if im_h > im_w:
261
+ im_h = max_size
262
+ im_w = max(1, int(round(im_w * resize_ratio)))
263
+
264
+ else:
265
+ im_w = max_size
266
+ im_h = max(1, int(round(im_h * resize_ratio)))
267
+ if divisior is not None:
268
+ im_w = int(np.ceil(im_w / divisior) * divisior)
269
+ im_h = int(np.ceil(im_h / divisior) * divisior)
270
+
271
+ if im_w != ori_w or im_h != ori_h:
272
+ img = cv2.resize(img, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
273
+
274
+ return img
275
+
276
+
277
+ def resize_pad(img: np.ndarray, tgt_size: int, pad_value: Tuple = (0, 0, 0)):
278
+ # downscale to tgt_size and pad to square
279
+ img = scaledown_maxsize(img, tgt_size)
280
+ padl, padr, padt, padb = 0, 0, 0, 0
281
+ h, w = img.shape[:2]
282
+ # padt = (tgt_size - h) // 2
283
+ # padb = tgt_size - h - padt
284
+ # padl = (tgt_size - w) // 2
285
+ # padr = tgt_size - w - padl
286
+ padb = tgt_size - h
287
+ padr = tgt_size - w
288
+
289
+ if padt + padb + padl + padr > 0:
290
+ img = cv2.copyMakeBorder(img, padt, padb, padl, padr, cv2.BORDER_CONSTANT, value=pad_value)
291
+
292
+ return img, (padt, padb, padl, padr)
293
+
294
+
295
+ def resize_pad2divisior(img: np.ndarray, tgt_size: int, divisior: int = 64, pad_value: Tuple = (0, 0, 0)):
296
+ img = scaledown_maxsize(img, tgt_size)
297
+ img, (pad_h, pad_w) = pad2divisior(img, divisior, pad_value)
298
+ return img, (pad_h, pad_w)
299
+
300
+
301
+ def img2grey(img: Union[np.ndarray, str], is_rgb: bool = False) -> np.ndarray:
302
+ if isinstance(img, np.ndarray):
303
+ if len(img.shape) == 3:
304
+ if img.shape[-1] != 1:
305
+ if is_rgb:
306
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
307
+ else:
308
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
309
+ else:
310
+ img = img[..., 0]
311
+ return img
312
+ elif isinstance(img, str):
313
+ return cv2.imread(img, cv2.IMREAD_GRAYSCALE)
314
+ else:
315
+ raise NotImplementedError
316
+
317
+
318
+ def pad2divisior(img: np.ndarray, divisior: int, value = (0, 0, 0)) -> np.ndarray:
319
+ im_h, im_w = img.shape[:2]
320
+ pad_h = int(np.ceil(im_h / divisior)) * divisior - im_h
321
+ pad_w = int(np.ceil(im_w / divisior)) * divisior - im_w
322
+ if pad_h != 0 or pad_w != 0:
323
+ img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, value=value, borderType=cv2.BORDER_CONSTANT)
324
+ return img, (pad_h, pad_w)
325
+
326
+
327
+ def mask2rle(mask: np.ndarray, decode_for_json: bool = True) -> Dict:
328
+ mask_rle = maskUtils.encode(np.array(
329
+ mask[..., np.newaxis] > 0, order='F',
330
+ dtype='uint8'))[0]
331
+ if decode_for_json:
332
+ mask_rle['counts'] = mask_rle['counts'].decode()
333
+ return mask_rle
334
+
335
+
336
+ def bbox2xyxy(box) -> Tuple[int]:
337
+ x1, y1 = box[0], box[1]
338
+ return x1, y1, x1+box[2], y1+box[3]
339
+
340
+
341
+ def bbox_overlap_area(abox, boxb) -> int:
342
+ ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
343
+ bx1, by1, bx2, by2 = bbox2xyxy(boxb)
344
+
345
+ ix = min(ax2, bx2) - max(ax1, bx1)
346
+ iy = min(ay2, by2) - max(ay1, by1)
347
+
348
+ if ix > 0 and iy > 0:
349
+ return ix * iy
350
+ else:
351
+ return 0
352
+
353
+
354
+ def bbox_overlap_xy(abox, boxb) -> Tuple[int]:
355
+ ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
356
+ bx1, by1, bx2, by2 = bbox2xyxy(boxb)
357
+
358
+ ix = min(ax2, bx2) - max(ax1, bx1)
359
+ iy = min(ay2, by2) - max(ay1, by1)
360
+
361
+ return ix, iy
362
+
363
+
364
+ def xyxy_overlap_area(axyxy, bxyxy) -> int:
365
+ ax1, ay1, ax2, ay2 = axyxy
366
+ bx1, by1, bx2, by2 = bxyxy
367
+
368
+ ix = min(ax2, bx2) - max(ax1, bx1)
369
+ iy = min(ay2, by2) - max(ay1, by1)
370
+
371
+ if ix > 0 and iy > 0:
372
+ return ix * iy
373
+ else:
374
+ return 0
375
+
376
+
377
+ DIRNAME2TAG = {'rezero': 're:zero'}
378
+ def dirname2charactername(dirname, start=6):
379
+ cname = dirname[start:]
380
+ for k, v in DIRNAME2TAG.items():
381
+ cname = cname.replace(k, v)
382
+ return cname
383
+
384
+
385
+ def imglist2grid(imglist: np.ndarray, grid_size: int = 384, col=None) -> np.ndarray:
386
+ sqimlist = []
387
+ for img in imglist:
388
+ sqimlist.append(square_pad_resize(img, grid_size)[0])
389
+
390
+ nimg = len(imglist)
391
+ if nimg == 0:
392
+ return None
393
+ padn = 0
394
+ if col is None:
395
+ if nimg > 5:
396
+ row = int(np.round(np.sqrt(nimg)))
397
+ col = int(np.ceil(nimg / row))
398
+ else:
399
+ col = nimg
400
+
401
+ padn = int(np.ceil(nimg / col) * col) - nimg
402
+ if padn != 0:
403
+ padimg = np.zeros_like(sqimlist[0])
404
+ for _ in range(padn):
405
+ sqimlist.append(padimg)
406
+
407
+ return rearrange(sqimlist, '(row col) h w c -> (row h) (col w) c', col=col)
408
+
409
+ def write_jsonlines(filep: str, dict_lst: List[str], progress_bar: bool = True):
410
+ with open(filep, 'w') as out:
411
+ if progress_bar:
412
+ lst = tqdm(dict_lst)
413
+ else:
414
+ lst = dict_lst
415
+ for ddict in lst:
416
+ jout = json.dumps(ddict) + '\n'
417
+ out.write(jout)
418
+
419
+ def read_jsonlines(filep: str):
420
+ with open(filep, 'r', encoding='utf8') as f:
421
+ result = [json.loads(jline) for jline in f.read().splitlines()]
422
+ return result
423
+
424
+
425
+ def _b64encode(x: bytes) -> str:
426
+ return base64.b64encode(x).decode("utf-8")
427
+
428
+
429
+ def img2b64(img):
430
+ """
431
+ Convert a PIL image to a base64-encoded string.
432
+ """
433
+ if isinstance(img, np.ndarray):
434
+ img = Image.fromarray(img)
435
+ buffered = io.BytesIO()
436
+ img.save(buffered, format='PNG')
437
+ return _b64encode(buffered.getvalue())
438
+
439
+
440
+ def save_encoded_image(b64_image: str, output_path: str):
441
+ with open(output_path, "wb") as image_file:
442
+ image_file.write(base64.b64decode(b64_image))
443
+
444
+ def submit_request(url, data, exist_on_exception=True, auth=None, wait_time = 30):
445
+ response = None
446
+ try:
447
+ while True:
448
+ try:
449
+ response = requests.post(url, data=data, auth=auth)
450
+ response.raise_for_status()
451
+ break
452
+ except Exception as e:
453
+ if wait_time > 0:
454
+ print(traceback.format_exc(), file=sys.stderr)
455
+ print(f'sleep {wait_time} sec...')
456
+ time.sleep(wait_time)
457
+ continue
458
+ else:
459
+ raise e
460
+ except Exception as e:
461
+ print(traceback.format_exc(), file=sys.stderr)
462
+ if response is not None:
463
+ print('response content: ' + response.text)
464
+ if exist_on_exception:
465
+ exit()
466
+ return response
467
+
468
+
469
+ # def resize_image(input_image, resolution):
470
+ # H, W = input_image.shape[:2]
471
+ # k = float(min(resolution)) / min(H, W)
472
+ # img = cv2.resize(input_image, resolution, interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
473
+ # return img
utils/logger.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path as osp
3
+ from termcolor import colored
4
+
5
+ def set_logging(name=None, verbose=True):
6
+ for handler in logging.root.handlers[:]:
7
+ logging.root.removeHandler(handler)
8
+ # Sets level and returns logger
9
+ # rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
10
+ fmt = (
11
+ # colored("[%(name)s]", "magenta", attrs=["bold"])
12
+ colored("[%(asctime)s]", "blue")
13
+ + colored("%(levelname)s:", "green")
14
+ + colored("%(message)s", "white")
15
+ )
16
+ logging.basicConfig(format=fmt, level=logging.INFO if verbose else logging.WARNING)
17
+ return logging.getLogger(name)
18
+
19
+ LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
20
+
utils/mmdet_custom_hooks.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.fileio import FileClient
2
+ from mmengine.dist import master_only
3
+ from einops import rearrange
4
+ import torch
5
+ import mmcv
6
+ import numpy as np
7
+ import os.path as osp
8
+ import cv2
9
+ from typing import Optional, Sequence
10
+ import torch.nn as nn
11
+ from mmdet.apis import inference_detector
12
+ from mmcv.transforms import Compose
13
+ from mmdet.engine import DetVisualizationHook
14
+ from mmdet.registry import HOOKS
15
+ from mmdet.structures import DetDataSample
16
+
17
+ from utils.io_utils import find_all_imgs, square_pad_resize, imglist2grid
18
+
19
+ def inference_detector(
20
+ model: nn.Module,
21
+ imgs,
22
+ test_pipeline
23
+ ):
24
+
25
+ if isinstance(imgs, (list, tuple)):
26
+ is_batch = True
27
+ else:
28
+ imgs = [imgs]
29
+ is_batch = False
30
+
31
+ if len(imgs) == 0:
32
+ return []
33
+
34
+ test_pipeline = test_pipeline.copy()
35
+ if isinstance(imgs[0], np.ndarray):
36
+ # Calling this method across libraries will result
37
+ # in module unregistered error if not prefixed with mmdet.
38
+ test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
39
+
40
+ test_pipeline = Compose(test_pipeline)
41
+
42
+ result_list = []
43
+ for img in imgs:
44
+ # prepare data
45
+ if isinstance(img, np.ndarray):
46
+ # TODO: remove img_id.
47
+ data_ = dict(img=img, img_id=0)
48
+ else:
49
+ # TODO: remove img_id.
50
+ data_ = dict(img_path=img, img_id=0)
51
+ # build the data pipeline
52
+ data_ = test_pipeline(data_)
53
+
54
+ data_['inputs'] = [data_['inputs']]
55
+ data_['data_samples'] = [data_['data_samples']]
56
+
57
+ # forward the model
58
+ with torch.no_grad():
59
+ results = model.test_step(data_)[0]
60
+
61
+ result_list.append(results)
62
+
63
+ if not is_batch:
64
+ return result_list[0]
65
+ else:
66
+ return result_list
67
+
68
+
69
+ @HOOKS.register_module()
70
+ class InstanceSegVisualizationHook(DetVisualizationHook):
71
+
72
+ def __init__(self, visualize_samples: str = '',
73
+ read_rgb: bool = False,
74
+ draw: bool = False,
75
+ interval: int = 50,
76
+ score_thr: float = 0.3,
77
+ show: bool = False,
78
+ wait_time: float = 0.,
79
+ test_out_dir: Optional[str] = None,
80
+ file_client_args: dict = dict(backend='disk')):
81
+ super().__init__(draw, interval, score_thr, show, wait_time, test_out_dir, file_client_args)
82
+ self.vis_samples = []
83
+
84
+ if osp.exists(visualize_samples):
85
+ self.channel_order = channel_order = 'rgb' if read_rgb else 'bgr'
86
+ samples = find_all_imgs(visualize_samples, abs_path=True)
87
+ for imgp in samples:
88
+ img = mmcv.imread(imgp, channel_order=channel_order)
89
+ img, _, _, _ = square_pad_resize(img, 640)
90
+ self.vis_samples.append(img)
91
+
92
+ def before_val(self, runner) -> None:
93
+ total_curr_iter = runner.iter
94
+ self._visualize_data(total_curr_iter, runner)
95
+ return super().before_val(runner)
96
+
97
+ # def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
98
+ # outputs: Sequence[DetDataSample]) -> None:
99
+ # """Run after every ``self.interval`` validation iterations.
100
+
101
+ # Args:
102
+ # runner (:obj:`Runner`): The runner of the validation process.
103
+ # batch_idx (int): The index of the current batch in the val loop.
104
+ # data_batch (dict): Data from dataloader.
105
+ # outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
106
+ # that contain annotations and predictions.
107
+ # """
108
+ # # if self.draw is False:
109
+ # # return
110
+
111
+ # if self.file_client is None:
112
+ # self.file_client = FileClient(**self.file_client_args)
113
+
114
+
115
+ # # There is no guarantee that the same batch of images
116
+ # # is visualized for each evaluation.
117
+ # total_curr_iter = runner.iter + batch_idx
118
+
119
+ # # # Visualize only the first data
120
+ # # img_path = outputs[0].img_path
121
+ # # img_bytes = self.file_client.get(img_path)
122
+ # # img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
123
+ # if total_curr_iter % self.interval == 0 and self.vis_samples:
124
+ # self._visualize_data(total_curr_iter, runner)
125
+
126
+
127
+ @master_only
128
+ def _visualize_data(self, total_curr_iter, runner):
129
+
130
+ tgt_size = 384
131
+
132
+ runner.model.eval()
133
+ outputs = inference_detector(runner.model, self.vis_samples, test_pipeline=runner.cfg.test_pipeline)
134
+ vis_results = []
135
+ for img, output in zip(self.vis_samples, outputs):
136
+ vis_img = self.add_datasample(
137
+ 'val_img',
138
+ img,
139
+ data_sample=output,
140
+ show=self.show,
141
+ wait_time=self.wait_time,
142
+ pred_score_thr=self.score_thr,
143
+ draw_gt=False,
144
+ step=total_curr_iter)
145
+ vis_results.append(cv2.resize(vis_img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA))
146
+
147
+ drawn_img = imglist2grid(vis_results, tgt_size)
148
+ if drawn_img is None:
149
+ return
150
+ drawn_img = cv2.cvtColor(drawn_img, cv2.COLOR_BGR2RGB)
151
+ visualizer = self._visualizer
152
+ visualizer.set_image(drawn_img)
153
+ visualizer.add_image('val_img', drawn_img, total_curr_iter)
154
+
155
+
156
+ @master_only
157
+ def add_datasample(
158
+ self,
159
+ name: str,
160
+ image: np.ndarray,
161
+ data_sample: Optional['DetDataSample'] = None,
162
+ draw_gt: bool = True,
163
+ draw_pred: bool = True,
164
+ show: bool = False,
165
+ wait_time: float = 0,
166
+ # TODO: Supported in mmengine's Viusalizer.
167
+ out_file: Optional[str] = None,
168
+ pred_score_thr: float = 0.3,
169
+ step: int = 0) -> np.ndarray:
170
+ image = image.clip(0, 255).astype(np.uint8)
171
+ visualizer = self._visualizer
172
+ classes = visualizer.dataset_meta.get('classes', None)
173
+ palette = visualizer.dataset_meta.get('palette', None)
174
+
175
+ gt_img_data = None
176
+ pred_img_data = None
177
+
178
+ if data_sample is not None:
179
+ data_sample = data_sample.cpu()
180
+
181
+ if draw_gt and data_sample is not None:
182
+ gt_img_data = image
183
+ if 'gt_instances' in data_sample:
184
+ gt_img_data = visualizer._draw_instances(image,
185
+ data_sample.gt_instances,
186
+ classes, palette)
187
+
188
+ if 'gt_panoptic_seg' in data_sample:
189
+ assert classes is not None, 'class information is ' \
190
+ 'not provided when ' \
191
+ 'visualizing panoptic ' \
192
+ 'segmentation results.'
193
+ gt_img_data = visualizer._draw_panoptic_seg(
194
+ gt_img_data, data_sample.gt_panoptic_seg, classes)
195
+
196
+ if draw_pred and data_sample is not None:
197
+ pred_img_data = image
198
+ if 'pred_instances' in data_sample:
199
+ pred_instances = data_sample.pred_instances
200
+ pred_instances = pred_instances[
201
+ pred_instances.scores > pred_score_thr]
202
+ pred_img_data = visualizer._draw_instances(image, pred_instances,
203
+ classes, palette)
204
+ if 'pred_panoptic_seg' in data_sample:
205
+ assert classes is not None, 'class information is ' \
206
+ 'not provided when ' \
207
+ 'visualizing panoptic ' \
208
+ 'segmentation results.'
209
+ pred_img_data = visualizer._draw_panoptic_seg(
210
+ pred_img_data, data_sample.pred_panoptic_seg.numpy(),
211
+ classes)
212
+
213
+ if gt_img_data is not None and pred_img_data is not None:
214
+ drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
215
+ elif gt_img_data is not None:
216
+ drawn_img = gt_img_data
217
+ elif pred_img_data is not None:
218
+ drawn_img = pred_img_data
219
+ else:
220
+ # Display the original image directly if nothing is drawn.
221
+ drawn_img = image
222
+
223
+ return drawn_img