ljsabc commited on
Commit
8aa4f1e
·
1 Parent(s): 395d300
utils/__init__.py ADDED
File without changes
utils/booru_tagger.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import pandas as pd
4
+ import numpy as np
5
+ from onnxruntime import InferenceSession
6
+ from typing import Tuple, List, Dict
7
+ from io import BytesIO
8
+ from PIL import Image
9
+
10
+ import cv2
11
+ from pathlib import Path
12
+
13
+ from tqdm import tqdm
14
+
15
+ def make_square(img, target_size):
16
+ old_size = img.shape[:2]
17
+ desired_size = max(old_size)
18
+ desired_size = max(desired_size, target_size)
19
+
20
+ delta_w = desired_size - old_size[1]
21
+ delta_h = desired_size - old_size[0]
22
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
23
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
24
+
25
+ color = [255, 255, 255]
26
+ new_im = cv2.copyMakeBorder(
27
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
28
+ )
29
+ return new_im
30
+
31
+
32
+ def smart_resize(img, size):
33
+ # Assumes the image has already gone through make_square
34
+ if img.shape[0] > size:
35
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
36
+ elif img.shape[0] < size:
37
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
38
+ return img
39
+
40
+ class Tagger :
41
+ def __init__(self, filename) -> None:
42
+ self.model = InferenceSession(filename, providers=['CUDAExecutionProvider'])
43
+ [root, _] = os.path.split(filename)
44
+ self.tags = pd.read_csv(os.path.join(root, 'selected_tags.csv') if root else 'selected_tags.csv')
45
+
46
+ _, self.height, _, _ = self.model.get_inputs()[0].shape
47
+
48
+ characters = self.tags.loc[self.tags['category'] == 4]
49
+ self.characters = set(characters['name'].values.tolist())
50
+
51
+ def label(self, image: Image) -> Dict[str, float] :
52
+ # alpha to white
53
+ image = image.convert('RGBA')
54
+ new_image = Image.new('RGBA', image.size, 'WHITE')
55
+ new_image.paste(image, mask=image)
56
+ image = new_image.convert('RGB')
57
+ image = np.asarray(image)
58
+
59
+ # PIL RGB to OpenCV BGR
60
+ image = image[:, :, ::-1]
61
+
62
+ image = make_square(image, self.height)
63
+ image = smart_resize(image, self.height)
64
+ image = image.astype(np.float32)
65
+ image = np.expand_dims(image, 0)
66
+
67
+ # evaluate model
68
+ input_name = self.model.get_inputs()[0].name
69
+ label_name = self.model.get_outputs()[0].name
70
+ confidents = self.model.run([label_name], {input_name: image})[0]
71
+
72
+ tags = self.tags[:][['name']]
73
+ tags['confidents'] = confidents[0]
74
+
75
+ # first 4 items are for rating (general, sensitive, questionable, explicit)
76
+ ratings = dict(tags[:4].values)
77
+
78
+ # rest are regular tags
79
+ tags = dict(tags[4:].values)
80
+
81
+ tags = {t: v for t, v in tags.items() if v > 0.5}
82
+ return tags
83
+
84
+ def label_cv2_bgr(self, image: np.ndarray) -> Dict[str, float] :
85
+ # image in BGR u8
86
+ image = make_square(image, self.height)
87
+ image = smart_resize(image, self.height)
88
+ image = image.astype(np.float32)
89
+ image = np.expand_dims(image, 0)
90
+
91
+ # evaluate model
92
+ input_name = self.model.get_inputs()[0].name
93
+ label_name = self.model.get_outputs()[0].name
94
+ confidents = self.model.run([label_name], {input_name: image})[0]
95
+
96
+ tags = self.tags[:][['name']]
97
+ cats = self.tags[:][['category']]
98
+ tags['confidents'] = confidents[0]
99
+
100
+ # first 4 items are for rating (general, sensitive, questionable, explicit)
101
+ ratings = dict(tags[:4].values)
102
+
103
+ # rest are regular tags
104
+ tags = dict(tags[4:].values)
105
+
106
+ tags = [t for t, v in tags.items() if v > 0.5]
107
+ character_str = []
108
+ for t in tags:
109
+ if t in self.characters:
110
+ character_str.append(t)
111
+ return tags, character_str
112
+
113
+
114
+ if __name__ == '__main__':
115
+ modelp = r'models/wd-v1-4-swinv2-tagger-v2/model.onnx'
116
+ tagger = Tagger(modelp)
utils/constants.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ CATEGORIES = [
4
+ {"id": 0, "name": "object", "isthing": 1}
5
+ ]
6
+
7
+ IMAGE_ID_ZFILL = 12
8
+
9
+ COLOR_PALETTE = [
10
+ (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
11
+ (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
12
+ (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
13
+ (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
14
+ (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
15
+ (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
16
+ (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
17
+ (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
18
+ (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
19
+ (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
20
+ (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
21
+ (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
22
+ (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
23
+ (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
24
+ (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
25
+ (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
26
+ (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
27
+ (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
28
+ (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
29
+ (246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203),
30
+ (150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100),
31
+ (92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255),
32
+ (124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0),
33
+ (193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176),
34
+ (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
35
+ (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
36
+ (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
37
+ (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
38
+ (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
39
+ (146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140),
40
+ (96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152),
41
+ (208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0),
42
+ (0, 114, 143), (102, 102, 156), (250, 141, 255)
43
+ ]
44
+
45
+ class Colors:
46
+ # Ultralytics color palette https://ultralytics.com/
47
+ def __init__(self):
48
+ # hex = matplotlib.colors.TABLEAU_COLORS.values()
49
+ hexs = ('FF1010', '10FF10', 'FFF010', '100FFF', '0018EC', 'FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
50
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
51
+ self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
52
+ self.n = len(self.palette)
53
+
54
+ def __call__(self, i, bgr=True):
55
+ c = self.palette[int(i) % self.n]
56
+ return (c[2], c[1], c[0]) if bgr else c
57
+
58
+ @staticmethod
59
+ def hex2rgb(h): # rgb order (PIL)
60
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
61
+
62
+ colors = Colors()
63
+ def get_color(idx):
64
+ if idx == -1:
65
+ return 255
66
+ else:
67
+ return colors(idx)
68
+
69
+
70
+ MULTIPLE_TAGS = {'2girls', '3girls', '4girls', '5girls', '6+girls', 'multiple_girls',
71
+ '2boys', '3boys', '4boys', '5boys', '6+boys', 'multiple_boys',
72
+ '2others', '3others', '4others', '5others', '6+others', 'multiple_others'}
73
+
74
+ if hasattr(torch, 'cuda'):
75
+ DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
76
+ else:
77
+ DEFAULT_DEVICE = 'cpu'
78
+
79
+ DEFAULT_DETECTOR_CKPT = 'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
80
+ DEFAULT_DEPTHREFINE_CKPT = 'models/AnimeInstanceSegmentation/kenburns_depth_refinenet.ckpt'
81
+ DEFAULT_INPAINTNET_CKPT = 'models/AnimeInstanceSegmentation/kenburns_inpaintnet.ckpt'
82
+ DEPTH_ZOE_CKPT = 'models/AnimeInstanceSegmentation/ZoeD_M12_N.pt'
utils/cupy_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import cupy
4
+ import os.path as osp
5
+ import torch
6
+
7
+ @cupy.memoize(for_each_device=True)
8
+ def launch_kernel(strFunction, strKernel):
9
+ if 'CUDA_HOME' not in os.environ:
10
+ os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
11
+ # end
12
+ # , options=tuple([ '-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include' ])
13
+ return cupy.RawKernel(strKernel, strFunction)
14
+
15
+
16
+ def preprocess_kernel(strKernel, objVariables):
17
+ path_to_math_helper = osp.join(osp.dirname(osp.abspath(__file__)), 'helper_math.h')
18
+ strKernel = '''
19
+ #include <{{HELPER_PATH}}>
20
+
21
+ __device__ __forceinline__ float atomicMin(const float* buffer, float dblValue) {
22
+ int intValue = __float_as_int(*buffer);
23
+
24
+ while (__int_as_float(intValue) > dblValue) {
25
+ intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
26
+ }
27
+
28
+ return __int_as_float(intValue);
29
+ }
30
+
31
+
32
+ __device__ __forceinline__ float atomicMax(const float* buffer, float dblValue) {
33
+ int intValue = __float_as_int(*buffer);
34
+
35
+ while (__int_as_float(intValue) < dblValue) {
36
+ intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
37
+ }
38
+
39
+ return __int_as_float(intValue);
40
+ }
41
+ '''.replace('{{HELPER_PATH}}', path_to_math_helper) + strKernel
42
+ # end
43
+
44
+ for strVariable in objVariables:
45
+ objValue = objVariables[strVariable]
46
+
47
+ if type(objValue) == int:
48
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
49
+
50
+ elif type(objValue) == float:
51
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
52
+
53
+ elif type(objValue) == str:
54
+ strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
55
+
56
+ # end
57
+ # end
58
+
59
+ while True:
60
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
61
+
62
+ if objMatch is None:
63
+ break
64
+ # end
65
+
66
+ intArg = int(objMatch.group(2))
67
+
68
+ strTensor = objMatch.group(4)
69
+ intSizes = objVariables[strTensor].size()
70
+
71
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
72
+ # end
73
+
74
+ while True:
75
+ objMatch = re.search('(STRIDE_)([0-4])(\()([^\)]*)(\))', strKernel)
76
+
77
+ if objMatch is None:
78
+ break
79
+ # end
80
+
81
+ intArg = int(objMatch.group(2))
82
+
83
+ strTensor = objMatch.group(4)
84
+ intStrides = objVariables[strTensor].stride()
85
+
86
+ strKernel = strKernel.replace(objMatch.group(), str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()))
87
+ # end
88
+
89
+ while True:
90
+ objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
91
+
92
+ if objMatch is None:
93
+ break
94
+ # end
95
+
96
+ intArgs = int(objMatch.group(2))
97
+ strArgs = objMatch.group(4).split(',')
98
+
99
+ strTensor = strArgs[0]
100
+ intStrides = objVariables[strTensor].stride()
101
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
102
+
103
+ strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
104
+ # end
105
+
106
+ while True:
107
+ objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
108
+
109
+ if objMatch is None:
110
+ break
111
+ # end
112
+
113
+ intArgs = int(objMatch.group(2))
114
+ strArgs = objMatch.group(4).split(',')
115
+
116
+ strTensor = strArgs[0]
117
+ intStrides = objVariables[strTensor].stride()
118
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
119
+
120
+ strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
121
+ # end
122
+ return strKernel
utils/effects.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numba import jit, njit
2
+ import numpy as np
3
+ import time
4
+ import cv2
5
+ import math
6
+ from pathlib import Path
7
+ import os.path as osp
8
+ import torch
9
+ from .cupy_utils import launch_kernel, preprocess_kernel
10
+ import cupy
11
+
12
+ def bokeh_filter_cupy(img, depth, dx, dy, im_h, im_w, num_samples=32):
13
+ blurred = img.clone()
14
+ n = im_h * im_w
15
+
16
+ str_kernel = '''
17
+ extern "C" __global__ void kernel_bokeh(
18
+ const int n,
19
+ const int h,
20
+ const int w,
21
+ const int nsamples,
22
+ const float dx,
23
+ const float dy,
24
+ const float* img,
25
+ const float* depth,
26
+ float* blurred
27
+ ) {
28
+
29
+ const int im_size = min(h, w);
30
+ const int sample_offset = nsamples / 2;
31
+ for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n * 3; intIndex += blockDim.x * gridDim.x) {
32
+
33
+ const int intSample = intIndex / 3;
34
+
35
+ const int c = intIndex % 3;
36
+ const int y = ( intSample / w) % h;
37
+ const int x = intSample % w;
38
+
39
+ const int flatten_xy = y * w + x;
40
+ const int fid = flatten_xy * 3 + c;
41
+ const float d = depth[flatten_xy];
42
+
43
+ const float _dx = dx * d;
44
+ const float _dy = dy * d;
45
+ float weight = 0;
46
+ float color = 0;
47
+ for (int s = 0; s < nsamples; s += 1) {
48
+
49
+ const int sp = (s - sample_offset) * im_size;
50
+ const int x_ = x + int(round(_dx * sp));
51
+ const int y_ = y + int(round(_dy * sp));
52
+
53
+ if ((x_ >= w) | (y_ >= h) | (x_ < 0) | (y_ < 0))
54
+ continue;
55
+
56
+ const int flatten_xy_ = y_ * w + x_;
57
+ const float w_ = depth[flatten_xy_];
58
+ weight += w_;
59
+ const int fid_ = flatten_xy_ * 3 + c;
60
+ color += img[fid_] * w_;
61
+ }
62
+
63
+ if (weight != 0) {
64
+ color /= weight;
65
+ }
66
+ else {
67
+ color = img[fid];
68
+ }
69
+
70
+ blurred[fid] = color;
71
+
72
+ }
73
+
74
+ }
75
+ '''
76
+ launch_kernel('kernel_bokeh', str_kernel)(
77
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
78
+ block=tuple([ 512, 1, 1 ]),
79
+ args=[ cupy.int32(n), cupy.int32(im_h), cupy.int32(im_w), \
80
+ cupy.int32(num_samples), cupy.float32(dx), cupy.float32(dy),
81
+ img.data_ptr(), depth.data_ptr(), blurred.data_ptr() ]
82
+ )
83
+
84
+ return blurred
85
+
86
+
87
+ def np2flatten_tensor(arr: np.ndarray, to_cuda: bool = True) -> torch.Tensor:
88
+ c = 1
89
+ if len(arr.shape) == 3:
90
+ c = arr.shape[2]
91
+ else:
92
+ arr = arr[..., None]
93
+ arr = arr.transpose((2, 0, 1))[None, ...]
94
+ t = torch.from_numpy(arr).view(1, c, -1)
95
+
96
+ if to_cuda:
97
+ t = t.cuda()
98
+ return t
99
+
100
+ def ftensor2img(t: torch.Tensor, im_h, im_w):
101
+ t = t.detach().cpu().numpy().squeeze()
102
+ c = t.shape[0]
103
+ t = t.transpose((1, 0)).reshape((im_h, im_w, c))
104
+ return t
105
+
106
+
107
+ @njit
108
+ def bokeh_filter(img, depth, dx, dy, num_samples=32):
109
+
110
+ sample_offset = num_samples // 2
111
+ # _scale = 0.0005
112
+ # depth = depth * _scale
113
+
114
+ im_h, im_w = img.shape[0], img.shape[1]
115
+ im_size = min(im_h, im_w)
116
+ blured = np.zeros_like(img)
117
+ for x in range(im_w):
118
+ for y in range(im_h):
119
+ d = depth[y, x]
120
+ _color = np.array([0, 0, 0], dtype=np.float32)
121
+ _dx = dx * d
122
+ _dy = dy * d
123
+ weight = 0
124
+ for s in range(num_samples):
125
+ s = (s - sample_offset) * im_size
126
+ x_ = x + int(round(_dx * s))
127
+ y_ = y + int(round(_dy * s))
128
+ if x_ >= im_w or y_ >= im_h or x_ < 0 or y_ < 0:
129
+ continue
130
+ _w = depth[y_, x_]
131
+ weight += _w
132
+ _color += img[y_, x_] * _w
133
+ if weight == 0:
134
+ blured[y, x] = img[y, x]
135
+ else:
136
+ blured[y, x] = _color / np.array([weight, weight, weight], dtype=np.float32)
137
+
138
+ return blured
139
+
140
+
141
+
142
+
143
+ def bokeh_blur(img, depth, num_samples=32, lightness_factor=10, depth_factor=2, use_cuda=False, focal_plane=None):
144
+ img = np.ascontiguousarray(img)
145
+
146
+ if depth is not None:
147
+ depth = depth.astype(np.float32)
148
+ if focal_plane is not None:
149
+ depth = depth.max() - np.abs(depth - focal_plane)
150
+ if depth_factor != 1:
151
+ depth = np.power(depth, depth_factor)
152
+ depth = depth - depth.min()
153
+ depth = depth.astype(np.float32) / depth.max()
154
+ depth = 1 - depth
155
+
156
+ img = img.astype(np.float32) / 255
157
+ img_hightlighted = np.power(img, lightness_factor)
158
+
159
+ # img =
160
+ im_h, im_w = img.shape[:2]
161
+ PI = math.pi
162
+
163
+ _scale = 0.0005
164
+ depth = depth * _scale
165
+
166
+ if use_cuda:
167
+ img_hightlighted = np2flatten_tensor(img_hightlighted, True)
168
+ depth = np2flatten_tensor(depth, True)
169
+ vertical_blured = bokeh_filter_cupy(img_hightlighted, depth, 0, 1, im_h, im_w, num_samples)
170
+ diag_blured = bokeh_filter_cupy(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), im_h, im_w, num_samples)
171
+ rhom_blur = bokeh_filter_cupy(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), im_h, im_w, num_samples)
172
+ blured = (diag_blured + rhom_blur) / 2
173
+ blured = ftensor2img(blured, im_h, im_w)
174
+ else:
175
+ vertical_blured = bokeh_filter(img_hightlighted, depth, 0, 1, num_samples)
176
+ diag_blured = bokeh_filter(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), num_samples)
177
+ rhom_blur = bokeh_filter(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), num_samples)
178
+ blured = (diag_blured + rhom_blur) / 2
179
+ blured = np.power(blured, 1 / lightness_factor)
180
+ blured = (blured * 255).astype(np.uint8)
181
+
182
+ return blured
utils/env_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import warnings
4
+
5
+ import torch.multiprocessing as mp
6
+
7
+
8
+ def set_multi_processing(
9
+ mp_start_method: str = "fork", opencv_num_threads: int = 0, distributed: bool = True
10
+ ) -> None:
11
+ """Set multi-processing related environment.
12
+
13
+ This function is refered from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/dl_utils/setup_env.py
14
+
15
+ Args:
16
+ mp_start_method (str): Set the method which should be used to start
17
+ child processes. Defaults to 'fork'.
18
+ opencv_num_threads (int): Number of threads for opencv.
19
+ Defaults to 0.
20
+ distributed (bool): True if distributed environment.
21
+ Defaults to False.
22
+ """ # noqa
23
+ # set multi-process start method as `fork` to speed up the training
24
+ if platform.system() != "Windows":
25
+ current_method = mp.get_start_method(allow_none=True)
26
+ if current_method is not None and current_method != mp_start_method:
27
+ warnings.warn(
28
+ f"Multi-processing start method `{mp_start_method}` is "
29
+ f"different from the previous setting `{current_method}`."
30
+ f"It will be force set to `{mp_start_method}`. You can "
31
+ "change this behavior by changing `mp_start_method` in "
32
+ "your config."
33
+ )
34
+ mp.set_start_method(mp_start_method, force=True)
35
+
36
+ try:
37
+ import cv2
38
+
39
+ # disable opencv multithreading to avoid system being overloaded
40
+ cv2.setNumThreads(opencv_num_threads)
41
+ except ImportError:
42
+ pass
43
+
44
+ # setup OMP threads
45
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
46
+ if "OMP_NUM_THREADS" not in os.environ and distributed:
47
+ omp_num_threads = 1
48
+ warnings.warn(
49
+ "Setting OMP_NUM_THREADS environment variable for each process"
50
+ f" to be {omp_num_threads} in default, to avoid your system "
51
+ "being overloaded, please further tune the variable for "
52
+ "optimal performance in your application as needed."
53
+ )
54
+ os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
55
+
56
+ # # setup MKL threads
57
+ if "MKL_NUM_THREADS" not in os.environ and distributed:
58
+ mkl_num_threads = 1
59
+ warnings.warn(
60
+ "Setting MKL_NUM_THREADS environment variable for each process"
61
+ f" to be {mkl_num_threads} in default, to avoid your system "
62
+ "being overloaded, please further tune the variable for "
63
+ "optimal performance in your application as needed."
64
+ )
65
+ os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
utils/helper_math.h ADDED
@@ -0,0 +1,1449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 1993-2012 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * Please refer to the NVIDIA end user license agreement (EULA) associated
5
+ * with this source code for terms and conditions that govern your use of
6
+ * this software. Any use, reproduction, disclosure, or distribution of
7
+ * this software and related documentation outside the terms of the EULA
8
+ * is strictly prohibited.
9
+ *
10
+ */
11
+
12
+ /*
13
+ * This file implements common mathematical operations on vector types
14
+ * (float3, float4 etc.) since these are not provided as standard by CUDA.
15
+ *
16
+ * The syntax is modeled on the Cg standard library.
17
+ *
18
+ * This is part of the Helper library includes
19
+ *
20
+ * Thanks to Linh Hah for additions and fixes.
21
+ */
22
+
23
+ #ifndef HELPER_MATH_H
24
+ #define HELPER_MATH_H
25
+
26
+ #include "cuda_runtime.h"
27
+
28
+ typedef unsigned int uint;
29
+ typedef unsigned short ushort;
30
+
31
+ #ifndef __CUDACC__
32
+ #include <math.h>
33
+
34
+ ////////////////////////////////////////////////////////////////////////////////
35
+ // host implementations of CUDA functions
36
+ ////////////////////////////////////////////////////////////////////////////////
37
+
38
+ inline float fminf(float a, float b)
39
+ {
40
+ return a < b ? a : b;
41
+ }
42
+
43
+ inline float fmaxf(float a, float b)
44
+ {
45
+ return a > b ? a : b;
46
+ }
47
+
48
+ inline int max(int a, int b)
49
+ {
50
+ return a > b ? a : b;
51
+ }
52
+
53
+ inline int min(int a, int b)
54
+ {
55
+ return a < b ? a : b;
56
+ }
57
+
58
+ inline float rsqrtf(float x)
59
+ {
60
+ return 1.0f / sqrtf(x);
61
+ }
62
+ #endif
63
+
64
+ ////////////////////////////////////////////////////////////////////////////////
65
+ // constructors
66
+ ////////////////////////////////////////////////////////////////////////////////
67
+
68
+ inline __host__ __device__ float2 make_float2(float s)
69
+ {
70
+ return make_float2(s, s);
71
+ }
72
+ inline __host__ __device__ float2 make_float2(float3 a)
73
+ {
74
+ return make_float2(a.x, a.y);
75
+ }
76
+ inline __host__ __device__ float2 make_float2(int2 a)
77
+ {
78
+ return make_float2(float(a.x), float(a.y));
79
+ }
80
+ inline __host__ __device__ float2 make_float2(uint2 a)
81
+ {
82
+ return make_float2(float(a.x), float(a.y));
83
+ }
84
+
85
+ inline __host__ __device__ int2 make_int2(int s)
86
+ {
87
+ return make_int2(s, s);
88
+ }
89
+ inline __host__ __device__ int2 make_int2(int3 a)
90
+ {
91
+ return make_int2(a.x, a.y);
92
+ }
93
+ inline __host__ __device__ int2 make_int2(uint2 a)
94
+ {
95
+ return make_int2(int(a.x), int(a.y));
96
+ }
97
+ inline __host__ __device__ int2 make_int2(float2 a)
98
+ {
99
+ return make_int2(int(a.x), int(a.y));
100
+ }
101
+
102
+ inline __host__ __device__ uint2 make_uint2(uint s)
103
+ {
104
+ return make_uint2(s, s);
105
+ }
106
+ inline __host__ __device__ uint2 make_uint2(uint3 a)
107
+ {
108
+ return make_uint2(a.x, a.y);
109
+ }
110
+ inline __host__ __device__ uint2 make_uint2(int2 a)
111
+ {
112
+ return make_uint2(uint(a.x), uint(a.y));
113
+ }
114
+
115
+ inline __host__ __device__ float3 make_float3(float s)
116
+ {
117
+ return make_float3(s, s, s);
118
+ }
119
+ inline __host__ __device__ float3 make_float3(float2 a)
120
+ {
121
+ return make_float3(a.x, a.y, 0.0f);
122
+ }
123
+ inline __host__ __device__ float3 make_float3(float2 a, float s)
124
+ {
125
+ return make_float3(a.x, a.y, s);
126
+ }
127
+ inline __host__ __device__ float3 make_float3(float4 a)
128
+ {
129
+ return make_float3(a.x, a.y, a.z);
130
+ }
131
+ inline __host__ __device__ float3 make_float3(int3 a)
132
+ {
133
+ return make_float3(float(a.x), float(a.y), float(a.z));
134
+ }
135
+ inline __host__ __device__ float3 make_float3(uint3 a)
136
+ {
137
+ return make_float3(float(a.x), float(a.y), float(a.z));
138
+ }
139
+
140
+ inline __host__ __device__ int3 make_int3(int s)
141
+ {
142
+ return make_int3(s, s, s);
143
+ }
144
+ inline __host__ __device__ int3 make_int3(int2 a)
145
+ {
146
+ return make_int3(a.x, a.y, 0);
147
+ }
148
+ inline __host__ __device__ int3 make_int3(int2 a, int s)
149
+ {
150
+ return make_int3(a.x, a.y, s);
151
+ }
152
+ inline __host__ __device__ int3 make_int3(uint3 a)
153
+ {
154
+ return make_int3(int(a.x), int(a.y), int(a.z));
155
+ }
156
+ inline __host__ __device__ int3 make_int3(float3 a)
157
+ {
158
+ return make_int3(int(a.x), int(a.y), int(a.z));
159
+ }
160
+
161
+ inline __host__ __device__ uint3 make_uint3(uint s)
162
+ {
163
+ return make_uint3(s, s, s);
164
+ }
165
+ inline __host__ __device__ uint3 make_uint3(uint2 a)
166
+ {
167
+ return make_uint3(a.x, a.y, 0);
168
+ }
169
+ inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
170
+ {
171
+ return make_uint3(a.x, a.y, s);
172
+ }
173
+ inline __host__ __device__ uint3 make_uint3(uint4 a)
174
+ {
175
+ return make_uint3(a.x, a.y, a.z);
176
+ }
177
+ inline __host__ __device__ uint3 make_uint3(int3 a)
178
+ {
179
+ return make_uint3(uint(a.x), uint(a.y), uint(a.z));
180
+ }
181
+
182
+ inline __host__ __device__ float4 make_float4(float s)
183
+ {
184
+ return make_float4(s, s, s, s);
185
+ }
186
+ inline __host__ __device__ float4 make_float4(float3 a)
187
+ {
188
+ return make_float4(a.x, a.y, a.z, 0.0f);
189
+ }
190
+ inline __host__ __device__ float4 make_float4(float3 a, float w)
191
+ {
192
+ return make_float4(a.x, a.y, a.z, w);
193
+ }
194
+ inline __host__ __device__ float4 make_float4(int4 a)
195
+ {
196
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
197
+ }
198
+ inline __host__ __device__ float4 make_float4(uint4 a)
199
+ {
200
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
201
+ }
202
+
203
+ inline __host__ __device__ int4 make_int4(int s)
204
+ {
205
+ return make_int4(s, s, s, s);
206
+ }
207
+ inline __host__ __device__ int4 make_int4(int3 a)
208
+ {
209
+ return make_int4(a.x, a.y, a.z, 0);
210
+ }
211
+ inline __host__ __device__ int4 make_int4(int3 a, int w)
212
+ {
213
+ return make_int4(a.x, a.y, a.z, w);
214
+ }
215
+ inline __host__ __device__ int4 make_int4(uint4 a)
216
+ {
217
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
218
+ }
219
+ inline __host__ __device__ int4 make_int4(float4 a)
220
+ {
221
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
222
+ }
223
+
224
+
225
+ inline __host__ __device__ uint4 make_uint4(uint s)
226
+ {
227
+ return make_uint4(s, s, s, s);
228
+ }
229
+ inline __host__ __device__ uint4 make_uint4(uint3 a)
230
+ {
231
+ return make_uint4(a.x, a.y, a.z, 0);
232
+ }
233
+ inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
234
+ {
235
+ return make_uint4(a.x, a.y, a.z, w);
236
+ }
237
+ inline __host__ __device__ uint4 make_uint4(int4 a)
238
+ {
239
+ return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
240
+ }
241
+
242
+ ////////////////////////////////////////////////////////////////////////////////
243
+ // negate
244
+ ////////////////////////////////////////////////////////////////////////////////
245
+
246
+ inline __host__ __device__ float2 operator-(float2 &a)
247
+ {
248
+ return make_float2(-a.x, -a.y);
249
+ }
250
+ inline __host__ __device__ int2 operator-(int2 &a)
251
+ {
252
+ return make_int2(-a.x, -a.y);
253
+ }
254
+ inline __host__ __device__ float3 operator-(float3 &a)
255
+ {
256
+ return make_float3(-a.x, -a.y, -a.z);
257
+ }
258
+ inline __host__ __device__ int3 operator-(int3 &a)
259
+ {
260
+ return make_int3(-a.x, -a.y, -a.z);
261
+ }
262
+ inline __host__ __device__ float4 operator-(float4 &a)
263
+ {
264
+ return make_float4(-a.x, -a.y, -a.z, -a.w);
265
+ }
266
+ inline __host__ __device__ int4 operator-(int4 &a)
267
+ {
268
+ return make_int4(-a.x, -a.y, -a.z, -a.w);
269
+ }
270
+
271
+ ////////////////////////////////////////////////////////////////////////////////
272
+ // addition
273
+ ////////////////////////////////////////////////////////////////////////////////
274
+
275
+ inline __host__ __device__ float2 operator+(float2 a, float2 b)
276
+ {
277
+ return make_float2(a.x + b.x, a.y + b.y);
278
+ }
279
+ inline __host__ __device__ void operator+=(float2 &a, float2 b)
280
+ {
281
+ a.x += b.x;
282
+ a.y += b.y;
283
+ }
284
+ inline __host__ __device__ float2 operator+(float2 a, float b)
285
+ {
286
+ return make_float2(a.x + b, a.y + b);
287
+ }
288
+ inline __host__ __device__ float2 operator+(float b, float2 a)
289
+ {
290
+ return make_float2(a.x + b, a.y + b);
291
+ }
292
+ inline __host__ __device__ void operator+=(float2 &a, float b)
293
+ {
294
+ a.x += b;
295
+ a.y += b;
296
+ }
297
+
298
+ inline __host__ __device__ int2 operator+(int2 a, int2 b)
299
+ {
300
+ return make_int2(a.x + b.x, a.y + b.y);
301
+ }
302
+ inline __host__ __device__ void operator+=(int2 &a, int2 b)
303
+ {
304
+ a.x += b.x;
305
+ a.y += b.y;
306
+ }
307
+ inline __host__ __device__ int2 operator+(int2 a, int b)
308
+ {
309
+ return make_int2(a.x + b, a.y + b);
310
+ }
311
+ inline __host__ __device__ int2 operator+(int b, int2 a)
312
+ {
313
+ return make_int2(a.x + b, a.y + b);
314
+ }
315
+ inline __host__ __device__ void operator+=(int2 &a, int b)
316
+ {
317
+ a.x += b;
318
+ a.y += b;
319
+ }
320
+
321
+ inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
322
+ {
323
+ return make_uint2(a.x + b.x, a.y + b.y);
324
+ }
325
+ inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
326
+ {
327
+ a.x += b.x;
328
+ a.y += b.y;
329
+ }
330
+ inline __host__ __device__ uint2 operator+(uint2 a, uint b)
331
+ {
332
+ return make_uint2(a.x + b, a.y + b);
333
+ }
334
+ inline __host__ __device__ uint2 operator+(uint b, uint2 a)
335
+ {
336
+ return make_uint2(a.x + b, a.y + b);
337
+ }
338
+ inline __host__ __device__ void operator+=(uint2 &a, uint b)
339
+ {
340
+ a.x += b;
341
+ a.y += b;
342
+ }
343
+
344
+
345
+ inline __host__ __device__ float3 operator+(float3 a, float3 b)
346
+ {
347
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
348
+ }
349
+ inline __host__ __device__ void operator+=(float3 &a, float3 b)
350
+ {
351
+ a.x += b.x;
352
+ a.y += b.y;
353
+ a.z += b.z;
354
+ }
355
+ inline __host__ __device__ float3 operator+(float3 a, float b)
356
+ {
357
+ return make_float3(a.x + b, a.y + b, a.z + b);
358
+ }
359
+ inline __host__ __device__ void operator+=(float3 &a, float b)
360
+ {
361
+ a.x += b;
362
+ a.y += b;
363
+ a.z += b;
364
+ }
365
+
366
+ inline __host__ __device__ int3 operator+(int3 a, int3 b)
367
+ {
368
+ return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
369
+ }
370
+ inline __host__ __device__ void operator+=(int3 &a, int3 b)
371
+ {
372
+ a.x += b.x;
373
+ a.y += b.y;
374
+ a.z += b.z;
375
+ }
376
+ inline __host__ __device__ int3 operator+(int3 a, int b)
377
+ {
378
+ return make_int3(a.x + b, a.y + b, a.z + b);
379
+ }
380
+ inline __host__ __device__ void operator+=(int3 &a, int b)
381
+ {
382
+ a.x += b;
383
+ a.y += b;
384
+ a.z += b;
385
+ }
386
+
387
+ inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
388
+ {
389
+ return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
390
+ }
391
+ inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
392
+ {
393
+ a.x += b.x;
394
+ a.y += b.y;
395
+ a.z += b.z;
396
+ }
397
+ inline __host__ __device__ uint3 operator+(uint3 a, uint b)
398
+ {
399
+ return make_uint3(a.x + b, a.y + b, a.z + b);
400
+ }
401
+ inline __host__ __device__ void operator+=(uint3 &a, uint b)
402
+ {
403
+ a.x += b;
404
+ a.y += b;
405
+ a.z += b;
406
+ }
407
+
408
+ inline __host__ __device__ int3 operator+(int b, int3 a)
409
+ {
410
+ return make_int3(a.x + b, a.y + b, a.z + b);
411
+ }
412
+ inline __host__ __device__ uint3 operator+(uint b, uint3 a)
413
+ {
414
+ return make_uint3(a.x + b, a.y + b, a.z + b);
415
+ }
416
+ inline __host__ __device__ float3 operator+(float b, float3 a)
417
+ {
418
+ return make_float3(a.x + b, a.y + b, a.z + b);
419
+ }
420
+
421
+ inline __host__ __device__ float4 operator+(float4 a, float4 b)
422
+ {
423
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
424
+ }
425
+ inline __host__ __device__ void operator+=(float4 &a, float4 b)
426
+ {
427
+ a.x += b.x;
428
+ a.y += b.y;
429
+ a.z += b.z;
430
+ a.w += b.w;
431
+ }
432
+ inline __host__ __device__ float4 operator+(float4 a, float b)
433
+ {
434
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
435
+ }
436
+ inline __host__ __device__ float4 operator+(float b, float4 a)
437
+ {
438
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
439
+ }
440
+ inline __host__ __device__ void operator+=(float4 &a, float b)
441
+ {
442
+ a.x += b;
443
+ a.y += b;
444
+ a.z += b;
445
+ a.w += b;
446
+ }
447
+
448
+ inline __host__ __device__ int4 operator+(int4 a, int4 b)
449
+ {
450
+ return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
451
+ }
452
+ inline __host__ __device__ void operator+=(int4 &a, int4 b)
453
+ {
454
+ a.x += b.x;
455
+ a.y += b.y;
456
+ a.z += b.z;
457
+ a.w += b.w;
458
+ }
459
+ inline __host__ __device__ int4 operator+(int4 a, int b)
460
+ {
461
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
462
+ }
463
+ inline __host__ __device__ int4 operator+(int b, int4 a)
464
+ {
465
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
466
+ }
467
+ inline __host__ __device__ void operator+=(int4 &a, int b)
468
+ {
469
+ a.x += b;
470
+ a.y += b;
471
+ a.z += b;
472
+ a.w += b;
473
+ }
474
+
475
+ inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
476
+ {
477
+ return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
478
+ }
479
+ inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
480
+ {
481
+ a.x += b.x;
482
+ a.y += b.y;
483
+ a.z += b.z;
484
+ a.w += b.w;
485
+ }
486
+ inline __host__ __device__ uint4 operator+(uint4 a, uint b)
487
+ {
488
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
489
+ }
490
+ inline __host__ __device__ uint4 operator+(uint b, uint4 a)
491
+ {
492
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
493
+ }
494
+ inline __host__ __device__ void operator+=(uint4 &a, uint b)
495
+ {
496
+ a.x += b;
497
+ a.y += b;
498
+ a.z += b;
499
+ a.w += b;
500
+ }
501
+
502
+ ////////////////////////////////////////////////////////////////////////////////
503
+ // subtract
504
+ ////////////////////////////////////////////////////////////////////////////////
505
+
506
+ inline __host__ __device__ float2 operator-(float2 a, float2 b)
507
+ {
508
+ return make_float2(a.x - b.x, a.y - b.y);
509
+ }
510
+ inline __host__ __device__ void operator-=(float2 &a, float2 b)
511
+ {
512
+ a.x -= b.x;
513
+ a.y -= b.y;
514
+ }
515
+ inline __host__ __device__ float2 operator-(float2 a, float b)
516
+ {
517
+ return make_float2(a.x - b, a.y - b);
518
+ }
519
+ inline __host__ __device__ float2 operator-(float b, float2 a)
520
+ {
521
+ return make_float2(b - a.x, b - a.y);
522
+ }
523
+ inline __host__ __device__ void operator-=(float2 &a, float b)
524
+ {
525
+ a.x -= b;
526
+ a.y -= b;
527
+ }
528
+
529
+ inline __host__ __device__ int2 operator-(int2 a, int2 b)
530
+ {
531
+ return make_int2(a.x - b.x, a.y - b.y);
532
+ }
533
+ inline __host__ __device__ void operator-=(int2 &a, int2 b)
534
+ {
535
+ a.x -= b.x;
536
+ a.y -= b.y;
537
+ }
538
+ inline __host__ __device__ int2 operator-(int2 a, int b)
539
+ {
540
+ return make_int2(a.x - b, a.y - b);
541
+ }
542
+ inline __host__ __device__ int2 operator-(int b, int2 a)
543
+ {
544
+ return make_int2(b - a.x, b - a.y);
545
+ }
546
+ inline __host__ __device__ void operator-=(int2 &a, int b)
547
+ {
548
+ a.x -= b;
549
+ a.y -= b;
550
+ }
551
+
552
+ inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
553
+ {
554
+ return make_uint2(a.x - b.x, a.y - b.y);
555
+ }
556
+ inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
557
+ {
558
+ a.x -= b.x;
559
+ a.y -= b.y;
560
+ }
561
+ inline __host__ __device__ uint2 operator-(uint2 a, uint b)
562
+ {
563
+ return make_uint2(a.x - b, a.y - b);
564
+ }
565
+ inline __host__ __device__ uint2 operator-(uint b, uint2 a)
566
+ {
567
+ return make_uint2(b - a.x, b - a.y);
568
+ }
569
+ inline __host__ __device__ void operator-=(uint2 &a, uint b)
570
+ {
571
+ a.x -= b;
572
+ a.y -= b;
573
+ }
574
+
575
+ inline __host__ __device__ float3 operator-(float3 a, float3 b)
576
+ {
577
+ return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
578
+ }
579
+ inline __host__ __device__ void operator-=(float3 &a, float3 b)
580
+ {
581
+ a.x -= b.x;
582
+ a.y -= b.y;
583
+ a.z -= b.z;
584
+ }
585
+ inline __host__ __device__ float3 operator-(float3 a, float b)
586
+ {
587
+ return make_float3(a.x - b, a.y - b, a.z - b);
588
+ }
589
+ inline __host__ __device__ float3 operator-(float b, float3 a)
590
+ {
591
+ return make_float3(b - a.x, b - a.y, b - a.z);
592
+ }
593
+ inline __host__ __device__ void operator-=(float3 &a, float b)
594
+ {
595
+ a.x -= b;
596
+ a.y -= b;
597
+ a.z -= b;
598
+ }
599
+
600
+ inline __host__ __device__ int3 operator-(int3 a, int3 b)
601
+ {
602
+ return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
603
+ }
604
+ inline __host__ __device__ void operator-=(int3 &a, int3 b)
605
+ {
606
+ a.x -= b.x;
607
+ a.y -= b.y;
608
+ a.z -= b.z;
609
+ }
610
+ inline __host__ __device__ int3 operator-(int3 a, int b)
611
+ {
612
+ return make_int3(a.x - b, a.y - b, a.z - b);
613
+ }
614
+ inline __host__ __device__ int3 operator-(int b, int3 a)
615
+ {
616
+ return make_int3(b - a.x, b - a.y, b - a.z);
617
+ }
618
+ inline __host__ __device__ void operator-=(int3 &a, int b)
619
+ {
620
+ a.x -= b;
621
+ a.y -= b;
622
+ a.z -= b;
623
+ }
624
+
625
+ inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
626
+ {
627
+ return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
628
+ }
629
+ inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
630
+ {
631
+ a.x -= b.x;
632
+ a.y -= b.y;
633
+ a.z -= b.z;
634
+ }
635
+ inline __host__ __device__ uint3 operator-(uint3 a, uint b)
636
+ {
637
+ return make_uint3(a.x - b, a.y - b, a.z - b);
638
+ }
639
+ inline __host__ __device__ uint3 operator-(uint b, uint3 a)
640
+ {
641
+ return make_uint3(b - a.x, b - a.y, b - a.z);
642
+ }
643
+ inline __host__ __device__ void operator-=(uint3 &a, uint b)
644
+ {
645
+ a.x -= b;
646
+ a.y -= b;
647
+ a.z -= b;
648
+ }
649
+
650
+ inline __host__ __device__ float4 operator-(float4 a, float4 b)
651
+ {
652
+ return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
653
+ }
654
+ inline __host__ __device__ void operator-=(float4 &a, float4 b)
655
+ {
656
+ a.x -= b.x;
657
+ a.y -= b.y;
658
+ a.z -= b.z;
659
+ a.w -= b.w;
660
+ }
661
+ inline __host__ __device__ float4 operator-(float4 a, float b)
662
+ {
663
+ return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
664
+ }
665
+ inline __host__ __device__ void operator-=(float4 &a, float b)
666
+ {
667
+ a.x -= b;
668
+ a.y -= b;
669
+ a.z -= b;
670
+ a.w -= b;
671
+ }
672
+
673
+ inline __host__ __device__ int4 operator-(int4 a, int4 b)
674
+ {
675
+ return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
676
+ }
677
+ inline __host__ __device__ void operator-=(int4 &a, int4 b)
678
+ {
679
+ a.x -= b.x;
680
+ a.y -= b.y;
681
+ a.z -= b.z;
682
+ a.w -= b.w;
683
+ }
684
+ inline __host__ __device__ int4 operator-(int4 a, int b)
685
+ {
686
+ return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
687
+ }
688
+ inline __host__ __device__ int4 operator-(int b, int4 a)
689
+ {
690
+ return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
691
+ }
692
+ inline __host__ __device__ void operator-=(int4 &a, int b)
693
+ {
694
+ a.x -= b;
695
+ a.y -= b;
696
+ a.z -= b;
697
+ a.w -= b;
698
+ }
699
+
700
+ inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
701
+ {
702
+ return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
703
+ }
704
+ inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
705
+ {
706
+ a.x -= b.x;
707
+ a.y -= b.y;
708
+ a.z -= b.z;
709
+ a.w -= b.w;
710
+ }
711
+ inline __host__ __device__ uint4 operator-(uint4 a, uint b)
712
+ {
713
+ return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
714
+ }
715
+ inline __host__ __device__ uint4 operator-(uint b, uint4 a)
716
+ {
717
+ return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
718
+ }
719
+ inline __host__ __device__ void operator-=(uint4 &a, uint b)
720
+ {
721
+ a.x -= b;
722
+ a.y -= b;
723
+ a.z -= b;
724
+ a.w -= b;
725
+ }
726
+
727
+ ////////////////////////////////////////////////////////////////////////////////
728
+ // multiply
729
+ ////////////////////////////////////////////////////////////////////////////////
730
+
731
+ inline __host__ __device__ float2 operator*(float2 a, float2 b)
732
+ {
733
+ return make_float2(a.x * b.x, a.y * b.y);
734
+ }
735
+ inline __host__ __device__ void operator*=(float2 &a, float2 b)
736
+ {
737
+ a.x *= b.x;
738
+ a.y *= b.y;
739
+ }
740
+ inline __host__ __device__ float2 operator*(float2 a, float b)
741
+ {
742
+ return make_float2(a.x * b, a.y * b);
743
+ }
744
+ inline __host__ __device__ float2 operator*(float b, float2 a)
745
+ {
746
+ return make_float2(b * a.x, b * a.y);
747
+ }
748
+ inline __host__ __device__ void operator*=(float2 &a, float b)
749
+ {
750
+ a.x *= b;
751
+ a.y *= b;
752
+ }
753
+
754
+ inline __host__ __device__ int2 operator*(int2 a, int2 b)
755
+ {
756
+ return make_int2(a.x * b.x, a.y * b.y);
757
+ }
758
+ inline __host__ __device__ void operator*=(int2 &a, int2 b)
759
+ {
760
+ a.x *= b.x;
761
+ a.y *= b.y;
762
+ }
763
+ inline __host__ __device__ int2 operator*(int2 a, int b)
764
+ {
765
+ return make_int2(a.x * b, a.y * b);
766
+ }
767
+ inline __host__ __device__ int2 operator*(int b, int2 a)
768
+ {
769
+ return make_int2(b * a.x, b * a.y);
770
+ }
771
+ inline __host__ __device__ void operator*=(int2 &a, int b)
772
+ {
773
+ a.x *= b;
774
+ a.y *= b;
775
+ }
776
+
777
+ inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
778
+ {
779
+ return make_uint2(a.x * b.x, a.y * b.y);
780
+ }
781
+ inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
782
+ {
783
+ a.x *= b.x;
784
+ a.y *= b.y;
785
+ }
786
+ inline __host__ __device__ uint2 operator*(uint2 a, uint b)
787
+ {
788
+ return make_uint2(a.x * b, a.y * b);
789
+ }
790
+ inline __host__ __device__ uint2 operator*(uint b, uint2 a)
791
+ {
792
+ return make_uint2(b * a.x, b * a.y);
793
+ }
794
+ inline __host__ __device__ void operator*=(uint2 &a, uint b)
795
+ {
796
+ a.x *= b;
797
+ a.y *= b;
798
+ }
799
+
800
+ inline __host__ __device__ float3 operator*(float3 a, float3 b)
801
+ {
802
+ return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
803
+ }
804
+ inline __host__ __device__ void operator*=(float3 &a, float3 b)
805
+ {
806
+ a.x *= b.x;
807
+ a.y *= b.y;
808
+ a.z *= b.z;
809
+ }
810
+ inline __host__ __device__ float3 operator*(float3 a, float b)
811
+ {
812
+ return make_float3(a.x * b, a.y * b, a.z * b);
813
+ }
814
+ inline __host__ __device__ float3 operator*(float b, float3 a)
815
+ {
816
+ return make_float3(b * a.x, b * a.y, b * a.z);
817
+ }
818
+ inline __host__ __device__ void operator*=(float3 &a, float b)
819
+ {
820
+ a.x *= b;
821
+ a.y *= b;
822
+ a.z *= b;
823
+ }
824
+
825
+ inline __host__ __device__ int3 operator*(int3 a, int3 b)
826
+ {
827
+ return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
828
+ }
829
+ inline __host__ __device__ void operator*=(int3 &a, int3 b)
830
+ {
831
+ a.x *= b.x;
832
+ a.y *= b.y;
833
+ a.z *= b.z;
834
+ }
835
+ inline __host__ __device__ int3 operator*(int3 a, int b)
836
+ {
837
+ return make_int3(a.x * b, a.y * b, a.z * b);
838
+ }
839
+ inline __host__ __device__ int3 operator*(int b, int3 a)
840
+ {
841
+ return make_int3(b * a.x, b * a.y, b * a.z);
842
+ }
843
+ inline __host__ __device__ void operator*=(int3 &a, int b)
844
+ {
845
+ a.x *= b;
846
+ a.y *= b;
847
+ a.z *= b;
848
+ }
849
+
850
+ inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
851
+ {
852
+ return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
853
+ }
854
+ inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
855
+ {
856
+ a.x *= b.x;
857
+ a.y *= b.y;
858
+ a.z *= b.z;
859
+ }
860
+ inline __host__ __device__ uint3 operator*(uint3 a, uint b)
861
+ {
862
+ return make_uint3(a.x * b, a.y * b, a.z * b);
863
+ }
864
+ inline __host__ __device__ uint3 operator*(uint b, uint3 a)
865
+ {
866
+ return make_uint3(b * a.x, b * a.y, b * a.z);
867
+ }
868
+ inline __host__ __device__ void operator*=(uint3 &a, uint b)
869
+ {
870
+ a.x *= b;
871
+ a.y *= b;
872
+ a.z *= b;
873
+ }
874
+
875
+ inline __host__ __device__ float4 operator*(float4 a, float4 b)
876
+ {
877
+ return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
878
+ }
879
+ inline __host__ __device__ void operator*=(float4 &a, float4 b)
880
+ {
881
+ a.x *= b.x;
882
+ a.y *= b.y;
883
+ a.z *= b.z;
884
+ a.w *= b.w;
885
+ }
886
+ inline __host__ __device__ float4 operator*(float4 a, float b)
887
+ {
888
+ return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
889
+ }
890
+ inline __host__ __device__ float4 operator*(float b, float4 a)
891
+ {
892
+ return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
893
+ }
894
+ inline __host__ __device__ void operator*=(float4 &a, float b)
895
+ {
896
+ a.x *= b;
897
+ a.y *= b;
898
+ a.z *= b;
899
+ a.w *= b;
900
+ }
901
+
902
+ inline __host__ __device__ int4 operator*(int4 a, int4 b)
903
+ {
904
+ return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
905
+ }
906
+ inline __host__ __device__ void operator*=(int4 &a, int4 b)
907
+ {
908
+ a.x *= b.x;
909
+ a.y *= b.y;
910
+ a.z *= b.z;
911
+ a.w *= b.w;
912
+ }
913
+ inline __host__ __device__ int4 operator*(int4 a, int b)
914
+ {
915
+ return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
916
+ }
917
+ inline __host__ __device__ int4 operator*(int b, int4 a)
918
+ {
919
+ return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
920
+ }
921
+ inline __host__ __device__ void operator*=(int4 &a, int b)
922
+ {
923
+ a.x *= b;
924
+ a.y *= b;
925
+ a.z *= b;
926
+ a.w *= b;
927
+ }
928
+
929
+ inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
930
+ {
931
+ return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
932
+ }
933
+ inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
934
+ {
935
+ a.x *= b.x;
936
+ a.y *= b.y;
937
+ a.z *= b.z;
938
+ a.w *= b.w;
939
+ }
940
+ inline __host__ __device__ uint4 operator*(uint4 a, uint b)
941
+ {
942
+ return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
943
+ }
944
+ inline __host__ __device__ uint4 operator*(uint b, uint4 a)
945
+ {
946
+ return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
947
+ }
948
+ inline __host__ __device__ void operator*=(uint4 &a, uint b)
949
+ {
950
+ a.x *= b;
951
+ a.y *= b;
952
+ a.z *= b;
953
+ a.w *= b;
954
+ }
955
+
956
+ ////////////////////////////////////////////////////////////////////////////////
957
+ // divide
958
+ ////////////////////////////////////////////////////////////////////////////////
959
+
960
+ inline __host__ __device__ float2 operator/(float2 a, float2 b)
961
+ {
962
+ return make_float2(a.x / b.x, a.y / b.y);
963
+ }
964
+ inline __host__ __device__ void operator/=(float2 &a, float2 b)
965
+ {
966
+ a.x /= b.x;
967
+ a.y /= b.y;
968
+ }
969
+ inline __host__ __device__ float2 operator/(float2 a, float b)
970
+ {
971
+ return make_float2(a.x / b, a.y / b);
972
+ }
973
+ inline __host__ __device__ void operator/=(float2 &a, float b)
974
+ {
975
+ a.x /= b;
976
+ a.y /= b;
977
+ }
978
+ inline __host__ __device__ float2 operator/(float b, float2 a)
979
+ {
980
+ return make_float2(b / a.x, b / a.y);
981
+ }
982
+
983
+ inline __host__ __device__ float3 operator/(float3 a, float3 b)
984
+ {
985
+ return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
986
+ }
987
+ inline __host__ __device__ void operator/=(float3 &a, float3 b)
988
+ {
989
+ a.x /= b.x;
990
+ a.y /= b.y;
991
+ a.z /= b.z;
992
+ }
993
+ inline __host__ __device__ float3 operator/(float3 a, float b)
994
+ {
995
+ return make_float3(a.x / b, a.y / b, a.z / b);
996
+ }
997
+ inline __host__ __device__ void operator/=(float3 &a, float b)
998
+ {
999
+ a.x /= b;
1000
+ a.y /= b;
1001
+ a.z /= b;
1002
+ }
1003
+ inline __host__ __device__ float3 operator/(float b, float3 a)
1004
+ {
1005
+ return make_float3(b / a.x, b / a.y, b / a.z);
1006
+ }
1007
+
1008
+ inline __host__ __device__ float4 operator/(float4 a, float4 b)
1009
+ {
1010
+ return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
1011
+ }
1012
+ inline __host__ __device__ void operator/=(float4 &a, float4 b)
1013
+ {
1014
+ a.x /= b.x;
1015
+ a.y /= b.y;
1016
+ a.z /= b.z;
1017
+ a.w /= b.w;
1018
+ }
1019
+ inline __host__ __device__ float4 operator/(float4 a, float b)
1020
+ {
1021
+ return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
1022
+ }
1023
+ inline __host__ __device__ void operator/=(float4 &a, float b)
1024
+ {
1025
+ a.x /= b;
1026
+ a.y /= b;
1027
+ a.z /= b;
1028
+ a.w /= b;
1029
+ }
1030
+ inline __host__ __device__ float4 operator/(float b, float4 a)
1031
+ {
1032
+ return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
1033
+ }
1034
+
1035
+ ////////////////////////////////////////////////////////////////////////////////
1036
+ // min
1037
+ ////////////////////////////////////////////////////////////////////////////////
1038
+
1039
+ inline __host__ __device__ float2 fminf(float2 a, float2 b)
1040
+ {
1041
+ return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
1042
+ }
1043
+ inline __host__ __device__ float3 fminf(float3 a, float3 b)
1044
+ {
1045
+ return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
1046
+ }
1047
+ inline __host__ __device__ float4 fminf(float4 a, float4 b)
1048
+ {
1049
+ return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
1050
+ }
1051
+
1052
+ inline __host__ __device__ int2 min(int2 a, int2 b)
1053
+ {
1054
+ return make_int2(min(a.x,b.x), min(a.y,b.y));
1055
+ }
1056
+ inline __host__ __device__ int3 min(int3 a, int3 b)
1057
+ {
1058
+ return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1059
+ }
1060
+ inline __host__ __device__ int4 min(int4 a, int4 b)
1061
+ {
1062
+ return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1063
+ }
1064
+
1065
+ inline __host__ __device__ uint2 min(uint2 a, uint2 b)
1066
+ {
1067
+ return make_uint2(min(a.x,b.x), min(a.y,b.y));
1068
+ }
1069
+ inline __host__ __device__ uint3 min(uint3 a, uint3 b)
1070
+ {
1071
+ return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1072
+ }
1073
+ inline __host__ __device__ uint4 min(uint4 a, uint4 b)
1074
+ {
1075
+ return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1076
+ }
1077
+
1078
+ ////////////////////////////////////////////////////////////////////////////////
1079
+ // max
1080
+ ////////////////////////////////////////////////////////////////////////////////
1081
+
1082
+ inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
1083
+ {
1084
+ return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
1085
+ }
1086
+ inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
1087
+ {
1088
+ return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
1089
+ }
1090
+ inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
1091
+ {
1092
+ return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
1093
+ }
1094
+
1095
+ inline __host__ __device__ int2 max(int2 a, int2 b)
1096
+ {
1097
+ return make_int2(max(a.x,b.x), max(a.y,b.y));
1098
+ }
1099
+ inline __host__ __device__ int3 max(int3 a, int3 b)
1100
+ {
1101
+ return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1102
+ }
1103
+ inline __host__ __device__ int4 max(int4 a, int4 b)
1104
+ {
1105
+ return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1106
+ }
1107
+
1108
+ inline __host__ __device__ uint2 max(uint2 a, uint2 b)
1109
+ {
1110
+ return make_uint2(max(a.x,b.x), max(a.y,b.y));
1111
+ }
1112
+ inline __host__ __device__ uint3 max(uint3 a, uint3 b)
1113
+ {
1114
+ return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1115
+ }
1116
+ inline __host__ __device__ uint4 max(uint4 a, uint4 b)
1117
+ {
1118
+ return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1119
+ }
1120
+
1121
+ ////////////////////////////////////////////////////////////////////////////////
1122
+ // lerp
1123
+ // - linear interpolation between a and b, based on value t in [0, 1] range
1124
+ ////////////////////////////////////////////////////////////////////////////////
1125
+
1126
+ inline __device__ __host__ float lerp(float a, float b, float t)
1127
+ {
1128
+ return a + t*(b-a);
1129
+ }
1130
+ inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
1131
+ {
1132
+ return a + t*(b-a);
1133
+ }
1134
+ inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
1135
+ {
1136
+ return a + t*(b-a);
1137
+ }
1138
+ inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
1139
+ {
1140
+ return a + t*(b-a);
1141
+ }
1142
+
1143
+ ////////////////////////////////////////////////////////////////////////////////
1144
+ // clamp
1145
+ // - clamp the value v to be in the range [a, b]
1146
+ ////////////////////////////////////////////////////////////////////////////////
1147
+
1148
+ inline __device__ __host__ float clamp(float f, float a, float b)
1149
+ {
1150
+ return fmaxf(a, fminf(f, b));
1151
+ }
1152
+ inline __device__ __host__ int clamp(int f, int a, int b)
1153
+ {
1154
+ return max(a, min(f, b));
1155
+ }
1156
+ inline __device__ __host__ uint clamp(uint f, uint a, uint b)
1157
+ {
1158
+ return max(a, min(f, b));
1159
+ }
1160
+
1161
+ inline __device__ __host__ float2 clamp(float2 v, float a, float b)
1162
+ {
1163
+ return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
1164
+ }
1165
+ inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
1166
+ {
1167
+ return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1168
+ }
1169
+ inline __device__ __host__ float3 clamp(float3 v, float a, float b)
1170
+ {
1171
+ return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1172
+ }
1173
+ inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
1174
+ {
1175
+ return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1176
+ }
1177
+ inline __device__ __host__ float4 clamp(float4 v, float a, float b)
1178
+ {
1179
+ return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1180
+ }
1181
+ inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
1182
+ {
1183
+ return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1184
+ }
1185
+
1186
+ inline __device__ __host__ int2 clamp(int2 v, int a, int b)
1187
+ {
1188
+ return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
1189
+ }
1190
+ inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
1191
+ {
1192
+ return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1193
+ }
1194
+ inline __device__ __host__ int3 clamp(int3 v, int a, int b)
1195
+ {
1196
+ return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1197
+ }
1198
+ inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
1199
+ {
1200
+ return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1201
+ }
1202
+ inline __device__ __host__ int4 clamp(int4 v, int a, int b)
1203
+ {
1204
+ return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1205
+ }
1206
+ inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
1207
+ {
1208
+ return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1209
+ }
1210
+
1211
+ inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
1212
+ {
1213
+ return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
1214
+ }
1215
+ inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
1216
+ {
1217
+ return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1218
+ }
1219
+ inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
1220
+ {
1221
+ return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1222
+ }
1223
+ inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
1224
+ {
1225
+ return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1226
+ }
1227
+ inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
1228
+ {
1229
+ return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1230
+ }
1231
+ inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
1232
+ {
1233
+ return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1234
+ }
1235
+
1236
+ ////////////////////////////////////////////////////////////////////////////////
1237
+ // dot product
1238
+ ////////////////////////////////////////////////////////////////////////////////
1239
+
1240
+ inline __host__ __device__ float dot(float2 a, float2 b)
1241
+ {
1242
+ return a.x * b.x + a.y * b.y;
1243
+ }
1244
+ inline __host__ __device__ float dot(float3 a, float3 b)
1245
+ {
1246
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1247
+ }
1248
+ inline __host__ __device__ float dot(float4 a, float4 b)
1249
+ {
1250
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1251
+ }
1252
+
1253
+ inline __host__ __device__ int dot(int2 a, int2 b)
1254
+ {
1255
+ return a.x * b.x + a.y * b.y;
1256
+ }
1257
+ inline __host__ __device__ int dot(int3 a, int3 b)
1258
+ {
1259
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1260
+ }
1261
+ inline __host__ __device__ int dot(int4 a, int4 b)
1262
+ {
1263
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1264
+ }
1265
+
1266
+ inline __host__ __device__ uint dot(uint2 a, uint2 b)
1267
+ {
1268
+ return a.x * b.x + a.y * b.y;
1269
+ }
1270
+ inline __host__ __device__ uint dot(uint3 a, uint3 b)
1271
+ {
1272
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1273
+ }
1274
+ inline __host__ __device__ uint dot(uint4 a, uint4 b)
1275
+ {
1276
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1277
+ }
1278
+
1279
+ ////////////////////////////////////////////////////////////////////////////////
1280
+ // length
1281
+ ////////////////////////////////////////////////////////////////////////////////
1282
+
1283
+ inline __host__ __device__ float length(float2 v)
1284
+ {
1285
+ return sqrtf(dot(v, v));
1286
+ }
1287
+ inline __host__ __device__ float length(float3 v)
1288
+ {
1289
+ return sqrtf(dot(v, v));
1290
+ }
1291
+ inline __host__ __device__ float length(float4 v)
1292
+ {
1293
+ return sqrtf(dot(v, v));
1294
+ }
1295
+
1296
+ ////////////////////////////////////////////////////////////////////////////////
1297
+ // normalize
1298
+ ////////////////////////////////////////////////////////////////////////////////
1299
+
1300
+ inline __host__ __device__ float2 normalize(float2 v)
1301
+ {
1302
+ float invLen = rsqrtf(dot(v, v));
1303
+ return v * invLen;
1304
+ }
1305
+ inline __host__ __device__ float3 normalize(float3 v)
1306
+ {
1307
+ float invLen = rsqrtf(dot(v, v));
1308
+ return v * invLen;
1309
+ }
1310
+ inline __host__ __device__ float4 normalize(float4 v)
1311
+ {
1312
+ float invLen = rsqrtf(dot(v, v));
1313
+ return v * invLen;
1314
+ }
1315
+
1316
+ ////////////////////////////////////////////////////////////////////////////////
1317
+ // floor
1318
+ ////////////////////////////////////////////////////////////////////////////////
1319
+
1320
+ inline __host__ __device__ float2 floorf(float2 v)
1321
+ {
1322
+ return make_float2(floorf(v.x), floorf(v.y));
1323
+ }
1324
+ inline __host__ __device__ float3 floorf(float3 v)
1325
+ {
1326
+ return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
1327
+ }
1328
+ inline __host__ __device__ float4 floorf(float4 v)
1329
+ {
1330
+ return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
1331
+ }
1332
+
1333
+ ////////////////////////////////////////////////////////////////////////////////
1334
+ // frac - returns the fractional portion of a scalar or each vector component
1335
+ ////////////////////////////////////////////////////////////////////////////////
1336
+
1337
+ inline __host__ __device__ float fracf(float v)
1338
+ {
1339
+ return v - floorf(v);
1340
+ }
1341
+ inline __host__ __device__ float2 fracf(float2 v)
1342
+ {
1343
+ return make_float2(fracf(v.x), fracf(v.y));
1344
+ }
1345
+ inline __host__ __device__ float3 fracf(float3 v)
1346
+ {
1347
+ return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
1348
+ }
1349
+ inline __host__ __device__ float4 fracf(float4 v)
1350
+ {
1351
+ return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
1352
+ }
1353
+
1354
+ ////////////////////////////////////////////////////////////////////////////////
1355
+ // fmod
1356
+ ////////////////////////////////////////////////////////////////////////////////
1357
+
1358
+ inline __host__ __device__ float2 fmodf(float2 a, float2 b)
1359
+ {
1360
+ return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
1361
+ }
1362
+ inline __host__ __device__ float3 fmodf(float3 a, float3 b)
1363
+ {
1364
+ return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
1365
+ }
1366
+ inline __host__ __device__ float4 fmodf(float4 a, float4 b)
1367
+ {
1368
+ return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
1369
+ }
1370
+
1371
+ ////////////////////////////////////////////////////////////////////////////////
1372
+ // absolute value
1373
+ ////////////////////////////////////////////////////////////////////////////////
1374
+
1375
+ inline __host__ __device__ float2 fabs(float2 v)
1376
+ {
1377
+ return make_float2(fabs(v.x), fabs(v.y));
1378
+ }
1379
+ inline __host__ __device__ float3 fabs(float3 v)
1380
+ {
1381
+ return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
1382
+ }
1383
+ inline __host__ __device__ float4 fabs(float4 v)
1384
+ {
1385
+ return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
1386
+ }
1387
+
1388
+ inline __host__ __device__ int2 abs(int2 v)
1389
+ {
1390
+ return make_int2(abs(v.x), abs(v.y));
1391
+ }
1392
+ inline __host__ __device__ int3 abs(int3 v)
1393
+ {
1394
+ return make_int3(abs(v.x), abs(v.y), abs(v.z));
1395
+ }
1396
+ inline __host__ __device__ int4 abs(int4 v)
1397
+ {
1398
+ return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
1399
+ }
1400
+
1401
+ ////////////////////////////////////////////////////////////////////////////////
1402
+ // reflect
1403
+ // - returns reflection of incident ray I around surface normal N
1404
+ // - N should be normalized, reflected vector's length is equal to length of I
1405
+ ////////////////////////////////////////////////////////////////////////////////
1406
+
1407
+ inline __host__ __device__ float3 reflect(float3 i, float3 n)
1408
+ {
1409
+ return i - 2.0f * n * dot(n,i);
1410
+ }
1411
+
1412
+ ////////////////////////////////////////////////////////////////////////////////
1413
+ // cross product
1414
+ ////////////////////////////////////////////////////////////////////////////////
1415
+
1416
+ inline __host__ __device__ float3 cross(float3 a, float3 b)
1417
+ {
1418
+ return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
1419
+ }
1420
+
1421
+ ////////////////////////////////////////////////////////////////////////////////
1422
+ // smoothstep
1423
+ // - returns 0 if x < a
1424
+ // - returns 1 if x > b
1425
+ // - otherwise returns smooth interpolation between 0 and 1 based on x
1426
+ ////////////////////////////////////////////////////////////////////////////////
1427
+
1428
+ inline __device__ __host__ float smoothstep(float a, float b, float x)
1429
+ {
1430
+ float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1431
+ return (y*y*(3.0f - (2.0f*y)));
1432
+ }
1433
+ inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
1434
+ {
1435
+ float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1436
+ return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
1437
+ }
1438
+ inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
1439
+ {
1440
+ float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1441
+ return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
1442
+ }
1443
+ inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
1444
+ {
1445
+ float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1446
+ return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
1447
+ }
1448
+
1449
+ #endif
utils/io_utils.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json, os, sys
3
+ import os.path as osp
4
+ from typing import List, Union, Tuple, Dict
5
+ from pathlib import Path
6
+ import cv2
7
+ import numpy as np
8
+ from imageio import imread, imwrite
9
+ import pickle
10
+ import pycocotools.mask as maskUtils
11
+ from einops import rearrange
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ import io
15
+ import requests
16
+ import traceback
17
+ import base64
18
+ import time
19
+
20
+
21
+ NP_BOOL_TYPES = (np.bool_, np.bool8)
22
+ NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
23
+ NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
24
+
25
+ class NumpyEncoder(json.JSONEncoder):
26
+ def default(self, obj):
27
+ if isinstance(obj, np.ndarray):
28
+ return obj.tolist()
29
+ elif isinstance(obj, np.ScalarType):
30
+ if isinstance(obj, NP_BOOL_TYPES):
31
+ return bool(obj)
32
+ elif isinstance(obj, NP_FLOAT_TYPES):
33
+ return float(obj)
34
+ elif isinstance(obj, NP_INT_TYPES):
35
+ return int(obj)
36
+ return json.JSONEncoder.default(self, obj)
37
+
38
+
39
+ def json2dict(json_path: str):
40
+ with open(json_path, 'r', encoding='utf8') as f:
41
+ metadata = json.loads(f.read())
42
+ return metadata
43
+
44
+
45
+ def dict2json(adict: dict, json_path: str):
46
+ with open(json_path, "w", encoding="utf-8") as f:
47
+ f.write(json.dumps(adict, ensure_ascii=False, cls=NumpyEncoder))
48
+
49
+
50
+ def dict2pickle(dumped_path: str, tgt_dict: dict):
51
+ with open(dumped_path, "wb") as f:
52
+ pickle.dump(tgt_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
53
+
54
+
55
+ def pickle2dict(pkl_path: str) -> Dict:
56
+ with open(pkl_path, "rb") as f:
57
+ dumped_data = pickle.load(f)
58
+ return dumped_data
59
+
60
+ def get_all_dirs(root_p: str) -> List[str]:
61
+ alldir = os.listdir(root_p)
62
+ dirlist = []
63
+ for dirp in alldir:
64
+ dirp = osp.join(root_p, dirp)
65
+ if osp.isdir(dirp):
66
+ dirlist.append(dirp)
67
+ return dirlist
68
+
69
+
70
+ def read_filelist(filelistp: str):
71
+ with open(filelistp, 'r', encoding='utf8') as f:
72
+ lines = f.readlines()
73
+ if len(lines) > 0 and lines[-1].strip() == '':
74
+ lines = lines[:-1]
75
+ return lines
76
+
77
+
78
+ VIDEO_EXTS = {'.flv', '.mp4', '.mkv', '.ts', '.mov', 'mpeg'}
79
+ def get_all_videos(video_dir: str, video_exts=VIDEO_EXTS, abs_path=False) -> List[str]:
80
+ filelist = os.listdir(video_dir)
81
+ vlist = []
82
+ for f in filelist:
83
+ if Path(f).suffix in video_exts:
84
+ if abs_path:
85
+ vlist.append(osp.join(video_dir, f))
86
+ else:
87
+ vlist.append(f)
88
+ return vlist
89
+
90
+
91
+ IMG_EXT = {'.bmp', '.jpg', '.png', '.jpeg'}
92
+ def find_all_imgs(img_dir, abs_path=False):
93
+ imglist = []
94
+ dir_list = os.listdir(img_dir)
95
+ for filename in dir_list:
96
+ file_suffix = Path(filename).suffix
97
+ if file_suffix.lower() not in IMG_EXT:
98
+ continue
99
+ if abs_path:
100
+ imglist.append(osp.join(img_dir, filename))
101
+ else:
102
+ imglist.append(filename)
103
+ return imglist
104
+
105
+
106
+ def find_all_files_recursive(tgt_dir: Union[List, str], ext, exclude_dirs={}):
107
+ if isinstance(tgt_dir, str):
108
+ tgt_dir = [tgt_dir]
109
+
110
+ filelst = []
111
+ for d in tgt_dir:
112
+ for root, _, files in os.walk(d):
113
+ if osp.basename(root) in exclude_dirs:
114
+ continue
115
+ for f in files:
116
+ if Path(f).suffix.lower() in ext:
117
+ filelst.append(osp.join(root, f))
118
+
119
+ return filelst
120
+
121
+
122
+ def danbooruid2relpath(id_str: str, file_ext='.jpg'):
123
+ if not isinstance(id_str, str):
124
+ id_str = str(id_str)
125
+ return id_str[-3:].zfill(4) + '/' + id_str + file_ext
126
+
127
+
128
+ def get_template_histvq(template: np.ndarray) -> Tuple[List[np.ndarray]]:
129
+ len_shape = len(template.shape)
130
+ num_c = 3
131
+ mask = None
132
+ if len_shape == 2:
133
+ num_c = 1
134
+ elif len_shape == 3 and template.shape[-1] == 4:
135
+ mask = np.where(template[..., -1])
136
+ template = template[..., :num_c][mask]
137
+
138
+ values, quantiles = [], []
139
+ for ii in range(num_c):
140
+ v, c = np.unique(template[..., ii].ravel(), return_counts=True)
141
+ q = np.cumsum(c).astype(np.float64)
142
+ if len(q) < 1:
143
+ return None, None
144
+ q /= q[-1]
145
+ values.append(v)
146
+ quantiles.append(q)
147
+ return values, quantiles
148
+
149
+
150
+ def inplace_hist_matching(img: np.ndarray, tv: List[np.ndarray], tq: List[np.ndarray]) -> None:
151
+ len_shape = len(img.shape)
152
+ num_c = 3
153
+ mask = None
154
+
155
+ tgtimg = img
156
+ if len_shape == 2:
157
+ num_c = 1
158
+ elif len_shape == 3 and img.shape[-1] == 4:
159
+ mask = np.where(img[..., -1])
160
+ tgtimg = img[..., :num_c][mask]
161
+
162
+ im_h, im_w = img.shape[:2]
163
+ oldtype = img.dtype
164
+ for ii in range(num_c):
165
+ _, bin_idx, s_counts = np.unique(tgtimg[..., ii].ravel(), return_inverse=True,
166
+ return_counts=True)
167
+ s_quantiles = np.cumsum(s_counts).astype(np.float64)
168
+ if len(s_quantiles) == 0:
169
+ return
170
+ s_quantiles /= s_quantiles[-1]
171
+ interp_t_values = np.interp(s_quantiles, tq[ii], tv[ii]).astype(oldtype)
172
+ if mask is not None:
173
+ img[..., ii][mask] = interp_t_values[bin_idx]
174
+ else:
175
+ img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
176
+ # try:
177
+ # img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
178
+ # except:
179
+ # LOGGER.error('##################### sth goes wrong')
180
+ # cv2.imshow('img', img)
181
+ # cv2.waitKey(0)
182
+
183
+
184
+ def fgbg_hist_matching(fg_list: List, bg: np.ndarray, min_tq_num=128):
185
+ btv, btq = get_template_histvq(bg)
186
+ ftv, ftq = get_template_histvq(fg_list[0]['image'])
187
+ num_fg = len(fg_list)
188
+ idx_matched = -1
189
+ if num_fg > 1:
190
+ _ftv, _ftq = get_template_histvq(fg_list[0]['image'])
191
+ if _ftq is not None and ftq is not None:
192
+ if len(_ftq[0]) > len(ftq[0]):
193
+ idx_matched = num_fg - 1
194
+ ftv, ftq = _ftv, _ftq
195
+ else:
196
+ idx_matched = 0
197
+
198
+ if btq is not None and ftq is not None:
199
+ if len(btq[0]) > len(ftq[0]):
200
+ tv, tq = btv, btq
201
+ idx_matched = -1
202
+ else:
203
+ tv, tq = ftv, ftq
204
+ if len(tq[0]) > min_tq_num:
205
+ inplace_hist_matching(bg, tv, tq)
206
+
207
+ if len(tq[0]) > min_tq_num:
208
+ for ii, fg_dict in enumerate(fg_list):
209
+ fg = fg_dict['image']
210
+ if ii != idx_matched and len(tq[0]) > min_tq_num:
211
+ inplace_hist_matching(fg, tv, tq)
212
+
213
+
214
+ def imread_nogrey_rgb(imp: str) -> np.ndarray:
215
+ img: np.ndarray = imread(imp)
216
+ c = 1
217
+ if len(img.shape) == 3:
218
+ c = img.shape[-1]
219
+ if c == 1:
220
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
221
+ if c == 4:
222
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
223
+ return img
224
+
225
+
226
+ def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value: Tuple = (114, 114, 114)):
227
+ h, w = img.shape[:2]
228
+ pad_h, pad_w = 0, 0
229
+
230
+ # make square image
231
+ if w < h:
232
+ pad_w = h - w
233
+ w += pad_w
234
+ elif h < w:
235
+ pad_h = w - h
236
+ h += pad_h
237
+
238
+ pad_size = tgt_size - h
239
+ if pad_size > 0:
240
+ pad_h += pad_size
241
+ pad_w += pad_size
242
+
243
+ if pad_h > 0 or pad_w > 0:
244
+ img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value)
245
+
246
+ down_scale_ratio = tgt_size / img.shape[0]
247
+ assert down_scale_ratio <= 1
248
+ if down_scale_ratio < 1:
249
+ img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
250
+
251
+ return img, down_scale_ratio, pad_h, pad_w
252
+
253
+
254
+ def scaledown_maxsize(img: np.ndarray, max_size: int, divisior: int = None):
255
+
256
+ im_h, im_w = img.shape[:2]
257
+ ori_h, ori_w = img.shape[:2]
258
+ resize_ratio = max_size / max(im_h, im_w)
259
+ if resize_ratio < 1:
260
+ if im_h > im_w:
261
+ im_h = max_size
262
+ im_w = max(1, int(round(im_w * resize_ratio)))
263
+
264
+ else:
265
+ im_w = max_size
266
+ im_h = max(1, int(round(im_h * resize_ratio)))
267
+ if divisior is not None:
268
+ im_w = int(np.ceil(im_w / divisior) * divisior)
269
+ im_h = int(np.ceil(im_h / divisior) * divisior)
270
+
271
+ if im_w != ori_w or im_h != ori_h:
272
+ img = cv2.resize(img, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
273
+
274
+ return img
275
+
276
+
277
+ def resize_pad(img: np.ndarray, tgt_size: int, pad_value: Tuple = (0, 0, 0)):
278
+ # downscale to tgt_size and pad to square
279
+ img = scaledown_maxsize(img, tgt_size)
280
+ padl, padr, padt, padb = 0, 0, 0, 0
281
+ h, w = img.shape[:2]
282
+ # padt = (tgt_size - h) // 2
283
+ # padb = tgt_size - h - padt
284
+ # padl = (tgt_size - w) // 2
285
+ # padr = tgt_size - w - padl
286
+ padb = tgt_size - h
287
+ padr = tgt_size - w
288
+
289
+ if padt + padb + padl + padr > 0:
290
+ img = cv2.copyMakeBorder(img, padt, padb, padl, padr, cv2.BORDER_CONSTANT, value=pad_value)
291
+
292
+ return img, (padt, padb, padl, padr)
293
+
294
+
295
+ def resize_pad2divisior(img: np.ndarray, tgt_size: int, divisior: int = 64, pad_value: Tuple = (0, 0, 0)):
296
+ img = scaledown_maxsize(img, tgt_size)
297
+ img, (pad_h, pad_w) = pad2divisior(img, divisior, pad_value)
298
+ return img, (pad_h, pad_w)
299
+
300
+
301
+ def img2grey(img: Union[np.ndarray, str], is_rgb: bool = False) -> np.ndarray:
302
+ if isinstance(img, np.ndarray):
303
+ if len(img.shape) == 3:
304
+ if img.shape[-1] != 1:
305
+ if is_rgb:
306
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
307
+ else:
308
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
309
+ else:
310
+ img = img[..., 0]
311
+ return img
312
+ elif isinstance(img, str):
313
+ return cv2.imread(img, cv2.IMREAD_GRAYSCALE)
314
+ else:
315
+ raise NotImplementedError
316
+
317
+
318
+ def pad2divisior(img: np.ndarray, divisior: int, value = (0, 0, 0)) -> np.ndarray:
319
+ im_h, im_w = img.shape[:2]
320
+ pad_h = int(np.ceil(im_h / divisior)) * divisior - im_h
321
+ pad_w = int(np.ceil(im_w / divisior)) * divisior - im_w
322
+ if pad_h != 0 or pad_w != 0:
323
+ img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, value=value, borderType=cv2.BORDER_CONSTANT)
324
+ return img, (pad_h, pad_w)
325
+
326
+
327
+ def mask2rle(mask: np.ndarray, decode_for_json: bool = True) -> Dict:
328
+ mask_rle = maskUtils.encode(np.array(
329
+ mask[..., np.newaxis] > 0, order='F',
330
+ dtype='uint8'))[0]
331
+ if decode_for_json:
332
+ mask_rle['counts'] = mask_rle['counts'].decode()
333
+ return mask_rle
334
+
335
+
336
+ def bbox2xyxy(box) -> Tuple[int]:
337
+ x1, y1 = box[0], box[1]
338
+ return x1, y1, x1+box[2], y1+box[3]
339
+
340
+
341
+ def bbox_overlap_area(abox, boxb) -> int:
342
+ ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
343
+ bx1, by1, bx2, by2 = bbox2xyxy(boxb)
344
+
345
+ ix = min(ax2, bx2) - max(ax1, bx1)
346
+ iy = min(ay2, by2) - max(ay1, by1)
347
+
348
+ if ix > 0 and iy > 0:
349
+ return ix * iy
350
+ else:
351
+ return 0
352
+
353
+
354
+ def bbox_overlap_xy(abox, boxb) -> Tuple[int]:
355
+ ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
356
+ bx1, by1, bx2, by2 = bbox2xyxy(boxb)
357
+
358
+ ix = min(ax2, bx2) - max(ax1, bx1)
359
+ iy = min(ay2, by2) - max(ay1, by1)
360
+
361
+ return ix, iy
362
+
363
+
364
+ def xyxy_overlap_area(axyxy, bxyxy) -> int:
365
+ ax1, ay1, ax2, ay2 = axyxy
366
+ bx1, by1, bx2, by2 = bxyxy
367
+
368
+ ix = min(ax2, bx2) - max(ax1, bx1)
369
+ iy = min(ay2, by2) - max(ay1, by1)
370
+
371
+ if ix > 0 and iy > 0:
372
+ return ix * iy
373
+ else:
374
+ return 0
375
+
376
+
377
+ DIRNAME2TAG = {'rezero': 're:zero'}
378
+ def dirname2charactername(dirname, start=6):
379
+ cname = dirname[start:]
380
+ for k, v in DIRNAME2TAG.items():
381
+ cname = cname.replace(k, v)
382
+ return cname
383
+
384
+
385
+ def imglist2grid(imglist: np.ndarray, grid_size: int = 384, col=None) -> np.ndarray:
386
+ sqimlist = []
387
+ for img in imglist:
388
+ sqimlist.append(square_pad_resize(img, grid_size)[0])
389
+
390
+ nimg = len(imglist)
391
+ if nimg == 0:
392
+ return None
393
+ padn = 0
394
+ if col is None:
395
+ if nimg > 5:
396
+ row = int(np.round(np.sqrt(nimg)))
397
+ col = int(np.ceil(nimg / row))
398
+ else:
399
+ col = nimg
400
+
401
+ padn = int(np.ceil(nimg / col) * col) - nimg
402
+ if padn != 0:
403
+ padimg = np.zeros_like(sqimlist[0])
404
+ for _ in range(padn):
405
+ sqimlist.append(padimg)
406
+
407
+ return rearrange(sqimlist, '(row col) h w c -> (row h) (col w) c', col=col)
408
+
409
+ def write_jsonlines(filep: str, dict_lst: List[str], progress_bar: bool = True):
410
+ with open(filep, 'w') as out:
411
+ if progress_bar:
412
+ lst = tqdm(dict_lst)
413
+ else:
414
+ lst = dict_lst
415
+ for ddict in lst:
416
+ jout = json.dumps(ddict) + '\n'
417
+ out.write(jout)
418
+
419
+ def read_jsonlines(filep: str):
420
+ with open(filep, 'r', encoding='utf8') as f:
421
+ result = [json.loads(jline) for jline in f.read().splitlines()]
422
+ return result
423
+
424
+
425
+ def _b64encode(x: bytes) -> str:
426
+ return base64.b64encode(x).decode("utf-8")
427
+
428
+
429
+ def img2b64(img):
430
+ """
431
+ Convert a PIL image to a base64-encoded string.
432
+ """
433
+ if isinstance(img, np.ndarray):
434
+ img = Image.fromarray(img)
435
+ buffered = io.BytesIO()
436
+ img.save(buffered, format='PNG')
437
+ return _b64encode(buffered.getvalue())
438
+
439
+
440
+ def save_encoded_image(b64_image: str, output_path: str):
441
+ with open(output_path, "wb") as image_file:
442
+ image_file.write(base64.b64decode(b64_image))
443
+
444
+ def submit_request(url, data, exist_on_exception=True, auth=None, wait_time = 30):
445
+ response = None
446
+ try:
447
+ while True:
448
+ try:
449
+ response = requests.post(url, data=data, auth=auth)
450
+ response.raise_for_status()
451
+ break
452
+ except Exception as e:
453
+ if wait_time > 0:
454
+ print(traceback.format_exc(), file=sys.stderr)
455
+ print(f'sleep {wait_time} sec...')
456
+ time.sleep(wait_time)
457
+ continue
458
+ else:
459
+ raise e
460
+ except Exception as e:
461
+ print(traceback.format_exc(), file=sys.stderr)
462
+ if response is not None:
463
+ print('response content: ' + response.text)
464
+ if exist_on_exception:
465
+ exit()
466
+ return response
467
+
468
+
469
+ # def resize_image(input_image, resolution):
470
+ # H, W = input_image.shape[:2]
471
+ # k = float(min(resolution)) / min(H, W)
472
+ # img = cv2.resize(input_image, resolution, interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
473
+ # return img
utils/logger.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path as osp
3
+ from termcolor import colored
4
+
5
+ def set_logging(name=None, verbose=True):
6
+ for handler in logging.root.handlers[:]:
7
+ logging.root.removeHandler(handler)
8
+ # Sets level and returns logger
9
+ # rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
10
+ fmt = (
11
+ # colored("[%(name)s]", "magenta", attrs=["bold"])
12
+ colored("[%(asctime)s]", "blue")
13
+ + colored("%(levelname)s:", "green")
14
+ + colored("%(message)s", "white")
15
+ )
16
+ logging.basicConfig(format=fmt, level=logging.INFO if verbose else logging.WARNING)
17
+ return logging.getLogger(name)
18
+
19
+ LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
20
+
utils/mmdet_custom_hooks.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.fileio import FileClient
2
+ from mmengine.dist import master_only
3
+ from einops import rearrange
4
+ import torch
5
+ import mmcv
6
+ import numpy as np
7
+ import os.path as osp
8
+ import cv2
9
+ from typing import Optional, Sequence
10
+ import torch.nn as nn
11
+ from mmdet.apis import inference_detector
12
+ from mmcv.transforms import Compose
13
+ from mmdet.engine import DetVisualizationHook
14
+ from mmdet.registry import HOOKS
15
+ from mmdet.structures import DetDataSample
16
+
17
+ from utils.io_utils import find_all_imgs, square_pad_resize, imglist2grid
18
+
19
+ def inference_detector(
20
+ model: nn.Module,
21
+ imgs,
22
+ test_pipeline
23
+ ):
24
+
25
+ if isinstance(imgs, (list, tuple)):
26
+ is_batch = True
27
+ else:
28
+ imgs = [imgs]
29
+ is_batch = False
30
+
31
+ if len(imgs) == 0:
32
+ return []
33
+
34
+ test_pipeline = test_pipeline.copy()
35
+ if isinstance(imgs[0], np.ndarray):
36
+ # Calling this method across libraries will result
37
+ # in module unregistered error if not prefixed with mmdet.
38
+ test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
39
+
40
+ test_pipeline = Compose(test_pipeline)
41
+
42
+ result_list = []
43
+ for img in imgs:
44
+ # prepare data
45
+ if isinstance(img, np.ndarray):
46
+ # TODO: remove img_id.
47
+ data_ = dict(img=img, img_id=0)
48
+ else:
49
+ # TODO: remove img_id.
50
+ data_ = dict(img_path=img, img_id=0)
51
+ # build the data pipeline
52
+ data_ = test_pipeline(data_)
53
+
54
+ data_['inputs'] = [data_['inputs']]
55
+ data_['data_samples'] = [data_['data_samples']]
56
+
57
+ # forward the model
58
+ with torch.no_grad():
59
+ results = model.test_step(data_)[0]
60
+
61
+ result_list.append(results)
62
+
63
+ if not is_batch:
64
+ return result_list[0]
65
+ else:
66
+ return result_list
67
+
68
+
69
+ @HOOKS.register_module()
70
+ class InstanceSegVisualizationHook(DetVisualizationHook):
71
+
72
+ def __init__(self, visualize_samples: str = '',
73
+ read_rgb: bool = False,
74
+ draw: bool = False,
75
+ interval: int = 50,
76
+ score_thr: float = 0.3,
77
+ show: bool = False,
78
+ wait_time: float = 0.,
79
+ test_out_dir: Optional[str] = None,
80
+ file_client_args: dict = dict(backend='disk')):
81
+ super().__init__(draw, interval, score_thr, show, wait_time, test_out_dir, file_client_args)
82
+ self.vis_samples = []
83
+
84
+ if osp.exists(visualize_samples):
85
+ self.channel_order = channel_order = 'rgb' if read_rgb else 'bgr'
86
+ samples = find_all_imgs(visualize_samples, abs_path=True)
87
+ for imgp in samples:
88
+ img = mmcv.imread(imgp, channel_order=channel_order)
89
+ img, _, _, _ = square_pad_resize(img, 640)
90
+ self.vis_samples.append(img)
91
+
92
+ def before_val(self, runner) -> None:
93
+ total_curr_iter = runner.iter
94
+ self._visualize_data(total_curr_iter, runner)
95
+ return super().before_val(runner)
96
+
97
+ # def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
98
+ # outputs: Sequence[DetDataSample]) -> None:
99
+ # """Run after every ``self.interval`` validation iterations.
100
+
101
+ # Args:
102
+ # runner (:obj:`Runner`): The runner of the validation process.
103
+ # batch_idx (int): The index of the current batch in the val loop.
104
+ # data_batch (dict): Data from dataloader.
105
+ # outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
106
+ # that contain annotations and predictions.
107
+ # """
108
+ # # if self.draw is False:
109
+ # # return
110
+
111
+ # if self.file_client is None:
112
+ # self.file_client = FileClient(**self.file_client_args)
113
+
114
+
115
+ # # There is no guarantee that the same batch of images
116
+ # # is visualized for each evaluation.
117
+ # total_curr_iter = runner.iter + batch_idx
118
+
119
+ # # # Visualize only the first data
120
+ # # img_path = outputs[0].img_path
121
+ # # img_bytes = self.file_client.get(img_path)
122
+ # # img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
123
+ # if total_curr_iter % self.interval == 0 and self.vis_samples:
124
+ # self._visualize_data(total_curr_iter, runner)
125
+
126
+
127
+ @master_only
128
+ def _visualize_data(self, total_curr_iter, runner):
129
+
130
+ tgt_size = 384
131
+
132
+ runner.model.eval()
133
+ outputs = inference_detector(runner.model, self.vis_samples, test_pipeline=runner.cfg.test_pipeline)
134
+ vis_results = []
135
+ for img, output in zip(self.vis_samples, outputs):
136
+ vis_img = self.add_datasample(
137
+ 'val_img',
138
+ img,
139
+ data_sample=output,
140
+ show=self.show,
141
+ wait_time=self.wait_time,
142
+ pred_score_thr=self.score_thr,
143
+ draw_gt=False,
144
+ step=total_curr_iter)
145
+ vis_results.append(cv2.resize(vis_img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA))
146
+
147
+ drawn_img = imglist2grid(vis_results, tgt_size)
148
+ if drawn_img is None:
149
+ return
150
+ drawn_img = cv2.cvtColor(drawn_img, cv2.COLOR_BGR2RGB)
151
+ visualizer = self._visualizer
152
+ visualizer.set_image(drawn_img)
153
+ visualizer.add_image('val_img', drawn_img, total_curr_iter)
154
+
155
+
156
+ @master_only
157
+ def add_datasample(
158
+ self,
159
+ name: str,
160
+ image: np.ndarray,
161
+ data_sample: Optional['DetDataSample'] = None,
162
+ draw_gt: bool = True,
163
+ draw_pred: bool = True,
164
+ show: bool = False,
165
+ wait_time: float = 0,
166
+ # TODO: Supported in mmengine's Viusalizer.
167
+ out_file: Optional[str] = None,
168
+ pred_score_thr: float = 0.3,
169
+ step: int = 0) -> np.ndarray:
170
+ image = image.clip(0, 255).astype(np.uint8)
171
+ visualizer = self._visualizer
172
+ classes = visualizer.dataset_meta.get('classes', None)
173
+ palette = visualizer.dataset_meta.get('palette', None)
174
+
175
+ gt_img_data = None
176
+ pred_img_data = None
177
+
178
+ if data_sample is not None:
179
+ data_sample = data_sample.cpu()
180
+
181
+ if draw_gt and data_sample is not None:
182
+ gt_img_data = image
183
+ if 'gt_instances' in data_sample:
184
+ gt_img_data = visualizer._draw_instances(image,
185
+ data_sample.gt_instances,
186
+ classes, palette)
187
+
188
+ if 'gt_panoptic_seg' in data_sample:
189
+ assert classes is not None, 'class information is ' \
190
+ 'not provided when ' \
191
+ 'visualizing panoptic ' \
192
+ 'segmentation results.'
193
+ gt_img_data = visualizer._draw_panoptic_seg(
194
+ gt_img_data, data_sample.gt_panoptic_seg, classes)
195
+
196
+ if draw_pred and data_sample is not None:
197
+ pred_img_data = image
198
+ if 'pred_instances' in data_sample:
199
+ pred_instances = data_sample.pred_instances
200
+ pred_instances = pred_instances[
201
+ pred_instances.scores > pred_score_thr]
202
+ pred_img_data = visualizer._draw_instances(image, pred_instances,
203
+ classes, palette)
204
+ if 'pred_panoptic_seg' in data_sample:
205
+ assert classes is not None, 'class information is ' \
206
+ 'not provided when ' \
207
+ 'visualizing panoptic ' \
208
+ 'segmentation results.'
209
+ pred_img_data = visualizer._draw_panoptic_seg(
210
+ pred_img_data, data_sample.pred_panoptic_seg.numpy(),
211
+ classes)
212
+
213
+ if gt_img_data is not None and pred_img_data is not None:
214
+ drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
215
+ elif gt_img_data is not None:
216
+ drawn_img = gt_img_data
217
+ elif pred_img_data is not None:
218
+ drawn_img = pred_img_data
219
+ else:
220
+ # Display the original image directly if nothing is drawn.
221
+ drawn_img = image
222
+
223
+ return drawn_img