nowsyn commited on
Commit
54a7220
1 Parent(s): 73d81d6

upload codes

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. annotator/canny/__init__.py +6 -0
  3. annotator/cielab/__init__.py +47 -0
  4. annotator/cielab/rayleigh/__init__.py +8 -0
  5. annotator/cielab/rayleigh/palette.py +132 -0
  6. annotator/cielab/rayleigh/util.py +270 -0
  7. annotator/content/__init__.py +23 -0
  8. annotator/entityseg/__init__.py +93 -0
  9. annotator/entityseg/configs/Base-Mask2Former.yaml +49 -0
  10. annotator/entityseg/configs/cropformer_hornet_3x.yaml +70 -0
  11. annotator/entityseg/mask2former/__init__.py +11 -0
  12. annotator/entityseg/mask2former/config.py +139 -0
  13. annotator/entityseg/mask2former/cropformer_model.py +678 -0
  14. annotator/entityseg/mask2former/data/__init__.py +1 -0
  15. annotator/entityseg/mask2former/data/dataset_mappers/__init__.py +1 -0
  16. annotator/entityseg/mask2former/data/dataset_mappers/crop_augmentations.py +421 -0
  17. annotator/entityseg/mask2former/maskformer_model.py +446 -0
  18. annotator/entityseg/mask2former/modeling/__init__.py +7 -0
  19. annotator/entityseg/mask2former/modeling/backbone/__init__.py +1 -0
  20. annotator/entityseg/mask2former/modeling/backbone/hornet.py +363 -0
  21. annotator/entityseg/mask2former/modeling/backbone/swin.py +770 -0
  22. annotator/entityseg/mask2former/modeling/criterion.py +263 -0
  23. annotator/entityseg/mask2former/modeling/criterion_view.py +288 -0
  24. annotator/entityseg/mask2former/modeling/matcher.py +189 -0
  25. annotator/entityseg/mask2former/modeling/matcher_view.py +194 -0
  26. annotator/entityseg/mask2former/modeling/meta_arch/__init__.py +1 -0
  27. annotator/entityseg/mask2former/modeling/meta_arch/mask_former_head.py +133 -0
  28. annotator/entityseg/mask2former/modeling/meta_arch/per_pixel_baseline.py +243 -0
  29. annotator/entityseg/mask2former/modeling/pixel_decoder/__init__.py +1 -0
  30. annotator/entityseg/mask2former/modeling/pixel_decoder/fpn.py +312 -0
  31. annotator/entityseg/mask2former/modeling/pixel_decoder/msdeformattn.py +358 -0
  32. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/__init__.py +13 -0
  33. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py +72 -0
  34. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/make.sh +13 -0
  35. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/__init__.py +12 -0
  36. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/ms_deform_attn.py +125 -0
  37. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/setup.py +78 -0
  38. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp +46 -0
  39. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h +38 -0
  40. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu +158 -0
  41. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h +35 -0
  42. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh +1332 -0
  43. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h +67 -0
  44. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/vision.cpp +21 -0
  45. annotator/entityseg/mask2former/modeling/pixel_decoder/ops/test.py +92 -0
  46. annotator/entityseg/mask2former/modeling/transformer_decoder/__init__.py +5 -0
  47. annotator/entityseg/mask2former/modeling/transformer_decoder/cropformer_transformer_decoder.py +595 -0
  48. annotator/entityseg/mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py +461 -0
  49. annotator/entityseg/mask2former/modeling/transformer_decoder/maskformer_transformer_decoder.py +188 -0
  50. annotator/entityseg/mask2former/modeling/transformer_decoder/position_encoding.py +134 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 OpenMMLab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/canny/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ class CannyDetector:
5
+ def __call__(self, img, low_threshold=100, high_threshold=200):
6
+ return cv2.Canny(img, low_threshold, high_threshold)
annotator/cielab/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.getcwd())
4
+ sys.path.append(os.path.join(os.getcwd(), 'rayleigh'))
5
+
6
+ import numpy as np
7
+ from skimage.color import rgb2lab
8
+ from .rayleigh import Palette
9
+ from .rayleigh.util import histogram_colors_strict, smooth_histogram, color_hist_to_palette_image
10
+
11
+
12
+ class CIELabDetector:
13
+
14
+ MAX_DIMENSION = 240 + 1
15
+
16
+ def __init__(self, sigma=10, num_hues=11, num_light=5, num_sat=5):
17
+ self.sigma = sigma
18
+ self.palette = Palette(num_hues=num_hues, light_range=num_light, sat_range=num_sat)
19
+
20
+ def __call__(self, img):
21
+ # Handle grayscale and RGBA images.
22
+ # TODO: Should be smarter here in the future, but for now simply remove
23
+ # the alpha channel if present.
24
+ if img.ndim == 2:
25
+ img = np.tile(img[:, :, np.newaxis], (1, 1, 3))
26
+ elif img.ndim == 4:
27
+ img = img[:, :, :3]
28
+ img = img[:,:,:3]
29
+
30
+ h, w, d = tuple(img.shape)
31
+ h_stride = int(h / self.MAX_DIMENSION + 1)
32
+ w_stride = int(w / self.MAX_DIMENSION + 1)
33
+ img = img[::h_stride, ::w_stride, :]
34
+
35
+ # Convert to L*a*b colors.
36
+ h, w, d = img.shape
37
+ lab_array = rgb2lab(img).reshape((h * w, d))
38
+
39
+ # compute hist
40
+ hist = histogram_colors_strict(lab_array, self.palette)
41
+ hist = smooth_histogram(hist, self.palette, self.sigma)
42
+ return hist
43
+
44
+ def hist_to_palette(self, hist):
45
+ # hist to image
46
+ plt = color_hist_to_palette_image(hist, self.palette)
47
+ return (plt * 255).astype(np.uint8)
annotator/cielab/rayleigh/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rayleigh is an open-source system for quickly searching medium-sized image
3
+ collections by multiple colors given as a palette or derived from a query image.
4
+ """
5
+
6
+
7
+ from .palette import Palette
8
+ from .util import *
annotator/cielab/rayleigh/palette.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encapsulate the list of hex colors and array of Lab values representations
3
+ of a palette (codebook) of colors.
4
+
5
+ Provide methods to work with color conversion and the Palette class.
6
+
7
+ Provide a parametrized method to generate a palette that covers the range
8
+ of colors.
9
+ """
10
+
11
+ import os
12
+ import numpy as np
13
+ from skimage.color import hsv2rgb, rgb2lab
14
+ from skimage.io import imsave
15
+ from sklearn.metrics import euclidean_distances
16
+
17
+ from .util import rgb2hex
18
+
19
+
20
+ class Palette(object):
21
+ """
22
+ Create a color palette (codebook) in the form of a 2D grid of colors,
23
+ as described in the parameters list below.
24
+ Further, the rightmost column has num_hues gradations from black to white.
25
+
26
+ Parameters
27
+ ----------
28
+ num_hues : int
29
+ number of colors with full lightness and saturation, in the middle
30
+ sat_range : int
31
+ number of rows above middle row that show
32
+ the same hues with decreasing saturation.
33
+ light_range : int
34
+ number of rows below middle row that show
35
+ the same hues with decreasing lightness.
36
+
37
+ Returns
38
+ -------
39
+ palette: rayleigh.Palette
40
+ """
41
+
42
+ def __init__(self, num_hues=8, sat_range=2, light_range=2):
43
+ height = 1 + sat_range + (2 * light_range - 1)
44
+ # generate num_hues+1 hues, but don't take the last one:
45
+ # hues are on a circle, and we would be oversampling the origin
46
+ hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (height, 1))
47
+ if num_hues == 8:
48
+ hues = np.tile(np.array(
49
+ [0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (height, 1))
50
+ if num_hues == 9:
51
+ hues = np.tile(np.array(
52
+ [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (height, 1))
53
+ if num_hues == 10:
54
+ hues = np.tile(np.array(
55
+ [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (height, 1))
56
+ elif num_hues == 11:
57
+ hues = np.tile(np.array(
58
+ [0.0, 0.0833, 0.166, 0.25,
59
+ 0.333, 0.5, 0.56333,
60
+ 0.666, 0.73, 0.803,
61
+ 0.916]), (height, 1))
62
+
63
+ sats = np.hstack((
64
+ np.linspace(0, 1, sat_range + 2)[1:-1],
65
+ 1,
66
+ [1] * (light_range),
67
+ [.4] * (light_range - 1),
68
+ ))
69
+ lights = np.hstack((
70
+ [1] * sat_range,
71
+ 1,
72
+ np.linspace(1, 0.2, light_range + 2)[1:-1],
73
+ np.linspace(1, 0.2, light_range + 2)[1:-2],
74
+ ))
75
+
76
+ sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
77
+ lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
78
+ colors = hsv2rgb(np.dstack((hues, sats, lights)))
79
+ grays = np.tile(
80
+ np.linspace(1, 0, height)[:, np.newaxis, np.newaxis], (1, 1, 3))
81
+
82
+ self.rgb_image = np.hstack((colors, grays))
83
+
84
+ # Make a nice histogram ordering of the hues and grays
85
+ h, w, d = colors.shape
86
+ color_array = colors.T.reshape((d, w * h)).T
87
+ h, w, d = grays.shape
88
+ gray_array = grays.T.reshape((d, w * h)).T
89
+
90
+ self.rgb_array = np.vstack((color_array, gray_array))
91
+ self.lab_array = rgb2lab(self.rgb_array[None, :, :]).squeeze()
92
+ self.hex_list = [rgb2hex(row) for row in self.rgb_array]
93
+ #assert(np.all(self.rgb_array == self.rgb_array[None, :, :].squeeze()))
94
+
95
+ self.distances = euclidean_distances(self.lab_array, squared=True)
96
+
97
+ def output(self, dirname, html=False):
98
+ """
99
+ Output an image of the palette, josn list of the hex
100
+ colors, and an HTML color picker for it.
101
+
102
+ Parameters
103
+ ----------
104
+ dirname : string
105
+ directory for the files to be output
106
+ """
107
+ def get_palette_html():
108
+ """
109
+ Return HTML for a color picker using the given palette.
110
+ """
111
+ html = """
112
+ <style>
113
+ span {
114
+ width: 20px;
115
+ height: 20px;
116
+ margin: 2px;
117
+ padding: 0px;
118
+ display: inline-block;
119
+ }
120
+ </style>
121
+ """
122
+ for row in self.rgb_image:
123
+ for rgb_color in row:
124
+ s = '<a id="{0}"><span style="background-color: {0}" /></a>\n'
125
+ html += s.format(rgb2hex(rgb_color))
126
+ html += "<br />\n"
127
+ return html
128
+
129
+ imsave(os.path.join(dirname, 'palette.png'), (self.rgb_image*255).astype(np.uint8))
130
+ if html:
131
+ with open(os.path.join(dirname, 'palette.html'), 'w') as f:
132
+ f.write(get_palette_html())
annotator/cielab/rayleigh/util.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tempfile
4
+ import matplotlib.pyplot as plt
5
+ from sklearn.metrics import euclidean_distances
6
+ from skimage.io import imsave
7
+
8
+
9
+ def rgb2hex(rgb_number):
10
+ """
11
+ Args:
12
+ - rgb_number (sequence of float)
13
+
14
+ Returns:
15
+ - hex_number (string)
16
+ """
17
+ return '#%02x%02x%02x' % tuple([int(np.round(val * 255)) for val in rgb_number])
18
+
19
+
20
+ def hex2rgb(hexcolor_str):
21
+ """
22
+ Args:
23
+ - hexcolor_str (string): e.g. '#ffffff' or '33cc00'
24
+
25
+ Returns:
26
+ - rgb_color (sequence of floats): e.g. (0.2, 0.3, 0)
27
+ """
28
+ color = hexcolor_str.strip('#')
29
+ rgb = lambda x: round(int(x, 16) / 255., 5)
30
+ return (rgb(color[:2]), rgb(color[2:4]), rgb(color[4:6]))
31
+
32
+
33
+ def color_hist_to_palette_image(color_hist, palette, percentile=90,
34
+ width=200, height=50, filename=None):
35
+ """
36
+ Output the main colors in the histogram to a "palette image."
37
+
38
+ Parameters
39
+ ----------
40
+ color_hist : (K,) ndarray
41
+ palette : rayleigh.Palette
42
+ percentile : int, optional:
43
+ Output only colors above this percentile of prevalence in the histogram.
44
+ filename : string, optional:
45
+ If given, save the resulting image to file.
46
+
47
+ Returns
48
+ -------
49
+ rgb_image : ndarray
50
+ """
51
+ ind = np.argsort(-color_hist)
52
+ ind = ind[color_hist[ind] > np.percentile(color_hist, percentile)]
53
+ hex_list = np.take(palette.hex_list, ind)
54
+ values = color_hist[ind]
55
+ rgb_image = palette_query_to_rgb_image(dict(zip(hex_list, values)))
56
+ if filename:
57
+ imsave(filename, rgb_image)
58
+ return rgb_image
59
+
60
+
61
+ def palette_query_to_rgb_image(palette_query, width=200, height=50):
62
+ """
63
+ Convert a list of hex colors and their values to an RGB image of given
64
+ width and height.
65
+
66
+ Args:
67
+ - palette_query (dict):
68
+ a dictionary of hex colors to unnormalized values,
69
+ e.g. {'#ffffff': 20, '#33cc00': 0.4}.
70
+ """
71
+ hex_list, values = zip(*palette_query.items())
72
+ values = np.array(values)
73
+ values /= values.sum()
74
+ nums = np.array(values * width, dtype=int)
75
+ rgb_arrays = (np.tile(np.array(hex2rgb(x)), (num, 1))
76
+ for x, num in zip(hex_list, nums))
77
+ rgb_array = np.vstack(list(rgb_arrays))
78
+ rgb_image = rgb_array[np.newaxis, :, :]
79
+ rgb_image = np.tile(rgb_image, (height, 1, 1))
80
+ return rgb_image
81
+
82
+
83
+ def plot_histogram(color_hist, palette, plot_filename=None):
84
+ """
85
+ Return Figure containing the color palette histogram.
86
+
87
+ Args:
88
+ - color_hist (K, ndarray)
89
+
90
+ - palette (Palette)
91
+
92
+ - plot_filename (string) [default=None]:
93
+ Save histogram to this file, if given.
94
+
95
+ Returns:
96
+ - fig (Figure)
97
+ """
98
+ fig = plt.figure(figsize=(5, 3), dpi=150)
99
+ ax = fig.add_subplot(111)
100
+ ax.bar(
101
+ range(len(color_hist)), color_hist,
102
+ color=palette.hex_list, edgecolor='black')
103
+ ax.set_ylim((0, 0.3))
104
+ ax.xaxis.set_ticks([])
105
+ ax.set_xlim((0, len(palette.hex_list)))
106
+ if plot_filename:
107
+ fig.savefig(plot_filename, dpi=150, facecolor='none')
108
+ return fig
109
+
110
+
111
+ def output_histogram_base64(color_hist, palette):
112
+ """
113
+ Return base64-encoded image containing the color palette histogram.
114
+
115
+ Args:
116
+ - color_hist (K, ndarray)
117
+
118
+ - palette (Palette)
119
+
120
+ Returns:
121
+ - data_uri (base64 encoded string)
122
+ """
123
+ _, tfname = tempfile.mkstemp('.png')
124
+ plot_histogram(color_hist, palette, tfname)
125
+ data_uri = open(tfname, 'rb').read().encode('base64').replace('\n', '')
126
+ os.remove(tfname)
127
+ return data_uri
128
+
129
+
130
+ def histogram_colors_strict(lab_array, palette, plot_filename=None):
131
+ """
132
+ Return a palette histogram of colors in the image.
133
+
134
+ Parameters
135
+ ----------
136
+ lab_array : (N,3) ndarray
137
+ The L*a*b color of each of N pixels.
138
+ palette : rayleigh.Palette
139
+ Containing K colors.
140
+ plot_filename : string, optional
141
+ If given, save histogram to this filename.
142
+
143
+ Returns
144
+ -------
145
+ color_hist : (K,) ndarray
146
+ """
147
+ # This is the fastest way that I've found.
148
+ # >>> %%timeit -n 200 from sklearn.metrics import euclidean_distances
149
+ # >>> euclidean_distances(palette, lab_array, squared=True)
150
+ dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T
151
+ min_ind = np.argmin(dist, axis=1)
152
+ num_colors = palette.lab_array.shape[0]
153
+ num_pixels = lab_array.shape[0]
154
+ color_hist = 1. * np.bincount(min_ind, minlength=num_colors) / num_pixels
155
+ if plot_filename is not None:
156
+ plot_histogram(color_hist, palette, plot_filename)
157
+ return color_hist
158
+
159
+
160
+ def histogram_colors_smoothed(lab_array, palette, sigma=10,
161
+ plot_filename=None, direct=True):
162
+ """
163
+ Returns a palette histogram of colors in the image, smoothed with
164
+ a Gaussian. Can smooth directly per-pixel, or after computing a strict
165
+ histogram.
166
+
167
+ Parameters
168
+ ----------
169
+ lab_array : (N,3) ndarray
170
+ The L*a*b color of each of N pixels.
171
+ palette : rayleigh.Palette
172
+ Containing K colors.
173
+ sigma : float
174
+ Variance of the smoothing Gaussian.
175
+ direct : bool, optional
176
+ If True, constructs a smoothed histogram directly from pixels.
177
+ If False, constructs a nearest-color histogram and then smoothes it.
178
+
179
+ Returns
180
+ -------
181
+ color_hist : (K,) ndarray
182
+ """
183
+ if direct:
184
+ color_hist_smooth = histogram_colors_with_smoothing(
185
+ lab_array, palette, sigma)
186
+ else:
187
+ color_hist_strict = histogram_colors_strict(lab_array, palette)
188
+ color_hist_smooth = smooth_histogram(color_hist_strict, palette, sigma)
189
+ if plot_filename is not None:
190
+ plot_histogram(color_hist_smooth, palette, plot_filename)
191
+ return color_hist_smooth
192
+
193
+
194
+ def smooth_histogram(color_hist, palette, sigma=10):
195
+ """
196
+ Smooth the given palette histogram with a Gaussian of variance sigma.
197
+
198
+ Parameters
199
+ ----------
200
+ color_hist : (K,) ndarray
201
+ palette : rayleigh.Palette
202
+ containing K colors.
203
+
204
+ Returns
205
+ -------
206
+ color_hist_smooth : (K,) ndarray
207
+ """
208
+ n = 2. * sigma ** 2
209
+ weights = np.exp(-palette.distances / n)
210
+ norm_weights = weights / weights.sum(1)[:, np.newaxis]
211
+ color_hist_smooth = (norm_weights * color_hist).sum(1)
212
+ color_hist_smooth[color_hist_smooth < 1e-5] = 0
213
+ return color_hist_smooth
214
+
215
+
216
+ def histogram_colors_with_smoothing(lab_array, palette, sigma=10):
217
+ """
218
+ Assign colors in the image to nearby colors in the palette, weighted by
219
+ distance in Lab color space.
220
+
221
+ Parameters
222
+ ----------
223
+ lab_array (N,3) ndarray:
224
+ N is the number of data points, columns are L, a, b values.
225
+ palette : rayleigh.Palette
226
+ containing K colors.
227
+ sigma : float
228
+ (0,1] value to control the steepness of exponential falloff.
229
+ To see the effect:
230
+
231
+ >>> from pylab import *
232
+ >>> ds = linspace(0,5000) # squared distance
233
+ >>> sigma=10; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
234
+ >>> sigma=20; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
235
+ >>> sigma=40; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
236
+ >>> ylim([0,1]); legend();
237
+ >>> xlabel('Squared distance'); ylabel('Weight');
238
+ >>> title('Exponential smoothing')
239
+ >>> #plt.savefig('exponential_smoothing.png', dpi=300)
240
+
241
+ sigma=20 seems reasonable: hits 0 around squared distance of 4000.
242
+
243
+ Returns:
244
+ color_hist : (K,) ndarray
245
+ the normalized, smooth histogram of colors.
246
+ """
247
+ dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T
248
+ n = 2. * sigma ** 2
249
+ weights = np.exp(-dist / n)
250
+
251
+ # normalize by sum: if a color is equally well represented by several colors
252
+ # it should not contribute much to the overall histogram
253
+ normalizing = weights.sum(1)
254
+ normalizing[normalizing == 0] = 1e16
255
+ normalized_weights = weights / normalizing[:, np.newaxis]
256
+
257
+ color_hist = normalized_weights.sum(0)
258
+ color_hist /= lab_array.shape[0]
259
+ color_hist[color_hist < 1e-5] = 0
260
+ return color_hist
261
+
262
+
263
+ def makedirs(dirname):
264
+ "Does what mkdir -p does, and returns dirname."
265
+ if not os.path.exists(dirname):
266
+ try:
267
+ os.makedirs(dirname)
268
+ except:
269
+ print("Exception on os.makedirs")
270
+ return dirname
annotator/content/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoProcessor, CLIPModel
7
+
8
+ from annotator.util import annotator_ckpts_path
9
+
10
+
11
+ class ContentDetector:
12
+ def __init__(self, model_name="openai/clip-vit-large-patch14"):
13
+
14
+ self.model = CLIPModel.from_pretrained(model_name, cache_dir=annotator_ckpts_path).cuda().eval()
15
+ self.processor = AutoProcessor.from_pretrained(model_name, cache_dir=annotator_ckpts_path)
16
+
17
+ def __call__(self, img):
18
+ with torch.no_grad():
19
+ img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
20
+ inputs = self.processor(images=[img], return_tensors="pt").to('cuda')
21
+ image_features = self.model.get_image_features(**inputs)
22
+ content_emb = image_features[0].detach().cpu().numpy()
23
+ return content_emb
annotator/entityseg/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
3
+ import argparse
4
+ import glob
5
+ import multiprocessing as mp
6
+ import os
7
+ import sys
8
+ sys.path.insert(1, os.getcwd())
9
+
10
+ import tempfile
11
+ import time
12
+ import warnings
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import tqdm
17
+ import torch
18
+
19
+ from detectron2.config import get_cfg
20
+ from detectron2.data.detection_utils import read_image
21
+ from detectron2.projects.deeplab import add_deeplab_config
22
+ from detectron2.utils.logger import setup_logger
23
+
24
+ from mask2former import add_maskformer2_config
25
+ from predictor import VisualizationDemo
26
+
27
+ from annotator.util import annotator_ckpts_path
28
+
29
+
30
+ model_url = "https://huggingface.co/datasets/qqlu1992/Adobe_EntitySeg/resolve/main/CropFormer_model/Entity_Segmentation/CropFormer_hornet_3x.pth"
31
+
32
+
33
+ def make_colors():
34
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
35
+ colors = []
36
+ for cate in COCO_CATEGORIES:
37
+ colors.append(cate["color"])
38
+ return colors
39
+
40
+
41
+ class EntitysegDetector:
42
+
43
+ def __init__(self, confidence_threshold=0.5):
44
+ cfg = get_cfg()
45
+ add_deeplab_config(cfg)
46
+ add_maskformer2_config(cfg)
47
+
48
+ workdir = os.getcwd()
49
+ config_file = f"{workdir}/annotator/entityseg/configs/cropformer_hornet_3x.yaml"
50
+ model_path = f'{annotator_ckpts_path}/CropFormer_hornet_3x_03823a.pth'
51
+ # Authentication required
52
+ # if not os.path.exists(model_path):
53
+ # from basicsr.utils.download_util import load_file_from_url
54
+ # load_file_from_url(model_url, model_dir=annotator_ckpts_path)
55
+
56
+ cfg.merge_from_file(config_file)
57
+ opts = ['MODEL.WEIGHTS', model_path]
58
+ cfg.merge_from_list(opts)
59
+ cfg.freeze()
60
+
61
+ self.confidence_threshold = confidence_threshold
62
+
63
+ self.colors = make_colors()
64
+ self.demo = VisualizationDemo(cfg)
65
+
66
+
67
+ def __call__(self, image):
68
+ predictions = self.demo.run_on_image(image)
69
+ ##### color_mask
70
+ pred_masks = predictions["instances"].pred_masks
71
+ pred_scores = predictions["instances"].scores
72
+
73
+ # select by confidence threshold
74
+ selected_indexes = (pred_scores >= self.confidence_threshold)
75
+ selected_scores = pred_scores[selected_indexes]
76
+ selected_masks = pred_masks[selected_indexes]
77
+ _, m_H, m_W = selected_masks.shape
78
+ mask_id = np.zeros((m_H, m_W), dtype=np.uint8)
79
+
80
+ # rank
81
+ selected_scores, ranks = torch.sort(selected_scores)
82
+ ranks = ranks + 1
83
+ for index in ranks:
84
+ mask_id[(selected_masks[index-1]==1).cpu().numpy()] = int(index)
85
+ unique_mask_id = np.unique(mask_id)
86
+
87
+ color_mask = np.zeros(image.shape, dtype=np.uint8)
88
+ for count in unique_mask_id:
89
+ if count == 0:
90
+ continue
91
+ color_mask[mask_id==count] = self.colors[count % len(self.colors)]
92
+
93
+ return color_mask
annotator/entityseg/configs/Base-Mask2Former.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ENTITY:
2
+ ENABLE: True
3
+ MODEL:
4
+ BACKBONE:
5
+ FREEZE_AT: 0
6
+ NAME: "build_resnet_backbone"
7
+ WEIGHTS: "R-50.pkl"
8
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
9
+ PIXEL_STD: [58.395, 57.120, 57.375]
10
+ RESNETS:
11
+ DEPTH: 50
12
+ STEM_TYPE: "basic" # not used
13
+ STEM_OUT_CHANNELS: 64
14
+ STRIDE_IN_1X1: False
15
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
16
+ # NORM: "SyncBN"
17
+ RES5_MULTI_GRID: [1, 1, 1] # not used
18
+ DATASETS:
19
+ TRAIN: ("entityv2_entity_train_01",)
20
+ TEST: ("entityv2_entity_val_01",)
21
+ SOLVER:
22
+ STEPS: (30525, 33138)
23
+ MAX_ITER: 34375
24
+ IMS_PER_BATCH: 16
25
+ BASE_LR: 0.0001
26
+ WARMUP_FACTOR: 1.0
27
+ WARMUP_ITERS: 0
28
+ WEIGHT_DECAY: 0.05
29
+ OPTIMIZER: "ADAMW"
30
+ LR_SCHEDULER_NAME: "WarmupPolyLR"
31
+ BACKBONE_MULTIPLIER: 0.1
32
+ CLIP_GRADIENTS:
33
+ ENABLED: True
34
+ CLIP_TYPE: "full_model"
35
+ CLIP_VALUE: 0.01
36
+ NORM_TYPE: 2.0
37
+ AMP:
38
+ ENABLED: True
39
+ INPUT:
40
+ MASK_FORMAT: "bitmask"
41
+ FORMAT: "RGB"
42
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
43
+ DATASET_MAPPER_NAME: "entity_crop"
44
+ TEST:
45
+ EVAL_PERIOD: 400000
46
+ DATALOADER:
47
+ FILTER_EMPTY_ANNOTATIONS: True
48
+ NUM_WORKERS: 32
49
+ VERSION: 2
annotator/entityseg/configs/cropformer_hornet_3x.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-Mask2Former.yaml
2
+ DATALOADER:
3
+ NUM_WORKERS: 32
4
+ DATASETS:
5
+ TRAIN: ("entityv2_entity_train_01","entityv2_entity_train_02","entityv2_entity_train_03",)
6
+ TEST: ("entityv2_entity_val_all",)
7
+ # TEST: ("entityv2_entity_val_all_lr",)
8
+ SOLVER:
9
+ # STEPS: (91575, 99414)
10
+ # MAX_ITER: 103125
11
+ IMS_PER_BATCH: 8
12
+ STEPS: (183150, 198828)
13
+ MAX_ITER: 206250
14
+ MODEL:
15
+ BACKBONE:
16
+ NAME: "D2HorNet"
17
+ PIXEL_MEAN: [123.675, 116.28, 103.53]
18
+ PIXEL_STD: [58.395, 57.120, 57.375]
19
+ SWIN:
20
+ EMBED_DIM: 192
21
+ DEPTHS: [2, 2, 18, 2]
22
+ NUM_HEADS: [6, 12, 24, 48]
23
+ WINDOW_SIZE: 7
24
+ APE: False
25
+ DROP_PATH_RATE: 0.3
26
+ PATCH_NORM: True
27
+ PRETRAIN_IMG_SIZE: 384
28
+ WEIGHTS: "hornet_l_pretrained.pth"
29
+ META_ARCHITECTURE: "CropFormer"
30
+ SEM_SEG_HEAD:
31
+ NAME: "MaskFormerHead"
32
+ IGNORE_VALUE: 255
33
+ NUM_CLASSES: 1
34
+ LOSS_WEIGHT: 1.0
35
+ CONVS_DIM: 256
36
+ MASK_DIM: 256
37
+ NORM: "GN"
38
+ # pixel decoder
39
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
40
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
41
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
42
+ COMMON_STRIDE: 4
43
+ TRANSFORMER_ENC_LAYERS: 6
44
+ MASK_FORMER:
45
+ TRANSFORMER_DECODER_NAME: "CropSharedMultiScaleMaskedTransformerDecoder"
46
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
47
+ DEEP_SUPERVISION: True
48
+ NO_OBJECT_WEIGHT: 0.1
49
+ CLASS_WEIGHT: 2.0
50
+ MASK_WEIGHT: 5.0
51
+ DICE_WEIGHT: 5.0
52
+ HIDDEN_DIM: 256
53
+ NUM_OBJECT_QUERIES: 200
54
+ NHEADS: 8
55
+ DROPOUT: 0.0
56
+ DIM_FEEDFORWARD: 2048
57
+ ENC_LAYERS: 0
58
+ PRE_NORM: False
59
+ ENFORCE_INPUT_PROJ: False
60
+ SIZE_DIVISIBILITY: 32
61
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
62
+ TRAIN_NUM_POINTS: 12544
63
+ OVERSAMPLE_RATIO: 3.0
64
+ IMPORTANCE_SAMPLE_RATIO: 0.75
65
+ TEST:
66
+ SEMANTIC_ON: False
67
+ INSTANCE_ON: True
68
+ PANOPTIC_ON: False
69
+ OVERLAP_THRESHOLD: 0.8
70
+ OBJECT_MASK_THRESHOLD: 0.8
annotator/entityseg/mask2former/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import data # register all new datasets
3
+ from . import modeling
4
+
5
+ # config
6
+ from .config import add_maskformer2_config
7
+
8
+ # models
9
+ from .maskformer_model import MaskFormer
10
+ from .cropformer_model import CropFormer
11
+ from .test_time_augmentation import SemanticSegmentorWithTTA
annotator/entityseg/mask2former/config.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from detectron2.config import CfgNode as CN
4
+
5
+
6
+ def add_maskformer2_config(cfg):
7
+ """
8
+ Add config for MASK_FORMER.
9
+ """
10
+ # NOTE: configs from original maskformer
11
+ # data config
12
+ # select the dataset mapper
13
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
14
+ # Color augmentation
15
+ cfg.INPUT.COLOR_AUG_SSD = False
16
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
17
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
18
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
19
+ # Pad image and segmentation GT in dataset mapper.
20
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
21
+
22
+ # solver config
23
+ # weight decay on embedding
24
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
25
+ # optimizer
26
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
27
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
28
+
29
+ # mask_former model config
30
+ cfg.MODEL.MASK_FORMER = CN()
31
+
32
+ # loss
33
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
34
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
35
+ cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = 1.0
36
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
37
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
38
+
39
+ # transformer config
40
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
41
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
42
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
43
+ cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
44
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
45
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
46
+
47
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
48
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
49
+
50
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
51
+ cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
52
+
53
+ # mask_former inference config
54
+ cfg.MODEL.MASK_FORMER.TEST = CN()
55
+ cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
56
+ cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
57
+ cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
58
+ cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
59
+ cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
60
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
61
+
62
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
63
+ # you can use this config to override
64
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
65
+
66
+ # pixel decoder config
67
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
68
+ # adding transformer in pixel decoder
69
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
70
+ # pixel decoder
71
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
72
+
73
+ # swin transformer backbone
74
+ cfg.MODEL.SWIN = CN()
75
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
76
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
77
+ cfg.MODEL.SWIN.EMBED_DIM = 96
78
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
79
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
80
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
81
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
82
+ cfg.MODEL.SWIN.QKV_BIAS = True
83
+ cfg.MODEL.SWIN.QK_SCALE = None
84
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
85
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
86
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
87
+ cfg.MODEL.SWIN.APE = False
88
+ cfg.MODEL.SWIN.PATCH_NORM = True
89
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
90
+ cfg.MODEL.SWIN.USE_CHECKPOINT = False
91
+
92
+ # NOTE: maskformer2 extra configs
93
+ # transformer module
94
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder"
95
+
96
+ # LSJ aug
97
+ cfg.INPUT.IMAGE_SIZE = 1024
98
+ cfg.INPUT.MIN_SCALE = 0.1
99
+ cfg.INPUT.MAX_SCALE = 2.0
100
+
101
+ # MSDeformAttn encoder configs
102
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
103
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
104
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
105
+
106
+ # point loss configs
107
+ # Number of points sampled during training for a mask point head.
108
+ cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = 112 * 112
109
+ # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
110
+ # original paper.
111
+ cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
112
+ # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
113
+ # the original paper.
114
+ cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
115
+
116
+ ## For Entity
117
+ cfg.ENTITY = CN()
118
+ cfg.ENTITY.ENABLE = False
119
+ cfg.ENTITY.CROP_AREA_RATIO = 0.7
120
+ cfg.ENTITY.CROP_STRIDE_RATIO = 0.6
121
+ cfg.ENTITY.CROP_SAMPLE_NUM_TRAIN = 1
122
+ cfg.ENTITY.CROP_SAMPLE_NUM_TEST = 4
123
+
124
+ ## fuse frame embeddings to batch embedding
125
+ cfg.ENTITY.FUSE_NUM_LAYERS = 1
126
+ cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM = 256
127
+ cfg.ENTITY.FUSE_ENC_NHEADS = 8
128
+ cfg.ENTITY.FUSE_ENC_PRE_NORM = False
129
+ cfg.ENTITY.FUSE_ENC_DIM_FEEDFORWARD = 2048
130
+ cfg.ENTITY.FUSE_ENC_LAST_LAYERS = 1
131
+ cfg.ENTITY.FUSE_DEC_NUM_LAYERS = 3
132
+
133
+ ## Hornet backbone
134
+ cfg.MODEL.HORNET = CN()
135
+ cfg.MODEL.HORNET.DEPTHS = [2, 3, 18, 2]
136
+ cfg.MODEL.HORNET.BASE_DIM = 192
137
+ cfg.MODEL.HORNET.GCONV = ['partial(gnconv, order=2, s=1/3)', 'partial(gnconv, order=3, s=1/3)', 'partial(gnconv, order=4, s=1/3, h=24, w=13, gflayer=GlobalLocalFilter)', 'partial(gnconv, order=5, s=1/3, h=12, w=7, gflayer=GlobalLocalFilter)']
138
+ cfg.MODEL.HORNET.DROP_PATH_RATE=0.6
139
+ cfg.MODEL.HORNET.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
annotator/entityseg/mask2former/cropformer_model.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ import pdb
8
+ import numpy as np
9
+ import cv2
10
+ import os
11
+
12
+ from detectron2.config import configurable
13
+ from detectron2.data import MetadataCatalog
14
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
15
+ from detectron2.modeling.backbone import Backbone
16
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
17
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
18
+ from detectron2.utils.memory import retry_if_cuda_oom
19
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
20
+
21
+ from .modeling.criterion import SetCriterion
22
+ from .modeling.matcher import HungarianMatcher
23
+ from .modeling.criterion_view import ViewSetCriterion
24
+ from .modeling.matcher_view import ViewHungarianMatcher
25
+ import pdb
26
+ import copy
27
+
28
+ @META_ARCH_REGISTRY.register()
29
+ class CropFormer(nn.Module):
30
+ """
31
+ Main class for mask classification semantic segmentation architectures.
32
+ """
33
+ @configurable
34
+ def __init__(
35
+ self,
36
+ *,
37
+ cfg,
38
+ backbone: Backbone,
39
+ sem_seg_head: nn.Module,
40
+ criterion_2d: nn.Module,
41
+ criterion_3d: nn.Module,
42
+ num_queries: int,
43
+ object_mask_threshold: float,
44
+ overlap_threshold: float,
45
+ metadata,
46
+ size_divisibility: int,
47
+ sem_seg_postprocess_before_inference: bool,
48
+ pixel_mean: Tuple[float],
49
+ pixel_std: Tuple[float],
50
+ # inference
51
+ semantic_on: bool,
52
+ panoptic_on: bool,
53
+ instance_on: bool,
54
+ test_topk_per_image: int,
55
+ ):
56
+ """
57
+ Args:
58
+ backbone: a backbone module, must follow detectron2's backbone interface
59
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
60
+ criterion: a module that defines the loss
61
+ num_queries: int, number of queries
62
+ object_mask_threshold: float, threshold to filter query based on classification score
63
+ for panoptic segmentation inference
64
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
65
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
66
+ segmentation inference
67
+ size_divisibility: Some backbones require the input height and width to be divisible by a
68
+ specific integer. We can use this to override such requirement.
69
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
70
+ to original input size before semantic segmentation inference or after.
71
+ For high-resolution dataset like Mapillary, resizing predictions before
72
+ inference will cause OOM error.
73
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
74
+ the per-channel mean and std to be used to normalize the input image
75
+ semantic_on: bool, whether to output semantic segmentation prediction
76
+ instance_on: bool, whether to output instance segmentation prediction
77
+ panoptic_on: bool, whether to output panoptic segmentation prediction
78
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
79
+ """
80
+ super().__init__()
81
+ self.cfg = cfg
82
+ self.backbone = backbone
83
+ self.sem_seg_head = sem_seg_head
84
+ self.criterion_2d = criterion_2d
85
+ self.criterion_3d = criterion_3d
86
+ ## colors
87
+ self.colors = [info["color"] for info in COCO_CATEGORIES]
88
+
89
+ self.num_queries = num_queries
90
+ self.overlap_threshold = overlap_threshold
91
+ self.object_mask_threshold = object_mask_threshold
92
+ self.metadata = metadata
93
+ if size_divisibility < 0:
94
+ # use backbone size_divisibility if not set
95
+ size_divisibility = self.backbone.size_divisibility
96
+ self.size_divisibility = size_divisibility
97
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
98
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
99
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
100
+
101
+ ## colors
102
+ self.colors = [info["color"] for info in COCO_CATEGORIES]
103
+
104
+ # additional args
105
+ self.semantic_on = semantic_on
106
+ self.instance_on = instance_on
107
+ self.panoptic_on = panoptic_on
108
+ self.test_topk_per_image = test_topk_per_image
109
+
110
+ if not self.semantic_on:
111
+ assert self.sem_seg_postprocess_before_inference
112
+
113
+ @classmethod
114
+ def from_config(cls, cfg):
115
+ backbone = build_backbone(cfg)
116
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
117
+
118
+ # Loss parameters:
119
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
120
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
121
+
122
+ # loss weights
123
+ class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
124
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
125
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
126
+
127
+ # building criterion
128
+ matcher_2d = HungarianMatcher(
129
+ cost_class=class_weight,
130
+ cost_mask=mask_weight,
131
+ cost_dice=dice_weight,
132
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
133
+ )
134
+
135
+ matcher_3d = ViewHungarianMatcher(
136
+ cost_class=class_weight,
137
+ cost_mask=mask_weight,
138
+ cost_dice=dice_weight,
139
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
140
+ )
141
+
142
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
143
+
144
+ if deep_supervision:
145
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
146
+ aux_weight_dict = {}
147
+ for i in range(dec_layers - 1):
148
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
149
+ weight_dict.update(aux_weight_dict)
150
+
151
+ losses = ["labels", "masks"]
152
+
153
+ criterion_2d = SetCriterion(
154
+ sem_seg_head.num_classes,
155
+ matcher=matcher_2d,
156
+ weight_dict=weight_dict,
157
+ eos_coef=no_object_weight,
158
+ losses=losses,
159
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
160
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
161
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
162
+ )
163
+
164
+ criterion_3d = ViewSetCriterion(
165
+ sem_seg_head.num_classes,
166
+ matcher=matcher_3d,
167
+ weight_dict=weight_dict,
168
+ eos_coef=no_object_weight,
169
+ losses=losses,
170
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
171
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
172
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
173
+ )
174
+
175
+ return {
176
+ "cfg": cfg,
177
+ "backbone": backbone,
178
+ "sem_seg_head": sem_seg_head,
179
+ "criterion_2d": criterion_2d,
180
+ "criterion_3d": criterion_3d,
181
+ "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
182
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
183
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
184
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
185
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
186
+ "sem_seg_postprocess_before_inference": (
187
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
188
+ or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
189
+ or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
190
+ ),
191
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
192
+ "pixel_std": cfg.MODEL.PIXEL_STD,
193
+ # inference
194
+ "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON,
195
+ "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON,
196
+ "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
197
+ "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
198
+ }
199
+
200
+ @property
201
+ def device(self):
202
+ return self.pixel_mean.device
203
+
204
+ def forward(self, batched_inputs):
205
+ """
206
+ Args:
207
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
208
+ Each item in the list contains the inputs for one image.
209
+ For now, each item in the list is a dict that contains:
210
+ * "image": Tensor, image in (C, H, W) format.
211
+ * "instances": per-region ground truth
212
+ * Other information that's included in the original dicts, such as:
213
+ "height", "width" (int): the output resolution of the model (may be different
214
+ from input resolution), used in inference.
215
+ Returns:
216
+ list[dict]:
217
+ each dict has the results for one image. The dict contains the following keys:
218
+
219
+ * "sem_seg":
220
+ A Tensor that represents the
221
+ per-pixel segmentation prediced by the head.
222
+ The prediction has shape KxHxW that represents the logits of
223
+ each class for each pixel.
224
+ * "panoptic_seg":
225
+ A tuple that represent panoptic output
226
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
227
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
228
+ Each dict contains keys "id", "category_id", "isthing".
229
+ """
230
+ ## make new images
231
+ batched_inputs_new = []
232
+ for batched_input in batched_inputs:
233
+ ori_infos = {"height": batched_input["height"],
234
+ "width": batched_input["width"],
235
+ "image": batched_input["image"],
236
+ # "file_name": batched_input["file_name"],
237
+ }
238
+ if "instances" in batched_input.keys():
239
+ ori_instances = batched_input["instances"]
240
+ ori_instances.original_indices = torch.arange(0, len(ori_instances)).long()
241
+ ori_infos["instances"] = ori_instances
242
+ batched_inputs_new.append(ori_infos)
243
+ ## cropped patches
244
+ # pdb.set_trace()
245
+ crop_region = batched_input["crop_region"]
246
+ crop_images = batched_input["image_crop"]
247
+ crop_o_width = int(crop_region[0][2]-crop_region[0][0])
248
+ crop_o_height = int(crop_region[0][3]-crop_region[0][1])
249
+
250
+ if "instances_crop" in batched_input.keys():
251
+ crop_instances = batched_input["instances_crop"]
252
+ else:
253
+ crop_instances = None
254
+
255
+ for crop_index, crop_image in enumerate(crop_images):
256
+ crop_infos = {"height": crop_o_height, "width": crop_o_width, "image": crop_image}
257
+ if not crop_instances == None:
258
+ crop_instance = crop_instances[crop_index]
259
+ crop_instance.original_indices = torch.arange(0, len(crop_instance)).long()
260
+ crop_infos["instances"] = crop_instance
261
+ batched_inputs_new.append(crop_infos)
262
+
263
+ images = [x["image"].to(self.device) for x in batched_inputs_new]
264
+ ## +1 means
265
+ num_views = self.cfg.ENTITY.CROP_SAMPLE_NUM_TRAIN+1 if self.training else self.cfg.ENTITY.CROP_SAMPLE_NUM_TEST+1
266
+ for i in range(len(images)):
267
+ if i%num_views==0:
268
+ continue
269
+ _, c_h, c_w = images[i].shape
270
+ if "instances" in batched_inputs_new[i].keys():
271
+ batched_inputs_new[i]["instances"]._image_size = (c_h, c_w)
272
+
273
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
274
+ images = ImageList.from_tensors(images, self.size_divisibility)
275
+
276
+ features = self.backbone(images.tensor)
277
+ outputs_2d, outputs_3d = self.sem_seg_head(features)
278
+
279
+ if self.training:
280
+ if self.cfg.ENTITY.ENABLE:
281
+ for i in range(len(batched_inputs_new)):
282
+ batched_inputs_new[i]["instances"].gt_classes[:] = 0
283
+
284
+ if "instances" in batched_inputs[0]:
285
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs_new]
286
+ targets_2d = self.prepare_targets_2d(copy.deepcopy(gt_instances), copy.deepcopy(images))
287
+ targets_3d = self.prepare_targets_3d(copy.deepcopy(gt_instances), copy.deepcopy(images), num_views)
288
+ else:
289
+ targets = None
290
+
291
+ # bipartite matching-based loss
292
+ losses = {}
293
+ losses_2d = self.criterion_2d(outputs_2d, targets_2d)
294
+ losses_3d = self.criterion_3d(outputs_3d, targets_3d)
295
+
296
+ for k in list(losses_2d.keys()):
297
+ if k in self.criterion_2d.weight_dict:
298
+ losses[k+"_2d"] = losses_2d[k] * self.criterion_2d.weight_dict[k] * 0.5
299
+ else:
300
+ # remove this loss if not specified in `weight_dict`
301
+ losses_2d.pop(k)
302
+
303
+ for k in list(losses_3d.keys()):
304
+ if k in self.criterion_3d.weight_dict:
305
+ losses[k+"_3d"] = losses_3d[k] * self.criterion_3d.weight_dict[k]
306
+ else:
307
+ # remove this loss if not specified in `weight_dict`
308
+ losses_3d.pop(k)
309
+ return losses
310
+ else:
311
+ mask_cls_results_3d = outputs_3d["pred_logits"][0] ## 100,2
312
+ mask_pred_results_3d = outputs_3d["pred_masks"][0] ## 100,5,200, 304
313
+
314
+ mask_cls_results_2d = outputs_2d["pred_logits"]
315
+ mask_pred_results_2d = outputs_2d["pred_masks"]
316
+ # upsample masks
317
+
318
+ mask_pred_results_3d = retry_if_cuda_oom(F.interpolate)(
319
+ mask_pred_results_3d,
320
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
321
+ mode="bilinear",
322
+ align_corners=False,
323
+ )
324
+
325
+ mask_pred_results_2d = F.interpolate(
326
+ mask_pred_results_2d,
327
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
328
+ mode="bilinear",
329
+ align_corners=False,
330
+ )
331
+
332
+ del outputs_2d, outputs_3d
333
+
334
+ crop_regions = batched_input["crop_region"][:num_views-1]
335
+ processed_results = retry_if_cuda_oom(self.inference_whole_views)(
336
+ mask_cls_results_3d,
337
+ mask_pred_results_3d,
338
+ mask_cls_results_2d,
339
+ mask_pred_results_2d,
340
+ batched_inputs_new,
341
+ images.image_sizes,
342
+ crop_regions)
343
+
344
+ # processed_results = retry_if_cuda_oom(self.instance_inference_nonoverlap)(
345
+ # mask_cls_results_2d[0],
346
+ # mask_pred_results_2d[0],
347
+ # batched_inputs_new[0],
348
+ # images.image_sizes[0])
349
+
350
+ return [{"instances": processed_results}]
351
+
352
+ def prepare_targets_2d(self, targets, images):
353
+ h_pad, w_pad = images.tensor.shape[-2:]
354
+ new_targets = []
355
+ for targets_per_image in targets:
356
+ gt_masks = targets_per_image.gt_masks.tensor
357
+ gt_valid = targets_per_image.gt_boxes_valid
358
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
359
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
360
+ valid_index = torch.nonzero(gt_valid).flatten()
361
+ new_targets.append(
362
+ {
363
+ "labels": targets_per_image.gt_classes[valid_index],
364
+ "masks": padded_masks[valid_index],
365
+ }
366
+ )
367
+ return new_targets
368
+
369
+ def prepare_targets_3d(self, targets_ori, images, num_views):
370
+ T = num_views
371
+ B = int(len(targets_ori) / T)
372
+ h_pad, w_pad = images.tensor.shape[-2:]
373
+
374
+ ## reshape to new targets
375
+ new_targets = []
376
+ for count, target in enumerate(targets_ori):
377
+ b_index, t_index = int(count // T), int(count % T)
378
+ if t_index == 0:
379
+ new_targets.append([target])
380
+ else:
381
+ new_targets[b_index].append(target)
382
+
383
+ gt_instances = []
384
+ for count, targets in enumerate(new_targets):
385
+ _num_instance = len(targets[0])
386
+ mask_shape = [_num_instance, T, h_pad, w_pad]
387
+ gt_masks_per_view = torch.zeros(mask_shape, dtype=torch.bool, device=self.device)
388
+
389
+ for v_i, targets_per_view in enumerate(targets):
390
+ assert torch.all(targets[0].original_indices == targets_per_view.original_indices)
391
+
392
+ gt_ids_per_view = []
393
+ gt_ids_per_valid = []
394
+ gt_ids_categories = []
395
+ ## view first, then entities
396
+ for v_i, targets_per_view in enumerate(targets):
397
+ targets_per_view = targets_per_view.to(self.device)
398
+ h, w = targets_per_view.image_size
399
+ for i_i, (instance_mask, instance_valid) in enumerate(zip(targets_per_view.gt_masks.tensor, targets_per_view.gt_boxes_valid)):
400
+ if instance_valid == 1:
401
+ gt_masks_per_view[i_i, v_i, :h, :w] = instance_mask
402
+ gt_ids_per_valid.append(targets_per_view.gt_boxes_valid[None,:])
403
+ gt_ids_per_view.append(targets_per_view.original_indices[None,:])
404
+ gt_ids_categories.append(targets_per_view.gt_classes[None, :])
405
+ ## (num_instances, num_views)
406
+ gt_ids_per_valid = torch.cat(gt_ids_per_valid, dim=0).permute((1,0))
407
+ gt_ids_per_view = torch.cat(gt_ids_per_view, dim=0).permute((1,0))
408
+ gt_ids_categories = torch.cat(gt_ids_categories, dim=0).permute((1,0))
409
+
410
+ gt_ids_per_view[gt_ids_per_valid == 0] = -1
411
+ valid_idx = (gt_ids_per_view != 1).any(dim=-1)
412
+ ## categoreis
413
+ gt_classes_per_group = gt_ids_categories[:,0] ## N
414
+ gt_ids_per_group = gt_ids_per_view ## N, num_views
415
+ gt_masks_per_group = gt_masks_per_view.float() ## N, num_views, H, W
416
+
417
+ ##
418
+ gt_instances.append({"labels": gt_classes_per_group,
419
+ "ids": gt_ids_per_group,
420
+ "masks": gt_masks_per_group})
421
+
422
+ return gt_instances
423
+
424
+ def semantic_inference(self, mask_cls, mask_pred):
425
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
426
+ mask_pred = mask_pred.sigmoid()
427
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
428
+ return semseg
429
+
430
+ def panoptic_inference(self, mask_cls, mask_pred):
431
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
432
+ mask_pred = mask_pred.sigmoid()
433
+
434
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
435
+ cur_scores = scores[keep]
436
+ cur_classes = labels[keep]
437
+ cur_masks = mask_pred[keep]
438
+ cur_mask_cls = mask_cls[keep]
439
+ cur_mask_cls = cur_mask_cls[:, :-1]
440
+
441
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
442
+
443
+ h, w = cur_masks.shape[-2:]
444
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
445
+ segments_info = []
446
+
447
+ current_segment_id = 0
448
+
449
+ if cur_masks.shape[0] == 0:
450
+ # We didn't detect any mask :(
451
+ return panoptic_seg, segments_info
452
+ else:
453
+ # take argmax
454
+ cur_mask_ids = cur_prob_masks.argmax(0)
455
+ stuff_memory_list = {}
456
+ for k in range(cur_classes.shape[0]):
457
+ pred_class = cur_classes[k].item()
458
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
459
+ mask_area = (cur_mask_ids == k).sum().item()
460
+ original_area = (cur_masks[k] >= 0.5).sum().item()
461
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
462
+
463
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
464
+ if mask_area / original_area < self.overlap_threshold:
465
+ continue
466
+
467
+ # merge stuff regions
468
+ if not isthing:
469
+ if int(pred_class) in stuff_memory_list.keys():
470
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
471
+ continue
472
+ else:
473
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
474
+
475
+ current_segment_id += 1
476
+ panoptic_seg[mask] = current_segment_id
477
+
478
+ segments_info.append(
479
+ {
480
+ "id": current_segment_id,
481
+ "isthing": bool(isthing),
482
+ "category_id": int(pred_class),
483
+ }
484
+ )
485
+ return panoptic_seg, segments_info
486
+
487
+ def instance_inference_nonoverlap(self, mask_cls, mask_pred):
488
+ # mask_pred is already processed to have the same shape as original input
489
+ image_size = mask_pred.shape[-2:]
490
+
491
+ # [Q, K]
492
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
493
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
494
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
495
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
496
+ labels_per_image = labels[topk_indices]
497
+
498
+ topk_indices = topk_indices // self.sem_seg_head.num_classes
499
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
500
+ mask_pred = mask_pred[topk_indices]
501
+
502
+ ###### ranks
503
+ pred_masks = (mask_pred>0).float()
504
+ pred_masks_logits = mask_pred.sigmoid()
505
+ pred_scores = scores_per_image
506
+
507
+ _, m_H, m_W = pred_masks.shape
508
+ mask_id = torch.zeros((m_H, m_W), dtype=torch.int).to(pred_masks.device)
509
+ sorted_scores, ranks = torch.sort(pred_scores)
510
+ ranks = ranks + 1
511
+ for index in ranks:
512
+ mask_id[(pred_masks[index-1]==1)] = int(index)
513
+ # re-generate mask
514
+ new_scores = []
515
+ new_masks = []
516
+ new_masks_logits = []
517
+ entity_nums = len(ranks)
518
+ for ii in range(entity_nums):
519
+ index = int(ranks[entity_nums-ii-1])
520
+ score = sorted_scores[entity_nums-ii-1]
521
+ new_scores.append(score)
522
+ new_masks.append((mask_id==index).float())
523
+ new_masks_logits.append(pred_masks_logits[index-1])
524
+
525
+ new_scores = torch.stack(new_scores)
526
+ new_masks = torch.stack(new_masks)
527
+ new_masks_logits = torch.stack(new_masks_logits)
528
+
529
+ result = Instances(image_size)
530
+ # mask (before sigmoid)
531
+ result.pred_masks = new_masks
532
+ result.pred_boxes = Boxes(torch.zeros(new_masks.size(0), 4))
533
+ # Uncomment the following to get boxes from masks (this is slow)
534
+
535
+ # calculate average mask prob
536
+ mask_scores_per_image = (new_masks_logits.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
537
+ result.scores = new_scores * mask_scores_per_image
538
+ result.pred_classes = labels_per_image
539
+ return result
540
+
541
+ def instance_inference(self, mask_cls, mask_pred):
542
+ # mask_pred is already processed to have the same shape as original input
543
+ image_size = mask_pred.shape[-2:]
544
+
545
+ # [Q, K]
546
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
547
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
548
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
549
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
550
+ labels_per_image = labels[topk_indices]
551
+
552
+ topk_indices = topk_indices // self.sem_seg_head.num_classes
553
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
554
+ mask_pred = mask_pred[topk_indices]
555
+
556
+ # if this is panoptic segmentation, we only keep the "thing" classes
557
+ if self.panoptic_on:
558
+ keep = torch.zeros_like(scores_per_image).bool()
559
+ for i, lab in enumerate(labels_per_image):
560
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
561
+
562
+ scores_per_image = scores_per_image[keep]
563
+ labels_per_image = labels_per_image[keep]
564
+ mask_pred = mask_pred[keep]
565
+
566
+ result = Instances(image_size)
567
+ # mask (before sigmoid)
568
+ result.pred_masks = (mask_pred > 0).float()
569
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
570
+ # Uncomment the following to get boxes from masks (this is slow)
571
+ # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
572
+
573
+ # calculate average mask prob
574
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
575
+ # pdb.set_trace()
576
+ result.scores = scores_per_image * mask_scores_per_image
577
+ result.pred_classes = labels_per_image
578
+ return result
579
+
580
+ def inference_whole_views(self, pred_cls, pred_masks, pred_cls_2d, pred_masks_2d, batched_inputs, image_sizes, crop_regions):
581
+ ## pred_masks: [100, 5, 800, 1216]
582
+ ## pred_masks_2d: [5, 100, 800, 1216]
583
+ scores = F.softmax(pred_cls, dim=-1)[:,:-1] # 100,1
584
+ scores_2d = F.softmax(pred_cls_2d, dim=-1)[:, :, :-1] # 5, 100, 1
585
+
586
+ # scores = (scores+scores_2d[0])/2
587
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
588
+ ### keep all the indices
589
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
590
+ labels_per_image = labels[topk_indices]
591
+ # topk_indices = topk_indices // self.sem_seg_head.num_classes
592
+ topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode="trunc")
593
+ pred_masks = pred_masks[topk_indices]
594
+ pred_masks = pred_masks.permute((1,0,2,3))
595
+
596
+ new_pred_masks = []
597
+ for view_index, (pred_masks_per_view, batched_input_per_view, image_size_per_view) in enumerate(zip(pred_masks, batched_inputs, image_sizes)):
598
+ O_H = batched_input_per_view["height"]
599
+ O_W = batched_input_per_view["width"]
600
+
601
+ SO_H, SO_W = image_size_per_view
602
+
603
+ pred_masks_per_view = pred_masks_per_view[..., : SO_H, :SO_W]
604
+ pred_masks_per_view = F.interpolate(pred_masks_per_view[None], size=(O_H, O_W), mode="bilinear", align_corners=False)
605
+
606
+ new_pred_masks.append(pred_masks_per_view[0].sigmoid())
607
+
608
+ ## fuse the masks
609
+ full_image_masks = new_pred_masks[0]
610
+
611
+ ## fuse crop image
612
+ fused_image_masks = torch.zeros_like(full_image_masks).float()
613
+ fused_image_masks_valid = torch.zeros_like(full_image_masks).float() + 1e-16
614
+ for crop_region_per_view, pred_masks_per_view in zip(crop_regions, new_pred_masks[1:]):
615
+ x0, y0, x1, y1 = crop_region_per_view
616
+ fused_image_masks[..., y0:y1, x0:x1] += pred_masks_per_view
617
+ fused_image_masks_valid[..., y0:y1, x0:x1] += 1
618
+
619
+ # add original masks
620
+ fused_image_masks += full_image_masks
621
+ fused_image_masks_valid += 1
622
+
623
+ ## average
624
+ fuse_image_masks = fused_image_masks / fused_image_masks_valid
625
+
626
+ ###### change to the single image, begin to non_overlap_supression
627
+ ## ranks
628
+ pred_masks_logits = fuse_image_masks
629
+ pred_masks = (fuse_image_masks>0.5).float()
630
+ pred_scores = scores_per_image
631
+
632
+ _, m_H, m_W = pred_masks.shape
633
+ ## for visualization
634
+ mask_id = torch.zeros((m_H, m_W), dtype=torch.int).to(pred_masks.device)
635
+
636
+ # mask_id_colors = np.zeros((m_H, m_W, 3), dtype=np.uint8)
637
+ # pred_masks_np = pred_masks.cpu().numpy()
638
+
639
+ sorted_scores, ranks = torch.sort(pred_scores)
640
+ ranks = ranks + 1
641
+ for index in ranks:
642
+ mask_id[(pred_masks[index-1]==1)] = int(index)
643
+ # mask_id_colors[(pred_masks_np[index-1]==1)] = self.colors[index]
644
+ # base_path = "/group/20018/gavinqi/vis_entityv2_release_debug"
645
+ # pdb.set_trace()
646
+ # file_name = batched_inputs[0]["file_name"]
647
+ # split_index, img_name = file_name.split("/")[-2:]
648
+ # save_name = img_name.split(".")[0]+".png"
649
+ # if not os.path.exists(os.path.join(base_path, save_name)):
650
+ # cv2.imwrite(os.path.join(base_path, save_name), mask_id_colors)
651
+ # re-generate mask
652
+ new_scores = []
653
+ new_masks = []
654
+ new_masks_logits = []
655
+ entity_nums = len(ranks)
656
+ for ii in range(entity_nums):
657
+ index = int(ranks[entity_nums-ii-1])
658
+ score = sorted_scores[entity_nums-ii-1]
659
+ new_scores.append(score)
660
+ new_masks.append((mask_id==index).float())
661
+ new_masks_logits.append(pred_masks_logits[index-1])
662
+
663
+ new_scores = torch.stack(new_scores)
664
+ new_masks = torch.stack(new_masks)
665
+ new_masks_logits = torch.stack(new_masks_logits)
666
+ # make result
667
+ image_size = (batched_inputs[0]["height"], batched_inputs[0]["width"])
668
+ result = Instances(image_size)
669
+ # mask (before sigmoid)
670
+ result.pred_masks = new_masks
671
+ result.pred_boxes = Boxes(torch.zeros(new_masks.size(0), 4))
672
+ # Uncomment the following to get boxes from masks (this is slow)
673
+
674
+ # calculate average mask prob
675
+ mask_scores_per_image = (new_masks_logits.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
676
+ result.scores = new_scores * mask_scores_per_image
677
+ result.pred_classes = labels_per_image
678
+ return result
annotator/entityseg/mask2former/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
annotator/entityseg/mask2former/data/dataset_mappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
annotator/entityseg/mask2former/data/dataset_mappers/crop_augmentations.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ """
4
+ Implement many useful :class:`Augmentation`.
5
+ """
6
+ import numpy as np
7
+ import sys
8
+ from typing import Tuple
9
+ from PIL import Image
10
+ import random
11
+
12
+ from fvcore.transforms.transform import NoOpTransform, Transform
13
+
14
+ from detectron2.data.transforms.augmentation import Augmentation
15
+ import pdb
16
+ import math
17
+
18
+ import logging
19
+ import numpy as np
20
+ import pycocotools.mask as mask_util
21
+ import torch
22
+ from PIL import Image
23
+ from collections import defaultdict
24
+ import copy
25
+ from detectron2.data import transforms as T
26
+ from detectron2.structures import (
27
+ BitMasks,
28
+ Boxes,
29
+ BoxMode,
30
+ Instances,
31
+ Keypoints,
32
+ PolygonMasks,
33
+ RotatedBoxes,
34
+ polygons_to_bitmask,
35
+ )
36
+ from detectron2.utils.file_io import PathManager
37
+
38
+ __all__ = [
39
+ "BatchResizeShortestEdge",
40
+ "EntityCrop",
41
+ ]
42
+
43
+ class BatchResizeTransform(Transform):
44
+ """
45
+ Resize the image to a target size.
46
+ """
47
+
48
+ def __init__(self, h, w, new_h, new_w, interp=None):
49
+ """
50
+ Args:
51
+ h, w (int): original image size
52
+ new_h, new_w (int): new image size
53
+ interp: PIL interpolation methods, defaults to bilinear.
54
+ """
55
+ # TODO decide on PIL vs opencv
56
+ super().__init__()
57
+ if interp is None:
58
+ interp = Image.BILINEAR
59
+ self._set_attributes(locals())
60
+
61
+ def apply_image(self, imgs, interp=None):
62
+ dim_num = len(imgs.shape)
63
+ assert dim_num == 4
64
+ interp_method = interp if interp is not None else self.interp
65
+ resized_imgs = []
66
+ for img in imgs:
67
+ if len(img.shape) > 2 and img.shape[2] == 1:
68
+ pil_image = Image.fromarray(img[:, :, 0], mode="L")
69
+ else:
70
+ pil_image = Image.fromarray(img)
71
+ pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
72
+ ret = np.asarray(pil_image)
73
+ if len(img.shape) > 2 and img.shape[2] == 1:
74
+ ret = np.expand_dims(ret, -1)
75
+ resized_imgs.append(ret)
76
+ resized_imgs = np.stack(resized_imgs)
77
+ return resized_imgs
78
+
79
+ def apply_coords(self, coords):
80
+ coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
81
+ coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
82
+ return coords
83
+
84
+ def apply_box(self, boxes):
85
+ boxes = boxes[0]
86
+ new_boxes = super(BatchResizeTransform, self).apply_box(boxes[:,:4])
87
+ boxes[...,:4] = new_boxes
88
+ return boxes[None]
89
+
90
+ def apply_segmentation(self, segmentation):
91
+ if len(segmentation.shape)==3:
92
+ segmentation = segmentation[..., None]
93
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
94
+ segmentation = segmentation[..., 0]
95
+ else:
96
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
97
+ return segmentation
98
+
99
+ class EntityCropTransform(Transform):
100
+ """
101
+ Consectively crop the images
102
+ """
103
+ def __init__(self, crop_axises, crop_indexes):
104
+ super().__init__()
105
+ self._set_attributes(locals())
106
+
107
+ def apply_image(self, img):
108
+ """
109
+ Args:
110
+ img (ndarray): of shape NxHxWxC, or HxWxC or HxW. The array can be
111
+ of type uint8 in range [0, 255], or floating point in range
112
+ [0, 1] or [0, 255]
113
+ returns:
114
+ ndarray: cropped images
115
+ """
116
+ dim_num = len(img.shape)
117
+ imgs = []
118
+
119
+ for crop_axis in self.crop_axises:
120
+ x0, y0, x1, y1 = crop_axis
121
+ if dim_num <= 3:
122
+ crop_img = img[y0:y1, x0:x1]
123
+ else:
124
+ crop_img = img[..., y0:y1, x0:x1, :]
125
+ imgs.append(crop_img)
126
+
127
+ if dim_num <= 3:
128
+ imgs = np.stack(imgs, axis=0)
129
+ else:
130
+ imgs = np.concatenate(imgs, axis=0)
131
+ return imgs
132
+
133
+ def apply_coords(self, coords: np.ndarray, x0, y0):
134
+ coords[:, 0] -= x0
135
+ coords[:, 1] -= y0
136
+ return coords
137
+
138
+ def apply_box(self, box: np.ndarray) -> np.ndarray:
139
+ """
140
+ box: Nx4, [x0, y0, x1, y1]
141
+ """
142
+ idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
143
+ coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
144
+ split_boxes = []
145
+ crop_ws, crop_hs = [], []
146
+ for crop_axis in self.crop_axises:
147
+ startw, starth, endw, endh = crop_axis
148
+ coords_new = self.apply_coords(copy.deepcopy(coords), startw, starth).reshape((-1, 4, 2))
149
+ minxy = coords_new.min(axis=1)
150
+ maxxy = coords_new.max(axis=1)
151
+ trans_boxes = np.concatenate((minxy, maxxy), axis=1)
152
+
153
+ crop_ws.append(endw-startw)
154
+ crop_hs.append(endh-starth)
155
+ split_boxes.append(trans_boxes)
156
+ split_boxes = np.stack(split_boxes, axis=1)
157
+ ### clip to the image boundary
158
+ ## assert each crop size is equal
159
+ for crop_index, (crop_w, crop_h) in enumerate(zip(crop_ws, crop_hs)):
160
+ assert crop_w == crop_ws[0], "crop width is not equal, crop_{}: {}, crop_0: {}".format(crop_index, crop_w, crop_ws[0])
161
+ assert crop_h == crop_hs[0], "crop height is not equal, crop_{}: {}, crop_0: {}".format(crop_index, crop_h, crop_hs[0])
162
+ crop_w = crop_ws[0]
163
+ crop_h = crop_hs[0]
164
+ # pdb.set_trace()
165
+ split_boxes[...,0::2] = np.clip(split_boxes[...,0::2], 0, crop_w)
166
+ split_boxes[...,1::2] = np.clip(split_boxes[...,1::2], 0, crop_h)
167
+ valid_inds = (split_boxes[...,2]>split_boxes[...,0]) & (split_boxes[...,3]>split_boxes[...,1])
168
+ split_infos = np.concatenate((split_boxes, valid_inds[...,None]), axis=-1)
169
+ return split_infos
170
+
171
+ class BatchResizeShortestEdge(Augmentation):
172
+ """
173
+ Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
174
+ If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
175
+ """
176
+
177
+ def __init__(
178
+ self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
179
+ ):
180
+ """
181
+ Args:
182
+ short_edge_length (list[int]): If ``sample_style=="range"``,
183
+ a [min, max] interval from which to sample the shortest edge length.
184
+ If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
185
+ max_size (int): maximum allowed longest edge length.
186
+ sample_style (str): either "range" or "choice".
187
+ """
188
+ super().__init__()
189
+ assert sample_style in ["range", "choice"], sample_style
190
+
191
+ self.is_range = sample_style == "range"
192
+ if isinstance(short_edge_length, int):
193
+ short_edge_length = (short_edge_length, short_edge_length)
194
+ if self.is_range:
195
+ assert len(short_edge_length) == 2, (
196
+ "short_edge_length must be two values using 'range' sample style."
197
+ f" Got {short_edge_length}!"
198
+ )
199
+ self._init(locals())
200
+
201
+ def get_transform(self, image):
202
+ dim_num = len(image.shape)
203
+ assert dim_num == 4, "the tensor should be in [B, H, W, C]"
204
+ h, w = image.shape[1:3]
205
+ if self.is_range:
206
+ size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
207
+ else:
208
+ size = np.random.choice(self.short_edge_length)
209
+ if size == 0:
210
+ return NoOpTransform()
211
+
212
+ scale = size * 1.0 / min(h, w)
213
+ if h < w:
214
+ newh, neww = size, scale * w
215
+ else:
216
+ newh, neww = scale * h, size
217
+ if max(newh, neww) > self.max_size:
218
+ scale = self.max_size * 1.0 / max(newh, neww)
219
+ newh = newh * scale
220
+ neww = neww * scale
221
+ neww = int(neww + 0.5)
222
+ newh = int(newh + 0.5)
223
+ return BatchResizeTransform(h, w, newh, neww, self.interp)
224
+
225
+ class EntityCrop(Augmentation):
226
+ def __init__(self, crop_ratio, stride_ratio, sample_num, is_train):
227
+ super().__init__()
228
+ self._init(locals())
229
+
230
+ def get_transform(self, image):
231
+ h, w = image.shape[:2]
232
+ crop_axises, crop_indexes = self.get_crop_axises((h, w))
233
+ transform = EntityCropTransform(crop_axises, crop_indexes)
234
+ return transform
235
+
236
+ def get_crop_axises(self, image_size):
237
+ h, w = image_size
238
+ crop_w = int(self.crop_ratio*w)
239
+ crop_h = int(self.crop_ratio*h)
240
+ # if self.is_train:
241
+ stride_w = int(self.stride_ratio*w)
242
+ stride_h = int(self.stride_ratio*h)
243
+ # pdb.set_trace()
244
+
245
+ crop_axises = []
246
+ for starth in range(0, h, stride_h):
247
+ for startw in range(0, w, stride_w):
248
+ endh = min(starth+crop_h, h)
249
+ endw = min(startw+crop_w, w)
250
+ starth = int(endh-crop_h)
251
+ startw = int(endw-crop_w)
252
+ crop_axises.append([startw, starth, endw, endh])
253
+ if self.is_train:
254
+ crop_indexes = random.sample([i for i in range(len(crop_axises))], self.sample_num)
255
+ crop_axises = [crop_axises[i] for i in crop_indexes]
256
+ else:
257
+ crop_indexes = [i for i in range(self.sample_num)]
258
+ # left_upper = [0, 0, crop_w, crop_h]
259
+ # right_upper = [w-crop_w, 0, w, crop_h]
260
+ # left_bottom = [0, h-crop_h, crop_w, h]
261
+ # right_bottom = [w-crop_w, h-crop_h, w, h]
262
+
263
+ # crop_axises = [left_upper, right_upper, left_bottom, right_bottom]
264
+ # crop_indexes = [0,1,2,3]
265
+ assert len(crop_axises)==len(crop_indexes)
266
+ return crop_axises, crop_indexes
267
+
268
+ def transform_instance_annotations_crop(
269
+ annotation, transforms, image_size, *, keypoint_hflip_indices=None
270
+ ):
271
+ """
272
+ Apply transforms to box, segmentation and keypoints annotations of a single instance.
273
+
274
+ It will use `transforms.apply_box` for the box, and
275
+ `transforms.apply_coords` for segmentation polygons & keypoints.
276
+ If you need anything more specially designed for each data structure,
277
+ you'll need to implement your own version of this function or the transforms.
278
+
279
+ Args:
280
+ annotation (dict): dict of instance annotations for a single instance.
281
+ It will be modified in-place.
282
+ transforms (TransformList or list[Transform]):
283
+ image_size (tuple): the height, width of the transformed image
284
+ keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
285
+
286
+ Returns:
287
+ dict:
288
+ the same input dict with fields "bbox", "segmentation", "keypoints"
289
+ transformed according to `transforms`.
290
+ The "bbox_mode" field will be set to XYXY_ABS.
291
+ """
292
+ if isinstance(transforms, (tuple, list)):
293
+ transforms = T.TransformList(transforms)
294
+ # bbox is 1d (per-instance bounding box)
295
+ bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
296
+
297
+ # clip transformed bbox to image size
298
+ bboxes_info = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
299
+ annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
300
+ annotation["bbox"] = bboxes_info[...,:4]
301
+ annotation["bbox_mode"] = BoxMode.XYXY_ABS
302
+ annotation["bbox_valid"] = bboxes_info[...,4]
303
+ for transform_type in transforms:
304
+ if isinstance(transform_type, EntityCropTransform):
305
+ annotation["crop_axises"] = transform_type.crop_axises
306
+ annotation["crop_indexes"] = transform_type.crop_indexes
307
+
308
+ if "segmentation" in annotation:
309
+ segm = annotation["segmentation"]
310
+ assert isinstance(segm, dict), "requiring segmentation encoding -> RLE"
311
+ # RLE
312
+ mask = mask_util.decode(segm)
313
+ mask = transforms.apply_segmentation(mask)
314
+ annotation["segmentation"] = mask
315
+ return annotation
316
+
317
+ def annotations_to_instances_crop(annos, image_size, mask_format="polygon", return_indexes=False):
318
+ """
319
+ Create an :class:`Instances` object used by the models,
320
+ from instance annotations in the dataset dict.
321
+
322
+ Args:
323
+ annos (list[dict]): a list of instance annotations in one image, each
324
+ element for one instance.
325
+ image_size (tuple): height, width
326
+
327
+ Returns:
328
+ Instances:
329
+ It will contain fields "gt_boxes", "gt_classes",
330
+ "gt_masks", "gt_keypoints", if they can be obtained from `annos`.
331
+ This is the format that builtin models expect.
332
+ """
333
+ ###
334
+ all_boxes = []
335
+ all_boxes_valid = []
336
+ all_classes = []
337
+ all_segmentations = []
338
+ all_iscrowds = []
339
+ # pdb.set_trace()
340
+ annos_num = len(annos)
341
+ patches_num = len(annos[0]["bbox"])
342
+ for ann_index, obj in enumerate(annos):
343
+ for split_index in range(len(obj["bbox"])):
344
+ all_boxes.append(BoxMode.convert(obj["bbox"][split_index], obj["bbox_mode"], BoxMode.XYXY_ABS))
345
+ all_boxes_valid.append(obj["bbox_valid"][split_index])
346
+ all_classes.append(obj["category_id"])
347
+ all_segmentations.append(obj["segmentation"][split_index])
348
+ all_iscrowds.append(obj["iscrowd"])
349
+ # print("ann_index:{}, split_index:{}".format(ann_index, split_index))
350
+
351
+ new_targets = []
352
+ crop_axises = annos[0]["crop_axises"]
353
+ # pdb.set_trace()
354
+ crop_size = (crop_axises[0][3], crop_axises[0][2])
355
+ crop_axises = torch.tensor(crop_axises)
356
+
357
+ for split_index in range(patches_num):
358
+ new_targets.append(Instances(crop_size))
359
+ # pdb.set_trace()
360
+ ## boxes
361
+ new_targets[-1].gt_boxes = Boxes(all_boxes[split_index::patches_num])
362
+ new_targets[-1].gt_boxes_valid = torch.tensor(all_boxes_valid[split_index::patches_num], dtype=torch.int64)
363
+ ## categories
364
+ new_targets[-1].gt_classes = torch.tensor(all_classes[split_index::patches_num], dtype=torch.int64)
365
+
366
+ ## masks
367
+ if "segmentation" in annos[0]:
368
+ new_targets[-1].gt_masks = BitMasks(torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in all_segmentations[split_index::patches_num]]))
369
+
370
+ # pdb.set_trace()
371
+ if return_indexes:
372
+ return new_targets, crop_axises, annos[0]["crop_indexes"]
373
+ else:
374
+ return new_targets, crop_axises
375
+
376
+ class EntityCascadedCrop(Augmentation):
377
+ def __init__(self, crop_ratio, stride_ratio, sample_num, cascade_num, is_train):
378
+ super().__init__()
379
+ self._init(locals())
380
+
381
+ def get_transform(self, image):
382
+ h, w = image.shape[:2]
383
+ crop_axises, crop_indexes = self.get_crop_axises((h, w))
384
+ transform = EntityCropTransform(crop_axises, crop_indexes)
385
+ return transform
386
+
387
+ def get_crop_axises(self, image_size):
388
+ h, w = image_size
389
+ # for i in range(self.cascade_num):
390
+ # crop_w = int((self.crop_ratio**(i+1))*w)
391
+ # crop_h = int((self.crop_ratio**(i+1))*h)
392
+ # stride_w = int((self.stride_ratio**(i+1))*w)
393
+ # stride_h = int((self.stride_ratio**(i+1))*h)
394
+ # crop_axises = []
395
+ # if i==0:
396
+
397
+ # for starth in range(0, )
398
+
399
+
400
+ crop_axises = []
401
+ for starth in range(0, h, stride_h):
402
+ for startw in range(0, w, stride_w):
403
+ endh = min(starth+crop_h, h)
404
+ endw = min(startw+crop_w, w)
405
+ starth = int(endh-crop_h)
406
+ startw = int(endw-crop_w)
407
+ crop_axises.append([startw, starth, endw, endh])
408
+ if self.is_train:
409
+ crop_indexes = random.sample([i for i in range(len(crop_axises))], self.sample_num)
410
+ crop_axises = [crop_axises[i] for i in crop_indexes]
411
+ else:
412
+ crop_indexes = [i for i in range(self.sample_num)]
413
+ # left_upper = [0, 0, crop_w, crop_h]
414
+ # right_upper = [w-crop_w, 0, w, crop_h]
415
+ # left_bottom = [0, h-crop_h, crop_w, h]
416
+ # right_bottom = [w-crop_w, h-crop_h, w, h]
417
+
418
+ # crop_axises = [left_upper, right_upper, left_bottom, right_bottom]
419
+ # crop_indexes = [0,1,2,3]
420
+ assert len(crop_axises)==len(crop_indexes)
421
+ return crop_axises, crop_indexes
annotator/entityseg/mask2former/maskformer_model.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.data import MetadataCatalog
10
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
11
+ from detectron2.modeling.backbone import Backbone
12
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
13
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
14
+ from detectron2.utils.memory import retry_if_cuda_oom
15
+
16
+ from .modeling.criterion import SetCriterion
17
+ from .modeling.matcher import HungarianMatcher
18
+
19
+
20
+ @META_ARCH_REGISTRY.register()
21
+ class MaskFormer(nn.Module):
22
+ """
23
+ Main class for mask classification semantic segmentation architectures.
24
+ """
25
+
26
+ @configurable
27
+ def __init__(
28
+ self,
29
+ *,
30
+ cfg,
31
+ backbone: Backbone,
32
+ sem_seg_head: nn.Module,
33
+ criterion: nn.Module,
34
+ num_queries: int,
35
+ object_mask_threshold: float,
36
+ overlap_threshold: float,
37
+ metadata,
38
+ size_divisibility: int,
39
+ sem_seg_postprocess_before_inference: bool,
40
+ pixel_mean: Tuple[float],
41
+ pixel_std: Tuple[float],
42
+ # inference
43
+ semantic_on: bool,
44
+ panoptic_on: bool,
45
+ instance_on: bool,
46
+ test_topk_per_image: int,
47
+ ):
48
+ """
49
+ Args:
50
+ backbone: a backbone module, must follow detectron2's backbone interface
51
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
52
+ criterion: a module that defines the loss
53
+ num_queries: int, number of queries
54
+ object_mask_threshold: float, threshold to filter query based on classification score
55
+ for panoptic segmentation inference
56
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
57
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
58
+ segmentation inference
59
+ size_divisibility: Some backbones require the input height and width to be divisible by a
60
+ specific integer. We can use this to override such requirement.
61
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
62
+ to original input size before semantic segmentation inference or after.
63
+ For high-resolution dataset like Mapillary, resizing predictions before
64
+ inference will cause OOM error.
65
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
66
+ the per-channel mean and std to be used to normalize the input image
67
+ semantic_on: bool, whether to output semantic segmentation prediction
68
+ instance_on: bool, whether to output instance segmentation prediction
69
+ panoptic_on: bool, whether to output panoptic segmentation prediction
70
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
71
+ """
72
+ super().__init__()
73
+ self.cfg = cfg
74
+ self.backbone = backbone
75
+ self.sem_seg_head = sem_seg_head
76
+ self.criterion = criterion
77
+ self.num_queries = num_queries
78
+ self.overlap_threshold = overlap_threshold
79
+ self.entity_enable = self.cfg.ENTITY.ENABLE
80
+ self.object_mask_threshold = object_mask_threshold
81
+ self.metadata = metadata
82
+ if size_divisibility < 0:
83
+ # use backbone size_divisibility if not set
84
+ size_divisibility = self.backbone.size_divisibility
85
+ self.size_divisibility = size_divisibility
86
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
87
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
88
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
89
+
90
+ # additional args
91
+ self.semantic_on = semantic_on
92
+ self.instance_on = instance_on
93
+ self.panoptic_on = panoptic_on
94
+ self.test_topk_per_image = test_topk_per_image
95
+
96
+ if not self.semantic_on:
97
+ assert self.sem_seg_postprocess_before_inference
98
+
99
+ @classmethod
100
+ def from_config(cls, cfg):
101
+ backbone = build_backbone(cfg)
102
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
103
+
104
+ # Loss parameters:
105
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
106
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
107
+
108
+ # loss weights
109
+ class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
110
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
111
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
112
+
113
+ # building criterion
114
+ matcher = HungarianMatcher(
115
+ cost_class=class_weight,
116
+ cost_mask=mask_weight,
117
+ cost_dice=dice_weight,
118
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
119
+ )
120
+
121
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
122
+
123
+ if deep_supervision:
124
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
125
+ aux_weight_dict = {}
126
+ for i in range(dec_layers - 1):
127
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
128
+ weight_dict.update(aux_weight_dict)
129
+
130
+ losses = ["labels", "masks"]
131
+
132
+ criterion = SetCriterion(
133
+ sem_seg_head.num_classes,
134
+ matcher=matcher,
135
+ weight_dict=weight_dict,
136
+ eos_coef=no_object_weight,
137
+ losses=losses,
138
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
139
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
140
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
141
+ )
142
+
143
+ return {
144
+ "cfg": cfg,
145
+ "backbone": backbone,
146
+ "sem_seg_head": sem_seg_head,
147
+ "criterion": criterion,
148
+ "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
149
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
150
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
151
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
152
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
153
+ "sem_seg_postprocess_before_inference": (
154
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
155
+ or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
156
+ or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
157
+ ),
158
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
159
+ "pixel_std": cfg.MODEL.PIXEL_STD,
160
+ # inference
161
+ "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON,
162
+ "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON,
163
+ "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
164
+ "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
165
+ }
166
+
167
+ @property
168
+ def device(self):
169
+ return self.pixel_mean.device
170
+
171
+ def forward(self, batched_inputs):
172
+ """
173
+ Args:
174
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
175
+ Each item in the list contains the inputs for one image.
176
+ For now, each item in the list is a dict that contains:
177
+ * "image": Tensor, image in (C, H, W) format.
178
+ * "instances": per-region ground truth
179
+ * Other information that's included in the original dicts, such as:
180
+ "height", "width" (int): the output resolution of the model (may be different
181
+ from input resolution), used in inference.
182
+ Returns:
183
+ list[dict]:
184
+ each dict has the results for one image. The dict contains the following keys:
185
+
186
+ * "sem_seg":
187
+ A Tensor that represents the
188
+ per-pixel segmentation prediced by the head.
189
+ The prediction has shape KxHxW that represents the logits of
190
+ each class for each pixel.
191
+ * "panoptic_seg":
192
+ A tuple that represent panoptic output
193
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
194
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
195
+ Each dict contains keys "id", "category_id", "isthing".
196
+ """
197
+ images = [x["image"].to(self.device) for x in batched_inputs]
198
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
199
+ images = ImageList.from_tensors(images, self.size_divisibility)
200
+
201
+ features = self.backbone(images.tensor)
202
+ outputs = self.sem_seg_head(features)
203
+
204
+ if self.training:
205
+ # mask classification target
206
+ if "instances" in batched_inputs[0]:
207
+ if self.cfg.ENTITY.ENABLE:
208
+ for i in range(len(batched_inputs)):
209
+ batched_inputs[i]["instances"].gt_classes[:] = 0
210
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
211
+ targets = self.prepare_targets(gt_instances, images)
212
+ else:
213
+ targets = None
214
+
215
+ # bipartite matching-based loss
216
+ losses = self.criterion(outputs, targets)
217
+
218
+ for k in list(losses.keys()):
219
+ if k in self.criterion.weight_dict:
220
+ losses[k] *= self.criterion.weight_dict[k]
221
+ else:
222
+ # remove this loss if not specified in `weight_dict`
223
+ losses.pop(k)
224
+ return losses
225
+ else:
226
+ mask_cls_results = outputs["pred_logits"]
227
+ mask_pred_results = outputs["pred_masks"]
228
+ # upsample masks
229
+ mask_pred_results = F.interpolate(
230
+ mask_pred_results,
231
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
232
+ mode="bilinear",
233
+ align_corners=False,
234
+ )
235
+
236
+ del outputs
237
+
238
+ processed_results = []
239
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
240
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
241
+ ):
242
+ height = input_per_image.get("height", image_size[0])
243
+ width = input_per_image.get("width", image_size[1])
244
+ processed_results.append({})
245
+
246
+ if self.sem_seg_postprocess_before_inference:
247
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
248
+ mask_pred_result, image_size, height, width
249
+ )
250
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
251
+
252
+ # semantic segmentation inference
253
+ if self.semantic_on:
254
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
255
+ if not self.sem_seg_postprocess_before_inference:
256
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
257
+ processed_results[-1]["sem_seg"] = r
258
+
259
+ # panoptic segmentation inference
260
+ if self.panoptic_on:
261
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
262
+ processed_results[-1]["panoptic_seg"] = panoptic_r
263
+
264
+ # instance segmentation and entity segmentation inference
265
+ if self.instance_on and self.cfg.ENTITY.ENABLE:
266
+ instance_r = retry_if_cuda_oom(self.instance_inference_nonoverlap)(mask_cls_result, mask_pred_result)
267
+ processed_results[-1]["instances"] = instance_r
268
+ else:
269
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result)
270
+ processed_results[-1]["instances"] = instance_r
271
+
272
+ return processed_results
273
+
274
+ def prepare_targets(self, targets, images):
275
+ h_pad, w_pad = images.tensor.shape[-2:]
276
+ new_targets = []
277
+ for targets_per_image in targets:
278
+ # pad gt
279
+ gt_masks = targets_per_image.gt_masks
280
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
281
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
282
+ new_targets.append(
283
+ {
284
+ "labels": targets_per_image.gt_classes,
285
+ "masks": padded_masks,
286
+ }
287
+ )
288
+ return new_targets
289
+
290
+ def semantic_inference(self, mask_cls, mask_pred):
291
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
292
+ mask_pred = mask_pred.sigmoid()
293
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
294
+ return semseg
295
+
296
+ def panoptic_inference(self, mask_cls, mask_pred):
297
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
298
+ mask_pred = mask_pred.sigmoid()
299
+
300
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
301
+ cur_scores = scores[keep]
302
+ cur_classes = labels[keep]
303
+ cur_masks = mask_pred[keep]
304
+ cur_mask_cls = mask_cls[keep]
305
+ cur_mask_cls = cur_mask_cls[:, :-1]
306
+
307
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
308
+
309
+ h, w = cur_masks.shape[-2:]
310
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
311
+ segments_info = []
312
+
313
+ current_segment_id = 0
314
+
315
+ if cur_masks.shape[0] == 0:
316
+ # We didn't detect any mask :(
317
+ return panoptic_seg, segments_info
318
+ else:
319
+ # take argmax
320
+ cur_mask_ids = cur_prob_masks.argmax(0)
321
+ stuff_memory_list = {}
322
+ for k in range(cur_classes.shape[0]):
323
+ pred_class = cur_classes[k].item()
324
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
325
+ mask_area = (cur_mask_ids == k).sum().item()
326
+ original_area = (cur_masks[k] >= 0.5).sum().item()
327
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
328
+
329
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
330
+ if mask_area / original_area < self.overlap_threshold:
331
+ continue
332
+
333
+ # merge stuff regions
334
+ if not isthing:
335
+ if int(pred_class) in stuff_memory_list.keys():
336
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
337
+ continue
338
+ else:
339
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
340
+
341
+ current_segment_id += 1
342
+ panoptic_seg[mask] = current_segment_id
343
+
344
+ segments_info.append(
345
+ {
346
+ "id": current_segment_id,
347
+ "isthing": bool(isthing),
348
+ "category_id": int(pred_class),
349
+ }
350
+ )
351
+
352
+ return panoptic_seg, segments_info
353
+
354
+ def instance_inference(self, mask_cls, mask_pred):
355
+ # mask_pred is already processed to have the same shape as original input
356
+ image_size = mask_pred.shape[-2:]
357
+
358
+ # [Q, K]
359
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
360
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
361
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
362
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
363
+ labels_per_image = labels[topk_indices]
364
+
365
+ # topk_indices = topk_indices // self.sem_seg_head.num_classes
366
+ topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='trunc')
367
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
368
+ mask_pred = mask_pred[topk_indices]
369
+
370
+ # if this is panoptic segmentation, we only keep the "thing" classes
371
+ if self.panoptic_on:
372
+ keep = torch.zeros_like(scores_per_image).bool()
373
+ for i, lab in enumerate(labels_per_image):
374
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
375
+
376
+ scores_per_image = scores_per_image[keep]
377
+ labels_per_image = labels_per_image[keep]
378
+ mask_pred = mask_pred[keep]
379
+
380
+ result = Instances(image_size)
381
+ # mask (before sigmoid)
382
+ result.pred_masks = (mask_pred > 0).float()
383
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
384
+ # Uncomment the following to get boxes from masks (this is slow)
385
+ # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
386
+
387
+ # calculate average mask prob
388
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
389
+ result.scores = scores_per_image * mask_scores_per_image
390
+ result.pred_classes = labels_per_image
391
+ return result
392
+
393
+ def instance_inference_nonoverlap(self, mask_cls, mask_pred):
394
+ # mask_pred is already processed to have the same shape as original input
395
+ image_size = mask_pred.shape[-2:]
396
+
397
+ # [Q, K]
398
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
399
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
400
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
401
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
402
+ labels_per_image = labels[topk_indices]
403
+
404
+ # topk_indices = topk_indices // self.sem_seg_head.num_classes
405
+ topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='trunc')
406
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
407
+ mask_pred = mask_pred[topk_indices]
408
+
409
+ ###### ranks
410
+ pred_masks = (mask_pred>0).float()
411
+ pred_masks_logits = mask_pred.sigmoid()
412
+ pred_scores = scores_per_image
413
+
414
+ _, m_H, m_W = pred_masks.shape
415
+ mask_id = torch.zeros((m_H, m_W), dtype=torch.int).to(pred_masks.device)
416
+ sorted_scores, ranks = torch.sort(pred_scores)
417
+ ranks = ranks + 1
418
+ for index in ranks:
419
+ mask_id[(pred_masks[index-1]==1)] = int(index)
420
+ # re-generate mask
421
+ new_scores = []
422
+ new_masks = []
423
+ new_masks_logits = []
424
+ entity_nums = len(ranks)
425
+ for ii in range(entity_nums):
426
+ index = int(ranks[entity_nums-ii-1])
427
+ score = sorted_scores[entity_nums-ii-1]
428
+ new_scores.append(score)
429
+ new_masks.append((mask_id==index).float())
430
+ new_masks_logits.append(pred_masks_logits[index-1])
431
+
432
+ new_scores = torch.stack(new_scores)
433
+ new_masks = torch.stack(new_masks)
434
+ new_masks_logits = torch.stack(new_masks_logits)
435
+
436
+ result = Instances(image_size)
437
+ # mask (before sigmoid)
438
+ result.pred_masks = new_masks
439
+ result.pred_boxes = Boxes(torch.zeros(new_masks.size(0), 4))
440
+ # Uncomment the following to get boxes from masks (this is slow)
441
+
442
+ # calculate average mask prob
443
+ mask_scores_per_image = (new_masks_logits.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
444
+ result.scores = new_scores * mask_scores_per_image
445
+ result.pred_classes = labels_per_image
446
+ return result
annotator/entityseg/mask2former/modeling/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .backbone.swin import D2SwinTransformer
3
+ from .backbone.hornet import D2HorNet
4
+ from .pixel_decoder.fpn import BasePixelDecoder
5
+ from .pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder
6
+ from .meta_arch.mask_former_head import MaskFormerHead
7
+ from .meta_arch.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead
annotator/entityseg/mask2former/modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
annotator/entityseg/mask2former/modeling/backbone/hornet.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from functools import partial
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.registry import register_model
14
+ import os
15
+ import sys
16
+ import torch.fft
17
+ import math
18
+
19
+ import traceback
20
+ import torch.utils.checkpoint as checkpoint
21
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
22
+
23
+
24
+ if 'DWCONV_IMPL' in os.environ:
25
+ try:
26
+ sys.path.append(os.environ['DWCONV_IMPL'])
27
+ from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
28
+ def get_dwconv(dim, kernel, bias):
29
+ return DepthWiseConv2dImplicitGEMM(dim, kernel, bias)
30
+ print('Using Megvii large kernel dw conv impl')
31
+ except:
32
+ print(traceback.format_exc())
33
+ def get_dwconv(dim, kernel, bias):
34
+ return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
35
+
36
+ print('[fail to use Megvii Large kernel] Using PyTorch large kernel dw conv impl')
37
+ else:
38
+ def get_dwconv(dim, kernel, bias):
39
+ return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
40
+
41
+ print('Using PyTorch large kernel dw conv impl')
42
+
43
+ class GlobalLocalFilter(nn.Module):
44
+ def __init__(self, dim, h=14, w=8):
45
+ super().__init__()
46
+ self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
47
+ self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
48
+ trunc_normal_(self.complex_weight, std=.02)
49
+ self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
50
+ self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
51
+
52
+ def forward(self, x):
53
+ x = self.pre_norm(x)
54
+ x1, x2 = torch.chunk(x, 2, dim=1)
55
+ x1 = self.dw(x1)
56
+
57
+ x2 = x2.to(torch.float32)
58
+ B, C, a, b = x2.shape
59
+ x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
60
+
61
+ weight = self.complex_weight
62
+ if not weight.shape[1:3] == x2.shape[2:4]:
63
+ weight = F.interpolate(weight.permute(3,0,1,2), size=x2.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
64
+
65
+ weight = torch.view_as_complex(weight.contiguous())
66
+
67
+ x2 = x2 * weight
68
+ x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
69
+
70
+ x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b)
71
+ x = self.post_norm(x)
72
+ return x
73
+
74
+
75
+ class gnconv(nn.Module):
76
+ def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
77
+ super().__init__()
78
+ self.order = order
79
+ self.dims = [dim // 2 ** i for i in range(order)]
80
+ self.dims.reverse()
81
+ self.proj_in = nn.Conv2d(dim, 2*dim, 1)
82
+
83
+ if gflayer is None:
84
+ self.dwconv = get_dwconv(sum(self.dims), 7, True)
85
+ else:
86
+ self.dwconv = gflayer(sum(self.dims), h=h, w=w)
87
+
88
+ self.proj_out = nn.Conv2d(dim, dim, 1)
89
+
90
+ self.pws = nn.ModuleList(
91
+ [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
92
+ )
93
+
94
+ self.scale = s
95
+
96
+ print('[gconv]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale)
97
+
98
+
99
+ def forward(self, x, mask=None, dummy=False):
100
+ B, C, H, W = x.shape
101
+
102
+ fused_x = self.proj_in(x)
103
+ pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
104
+
105
+ dw_abc = self.dwconv(abc) * self.scale
106
+
107
+ dw_list = torch.split(dw_abc, self.dims, dim=1)
108
+ x = pwa * dw_list[0]
109
+
110
+ for i in range(self.order -1):
111
+ x = self.pws[i](x) * dw_list[i+1]
112
+
113
+ x = self.proj_out(x)
114
+
115
+ return x
116
+
117
+ class Block(nn.Module):
118
+ r""" HorNet block
119
+ """
120
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, gnconv=gnconv):
121
+ super().__init__()
122
+
123
+ self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
124
+ self.gnconv = gnconv(dim) # depthwise conv
125
+ self.norm2 = LayerNorm(dim, eps=1e-6)
126
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
127
+ self.act = nn.GELU()
128
+ self.pwconv2 = nn.Linear(4 * dim, dim)
129
+
130
+ self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
131
+ requires_grad=True) if layer_scale_init_value > 0 else None
132
+
133
+ self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
134
+ requires_grad=True) if layer_scale_init_value > 0 else None
135
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
136
+
137
+ def forward(self, x):
138
+ B, C, H, W = x.shape
139
+ if self.gamma1 is not None:
140
+ gamma1 = self.gamma1.view(C, 1, 1)
141
+ else:
142
+ gamma1 = 1
143
+ x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))
144
+
145
+ input = x
146
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
147
+ x = self.norm2(x)
148
+ x = self.pwconv1(x)
149
+ x = self.act(x)
150
+ x = self.pwconv2(x)
151
+ if self.gamma2 is not None:
152
+ x = self.gamma2 * x
153
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
154
+
155
+ x = input + self.drop_path(x)
156
+ return x
157
+
158
+
159
+ class HorNet(nn.Module):
160
+ r""" HorNet
161
+ A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions`
162
+
163
+ Args:
164
+ in_chans (int): Number of input image channels. Default: 3
165
+ num_classes (int): Number of classes for classification head. Default: 1000
166
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
167
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
168
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
169
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
170
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
171
+ """
172
+ def __init__(self, in_chans=3, num_classes=1000,
173
+ depths=[3, 3, 9, 3], base_dim=96, drop_path_rate=0.,
174
+ layer_scale_init_value=1e-6, head_init_scale=1.,
175
+ gnconv=gnconv, block=Block,
176
+ pretrained=None,
177
+ use_checkpoint=False,
178
+ ):
179
+ super().__init__()
180
+
181
+ self.pretrained = pretrained
182
+ self.use_checkpoint = use_checkpoint
183
+
184
+ dims = [base_dim, base_dim*2, base_dim*4, base_dim*8]
185
+
186
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
187
+ stem = nn.Sequential(
188
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
189
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
190
+ )
191
+ self.downsample_layers.append(stem)
192
+ for i in range(3):
193
+ downsample_layer = nn.Sequential(
194
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
195
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
196
+ )
197
+ self.downsample_layers.append(downsample_layer)
198
+
199
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
200
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
201
+
202
+
203
+ if not isinstance(gnconv, list):
204
+ gnconv = [gnconv, gnconv, gnconv, gnconv]
205
+ else:
206
+ gnconv = gnconv
207
+ assert len(gnconv) == 4
208
+
209
+ if isinstance(gnconv[0], str):
210
+ print('[GConvNet]: convert str gconv to func')
211
+ gnconv = [eval(g) for g in gnconv]
212
+
213
+ if isinstance(block, str):
214
+ block = eval(block)
215
+
216
+ cur = 0
217
+ num_features = []
218
+ for i in range(4):
219
+ stage = nn.Sequential(
220
+ *[block(dim=dims[i], drop_path=dp_rates[cur + j],
221
+ layer_scale_init_value=layer_scale_init_value, gnconv=gnconv[i]) for j in range(depths[i])]
222
+ )
223
+ self.stages.append(stage)
224
+ cur += depths[i]
225
+ num_features.append(dims[i])
226
+ self.num_features = num_features
227
+
228
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
229
+ for i_layer in range(4):
230
+ layer = norm_layer(dims[i_layer])
231
+ layer_name = f'norm{i_layer}'
232
+ self.add_module(layer_name, layer)
233
+
234
+ def init_weights(self):
235
+ """Initialize the weights in backbone.
236
+ Args:
237
+ pretrained (str, optional): Path to pre-trained weights.
238
+ Defaults to None.
239
+ """
240
+ #pretrained = self.pretrained
241
+
242
+ def _init_weights(m):
243
+ if isinstance(m, nn.Linear):
244
+ trunc_normal_(m.weight, std=.02)
245
+ if isinstance(m, nn.Linear) and m.bias is not None:
246
+ nn.init.constant_(m.bias, 0)
247
+ elif isinstance(m, nn.LayerNorm):
248
+ nn.init.constant_(m.bias, 0)
249
+ nn.init.constant_(m.weight, 1.0)
250
+
251
+ #if isinstance(pretrained, str):
252
+ # self.apply(_init_weights)
253
+ # logger = get_root_logger()
254
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
255
+ #elif pretrained is None:
256
+ # raise NotImplementedError()
257
+ self.apply(_init_weights)
258
+ #else:
259
+ # raise TypeError('pretrained must be a str or None')
260
+
261
+ def forward_features(self, x):
262
+ outs = dict()
263
+ for i in range(4):
264
+ x = self.downsample_layers[i](x)
265
+ if self.use_checkpoint:
266
+ x = checkpoint.checkpoint_sequential(self.stages[i], len(self.stages[i]), x)
267
+ else:
268
+ x = self.stages[i](x)
269
+ norm_layer = getattr(self, f'norm{i}')
270
+ x_out = norm_layer(x)
271
+ outs["res%i"% (i+2)] = x_out
272
+ return outs #tuple(outs)
273
+
274
+ def forward(self, x):
275
+ x = self.forward_features(x)
276
+ return x
277
+
278
+
279
+ class LayerNorm(nn.Module):
280
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
281
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
282
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
283
+ with shape (batch_size, channels, height, width).
284
+ """
285
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
286
+ super().__init__()
287
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
288
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
289
+ self.eps = eps
290
+ self.data_format = data_format
291
+ if self.data_format not in ["channels_last", "channels_first"]:
292
+ raise NotImplementedError
293
+ self.normalized_shape = (normalized_shape, )
294
+
295
+ def forward(self, x):
296
+ if self.data_format == "channels_last":
297
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
298
+ elif self.data_format == "channels_first":
299
+ u = x.mean(1, keepdim=True)
300
+ s = (x - u).pow(2).mean(1, keepdim=True)
301
+ x = (x - u) / torch.sqrt(s + self.eps)
302
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
303
+ return x
304
+
305
+ @BACKBONE_REGISTRY.register()
306
+ class D2HorNet(HorNet, Backbone):
307
+ def __init__(self, cfg, input_shape):
308
+
309
+ depths=cfg.MODEL.HORNET.DEPTHS
310
+ base_dim=cfg.MODEL.HORNET.BASE_DIM
311
+ gnconv=cfg.MODEL.HORNET.GCONV
312
+ drop_path_rate=cfg.MODEL.HORNET.DROP_PATH_RATE
313
+
314
+ super().__init__(
315
+ depths=depths,
316
+ base_dim=base_dim,
317
+ gnconv=gnconv,
318
+ drop_path_rate=drop_path_rate,
319
+ )
320
+
321
+ self._out_features = cfg.MODEL.HORNET.OUT_FEATURES
322
+
323
+ self._out_feature_strides = {
324
+ "res2": 4,
325
+ "res3": 8,
326
+ "res4": 16,
327
+ "res5": 32,
328
+ }
329
+ self._out_feature_channels = {
330
+ "res2": self.num_features[0],
331
+ "res3": self.num_features[1],
332
+ "res4": self.num_features[2],
333
+ "res5": self.num_features[3],
334
+ }
335
+
336
+ def forward(self, x):
337
+ """
338
+ Args:
339
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
340
+ Returns:
341
+ dict[str->Tensor]: names and the corresponding features
342
+ """
343
+ assert (
344
+ x.dim() == 4
345
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
346
+ outputs = {}
347
+ y = super().forward(x)
348
+ for k in y.keys():
349
+ if k in self._out_features:
350
+ outputs[k] = y[k]
351
+ return outputs
352
+
353
+ def output_shape(self):
354
+ return {
355
+ name: ShapeSpec(
356
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
357
+ )
358
+ for name in self._out_features
359
+ }
360
+
361
+ @property
362
+ def size_divisibility(self):
363
+ return 32
annotator/entityseg/mask2former/modeling/backbone/swin.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17
+
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ """Multilayer perceptron."""
23
+
24
+ def __init__(
25
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
26
+ ):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, H, W, C)
48
+ window_size (int): window size
49
+ Returns:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ """
52
+ B, H, W, C = x.shape
53
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
70
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
71
+ return x
72
+
73
+
74
+ class WindowAttention(nn.Module):
75
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
76
+ It supports both of shifted and non-shifted window.
77
+ Args:
78
+ dim (int): Number of input channels.
79
+ window_size (tuple[int]): The height and width of the window.
80
+ num_heads (int): Number of attention heads.
81
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
82
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
83
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
84
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ window_size,
91
+ num_heads,
92
+ qkv_bias=True,
93
+ qk_scale=None,
94
+ attn_drop=0.0,
95
+ proj_drop=0.0,
96
+ ):
97
+
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.window_size = window_size # Wh, Ww
101
+ self.num_heads = num_heads
102
+ head_dim = dim // num_heads
103
+ self.scale = qk_scale or head_dim ** -0.5
104
+
105
+ # define a parameter table of relative position bias
106
+ self.relative_position_bias_table = nn.Parameter(
107
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
108
+ ) # 2*Wh-1 * 2*Ww-1, nH
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ self.proj = nn.Linear(dim, dim)
126
+ self.proj_drop = nn.Dropout(proj_drop)
127
+
128
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
129
+ self.softmax = nn.Softmax(dim=-1)
130
+
131
+ def forward(self, x, mask=None):
132
+ """Forward function.
133
+ Args:
134
+ x: input features with shape of (num_windows*B, N, C)
135
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
136
+ """
137
+ B_, N, C = x.shape
138
+ qkv = (
139
+ self.qkv(x)
140
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
141
+ .permute(2, 0, 3, 1, 4)
142
+ )
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = q @ k.transpose(-2, -1)
147
+
148
+ relative_position_bias = self.relative_position_bias_table[
149
+ self.relative_position_index.view(-1)
150
+ ].view(
151
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
152
+ ) # Wh*Ww,Wh*Ww,nH
153
+ relative_position_bias = relative_position_bias.permute(
154
+ 2, 0, 1
155
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
156
+ attn = attn + relative_position_bias.unsqueeze(0)
157
+
158
+ if mask is not None:
159
+ nW = mask.shape[0]
160
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
161
+ attn = attn.view(-1, self.num_heads, N, N)
162
+ attn = self.softmax(attn)
163
+ else:
164
+ attn = self.softmax(attn)
165
+
166
+ attn = self.attn_drop(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
169
+ x = self.proj(x)
170
+ x = self.proj_drop(x)
171
+ return x
172
+
173
+
174
+ class SwinTransformerBlock(nn.Module):
175
+ """Swin Transformer Block.
176
+ Args:
177
+ dim (int): Number of input channels.
178
+ num_heads (int): Number of attention heads.
179
+ window_size (int): Window size.
180
+ shift_size (int): Shift size for SW-MSA.
181
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
182
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
183
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
184
+ drop (float, optional): Dropout rate. Default: 0.0
185
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
186
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
187
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
188
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ num_heads,
195
+ window_size=7,
196
+ shift_size=0,
197
+ mlp_ratio=4.0,
198
+ qkv_bias=True,
199
+ qk_scale=None,
200
+ drop=0.0,
201
+ attn_drop=0.0,
202
+ drop_path=0.0,
203
+ act_layer=nn.GELU,
204
+ norm_layer=nn.LayerNorm,
205
+ ):
206
+ super().__init__()
207
+ self.dim = dim
208
+ self.num_heads = num_heads
209
+ self.window_size = window_size
210
+ self.shift_size = shift_size
211
+ self.mlp_ratio = mlp_ratio
212
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
213
+
214
+ self.norm1 = norm_layer(dim)
215
+ self.attn = WindowAttention(
216
+ dim,
217
+ window_size=to_2tuple(self.window_size),
218
+ num_heads=num_heads,
219
+ qkv_bias=qkv_bias,
220
+ qk_scale=qk_scale,
221
+ attn_drop=attn_drop,
222
+ proj_drop=drop,
223
+ )
224
+
225
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
226
+ self.norm2 = norm_layer(dim)
227
+ mlp_hidden_dim = int(dim * mlp_ratio)
228
+ self.mlp = Mlp(
229
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
230
+ )
231
+
232
+ self.H = None
233
+ self.W = None
234
+
235
+ def forward(self, x, mask_matrix):
236
+ """Forward function.
237
+ Args:
238
+ x: Input feature, tensor size (B, H*W, C).
239
+ H, W: Spatial resolution of the input feature.
240
+ mask_matrix: Attention mask for cyclic shift.
241
+ """
242
+ B, L, C = x.shape
243
+ H, W = self.H, self.W
244
+ assert L == H * W, "input feature has wrong size"
245
+
246
+ shortcut = x
247
+ x = self.norm1(x)
248
+ x = x.view(B, H, W, C)
249
+
250
+ # pad feature maps to multiples of window size
251
+ pad_l = pad_t = 0
252
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
253
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
254
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
255
+ _, Hp, Wp, _ = x.shape
256
+
257
+ # cyclic shift
258
+ if self.shift_size > 0:
259
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
260
+ attn_mask = mask_matrix
261
+ else:
262
+ shifted_x = x
263
+ attn_mask = None
264
+
265
+ # partition windows
266
+ x_windows = window_partition(
267
+ shifted_x, self.window_size
268
+ ) # nW*B, window_size, window_size, C
269
+ x_windows = x_windows.view(
270
+ -1, self.window_size * self.window_size, C
271
+ ) # nW*B, window_size*window_size, C
272
+
273
+ # W-MSA/SW-MSA
274
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
275
+
276
+ # merge windows
277
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
278
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
279
+
280
+ # reverse cyclic shift
281
+ if self.shift_size > 0:
282
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
283
+ else:
284
+ x = shifted_x
285
+
286
+ if pad_r > 0 or pad_b > 0:
287
+ x = x[:, :H, :W, :].contiguous()
288
+
289
+ x = x.view(B, H * W, C)
290
+
291
+ # FFN
292
+ x = shortcut + self.drop_path(x)
293
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
294
+
295
+ return x
296
+
297
+
298
+ class PatchMerging(nn.Module):
299
+ """Patch Merging Layer
300
+ Args:
301
+ dim (int): Number of input channels.
302
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
303
+ """
304
+
305
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
309
+ self.norm = norm_layer(4 * dim)
310
+
311
+ def forward(self, x, H, W):
312
+ """Forward function.
313
+ Args:
314
+ x: Input feature, tensor size (B, H*W, C).
315
+ H, W: Spatial resolution of the input feature.
316
+ """
317
+ B, L, C = x.shape
318
+ assert L == H * W, "input feature has wrong size"
319
+
320
+ x = x.view(B, H, W, C)
321
+
322
+ # padding
323
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
324
+ if pad_input:
325
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+
340
+ class BasicLayer(nn.Module):
341
+ """A basic Swin Transformer layer for one stage.
342
+ Args:
343
+ dim (int): Number of feature channels
344
+ depth (int): Depths of this stage.
345
+ num_heads (int): Number of attention head.
346
+ window_size (int): Local window size. Default: 7.
347
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
348
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
349
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
350
+ drop (float, optional): Dropout rate. Default: 0.0
351
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
352
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
353
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
354
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
355
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ dim,
361
+ depth,
362
+ num_heads,
363
+ window_size=7,
364
+ mlp_ratio=4.0,
365
+ qkv_bias=True,
366
+ qk_scale=None,
367
+ drop=0.0,
368
+ attn_drop=0.0,
369
+ drop_path=0.0,
370
+ norm_layer=nn.LayerNorm,
371
+ downsample=None,
372
+ use_checkpoint=False,
373
+ ):
374
+ super().__init__()
375
+ self.window_size = window_size
376
+ self.shift_size = window_size // 2
377
+ self.depth = depth
378
+ self.use_checkpoint = use_checkpoint
379
+
380
+ # build blocks
381
+ self.blocks = nn.ModuleList(
382
+ [
383
+ SwinTransformerBlock(
384
+ dim=dim,
385
+ num_heads=num_heads,
386
+ window_size=window_size,
387
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
388
+ mlp_ratio=mlp_ratio,
389
+ qkv_bias=qkv_bias,
390
+ qk_scale=qk_scale,
391
+ drop=drop,
392
+ attn_drop=attn_drop,
393
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
394
+ norm_layer=norm_layer,
395
+ )
396
+ for i in range(depth)
397
+ ]
398
+ )
399
+
400
+ # patch merging layer
401
+ if downsample is not None:
402
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
403
+ else:
404
+ self.downsample = None
405
+
406
+ def forward(self, x, H, W):
407
+ """Forward function.
408
+ Args:
409
+ x: Input feature, tensor size (B, H*W, C).
410
+ H, W: Spatial resolution of the input feature.
411
+ """
412
+
413
+ # calculate attention mask for SW-MSA
414
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
415
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
416
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
417
+ h_slices = (
418
+ slice(0, -self.window_size),
419
+ slice(-self.window_size, -self.shift_size),
420
+ slice(-self.shift_size, None),
421
+ )
422
+ w_slices = (
423
+ slice(0, -self.window_size),
424
+ slice(-self.window_size, -self.shift_size),
425
+ slice(-self.shift_size, None),
426
+ )
427
+ cnt = 0
428
+ for h in h_slices:
429
+ for w in w_slices:
430
+ img_mask[:, h, w, :] = cnt
431
+ cnt += 1
432
+
433
+ mask_windows = window_partition(
434
+ img_mask, self.window_size
435
+ ) # nW, window_size, window_size, 1
436
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
437
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
438
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
439
+ attn_mask == 0, float(0.0)
440
+ )
441
+
442
+ for blk in self.blocks:
443
+ blk.H, blk.W = H, W
444
+ if self.use_checkpoint:
445
+ x = checkpoint.checkpoint(blk, x, attn_mask)
446
+ else:
447
+ x = blk(x, attn_mask)
448
+ if self.downsample is not None:
449
+ x_down = self.downsample(x, H, W)
450
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
451
+ return x, H, W, x_down, Wh, Ww
452
+ else:
453
+ return x, H, W, x, H, W
454
+
455
+
456
+ class PatchEmbed(nn.Module):
457
+ """Image to Patch Embedding
458
+ Args:
459
+ patch_size (int): Patch token size. Default: 4.
460
+ in_chans (int): Number of input image channels. Default: 3.
461
+ embed_dim (int): Number of linear projection output channels. Default: 96.
462
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
463
+ """
464
+
465
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
466
+ super().__init__()
467
+ patch_size = to_2tuple(patch_size)
468
+ self.patch_size = patch_size
469
+
470
+ self.in_chans = in_chans
471
+ self.embed_dim = embed_dim
472
+
473
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
474
+ if norm_layer is not None:
475
+ self.norm = norm_layer(embed_dim)
476
+ else:
477
+ self.norm = None
478
+
479
+ def forward(self, x):
480
+ """Forward function."""
481
+ # padding
482
+ _, _, H, W = x.size()
483
+ if W % self.patch_size[1] != 0:
484
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
485
+ if H % self.patch_size[0] != 0:
486
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
487
+
488
+ x = self.proj(x) # B C Wh Ww
489
+ if self.norm is not None:
490
+ Wh, Ww = x.size(2), x.size(3)
491
+ x = x.flatten(2).transpose(1, 2)
492
+ x = self.norm(x)
493
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
494
+
495
+ return x
496
+
497
+
498
+ class SwinTransformer(nn.Module):
499
+ """Swin Transformer backbone.
500
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
501
+ https://arxiv.org/pdf/2103.14030
502
+ Args:
503
+ pretrain_img_size (int): Input image size for training the pretrained model,
504
+ used in absolute postion embedding. Default 224.
505
+ patch_size (int | tuple(int)): Patch size. Default: 4.
506
+ in_chans (int): Number of input image channels. Default: 3.
507
+ embed_dim (int): Number of linear projection output channels. Default: 96.
508
+ depths (tuple[int]): Depths of each Swin Transformer stage.
509
+ num_heads (tuple[int]): Number of attention head of each stage.
510
+ window_size (int): Window size. Default: 7.
511
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
512
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
513
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
514
+ drop_rate (float): Dropout rate.
515
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
516
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
517
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
518
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
519
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
520
+ out_indices (Sequence[int]): Output from which stages.
521
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
522
+ -1 means not freezing any parameters.
523
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ pretrain_img_size=224,
529
+ patch_size=4,
530
+ in_chans=3,
531
+ embed_dim=96,
532
+ depths=[2, 2, 6, 2],
533
+ num_heads=[3, 6, 12, 24],
534
+ window_size=7,
535
+ mlp_ratio=4.0,
536
+ qkv_bias=True,
537
+ qk_scale=None,
538
+ drop_rate=0.0,
539
+ attn_drop_rate=0.0,
540
+ drop_path_rate=0.2,
541
+ norm_layer=nn.LayerNorm,
542
+ ape=False,
543
+ patch_norm=True,
544
+ out_indices=(0, 1, 2, 3),
545
+ frozen_stages=-1,
546
+ use_checkpoint=False,
547
+ ):
548
+ super().__init__()
549
+
550
+ self.pretrain_img_size = pretrain_img_size
551
+ self.num_layers = len(depths)
552
+ self.embed_dim = embed_dim
553
+ self.ape = ape
554
+ self.patch_norm = patch_norm
555
+ self.out_indices = out_indices
556
+ self.frozen_stages = frozen_stages
557
+
558
+ # split image into non-overlapping patches
559
+ self.patch_embed = PatchEmbed(
560
+ patch_size=patch_size,
561
+ in_chans=in_chans,
562
+ embed_dim=embed_dim,
563
+ norm_layer=norm_layer if self.patch_norm else None,
564
+ )
565
+
566
+ # absolute position embedding
567
+ if self.ape:
568
+ pretrain_img_size = to_2tuple(pretrain_img_size)
569
+ patch_size = to_2tuple(patch_size)
570
+ patches_resolution = [
571
+ pretrain_img_size[0] // patch_size[0],
572
+ pretrain_img_size[1] // patch_size[1],
573
+ ]
574
+
575
+ self.absolute_pos_embed = nn.Parameter(
576
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
577
+ )
578
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
579
+
580
+ self.pos_drop = nn.Dropout(p=drop_rate)
581
+
582
+ # stochastic depth
583
+ dpr = [
584
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
585
+ ] # stochastic depth decay rule
586
+
587
+ # build layers
588
+ self.layers = nn.ModuleList()
589
+ for i_layer in range(self.num_layers):
590
+ layer = BasicLayer(
591
+ dim=int(embed_dim * 2 ** i_layer),
592
+ depth=depths[i_layer],
593
+ num_heads=num_heads[i_layer],
594
+ window_size=window_size,
595
+ mlp_ratio=mlp_ratio,
596
+ qkv_bias=qkv_bias,
597
+ qk_scale=qk_scale,
598
+ drop=drop_rate,
599
+ attn_drop=attn_drop_rate,
600
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
601
+ norm_layer=norm_layer,
602
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
603
+ use_checkpoint=use_checkpoint,
604
+ )
605
+ self.layers.append(layer)
606
+
607
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
608
+ self.num_features = num_features
609
+
610
+ # add a norm layer for each output
611
+ for i_layer in out_indices:
612
+ layer = norm_layer(num_features[i_layer])
613
+ layer_name = f"norm{i_layer}"
614
+ self.add_module(layer_name, layer)
615
+
616
+ self._freeze_stages()
617
+
618
+ def _freeze_stages(self):
619
+ if self.frozen_stages >= 0:
620
+ self.patch_embed.eval()
621
+ for param in self.patch_embed.parameters():
622
+ param.requires_grad = False
623
+
624
+ if self.frozen_stages >= 1 and self.ape:
625
+ self.absolute_pos_embed.requires_grad = False
626
+
627
+ if self.frozen_stages >= 2:
628
+ self.pos_drop.eval()
629
+ for i in range(0, self.frozen_stages - 1):
630
+ m = self.layers[i]
631
+ m.eval()
632
+ for param in m.parameters():
633
+ param.requires_grad = False
634
+
635
+ def init_weights(self, pretrained=None):
636
+ """Initialize the weights in backbone.
637
+ Args:
638
+ pretrained (str, optional): Path to pre-trained weights.
639
+ Defaults to None.
640
+ """
641
+
642
+ def _init_weights(m):
643
+ if isinstance(m, nn.Linear):
644
+ trunc_normal_(m.weight, std=0.02)
645
+ if isinstance(m, nn.Linear) and m.bias is not None:
646
+ nn.init.constant_(m.bias, 0)
647
+ elif isinstance(m, nn.LayerNorm):
648
+ nn.init.constant_(m.bias, 0)
649
+ nn.init.constant_(m.weight, 1.0)
650
+
651
+ def forward(self, x):
652
+ """Forward function."""
653
+ x = self.patch_embed(x)
654
+
655
+ Wh, Ww = x.size(2), x.size(3)
656
+ if self.ape:
657
+ # interpolate the position embedding to the corresponding size
658
+ absolute_pos_embed = F.interpolate(
659
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
660
+ )
661
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
662
+ else:
663
+ x = x.flatten(2).transpose(1, 2)
664
+ x = self.pos_drop(x)
665
+
666
+ outs = {}
667
+ for i in range(self.num_layers):
668
+ layer = self.layers[i]
669
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
670
+
671
+ if i in self.out_indices:
672
+ norm_layer = getattr(self, f"norm{i}")
673
+ x_out = norm_layer(x_out)
674
+
675
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
676
+ outs["res{}".format(i + 2)] = out
677
+
678
+ return outs
679
+
680
+ def train(self, mode=True):
681
+ """Convert the model into training mode while keep layers freezed."""
682
+ super(SwinTransformer, self).train(mode)
683
+ self._freeze_stages()
684
+
685
+
686
+ @BACKBONE_REGISTRY.register()
687
+ class D2SwinTransformer(SwinTransformer, Backbone):
688
+ def __init__(self, cfg, input_shape):
689
+
690
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
691
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
692
+ in_chans = 3
693
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
694
+ depths = cfg.MODEL.SWIN.DEPTHS
695
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
696
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
697
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
698
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
699
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
700
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
701
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
702
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
703
+ norm_layer = nn.LayerNorm
704
+ ape = cfg.MODEL.SWIN.APE
705
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
706
+ use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT
707
+
708
+ super().__init__(
709
+ pretrain_img_size,
710
+ patch_size,
711
+ in_chans,
712
+ embed_dim,
713
+ depths,
714
+ num_heads,
715
+ window_size,
716
+ mlp_ratio,
717
+ qkv_bias,
718
+ qk_scale,
719
+ drop_rate,
720
+ attn_drop_rate,
721
+ drop_path_rate,
722
+ norm_layer,
723
+ ape,
724
+ patch_norm,
725
+ use_checkpoint=use_checkpoint,
726
+ )
727
+
728
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
729
+
730
+ self._out_feature_strides = {
731
+ "res2": 4,
732
+ "res3": 8,
733
+ "res4": 16,
734
+ "res5": 32,
735
+ }
736
+ self._out_feature_channels = {
737
+ "res2": self.num_features[0],
738
+ "res3": self.num_features[1],
739
+ "res4": self.num_features[2],
740
+ "res5": self.num_features[3],
741
+ }
742
+
743
+ def forward(self, x):
744
+ """
745
+ Args:
746
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
747
+ Returns:
748
+ dict[str->Tensor]: names and the corresponding features
749
+ """
750
+ assert (
751
+ x.dim() == 4
752
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
753
+ outputs = {}
754
+ y = super().forward(x)
755
+ for k in y.keys():
756
+ if k in self._out_features:
757
+ outputs[k] = y[k]
758
+ return outputs
759
+
760
+ def output_shape(self):
761
+ return {
762
+ name: ShapeSpec(
763
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
764
+ )
765
+ for name in self._out_features
766
+ }
767
+
768
+ @property
769
+ def size_divisibility(self):
770
+ return 32
annotator/entityseg/mask2former/modeling/criterion.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ """
4
+ MaskFormer criterion.
5
+ """
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from detectron2.utils.comm import get_world_size
13
+ from detectron2.projects.point_rend.point_features import (
14
+ get_uncertain_point_coords_with_randomness,
15
+ point_sample,
16
+ )
17
+
18
+ from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
19
+
20
+
21
+ def dice_loss(
22
+ inputs: torch.Tensor,
23
+ targets: torch.Tensor,
24
+ num_masks: float,
25
+ ):
26
+ """
27
+ Compute the DICE loss, similar to generalized IOU for masks
28
+ Args:
29
+ inputs: A float tensor of arbitrary shape.
30
+ The predictions for each example.
31
+ targets: A float tensor with the same shape as inputs. Stores the binary
32
+ classification label for each element in inputs
33
+ (0 for the negative class and 1 for the positive class).
34
+ """
35
+ inputs = inputs.sigmoid()
36
+ inputs = inputs.flatten(1)
37
+ numerator = 2 * (inputs * targets).sum(-1)
38
+ denominator = inputs.sum(-1) + targets.sum(-1)
39
+ loss = 1 - (numerator + 1) / (denominator + 1)
40
+ return loss.sum() / num_masks
41
+
42
+
43
+ dice_loss_jit = torch.jit.script(
44
+ dice_loss
45
+ ) # type: torch.jit.ScriptModule
46
+
47
+
48
+ def sigmoid_ce_loss(
49
+ inputs: torch.Tensor,
50
+ targets: torch.Tensor,
51
+ num_masks: float,
52
+ ):
53
+ """
54
+ Args:
55
+ inputs: A float tensor of arbitrary shape.
56
+ The predictions for each example.
57
+ targets: A float tensor with the same shape as inputs. Stores the binary
58
+ classification label for each element in inputs
59
+ (0 for the negative class and 1 for the positive class).
60
+ Returns:
61
+ Loss tensor
62
+ """
63
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
64
+
65
+ return loss.mean(1).sum() / num_masks
66
+
67
+
68
+ sigmoid_ce_loss_jit = torch.jit.script(
69
+ sigmoid_ce_loss
70
+ ) # type: torch.jit.ScriptModule
71
+
72
+
73
+ def calculate_uncertainty(logits):
74
+ """
75
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
76
+ foreground class in `classes`.
77
+ Args:
78
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
79
+ class-agnostic, where R is the total number of predicted masks in all images and C is
80
+ the number of foreground classes. The values are logits.
81
+ Returns:
82
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
83
+ the most uncertain locations having the highest uncertainty score.
84
+ """
85
+ assert logits.shape[1] == 1
86
+ gt_class_logits = logits.clone()
87
+ return -(torch.abs(gt_class_logits))
88
+
89
+
90
+ class SetCriterion(nn.Module):
91
+ """This class computes the loss for DETR.
92
+ The process happens in two steps:
93
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
94
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
95
+ """
96
+
97
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
98
+ num_points, oversample_ratio, importance_sample_ratio):
99
+ """Create the criterion.
100
+ Parameters:
101
+ num_classes: number of object categories, omitting the special no-object category
102
+ matcher: module able to compute a matching between targets and proposals
103
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
104
+ eos_coef: relative classification weight applied to the no-object category
105
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
106
+ """
107
+ super().__init__()
108
+ self.num_classes = num_classes
109
+ self.matcher = matcher
110
+ self.weight_dict = weight_dict
111
+ self.eos_coef = eos_coef
112
+ self.losses = losses
113
+ empty_weight = torch.ones(self.num_classes + 1)
114
+ empty_weight[-1] = self.eos_coef
115
+ self.register_buffer("empty_weight", empty_weight)
116
+
117
+ # pointwise mask loss parameters
118
+ self.num_points = num_points
119
+ self.oversample_ratio = oversample_ratio
120
+ self.importance_sample_ratio = importance_sample_ratio
121
+
122
+ def loss_labels(self, outputs, targets, indices, num_masks):
123
+ """Classification loss (NLL)
124
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
125
+ """
126
+ assert "pred_logits" in outputs
127
+ src_logits = outputs["pred_logits"].float()
128
+
129
+ idx = self._get_src_permutation_idx(indices)
130
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
131
+ target_classes = torch.full(
132
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
133
+ )
134
+ target_classes[idx] = target_classes_o
135
+
136
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
137
+ losses = {"loss_ce": loss_ce}
138
+ return losses
139
+
140
+ def loss_masks(self, outputs, targets, indices, num_masks):
141
+ """Compute the losses related to the masks: the focal loss and the dice loss.
142
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
143
+ """
144
+ assert "pred_masks" in outputs
145
+
146
+ src_idx = self._get_src_permutation_idx(indices)
147
+ tgt_idx = self._get_tgt_permutation_idx(indices)
148
+ src_masks = outputs["pred_masks"]
149
+ src_masks = src_masks[src_idx]
150
+ masks = [t["masks"] for t in targets]
151
+ # TODO use valid to mask invalid areas due to padding in loss
152
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
153
+ target_masks = target_masks.to(src_masks)
154
+ target_masks = target_masks[tgt_idx]
155
+
156
+ # No need to upsample predictions as we are using normalized coordinates :)
157
+ # N x 1 x H x W
158
+ src_masks = src_masks[:, None]
159
+ target_masks = target_masks[:, None]
160
+
161
+ with torch.no_grad():
162
+ # sample point_coords
163
+ point_coords = get_uncertain_point_coords_with_randomness(
164
+ src_masks,
165
+ lambda logits: calculate_uncertainty(logits),
166
+ self.num_points,
167
+ self.oversample_ratio,
168
+ self.importance_sample_ratio,
169
+ )
170
+ # get gt labels
171
+ point_labels = point_sample(
172
+ target_masks,
173
+ point_coords,
174
+ align_corners=False,
175
+ ).squeeze(1)
176
+
177
+ point_logits = point_sample(
178
+ src_masks,
179
+ point_coords,
180
+ align_corners=False,
181
+ ).squeeze(1)
182
+
183
+ losses = {
184
+ "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
185
+ "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks),
186
+ }
187
+
188
+ del src_masks
189
+ del target_masks
190
+ return losses
191
+
192
+ def _get_src_permutation_idx(self, indices):
193
+ # permute predictions following indices
194
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
195
+ src_idx = torch.cat([src for (src, _) in indices])
196
+ return batch_idx, src_idx
197
+
198
+ def _get_tgt_permutation_idx(self, indices):
199
+ # permute targets following indices
200
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
201
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
202
+ return batch_idx, tgt_idx
203
+
204
+ def get_loss(self, loss, outputs, targets, indices, num_masks):
205
+ loss_map = {
206
+ 'labels': self.loss_labels,
207
+ 'masks': self.loss_masks,
208
+ }
209
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
210
+ return loss_map[loss](outputs, targets, indices, num_masks)
211
+
212
+ def forward(self, outputs, targets):
213
+ """This performs the loss computation.
214
+ Parameters:
215
+ outputs: dict of tensors, see the output specification of the model for the format
216
+ targets: list of dicts, such that len(targets) == batch_size.
217
+ The expected keys in each dict depends on the losses applied, see each loss' doc
218
+ """
219
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
220
+
221
+ # Retrieve the matching between the outputs of the last layer and the targets
222
+ indices = self.matcher(outputs_without_aux, targets)
223
+
224
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
225
+ num_masks = sum(len(t["labels"]) for t in targets)
226
+ num_masks = torch.as_tensor(
227
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
228
+ )
229
+ if is_dist_avail_and_initialized():
230
+ torch.distributed.all_reduce(num_masks)
231
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
232
+
233
+ # Compute all the requested losses
234
+ losses = {}
235
+ for loss in self.losses:
236
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
237
+
238
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
239
+ if "aux_outputs" in outputs:
240
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
241
+ indices = self.matcher(aux_outputs, targets)
242
+ for loss in self.losses:
243
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
244
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
245
+ losses.update(l_dict)
246
+
247
+ return losses
248
+
249
+ def __repr__(self):
250
+ head = "Criterion " + self.__class__.__name__
251
+ body = [
252
+ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
253
+ "losses: {}".format(self.losses),
254
+ "weight_dict: {}".format(self.weight_dict),
255
+ "num_classes: {}".format(self.num_classes),
256
+ "eos_coef: {}".format(self.eos_coef),
257
+ "num_points: {}".format(self.num_points),
258
+ "oversample_ratio: {}".format(self.oversample_ratio),
259
+ "importance_sample_ratio: {}".format(self.importance_sample_ratio),
260
+ ]
261
+ _repr_indent = 4
262
+ lines = [head] + [" " * _repr_indent + line for line in body]
263
+ return "\n".join(lines)
annotator/entityseg/mask2former/modeling/criterion_view.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ """
4
+ MaskFormer criterion.
5
+ """
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from detectron2.utils.comm import get_world_size
13
+ from detectron2.projects.point_rend.point_features import (
14
+ get_uncertain_point_coords_with_randomness,
15
+ point_sample,
16
+ )
17
+
18
+ from mask2former.utils.misc import is_dist_avail_and_initialized
19
+
20
+ import pdb
21
+
22
+
23
+ def dice_loss(
24
+ inputs: torch.Tensor,
25
+ targets: torch.Tensor,
26
+ num_masks: float,
27
+ ):
28
+ """
29
+ Compute the DICE loss, similar to generalized IOU for masks
30
+ Args:
31
+ inputs: A float tensor of arbitrary shape.
32
+ The predictions for each example.
33
+ targets: A float tensor with the same shape as inputs. Stores the binary
34
+ classification label for each element in inputs
35
+ (0 for the negative class and 1 for the positive class).
36
+ """
37
+ inputs = inputs.sigmoid()
38
+ inputs = inputs.flatten(1)
39
+ numerator = 2 * (inputs * targets).sum(-1)
40
+ denominator = inputs.sum(-1) + targets.sum(-1)
41
+ loss = 1 - (numerator + 1) / (denominator + 1)
42
+ return loss.sum() / num_masks
43
+
44
+
45
+ dice_loss_jit = torch.jit.script(
46
+ dice_loss
47
+ ) # type: torch.jit.ScriptModule
48
+
49
+
50
+ def sigmoid_ce_loss(
51
+ inputs: torch.Tensor,
52
+ targets: torch.Tensor,
53
+ num_masks: float,
54
+ ):
55
+ """
56
+ Args:
57
+ inputs: A float tensor of arbitrary shape.
58
+ The predictions for each example.
59
+ targets: A float tensor with the same shape as inputs. Stores the binary
60
+ classification label for each element in inputs
61
+ (0 for the negative class and 1 for the positive class).
62
+ Returns:
63
+ Loss tensor
64
+ """
65
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
66
+
67
+ return loss.mean(1).sum() / num_masks
68
+
69
+
70
+ sigmoid_ce_loss_jit = torch.jit.script(
71
+ sigmoid_ce_loss
72
+ ) # type: torch.jit.ScriptModule
73
+
74
+
75
+ def calculate_uncertainty(logits):
76
+ """
77
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
78
+ foreground class in `classes`.
79
+ Args:
80
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
81
+ class-agnostic, where R is the total number of predicted masks in all images and C is
82
+ the number of foreground classes. The values are logits.
83
+ Returns:
84
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
85
+ the most uncertain locations having the highest uncertainty score.
86
+ """
87
+ assert logits.shape[1] == 1
88
+ gt_class_logits = logits.clone()
89
+ return -(torch.abs(gt_class_logits))
90
+
91
+
92
+ class ViewSetCriterion(nn.Module):
93
+ """This class computes the loss for DETR.
94
+ The process happens in two steps:
95
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
96
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
97
+ """
98
+
99
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
100
+ num_points, oversample_ratio, importance_sample_ratio):
101
+ """Create the criterion.
102
+ Parameters:
103
+ num_classes: number of object categories, omitting the special no-object category
104
+ matcher: module able to compute a matching between targets and proposals
105
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
106
+ eos_coef: relative classification weight applied to the no-object category
107
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
108
+ """
109
+ super().__init__()
110
+ self.num_classes = num_classes
111
+ self.matcher = matcher
112
+ self.weight_dict = weight_dict
113
+ self.eos_coef = eos_coef
114
+ self.losses = losses
115
+ empty_weight = torch.ones(self.num_classes + 1)
116
+ empty_weight[-1] = self.eos_coef
117
+ self.register_buffer("empty_weight", empty_weight)
118
+
119
+ # pointwise mask loss parameters
120
+ self.num_points = num_points
121
+ self.oversample_ratio = oversample_ratio
122
+ self.importance_sample_ratio = importance_sample_ratio
123
+
124
+ def loss_labels(self, outputs, targets, indices, num_masks):
125
+ """Classification loss (NLL)
126
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
127
+ """
128
+ assert "pred_logits" in outputs
129
+ src_logits = outputs["pred_logits"].float()
130
+ ## src_logits: torch.Size([2, 100, 41])
131
+
132
+ idx = self._get_src_permutation_idx(indices)
133
+ ## idx: (tensor([0, 0, 1, 1]), tensor([17, 84, 17, 76]))
134
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
135
+ ### target_class_o: tensor([ 0, 26, 0, 11], device='cuda:0')
136
+ target_classes = torch.full(
137
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
138
+ )
139
+ ## target_class: torch.Size([2, 100]), 全是40, background类
140
+ target_classes[idx] = target_classes_o
141
+ ##
142
+ ## src_logits: torch.Size([2, 41, 100])
143
+ ## target_classes: torch.Size([2, 100])
144
+ ## self.empty_weight: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
145
+ ##1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
146
+ ##1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
147
+ ##1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
148
+ ##1.0000, 1.0000, 1.0000, 1.0000, 0.1000], device='cuda:0')
149
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
150
+ losses = {"loss_ce": loss_ce}
151
+ return losses
152
+
153
+ def loss_masks(self, outputs, targets, indices, num_masks):
154
+ """Compute the losses related to the masks: the focal loss and the dice loss.
155
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
156
+ """
157
+ assert "pred_masks" in outputs
158
+
159
+ src_idx = self._get_src_permutation_idx(indices)
160
+ ### src_idx: (tensor([0, 0, 1, 1]), tensor([34, 95, 32, 65]))
161
+ src_masks = outputs["pred_masks"]
162
+ ## src_masks: torch.Size([2, 100, 2, 120, 216])
163
+ src_masks = src_masks[src_idx]
164
+ ## src_masks: torch.Size([4, 2, 120, 216])
165
+
166
+ # Modified to handle video
167
+ target_masks = torch.cat([t['masks'][i] for t, (_, i) in zip(targets, indices)]).to(src_masks)
168
+ ### target_masks: torch.Size([4, 2, 480, 864])
169
+
170
+ # No need to upsample predictions as we are using normalized coordinates :)
171
+ # NT x 1 x H x W
172
+ src_masks = src_masks.flatten(0, 1)[:, None]
173
+ ## src_masks: torch.Size([8, 1, 120, 216])
174
+ target_masks = target_masks.flatten(0, 1)[:, None]
175
+ ## target_masks: torch.Size([8, 1, 480, 864])
176
+
177
+ with torch.no_grad():
178
+ # sample point_coords
179
+ point_coords = get_uncertain_point_coords_with_randomness(
180
+ src_masks,
181
+ lambda logits: calculate_uncertainty(logits),
182
+ self.num_points,
183
+ self.oversample_ratio,
184
+ self.importance_sample_ratio,
185
+ )
186
+ # get gt labels
187
+ point_labels = point_sample(
188
+ target_masks,
189
+ point_coords,
190
+ align_corners=False,
191
+ ).squeeze(1)
192
+
193
+ point_logits = point_sample(
194
+ src_masks,
195
+ point_coords,
196
+ align_corners=False,
197
+ ).squeeze(1)
198
+
199
+ losses = {
200
+ "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
201
+ "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks),
202
+ }
203
+
204
+ del src_masks
205
+ del target_masks
206
+ return losses
207
+
208
+ def _get_src_permutation_idx(self, indices):
209
+ # permute predictions following indices
210
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
211
+ src_idx = torch.cat([src for (src, _) in indices])
212
+ return batch_idx, src_idx
213
+
214
+ def _get_tgt_permutation_idx(self, indices):
215
+ # permute targets following indices
216
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
217
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
218
+ return batch_idx, tgt_idx
219
+
220
+ def get_loss(self, loss, outputs, targets, indices, num_masks):
221
+ loss_map = {
222
+ 'labels': self.loss_labels,
223
+ 'masks': self.loss_masks,
224
+ }
225
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
226
+ return loss_map[loss](outputs, targets, indices, num_masks)
227
+
228
+ def forward(self, outputs, targets, return_indices=False):
229
+ """This performs the loss computation.
230
+ Parameters:
231
+ outputs: dict of tensors, see the output specification of the model for the format
232
+ targets: list of dicts, such that len(targets) == batch_size.
233
+ The expected keys in each dict depends on the losses applied, see each loss' doc
234
+ """
235
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
236
+
237
+ # Retrieve the matching between the outputs of the last layer and the targets
238
+ indices = self.matcher(outputs_without_aux, targets)
239
+ indices_l = []
240
+ indices_l.append(indices)
241
+ # pdb.set_trace()
242
+
243
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
244
+ num_masks = sum(len(t["labels"]) for t in targets)
245
+ num_masks = torch.as_tensor(
246
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
247
+ )
248
+ if is_dist_avail_and_initialized():
249
+ torch.distributed.all_reduce(num_masks)
250
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
251
+
252
+ # Compute all the requested losses
253
+ losses = {}
254
+ for loss in self.losses:
255
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
256
+
257
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
258
+ if "aux_outputs" in outputs:
259
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
260
+ indices = self.matcher(aux_outputs, targets)
261
+ indices_l.append(indices)
262
+ for loss in self.losses:
263
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
264
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
265
+ losses.update(l_dict)
266
+ indices_l.append(indices_l[0])
267
+ indices_l = indices_l[1:]
268
+
269
+ if return_indices:
270
+ return losses, indices_l
271
+ else:
272
+ return losses
273
+
274
+ def __repr__(self):
275
+ head = "Criterion " + self.__class__.__name__
276
+ body = [
277
+ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
278
+ "losses: {}".format(self.losses),
279
+ "weight_dict: {}".format(self.weight_dict),
280
+ "num_classes: {}".format(self.num_classes),
281
+ "eos_coef: {}".format(self.eos_coef),
282
+ "num_points: {}".format(self.num_points),
283
+ "oversample_ratio: {}".format(self.oversample_ratio),
284
+ "importance_sample_ratio: {}".format(self.importance_sample_ratio),
285
+ ]
286
+ _repr_indent = 4
287
+ lines = [head] + [" " * _repr_indent + line for line in body]
288
+ return "\n".join(lines)
annotator/entityseg/mask2former/modeling/matcher.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
3
+ """
4
+ Modules to compute the matching cost and solve the corresponding LSAP.
5
+ """
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from scipy.optimize import linear_sum_assignment
9
+ from torch import nn
10
+ from torch.cuda.amp import autocast
11
+
12
+ from detectron2.projects.point_rend.point_features import point_sample
13
+
14
+
15
+ def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
16
+ """
17
+ Compute the DICE loss, similar to generalized IOU for masks
18
+ Args:
19
+ inputs: A float tensor of arbitrary shape.
20
+ The predictions for each example.
21
+ targets: A float tensor with the same shape as inputs. Stores the binary
22
+ classification label for each element in inputs
23
+ (0 for the negative class and 1 for the positive class).
24
+ """
25
+ inputs = inputs.sigmoid()
26
+ inputs = inputs.flatten(1)
27
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
28
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
29
+ loss = 1 - (numerator + 1) / (denominator + 1)
30
+ return loss
31
+
32
+
33
+ batch_dice_loss_jit = torch.jit.script(
34
+ batch_dice_loss
35
+ ) # type: torch.jit.ScriptModule
36
+
37
+
38
+ def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
39
+ """
40
+ Args:
41
+ inputs: A float tensor of arbitrary shape.
42
+ The predictions for each example.
43
+ targets: A float tensor with the same shape as inputs. Stores the binary
44
+ classification label for each element in inputs
45
+ (0 for the negative class and 1 for the positive class).
46
+ Returns:
47
+ Loss tensor
48
+ """
49
+ hw = inputs.shape[1]
50
+
51
+ pos = F.binary_cross_entropy_with_logits(
52
+ inputs, torch.ones_like(inputs), reduction="none"
53
+ )
54
+ neg = F.binary_cross_entropy_with_logits(
55
+ inputs, torch.zeros_like(inputs), reduction="none"
56
+ )
57
+
58
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
59
+ "nc,mc->nm", neg, (1 - targets)
60
+ )
61
+
62
+ return loss / hw
63
+
64
+
65
+ batch_sigmoid_ce_loss_jit = torch.jit.script(
66
+ batch_sigmoid_ce_loss
67
+ ) # type: torch.jit.ScriptModule
68
+
69
+
70
+ class HungarianMatcher(nn.Module):
71
+ """This class computes an assignment between the targets and the predictions of the network
72
+
73
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
74
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
75
+ while the others are un-matched (and thus treated as non-objects).
76
+ """
77
+
78
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0):
79
+ """Creates the matcher
80
+
81
+ Params:
82
+ cost_class: This is the relative weight of the classification error in the matching cost
83
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
84
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
85
+ """
86
+ super().__init__()
87
+ self.cost_class = cost_class
88
+ self.cost_mask = cost_mask
89
+ self.cost_dice = cost_dice
90
+
91
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
92
+
93
+ self.num_points = num_points
94
+
95
+ @torch.no_grad()
96
+ def memory_efficient_forward(self, outputs, targets):
97
+ """More memory-friendly matching"""
98
+ bs, num_queries = outputs["pred_logits"].shape[:2]
99
+
100
+ indices = []
101
+
102
+ # Iterate through batch size
103
+ for b in range(bs):
104
+
105
+ out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
106
+ tgt_ids = targets[b]["labels"]
107
+
108
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
109
+ # but approximate it in 1 - proba[target class].
110
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
111
+ cost_class = -out_prob[:, tgt_ids]
112
+
113
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
114
+ # gt masks are already padded when preparing target
115
+ tgt_mask = targets[b]["masks"].to(out_mask)
116
+
117
+ out_mask = out_mask[:, None]
118
+ tgt_mask = tgt_mask[:, None]
119
+ # all masks share the same set of points for efficient matching!
120
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
121
+ # get gt labels
122
+ tgt_mask = point_sample(
123
+ tgt_mask,
124
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
125
+ align_corners=False,
126
+ ).squeeze(1)
127
+
128
+ out_mask = point_sample(
129
+ out_mask,
130
+ point_coords.repeat(out_mask.shape[0], 1, 1),
131
+ align_corners=False,
132
+ ).squeeze(1)
133
+
134
+ with autocast(enabled=False):
135
+ out_mask = out_mask.float()
136
+ tgt_mask = tgt_mask.float()
137
+ # Compute the focal loss between masks
138
+ cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask)
139
+
140
+ # Compute the dice loss betwen masks
141
+ cost_dice = batch_dice_loss(out_mask, tgt_mask)
142
+
143
+ # Final cost matrix
144
+ C = (
145
+ self.cost_mask * cost_mask
146
+ + self.cost_class * cost_class
147
+ + self.cost_dice * cost_dice
148
+ )
149
+ C = C.reshape(num_queries, -1).cpu()
150
+
151
+ indices.append(linear_sum_assignment(C))
152
+
153
+ return [
154
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
155
+ for i, j in indices
156
+ ]
157
+
158
+ @torch.no_grad()
159
+ def forward(self, outputs, targets):
160
+ """Performs the matching
161
+
162
+ Params:
163
+ outputs: This is a dict that contains at least these entries:
164
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
165
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
166
+
167
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
168
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
169
+ objects in the target) containing the class labels
170
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
171
+
172
+ Returns:
173
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
174
+ - index_i is the indices of the selected predictions (in order)
175
+ - index_j is the indices of the corresponding selected targets (in order)
176
+ For each batch element, it holds:
177
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
178
+ """
179
+ return self.memory_efficient_forward(outputs, targets)
180
+
181
+ def __repr__(self, _repr_indent=4):
182
+ head = "Matcher " + self.__class__.__name__
183
+ body = [
184
+ "cost_class: {}".format(self.cost_class),
185
+ "cost_mask: {}".format(self.cost_mask),
186
+ "cost_dice: {}".format(self.cost_dice),
187
+ ]
188
+ lines = [head] + [" " * _repr_indent + line for line in body]
189
+ return "\n".join(lines)
annotator/entityseg/mask2former/modeling/matcher_view.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
3
+ """
4
+ Modules to compute the matching cost and solve the corresponding LSAP.
5
+ """
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from scipy.optimize import linear_sum_assignment
9
+ from torch import nn
10
+ from torch.cuda.amp import autocast
11
+
12
+ from detectron2.projects.point_rend.point_features import point_sample
13
+
14
+ def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
15
+ """
16
+ Compute the DICE loss, similar to generalized IOU for masks
17
+ Args:
18
+ inputs: A float tensor of arbitrary shape.
19
+ The predictions for each example.
20
+ targets: A float tensor with the same shape as inputs. Stores the binary
21
+ classification label for each element in inputs
22
+ (0 for the negative class and 1 for the positive class).
23
+ """
24
+ inputs = inputs.sigmoid()
25
+ inputs = inputs.flatten(1)
26
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
27
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
28
+ loss = 1 - (numerator + 1) / (denominator + 1)
29
+ return loss
30
+
31
+
32
+ batch_dice_loss_jit = torch.jit.script(
33
+ batch_dice_loss
34
+ ) # type: torch.jit.ScriptModule
35
+
36
+
37
+ def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
38
+ """
39
+ Args:
40
+ inputs: A float tensor of arbitrary shape.
41
+ The predictions for each example.
42
+ targets: A float tensor with the same shape as inputs. Stores the binary
43
+ classification label for each element in inputs
44
+ (0 for the negative class and 1 for the positive class).
45
+ Returns:
46
+ Loss tensor
47
+ """
48
+ hw = inputs.shape[1]
49
+
50
+ pos = F.binary_cross_entropy_with_logits(
51
+ inputs, torch.ones_like(inputs), reduction="none"
52
+ )
53
+ neg = F.binary_cross_entropy_with_logits(
54
+ inputs, torch.zeros_like(inputs), reduction="none"
55
+ )
56
+
57
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
58
+ "nc,mc->nm", neg, (1 - targets)
59
+ )
60
+
61
+ return loss / hw
62
+
63
+
64
+ batch_sigmoid_ce_loss_jit = torch.jit.script(
65
+ batch_sigmoid_ce_loss
66
+ ) # type: torch.jit.ScriptModule
67
+
68
+
69
+ class ViewHungarianMatcher(nn.Module):
70
+ """This class computes an assignment between the targets and the predictions of the network
71
+
72
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
73
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
74
+ while the others are un-matched (and thus treated as non-objects).
75
+ """
76
+
77
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0):
78
+ """Creates the matcher
79
+
80
+ Params:
81
+ cost_class: This is the relative weight of the classification error in the matching cost
82
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
83
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
84
+ """
85
+ super().__init__()
86
+ self.cost_class = cost_class
87
+ self.cost_mask = cost_mask
88
+ self.cost_dice = cost_dice
89
+
90
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
91
+
92
+ self.num_points = num_points
93
+
94
+ @torch.no_grad()
95
+ def memory_efficient_forward(self, outputs, targets):
96
+ """More memory-friendly matching"""
97
+ ### outputs["pred_logits"]: torch.Size([2, 100, 41]), query是对两帧负责,所以没有frame的概念
98
+ ### outputs["pred_masks"]: torch.Size([2, 100, 2, 120, 160]), 第三维的2是两帧frame
99
+ bs, num_queries = outputs["pred_logits"].shape[:2]
100
+
101
+ indices = []
102
+
103
+ # Iterate through batch size
104
+ for b in range(bs):
105
+ out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
106
+ ## out_prob: [100, 41], 100个query, 40类+background类
107
+ tgt_ids = targets[b]["labels"]
108
+ ## tgt_ids: tensor([ 3, 10]), 说明只有两个ground truth
109
+
110
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
111
+ # but approximate it in 1 - proba[target class].
112
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
113
+ cost_class = -out_prob[:, tgt_ids]
114
+
115
+ out_mask = outputs["pred_masks"][b] # [num_queries, T, H_pred, W_pred]
116
+ ### out_mask: torch.Size([100, 2, 120, 160])
117
+ # gt masks are already padded when preparing target
118
+ tgt_mask = targets[b]["masks"].to(out_mask) # [num_gts, T, H_pred, W_pred]
119
+ ## tgt_mask: torch.Size([2, 2, 480, 640])
120
+
121
+ # out_mask = out_mask[:, None]
122
+ # tgt_mask = tgt_mask[:, None]
123
+ # all masks share the same set of points for efficient matching!
124
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
125
+ # get gt labels
126
+ tgt_mask = point_sample(
127
+ tgt_mask,
128
+ point_coords.repeat(tgt_mask.shape[0], 1, 1), ## repeat了一份, torch.Size([2, 12544, 2]), 每一帧采样的位置都是一样的
129
+ align_corners=False,
130
+ ).flatten(1)
131
+
132
+ out_mask = point_sample(
133
+ out_mask,
134
+ point_coords.repeat(out_mask.shape[0], 1, 1),
135
+ align_corners=False,
136
+ ).flatten(1)
137
+
138
+ with autocast(enabled=False):
139
+ out_mask = out_mask.float() ## out_mask: torch.Size([100, 25088])
140
+ tgt_mask = tgt_mask.float() ## tgt_mask: torch.Size([2, 25088])
141
+ # Compute the focal loss between masks
142
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) ## cost_mask: torch.Size([100, 2])
143
+
144
+ # Compute the dice loss betwen masks
145
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) ## cost_dice: torch.Size([100, 2])
146
+
147
+ # Final cost matrix
148
+ C = (
149
+ self.cost_mask * cost_mask
150
+ + self.cost_class * cost_class
151
+ + self.cost_dice * cost_dice
152
+ )
153
+ C = C.reshape(num_queries, -1).cpu()
154
+
155
+ indices.append(linear_sum_assignment(C))
156
+ ## [(array([17, 33]), array([1, 0]), ...]
157
+
158
+ return [
159
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
160
+ for i, j in indices
161
+ ]
162
+
163
+ @torch.no_grad()
164
+ def forward(self, outputs, targets):
165
+ """Performs the matching
166
+
167
+ Params:
168
+ outputs: This is a dict that contains at least these entries:
169
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
170
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
171
+
172
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
173
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
174
+ objects in the target) containing the class labels
175
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
176
+
177
+ Returns:
178
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
179
+ - index_i is the indices of the selected predictions (in order)
180
+ - index_j is the indices of the corresponding selected targets (in order)
181
+ For each batch element, it holds:
182
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
183
+ """
184
+ return self.memory_efficient_forward(outputs, targets)
185
+
186
+ def __repr__(self, _repr_indent=4):
187
+ head = "Matcher " + self.__class__.__name__
188
+ body = [
189
+ "cost_class: {}".format(self.cost_class),
190
+ "cost_mask: {}".format(self.cost_mask),
191
+ "cost_dice: {}".format(self.cost_dice),
192
+ ]
193
+ lines = [head] + [" " * _repr_indent + line for line in body]
194
+ return "\n".join(lines)
annotator/entityseg/mask2former/modeling/meta_arch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
annotator/entityseg/mask2former/modeling/meta_arch/mask_former_head.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
12
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
13
+
14
+ from ..transformer_decoder.maskformer_transformer_decoder import build_transformer_decoder
15
+ from ..pixel_decoder.fpn import build_pixel_decoder
16
+
17
+
18
+ @SEM_SEG_HEADS_REGISTRY.register()
19
+ class MaskFormerHead(nn.Module):
20
+
21
+ _version = 2
22
+
23
+ def _load_from_state_dict(
24
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
25
+ ):
26
+ version = local_metadata.get("version", None)
27
+ if version is None or version < 2:
28
+ # Do not warn if train from scratch
29
+ scratch = True
30
+ logger = logging.getLogger(__name__)
31
+ for k in list(state_dict.keys()):
32
+ newk = k
33
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
34
+ # newk = k.replace(prefix, prefix + "pixel_decoder.")
35
+ newk = k.replace(prefix, prefix)
36
+ # logger.debug(f"{k} ==> {newk}")
37
+ if newk != k:
38
+ state_dict[newk] = state_dict[k]
39
+ del state_dict[k]
40
+ scratch = False
41
+
42
+ if not scratch:
43
+ logger.warning(
44
+ f"Weight format of {self.__class__.__name__} have changed! "
45
+ "Please upgrade your models. Applying automatic conversion now ..."
46
+ )
47
+
48
+ @configurable
49
+ def __init__(
50
+ self,
51
+ input_shape: Dict[str, ShapeSpec],
52
+ *,
53
+ num_classes: int,
54
+ pixel_decoder: nn.Module,
55
+ loss_weight: float = 1.0,
56
+ ignore_value: int = -1,
57
+ # extra parameters
58
+ transformer_predictor: nn.Module,
59
+ transformer_in_feature: str,
60
+ ):
61
+ """
62
+ NOTE: this interface is experimental.
63
+ Args:
64
+ input_shape: shapes (channels and stride) of the input features
65
+ num_classes: number of classes to predict
66
+ pixel_decoder: the pixel decoder module
67
+ loss_weight: loss weight
68
+ ignore_value: category id to be ignored during training.
69
+ transformer_predictor: the transformer decoder that makes prediction
70
+ transformer_in_feature: input feature name to the transformer_predictor
71
+ """
72
+ super().__init__()
73
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
74
+ self.in_features = [k for k, v in input_shape]
75
+ feature_strides = [v.stride for k, v in input_shape]
76
+ feature_channels = [v.channels for k, v in input_shape]
77
+
78
+ self.ignore_value = ignore_value
79
+ self.common_stride = 4
80
+ self.loss_weight = loss_weight
81
+
82
+ self.pixel_decoder = pixel_decoder
83
+ self.predictor = transformer_predictor
84
+ self.transformer_in_feature = transformer_in_feature
85
+
86
+ self.num_classes = num_classes
87
+
88
+ @classmethod
89
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
90
+ # figure out in_channels to transformer predictor
91
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
92
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
93
+ elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "pixel_embedding":
94
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
95
+ elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "multi_scale_pixel_decoder": # for maskformer2
96
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
97
+ else:
98
+ transformer_predictor_in_channels = input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels
99
+
100
+ return {
101
+ "input_shape": {
102
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
103
+ },
104
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
105
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
106
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
107
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
108
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
109
+ "transformer_predictor": build_transformer_decoder(
110
+ cfg,
111
+ transformer_predictor_in_channels,
112
+ mask_classification=True,
113
+ ),
114
+ }
115
+
116
+ def forward(self, features, mask=None):
117
+ return self.layers(features, mask)
118
+
119
+ def layers(self, features, mask=None):
120
+ mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
121
+ if self.transformer_in_feature == "multi_scale_pixel_decoder":
122
+ predictions = self.predictor(multi_scale_features, mask_features, mask)
123
+ else:
124
+ if self.transformer_in_feature == "transformer_encoder":
125
+ assert (
126
+ transformer_encoder_features is not None
127
+ ), "Please use the TransformerEncoderPixelDecoder."
128
+ predictions = self.predictor(transformer_encoder_features, mask_features, mask)
129
+ elif self.transformer_in_feature == "pixel_embedding":
130
+ predictions = self.predictor(mask_features, mask_features, mask)
131
+ else:
132
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
133
+ return predictions
annotator/entityseg/mask2former/modeling/meta_arch/per_pixel_baseline.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
11
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
12
+
13
+ from ..transformer_decoder.maskformer_transformer_decoder import StandardTransformerDecoder
14
+ from ..pixel_decoder.fpn import build_pixel_decoder
15
+
16
+
17
+ @SEM_SEG_HEADS_REGISTRY.register()
18
+ class PerPixelBaselineHead(nn.Module):
19
+
20
+ _version = 2
21
+
22
+ def _load_from_state_dict(
23
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
24
+ ):
25
+ version = local_metadata.get("version", None)
26
+ if version is None or version < 2:
27
+ logger = logging.getLogger(__name__)
28
+ # Do not warn if train from scratch
29
+ scratch = True
30
+ logger = logging.getLogger(__name__)
31
+ for k in list(state_dict.keys()):
32
+ newk = k
33
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
34
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
35
+ # logger.warning(f"{k} ==> {newk}")
36
+ if newk != k:
37
+ state_dict[newk] = state_dict[k]
38
+ del state_dict[k]
39
+ scratch = False
40
+
41
+ if not scratch:
42
+ logger.warning(
43
+ f"Weight format of {self.__class__.__name__} have changed! "
44
+ "Please upgrade your models. Applying automatic conversion now ..."
45
+ )
46
+
47
+ @configurable
48
+ def __init__(
49
+ self,
50
+ input_shape: Dict[str, ShapeSpec],
51
+ *,
52
+ num_classes: int,
53
+ pixel_decoder: nn.Module,
54
+ loss_weight: float = 1.0,
55
+ ignore_value: int = -1,
56
+ ):
57
+ """
58
+ NOTE: this interface is experimental.
59
+ Args:
60
+ input_shape: shapes (channels and stride) of the input features
61
+ num_classes: number of classes to predict
62
+ pixel_decoder: the pixel decoder module
63
+ loss_weight: loss weight
64
+ ignore_value: category id to be ignored during training.
65
+ """
66
+ super().__init__()
67
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
68
+ self.in_features = [k for k, v in input_shape]
69
+ feature_strides = [v.stride for k, v in input_shape]
70
+ feature_channels = [v.channels for k, v in input_shape]
71
+
72
+ self.ignore_value = ignore_value
73
+ self.common_stride = 4
74
+ self.loss_weight = loss_weight
75
+
76
+ self.pixel_decoder = pixel_decoder
77
+ self.predictor = Conv2d(
78
+ self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0
79
+ )
80
+ weight_init.c2_msra_fill(self.predictor)
81
+
82
+ @classmethod
83
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
84
+ return {
85
+ "input_shape": {
86
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
87
+ },
88
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
89
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
90
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
91
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
92
+ }
93
+
94
+ def forward(self, features, targets=None):
95
+ """
96
+ Returns:
97
+ In training, returns (None, dict of losses)
98
+ In inference, returns (CxHxW logits, {})
99
+ """
100
+ x = self.layers(features)
101
+ if self.training:
102
+ return None, self.losses(x, targets)
103
+ else:
104
+ x = F.interpolate(
105
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
106
+ )
107
+ return x, {}
108
+
109
+ def layers(self, features):
110
+ x, _, _ = self.pixel_decoder.forward_features(features)
111
+ x = self.predictor(x)
112
+ return x
113
+
114
+ def losses(self, predictions, targets):
115
+ predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
116
+ predictions = F.interpolate(
117
+ predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
118
+ )
119
+ loss = F.cross_entropy(
120
+ predictions, targets, reduction="mean", ignore_index=self.ignore_value
121
+ )
122
+ losses = {"loss_sem_seg": loss * self.loss_weight}
123
+ return losses
124
+
125
+
126
+ @SEM_SEG_HEADS_REGISTRY.register()
127
+ class PerPixelBaselinePlusHead(PerPixelBaselineHead):
128
+ def _load_from_state_dict(
129
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
130
+ ):
131
+ version = local_metadata.get("version", None)
132
+ if version is None or version < 2:
133
+ # Do not warn if train from scratch
134
+ scratch = True
135
+ logger = logging.getLogger(__name__)
136
+ for k in list(state_dict.keys()):
137
+ newk = k
138
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
139
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
140
+ logger.debug(f"{k} ==> {newk}")
141
+ if newk != k:
142
+ state_dict[newk] = state_dict[k]
143
+ del state_dict[k]
144
+ scratch = False
145
+
146
+ if not scratch:
147
+ logger.warning(
148
+ f"Weight format of {self.__class__.__name__} have changed! "
149
+ "Please upgrade your models. Applying automatic conversion now ..."
150
+ )
151
+
152
+ @configurable
153
+ def __init__(
154
+ self,
155
+ input_shape: Dict[str, ShapeSpec],
156
+ *,
157
+ # extra parameters
158
+ transformer_predictor: nn.Module,
159
+ transformer_in_feature: str,
160
+ deep_supervision: bool,
161
+ # inherit parameters
162
+ num_classes: int,
163
+ pixel_decoder: nn.Module,
164
+ loss_weight: float = 1.0,
165
+ ignore_value: int = -1,
166
+ ):
167
+ """
168
+ NOTE: this interface is experimental.
169
+ Args:
170
+ input_shape: shapes (channels and stride) of the input features
171
+ transformer_predictor: the transformer decoder that makes prediction
172
+ transformer_in_feature: input feature name to the transformer_predictor
173
+ deep_supervision: whether or not to add supervision to the output of
174
+ every transformer decoder layer
175
+ num_classes: number of classes to predict
176
+ pixel_decoder: the pixel decoder module
177
+ loss_weight: loss weight
178
+ ignore_value: category id to be ignored during training.
179
+ """
180
+ super().__init__(
181
+ input_shape,
182
+ num_classes=num_classes,
183
+ pixel_decoder=pixel_decoder,
184
+ loss_weight=loss_weight,
185
+ ignore_value=ignore_value,
186
+ )
187
+
188
+ del self.predictor
189
+
190
+ self.predictor = transformer_predictor
191
+ self.transformer_in_feature = transformer_in_feature
192
+ self.deep_supervision = deep_supervision
193
+
194
+ @classmethod
195
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
196
+ ret = super().from_config(cfg, input_shape)
197
+ ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE
198
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
199
+ in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
200
+ else:
201
+ in_channels = input_shape[ret["transformer_in_feature"]].channels
202
+ ret["transformer_predictor"] = StandardTransformerDecoder(
203
+ cfg, in_channels, mask_classification=False
204
+ )
205
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
206
+ return ret
207
+
208
+ def forward(self, features, targets=None):
209
+ """
210
+ Returns:
211
+ In training, returns (None, dict of losses)
212
+ In inference, returns (CxHxW logits, {})
213
+ """
214
+ x, aux_outputs = self.layers(features)
215
+ if self.training:
216
+ if self.deep_supervision:
217
+ losses = self.losses(x, targets)
218
+ for i, aux_output in enumerate(aux_outputs):
219
+ losses["loss_sem_seg" + f"_{i}"] = self.losses(
220
+ aux_output["pred_masks"], targets
221
+ )["loss_sem_seg"]
222
+ return None, losses
223
+ else:
224
+ return None, self.losses(x, targets)
225
+ else:
226
+ x = F.interpolate(
227
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
228
+ )
229
+ return x, {}
230
+
231
+ def layers(self, features):
232
+ mask_features, transformer_encoder_features, _ = self.pixel_decoder.forward_features(features)
233
+ if self.transformer_in_feature == "transformer_encoder":
234
+ assert (
235
+ transformer_encoder_features is not None
236
+ ), "Please use the TransformerEncoderPixelDecoder."
237
+ predictions = self.predictor(transformer_encoder_features, mask_features)
238
+ else:
239
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features)
240
+ if self.deep_supervision:
241
+ return predictions["pred_masks"], predictions["aux_outputs"]
242
+ else:
243
+ return predictions["pred_masks"], None
annotator/entityseg/mask2former/modeling/pixel_decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
annotator/entityseg/mask2former/modeling/pixel_decoder/fpn.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ from torch.cuda.amp import autocast
12
+
13
+ from detectron2.config import configurable
14
+ from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm
15
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
16
+
17
+ from ..transformer_decoder.position_encoding import PositionEmbeddingSine
18
+ from ..transformer_decoder.transformer import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn
19
+
20
+
21
+ def build_pixel_decoder(cfg, input_shape):
22
+ """
23
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
24
+ """
25
+ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
26
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
27
+ forward_features = getattr(model, "forward_features", None)
28
+ if not callable(forward_features):
29
+ raise ValueError(
30
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
31
+ f"Please implement forward_features for {name} to only return mask features."
32
+ )
33
+ return model
34
+
35
+
36
+ # This is a modified FPN decoder.
37
+ @SEM_SEG_HEADS_REGISTRY.register()
38
+ class BasePixelDecoder(nn.Module):
39
+ @configurable
40
+ def __init__(
41
+ self,
42
+ input_shape: Dict[str, ShapeSpec],
43
+ *,
44
+ conv_dim: int,
45
+ mask_dim: int,
46
+ norm: Optional[Union[str, Callable]] = None,
47
+ ):
48
+ """
49
+ NOTE: this interface is experimental.
50
+ Args:
51
+ input_shape: shapes (channels and stride) of the input features
52
+ conv_dims: number of output channels for the intermediate conv layers.
53
+ mask_dim: number of output channels for the final conv layer.
54
+ norm (str or callable): normalization for all conv layers
55
+ """
56
+ super().__init__()
57
+
58
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
59
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
60
+ feature_channels = [v.channels for k, v in input_shape]
61
+
62
+ lateral_convs = []
63
+ output_convs = []
64
+
65
+ use_bias = norm == ""
66
+ for idx, in_channels in enumerate(feature_channels):
67
+ if idx == len(self.in_features) - 1:
68
+ output_norm = get_norm(norm, conv_dim)
69
+ output_conv = Conv2d(
70
+ in_channels,
71
+ conv_dim,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ bias=use_bias,
76
+ norm=output_norm,
77
+ activation=F.relu,
78
+ )
79
+ weight_init.c2_xavier_fill(output_conv)
80
+ self.add_module("layer_{}".format(idx + 1), output_conv)
81
+
82
+ lateral_convs.append(None)
83
+ output_convs.append(output_conv)
84
+ else:
85
+ lateral_norm = get_norm(norm, conv_dim)
86
+ output_norm = get_norm(norm, conv_dim)
87
+
88
+ lateral_conv = Conv2d(
89
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
90
+ )
91
+ output_conv = Conv2d(
92
+ conv_dim,
93
+ conv_dim,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=use_bias,
98
+ norm=output_norm,
99
+ activation=F.relu,
100
+ )
101
+ weight_init.c2_xavier_fill(lateral_conv)
102
+ weight_init.c2_xavier_fill(output_conv)
103
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
104
+ self.add_module("layer_{}".format(idx + 1), output_conv)
105
+
106
+ lateral_convs.append(lateral_conv)
107
+ output_convs.append(output_conv)
108
+ # Place convs into top-down order (from low to high resolution)
109
+ # to make the top-down computation in forward clearer.
110
+ self.lateral_convs = lateral_convs[::-1]
111
+ self.output_convs = output_convs[::-1]
112
+
113
+ self.mask_dim = mask_dim
114
+ self.mask_features = Conv2d(
115
+ conv_dim,
116
+ mask_dim,
117
+ kernel_size=3,
118
+ stride=1,
119
+ padding=1,
120
+ )
121
+ weight_init.c2_xavier_fill(self.mask_features)
122
+
123
+ self.maskformer_num_feature_levels = 3 # always use 3 scales
124
+
125
+ @classmethod
126
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
127
+ ret = {}
128
+ ret["input_shape"] = {
129
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
130
+ }
131
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
132
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
133
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
134
+ return ret
135
+
136
+ def forward_features(self, features):
137
+ multi_scale_features = []
138
+ num_cur_levels = 0
139
+ # Reverse feature maps into top-down order (from low to high resolution)
140
+ for idx, f in enumerate(self.in_features[::-1]):
141
+ x = features[f]
142
+ lateral_conv = self.lateral_convs[idx]
143
+ output_conv = self.output_convs[idx]
144
+ if lateral_conv is None:
145
+ y = output_conv(x)
146
+ else:
147
+ cur_fpn = lateral_conv(x)
148
+ # Following FPN implementation, we use nearest upsampling here
149
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
150
+ y = output_conv(y)
151
+ if num_cur_levels < self.maskformer_num_feature_levels:
152
+ multi_scale_features.append(y)
153
+ num_cur_levels += 1
154
+ return self.mask_features(y), None, multi_scale_features
155
+
156
+ def forward(self, features, targets=None):
157
+ logger = logging.getLogger(__name__)
158
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
159
+ return self.forward_features(features)
160
+
161
+
162
+ class TransformerEncoderOnly(nn.Module):
163
+ def __init__(
164
+ self,
165
+ d_model=512,
166
+ nhead=8,
167
+ num_encoder_layers=6,
168
+ dim_feedforward=2048,
169
+ dropout=0.1,
170
+ activation="relu",
171
+ normalize_before=False,
172
+ ):
173
+ super().__init__()
174
+
175
+ encoder_layer = TransformerEncoderLayer(
176
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
177
+ )
178
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
179
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
180
+
181
+ self._reset_parameters()
182
+
183
+ self.d_model = d_model
184
+ self.nhead = nhead
185
+
186
+ def _reset_parameters(self):
187
+ for p in self.parameters():
188
+ if p.dim() > 1:
189
+ nn.init.xavier_uniform_(p)
190
+
191
+ def forward(self, src, mask, pos_embed):
192
+ # flatten NxCxHxW to HWxNxC
193
+ bs, c, h, w = src.shape
194
+ src = src.flatten(2).permute(2, 0, 1)
195
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
196
+ if mask is not None:
197
+ mask = mask.flatten(1)
198
+
199
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
200
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
201
+
202
+
203
+ # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.
204
+ @SEM_SEG_HEADS_REGISTRY.register()
205
+ class TransformerEncoderPixelDecoder(BasePixelDecoder):
206
+ @configurable
207
+ def __init__(
208
+ self,
209
+ input_shape: Dict[str, ShapeSpec],
210
+ *,
211
+ transformer_dropout: float,
212
+ transformer_nheads: int,
213
+ transformer_dim_feedforward: int,
214
+ transformer_enc_layers: int,
215
+ transformer_pre_norm: bool,
216
+ conv_dim: int,
217
+ mask_dim: int,
218
+ norm: Optional[Union[str, Callable]] = None,
219
+ ):
220
+ """
221
+ NOTE: this interface is experimental.
222
+ Args:
223
+ input_shape: shapes (channels and stride) of the input features
224
+ transformer_dropout: dropout probability in transformer
225
+ transformer_nheads: number of heads in transformer
226
+ transformer_dim_feedforward: dimension of feedforward network
227
+ transformer_enc_layers: number of transformer encoder layers
228
+ transformer_pre_norm: whether to use pre-layernorm or not
229
+ conv_dims: number of output channels for the intermediate conv layers.
230
+ mask_dim: number of output channels for the final conv layer.
231
+ norm (str or callable): normalization for all conv layers
232
+ """
233
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
234
+
235
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
236
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
237
+ feature_strides = [v.stride for k, v in input_shape]
238
+ feature_channels = [v.channels for k, v in input_shape]
239
+
240
+ in_channels = feature_channels[len(self.in_features) - 1]
241
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
242
+ weight_init.c2_xavier_fill(self.input_proj)
243
+ self.transformer = TransformerEncoderOnly(
244
+ d_model=conv_dim,
245
+ dropout=transformer_dropout,
246
+ nhead=transformer_nheads,
247
+ dim_feedforward=transformer_dim_feedforward,
248
+ num_encoder_layers=transformer_enc_layers,
249
+ normalize_before=transformer_pre_norm,
250
+ )
251
+ N_steps = conv_dim // 2
252
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
253
+
254
+ # update layer
255
+ use_bias = norm == ""
256
+ output_norm = get_norm(norm, conv_dim)
257
+ output_conv = Conv2d(
258
+ conv_dim,
259
+ conv_dim,
260
+ kernel_size=3,
261
+ stride=1,
262
+ padding=1,
263
+ bias=use_bias,
264
+ norm=output_norm,
265
+ activation=F.relu,
266
+ )
267
+ weight_init.c2_xavier_fill(output_conv)
268
+ delattr(self, "layer_{}".format(len(self.in_features)))
269
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
270
+ self.output_convs[0] = output_conv
271
+
272
+ @classmethod
273
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
274
+ ret = super().from_config(cfg, input_shape)
275
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
276
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
277
+ ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
278
+ ret[
279
+ "transformer_enc_layers"
280
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
281
+ ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
282
+ return ret
283
+
284
+ def forward_features(self, features):
285
+ multi_scale_features = []
286
+ num_cur_levels = 0
287
+ # Reverse feature maps into top-down order (from low to high resolution)
288
+ for idx, f in enumerate(self.in_features[::-1]):
289
+ x = features[f]
290
+ lateral_conv = self.lateral_convs[idx]
291
+ output_conv = self.output_convs[idx]
292
+ if lateral_conv is None:
293
+ transformer = self.input_proj(x)
294
+ pos = self.pe_layer(x)
295
+ transformer = self.transformer(transformer, None, pos)
296
+ y = output_conv(transformer)
297
+ # save intermediate feature as input to Transformer decoder
298
+ transformer_encoder_features = transformer
299
+ else:
300
+ cur_fpn = lateral_conv(x)
301
+ # Following FPN implementation, we use nearest upsampling here
302
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
303
+ y = output_conv(y)
304
+ if num_cur_levels < self.maskformer_num_feature_levels:
305
+ multi_scale_features.append(y)
306
+ num_cur_levels += 1
307
+ return self.mask_features(y), transformer_encoder_features, multi_scale_features
308
+
309
+ def forward(self, features, targets=None):
310
+ logger = logging.getLogger(__name__)
311
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
312
+ return self.forward_features(features)
annotator/entityseg/mask2former/modeling/pixel_decoder/msdeformattn.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ from torch.cuda.amp import autocast
12
+
13
+ from detectron2.config import configurable
14
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
15
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
16
+
17
+ from ..transformer_decoder.position_encoding import PositionEmbeddingSine
18
+ from ..transformer_decoder.transformer import _get_clones, _get_activation_fn
19
+ from .ops.modules import MSDeformAttn
20
+
21
+
22
+ # MSDeformAttn Transformer encoder in deformable detr
23
+ class MSDeformAttnTransformerEncoderOnly(nn.Module):
24
+ def __init__(self, d_model=256, nhead=8,
25
+ num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
26
+ activation="relu",
27
+ num_feature_levels=4, enc_n_points=4,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.d_model = d_model
32
+ self.nhead = nhead
33
+
34
+ encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
35
+ dropout, activation,
36
+ num_feature_levels, nhead, enc_n_points)
37
+ self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)
38
+
39
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
40
+
41
+ self._reset_parameters()
42
+
43
+ def _reset_parameters(self):
44
+ for p in self.parameters():
45
+ if p.dim() > 1:
46
+ nn.init.xavier_uniform_(p)
47
+ for m in self.modules():
48
+ if isinstance(m, MSDeformAttn):
49
+ m._reset_parameters()
50
+ normal_(self.level_embed)
51
+
52
+ def get_valid_ratio(self, mask):
53
+ _, H, W = mask.shape
54
+ valid_H = torch.sum(~mask[:, :, 0], 1)
55
+ valid_W = torch.sum(~mask[:, 0, :], 1)
56
+ valid_ratio_h = valid_H.float() / H
57
+ valid_ratio_w = valid_W.float() / W
58
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
59
+ return valid_ratio
60
+
61
+ def forward(self, srcs, pos_embeds):
62
+ masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
63
+ # prepare input for encoder
64
+ src_flatten = []
65
+ mask_flatten = []
66
+ lvl_pos_embed_flatten = []
67
+ spatial_shapes = []
68
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
69
+ bs, c, h, w = src.shape
70
+ spatial_shape = (h, w)
71
+ spatial_shapes.append(spatial_shape)
72
+ src = src.flatten(2).transpose(1, 2)
73
+ mask = mask.flatten(1)
74
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
75
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
76
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
77
+ src_flatten.append(src)
78
+ mask_flatten.append(mask)
79
+ src_flatten = torch.cat(src_flatten, 1)
80
+ mask_flatten = torch.cat(mask_flatten, 1)
81
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
82
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
83
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
84
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
85
+
86
+ # encoder
87
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
88
+
89
+ return memory, spatial_shapes, level_start_index
90
+
91
+
92
+ class MSDeformAttnTransformerEncoderLayer(nn.Module):
93
+ def __init__(self,
94
+ d_model=256, d_ffn=1024,
95
+ dropout=0.1, activation="relu",
96
+ n_levels=4, n_heads=8, n_points=4):
97
+ super().__init__()
98
+
99
+ # self attention
100
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
101
+ self.dropout1 = nn.Dropout(dropout)
102
+ self.norm1 = nn.LayerNorm(d_model)
103
+
104
+ # ffn
105
+ self.linear1 = nn.Linear(d_model, d_ffn)
106
+ self.activation = _get_activation_fn(activation)
107
+ self.dropout2 = nn.Dropout(dropout)
108
+ self.linear2 = nn.Linear(d_ffn, d_model)
109
+ self.dropout3 = nn.Dropout(dropout)
110
+ self.norm2 = nn.LayerNorm(d_model)
111
+
112
+ @staticmethod
113
+ def with_pos_embed(tensor, pos):
114
+ return tensor if pos is None else tensor + pos
115
+
116
+ def forward_ffn(self, src):
117
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
118
+ src = src + self.dropout3(src2)
119
+ src = self.norm2(src)
120
+ return src
121
+
122
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
123
+ # self attention
124
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
125
+ src = src + self.dropout1(src2)
126
+ src = self.norm1(src)
127
+
128
+ # ffn
129
+ src = self.forward_ffn(src)
130
+
131
+ return src
132
+
133
+
134
+ class MSDeformAttnTransformerEncoder(nn.Module):
135
+ def __init__(self, encoder_layer, num_layers):
136
+ super().__init__()
137
+ self.layers = _get_clones(encoder_layer, num_layers)
138
+ self.num_layers = num_layers
139
+
140
+ @staticmethod
141
+ def get_reference_points(spatial_shapes, valid_ratios, device):
142
+ reference_points_list = []
143
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
144
+
145
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
146
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
147
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
148
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
149
+ ref = torch.stack((ref_x, ref_y), -1)
150
+ reference_points_list.append(ref)
151
+ reference_points = torch.cat(reference_points_list, 1)
152
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
153
+ return reference_points
154
+
155
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
156
+ output = src
157
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
158
+ for _, layer in enumerate(self.layers):
159
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
160
+
161
+ return output
162
+
163
+
164
+ @SEM_SEG_HEADS_REGISTRY.register()
165
+ class MSDeformAttnPixelDecoder(nn.Module):
166
+ @configurable
167
+ def __init__(
168
+ self,
169
+ input_shape: Dict[str, ShapeSpec],
170
+ *,
171
+ transformer_dropout: float,
172
+ transformer_nheads: int,
173
+ transformer_dim_feedforward: int,
174
+ transformer_enc_layers: int,
175
+ conv_dim: int,
176
+ mask_dim: int,
177
+ norm: Optional[Union[str, Callable]] = None,
178
+ # deformable transformer encoder args
179
+ transformer_in_features: List[str],
180
+ common_stride: int,
181
+ ):
182
+ """
183
+ NOTE: this interface is experimental.
184
+ Args:
185
+ input_shape: shapes (channels and stride) of the input features
186
+ transformer_dropout: dropout probability in transformer
187
+ transformer_nheads: number of heads in transformer
188
+ transformer_dim_feedforward: dimension of feedforward network
189
+ transformer_enc_layers: number of transformer encoder layers
190
+ conv_dims: number of output channels for the intermediate conv layers.
191
+ mask_dim: number of output channels for the final conv layer.
192
+ norm (str or callable): normalization for all conv layers
193
+ """
194
+ super().__init__()
195
+ transformer_input_shape = {
196
+ k: v for k, v in input_shape.items() if k in transformer_in_features
197
+ }
198
+
199
+ # this is the input shape of pixel decoder
200
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
201
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
202
+ self.feature_strides = [v.stride for k, v in input_shape]
203
+ self.feature_channels = [v.channels for k, v in input_shape]
204
+
205
+ # this is the input shape of transformer encoder (could use less features than pixel decoder
206
+ transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
207
+ self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5"
208
+ transformer_in_channels = [v.channels for k, v in transformer_input_shape]
209
+ self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers
210
+
211
+ self.transformer_num_feature_levels = len(self.transformer_in_features)
212
+ if self.transformer_num_feature_levels > 1:
213
+ input_proj_list = []
214
+ # from low resolution to high resolution (res5 -> res2)
215
+ for in_channels in transformer_in_channels[::-1]:
216
+ input_proj_list.append(nn.Sequential(
217
+ nn.Conv2d(in_channels, conv_dim, kernel_size=1),
218
+ nn.GroupNorm(32, conv_dim),
219
+ ))
220
+ self.input_proj = nn.ModuleList(input_proj_list)
221
+ else:
222
+ self.input_proj = nn.ModuleList([
223
+ nn.Sequential(
224
+ nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
225
+ nn.GroupNorm(32, conv_dim),
226
+ )])
227
+
228
+ for proj in self.input_proj:
229
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
230
+ nn.init.constant_(proj[0].bias, 0)
231
+
232
+ self.transformer = MSDeformAttnTransformerEncoderOnly(
233
+ d_model=conv_dim,
234
+ dropout=transformer_dropout,
235
+ nhead=transformer_nheads,
236
+ dim_feedforward=transformer_dim_feedforward,
237
+ num_encoder_layers=transformer_enc_layers,
238
+ num_feature_levels=self.transformer_num_feature_levels,
239
+ )
240
+ N_steps = conv_dim // 2
241
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
242
+
243
+ self.mask_dim = mask_dim
244
+ # use 1x1 conv instead
245
+ self.mask_features = Conv2d(
246
+ conv_dim,
247
+ mask_dim,
248
+ kernel_size=1,
249
+ stride=1,
250
+ padding=0,
251
+ )
252
+ weight_init.c2_xavier_fill(self.mask_features)
253
+
254
+ self.maskformer_num_feature_levels = 3 # always use 3 scales
255
+ self.common_stride = common_stride
256
+
257
+ # extra fpn levels
258
+ stride = min(self.transformer_feature_strides)
259
+ self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
260
+
261
+ lateral_convs = []
262
+ output_convs = []
263
+
264
+ use_bias = norm == ""
265
+ for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
266
+ lateral_norm = get_norm(norm, conv_dim)
267
+ output_norm = get_norm(norm, conv_dim)
268
+
269
+ lateral_conv = Conv2d(
270
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
271
+ )
272
+ output_conv = Conv2d(
273
+ conv_dim,
274
+ conv_dim,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=use_bias,
279
+ norm=output_norm,
280
+ activation=F.relu,
281
+ )
282
+ weight_init.c2_xavier_fill(lateral_conv)
283
+ weight_init.c2_xavier_fill(output_conv)
284
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
285
+ self.add_module("layer_{}".format(idx + 1), output_conv)
286
+
287
+ lateral_convs.append(lateral_conv)
288
+ output_convs.append(output_conv)
289
+ # Place convs into top-down order (from low to high resolution)
290
+ # to make the top-down computation in forward clearer.
291
+ self.lateral_convs = lateral_convs[::-1]
292
+ self.output_convs = output_convs[::-1]
293
+
294
+ @classmethod
295
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
296
+ ret = {}
297
+ ret["input_shape"] = {
298
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
299
+ }
300
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
301
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
302
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
303
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
304
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
305
+ # ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
306
+ ret["transformer_dim_feedforward"] = 1024 # use 1024 for deformable transformer encoder
307
+ ret[
308
+ "transformer_enc_layers"
309
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
310
+ ret["transformer_in_features"] = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES
311
+ ret["common_stride"] = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
312
+ return ret
313
+
314
+ @autocast(enabled=False)
315
+ def forward_features(self, features):
316
+ srcs = []
317
+ pos = []
318
+ # Reverse feature maps into top-down order (from low to high resolution)
319
+ for idx, f in enumerate(self.transformer_in_features[::-1]):
320
+ x = features[f].float() # deformable detr does not support half precision
321
+ srcs.append(self.input_proj[idx](x))
322
+ pos.append(self.pe_layer(x))
323
+
324
+ y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
325
+ bs = y.shape[0]
326
+
327
+ split_size_or_sections = [None] * self.transformer_num_feature_levels
328
+ for i in range(self.transformer_num_feature_levels):
329
+ if i < self.transformer_num_feature_levels - 1:
330
+ split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
331
+ else:
332
+ split_size_or_sections[i] = y.shape[1] - level_start_index[i]
333
+ y = torch.split(y, split_size_or_sections, dim=1)
334
+
335
+ out = []
336
+ multi_scale_features = []
337
+ num_cur_levels = 0
338
+ for i, z in enumerate(y):
339
+ out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
340
+
341
+ # append `out` with extra FPN levels
342
+ # Reverse feature maps into top-down order (from low to high resolution)
343
+ for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
344
+ x = features[f].float()
345
+ lateral_conv = self.lateral_convs[idx]
346
+ output_conv = self.output_convs[idx]
347
+ cur_fpn = lateral_conv(x)
348
+ # Following FPN implementation, we use nearest upsampling here
349
+ y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
350
+ y = output_conv(y)
351
+ out.append(y)
352
+
353
+ for o in out:
354
+ if num_cur_levels < self.maskformer_num_feature_levels:
355
+ multi_scale_features.append(o)
356
+ num_cur_levels += 1
357
+
358
+ return self.mask_features(out[-1]), out[0], multi_scale_features
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn_func import MSDeformAttnFunction
13
+
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.autograd import Function
19
+ from torch.autograd.function import once_differentiable
20
+
21
+ try:
22
+ import MultiScaleDeformableAttention as MSDA
23
+ except ModuleNotFoundError as e:
24
+ info_string = (
25
+ "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
26
+ "\t`cd mask2former/modeling/pixel_decoder/ops`\n"
27
+ "\t`sh make.sh`\n"
28
+ )
29
+ raise ModuleNotFoundError(info_string)
30
+
31
+
32
+ class MSDeformAttnFunction(Function):
33
+ @staticmethod
34
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
35
+ ctx.im2col_step = im2col_step
36
+ output = MSDA.ms_deform_attn_forward(
37
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
38
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
39
+ return output
40
+
41
+ @staticmethod
42
+ @once_differentiable
43
+ def backward(ctx, grad_output):
44
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
45
+ grad_value, grad_sampling_loc, grad_attn_weight = \
46
+ MSDA.ms_deform_attn_backward(
47
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
48
+
49
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
50
+
51
+
52
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
53
+ # for debug and test only,
54
+ # need to use cuda version instead
55
+ N_, S_, M_, D_ = value.shape
56
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
57
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
58
+ sampling_grids = 2 * sampling_locations - 1
59
+ sampling_value_list = []
60
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
61
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
62
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
63
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
64
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
65
+ # N_*M_, D_, Lq_, P_
66
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
67
+ mode='bilinear', padding_mode='zeros', align_corners=False)
68
+ sampling_value_list.append(sampling_value_l_)
69
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
70
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
71
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
72
+ return output.transpose(1, 2).contiguous()
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/make.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------------------------------
7
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ # ------------------------------------------------------------------------------------------------
9
+
10
+ # Copyright (c) Facebook, Inc. and its affiliates.
11
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
12
+
13
+ python setup.py build install
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn import MSDeformAttn
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import warnings
17
+ import math
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+ from torch.nn.init import xavier_uniform_, constant_
23
+
24
+ from ..functions import MSDeformAttnFunction
25
+ from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
26
+
27
+
28
+ def _is_power_of_2(n):
29
+ if (not isinstance(n, int)) or (n < 0):
30
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
31
+ return (n & (n-1) == 0) and n != 0
32
+
33
+
34
+ class MSDeformAttn(nn.Module):
35
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
36
+ """
37
+ Multi-Scale Deformable Attention Module
38
+ :param d_model hidden dimension
39
+ :param n_levels number of feature levels
40
+ :param n_heads number of attention heads
41
+ :param n_points number of sampling points per attention head per feature level
42
+ """
43
+ super().__init__()
44
+ if d_model % n_heads != 0:
45
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
46
+ _d_per_head = d_model // n_heads
47
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
48
+ if not _is_power_of_2(_d_per_head):
49
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
50
+ "which is more efficient in our CUDA implementation.")
51
+
52
+ self.im2col_step = 128
53
+
54
+ self.d_model = d_model
55
+ self.n_levels = n_levels
56
+ self.n_heads = n_heads
57
+ self.n_points = n_points
58
+
59
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
60
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
61
+ self.value_proj = nn.Linear(d_model, d_model)
62
+ self.output_proj = nn.Linear(d_model, d_model)
63
+
64
+ self._reset_parameters()
65
+
66
+ def _reset_parameters(self):
67
+ constant_(self.sampling_offsets.weight.data, 0.)
68
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
69
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
70
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
71
+ for i in range(self.n_points):
72
+ grid_init[:, :, i, :] *= i + 1
73
+ with torch.no_grad():
74
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
75
+ constant_(self.attention_weights.weight.data, 0.)
76
+ constant_(self.attention_weights.bias.data, 0.)
77
+ xavier_uniform_(self.value_proj.weight.data)
78
+ constant_(self.value_proj.bias.data, 0.)
79
+ xavier_uniform_(self.output_proj.weight.data)
80
+ constant_(self.output_proj.bias.data, 0.)
81
+
82
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
83
+ """
84
+ :param query (N, Length_{query}, C)
85
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
86
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
87
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
88
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
89
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
90
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
91
+
92
+ :return output (N, Length_{query}, C)
93
+ """
94
+ N, Len_q, _ = query.shape
95
+ N, Len_in, _ = input_flatten.shape
96
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
97
+
98
+ value = self.value_proj(input_flatten)
99
+ if input_padding_mask is not None:
100
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
101
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
102
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
103
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
104
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
105
+ # N, Len_q, n_heads, n_levels, n_points, 2
106
+ if reference_points.shape[-1] == 2:
107
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
108
+ sampling_locations = reference_points[:, :, None, :, None, :] \
109
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
110
+ elif reference_points.shape[-1] == 4:
111
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
112
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
113
+ else:
114
+ raise ValueError(
115
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
116
+ try:
117
+ output = MSDeformAttnFunction.apply(
118
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
119
+ except:
120
+ # CPU
121
+ output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
122
+ # # For FLOPs calculation only
123
+ # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
124
+ output = self.output_proj(output)
125
+ return output
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/setup.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ import os
13
+ import glob
14
+
15
+ import torch
16
+
17
+ from torch.utils.cpp_extension import CUDA_HOME
18
+ from torch.utils.cpp_extension import CppExtension
19
+ from torch.utils.cpp_extension import CUDAExtension
20
+
21
+ from setuptools import find_packages
22
+ from setuptools import setup
23
+
24
+ requirements = ["torch", "torchvision"]
25
+
26
+ def get_extensions():
27
+ this_dir = os.path.dirname(os.path.abspath(__file__))
28
+ extensions_dir = os.path.join(this_dir, "src")
29
+
30
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
31
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
32
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
33
+
34
+ sources = main_file + source_cpu
35
+ extension = CppExtension
36
+ extra_compile_args = {"cxx": []}
37
+ define_macros = []
38
+
39
+ # Force cuda since torch ask for a device, not if cuda is in fact available.
40
+ if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
41
+ extension = CUDAExtension
42
+ sources += source_cuda
43
+ define_macros += [("WITH_CUDA", None)]
44
+ extra_compile_args["nvcc"] = [
45
+ "-DCUDA_HAS_FP16=1",
46
+ "-D__CUDA_NO_HALF_OPERATORS__",
47
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
48
+ "-D__CUDA_NO_HALF2_OPERATORS__",
49
+ ]
50
+ else:
51
+ if CUDA_HOME is None:
52
+ raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
53
+ else:
54
+ raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
55
+
56
+ sources = [os.path.join(extensions_dir, s) for s in sources]
57
+ include_dirs = [extensions_dir]
58
+ ext_modules = [
59
+ extension(
60
+ "MultiScaleDeformableAttention",
61
+ sources,
62
+ include_dirs=include_dirs,
63
+ define_macros=define_macros,
64
+ extra_compile_args=extra_compile_args,
65
+ )
66
+ ]
67
+ return ext_modules
68
+
69
+ setup(
70
+ name="MultiScaleDeformableAttention",
71
+ version="1.0",
72
+ author="Weijie Su",
73
+ url="https://github.com/fundamentalvision/Deformable-DETR",
74
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
75
+ packages=find_packages(exclude=("configs", "tests",)),
76
+ ext_modules=get_extensions(),
77
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
78
+ )
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+
18
+ #include <ATen/ATen.h>
19
+ #include <ATen/cuda/CUDAContext.h>
20
+
21
+
22
+ at::Tensor
23
+ ms_deform_attn_cpu_forward(
24
+ const at::Tensor &value,
25
+ const at::Tensor &spatial_shapes,
26
+ const at::Tensor &level_start_index,
27
+ const at::Tensor &sampling_loc,
28
+ const at::Tensor &attn_weight,
29
+ const int im2col_step)
30
+ {
31
+ AT_ERROR("Not implement on cpu");
32
+ }
33
+
34
+ std::vector<at::Tensor>
35
+ ms_deform_attn_cpu_backward(
36
+ const at::Tensor &value,
37
+ const at::Tensor &spatial_shapes,
38
+ const at::Tensor &level_start_index,
39
+ const at::Tensor &sampling_loc,
40
+ const at::Tensor &attn_weight,
41
+ const at::Tensor &grad_output,
42
+ const int im2col_step)
43
+ {
44
+ AT_ERROR("Not implement on cpu");
45
+ }
46
+
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor
20
+ ms_deform_attn_cpu_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step);
27
+
28
+ std::vector<at::Tensor>
29
+ ms_deform_attn_cpu_backward(
30
+ const at::Tensor &value,
31
+ const at::Tensor &spatial_shapes,
32
+ const at::Tensor &level_start_index,
33
+ const at::Tensor &sampling_loc,
34
+ const at::Tensor &attn_weight,
35
+ const at::Tensor &grad_output,
36
+ const int im2col_step);
37
+
38
+
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+ #include "cuda/ms_deform_im2col_cuda.cuh"
18
+
19
+ #include <ATen/ATen.h>
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <cuda.h>
22
+ #include <cuda_runtime.h>
23
+
24
+
25
+ at::Tensor ms_deform_attn_cuda_forward(
26
+ const at::Tensor &value,
27
+ const at::Tensor &spatial_shapes,
28
+ const at::Tensor &level_start_index,
29
+ const at::Tensor &sampling_loc,
30
+ const at::Tensor &attn_weight,
31
+ const int im2col_step)
32
+ {
33
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
34
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
35
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
36
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
37
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
38
+
39
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
40
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
41
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
42
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
43
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
44
+
45
+ const int batch = value.size(0);
46
+ const int spatial_size = value.size(1);
47
+ const int num_heads = value.size(2);
48
+ const int channels = value.size(3);
49
+
50
+ const int num_levels = spatial_shapes.size(0);
51
+
52
+ const int num_query = sampling_loc.size(1);
53
+ const int num_point = sampling_loc.size(4);
54
+
55
+ const int im2col_step_ = std::min(batch, im2col_step);
56
+
57
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
58
+
59
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
60
+
61
+ const int batch_n = im2col_step_;
62
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
63
+ auto per_value_size = spatial_size * num_heads * channels;
64
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
65
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
66
+ for (int n = 0; n < batch/im2col_step_; ++n)
67
+ {
68
+ auto columns = output_n.select(0, n);
69
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
70
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
71
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
72
+ spatial_shapes.data<int64_t>(),
73
+ level_start_index.data<int64_t>(),
74
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
75
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
76
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
77
+ columns.data<scalar_t>());
78
+
79
+ }));
80
+ }
81
+
82
+ output = output.view({batch, num_query, num_heads*channels});
83
+
84
+ return output;
85
+ }
86
+
87
+
88
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
89
+ const at::Tensor &value,
90
+ const at::Tensor &spatial_shapes,
91
+ const at::Tensor &level_start_index,
92
+ const at::Tensor &sampling_loc,
93
+ const at::Tensor &attn_weight,
94
+ const at::Tensor &grad_output,
95
+ const int im2col_step)
96
+ {
97
+
98
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
99
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
100
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
101
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
102
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104
+
105
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
106
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
107
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
108
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
109
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
110
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
111
+
112
+ const int batch = value.size(0);
113
+ const int spatial_size = value.size(1);
114
+ const int num_heads = value.size(2);
115
+ const int channels = value.size(3);
116
+
117
+ const int num_levels = spatial_shapes.size(0);
118
+
119
+ const int num_query = sampling_loc.size(1);
120
+ const int num_point = sampling_loc.size(4);
121
+
122
+ const int im2col_step_ = std::min(batch, im2col_step);
123
+
124
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
125
+
126
+ auto grad_value = at::zeros_like(value);
127
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
128
+ auto grad_attn_weight = at::zeros_like(attn_weight);
129
+
130
+ const int batch_n = im2col_step_;
131
+ auto per_value_size = spatial_size * num_heads * channels;
132
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
133
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
134
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
135
+
136
+ for (int n = 0; n < batch/im2col_step_; ++n)
137
+ {
138
+ auto grad_output_g = grad_output_n.select(0, n);
139
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
140
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141
+ grad_output_g.data<scalar_t>(),
142
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
143
+ spatial_shapes.data<int64_t>(),
144
+ level_start_index.data<int64_t>(),
145
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
147
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
148
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
150
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
151
+
152
+ }));
153
+ }
154
+
155
+ return {
156
+ grad_value, grad_sampling_loc, grad_attn_weight
157
+ };
158
+ }
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor ms_deform_attn_cuda_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step);
26
+
27
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
28
+ const at::Tensor &value,
29
+ const at::Tensor &spatial_shapes,
30
+ const at::Tensor &level_start_index,
31
+ const at::Tensor &sampling_loc,
32
+ const at::Tensor &attn_weight,
33
+ const at::Tensor &grad_output,
34
+ const int im2col_step);
35
+
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ /*!
13
+ * Copyright (c) Facebook, Inc. and its affiliates.
14
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
15
+ */
16
+
17
+ #include <cstdio>
18
+ #include <algorithm>
19
+ #include <cstring>
20
+
21
+ #include <ATen/ATen.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+
24
+ #include <THC/THCAtomics.cuh>
25
+
26
+ #define CUDA_KERNEL_LOOP(i, n) \
27
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
28
+ i < (n); \
29
+ i += blockDim.x * gridDim.x)
30
+
31
+ const int CUDA_NUM_THREADS = 1024;
32
+ inline int GET_BLOCKS(const int N, const int num_threads)
33
+ {
34
+ return (N + num_threads - 1) / num_threads;
35
+ }
36
+
37
+
38
+ template <typename scalar_t>
39
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
40
+ const int &height, const int &width, const int &nheads, const int &channels,
41
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
42
+ {
43
+ const int h_low = floor(h);
44
+ const int w_low = floor(w);
45
+ const int h_high = h_low + 1;
46
+ const int w_high = w_low + 1;
47
+
48
+ const scalar_t lh = h - h_low;
49
+ const scalar_t lw = w - w_low;
50
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
51
+
52
+ const int w_stride = nheads * channels;
53
+ const int h_stride = width * w_stride;
54
+ const int h_low_ptr_offset = h_low * h_stride;
55
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
56
+ const int w_low_ptr_offset = w_low * w_stride;
57
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
58
+ const int base_ptr = m * channels + c;
59
+
60
+ scalar_t v1 = 0;
61
+ if (h_low >= 0 && w_low >= 0)
62
+ {
63
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
64
+ v1 = bottom_data[ptr1];
65
+ }
66
+ scalar_t v2 = 0;
67
+ if (h_low >= 0 && w_high <= width - 1)
68
+ {
69
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
70
+ v2 = bottom_data[ptr2];
71
+ }
72
+ scalar_t v3 = 0;
73
+ if (h_high <= height - 1 && w_low >= 0)
74
+ {
75
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
76
+ v3 = bottom_data[ptr3];
77
+ }
78
+ scalar_t v4 = 0;
79
+ if (h_high <= height - 1 && w_high <= width - 1)
80
+ {
81
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
82
+ v4 = bottom_data[ptr4];
83
+ }
84
+
85
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
86
+
87
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
88
+ return val;
89
+ }
90
+
91
+
92
+ template <typename scalar_t>
93
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
94
+ const int &height, const int &width, const int &nheads, const int &channels,
95
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
96
+ const scalar_t &top_grad,
97
+ const scalar_t &attn_weight,
98
+ scalar_t* &grad_value,
99
+ scalar_t* grad_sampling_loc,
100
+ scalar_t* grad_attn_weight)
101
+ {
102
+ const int h_low = floor(h);
103
+ const int w_low = floor(w);
104
+ const int h_high = h_low + 1;
105
+ const int w_high = w_low + 1;
106
+
107
+ const scalar_t lh = h - h_low;
108
+ const scalar_t lw = w - w_low;
109
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
110
+
111
+ const int w_stride = nheads * channels;
112
+ const int h_stride = width * w_stride;
113
+ const int h_low_ptr_offset = h_low * h_stride;
114
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
115
+ const int w_low_ptr_offset = w_low * w_stride;
116
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
117
+ const int base_ptr = m * channels + c;
118
+
119
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
120
+ const scalar_t top_grad_value = top_grad * attn_weight;
121
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
122
+
123
+ scalar_t v1 = 0;
124
+ if (h_low >= 0 && w_low >= 0)
125
+ {
126
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
127
+ v1 = bottom_data[ptr1];
128
+ grad_h_weight -= hw * v1;
129
+ grad_w_weight -= hh * v1;
130
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
131
+ }
132
+ scalar_t v2 = 0;
133
+ if (h_low >= 0 && w_high <= width - 1)
134
+ {
135
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
136
+ v2 = bottom_data[ptr2];
137
+ grad_h_weight -= lw * v2;
138
+ grad_w_weight += hh * v2;
139
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
140
+ }
141
+ scalar_t v3 = 0;
142
+ if (h_high <= height - 1 && w_low >= 0)
143
+ {
144
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
145
+ v3 = bottom_data[ptr3];
146
+ grad_h_weight += hw * v3;
147
+ grad_w_weight -= lh * v3;
148
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
149
+ }
150
+ scalar_t v4 = 0;
151
+ if (h_high <= height - 1 && w_high <= width - 1)
152
+ {
153
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
154
+ v4 = bottom_data[ptr4];
155
+ grad_h_weight += lw * v4;
156
+ grad_w_weight += lh * v4;
157
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
158
+ }
159
+
160
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
161
+ *grad_attn_weight = top_grad * val;
162
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
163
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
164
+ }
165
+
166
+
167
+ template <typename scalar_t>
168
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
169
+ const int &height, const int &width, const int &nheads, const int &channels,
170
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
171
+ const scalar_t &top_grad,
172
+ const scalar_t &attn_weight,
173
+ scalar_t* &grad_value,
174
+ scalar_t* grad_sampling_loc,
175
+ scalar_t* grad_attn_weight)
176
+ {
177
+ const int h_low = floor(h);
178
+ const int w_low = floor(w);
179
+ const int h_high = h_low + 1;
180
+ const int w_high = w_low + 1;
181
+
182
+ const scalar_t lh = h - h_low;
183
+ const scalar_t lw = w - w_low;
184
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
185
+
186
+ const int w_stride = nheads * channels;
187
+ const int h_stride = width * w_stride;
188
+ const int h_low_ptr_offset = h_low * h_stride;
189
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
190
+ const int w_low_ptr_offset = w_low * w_stride;
191
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
192
+ const int base_ptr = m * channels + c;
193
+
194
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
195
+ const scalar_t top_grad_value = top_grad * attn_weight;
196
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
197
+
198
+ scalar_t v1 = 0;
199
+ if (h_low >= 0 && w_low >= 0)
200
+ {
201
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
202
+ v1 = bottom_data[ptr1];
203
+ grad_h_weight -= hw * v1;
204
+ grad_w_weight -= hh * v1;
205
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
206
+ }
207
+ scalar_t v2 = 0;
208
+ if (h_low >= 0 && w_high <= width - 1)
209
+ {
210
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
211
+ v2 = bottom_data[ptr2];
212
+ grad_h_weight -= lw * v2;
213
+ grad_w_weight += hh * v2;
214
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
215
+ }
216
+ scalar_t v3 = 0;
217
+ if (h_high <= height - 1 && w_low >= 0)
218
+ {
219
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
220
+ v3 = bottom_data[ptr3];
221
+ grad_h_weight += hw * v3;
222
+ grad_w_weight -= lh * v3;
223
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
224
+ }
225
+ scalar_t v4 = 0;
226
+ if (h_high <= height - 1 && w_high <= width - 1)
227
+ {
228
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
229
+ v4 = bottom_data[ptr4];
230
+ grad_h_weight += lw * v4;
231
+ grad_w_weight += lh * v4;
232
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
233
+ }
234
+
235
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
236
+ atomicAdd(grad_attn_weight, top_grad * val);
237
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
238
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
239
+ }
240
+
241
+
242
+ template <typename scalar_t>
243
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
244
+ const scalar_t *data_value,
245
+ const int64_t *data_spatial_shapes,
246
+ const int64_t *data_level_start_index,
247
+ const scalar_t *data_sampling_loc,
248
+ const scalar_t *data_attn_weight,
249
+ const int batch_size,
250
+ const int spatial_size,
251
+ const int num_heads,
252
+ const int channels,
253
+ const int num_levels,
254
+ const int num_query,
255
+ const int num_point,
256
+ scalar_t *data_col)
257
+ {
258
+ CUDA_KERNEL_LOOP(index, n)
259
+ {
260
+ int _temp = index;
261
+ const int c_col = _temp % channels;
262
+ _temp /= channels;
263
+ const int sampling_index = _temp;
264
+ const int m_col = _temp % num_heads;
265
+ _temp /= num_heads;
266
+ const int q_col = _temp % num_query;
267
+ _temp /= num_query;
268
+ const int b_col = _temp;
269
+
270
+ scalar_t *data_col_ptr = data_col + index;
271
+ int data_weight_ptr = sampling_index * num_levels * num_point;
272
+ int data_loc_w_ptr = data_weight_ptr << 1;
273
+ const int qid_stride = num_heads * channels;
274
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
275
+ scalar_t col = 0;
276
+
277
+ for (int l_col=0; l_col < num_levels; ++l_col)
278
+ {
279
+ const int level_start_id = data_level_start_index[l_col];
280
+ const int spatial_h_ptr = l_col << 1;
281
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
282
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
283
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
284
+ for (int p_col=0; p_col < num_point; ++p_col)
285
+ {
286
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
287
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
288
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
289
+
290
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
291
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
292
+
293
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
294
+ {
295
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
296
+ }
297
+
298
+ data_weight_ptr += 1;
299
+ data_loc_w_ptr += 2;
300
+ }
301
+ }
302
+ *data_col_ptr = col;
303
+ }
304
+ }
305
+
306
+ template <typename scalar_t, unsigned int blockSize>
307
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
308
+ const scalar_t *grad_col,
309
+ const scalar_t *data_value,
310
+ const int64_t *data_spatial_shapes,
311
+ const int64_t *data_level_start_index,
312
+ const scalar_t *data_sampling_loc,
313
+ const scalar_t *data_attn_weight,
314
+ const int batch_size,
315
+ const int spatial_size,
316
+ const int num_heads,
317
+ const int channels,
318
+ const int num_levels,
319
+ const int num_query,
320
+ const int num_point,
321
+ scalar_t *grad_value,
322
+ scalar_t *grad_sampling_loc,
323
+ scalar_t *grad_attn_weight)
324
+ {
325
+ CUDA_KERNEL_LOOP(index, n)
326
+ {
327
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
328
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
329
+ unsigned int tid = threadIdx.x;
330
+ int _temp = index;
331
+ const int c_col = _temp % channels;
332
+ _temp /= channels;
333
+ const int sampling_index = _temp;
334
+ const int m_col = _temp % num_heads;
335
+ _temp /= num_heads;
336
+ const int q_col = _temp % num_query;
337
+ _temp /= num_query;
338
+ const int b_col = _temp;
339
+
340
+ const scalar_t top_grad = grad_col[index];
341
+
342
+ int data_weight_ptr = sampling_index * num_levels * num_point;
343
+ int data_loc_w_ptr = data_weight_ptr << 1;
344
+ const int grad_sampling_ptr = data_weight_ptr;
345
+ grad_sampling_loc += grad_sampling_ptr << 1;
346
+ grad_attn_weight += grad_sampling_ptr;
347
+ const int grad_weight_stride = 1;
348
+ const int grad_loc_stride = 2;
349
+ const int qid_stride = num_heads * channels;
350
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
351
+
352
+ for (int l_col=0; l_col < num_levels; ++l_col)
353
+ {
354
+ const int level_start_id = data_level_start_index[l_col];
355
+ const int spatial_h_ptr = l_col << 1;
356
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
357
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
358
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
359
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
360
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
361
+
362
+ for (int p_col=0; p_col < num_point; ++p_col)
363
+ {
364
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
365
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
366
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
367
+
368
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
369
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
370
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
371
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
372
+ *(cache_grad_attn_weight+threadIdx.x)=0;
373
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
374
+ {
375
+ ms_deform_attn_col2im_bilinear(
376
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
377
+ top_grad, weight, grad_value_ptr,
378
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
379
+ }
380
+
381
+ __syncthreads();
382
+ if (tid == 0)
383
+ {
384
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
385
+ int sid=2;
386
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
387
+ {
388
+ _grad_w += cache_grad_sampling_loc[sid];
389
+ _grad_h += cache_grad_sampling_loc[sid + 1];
390
+ _grad_a += cache_grad_attn_weight[tid];
391
+ sid += 2;
392
+ }
393
+
394
+
395
+ *grad_sampling_loc = _grad_w;
396
+ *(grad_sampling_loc + 1) = _grad_h;
397
+ *grad_attn_weight = _grad_a;
398
+ }
399
+ __syncthreads();
400
+
401
+ data_weight_ptr += 1;
402
+ data_loc_w_ptr += 2;
403
+ grad_attn_weight += grad_weight_stride;
404
+ grad_sampling_loc += grad_loc_stride;
405
+ }
406
+ }
407
+ }
408
+ }
409
+
410
+
411
+ template <typename scalar_t, unsigned int blockSize>
412
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
413
+ const scalar_t *grad_col,
414
+ const scalar_t *data_value,
415
+ const int64_t *data_spatial_shapes,
416
+ const int64_t *data_level_start_index,
417
+ const scalar_t *data_sampling_loc,
418
+ const scalar_t *data_attn_weight,
419
+ const int batch_size,
420
+ const int spatial_size,
421
+ const int num_heads,
422
+ const int channels,
423
+ const int num_levels,
424
+ const int num_query,
425
+ const int num_point,
426
+ scalar_t *grad_value,
427
+ scalar_t *grad_sampling_loc,
428
+ scalar_t *grad_attn_weight)
429
+ {
430
+ CUDA_KERNEL_LOOP(index, n)
431
+ {
432
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
433
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
434
+ unsigned int tid = threadIdx.x;
435
+ int _temp = index;
436
+ const int c_col = _temp % channels;
437
+ _temp /= channels;
438
+ const int sampling_index = _temp;
439
+ const int m_col = _temp % num_heads;
440
+ _temp /= num_heads;
441
+ const int q_col = _temp % num_query;
442
+ _temp /= num_query;
443
+ const int b_col = _temp;
444
+
445
+ const scalar_t top_grad = grad_col[index];
446
+
447
+ int data_weight_ptr = sampling_index * num_levels * num_point;
448
+ int data_loc_w_ptr = data_weight_ptr << 1;
449
+ const int grad_sampling_ptr = data_weight_ptr;
450
+ grad_sampling_loc += grad_sampling_ptr << 1;
451
+ grad_attn_weight += grad_sampling_ptr;
452
+ const int grad_weight_stride = 1;
453
+ const int grad_loc_stride = 2;
454
+ const int qid_stride = num_heads * channels;
455
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
456
+
457
+ for (int l_col=0; l_col < num_levels; ++l_col)
458
+ {
459
+ const int level_start_id = data_level_start_index[l_col];
460
+ const int spatial_h_ptr = l_col << 1;
461
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
462
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
463
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
464
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
465
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
466
+
467
+ for (int p_col=0; p_col < num_point; ++p_col)
468
+ {
469
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
470
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
471
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
472
+
473
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
474
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
475
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
476
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
477
+ *(cache_grad_attn_weight+threadIdx.x)=0;
478
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
479
+ {
480
+ ms_deform_attn_col2im_bilinear(
481
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
482
+ top_grad, weight, grad_value_ptr,
483
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
484
+ }
485
+
486
+ __syncthreads();
487
+
488
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
489
+ {
490
+ if (tid < s) {
491
+ const unsigned int xid1 = tid << 1;
492
+ const unsigned int xid2 = (tid + s) << 1;
493
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
494
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
495
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
496
+ }
497
+ __syncthreads();
498
+ }
499
+
500
+ if (tid == 0)
501
+ {
502
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
503
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
504
+ *grad_attn_weight = cache_grad_attn_weight[0];
505
+ }
506
+ __syncthreads();
507
+
508
+ data_weight_ptr += 1;
509
+ data_loc_w_ptr += 2;
510
+ grad_attn_weight += grad_weight_stride;
511
+ grad_sampling_loc += grad_loc_stride;
512
+ }
513
+ }
514
+ }
515
+ }
516
+
517
+
518
+ template <typename scalar_t>
519
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
520
+ const scalar_t *grad_col,
521
+ const scalar_t *data_value,
522
+ const int64_t *data_spatial_shapes,
523
+ const int64_t *data_level_start_index,
524
+ const scalar_t *data_sampling_loc,
525
+ const scalar_t *data_attn_weight,
526
+ const int batch_size,
527
+ const int spatial_size,
528
+ const int num_heads,
529
+ const int channels,
530
+ const int num_levels,
531
+ const int num_query,
532
+ const int num_point,
533
+ scalar_t *grad_value,
534
+ scalar_t *grad_sampling_loc,
535
+ scalar_t *grad_attn_weight)
536
+ {
537
+ CUDA_KERNEL_LOOP(index, n)
538
+ {
539
+ extern __shared__ int _s[];
540
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
541
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
542
+ unsigned int tid = threadIdx.x;
543
+ int _temp = index;
544
+ const int c_col = _temp % channels;
545
+ _temp /= channels;
546
+ const int sampling_index = _temp;
547
+ const int m_col = _temp % num_heads;
548
+ _temp /= num_heads;
549
+ const int q_col = _temp % num_query;
550
+ _temp /= num_query;
551
+ const int b_col = _temp;
552
+
553
+ const scalar_t top_grad = grad_col[index];
554
+
555
+ int data_weight_ptr = sampling_index * num_levels * num_point;
556
+ int data_loc_w_ptr = data_weight_ptr << 1;
557
+ const int grad_sampling_ptr = data_weight_ptr;
558
+ grad_sampling_loc += grad_sampling_ptr << 1;
559
+ grad_attn_weight += grad_sampling_ptr;
560
+ const int grad_weight_stride = 1;
561
+ const int grad_loc_stride = 2;
562
+ const int qid_stride = num_heads * channels;
563
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
564
+
565
+ for (int l_col=0; l_col < num_levels; ++l_col)
566
+ {
567
+ const int level_start_id = data_level_start_index[l_col];
568
+ const int spatial_h_ptr = l_col << 1;
569
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
570
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
571
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
572
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
573
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
574
+
575
+ for (int p_col=0; p_col < num_point; ++p_col)
576
+ {
577
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
578
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
579
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
580
+
581
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
582
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
583
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
584
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
585
+ *(cache_grad_attn_weight+threadIdx.x)=0;
586
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
587
+ {
588
+ ms_deform_attn_col2im_bilinear(
589
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
590
+ top_grad, weight, grad_value_ptr,
591
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
592
+ }
593
+
594
+ __syncthreads();
595
+ if (tid == 0)
596
+ {
597
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
598
+ int sid=2;
599
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
600
+ {
601
+ _grad_w += cache_grad_sampling_loc[sid];
602
+ _grad_h += cache_grad_sampling_loc[sid + 1];
603
+ _grad_a += cache_grad_attn_weight[tid];
604
+ sid += 2;
605
+ }
606
+
607
+
608
+ *grad_sampling_loc = _grad_w;
609
+ *(grad_sampling_loc + 1) = _grad_h;
610
+ *grad_attn_weight = _grad_a;
611
+ }
612
+ __syncthreads();
613
+
614
+ data_weight_ptr += 1;
615
+ data_loc_w_ptr += 2;
616
+ grad_attn_weight += grad_weight_stride;
617
+ grad_sampling_loc += grad_loc_stride;
618
+ }
619
+ }
620
+ }
621
+ }
622
+
623
+ template <typename scalar_t>
624
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
625
+ const scalar_t *grad_col,
626
+ const scalar_t *data_value,
627
+ const int64_t *data_spatial_shapes,
628
+ const int64_t *data_level_start_index,
629
+ const scalar_t *data_sampling_loc,
630
+ const scalar_t *data_attn_weight,
631
+ const int batch_size,
632
+ const int spatial_size,
633
+ const int num_heads,
634
+ const int channels,
635
+ const int num_levels,
636
+ const int num_query,
637
+ const int num_point,
638
+ scalar_t *grad_value,
639
+ scalar_t *grad_sampling_loc,
640
+ scalar_t *grad_attn_weight)
641
+ {
642
+ CUDA_KERNEL_LOOP(index, n)
643
+ {
644
+ extern __shared__ int _s[];
645
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
646
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
647
+ unsigned int tid = threadIdx.x;
648
+ int _temp = index;
649
+ const int c_col = _temp % channels;
650
+ _temp /= channels;
651
+ const int sampling_index = _temp;
652
+ const int m_col = _temp % num_heads;
653
+ _temp /= num_heads;
654
+ const int q_col = _temp % num_query;
655
+ _temp /= num_query;
656
+ const int b_col = _temp;
657
+
658
+ const scalar_t top_grad = grad_col[index];
659
+
660
+ int data_weight_ptr = sampling_index * num_levels * num_point;
661
+ int data_loc_w_ptr = data_weight_ptr << 1;
662
+ const int grad_sampling_ptr = data_weight_ptr;
663
+ grad_sampling_loc += grad_sampling_ptr << 1;
664
+ grad_attn_weight += grad_sampling_ptr;
665
+ const int grad_weight_stride = 1;
666
+ const int grad_loc_stride = 2;
667
+ const int qid_stride = num_heads * channels;
668
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
669
+
670
+ for (int l_col=0; l_col < num_levels; ++l_col)
671
+ {
672
+ const int level_start_id = data_level_start_index[l_col];
673
+ const int spatial_h_ptr = l_col << 1;
674
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
675
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
676
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
677
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
678
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
679
+
680
+ for (int p_col=0; p_col < num_point; ++p_col)
681
+ {
682
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
683
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
684
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
685
+
686
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
687
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
688
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
689
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
690
+ *(cache_grad_attn_weight+threadIdx.x)=0;
691
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
692
+ {
693
+ ms_deform_attn_col2im_bilinear(
694
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
695
+ top_grad, weight, grad_value_ptr,
696
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
697
+ }
698
+
699
+ __syncthreads();
700
+
701
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
702
+ {
703
+ if (tid < s) {
704
+ const unsigned int xid1 = tid << 1;
705
+ const unsigned int xid2 = (tid + s) << 1;
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
709
+ if (tid + (s << 1) < spre)
710
+ {
711
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
712
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
713
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
714
+ }
715
+ }
716
+ __syncthreads();
717
+ }
718
+
719
+ if (tid == 0)
720
+ {
721
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
722
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
723
+ *grad_attn_weight = cache_grad_attn_weight[0];
724
+ }
725
+ __syncthreads();
726
+
727
+ data_weight_ptr += 1;
728
+ data_loc_w_ptr += 2;
729
+ grad_attn_weight += grad_weight_stride;
730
+ grad_sampling_loc += grad_loc_stride;
731
+ }
732
+ }
733
+ }
734
+ }
735
+
736
+ template <typename scalar_t>
737
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
738
+ const scalar_t *grad_col,
739
+ const scalar_t *data_value,
740
+ const int64_t *data_spatial_shapes,
741
+ const int64_t *data_level_start_index,
742
+ const scalar_t *data_sampling_loc,
743
+ const scalar_t *data_attn_weight,
744
+ const int batch_size,
745
+ const int spatial_size,
746
+ const int num_heads,
747
+ const int channels,
748
+ const int num_levels,
749
+ const int num_query,
750
+ const int num_point,
751
+ scalar_t *grad_value,
752
+ scalar_t *grad_sampling_loc,
753
+ scalar_t *grad_attn_weight)
754
+ {
755
+ CUDA_KERNEL_LOOP(index, n)
756
+ {
757
+ extern __shared__ int _s[];
758
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
759
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
760
+ unsigned int tid = threadIdx.x;
761
+ int _temp = index;
762
+ const int c_col = _temp % channels;
763
+ _temp /= channels;
764
+ const int sampling_index = _temp;
765
+ const int m_col = _temp % num_heads;
766
+ _temp /= num_heads;
767
+ const int q_col = _temp % num_query;
768
+ _temp /= num_query;
769
+ const int b_col = _temp;
770
+
771
+ const scalar_t top_grad = grad_col[index];
772
+
773
+ int data_weight_ptr = sampling_index * num_levels * num_point;
774
+ int data_loc_w_ptr = data_weight_ptr << 1;
775
+ const int grad_sampling_ptr = data_weight_ptr;
776
+ grad_sampling_loc += grad_sampling_ptr << 1;
777
+ grad_attn_weight += grad_sampling_ptr;
778
+ const int grad_weight_stride = 1;
779
+ const int grad_loc_stride = 2;
780
+ const int qid_stride = num_heads * channels;
781
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
782
+
783
+ for (int l_col=0; l_col < num_levels; ++l_col)
784
+ {
785
+ const int level_start_id = data_level_start_index[l_col];
786
+ const int spatial_h_ptr = l_col << 1;
787
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
788
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
789
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
790
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
791
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
792
+
793
+ for (int p_col=0; p_col < num_point; ++p_col)
794
+ {
795
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
796
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
797
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
798
+
799
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
800
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
801
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
802
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
803
+ *(cache_grad_attn_weight+threadIdx.x)=0;
804
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
805
+ {
806
+ ms_deform_attn_col2im_bilinear(
807
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
808
+ top_grad, weight, grad_value_ptr,
809
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
810
+ }
811
+
812
+ __syncthreads();
813
+
814
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
815
+ {
816
+ if (tid < s) {
817
+ const unsigned int xid1 = tid << 1;
818
+ const unsigned int xid2 = (tid + s) << 1;
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
822
+ if (tid + (s << 1) < spre)
823
+ {
824
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
825
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
826
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
827
+ }
828
+ }
829
+ __syncthreads();
830
+ }
831
+
832
+ if (tid == 0)
833
+ {
834
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
835
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
836
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
837
+ }
838
+ __syncthreads();
839
+
840
+ data_weight_ptr += 1;
841
+ data_loc_w_ptr += 2;
842
+ grad_attn_weight += grad_weight_stride;
843
+ grad_sampling_loc += grad_loc_stride;
844
+ }
845
+ }
846
+ }
847
+ }
848
+
849
+
850
+ template <typename scalar_t>
851
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
852
+ const scalar_t *grad_col,
853
+ const scalar_t *data_value,
854
+ const int64_t *data_spatial_shapes,
855
+ const int64_t *data_level_start_index,
856
+ const scalar_t *data_sampling_loc,
857
+ const scalar_t *data_attn_weight,
858
+ const int batch_size,
859
+ const int spatial_size,
860
+ const int num_heads,
861
+ const int channels,
862
+ const int num_levels,
863
+ const int num_query,
864
+ const int num_point,
865
+ scalar_t *grad_value,
866
+ scalar_t *grad_sampling_loc,
867
+ scalar_t *grad_attn_weight)
868
+ {
869
+ CUDA_KERNEL_LOOP(index, n)
870
+ {
871
+ int _temp = index;
872
+ const int c_col = _temp % channels;
873
+ _temp /= channels;
874
+ const int sampling_index = _temp;
875
+ const int m_col = _temp % num_heads;
876
+ _temp /= num_heads;
877
+ const int q_col = _temp % num_query;
878
+ _temp /= num_query;
879
+ const int b_col = _temp;
880
+
881
+ const scalar_t top_grad = grad_col[index];
882
+
883
+ int data_weight_ptr = sampling_index * num_levels * num_point;
884
+ int data_loc_w_ptr = data_weight_ptr << 1;
885
+ const int grad_sampling_ptr = data_weight_ptr;
886
+ grad_sampling_loc += grad_sampling_ptr << 1;
887
+ grad_attn_weight += grad_sampling_ptr;
888
+ const int grad_weight_stride = 1;
889
+ const int grad_loc_stride = 2;
890
+ const int qid_stride = num_heads * channels;
891
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
892
+
893
+ for (int l_col=0; l_col < num_levels; ++l_col)
894
+ {
895
+ const int level_start_id = data_level_start_index[l_col];
896
+ const int spatial_h_ptr = l_col << 1;
897
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
898
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
899
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
900
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
901
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
902
+
903
+ for (int p_col=0; p_col < num_point; ++p_col)
904
+ {
905
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
906
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
907
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
908
+
909
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
910
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
911
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
912
+ {
913
+ ms_deform_attn_col2im_bilinear_gm(
914
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
915
+ top_grad, weight, grad_value_ptr,
916
+ grad_sampling_loc, grad_attn_weight);
917
+ }
918
+ data_weight_ptr += 1;
919
+ data_loc_w_ptr += 2;
920
+ grad_attn_weight += grad_weight_stride;
921
+ grad_sampling_loc += grad_loc_stride;
922
+ }
923
+ }
924
+ }
925
+ }
926
+
927
+
928
+ template <typename scalar_t>
929
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
930
+ const scalar_t* data_value,
931
+ const int64_t* data_spatial_shapes,
932
+ const int64_t* data_level_start_index,
933
+ const scalar_t* data_sampling_loc,
934
+ const scalar_t* data_attn_weight,
935
+ const int batch_size,
936
+ const int spatial_size,
937
+ const int num_heads,
938
+ const int channels,
939
+ const int num_levels,
940
+ const int num_query,
941
+ const int num_point,
942
+ scalar_t* data_col)
943
+ {
944
+ const int num_kernels = batch_size * num_query * num_heads * channels;
945
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
946
+ const int num_threads = CUDA_NUM_THREADS;
947
+ ms_deformable_im2col_gpu_kernel<scalar_t>
948
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
949
+ 0, stream>>>(
950
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
951
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
952
+
953
+ cudaError_t err = cudaGetLastError();
954
+ if (err != cudaSuccess)
955
+ {
956
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
957
+ }
958
+
959
+ }
960
+
961
+ template <typename scalar_t>
962
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
963
+ const scalar_t* grad_col,
964
+ const scalar_t* data_value,
965
+ const int64_t * data_spatial_shapes,
966
+ const int64_t * data_level_start_index,
967
+ const scalar_t * data_sampling_loc,
968
+ const scalar_t * data_attn_weight,
969
+ const int batch_size,
970
+ const int spatial_size,
971
+ const int num_heads,
972
+ const int channels,
973
+ const int num_levels,
974
+ const int num_query,
975
+ const int num_point,
976
+ scalar_t* grad_value,
977
+ scalar_t* grad_sampling_loc,
978
+ scalar_t* grad_attn_weight)
979
+ {
980
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
981
+ const int num_kernels = batch_size * num_query * num_heads * channels;
982
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
983
+ if (channels > 1024)
984
+ {
985
+ if ((channels & 1023) == 0)
986
+ {
987
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
988
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
989
+ num_threads*3*sizeof(scalar_t), stream>>>(
990
+ num_kernels,
991
+ grad_col,
992
+ data_value,
993
+ data_spatial_shapes,
994
+ data_level_start_index,
995
+ data_sampling_loc,
996
+ data_attn_weight,
997
+ batch_size,
998
+ spatial_size,
999
+ num_heads,
1000
+ channels,
1001
+ num_levels,
1002
+ num_query,
1003
+ num_point,
1004
+ grad_value,
1005
+ grad_sampling_loc,
1006
+ grad_attn_weight);
1007
+ }
1008
+ else
1009
+ {
1010
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1011
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1012
+ 0, stream>>>(
1013
+ num_kernels,
1014
+ grad_col,
1015
+ data_value,
1016
+ data_spatial_shapes,
1017
+ data_level_start_index,
1018
+ data_sampling_loc,
1019
+ data_attn_weight,
1020
+ batch_size,
1021
+ spatial_size,
1022
+ num_heads,
1023
+ channels,
1024
+ num_levels,
1025
+ num_query,
1026
+ num_point,
1027
+ grad_value,
1028
+ grad_sampling_loc,
1029
+ grad_attn_weight);
1030
+ }
1031
+ }
1032
+ else{
1033
+ switch(channels)
1034
+ {
1035
+ case 1:
1036
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1037
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1038
+ 0, stream>>>(
1039
+ num_kernels,
1040
+ grad_col,
1041
+ data_value,
1042
+ data_spatial_shapes,
1043
+ data_level_start_index,
1044
+ data_sampling_loc,
1045
+ data_attn_weight,
1046
+ batch_size,
1047
+ spatial_size,
1048
+ num_heads,
1049
+ channels,
1050
+ num_levels,
1051
+ num_query,
1052
+ num_point,
1053
+ grad_value,
1054
+ grad_sampling_loc,
1055
+ grad_attn_weight);
1056
+ break;
1057
+ case 2:
1058
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1059
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1060
+ 0, stream>>>(
1061
+ num_kernels,
1062
+ grad_col,
1063
+ data_value,
1064
+ data_spatial_shapes,
1065
+ data_level_start_index,
1066
+ data_sampling_loc,
1067
+ data_attn_weight,
1068
+ batch_size,
1069
+ spatial_size,
1070
+ num_heads,
1071
+ channels,
1072
+ num_levels,
1073
+ num_query,
1074
+ num_point,
1075
+ grad_value,
1076
+ grad_sampling_loc,
1077
+ grad_attn_weight);
1078
+ break;
1079
+ case 4:
1080
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1081
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1082
+ 0, stream>>>(
1083
+ num_kernels,
1084
+ grad_col,
1085
+ data_value,
1086
+ data_spatial_shapes,
1087
+ data_level_start_index,
1088
+ data_sampling_loc,
1089
+ data_attn_weight,
1090
+ batch_size,
1091
+ spatial_size,
1092
+ num_heads,
1093
+ channels,
1094
+ num_levels,
1095
+ num_query,
1096
+ num_point,
1097
+ grad_value,
1098
+ grad_sampling_loc,
1099
+ grad_attn_weight);
1100
+ break;
1101
+ case 8:
1102
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1103
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1104
+ 0, stream>>>(
1105
+ num_kernels,
1106
+ grad_col,
1107
+ data_value,
1108
+ data_spatial_shapes,
1109
+ data_level_start_index,
1110
+ data_sampling_loc,
1111
+ data_attn_weight,
1112
+ batch_size,
1113
+ spatial_size,
1114
+ num_heads,
1115
+ channels,
1116
+ num_levels,
1117
+ num_query,
1118
+ num_point,
1119
+ grad_value,
1120
+ grad_sampling_loc,
1121
+ grad_attn_weight);
1122
+ break;
1123
+ case 16:
1124
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1125
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1126
+ 0, stream>>>(
1127
+ num_kernels,
1128
+ grad_col,
1129
+ data_value,
1130
+ data_spatial_shapes,
1131
+ data_level_start_index,
1132
+ data_sampling_loc,
1133
+ data_attn_weight,
1134
+ batch_size,
1135
+ spatial_size,
1136
+ num_heads,
1137
+ channels,
1138
+ num_levels,
1139
+ num_query,
1140
+ num_point,
1141
+ grad_value,
1142
+ grad_sampling_loc,
1143
+ grad_attn_weight);
1144
+ break;
1145
+ case 32:
1146
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1147
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1148
+ 0, stream>>>(
1149
+ num_kernels,
1150
+ grad_col,
1151
+ data_value,
1152
+ data_spatial_shapes,
1153
+ data_level_start_index,
1154
+ data_sampling_loc,
1155
+ data_attn_weight,
1156
+ batch_size,
1157
+ spatial_size,
1158
+ num_heads,
1159
+ channels,
1160
+ num_levels,
1161
+ num_query,
1162
+ num_point,
1163
+ grad_value,
1164
+ grad_sampling_loc,
1165
+ grad_attn_weight);
1166
+ break;
1167
+ case 64:
1168
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1169
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1170
+ 0, stream>>>(
1171
+ num_kernels,
1172
+ grad_col,
1173
+ data_value,
1174
+ data_spatial_shapes,
1175
+ data_level_start_index,
1176
+ data_sampling_loc,
1177
+ data_attn_weight,
1178
+ batch_size,
1179
+ spatial_size,
1180
+ num_heads,
1181
+ channels,
1182
+ num_levels,
1183
+ num_query,
1184
+ num_point,
1185
+ grad_value,
1186
+ grad_sampling_loc,
1187
+ grad_attn_weight);
1188
+ break;
1189
+ case 128:
1190
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1191
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1192
+ 0, stream>>>(
1193
+ num_kernels,
1194
+ grad_col,
1195
+ data_value,
1196
+ data_spatial_shapes,
1197
+ data_level_start_index,
1198
+ data_sampling_loc,
1199
+ data_attn_weight,
1200
+ batch_size,
1201
+ spatial_size,
1202
+ num_heads,
1203
+ channels,
1204
+ num_levels,
1205
+ num_query,
1206
+ num_point,
1207
+ grad_value,
1208
+ grad_sampling_loc,
1209
+ grad_attn_weight);
1210
+ break;
1211
+ case 256:
1212
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1213
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1214
+ 0, stream>>>(
1215
+ num_kernels,
1216
+ grad_col,
1217
+ data_value,
1218
+ data_spatial_shapes,
1219
+ data_level_start_index,
1220
+ data_sampling_loc,
1221
+ data_attn_weight,
1222
+ batch_size,
1223
+ spatial_size,
1224
+ num_heads,
1225
+ channels,
1226
+ num_levels,
1227
+ num_query,
1228
+ num_point,
1229
+ grad_value,
1230
+ grad_sampling_loc,
1231
+ grad_attn_weight);
1232
+ break;
1233
+ case 512:
1234
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1235
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1236
+ 0, stream>>>(
1237
+ num_kernels,
1238
+ grad_col,
1239
+ data_value,
1240
+ data_spatial_shapes,
1241
+ data_level_start_index,
1242
+ data_sampling_loc,
1243
+ data_attn_weight,
1244
+ batch_size,
1245
+ spatial_size,
1246
+ num_heads,
1247
+ channels,
1248
+ num_levels,
1249
+ num_query,
1250
+ num_point,
1251
+ grad_value,
1252
+ grad_sampling_loc,
1253
+ grad_attn_weight);
1254
+ break;
1255
+ case 1024:
1256
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1257
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1258
+ 0, stream>>>(
1259
+ num_kernels,
1260
+ grad_col,
1261
+ data_value,
1262
+ data_spatial_shapes,
1263
+ data_level_start_index,
1264
+ data_sampling_loc,
1265
+ data_attn_weight,
1266
+ batch_size,
1267
+ spatial_size,
1268
+ num_heads,
1269
+ channels,
1270
+ num_levels,
1271
+ num_query,
1272
+ num_point,
1273
+ grad_value,
1274
+ grad_sampling_loc,
1275
+ grad_attn_weight);
1276
+ break;
1277
+ default:
1278
+ if (channels < 64)
1279
+ {
1280
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1281
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1282
+ num_threads*3*sizeof(scalar_t), stream>>>(
1283
+ num_kernels,
1284
+ grad_col,
1285
+ data_value,
1286
+ data_spatial_shapes,
1287
+ data_level_start_index,
1288
+ data_sampling_loc,
1289
+ data_attn_weight,
1290
+ batch_size,
1291
+ spatial_size,
1292
+ num_heads,
1293
+ channels,
1294
+ num_levels,
1295
+ num_query,
1296
+ num_point,
1297
+ grad_value,
1298
+ grad_sampling_loc,
1299
+ grad_attn_weight);
1300
+ }
1301
+ else
1302
+ {
1303
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1304
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
+ num_threads*3*sizeof(scalar_t), stream>>>(
1306
+ num_kernels,
1307
+ grad_col,
1308
+ data_value,
1309
+ data_spatial_shapes,
1310
+ data_level_start_index,
1311
+ data_sampling_loc,
1312
+ data_attn_weight,
1313
+ batch_size,
1314
+ spatial_size,
1315
+ num_heads,
1316
+ channels,
1317
+ num_levels,
1318
+ num_query,
1319
+ num_point,
1320
+ grad_value,
1321
+ grad_sampling_loc,
1322
+ grad_attn_weight);
1323
+ }
1324
+ }
1325
+ }
1326
+ cudaError_t err = cudaGetLastError();
1327
+ if (err != cudaSuccess)
1328
+ {
1329
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1330
+ }
1331
+
1332
+ }
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+
18
+ #include "cpu/ms_deform_attn_cpu.h"
19
+
20
+ #ifdef WITH_CUDA
21
+ #include "cuda/ms_deform_attn_cuda.h"
22
+ #endif
23
+
24
+
25
+ at::Tensor
26
+ ms_deform_attn_forward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const int im2col_step)
33
+ {
34
+ if (value.type().is_cuda())
35
+ {
36
+ #ifdef WITH_CUDA
37
+ return ms_deform_attn_cuda_forward(
38
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
39
+ #else
40
+ AT_ERROR("Not compiled with GPU support");
41
+ #endif
42
+ }
43
+ AT_ERROR("Not implemented on the CPU");
44
+ }
45
+
46
+ std::vector<at::Tensor>
47
+ ms_deform_attn_backward(
48
+ const at::Tensor &value,
49
+ const at::Tensor &spatial_shapes,
50
+ const at::Tensor &level_start_index,
51
+ const at::Tensor &sampling_loc,
52
+ const at::Tensor &attn_weight,
53
+ const at::Tensor &grad_output,
54
+ const int im2col_step)
55
+ {
56
+ if (value.type().is_cuda())
57
+ {
58
+ #ifdef WITH_CUDA
59
+ return ms_deform_attn_cuda_backward(
60
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
61
+ #else
62
+ AT_ERROR("Not compiled with GPU support");
63
+ #endif
64
+ }
65
+ AT_ERROR("Not implemented on the CPU");
66
+ }
67
+
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/vision.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include "ms_deform_attn.h"
17
+
18
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
20
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
21
+ }
annotator/entityseg/mask2former/modeling/pixel_decoder/ops/test.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.autograd import gradcheck
20
+
21
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
22
+
23
+
24
+ N, M, D = 1, 2, 2
25
+ Lq, L, P = 2, 2, 2
26
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
27
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
28
+ S = sum([(H*W).item() for H, W in shapes])
29
+
30
+
31
+ torch.manual_seed(3)
32
+
33
+
34
+ @torch.no_grad()
35
+ def check_forward_equal_with_pytorch_double():
36
+ value = torch.rand(N, S, M, D).cuda() * 0.01
37
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
38
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
39
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
40
+ im2col_step = 2
41
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
42
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
43
+ fwdok = torch.allclose(output_cuda, output_pytorch)
44
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
45
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
46
+
47
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
48
+
49
+
50
+ @torch.no_grad()
51
+ def check_forward_equal_with_pytorch_float():
52
+ value = torch.rand(N, S, M, D).cuda() * 0.01
53
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
54
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
55
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
56
+ im2col_step = 2
57
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
58
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
59
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
60
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
61
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
62
+
63
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
64
+
65
+
66
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
67
+
68
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
69
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
70
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
71
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
72
+ im2col_step = 2
73
+ func = MSDeformAttnFunction.apply
74
+
75
+ value.requires_grad = grad_value
76
+ sampling_locations.requires_grad = grad_sampling_loc
77
+ attention_weights.requires_grad = grad_attn_weight
78
+
79
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
80
+
81
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
82
+
83
+
84
+ if __name__ == '__main__':
85
+ check_forward_equal_with_pytorch_double()
86
+ check_forward_equal_with_pytorch_float()
87
+
88
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
89
+ check_gradient_numerical(channels, True, True, True)
90
+
91
+
92
+
annotator/entityseg/mask2former/modeling/transformer_decoder/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .maskformer_transformer_decoder import StandardTransformerDecoder
3
+ from .mask2former_transformer_decoder import MultiScaleMaskedTransformerDecoder
4
+ from .cropformer_transformer_decoder import CropSharedMultiScaleMaskedTransformerDecoder
5
+
annotator/entityseg/mask2former/modeling/transformer_decoder/cropformer_transformer_decoder.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ import logging
4
+ import fvcore.nn.weight_init as weight_init
5
+ from typing import Optional
6
+ import torch
7
+ from torch import nn, Tensor
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d
12
+
13
+ from .position_encoding import PositionEmbeddingSine3D2D
14
+ from .maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY
15
+
16
+ import pdb
17
+
18
+ class SelfAttentionLayer(nn.Module):
19
+
20
+ def __init__(self, d_model, nhead, dropout=0.0,
21
+ activation="relu", normalize_before=False):
22
+ super().__init__()
23
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
24
+
25
+ self.norm = nn.LayerNorm(d_model)
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ self.activation = _get_activation_fn(activation)
29
+ self.normalize_before = normalize_before
30
+
31
+ self._reset_parameters()
32
+
33
+ def _reset_parameters(self):
34
+ for p in self.parameters():
35
+ if p.dim() > 1:
36
+ nn.init.xavier_uniform_(p)
37
+
38
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
39
+ return tensor if pos is None else tensor + pos
40
+
41
+ def forward_post(self, tgt,
42
+ tgt_mask: Optional[Tensor] = None,
43
+ tgt_key_padding_mask: Optional[Tensor] = None,
44
+ query_pos: Optional[Tensor] = None):
45
+ q = k = self.with_pos_embed(tgt, query_pos)
46
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
47
+ key_padding_mask=tgt_key_padding_mask)[0]
48
+ tgt = tgt + self.dropout(tgt2)
49
+ tgt = self.norm(tgt)
50
+
51
+ return tgt
52
+
53
+ def forward_pre(self, tgt,
54
+ tgt_mask: Optional[Tensor] = None,
55
+ tgt_key_padding_mask: Optional[Tensor] = None,
56
+ query_pos: Optional[Tensor] = None):
57
+ tgt2 = self.norm(tgt)
58
+ q = k = self.with_pos_embed(tgt2, query_pos)
59
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
60
+ key_padding_mask=tgt_key_padding_mask)[0]
61
+ tgt = tgt + self.dropout(tgt2)
62
+
63
+ return tgt
64
+
65
+ def forward(self, tgt,
66
+ tgt_mask: Optional[Tensor] = None,
67
+ tgt_key_padding_mask: Optional[Tensor] = None,
68
+ query_pos: Optional[Tensor] = None):
69
+ if self.normalize_before:
70
+ return self.forward_pre(tgt, tgt_mask,
71
+ tgt_key_padding_mask, query_pos)
72
+ return self.forward_post(tgt, tgt_mask,
73
+ tgt_key_padding_mask, query_pos)
74
+
75
+
76
+ class CrossAttentionLayer(nn.Module):
77
+
78
+ def __init__(self, d_model, nhead, dropout=0.0,
79
+ activation="relu", normalize_before=False):
80
+ super().__init__()
81
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
82
+
83
+ self.norm = nn.LayerNorm(d_model)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ self.activation = _get_activation_fn(activation)
87
+ self.normalize_before = normalize_before
88
+
89
+ self._reset_parameters()
90
+
91
+ def _reset_parameters(self):
92
+ for p in self.parameters():
93
+ if p.dim() > 1:
94
+ nn.init.xavier_uniform_(p)
95
+
96
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
97
+ return tensor if pos is None else tensor + pos
98
+
99
+ def forward_post(self, tgt, memory,
100
+ memory_mask: Optional[Tensor] = None,
101
+ memory_key_padding_mask: Optional[Tensor] = None,
102
+ pos: Optional[Tensor] = None,
103
+ query_pos: Optional[Tensor] = None):
104
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
105
+ key=self.with_pos_embed(memory, pos),
106
+ value=memory, attn_mask=memory_mask,
107
+ key_padding_mask=memory_key_padding_mask)[0]
108
+ tgt = tgt + self.dropout(tgt2)
109
+ tgt = self.norm(tgt)
110
+
111
+ return tgt
112
+
113
+ def forward_pre(self, tgt, memory,
114
+ memory_mask: Optional[Tensor] = None,
115
+ memory_key_padding_mask: Optional[Tensor] = None,
116
+ pos: Optional[Tensor] = None,
117
+ query_pos: Optional[Tensor] = None):
118
+ tgt2 = self.norm(tgt)
119
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
120
+ key=self.with_pos_embed(memory, pos),
121
+ value=memory, attn_mask=memory_mask,
122
+ key_padding_mask=memory_key_padding_mask)[0]
123
+ tgt = tgt + self.dropout(tgt2)
124
+
125
+ return tgt
126
+
127
+ def forward(self, tgt, memory,
128
+ memory_mask: Optional[Tensor] = None,
129
+ memory_key_padding_mask: Optional[Tensor] = None,
130
+ pos: Optional[Tensor] = None,
131
+ query_pos: Optional[Tensor] = None):
132
+ if self.normalize_before:
133
+ return self.forward_pre(tgt, memory, memory_mask,
134
+ memory_key_padding_mask, pos, query_pos)
135
+ return self.forward_post(tgt, memory, memory_mask,
136
+ memory_key_padding_mask, pos, query_pos)
137
+
138
+
139
+ class FFNLayer(nn.Module):
140
+
141
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
142
+ activation="relu", normalize_before=False):
143
+ super().__init__()
144
+ # Implementation of Feedforward model
145
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
146
+ self.dropout = nn.Dropout(dropout)
147
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
148
+
149
+ self.norm = nn.LayerNorm(d_model)
150
+
151
+ self.activation = _get_activation_fn(activation)
152
+ self.normalize_before = normalize_before
153
+
154
+ self._reset_parameters()
155
+
156
+ def _reset_parameters(self):
157
+ for p in self.parameters():
158
+ if p.dim() > 1:
159
+ nn.init.xavier_uniform_(p)
160
+
161
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
162
+ return tensor if pos is None else tensor + pos
163
+
164
+ def forward_post(self, tgt):
165
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
166
+ tgt = tgt + self.dropout(tgt2)
167
+ tgt = self.norm(tgt)
168
+ return tgt
169
+
170
+ def forward_pre(self, tgt):
171
+ tgt2 = self.norm(tgt)
172
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
173
+ tgt = tgt + self.dropout(tgt2)
174
+ return tgt
175
+
176
+ def forward(self, tgt):
177
+ if self.normalize_before:
178
+ return self.forward_pre(tgt)
179
+ return self.forward_post(tgt)
180
+
181
+ def _get_activation_fn(activation):
182
+ """Return an activation function given a string"""
183
+ if activation == "relu":
184
+ return F.relu
185
+ if activation == "gelu":
186
+ return F.gelu
187
+ if activation == "glu":
188
+ return F.glu
189
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
190
+
191
+
192
+ class MLP(nn.Module):
193
+ """ Very simple multi-layer perceptron (also called FFN)"""
194
+
195
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
196
+ super().__init__()
197
+ self.num_layers = num_layers
198
+ h = [hidden_dim] * (num_layers - 1)
199
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
200
+
201
+ def forward(self, x):
202
+ for i, layer in enumerate(self.layers):
203
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
204
+ return x
205
+
206
+ class Make3dQueries(nn.Module):
207
+ _version = 2
208
+ def __init__(self, cfg):
209
+ super().__init__()
210
+ self.cfg = cfg
211
+ self.enc_crosattn_3d = nn.ModuleList()
212
+ self.enc_selfattn_3d = nn.ModuleList()
213
+ self.enc_ffn_3d = nn.ModuleList()
214
+ self.num_layers_3d = cfg.ENTITY.FUSE_NUM_LAYERS
215
+ for _ in range(self.num_layers_3d):
216
+ self.enc_crosattn_3d.append(
217
+ CrossAttentionLayer(
218
+ d_model=cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM,
219
+ nhead=cfg.ENTITY.FUSE_ENC_NHEADS,
220
+ dropout=0.0,
221
+ normalize_before=cfg.ENTITY.FUSE_ENC_PRE_NORM)
222
+ )
223
+ self.enc_selfattn_3d.append(
224
+ SelfAttentionLayer(
225
+ d_model=cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM,
226
+ nhead=cfg.ENTITY.FUSE_ENC_NHEADS,
227
+ dropout=0.0,
228
+ normalize_before=cfg.ENTITY.FUSE_ENC_PRE_NORM)
229
+ )
230
+ self.enc_ffn_3d.append(
231
+ FFNLayer(
232
+ d_model=cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM,
233
+ dim_feedforward=cfg.ENTITY.FUSE_ENC_DIM_FEEDFORWARD,
234
+ dropout=0.0,
235
+ normalize_before=cfg.ENTITY.FUSE_ENC_PRE_NORM,
236
+ )
237
+ )
238
+
239
+ def forward(self, output_2d, query_embed_2d, query_embed_3d):
240
+ Q, BT, C = query_embed_2d.shape
241
+ Q, B, C = query_embed_3d.shape
242
+ T = int(BT / B)
243
+
244
+ output_3d = output_2d[:,0::T,:]
245
+ ### (Q, B, T, C)
246
+ output_2d = output_2d.unflatten(1, (B, T)).permute((0,2,1,3)).flatten(0,1)
247
+ query_embed_2d = query_embed_2d.unflatten(1, (B, T)).permute((0,2,1,3)).flatten(0,1)
248
+
249
+ for i in range(self.num_layers_3d):
250
+ output_3d = self.enc_crosattn_3d[i](output_3d, output_2d, pos=query_embed_2d, query_pos=query_embed_3d)
251
+ output_3d = self.enc_selfattn_3d[i](output_3d)
252
+ output_3d = self.enc_ffn_3d[i](output_3d)
253
+
254
+ return output_3d
255
+
256
+
257
+ @TRANSFORMER_DECODER_REGISTRY.register()
258
+ class CropSharedMultiScaleMaskedTransformerDecoder(nn.Module):
259
+ _version = 2
260
+
261
+ def _load_from_state_dict(
262
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
263
+ ):
264
+ version = local_metadata.get("version", None)
265
+ if version is None or version < 2:
266
+ # Do not warn if train from scratch
267
+ scratch = True
268
+ logger = logging.getLogger(__name__)
269
+ for k in list(state_dict.keys()):
270
+ newk = k
271
+ if "static_query" in k:
272
+ newk = k.replace("static_query", "query_feat")
273
+ if newk != k:
274
+ state_dict[newk] = state_dict[k]
275
+ del state_dict[k]
276
+ scratch = False
277
+
278
+ if not scratch:
279
+ logger.warning(
280
+ f"Weight format of {self.__class__.__name__} have changed! "
281
+ "Please upgrade your models. Applying automatic conversion now ..."
282
+ )
283
+
284
+ @configurable
285
+ def __init__(
286
+ self,
287
+ cfg,
288
+ in_channels,
289
+ mask_classification=True,
290
+ *,
291
+ num_classes: int,
292
+ hidden_dim: int,
293
+ num_queries: int,
294
+ nheads: int,
295
+ dim_feedforward: int,
296
+ dec_layers: int,
297
+ pre_norm: bool,
298
+ mask_dim: int,
299
+ enforce_input_project: bool,
300
+ ):
301
+ """
302
+ NOTE: this interface is experimental.
303
+ Args:
304
+ in_channels: channels of the input features
305
+ mask_classification: whether to add mask classifier or not
306
+ num_classes: number of classes
307
+ hidden_dim: Transformer feature dimension
308
+ num_queries: number of queries
309
+ nheads: number of heads
310
+ dim_feedforward: feature dimension in feedforward network
311
+ enc_layers: number of Transformer encoder layers
312
+ dec_layers: number of Transformer decoder layers
313
+ pre_norm: whether to use pre-LayerNorm or not
314
+ mask_dim: mask feature dimension
315
+ enforce_input_project: add input project 1x1 conv even if input
316
+ channels and hidden dim is identical
317
+ """
318
+ super().__init__()
319
+
320
+ assert mask_classification, "Only support mask classification model"
321
+ self.cfg = cfg
322
+
323
+ self.mask_classification = mask_classification
324
+ # positional encoding
325
+ N_steps = hidden_dim // 2
326
+
327
+ self.pe_layer = PositionEmbeddingSine3D2D(N_steps, normalize=True)
328
+
329
+ # define Transformer decoder here
330
+ self.num_heads = nheads
331
+ self.num_layers = dec_layers
332
+ self.transformer_self_attention_layers = nn.ModuleList()
333
+ self.transformer_cross_attention_layers = nn.ModuleList()
334
+ self.transformer_ffn_layers = nn.ModuleList()
335
+
336
+ for _ in range(self.num_layers):
337
+ self.transformer_self_attention_layers.append(
338
+ SelfAttentionLayer(
339
+ d_model=hidden_dim,
340
+ nhead=nheads,
341
+ dropout=0.0,
342
+ normalize_before=pre_norm,
343
+ )
344
+ )
345
+
346
+ self.transformer_cross_attention_layers.append(
347
+ CrossAttentionLayer(
348
+ d_model=hidden_dim,
349
+ nhead=nheads,
350
+ dropout=0.0,
351
+ normalize_before=pre_norm,
352
+ )
353
+ )
354
+
355
+ self.transformer_ffn_layers.append(
356
+ FFNLayer(
357
+ d_model=hidden_dim,
358
+ dim_feedforward=dim_feedforward,
359
+ dropout=0.0,
360
+ normalize_before=pre_norm,
361
+ )
362
+ )
363
+
364
+ self.make_3d = Make3dQueries(cfg)
365
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
366
+
367
+ self.num_queries = num_queries
368
+ # learnable query features
369
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
370
+ # learnable query p.e.
371
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
372
+
373
+ # level embedding (we always use 3 scales)
374
+ self.num_feature_levels = 3
375
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
376
+ self.input_proj = nn.ModuleList()
377
+ for _ in range(self.num_feature_levels):
378
+ if in_channels != hidden_dim or enforce_input_project:
379
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
380
+ weight_init.c2_xavier_fill(self.input_proj[-1])
381
+ weight_init.c2_xavier_fill(self.input_proj_3d[-1])
382
+ else:
383
+ self.input_proj.append(nn.Sequential())
384
+
385
+ # output FFNs
386
+ if self.mask_classification:
387
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
388
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
389
+
390
+ @classmethod
391
+ def from_config(cls, cfg, in_channels, mask_classification):
392
+ ret = {}
393
+ ret["cfg"] = cfg
394
+ ret["in_channels"] = in_channels
395
+ ret["mask_classification"] = mask_classification
396
+
397
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
398
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
399
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
400
+ # Transformer parameters:
401
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
402
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
403
+
404
+ # NOTE: because we add learnable query features which requires supervision,
405
+ # we add minus 1 to decoder layers to be consistent with our loss
406
+ # implementation: that is, number of auxiliary losses is always
407
+ # equal to number of decoder layers. With learnable query features, the number of
408
+ # auxiliary losses equals number of decoders plus 1.
409
+ assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
410
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
411
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
412
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
413
+
414
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
415
+
416
+ return ret
417
+
418
+ def forward(self, x, mask_features, mask = None):
419
+ # x is a list of multi-scale feature
420
+ assert len(x) == self.num_feature_levels
421
+
422
+ bt, c_m, h_m, w_m = mask_features.shape
423
+ bs = bt // (self.cfg.ENTITY.CROP_SAMPLE_NUM_TRAIN+1) if self.training else 1
424
+ # bs = bt // self.num_views if self.training else 1
425
+ t_m = bt // bs
426
+ mask_features_2d = mask_features
427
+ mask_features_3d = mask_features.view(bs, t_m, c_m, h_m, w_m)
428
+
429
+ src_2d, src_3d = [], []
430
+ pos_2d, pos_3d = [], []
431
+ size_list = []
432
+
433
+ # disable mask, it does not affect performance
434
+ del mask
435
+
436
+ # pdb.set_trace()
437
+ for i in range(self.num_feature_levels):
438
+ size_list.append(x[i].shape[-2:])
439
+ pos_2d_, pos_3d_ = self.pe_layer(x[i].view(bs, t_m, -1, size_list[-1][0], size_list[-1][1]))
440
+
441
+ pos_3d.append(pos_3d_.flatten(3))
442
+ src_3d.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
443
+
444
+ pos_2d.append(pos_2d_.flatten(2))
445
+ src_2d.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
446
+
447
+ # NTxCxHW => NxTxCxHW => (TxHW)xNxC
448
+ _, c, hw = src_3d[-1].shape
449
+ pos_3d[-1] = pos_3d[-1].view(bs, t_m, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
450
+ src_3d[-1] = src_3d[-1].view(bs, t_m, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
451
+
452
+ pos_2d[-1] = pos_2d[-1].permute(2,0,1)
453
+ src_2d[-1] = src_2d[-1].permute(2,0,1)
454
+
455
+ # QxNxC
456
+ query_embed_2d = self.query_embed.weight.unsqueeze(1).repeat(1, bt, 1)
457
+ output_2d = self.query_feat.weight.unsqueeze(1).repeat(1, bt, 1)
458
+
459
+ predictions_class_2d = []
460
+ predictions_mask_2d = []
461
+
462
+ # prediction heads on learnable query features
463
+ outputs_class_2d, outputs_mask_2d, attn_mask_2d, embedding_2d = self.forward_prediction_heads(output_2d, mask_features_2d, output_type="2d", attn_mask_target_size=size_list[0])
464
+ predictions_class_2d.append(outputs_class_2d)
465
+ predictions_mask_2d.append(outputs_mask_2d)
466
+
467
+ # pdb.set_trace()
468
+ for i in range(self.num_layers):
469
+ level_index = i % self.num_feature_levels
470
+ attn_mask_2d[torch.where(attn_mask_2d.sum(-1) == attn_mask_2d.shape[-1])] = False
471
+ # attention: cross-attention first
472
+ output_2d = self.transformer_cross_attention_layers[i](
473
+ output_2d, src_2d[level_index],
474
+ memory_mask=attn_mask_2d,
475
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
476
+ pos=pos_2d[level_index], query_pos=query_embed_2d
477
+ )
478
+
479
+ output_2d = self.transformer_self_attention_layers[i](
480
+ output_2d, tgt_mask=None,
481
+ tgt_key_padding_mask=None,
482
+ query_pos=query_embed_2d
483
+ )
484
+
485
+ # FFN
486
+ output_2d = self.transformer_ffn_layers[i](
487
+ output_2d
488
+ )
489
+
490
+ outputs_class_2d, outputs_mask_2d, attn_mask_2d, embedding_2d = self.forward_prediction_heads(output_2d, mask_features_2d, output_type="2d", attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
491
+ predictions_class_2d.append(outputs_class_2d)
492
+ predictions_mask_2d.append(outputs_mask_2d)
493
+
494
+ assert len(predictions_class_2d) == self.num_layers + 1
495
+
496
+ out_2d = {
497
+ 'pred_logits': predictions_class_2d[-1],
498
+ 'pred_masks': predictions_mask_2d[-1],
499
+ 'aux_outputs': self._set_aux_loss(
500
+ predictions_class_2d if self.mask_classification else None, predictions_mask_2d
501
+ )
502
+ }
503
+
504
+ predictions_class_3d = []
505
+ predictions_mask_3d = []
506
+
507
+ query_embed_3d = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
508
+
509
+ output_3d = self.make_3d(output_2d, query_embed_2d, query_embed_3d)
510
+
511
+ # self.fused
512
+ outputs_class_3d, outputs_mask_3d, attn_mask_3d, embedding_3d = self.forward_prediction_heads(output_3d, mask_features_3d, output_type="3d", attn_mask_target_size=size_list[0])
513
+ predictions_class_3d.append(outputs_class_3d)
514
+ predictions_mask_3d.append(outputs_mask_3d)
515
+
516
+ for i in range(self.num_layers):
517
+ level_index = i % self.num_feature_levels
518
+ attn_mask_3d[torch.where(attn_mask_3d.sum(-1) == attn_mask_3d.shape[-1])] = False
519
+ ################# 3d (unified) #############
520
+ # attention: cross-attention first
521
+ output_3d = self.transformer_cross_attention_layers[i](
522
+ output_3d, src_3d[level_index],
523
+ memory_mask=attn_mask_3d,
524
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
525
+ pos=pos_3d[level_index], query_pos=query_embed_3d
526
+ )
527
+
528
+ output_3d = self.transformer_self_attention_layers[i](
529
+ output_3d, tgt_mask=None,
530
+ tgt_key_padding_mask=None,
531
+ query_pos=query_embed_3d
532
+ )
533
+
534
+ output_3d = self.transformer_ffn_layers[i](
535
+ output_3d
536
+ )
537
+
538
+ outputs_class_3d, outputs_mask_3d, attn_mask_3d, embedding_3d = self.forward_prediction_heads(output_3d, mask_features_3d, output_type="3d", attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
539
+ predictions_class_3d.append(outputs_class_3d)
540
+ predictions_mask_3d.append(outputs_mask_3d)
541
+
542
+ # assert len(predictions_class_3d) == self.num_layers + 1
543
+
544
+ out_3d = {
545
+ 'pred_logits': predictions_class_3d[-1],
546
+ 'pred_masks': predictions_mask_3d[-1],
547
+ 'aux_outputs': self._set_aux_loss(
548
+ predictions_class_3d if self.mask_classification else None, predictions_mask_3d
549
+ ),
550
+ }
551
+
552
+ return out_2d, out_3d
553
+
554
+ def forward_prediction_heads(self, output, mask_features, output_type, attn_mask_target_size):
555
+ decoder_output = self.decoder_norm(output)
556
+ decoder_output = decoder_output.transpose(0, 1)
557
+ outputs_class = self.class_embed(decoder_output)
558
+ mask_embed = self.mask_embed(decoder_output)
559
+ if output_type == "3d":
560
+ outputs_mask = torch.einsum("bqc,btchw->bqthw", mask_embed, mask_features)
561
+ b, q, t, _, _ = outputs_mask.shape
562
+ # NOTE: prediction is of higher-resolution
563
+ # [B, Q, T, H, W] -> [B, Q, T*H*W] -> [B, h, Q, T*H*W] -> [B*h, Q, T*HW]
564
+ attn_mask = F.interpolate(outputs_mask.flatten(0, 1), size=attn_mask_target_size, mode="bilinear", align_corners=False).view(
565
+ b, q, t, attn_mask_target_size[0], attn_mask_target_size[1])
566
+ # must use bool type
567
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
568
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
569
+ attn_mask = attn_mask.detach()
570
+ elif output_type == "2d":
571
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
572
+ # NOTE: prediction is of higher-resolution
573
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
574
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
575
+ # must use bool type
576
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
577
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
578
+ attn_mask = attn_mask.detach()
579
+ else:
580
+ raise "the output_type should be 2d or 3d"
581
+
582
+ return outputs_class, outputs_mask, attn_mask, decoder_output
583
+
584
+ @torch.jit.unused
585
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
586
+ # this is a workaround to make torchscript happy, as torchscript
587
+ # doesn't support dictionary with non-homogeneous values, such
588
+ # as a dict having both a Tensor and a list.
589
+ if self.mask_classification:
590
+ return [
591
+ {"pred_logits": a, "pred_masks": b}
592
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
593
+ ]
594
+ else:
595
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
annotator/entityseg/mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ import logging
4
+ import fvcore.nn.weight_init as weight_init
5
+ from typing import Optional
6
+ import torch
7
+ from torch import nn, Tensor
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d
12
+
13
+ from .position_encoding import PositionEmbeddingSine
14
+ from .maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY
15
+
16
+
17
+ class SelfAttentionLayer(nn.Module):
18
+
19
+ def __init__(self, d_model, nhead, dropout=0.0,
20
+ activation="relu", normalize_before=False):
21
+ super().__init__()
22
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
23
+
24
+ self.norm = nn.LayerNorm(d_model)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ self.activation = _get_activation_fn(activation)
28
+ self.normalize_before = normalize_before
29
+
30
+ self._reset_parameters()
31
+
32
+ def _reset_parameters(self):
33
+ for p in self.parameters():
34
+ if p.dim() > 1:
35
+ nn.init.xavier_uniform_(p)
36
+
37
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
38
+ return tensor if pos is None else tensor + pos
39
+
40
+ def forward_post(self, tgt,
41
+ tgt_mask: Optional[Tensor] = None,
42
+ tgt_key_padding_mask: Optional[Tensor] = None,
43
+ query_pos: Optional[Tensor] = None):
44
+ q = k = self.with_pos_embed(tgt, query_pos)
45
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
46
+ key_padding_mask=tgt_key_padding_mask)[0]
47
+ tgt = tgt + self.dropout(tgt2)
48
+ tgt = self.norm(tgt)
49
+
50
+ return tgt
51
+
52
+ def forward_pre(self, tgt,
53
+ tgt_mask: Optional[Tensor] = None,
54
+ tgt_key_padding_mask: Optional[Tensor] = None,
55
+ query_pos: Optional[Tensor] = None):
56
+ tgt2 = self.norm(tgt)
57
+ q = k = self.with_pos_embed(tgt2, query_pos)
58
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
59
+ key_padding_mask=tgt_key_padding_mask)[0]
60
+ tgt = tgt + self.dropout(tgt2)
61
+
62
+ return tgt
63
+
64
+ def forward(self, tgt,
65
+ tgt_mask: Optional[Tensor] = None,
66
+ tgt_key_padding_mask: Optional[Tensor] = None,
67
+ query_pos: Optional[Tensor] = None):
68
+ if self.normalize_before:
69
+ return self.forward_pre(tgt, tgt_mask,
70
+ tgt_key_padding_mask, query_pos)
71
+ return self.forward_post(tgt, tgt_mask,
72
+ tgt_key_padding_mask, query_pos)
73
+
74
+
75
+ class CrossAttentionLayer(nn.Module):
76
+
77
+ def __init__(self, d_model, nhead, dropout=0.0,
78
+ activation="relu", normalize_before=False):
79
+ super().__init__()
80
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
81
+
82
+ self.norm = nn.LayerNorm(d_model)
83
+ self.dropout = nn.Dropout(dropout)
84
+
85
+ self.activation = _get_activation_fn(activation)
86
+ self.normalize_before = normalize_before
87
+
88
+ self._reset_parameters()
89
+
90
+ def _reset_parameters(self):
91
+ for p in self.parameters():
92
+ if p.dim() > 1:
93
+ nn.init.xavier_uniform_(p)
94
+
95
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
96
+ return tensor if pos is None else tensor + pos
97
+
98
+ def forward_post(self, tgt, memory,
99
+ memory_mask: Optional[Tensor] = None,
100
+ memory_key_padding_mask: Optional[Tensor] = None,
101
+ pos: Optional[Tensor] = None,
102
+ query_pos: Optional[Tensor] = None):
103
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
104
+ key=self.with_pos_embed(memory, pos),
105
+ value=memory, attn_mask=memory_mask,
106
+ key_padding_mask=memory_key_padding_mask)[0]
107
+ tgt = tgt + self.dropout(tgt2)
108
+ tgt = self.norm(tgt)
109
+
110
+ return tgt
111
+
112
+ def forward_pre(self, tgt, memory,
113
+ memory_mask: Optional[Tensor] = None,
114
+ memory_key_padding_mask: Optional[Tensor] = None,
115
+ pos: Optional[Tensor] = None,
116
+ query_pos: Optional[Tensor] = None):
117
+ tgt2 = self.norm(tgt)
118
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
119
+ key=self.with_pos_embed(memory, pos),
120
+ value=memory, attn_mask=memory_mask,
121
+ key_padding_mask=memory_key_padding_mask)[0]
122
+ tgt = tgt + self.dropout(tgt2)
123
+
124
+ return tgt
125
+
126
+ def forward(self, tgt, memory,
127
+ memory_mask: Optional[Tensor] = None,
128
+ memory_key_padding_mask: Optional[Tensor] = None,
129
+ pos: Optional[Tensor] = None,
130
+ query_pos: Optional[Tensor] = None):
131
+ if self.normalize_before:
132
+ return self.forward_pre(tgt, memory, memory_mask,
133
+ memory_key_padding_mask, pos, query_pos)
134
+ return self.forward_post(tgt, memory, memory_mask,
135
+ memory_key_padding_mask, pos, query_pos)
136
+
137
+
138
+ class FFNLayer(nn.Module):
139
+
140
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
141
+ activation="relu", normalize_before=False):
142
+ super().__init__()
143
+ # Implementation of Feedforward model
144
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
145
+ self.dropout = nn.Dropout(dropout)
146
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
147
+
148
+ self.norm = nn.LayerNorm(d_model)
149
+
150
+ self.activation = _get_activation_fn(activation)
151
+ self.normalize_before = normalize_before
152
+
153
+ self._reset_parameters()
154
+
155
+ def _reset_parameters(self):
156
+ for p in self.parameters():
157
+ if p.dim() > 1:
158
+ nn.init.xavier_uniform_(p)
159
+
160
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
161
+ return tensor if pos is None else tensor + pos
162
+
163
+ def forward_post(self, tgt):
164
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
165
+ tgt = tgt + self.dropout(tgt2)
166
+ tgt = self.norm(tgt)
167
+ return tgt
168
+
169
+ def forward_pre(self, tgt):
170
+ tgt2 = self.norm(tgt)
171
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
172
+ tgt = tgt + self.dropout(tgt2)
173
+ return tgt
174
+
175
+ def forward(self, tgt):
176
+ if self.normalize_before:
177
+ return self.forward_pre(tgt)
178
+ return self.forward_post(tgt)
179
+
180
+
181
+ def _get_activation_fn(activation):
182
+ """Return an activation function given a string"""
183
+ if activation == "relu":
184
+ return F.relu
185
+ if activation == "gelu":
186
+ return F.gelu
187
+ if activation == "glu":
188
+ return F.glu
189
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
190
+
191
+
192
+ class MLP(nn.Module):
193
+ """ Very simple multi-layer perceptron (also called FFN)"""
194
+
195
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
196
+ super().__init__()
197
+ self.num_layers = num_layers
198
+ h = [hidden_dim] * (num_layers - 1)
199
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
200
+
201
+ def forward(self, x):
202
+ for i, layer in enumerate(self.layers):
203
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
204
+ return x
205
+
206
+
207
+ @TRANSFORMER_DECODER_REGISTRY.register()
208
+ class MultiScaleMaskedTransformerDecoder(nn.Module):
209
+
210
+ _version = 2
211
+
212
+ def _load_from_state_dict(
213
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
214
+ ):
215
+ version = local_metadata.get("version", None)
216
+ if version is None or version < 2:
217
+ # Do not warn if train from scratch
218
+ scratch = True
219
+ logger = logging.getLogger(__name__)
220
+ for k in list(state_dict.keys()):
221
+ newk = k
222
+ if "static_query" in k:
223
+ newk = k.replace("static_query", "query_feat")
224
+ if newk != k:
225
+ state_dict[newk] = state_dict[k]
226
+ del state_dict[k]
227
+ scratch = False
228
+
229
+ if not scratch:
230
+ logger.warning(
231
+ f"Weight format of {self.__class__.__name__} have changed! "
232
+ "Please upgrade your models. Applying automatic conversion now ..."
233
+ )
234
+
235
+ @configurable
236
+ def __init__(
237
+ self,
238
+ in_channels,
239
+ mask_classification=True,
240
+ *,
241
+ num_classes: int,
242
+ hidden_dim: int,
243
+ num_queries: int,
244
+ nheads: int,
245
+ dim_feedforward: int,
246
+ dec_layers: int,
247
+ pre_norm: bool,
248
+ mask_dim: int,
249
+ enforce_input_project: bool,
250
+ ):
251
+ """
252
+ NOTE: this interface is experimental.
253
+ Args:
254
+ in_channels: channels of the input features
255
+ mask_classification: whether to add mask classifier or not
256
+ num_classes: number of classes
257
+ hidden_dim: Transformer feature dimension
258
+ num_queries: number of queries
259
+ nheads: number of heads
260
+ dim_feedforward: feature dimension in feedforward network
261
+ enc_layers: number of Transformer encoder layers
262
+ dec_layers: number of Transformer decoder layers
263
+ pre_norm: whether to use pre-LayerNorm or not
264
+ mask_dim: mask feature dimension
265
+ enforce_input_project: add input project 1x1 conv even if input
266
+ channels and hidden dim is identical
267
+ """
268
+ super().__init__()
269
+
270
+ assert mask_classification, "Only support mask classification model"
271
+ self.mask_classification = mask_classification
272
+
273
+ # positional encoding
274
+ N_steps = hidden_dim // 2
275
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
276
+
277
+ # define Transformer decoder here
278
+ self.num_heads = nheads
279
+ self.num_layers = dec_layers
280
+ self.transformer_self_attention_layers = nn.ModuleList()
281
+ self.transformer_cross_attention_layers = nn.ModuleList()
282
+ self.transformer_ffn_layers = nn.ModuleList()
283
+
284
+ for _ in range(self.num_layers):
285
+ self.transformer_self_attention_layers.append(
286
+ SelfAttentionLayer(
287
+ d_model=hidden_dim,
288
+ nhead=nheads,
289
+ dropout=0.0,
290
+ normalize_before=pre_norm,
291
+ )
292
+ )
293
+
294
+ self.transformer_cross_attention_layers.append(
295
+ CrossAttentionLayer(
296
+ d_model=hidden_dim,
297
+ nhead=nheads,
298
+ dropout=0.0,
299
+ normalize_before=pre_norm,
300
+ )
301
+ )
302
+
303
+ self.transformer_ffn_layers.append(
304
+ FFNLayer(
305
+ d_model=hidden_dim,
306
+ dim_feedforward=dim_feedforward,
307
+ dropout=0.0,
308
+ normalize_before=pre_norm,
309
+ )
310
+ )
311
+
312
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
313
+
314
+ self.num_queries = num_queries
315
+ # learnable query features
316
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
317
+ # learnable query p.e.
318
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
319
+
320
+ # level embedding (we always use 3 scales)
321
+ self.num_feature_levels = 3
322
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
323
+ self.input_proj = nn.ModuleList()
324
+ for _ in range(self.num_feature_levels):
325
+ if in_channels != hidden_dim or enforce_input_project:
326
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
327
+ weight_init.c2_xavier_fill(self.input_proj[-1])
328
+ else:
329
+ self.input_proj.append(nn.Sequential())
330
+
331
+ # output FFNs
332
+ if self.mask_classification:
333
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
334
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
335
+
336
+ @classmethod
337
+ def from_config(cls, cfg, in_channels, mask_classification):
338
+ ret = {}
339
+ ret["in_channels"] = in_channels
340
+ ret["mask_classification"] = mask_classification
341
+
342
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
343
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
344
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
345
+ # Transformer parameters:
346
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
347
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
348
+
349
+ # NOTE: because we add learnable query features which requires supervision,
350
+ # we add minus 1 to decoder layers to be consistent with our loss
351
+ # implementation: that is, number of auxiliary losses is always
352
+ # equal to number of decoder layers. With learnable query features, the number of
353
+ # auxiliary losses equals number of decoders plus 1.
354
+ assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
355
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
356
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
357
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
358
+
359
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
360
+
361
+ return ret
362
+
363
+ def forward(self, x, mask_features, mask = None):
364
+ # x is a list of multi-scale feature
365
+ assert len(x) == self.num_feature_levels
366
+ src = []
367
+ pos = []
368
+ size_list = []
369
+
370
+ # disable mask, it does not affect performance
371
+ del mask
372
+
373
+ for i in range(self.num_feature_levels):
374
+ size_list.append(x[i].shape[-2:])
375
+ pos.append(self.pe_layer(x[i], None).flatten(2))
376
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
377
+
378
+ # flatten NxCxHxW to HWxNxC
379
+ pos[-1] = pos[-1].permute(2, 0, 1)
380
+ src[-1] = src[-1].permute(2, 0, 1)
381
+
382
+ _, bs, _ = src[0].shape
383
+
384
+ # QxNxC
385
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
386
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
387
+
388
+ predictions_class = []
389
+ predictions_mask = []
390
+
391
+ # prediction heads on learnable query features
392
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
393
+ predictions_class.append(outputs_class)
394
+ predictions_mask.append(outputs_mask)
395
+
396
+ for i in range(self.num_layers):
397
+ level_index = i % self.num_feature_levels
398
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
399
+ # attention: cross-attention first
400
+ output = self.transformer_cross_attention_layers[i](
401
+ output, src[level_index],
402
+ memory_mask=attn_mask,
403
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
404
+ pos=pos[level_index], query_pos=query_embed
405
+ )
406
+
407
+ output = self.transformer_self_attention_layers[i](
408
+ output, tgt_mask=None,
409
+ tgt_key_padding_mask=None,
410
+ query_pos=query_embed
411
+ )
412
+
413
+ # FFN
414
+ output = self.transformer_ffn_layers[i](
415
+ output
416
+ )
417
+
418
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
419
+ predictions_class.append(outputs_class)
420
+ predictions_mask.append(outputs_mask)
421
+
422
+ assert len(predictions_class) == self.num_layers + 1
423
+
424
+ out = {
425
+ 'pred_logits': predictions_class[-1],
426
+ 'pred_masks': predictions_mask[-1],
427
+ 'aux_outputs': self._set_aux_loss(
428
+ predictions_class if self.mask_classification else None, predictions_mask
429
+ )
430
+ }
431
+ return out
432
+
433
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):
434
+ decoder_output = self.decoder_norm(output)
435
+ decoder_output = decoder_output.transpose(0, 1)
436
+ outputs_class = self.class_embed(decoder_output)
437
+ mask_embed = self.mask_embed(decoder_output)
438
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
439
+
440
+ # NOTE: prediction is of higher-resolution
441
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
442
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
443
+ # must use bool type
444
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
445
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
446
+ attn_mask = attn_mask.detach()
447
+
448
+ return outputs_class, outputs_mask, attn_mask
449
+
450
+ @torch.jit.unused
451
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
452
+ # this is a workaround to make torchscript happy, as torchscript
453
+ # doesn't support dictionary with non-homogeneous values, such
454
+ # as a dict having both a Tensor and a list.
455
+ if self.mask_classification:
456
+ return [
457
+ {"pred_logits": a, "pred_masks": b}
458
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
459
+ ]
460
+ else:
461
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
annotator/entityseg/mask2former/modeling/transformer_decoder/maskformer_transformer_decoder.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ import fvcore.nn.weight_init as weight_init
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.layers import Conv2d
10
+ from detectron2.utils.registry import Registry
11
+
12
+ from .position_encoding import PositionEmbeddingSine
13
+ from .transformer import Transformer
14
+
15
+
16
+ TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE")
17
+ TRANSFORMER_DECODER_REGISTRY.__doc__ = """
18
+ Registry for transformer module in MaskFormer.
19
+ """
20
+
21
+
22
+ def build_transformer_decoder(cfg, in_channels, mask_classification=True):
23
+ """
24
+ Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`.
25
+ """
26
+ name = cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME
27
+ return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels, mask_classification)
28
+
29
+
30
+ @TRANSFORMER_DECODER_REGISTRY.register()
31
+ class StandardTransformerDecoder(nn.Module):
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ mask_classification=True,
37
+ *,
38
+ num_classes: int,
39
+ hidden_dim: int,
40
+ num_queries: int,
41
+ nheads: int,
42
+ dropout: float,
43
+ dim_feedforward: int,
44
+ enc_layers: int,
45
+ dec_layers: int,
46
+ pre_norm: bool,
47
+ deep_supervision: bool,
48
+ mask_dim: int,
49
+ enforce_input_project: bool,
50
+ ):
51
+ """
52
+ NOTE: this interface is experimental.
53
+ Args:
54
+ in_channels: channels of the input features
55
+ mask_classification: whether to add mask classifier or not
56
+ num_classes: number of classes
57
+ hidden_dim: Transformer feature dimension
58
+ num_queries: number of queries
59
+ nheads: number of heads
60
+ dropout: dropout in Transformer
61
+ dim_feedforward: feature dimension in feedforward network
62
+ enc_layers: number of Transformer encoder layers
63
+ dec_layers: number of Transformer decoder layers
64
+ pre_norm: whether to use pre-LayerNorm or not
65
+ deep_supervision: whether to add supervision to every decoder layers
66
+ mask_dim: mask feature dimension
67
+ enforce_input_project: add input project 1x1 conv even if input
68
+ channels and hidden dim is identical
69
+ """
70
+ super().__init__()
71
+
72
+ self.mask_classification = mask_classification
73
+
74
+ # positional encoding
75
+ N_steps = hidden_dim // 2
76
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
77
+
78
+ transformer = Transformer(
79
+ d_model=hidden_dim,
80
+ dropout=dropout,
81
+ nhead=nheads,
82
+ dim_feedforward=dim_feedforward,
83
+ num_encoder_layers=enc_layers,
84
+ num_decoder_layers=dec_layers,
85
+ normalize_before=pre_norm,
86
+ return_intermediate_dec=deep_supervision,
87
+ )
88
+
89
+ self.num_queries = num_queries
90
+ self.transformer = transformer
91
+ hidden_dim = transformer.d_model
92
+
93
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
94
+
95
+ if in_channels != hidden_dim or enforce_input_project:
96
+ self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
97
+ weight_init.c2_xavier_fill(self.input_proj)
98
+ else:
99
+ self.input_proj = nn.Sequential()
100
+ self.aux_loss = deep_supervision
101
+
102
+ # output FFNs
103
+ if self.mask_classification:
104
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
105
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
106
+
107
+ @classmethod
108
+ def from_config(cls, cfg, in_channels, mask_classification):
109
+ ret = {}
110
+ ret["in_channels"] = in_channels
111
+ ret["mask_classification"] = mask_classification
112
+
113
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
114
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
115
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
116
+ # Transformer parameters:
117
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
118
+ ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
119
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
120
+ ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
121
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
122
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
123
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
124
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
125
+
126
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
127
+
128
+ return ret
129
+
130
+ def forward(self, x, mask_features, mask=None):
131
+ if mask is not None:
132
+ mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
133
+ pos = self.pe_layer(x, mask)
134
+
135
+ src = x
136
+ hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
137
+
138
+ if self.mask_classification:
139
+ outputs_class = self.class_embed(hs)
140
+ out = {"pred_logits": outputs_class[-1]}
141
+ else:
142
+ out = {}
143
+
144
+ if self.aux_loss:
145
+ # [l, bs, queries, embed]
146
+ mask_embed = self.mask_embed(hs)
147
+ outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
148
+ out["pred_masks"] = outputs_seg_masks[-1]
149
+ out["aux_outputs"] = self._set_aux_loss(
150
+ outputs_class if self.mask_classification else None, outputs_seg_masks
151
+ )
152
+ else:
153
+ # FIXME h_boxes takes the last one computed, keep this in mind
154
+ # [bs, queries, embed]
155
+ mask_embed = self.mask_embed(hs[-1])
156
+ outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
157
+ out["pred_masks"] = outputs_seg_masks
158
+ return out
159
+
160
+ @torch.jit.unused
161
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
162
+ # this is a workaround to make torchscript happy, as torchscript
163
+ # doesn't support dictionary with non-homogeneous values, such
164
+ # as a dict having both a Tensor and a list.
165
+ if self.mask_classification:
166
+ return [
167
+ {"pred_logits": a, "pred_masks": b}
168
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
169
+ ]
170
+ else:
171
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
172
+
173
+
174
+ class MLP(nn.Module):
175
+ """Very simple multi-layer perceptron (also called FFN)"""
176
+
177
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
178
+ super().__init__()
179
+ self.num_layers = num_layers
180
+ h = [hidden_dim] * (num_layers - 1)
181
+ self.layers = nn.ModuleList(
182
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
183
+ )
184
+
185
+ def forward(self, x):
186
+ for i, layer in enumerate(self.layers):
187
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
188
+ return x
annotator/entityseg/mask2former/modeling/transformer_decoder/position_encoding.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3
+ """
4
+ Various positional encodings for the transformer.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+
18
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * math.pi
27
+ self.scale = scale
28
+
29
+ def forward(self, x, mask=None):
30
+ if mask is None:
31
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+ dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="trunc")) / self.num_pos_feats)
43
+
44
+ pos_x = x_embed[:, :, :, None] / dim_t
45
+ pos_y = y_embed[:, :, :, None] / dim_t
46
+ pos_x = torch.stack(
47
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
48
+ ).flatten(3)
49
+ pos_y = torch.stack(
50
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
51
+ ).flatten(3)
52
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
53
+ return pos
54
+
55
+ def __repr__(self, _repr_indent=4):
56
+ head = "Positional encoding " + self.__class__.__name__
57
+ body = [
58
+ "num_pos_feats: {}".format(self.num_pos_feats),
59
+ "temperature: {}".format(self.temperature),
60
+ "normalize: {}".format(self.normalize),
61
+ "scale: {}".format(self.scale),
62
+ ]
63
+ # _repr_indent = 4
64
+ lines = [head] + [" " * _repr_indent + line for line in body]
65
+ return "\n".join(lines)
66
+
67
+ class PositionEmbeddingSine3D2D(nn.Module):
68
+ """
69
+ This is a more standard version of the position embedding, very similar to the one
70
+ used by the Attention is all you need paper, generalized to work on images.
71
+ """
72
+
73
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
74
+ super().__init__()
75
+ self.num_pos_feats = num_pos_feats
76
+ self.temperature = temperature
77
+ self.normalize = normalize
78
+ if scale is not None and normalize is False:
79
+ raise ValueError("normalize should be True if scale is passed")
80
+ if scale is None:
81
+ scale = 2 * math.pi
82
+ self.scale = scale
83
+
84
+ def forward(self, x, mask=None):
85
+ ## b, t, c, h, w
86
+ assert x.dim()==5, f"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead"
87
+ if mask is None:
88
+ mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool)
89
+ not_mask = ~mask
90
+ z_embed = not_mask.cumsum(1, dtype=torch.float32)
91
+ y_embed = not_mask.cumsum(2, dtype=torch.float32)
92
+ x_embed = not_mask.cumsum(3, dtype=torch.float32)
93
+ if self.normalize:
94
+ eps = 1e-6
95
+ z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
96
+ y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
97
+ x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale
98
+
99
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
100
+ # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101
+ dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="trunc")) / self.num_pos_feats)
102
+
103
+ dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device)
104
+ # dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2))
105
+ dim_t_z = self.temperature ** (2 * (torch.div(dim_t_z, 2, rounding_mode="trunc")) / (self.num_pos_feats*2))
106
+
107
+ pos_x = x_embed[:, :, :, :, None] / dim_t
108
+ pos_y = y_embed[:, :, :, :, None] / dim_t
109
+ pos_z = z_embed[:, :, :, :, None] / dim_t_z
110
+
111
+ pos_x = torch.stack(
112
+ (pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5
113
+ ).flatten(4)
114
+ pos_y = torch.stack(
115
+ (pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5
116
+ ).flatten(4)
117
+ pos_z = torch.stack(
118
+ (pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5
119
+ ).flatten(4)
120
+ pos2d = torch.cat((pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3).flatten(0,1)
121
+ pos3d = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3)
122
+ return pos2d, pos3d
123
+
124
+ def __repr__(self, _repr_indent=4):
125
+ head = "Positional encoding " + self.__class__.__name__
126
+ body = [
127
+ "num_pos_feats: {}".format(self.num_pos_feats),
128
+ "temperature: {}".format(self.temperature),
129
+ "normalize: {}".format(self.normalize),
130
+ "scale: {}".format(self.scale),
131
+ ]
132
+ # _repr_indent = 4
133
+ lines = [head] + [" " * _repr_indent + line for line in body]
134
+ return "\n".join(lines)