PeiqingYang commited on
Commit
dcc8c59
·
1 Parent(s): 941f6aa
Files changed (46) hide show
  1. .gitignore +8 -0
  2. LICENSE +42 -0
  3. README.md +5 -4
  4. hugging_face/app.py +967 -0
  5. hugging_face/matanyone_wrapper.py +73 -0
  6. hugging_face/tools/__init__.py +0 -0
  7. hugging_face/tools/base_segmenter.py +129 -0
  8. hugging_face/tools/download_util.py +109 -0
  9. hugging_face/tools/interact_tools.py +99 -0
  10. hugging_face/tools/mask_painter.py +288 -0
  11. hugging_face/tools/misc.py +131 -0
  12. hugging_face/tools/painter.py +215 -0
  13. matanyone/config/__init__.py +0 -0
  14. matanyone/config/eval_matanyone_config.yaml +47 -0
  15. matanyone/config/hydra/job_logging/custom-no-rank.yaml +22 -0
  16. matanyone/config/hydra/job_logging/custom.yaml +22 -0
  17. matanyone/config/model/base.yaml +58 -0
  18. matanyone/inference/__init__.py +0 -0
  19. matanyone/inference/image_feature_store.py +56 -0
  20. matanyone/inference/inference_core.py +407 -0
  21. matanyone/inference/kv_memory_store.py +348 -0
  22. matanyone/inference/memory_manager.py +457 -0
  23. matanyone/inference/object_info.py +24 -0
  24. matanyone/inference/object_manager.py +149 -0
  25. matanyone/inference/utils/__init__.py +0 -0
  26. matanyone/inference/utils/args_utils.py +30 -0
  27. matanyone/model/__init__.py +0 -0
  28. matanyone/model/aux_modules.py +93 -0
  29. matanyone/model/big_modules.py +358 -0
  30. matanyone/model/channel_attn.py +39 -0
  31. matanyone/model/group_modules.py +126 -0
  32. matanyone/model/matanyone.py +323 -0
  33. matanyone/model/modules.py +170 -0
  34. matanyone/model/transformer/__init__.py +0 -0
  35. matanyone/model/transformer/object_summarizer.py +89 -0
  36. matanyone/model/transformer/object_transformer.py +206 -0
  37. matanyone/model/transformer/positional_encoding.py +108 -0
  38. matanyone/model/transformer/transformer_layers.py +161 -0
  39. matanyone/model/utils/__init__.py +0 -0
  40. matanyone/model/utils/memory_utils.py +107 -0
  41. matanyone/model/utils/parameter_groups.py +72 -0
  42. matanyone/model/utils/resnet.py +179 -0
  43. matanyone/utils/__init__.py +0 -0
  44. matanyone/utils/get_default_model.py +23 -0
  45. matanyone/utils/tensor_utils.py +62 -0
  46. requirements.txt +35 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .vscode/
3
+ .DS_Store
4
+ assets/
5
+ inputs/
6
+ test_sample/
7
+ results/
8
+ pretrained_models/
LICENSE ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # S-Lab License 1.0
2
+
3
+ Copyright 2023 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
36
+
37
+
38
+ ---
39
+ For inquiries permission for commercial use, please consult our team:
40
+ Peiqing Yang ([email protected]),
41
+ Dr. Shangchen Zhou ([email protected]),
42
+ Prof. Chen Change Loy ([email protected])
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: MatAnyone
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
- app_file: app.py
9
  pinned: false
10
  license: other
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MatAnyone
3
+ emoji: 🤡
4
+ colorFrom: red
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
+ app_file: hugging_face/app.py
9
  pinned: false
10
  license: other
11
+ short_description: Gradio demo for MatAnyone
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
hugging_face/app.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("../")
3
+ sys.path.append("../../")
4
+
5
+ import os
6
+ import json
7
+ import time
8
+ import psutil
9
+ import ffmpeg
10
+ import imageio
11
+ import argparse
12
+ from PIL import Image
13
+
14
+ import cv2
15
+ import torch
16
+ import numpy as np
17
+ import gradio as gr
18
+
19
+ from tools.painter import mask_painter
20
+ from tools.interact_tools import SamControler
21
+ from tools.misc import get_device
22
+ from tools.download_util import load_file_from_url
23
+
24
+ from matanyone_wrapper import matanyone
25
+ from matanyone.utils.get_default_model import get_matanyone_model
26
+ from matanyone.inference.inference_core import InferenceCore
27
+
28
+ def parse_augment():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--device', type=str, default=None)
31
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
32
+ parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
33
+ parser.add_argument('--mask_save', default=False)
34
+ args = parser.parse_args()
35
+
36
+ if not args.device:
37
+ args.device = str(get_device())
38
+
39
+ return args
40
+
41
+ # SAM generator
42
+ class MaskGenerator():
43
+ def __init__(self, sam_checkpoint, args):
44
+ self.args = args
45
+ self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
46
+
47
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
48
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
49
+ return mask, logit, painted_image
50
+
51
+ # convert points input to prompt state
52
+ def get_prompt(click_state, click_input):
53
+ inputs = json.loads(click_input)
54
+ points = click_state[0]
55
+ labels = click_state[1]
56
+ for input in inputs:
57
+ points.append(input[:2])
58
+ labels.append(input[2])
59
+ click_state[0] = points
60
+ click_state[1] = labels
61
+ prompt = {
62
+ "prompt_type":["click"],
63
+ "input_point":click_state[0],
64
+ "input_label":click_state[1],
65
+ "multimask_output":"True",
66
+ }
67
+ return prompt
68
+
69
+ def get_frames_from_image(image_input, image_state):
70
+ """
71
+ Args:
72
+ video_path:str
73
+ timestamp:float64
74
+ Return
75
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
76
+ """
77
+
78
+ user_name = time.time()
79
+ frames = [image_input] * 2 # hardcode: mimic a video with 2 frames
80
+ image_size = (frames[0].shape[0],frames[0].shape[1])
81
+ # initialize video_state
82
+ image_state = {
83
+ "user_name": user_name,
84
+ "image_name": "output.png",
85
+ "origin_images": frames,
86
+ "painted_images": frames.copy(),
87
+ "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
88
+ "logits": [None]*len(frames),
89
+ "select_frame_number": 0,
90
+ "fps": None
91
+ }
92
+ image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
93
+ model.samcontroler.sam_controler.reset_image()
94
+ model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
95
+ return image_state, image_info, image_state["origin_images"][0], \
96
+ gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
97
+ gr.update(visible=True), gr.update(visible=True), \
98
+ gr.update(visible=True), gr.update(visible=True),\
99
+ gr.update(visible=True), gr.update(visible=True), \
100
+ gr.update(visible=True), gr.update(visible=True), \
101
+ gr.update(visible=True), gr.update(visible=True), \
102
+ gr.update(visible=True), gr.update(visible=True, value=[]), \
103
+ gr.update(visible=True)
104
+
105
+ # extract frames from upload video
106
+ def get_frames_from_video(video_input, video_state):
107
+ """
108
+ Args:
109
+ video_path:str
110
+ timestamp:float64
111
+ Return
112
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
113
+ """
114
+ video_path = video_input
115
+ frames = []
116
+ user_name = time.time()
117
+
118
+ # extract Audio
119
+ audio_path = "audio.wav"
120
+ audio_path = video_input.replace(".mp4", "_audio.wav")
121
+ try:
122
+ ffmpeg.input(video_path).output(audio_path, format='wav', acodec='pcm_s16le', ac=2, ar='44100').run(overwrite_output=True, quiet=True)
123
+ except Exception as e:
124
+ print(f"Audio extraction error: {str(e)}")
125
+ audio_path = "" # Set to "" if extraction fails
126
+ # print(f'audio_path: {audio_path}')
127
+
128
+ # extract frames
129
+ try:
130
+ cap = cv2.VideoCapture(video_path)
131
+ fps = cap.get(cv2.CAP_PROP_FPS)
132
+ while cap.isOpened():
133
+ ret, frame = cap.read()
134
+ if ret == True:
135
+ current_memory_usage = psutil.virtual_memory().percent
136
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
137
+ if current_memory_usage > 90:
138
+ break
139
+ else:
140
+ break
141
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
142
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
143
+ image_size = (frames[0].shape[0],frames[0].shape[1])
144
+
145
+ # initialize video_state
146
+ video_state = {
147
+ "user_name": user_name,
148
+ "video_name": os.path.split(video_path)[-1],
149
+ "origin_images": frames,
150
+ "painted_images": frames.copy(),
151
+ "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
152
+ "logits": [None]*len(frames),
153
+ "select_frame_number": 0,
154
+ "fps": fps,
155
+ "audio": audio_path
156
+ }
157
+ video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
158
+ model.samcontroler.sam_controler.reset_image()
159
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
160
+ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
161
+ gr.update(visible=True), gr.update(visible=True), \
162
+ gr.update(visible=True), gr.update(visible=True),\
163
+ gr.update(visible=True), gr.update(visible=True), \
164
+ gr.update(visible=True), gr.update(visible=False), \
165
+ gr.update(visible=False), gr.update(visible=True), \
166
+ gr.update(visible=True)
167
+
168
+ # get the select frame from gradio slider
169
+ def select_video_template(image_selection_slider, video_state, interactive_state):
170
+
171
+ image_selection_slider -= 1
172
+ video_state["select_frame_number"] = image_selection_slider
173
+
174
+ # once select a new template frame, set the image in sam
175
+ model.samcontroler.sam_controler.reset_image()
176
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
177
+
178
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state
179
+
180
+ def select_image_template(image_selection_slider, video_state, interactive_state):
181
+
182
+ image_selection_slider = 0 # fixed for image
183
+ video_state["select_frame_number"] = image_selection_slider
184
+
185
+ # once select a new template frame, set the image in sam
186
+ model.samcontroler.sam_controler.reset_image()
187
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
188
+
189
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state
190
+
191
+ # set the tracking end frame
192
+ def get_end_number(track_pause_number_slider, video_state, interactive_state):
193
+ interactive_state["track_end_number"] = track_pause_number_slider
194
+
195
+ return video_state["painted_images"][track_pause_number_slider],interactive_state
196
+
197
+ # use sam to get the mask
198
+ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
199
+ """
200
+ Args:
201
+ template_frame: PIL.Image
202
+ point_prompt: flag for positive or negative button click
203
+ click_state: [[points], [labels]]
204
+ """
205
+ if point_prompt == "Positive":
206
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
207
+ interactive_state["positive_click_times"] += 1
208
+ else:
209
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
210
+ interactive_state["negative_click_times"] += 1
211
+
212
+ # prompt for sam model
213
+ model.samcontroler.sam_controler.reset_image()
214
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
215
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
216
+
217
+ mask, logit, painted_image = model.first_frame_click(
218
+ image=video_state["origin_images"][video_state["select_frame_number"]],
219
+ points=np.array(prompt["input_point"]),
220
+ labels=np.array(prompt["input_label"]),
221
+ multimask=prompt["multimask_output"],
222
+ )
223
+ video_state["masks"][video_state["select_frame_number"]] = mask
224
+ video_state["logits"][video_state["select_frame_number"]] = logit
225
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
226
+
227
+ return painted_image, video_state, interactive_state
228
+
229
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
230
+ mask = video_state["masks"][video_state["select_frame_number"]]
231
+ interactive_state["multi_mask"]["masks"].append(mask)
232
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
233
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
234
+ select_frame = show_mask(video_state, interactive_state, mask_dropdown)
235
+
236
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
237
+
238
+ def clear_click(video_state, click_state):
239
+ click_state = [[],[]]
240
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
241
+ return template_frame, click_state
242
+
243
+ def remove_multi_mask(interactive_state, mask_dropdown):
244
+ interactive_state["multi_mask"]["mask_names"]= []
245
+ interactive_state["multi_mask"]["masks"] = []
246
+
247
+ return interactive_state, gr.update(choices=[],value=[])
248
+
249
+ def show_mask(video_state, interactive_state, mask_dropdown):
250
+ mask_dropdown.sort()
251
+ if video_state["origin_images"]:
252
+ select_frame = video_state["origin_images"][video_state["select_frame_number"]]
253
+ for i in range(len(mask_dropdown)):
254
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
255
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
256
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
257
+
258
+ return select_frame
259
+
260
+ # image matting
261
+ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
262
+ matanyone_processor.clear_memory()
263
+ if interactive_state["track_end_number"]:
264
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
265
+ else:
266
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
267
+
268
+ if interactive_state["multi_mask"]["masks"]:
269
+ if len(mask_dropdown) == 0:
270
+ mask_dropdown = ["mask_001"]
271
+ mask_dropdown.sort()
272
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
273
+ for i in range(1,len(mask_dropdown)):
274
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
275
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
276
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
277
+ else:
278
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
279
+
280
+ # operation error
281
+ if len(np.unique(template_mask))==1:
282
+ template_mask[0][0]=1
283
+ foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
284
+ foreground_output = Image.fromarray(foreground[-1])
285
+ alpha_output = Image.fromarray(alpha[-1][:,:,0])
286
+ return foreground_output, alpha_output
287
+
288
+ # video matting
289
+ def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
290
+ matanyone_processor.clear_memory()
291
+ if interactive_state["track_end_number"]:
292
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
293
+ else:
294
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
295
+
296
+ if interactive_state["multi_mask"]["masks"]:
297
+ if len(mask_dropdown) == 0:
298
+ mask_dropdown = ["mask_001"]
299
+ mask_dropdown.sort()
300
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
301
+ for i in range(1,len(mask_dropdown)):
302
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
303
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
304
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
305
+ else:
306
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
307
+ fps = video_state["fps"]
308
+
309
+ audio_path = video_state["audio"]
310
+
311
+ # operation error
312
+ if len(np.unique(template_mask))==1:
313
+ template_mask[0][0]=1
314
+ foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
315
+
316
+ foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
317
+ alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
318
+
319
+ return foreground_output, alpha_output
320
+
321
+
322
+ def add_audio_to_video(video_path, audio_path, output_path):
323
+ try:
324
+ video_input = ffmpeg.input(video_path)
325
+ audio_input = ffmpeg.input(audio_path)
326
+
327
+ _ = (
328
+ ffmpeg
329
+ .output(video_input, audio_input, output_path, vcodec="copy", acodec="aac")
330
+ .run(overwrite_output=True, capture_stdout=True, capture_stderr=True)
331
+ )
332
+ return output_path
333
+ except ffmpeg.Error as e:
334
+ print(f"FFmpeg error:\n{e.stderr.decode()}")
335
+ return None
336
+
337
+
338
+ def generate_video_from_frames(frames, output_path, fps=30, gray2rgb=False, audio_path=""):
339
+ """
340
+ Generates a video from a list of frames.
341
+
342
+ Args:
343
+ frames (list of numpy arrays): The frames to include in the video.
344
+ output_path (str): The path to save the generated video.
345
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
346
+ """
347
+ frames = torch.from_numpy(np.asarray(frames))
348
+ _, h, w, _ = frames.shape
349
+ if gray2rgb:
350
+ frames = np.repeat(frames, 3, axis=3)
351
+
352
+ if not os.path.exists(os.path.dirname(output_path)):
353
+ os.makedirs(os.path.dirname(output_path))
354
+ video_temp_path = output_path.replace(".mp4", "_temp.mp4")
355
+
356
+ # resize back to ensure input resolution
357
+ imageio.mimwrite(video_temp_path, frames, fps=fps, quality=7,
358
+ codec='libx264', ffmpeg_params=["-vf", f"scale={w}:{h}"])
359
+
360
+ # add audio to video if audio path exists
361
+ if audio_path != "" and os.path.exists(audio_path):
362
+ output_path = add_audio_to_video(video_temp_path, audio_path, output_path)
363
+ os.remove(video_temp_path)
364
+ return output_path
365
+ else:
366
+ return video_temp_path
367
+
368
+ # reset all states for a new input
369
+ def restart():
370
+ return {
371
+ "user_name": "",
372
+ "video_name": "",
373
+ "origin_images": None,
374
+ "painted_images": None,
375
+ "masks": None,
376
+ "inpaint_masks": None,
377
+ "logits": None,
378
+ "select_frame_number": 0,
379
+ "fps": 30
380
+ }, {
381
+ "inference_times": 0,
382
+ "negative_click_times" : 0,
383
+ "positive_click_times": 0,
384
+ "mask_save": args.mask_save,
385
+ "multi_mask": {
386
+ "mask_names": [],
387
+ "masks": []
388
+ },
389
+ "track_end_number": None,
390
+ }, [[],[]], None, None, \
391
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
392
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
393
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
394
+ gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
395
+
396
+ # args, defined in track_anything.py
397
+ args = parse_augment()
398
+ sam_checkpoint_url_dict = {
399
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
400
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
401
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
402
+ }
403
+ checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models')
404
+
405
+ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder)
406
+ # initialize sams
407
+ model = MaskGenerator(sam_checkpoint, args)
408
+
409
+ # initialize matanyone
410
+ pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0"
411
+ ckpt_path = load_file_from_url(os.path.join(pretrain_model_url, 'matanyone.pth'), checkpoint_folder)
412
+ matanyone_model = get_matanyone_model(ckpt_path, args.device)
413
+ matanyone_model = matanyone_model.to(args.device).eval()
414
+ matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
415
+
416
+ # download test samples
417
+ media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
418
+ test_sample_path = os.path.join('/home/user/app/hugging_face/', "test_sample/")
419
+ load_file_from_url(os.path.join(media_url, 'test-sample0.mp4'), test_sample_path)
420
+ load_file_from_url(os.path.join(media_url, 'test-sample1.mp4'), test_sample_path)
421
+ load_file_from_url(os.path.join(media_url, 'test-sample2.mp4'), test_sample_path)
422
+ load_file_from_url(os.path.join(media_url, 'test-sample3.mp4'), test_sample_path)
423
+ load_file_from_url(os.path.join(media_url, 'test-sample4.mp4'), test_sample_path)
424
+ load_file_from_url(os.path.join(media_url, 'test-sample5.mp4'), test_sample_path)
425
+ load_file_from_url(os.path.join(media_url, 'test-sample6.mp4'), test_sample_path)
426
+ load_file_from_url(os.path.join(media_url, 'test-sample0.png'), test_sample_path)
427
+ load_file_from_url(os.path.join(media_url, 'test-sample1.png'), test_sample_path)
428
+
429
+ # download assets
430
+ assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
431
+ load_file_from_url(os.path.join(media_url, 'tutorial_single_target.mp4'), assets_path)
432
+ load_file_from_url(os.path.join(media_url, 'tutorial_multi_targets.mp4'), assets_path)
433
+
434
+ # documents
435
+ title = r"""<div class="multi-layer" align="center"><span>MatAnyone</span></div>
436
+ """
437
+ description = r"""
438
+ <b>Official Gradio demo</b> for <a href='https://github.com/pq-yang/MatAnyone' target='_blank'><b>MatAnyone: Stable Video Matting with Consistent Memory Propagation</b></a>.<br>
439
+ 🔥 MatAnyone is a practical human video matting framework supporting target assignment 🎯.<br>
440
+ 🎪 Try to drop your video/image, assign the target masks with a few clicks, and get the the matting results 🤡!<br>
441
+ """
442
+ article = r"""
443
+ <b>If MatAnyone is helpful, please help to 🌟 the <a href='https://github.com/pq-yang/MatAnyone' target='_blank'>Github Repo</a>. Thanks!</b>
444
+
445
+ ---
446
+
447
+ 📑 **Citation**
448
+ <br>
449
+ If our work is useful for your research, please consider citing:
450
+ ```bibtex
451
+ @InProceedings{yang2025matanyone,
452
+ title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation},
453
+ author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change},
454
+ booktitle = {arXiv preprint arXiv:2501.14677},
455
+ year = {2025}
456
+ }
457
+ ```
458
+ 📝 **License**
459
+ <br>
460
+ This project is licensed under <a rel="license" href="https://github.com/pq-yang/MatAnyone/blob/main/LICENSE">S-Lab License 1.0</a>.
461
+ Redistribution and use for non-commercial purposes should follow this license.
462
+ <br>
463
+ 📧 **Contact**
464
+ <br>
465
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
466
+ <br>
467
+ 👏 **Acknowledgement**
468
+ <br>
469
+ The project is developed upon [Cutie](https://github.com/hkchengrex/Cutie), and harnesses the capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works!
470
+ """
471
+
472
+ my_custom_css = """
473
+ .gradio-container {width: 85% !important; margin: 0 auto;}
474
+ .gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important}
475
+ button {border-radius: 8px !important;}
476
+ .new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;}
477
+ .green_button {background-color: #4CAF50 !important; color: #ffffff !important; border: none !important;}
478
+ .new_button:hover {background-color: #4b4b4b !important;}
479
+ .green_button:hover {background-color: #77bd79 !important;}
480
+
481
+ .mask_button_group {gap: 10px !important;}
482
+ .video .wrap.svelte-lcpz3o {
483
+ display: flex !important;
484
+ align-items: center !important;
485
+ justify-content: center !important;
486
+ height: auto !important;
487
+ max-height: 300px !important;
488
+ }
489
+ .video .wrap.svelte-lcpz3o > :first-child {
490
+ height: auto !important;
491
+ width: 100% !important;
492
+ object-fit: contain !important;
493
+ }
494
+ .video .container.svelte-sxyn79 {
495
+ display: none !important;
496
+ }
497
+ .margin_center {width: 50% !important; margin: auto !important;}
498
+ .jc_center {justify-content: center !important;}
499
+ .video-title {
500
+ margin-bottom: 5px !important;
501
+ }
502
+ .custom-bg {
503
+ background-color: #f0f0f0;
504
+ padding: 10px;
505
+ border-radius: 10px;
506
+ }
507
+
508
+ <style>
509
+ @import url('https://fonts.googleapis.com/css2?family=Sarpanch:wght@400;500;600;700;800;900&family=Sen:[email protected]&family=Sixtyfour+Convergence&family=Stardos+Stencil:wght@400;700&display=swap');
510
+ body {
511
+ display: flex;
512
+ justify-content: center;
513
+ align-items: center;
514
+ height: 100vh;
515
+ margin: 0;
516
+ background-color: #0d1117;
517
+ font-family: Arial, sans-serif;
518
+ font-size: 18px;
519
+ }
520
+ .title-container {
521
+ text-align: center;
522
+ padding: 0;
523
+ margin: 0;
524
+ background: white;
525
+ height: 5vh;
526
+ width: 80vw;
527
+ font-family: "Sarpanch", sans-serif;
528
+ font-weight: 60;
529
+ }
530
+ #custom-markdown {
531
+ font-family: "Roboto", sans-serif;
532
+ font-size: 18px;
533
+ color: #333333;
534
+ font-weight: bold;
535
+ }
536
+ small {
537
+ font-size: 60%;
538
+ }
539
+ </style>
540
+ """
541
+
542
+ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
543
+ gr.HTML('''
544
+ <div class="title-container">
545
+ <h1 class="title is-2 publication-title"
546
+ style="font-size:50px; font-family: 'Sarpanch', serif;
547
+ background: linear-gradient(to right, #d231d8, #2dc464);
548
+ display: inline-block; -webkit-background-clip: text;
549
+ -webkit-text-fill-color: transparent;">
550
+ MatAnyone
551
+ </h1>
552
+ </div>
553
+ ''')
554
+
555
+ gr.Markdown(description)
556
+
557
+ with gr.Group(elem_classes="gr-monochrome-group", visible=True):
558
+ with gr.Row():
559
+ with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
560
+ with gr.Row():
561
+ with gr.Column():
562
+ gr.Markdown("### Case 1: Single Target")
563
+ gr.Video(value="/home/user/app/hugging_face/assets/tutorial_single_target.mp4", elem_classes="video")
564
+
565
+ with gr.Column():
566
+ gr.Markdown("### Case 2: Multiple Targets")
567
+ gr.Video(value="/home/user/app/hugging_face/assets/tutorial_multi_targets.mp4", elem_classes="video")
568
+
569
+ with gr.Tabs():
570
+ with gr.TabItem("Video"):
571
+ click_state = gr.State([[],[]])
572
+
573
+ interactive_state = gr.State({
574
+ "inference_times": 0,
575
+ "negative_click_times" : 0,
576
+ "positive_click_times": 0,
577
+ "mask_save": args.mask_save,
578
+ "multi_mask": {
579
+ "mask_names": [],
580
+ "masks": []
581
+ },
582
+ "track_end_number": None,
583
+ }
584
+ )
585
+
586
+ video_state = gr.State(
587
+ {
588
+ "user_name": "",
589
+ "video_name": "",
590
+ "origin_images": None,
591
+ "painted_images": None,
592
+ "masks": None,
593
+ "inpaint_masks": None,
594
+ "logits": None,
595
+ "select_frame_number": 0,
596
+ "fps": 30,
597
+ "audio": "",
598
+ }
599
+ )
600
+
601
+ with gr.Group(elem_classes="gr-monochrome-group", visible=True):
602
+ with gr.Row():
603
+ with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
604
+ with gr.Row():
605
+ erode_kernel_size = gr.Slider(label='Erode Kernel Size',
606
+ minimum=0,
607
+ maximum=30,
608
+ step=1,
609
+ value=10,
610
+ info="Erosion on the added mask",
611
+ interactive=True)
612
+ dilate_kernel_size = gr.Slider(label='Dilate Kernel Size',
613
+ minimum=0,
614
+ maximum=30,
615
+ step=1,
616
+ value=10,
617
+ info="Dilation on the added mask",
618
+ interactive=True)
619
+
620
+ with gr.Row():
621
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Start Frame", info="Choose the start frame for target assignment and video matting", visible=False)
622
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
623
+ with gr.Row():
624
+ point_prompt = gr.Radio(
625
+ choices=["Positive", "Negative"],
626
+ value="Positive",
627
+ label="Point Prompt",
628
+ info="Click to add positive or negative point for target mask",
629
+ interactive=True,
630
+ visible=False,
631
+ min_width=100,
632
+ scale=1)
633
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False)
634
+
635
+ gr.Markdown("---")
636
+
637
+ with gr.Column():
638
+ # input video
639
+ with gr.Row(equal_height=True):
640
+ with gr.Column(scale=2):
641
+ gr.Markdown("## Step1: Upload video")
642
+ with gr.Column(scale=2):
643
+ step2_title = gr.Markdown("## Step2: Add masks <small>(Several clicks then **`Add Mask`** <u>one by one</u>)</small>", visible=False)
644
+ with gr.Row(equal_height=True):
645
+ with gr.Column(scale=2):
646
+ video_input = gr.Video(label="Input Video", elem_classes="video")
647
+ extract_frames_button = gr.Button(value="Load Video", interactive=True, elem_classes="new_button")
648
+ with gr.Column(scale=2):
649
+ video_info = gr.Textbox(label="Video Info", visible=False)
650
+ template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
651
+ with gr.Row(equal_height=True, elem_classes="mask_button_group"):
652
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, elem_classes="new_button", min_width=100)
653
+ add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100)
654
+ remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100) # no use
655
+ matting_button = gr.Button(value="Video Matting", interactive=True, visible=False, elem_classes="green_button", min_width=100)
656
+
657
+ gr.HTML('<hr style="border: none; height: 1.5px; background: linear-gradient(to right, #a566b4, #74a781);margin: 5px 0;">')
658
+
659
+ # output video
660
+ with gr.Row(equal_height=True):
661
+ with gr.Column(scale=2):
662
+ foreground_video_output = gr.Video(label="Foreground Output", visible=False, elem_classes="video")
663
+ foreground_output_button = gr.Button(value="Foreground Output", visible=False, elem_classes="new_button")
664
+ with gr.Column(scale=2):
665
+ alpha_video_output = gr.Video(label="Alpha Output", visible=False, elem_classes="video")
666
+ alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
667
+
668
+
669
+ # first step: get the video information
670
+ extract_frames_button.click(
671
+ fn=get_frames_from_video,
672
+ inputs=[
673
+ video_input, video_state
674
+ ],
675
+ outputs=[video_state, video_info, template_frame,
676
+ image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
677
+ foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
678
+ )
679
+
680
+ # second step: select images from slider
681
+ image_selection_slider.release(fn=select_video_template,
682
+ inputs=[image_selection_slider, video_state, interactive_state],
683
+ outputs=[template_frame, video_state, interactive_state], api_name="select_image")
684
+ track_pause_number_slider.release(fn=get_end_number,
685
+ inputs=[track_pause_number_slider, video_state, interactive_state],
686
+ outputs=[template_frame, interactive_state], api_name="end_image")
687
+
688
+ # click select image to get mask using sam
689
+ template_frame.select(
690
+ fn=sam_refine,
691
+ inputs=[video_state, point_prompt, click_state, interactive_state],
692
+ outputs=[template_frame, video_state, interactive_state]
693
+ )
694
+
695
+ # add different mask
696
+ add_mask_button.click(
697
+ fn=add_multi_mask,
698
+ inputs=[video_state, interactive_state, mask_dropdown],
699
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state]
700
+ )
701
+
702
+ remove_mask_button.click(
703
+ fn=remove_multi_mask,
704
+ inputs=[interactive_state, mask_dropdown],
705
+ outputs=[interactive_state, mask_dropdown]
706
+ )
707
+
708
+ # video matting
709
+ matting_button.click(
710
+ fn=video_matting,
711
+ inputs=[video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
712
+ outputs=[foreground_video_output, alpha_video_output]
713
+ )
714
+
715
+ # click to get mask
716
+ mask_dropdown.change(
717
+ fn=show_mask,
718
+ inputs=[video_state, interactive_state, mask_dropdown],
719
+ outputs=[template_frame]
720
+ )
721
+
722
+ # clear input
723
+ video_input.change(
724
+ fn=restart,
725
+ inputs=[],
726
+ outputs=[
727
+ video_state,
728
+ interactive_state,
729
+ click_state,
730
+ foreground_video_output, alpha_video_output,
731
+ template_frame,
732
+ image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
733
+ add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
734
+ ],
735
+ queue=False,
736
+ show_progress=False)
737
+
738
+ video_input.clear(
739
+ fn=restart,
740
+ inputs=[],
741
+ outputs=[
742
+ video_state,
743
+ interactive_state,
744
+ click_state,
745
+ foreground_video_output, alpha_video_output,
746
+ template_frame,
747
+ image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
748
+ add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
749
+ ],
750
+ queue=False,
751
+ show_progress=False)
752
+
753
+ # points clear
754
+ clear_button_click.click(
755
+ fn = clear_click,
756
+ inputs = [video_state, click_state,],
757
+ outputs = [template_frame,click_state],
758
+ )
759
+
760
+ # set example
761
+ gr.Markdown("---")
762
+ gr.Markdown("## Examples")
763
+ gr.Examples(
764
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4", "test-sample5.mp4", "test-sample6.mp4"]],
765
+ inputs=[video_input],
766
+ )
767
+
768
+ with gr.TabItem("Image"):
769
+ click_state = gr.State([[],[]])
770
+
771
+ interactive_state = gr.State({
772
+ "inference_times": 0,
773
+ "negative_click_times" : 0,
774
+ "positive_click_times": 0,
775
+ "mask_save": args.mask_save,
776
+ "multi_mask": {
777
+ "mask_names": [],
778
+ "masks": []
779
+ },
780
+ "track_end_number": None,
781
+ }
782
+ )
783
+
784
+ image_state = gr.State(
785
+ {
786
+ "user_name": "",
787
+ "image_name": "",
788
+ "origin_images": None,
789
+ "painted_images": None,
790
+ "masks": None,
791
+ "inpaint_masks": None,
792
+ "logits": None,
793
+ "select_frame_number": 0,
794
+ "fps": 30
795
+ }
796
+ )
797
+
798
+ with gr.Group(elem_classes="gr-monochrome-group", visible=True):
799
+ with gr.Row():
800
+ with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
801
+ with gr.Row():
802
+ erode_kernel_size = gr.Slider(label='Erode Kernel Size',
803
+ minimum=0,
804
+ maximum=30,
805
+ step=1,
806
+ value=10,
807
+ info="Erosion on the added mask",
808
+ interactive=True)
809
+ dilate_kernel_size = gr.Slider(label='Dilate Kernel Size',
810
+ minimum=0,
811
+ maximum=30,
812
+ step=1,
813
+ value=10,
814
+ info="Dilation on the added mask",
815
+ interactive=True)
816
+
817
+ with gr.Row():
818
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Num of Refinement Iterations", info="More iterations → More details & More time", visible=False)
819
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
820
+ with gr.Row():
821
+ point_prompt = gr.Radio(
822
+ choices=["Positive", "Negative"],
823
+ value="Positive",
824
+ label="Point Prompt",
825
+ info="Click to add positive or negative point for target mask",
826
+ interactive=True,
827
+ visible=False,
828
+ min_width=100,
829
+ scale=1)
830
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False)
831
+
832
+ gr.Markdown("---")
833
+
834
+ with gr.Column():
835
+ # input image
836
+ with gr.Row(equal_height=True):
837
+ with gr.Column(scale=2):
838
+ gr.Markdown("## Step1: Upload image")
839
+ with gr.Column(scale=2):
840
+ step2_title = gr.Markdown("## Step2: Add masks <small>(Several clicks then **`Add Mask`** <u>one by one</u>)</small>", visible=False)
841
+ with gr.Row(equal_height=True):
842
+ with gr.Column(scale=2):
843
+ image_input = gr.Image(label="Input Image", elem_classes="image")
844
+ extract_frames_button = gr.Button(value="Load Image", interactive=True, elem_classes="new_button")
845
+ with gr.Column(scale=2):
846
+ image_info = gr.Textbox(label="Image Info", visible=False)
847
+ template_frame = gr.Image(type="pil", label="Start Frame", interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
848
+ with gr.Row(equal_height=True, elem_classes="mask_button_group"):
849
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, elem_classes="new_button", min_width=100)
850
+ add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100)
851
+ remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100)
852
+ matting_button = gr.Button(value="Image Matting", interactive=True, visible=False, elem_classes="green_button", min_width=100)
853
+
854
+ gr.HTML('<hr style="border: none; height: 1.5px; background: linear-gradient(to right, #a566b4, #74a781);margin: 5px 0;">')
855
+
856
+ # output image
857
+ with gr.Row(equal_height=True):
858
+ with gr.Column(scale=2):
859
+ foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image")
860
+ foreground_output_button = gr.Button(value="Foreground Output", visible=False, elem_classes="new_button")
861
+ with gr.Column(scale=2):
862
+ alpha_image_output = gr.Image(type="pil", label="Alpha Output", visible=False, elem_classes="image")
863
+ alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
864
+
865
+ # first step: get the image information
866
+ extract_frames_button.click(
867
+ fn=get_frames_from_image,
868
+ inputs=[
869
+ image_input, image_state
870
+ ],
871
+ outputs=[image_state, image_info, template_frame,
872
+ image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
873
+ foreground_image_output, alpha_image_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
874
+ )
875
+
876
+ # second step: select images from slider
877
+ image_selection_slider.release(fn=select_image_template,
878
+ inputs=[image_selection_slider, image_state, interactive_state],
879
+ outputs=[template_frame, image_state, interactive_state], api_name="select_image")
880
+ track_pause_number_slider.release(fn=get_end_number,
881
+ inputs=[track_pause_number_slider, image_state, interactive_state],
882
+ outputs=[template_frame, interactive_state], api_name="end_image")
883
+
884
+ # click select image to get mask using sam
885
+ template_frame.select(
886
+ fn=sam_refine,
887
+ inputs=[image_state, point_prompt, click_state, interactive_state],
888
+ outputs=[template_frame, image_state, interactive_state]
889
+ )
890
+
891
+ # add different mask
892
+ add_mask_button.click(
893
+ fn=add_multi_mask,
894
+ inputs=[image_state, interactive_state, mask_dropdown],
895
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state]
896
+ )
897
+
898
+ remove_mask_button.click(
899
+ fn=remove_multi_mask,
900
+ inputs=[interactive_state, mask_dropdown],
901
+ outputs=[interactive_state, mask_dropdown]
902
+ )
903
+
904
+ # image matting
905
+ matting_button.click(
906
+ fn=image_matting,
907
+ inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider],
908
+ outputs=[foreground_image_output, alpha_image_output]
909
+ )
910
+
911
+ # click to get mask
912
+ mask_dropdown.change(
913
+ fn=show_mask,
914
+ inputs=[image_state, interactive_state, mask_dropdown],
915
+ outputs=[template_frame]
916
+ )
917
+
918
+ # clear input
919
+ image_input.change(
920
+ fn=restart,
921
+ inputs=[],
922
+ outputs=[
923
+ image_state,
924
+ interactive_state,
925
+ click_state,
926
+ foreground_image_output, alpha_image_output,
927
+ template_frame,
928
+ image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
929
+ add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, image_info, step2_title
930
+ ],
931
+ queue=False,
932
+ show_progress=False)
933
+
934
+ image_input.clear(
935
+ fn=restart,
936
+ inputs=[],
937
+ outputs=[
938
+ image_state,
939
+ interactive_state,
940
+ click_state,
941
+ foreground_image_output, alpha_image_output,
942
+ template_frame,
943
+ image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
944
+ add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, image_info, step2_title
945
+ ],
946
+ queue=False,
947
+ show_progress=False)
948
+
949
+ # points clear
950
+ clear_button_click.click(
951
+ fn = clear_click,
952
+ inputs = [image_state, click_state,],
953
+ outputs = [template_frame,click_state],
954
+ )
955
+
956
+ # set example
957
+ gr.Markdown("---")
958
+ gr.Markdown("## Examples")
959
+ gr.Examples(
960
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.png", "test-sample1.png"]],
961
+ inputs=[image_input],
962
+ )
963
+
964
+ gr.Markdown(article)
965
+
966
+ demo.queue()
967
+ demo.launch(debug=True)
hugging_face/matanyone_wrapper.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import torch
3
+ from torchvision.transforms.functional import to_tensor
4
+ import numpy as np
5
+ import random
6
+ import cv2
7
+
8
+ def gen_dilate(alpha, min_kernel_size, max_kernel_size):
9
+ kernel_size = random.randint(min_kernel_size, max_kernel_size)
10
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
11
+ fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
12
+ dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
13
+ return dilate.astype(np.float32)
14
+
15
+ def gen_erosion(alpha, min_kernel_size, max_kernel_size):
16
+ kernel_size = random.randint(min_kernel_size, max_kernel_size)
17
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
18
+ fg = np.array(np.equal(alpha, 255).astype(np.float32))
19
+ erode = cv2.erode(fg, kernel, iterations=1)*255
20
+ return erode.astype(np.float32)
21
+
22
+ @torch.inference_mode()
23
+ @torch.cuda.amp.autocast()
24
+ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
25
+ """
26
+ Args:
27
+ frames_np: [(H,W,C)]*n, uint8
28
+ mask: (H,W), uint8
29
+ Outputs:
30
+ com: [(H,W,C)]*n, uint8
31
+ pha: [(H,W,C)]*n, uint8
32
+ """
33
+
34
+ # print(f'===== [r_erode] {r_erode}; [r_dilate] {r_dilate} =====')
35
+ bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3))
36
+ objects = [1]
37
+
38
+ # [optional] erode & dilate on given seg mask
39
+ if r_dilate > 0:
40
+ mask = gen_dilate(mask, r_dilate, r_dilate)
41
+ if r_erode > 0:
42
+ mask = gen_erosion(mask, r_erode, r_erode)
43
+
44
+ mask = torch.from_numpy(mask).cuda()
45
+
46
+ frames_np = [frames_np[0]]* n_warmup + frames_np
47
+
48
+ frames = []
49
+ phas = []
50
+ for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
51
+ image = to_tensor(frame_single).cuda().float()
52
+
53
+ if ti == 0:
54
+ output_prob = processor.step(image, mask, objects=objects) # encode given mask
55
+ output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
56
+ else:
57
+ if ti <= n_warmup:
58
+ output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
59
+ else:
60
+ output_prob = processor.step(image)
61
+
62
+ # convert output probabilities to an object mask
63
+ mask = processor.output_prob_to_mask(output_prob)
64
+
65
+ pha = mask.unsqueeze(2).cpu().numpy()
66
+ com_np = frame_single / 255. * pha + bgr * (1 - pha)
67
+
68
+ # DONOT save the warmup frames
69
+ if ti > (n_warmup-1):
70
+ frames.append((com_np*255).astype(np.uint8))
71
+ phas.append((pha*255).astype(np.uint8))
72
+
73
+ return frames, phas
hugging_face/tools/__init__.py ADDED
File without changes
hugging_face/tools/base_segmenter.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter
11
+
12
+
13
+ class BaseSegmenter:
14
+ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
15
+ """
16
+ device: model device
17
+ SAM_checkpoint: path of SAM checkpoint
18
+ model_type: vit_b, vit_l, vit_h
19
+ """
20
+ print(f"Initializing BaseSegmenter to {device}")
21
+ assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
22
+
23
+ self.device = device
24
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
25
+ self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
26
+ self.model.to(device=self.device)
27
+ self.predictor = SamPredictor(self.model)
28
+ self.embedded = False
29
+
30
+ @torch.no_grad()
31
+ def set_image(self, image: np.ndarray):
32
+ # PIL.open(image_path) 3channel: RGB
33
+ # image embedding: avoid encode the same image multiple times
34
+ self.orignal_image = image
35
+ if self.embedded:
36
+ print('repeat embedding, please reset_image.')
37
+ return
38
+ self.predictor.set_image(image)
39
+ self.embedded = True
40
+ return
41
+
42
+ @torch.no_grad()
43
+ def reset_image(self):
44
+ # reset image embeding
45
+ self.predictor.reset_image()
46
+ self.embedded = False
47
+
48
+ def predict(self, prompts, mode, multimask=True):
49
+ """
50
+ image: numpy array, h, w, 3
51
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
52
+ prompts['point_coords']: numpy array [N,2]
53
+ prompts['point_labels']: numpy array [1,N]
54
+ prompts['mask_input']: numpy array [1,256,256]
55
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
56
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
57
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
58
+ """
59
+ assert self.embedded, 'prediction is called before set_image (feature embedding).'
60
+ assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
61
+
62
+ if mode == 'point':
63
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
64
+ point_labels=prompts['point_labels'],
65
+ multimask_output=multimask)
66
+ elif mode == 'mask':
67
+ masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
68
+ multimask_output=multimask)
69
+ elif mode == 'both': # both
70
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
71
+ point_labels=prompts['point_labels'],
72
+ mask_input=prompts['mask_input'],
73
+ multimask_output=multimask)
74
+ else:
75
+ raise("Not implement now!")
76
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
77
+ return masks, scores, logits
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # load and show an image
82
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
83
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
84
+
85
+ # initialise BaseSegmenter
86
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
87
+ model_type = 'vit_h'
88
+ device = "cuda:4"
89
+ base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
90
+
91
+ # image embedding (once embedded, multiple prompts can be applied)
92
+ base_segmenter.set_image(image)
93
+
94
+ # examples
95
+ # point only ------------------------
96
+ mode = 'point'
97
+ prompts = {
98
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
99
+ 'point_labels': np.array([1, 1]),
100
+ }
101
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
102
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
103
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
104
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
105
+
106
+ # both ------------------------
107
+ mode = 'both'
108
+ mask_input = logits[np.argmax(scores), :, :]
109
+ prompts = {'mask_input': mask_input [None, :, :]}
110
+ prompts = {
111
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
112
+ 'point_labels': np.array([1, 0]),
113
+ 'mask_input': mask_input[None, :, :]
114
+ }
115
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
116
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
117
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
118
+ cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
119
+
120
+ # mask only ------------------------
121
+ mode = 'mask'
122
+ mask_input = logits[np.argmax(scores), :, :]
123
+
124
+ prompts = {'mask_input': mask_input[None, :, :]}
125
+
126
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
127
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
128
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
129
+ cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
hugging_face/tools/download_util.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import requests
4
+ from torch.hub import download_url_to_file, get_dir
5
+ from tqdm import tqdm
6
+ from urllib.parse import urlparse
7
+
8
+ def sizeof_fmt(size, suffix='B'):
9
+ """Get human readable file size.
10
+
11
+ Args:
12
+ size (int): File size.
13
+ suffix (str): Suffix. Default: 'B'.
14
+
15
+ Return:
16
+ str: Formated file siz.
17
+ """
18
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
19
+ if abs(size) < 1024.0:
20
+ return f'{size:3.1f} {unit}{suffix}'
21
+ size /= 1024.0
22
+ return f'{size:3.1f} Y{suffix}'
23
+
24
+
25
+ def download_file_from_google_drive(file_id, save_path):
26
+ """Download files from google drive.
27
+ Ref:
28
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
29
+ Args:
30
+ file_id (str): File id.
31
+ save_path (str): Save path.
32
+ """
33
+
34
+ session = requests.Session()
35
+ URL = 'https://docs.google.com/uc?export=download'
36
+ params = {'id': file_id}
37
+
38
+ response = session.get(URL, params=params, stream=True)
39
+ token = get_confirm_token(response)
40
+ if token:
41
+ params['confirm'] = token
42
+ response = session.get(URL, params=params, stream=True)
43
+
44
+ # get file size
45
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
46
+ print(response_file_size)
47
+ if 'Content-Range' in response_file_size.headers:
48
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
49
+ else:
50
+ file_size = None
51
+
52
+ save_response_content(response, save_path, file_size)
53
+
54
+
55
+ def get_confirm_token(response):
56
+ for key, value in response.cookies.items():
57
+ if key.startswith('download_warning'):
58
+ return value
59
+ return None
60
+
61
+
62
+ def save_response_content(response, destination, file_size=None, chunk_size=32768):
63
+ if file_size is not None:
64
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
65
+
66
+ readable_file_size = sizeof_fmt(file_size)
67
+ else:
68
+ pbar = None
69
+
70
+ with open(destination, 'wb') as f:
71
+ downloaded_size = 0
72
+ for chunk in response.iter_content(chunk_size):
73
+ downloaded_size += chunk_size
74
+ if pbar is not None:
75
+ pbar.update(1)
76
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
77
+ if chunk: # filter out keep-alive new chunks
78
+ f.write(chunk)
79
+ if pbar is not None:
80
+ pbar.close()
81
+
82
+
83
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
84
+ """Load file form http url, will download models if necessary.
85
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
86
+ Args:
87
+ url (str): URL to be downloaded.
88
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
89
+ Default: None.
90
+ progress (bool): Whether to show the download progress. Default: True.
91
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
92
+ Returns:
93
+ str: The path to the downloaded file.
94
+ """
95
+ if model_dir is None: # use the pytorch hub_dir
96
+ hub_dir = get_dir()
97
+ model_dir = os.path.join(hub_dir, 'checkpoints')
98
+
99
+ os.makedirs(model_dir, exist_ok=True)
100
+
101
+ parts = urlparse(url)
102
+ filename = os.path.basename(parts.path)
103
+ if file_name is not None:
104
+ filename = file_name
105
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
106
+ if not os.path.exists(cached_file):
107
+ print(f'Downloading: "{url}" to {cached_file}\n')
108
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
109
+ return cached_file
hugging_face/tools/interact_tools.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter as mask_painter2
11
+ from .base_segmenter import BaseSegmenter
12
+ from .painter import mask_painter, point_painter
13
+ import os
14
+ import requests
15
+ import sys
16
+
17
+
18
+ mask_color = 3
19
+ mask_alpha = 0.7
20
+ contour_color = 1
21
+ contour_width = 5
22
+ point_color_ne = 8
23
+ point_color_ps = 50
24
+ point_alpha = 0.9
25
+ point_radius = 15
26
+ contour_color = 2
27
+ contour_width = 5
28
+
29
+
30
+ class SamControler():
31
+ def __init__(self, SAM_checkpoint, model_type, device):
32
+ '''
33
+ initialize sam controler
34
+ '''
35
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
36
+
37
+
38
+ # def seg_again(self, image: np.ndarray):
39
+ # '''
40
+ # it is used when interact in video
41
+ # '''
42
+ # self.sam_controler.reset_image()
43
+ # self.sam_controler.set_image(image)
44
+ # return
45
+
46
+
47
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
48
+ '''
49
+ it is used in first frame in video
50
+ return: mask, logit, painted image(mask+point)
51
+ '''
52
+ # self.sam_controler.set_image(image)
53
+ origal_image = self.sam_controler.orignal_image
54
+ neg_flag = labels[-1]
55
+ if neg_flag==1:
56
+ #find neg
57
+ prompts = {
58
+ 'point_coords': points,
59
+ 'point_labels': labels,
60
+ }
61
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
62
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
63
+ prompts = {
64
+ 'point_coords': points,
65
+ 'point_labels': labels,
66
+ 'mask_input': logit[None, :, :]
67
+ }
68
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
69
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
70
+ else:
71
+ #find positive
72
+ prompts = {
73
+ 'point_coords': points,
74
+ 'point_labels': labels,
75
+ }
76
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
77
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
78
+
79
+
80
+ assert len(points)==len(labels)
81
+
82
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
83
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
84
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
85
+ painted_image = Image.fromarray(painted_image)
86
+
87
+ return mask, logit, painted_image
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
hugging_face/tools/mask_painter.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import copy
6
+ import time
7
+
8
+
9
+ def colormap(rgb=True):
10
+ color_list = np.array(
11
+ [
12
+ 0.000, 0.000, 0.000,
13
+ 1.000, 1.000, 1.000,
14
+ 1.000, 0.498, 0.313,
15
+ 0.392, 0.581, 0.929,
16
+ 0.000, 0.447, 0.741,
17
+ 0.850, 0.325, 0.098,
18
+ 0.929, 0.694, 0.125,
19
+ 0.494, 0.184, 0.556,
20
+ 0.466, 0.674, 0.188,
21
+ 0.301, 0.745, 0.933,
22
+ 0.635, 0.078, 0.184,
23
+ 0.300, 0.300, 0.300,
24
+ 0.600, 0.600, 0.600,
25
+ 1.000, 0.000, 0.000,
26
+ 1.000, 0.500, 0.000,
27
+ 0.749, 0.749, 0.000,
28
+ 0.000, 1.000, 0.000,
29
+ 0.000, 0.000, 1.000,
30
+ 0.667, 0.000, 1.000,
31
+ 0.333, 0.333, 0.000,
32
+ 0.333, 0.667, 0.000,
33
+ 0.333, 1.000, 0.000,
34
+ 0.667, 0.333, 0.000,
35
+ 0.667, 0.667, 0.000,
36
+ 0.667, 1.000, 0.000,
37
+ 1.000, 0.333, 0.000,
38
+ 1.000, 0.667, 0.000,
39
+ 1.000, 1.000, 0.000,
40
+ 0.000, 0.333, 0.500,
41
+ 0.000, 0.667, 0.500,
42
+ 0.000, 1.000, 0.500,
43
+ 0.333, 0.000, 0.500,
44
+ 0.333, 0.333, 0.500,
45
+ 0.333, 0.667, 0.500,
46
+ 0.333, 1.000, 0.500,
47
+ 0.667, 0.000, 0.500,
48
+ 0.667, 0.333, 0.500,
49
+ 0.667, 0.667, 0.500,
50
+ 0.667, 1.000, 0.500,
51
+ 1.000, 0.000, 0.500,
52
+ 1.000, 0.333, 0.500,
53
+ 1.000, 0.667, 0.500,
54
+ 1.000, 1.000, 0.500,
55
+ 0.000, 0.333, 1.000,
56
+ 0.000, 0.667, 1.000,
57
+ 0.000, 1.000, 1.000,
58
+ 0.333, 0.000, 1.000,
59
+ 0.333, 0.333, 1.000,
60
+ 0.333, 0.667, 1.000,
61
+ 0.333, 1.000, 1.000,
62
+ 0.667, 0.000, 1.000,
63
+ 0.667, 0.333, 1.000,
64
+ 0.667, 0.667, 1.000,
65
+ 0.667, 1.000, 1.000,
66
+ 1.000, 0.000, 1.000,
67
+ 1.000, 0.333, 1.000,
68
+ 1.000, 0.667, 1.000,
69
+ 0.167, 0.000, 0.000,
70
+ 0.333, 0.000, 0.000,
71
+ 0.500, 0.000, 0.000,
72
+ 0.667, 0.000, 0.000,
73
+ 0.833, 0.000, 0.000,
74
+ 1.000, 0.000, 0.000,
75
+ 0.000, 0.167, 0.000,
76
+ 0.000, 0.333, 0.000,
77
+ 0.000, 0.500, 0.000,
78
+ 0.000, 0.667, 0.000,
79
+ 0.000, 0.833, 0.000,
80
+ 0.000, 1.000, 0.000,
81
+ 0.000, 0.000, 0.167,
82
+ 0.000, 0.000, 0.333,
83
+ 0.000, 0.000, 0.500,
84
+ 0.000, 0.000, 0.667,
85
+ 0.000, 0.000, 0.833,
86
+ 0.000, 0.000, 1.000,
87
+ 0.143, 0.143, 0.143,
88
+ 0.286, 0.286, 0.286,
89
+ 0.429, 0.429, 0.429,
90
+ 0.571, 0.571, 0.571,
91
+ 0.714, 0.714, 0.714,
92
+ 0.857, 0.857, 0.857
93
+ ]
94
+ ).astype(np.float32)
95
+ color_list = color_list.reshape((-1, 3)) * 255
96
+ if not rgb:
97
+ color_list = color_list[:, ::-1]
98
+ return color_list
99
+
100
+
101
+ color_list = colormap()
102
+ color_list = color_list.astype('uint8').tolist()
103
+
104
+
105
+ def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
106
+ background_color = np.array(background_color)
107
+ contour_color = np.array(contour_color)
108
+
109
+ # background_mask = 1 - background_mask
110
+ # contour_mask = 1 - contour_mask
111
+
112
+ for i in range(3):
113
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
114
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
115
+
116
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
117
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
118
+
119
+ return image.astype('uint8')
120
+
121
+
122
+ def mask_generator_00(mask, background_radius, contour_radius):
123
+ # no background width when '00'
124
+ # distance map
125
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
126
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
127
+ dist_map = dist_transform_fore - dist_transform_back
128
+ # ...:::!!!:::...
129
+ contour_radius += 2
130
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
131
+ contour_mask = contour_mask / np.max(contour_mask)
132
+ contour_mask[contour_mask>0.5] = 1.
133
+
134
+ return mask, contour_mask
135
+
136
+
137
+ def mask_generator_01(mask, background_radius, contour_radius):
138
+ # no background width when '00'
139
+ # distance map
140
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
141
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
142
+ dist_map = dist_transform_fore - dist_transform_back
143
+ # ...:::!!!:::...
144
+ contour_radius += 2
145
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
146
+ contour_mask = contour_mask / np.max(contour_mask)
147
+ return mask, contour_mask
148
+
149
+
150
+ def mask_generator_10(mask, background_radius, contour_radius):
151
+ # distance map
152
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
153
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
154
+ dist_map = dist_transform_fore - dist_transform_back
155
+ # .....:::::!!!!!
156
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
157
+ background_mask = (background_mask - np.min(background_mask))
158
+ background_mask = background_mask / np.max(background_mask)
159
+ # ...:::!!!:::...
160
+ contour_radius += 2
161
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
162
+ contour_mask = contour_mask / np.max(contour_mask)
163
+ contour_mask[contour_mask>0.5] = 1.
164
+ return background_mask, contour_mask
165
+
166
+
167
+ def mask_generator_11(mask, background_radius, contour_radius):
168
+ # distance map
169
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
170
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
171
+ dist_map = dist_transform_fore - dist_transform_back
172
+ # .....:::::!!!!!
173
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
174
+ background_mask = (background_mask - np.min(background_mask))
175
+ background_mask = background_mask / np.max(background_mask)
176
+ # ...:::!!!:::...
177
+ contour_radius += 2
178
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
179
+ contour_mask = contour_mask / np.max(contour_mask)
180
+ return background_mask, contour_mask
181
+
182
+
183
+ def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
184
+ """
185
+ Input:
186
+ input_image: numpy array
187
+ input_mask: numpy array
188
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
189
+ background_blur_radius: radius of background blur, must be odd number
190
+ contour_width: width of mask contour, must be odd number
191
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
192
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
193
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
194
+
195
+ Output:
196
+ painted_image: numpy array
197
+ """
198
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
199
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
200
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
201
+
202
+ # downsample input image and mask
203
+ width, height = input_image.shape[0], input_image.shape[1]
204
+ res = 1024
205
+ ratio = min(1.0 * res / max(width, height), 1.0)
206
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
207
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
208
+
209
+ # 0: background, 1: foreground
210
+ msk = np.clip(input_mask, 0, 1)
211
+
212
+ # generate masks for background and contour pixels
213
+ background_radius = (background_blur_radius - 1) // 2
214
+ contour_radius = (contour_width - 1) // 2
215
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
216
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
217
+
218
+ # paint
219
+ painted_image = vis_add_mask\
220
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
221
+
222
+ return painted_image
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
228
+ background_blur_radius = 31 # radius of background blur, must be odd number
229
+ contour_width = 11 # contour width, must be odd number
230
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
231
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
232
+
233
+ # load input image and mask
234
+ input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
235
+ input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
236
+
237
+ # paint
238
+ overall_time_1 = 0
239
+ overall_time_2 = 0
240
+ overall_time_3 = 0
241
+ overall_time_4 = 0
242
+ overall_time_5 = 0
243
+
244
+ for i in range(50):
245
+ t2 = time.time()
246
+ painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
247
+ e2 = time.time()
248
+
249
+ t3 = time.time()
250
+ painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
251
+ e3 = time.time()
252
+
253
+ t1 = time.time()
254
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
255
+ e1 = time.time()
256
+
257
+ t4 = time.time()
258
+ painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
259
+ e4 = time.time()
260
+
261
+ t5 = time.time()
262
+ painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
263
+ e5 = time.time()
264
+
265
+ overall_time_1 += (e1 - t1)
266
+ overall_time_2 += (e2 - t2)
267
+ overall_time_3 += (e3 - t3)
268
+ overall_time_4 += (e4 - t4)
269
+ overall_time_5 += (e5 - t5)
270
+
271
+ print(f'average time w gaussian: {overall_time_1/50}')
272
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
273
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
274
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
275
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
276
+
277
+ # save
278
+ painted_image_00 = Image.fromarray(painted_image_00)
279
+ painted_image_00.save('./test_img/painter_output_image_00.png')
280
+
281
+ painted_image_10 = Image.fromarray(painted_image_10)
282
+ painted_image_10.save('./test_img/painter_output_image_10.png')
283
+
284
+ painted_image_01 = Image.fromarray(painted_image_01)
285
+ painted_image_01.save('./test_img/painter_output_image_01.png')
286
+
287
+ painted_image_11 = Image.fromarray(painted_image_11)
288
+ painted_image_11.save('./test_img/painter_output_image_11.png')
hugging_face/tools/misc.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import random
4
+ import time
5
+ import torch
6
+ import torch.nn as nn
7
+ import logging
8
+ import numpy as np
9
+ from os import path as osp
10
+
11
+ def constant_init(module, val, bias=0):
12
+ if hasattr(module, 'weight') and module.weight is not None:
13
+ nn.init.constant_(module.weight, val)
14
+ if hasattr(module, 'bias') and module.bias is not None:
15
+ nn.init.constant_(module.bias, bias)
16
+
17
+ initialized_logger = {}
18
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
19
+ """Get the root logger.
20
+ The logger will be initialized if it has not been initialized. By default a
21
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
22
+ also be added.
23
+ Args:
24
+ logger_name (str): root logger name. Default: 'basicsr'.
25
+ log_file (str | None): The log filename. If specified, a FileHandler
26
+ will be added to the root logger.
27
+ log_level (int): The root logger level. Note that only the process of
28
+ rank 0 is affected, while other processes will set the level to
29
+ "Error" and be silent most of the time.
30
+ Returns:
31
+ logging.Logger: The root logger.
32
+ """
33
+ logger = logging.getLogger(logger_name)
34
+ # if the logger has been initialized, just return it
35
+ if logger_name in initialized_logger:
36
+ return logger
37
+
38
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
39
+ stream_handler = logging.StreamHandler()
40
+ stream_handler.setFormatter(logging.Formatter(format_str))
41
+ logger.addHandler(stream_handler)
42
+ logger.propagate = False
43
+
44
+ if log_file is not None:
45
+ logger.setLevel(log_level)
46
+ # add file handler
47
+ # file_handler = logging.FileHandler(log_file, 'w')
48
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
49
+ file_handler.setFormatter(logging.Formatter(format_str))
50
+ file_handler.setLevel(log_level)
51
+ logger.addHandler(file_handler)
52
+ initialized_logger[logger_name] = True
53
+ return logger
54
+
55
+
56
+ IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
57
+ torch.__version__)[0][:3])] >= [1, 12, 0]
58
+
59
+ def gpu_is_available():
60
+ if IS_HIGH_VERSION:
61
+ if torch.backends.mps.is_available():
62
+ return True
63
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
64
+
65
+ def get_device(gpu_id=None):
66
+ if gpu_id is None:
67
+ gpu_str = ''
68
+ elif isinstance(gpu_id, int):
69
+ gpu_str = f':{gpu_id}'
70
+ else:
71
+ raise TypeError('Input should be int value.')
72
+
73
+ if IS_HIGH_VERSION:
74
+ if torch.backends.mps.is_available():
75
+ return torch.device('mps'+gpu_str)
76
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
77
+
78
+
79
+ def set_random_seed(seed):
80
+ """Set random seeds."""
81
+ random.seed(seed)
82
+ np.random.seed(seed)
83
+ torch.manual_seed(seed)
84
+ torch.cuda.manual_seed(seed)
85
+ torch.cuda.manual_seed_all(seed)
86
+
87
+
88
+ def get_time_str():
89
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
90
+
91
+
92
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
93
+ """Scan a directory to find the interested files.
94
+
95
+ Args:
96
+ dir_path (str): Path of the directory.
97
+ suffix (str | tuple(str), optional): File suffix that we are
98
+ interested in. Default: None.
99
+ recursive (bool, optional): If set to True, recursively scan the
100
+ directory. Default: False.
101
+ full_path (bool, optional): If set to True, include the dir_path.
102
+ Default: False.
103
+
104
+ Returns:
105
+ A generator for all the interested files with relative pathes.
106
+ """
107
+
108
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
109
+ raise TypeError('"suffix" must be a string or tuple of strings')
110
+
111
+ root = dir_path
112
+
113
+ def _scandir(dir_path, suffix, recursive):
114
+ for entry in os.scandir(dir_path):
115
+ if not entry.name.startswith('.') and entry.is_file():
116
+ if full_path:
117
+ return_path = entry.path
118
+ else:
119
+ return_path = osp.relpath(entry.path, root)
120
+
121
+ if suffix is None:
122
+ yield return_path
123
+ elif return_path.endswith(suffix):
124
+ yield return_path
125
+ else:
126
+ if recursive:
127
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
128
+ else:
129
+ continue
130
+
131
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
hugging_face/tools/painter.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # paint masks, contours, or points on images, with specified colors
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import copy
7
+ import time
8
+
9
+
10
+ def colormap(rgb=True):
11
+ color_list = np.array(
12
+ [
13
+ 0.000, 0.000, 0.000,
14
+ 1.000, 1.000, 1.000,
15
+ 1.000, 0.498, 0.313,
16
+ 0.392, 0.581, 0.929,
17
+ 0.000, 0.447, 0.741,
18
+ 0.850, 0.325, 0.098,
19
+ 0.929, 0.694, 0.125,
20
+ 0.494, 0.184, 0.556,
21
+ 0.466, 0.674, 0.188,
22
+ 0.301, 0.745, 0.933,
23
+ 0.635, 0.078, 0.184,
24
+ 0.300, 0.300, 0.300,
25
+ 0.600, 0.600, 0.600,
26
+ 1.000, 0.000, 0.000,
27
+ 1.000, 0.500, 0.000,
28
+ 0.749, 0.749, 0.000,
29
+ 0.000, 1.000, 0.000,
30
+ 0.000, 0.000, 1.000,
31
+ 0.667, 0.000, 1.000,
32
+ 0.333, 0.333, 0.000,
33
+ 0.333, 0.667, 0.000,
34
+ 0.333, 1.000, 0.000,
35
+ 0.667, 0.333, 0.000,
36
+ 0.667, 0.667, 0.000,
37
+ 0.667, 1.000, 0.000,
38
+ 1.000, 0.333, 0.000,
39
+ 1.000, 0.667, 0.000,
40
+ 1.000, 1.000, 0.000,
41
+ 0.000, 0.333, 0.500,
42
+ 0.000, 0.667, 0.500,
43
+ 0.000, 1.000, 0.500,
44
+ 0.333, 0.000, 0.500,
45
+ 0.333, 0.333, 0.500,
46
+ 0.333, 0.667, 0.500,
47
+ 0.333, 1.000, 0.500,
48
+ 0.667, 0.000, 0.500,
49
+ 0.667, 0.333, 0.500,
50
+ 0.667, 0.667, 0.500,
51
+ 0.667, 1.000, 0.500,
52
+ 1.000, 0.000, 0.500,
53
+ 1.000, 0.333, 0.500,
54
+ 1.000, 0.667, 0.500,
55
+ 1.000, 1.000, 0.500,
56
+ 0.000, 0.333, 1.000,
57
+ 0.000, 0.667, 1.000,
58
+ 0.000, 1.000, 1.000,
59
+ 0.333, 0.000, 1.000,
60
+ 0.333, 0.333, 1.000,
61
+ 0.333, 0.667, 1.000,
62
+ 0.333, 1.000, 1.000,
63
+ 0.667, 0.000, 1.000,
64
+ 0.667, 0.333, 1.000,
65
+ 0.667, 0.667, 1.000,
66
+ 0.667, 1.000, 1.000,
67
+ 1.000, 0.000, 1.000,
68
+ 1.000, 0.333, 1.000,
69
+ 1.000, 0.667, 1.000,
70
+ 0.167, 0.000, 0.000,
71
+ 0.333, 0.000, 0.000,
72
+ 0.500, 0.000, 0.000,
73
+ 0.667, 0.000, 0.000,
74
+ 0.833, 0.000, 0.000,
75
+ 1.000, 0.000, 0.000,
76
+ 0.000, 0.167, 0.000,
77
+ 0.000, 0.333, 0.000,
78
+ 0.000, 0.500, 0.000,
79
+ 0.000, 0.667, 0.000,
80
+ 0.000, 0.833, 0.000,
81
+ 0.000, 1.000, 0.000,
82
+ 0.000, 0.000, 0.167,
83
+ 0.000, 0.000, 0.333,
84
+ 0.000, 0.000, 0.500,
85
+ 0.000, 0.000, 0.667,
86
+ 0.000, 0.000, 0.833,
87
+ 0.000, 0.000, 1.000,
88
+ 0.143, 0.143, 0.143,
89
+ 0.286, 0.286, 0.286,
90
+ 0.429, 0.429, 0.429,
91
+ 0.571, 0.571, 0.571,
92
+ 0.714, 0.714, 0.714,
93
+ 0.857, 0.857, 0.857
94
+ ]
95
+ ).astype(np.float32)
96
+ color_list = color_list.reshape((-1, 3)) * 255
97
+ if not rgb:
98
+ color_list = color_list[:, ::-1]
99
+ return color_list
100
+
101
+
102
+ color_list = colormap()
103
+ color_list = color_list.astype('uint8').tolist()
104
+
105
+
106
+ def vis_add_mask(image, mask, color, alpha):
107
+ color = np.array(color_list[color])
108
+ mask = mask > 0.5
109
+ image[mask] = image[mask] * (1-alpha) + color * alpha
110
+ return image.astype('uint8')
111
+
112
+ def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
113
+ h, w = input_image.shape[:2]
114
+ point_mask = np.zeros((h, w)).astype('uint8')
115
+ for point in input_points:
116
+ point_mask[point[1], point[0]] = 1
117
+
118
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
119
+ point_mask = cv2.dilate(point_mask, kernel)
120
+
121
+ contour_radius = (contour_width - 1) // 2
122
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
123
+ dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
124
+ dist_map = dist_transform_fore - dist_transform_back
125
+ # ...:::!!!:::...
126
+ contour_radius += 2
127
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
128
+ contour_mask = contour_mask / np.max(contour_mask)
129
+ contour_mask[contour_mask>0.5] = 1.
130
+
131
+ # paint mask
132
+ painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
133
+ # paint contour
134
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
135
+ return painted_image
136
+
137
+ def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
138
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
139
+ # 0: background, 1: foreground
140
+ mask = np.clip(input_mask, 0, 1)
141
+ contour_radius = (contour_width - 1) // 2
142
+
143
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
144
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
145
+ dist_map = dist_transform_fore - dist_transform_back
146
+ # ...:::!!!:::...
147
+ contour_radius += 2
148
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
149
+ contour_mask = contour_mask / np.max(contour_mask)
150
+ contour_mask[contour_mask>0.5] = 1.
151
+
152
+ # paint mask
153
+ painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
154
+ # paint contour
155
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
156
+
157
+ return painted_image
158
+
159
+ def background_remover(input_image, input_mask):
160
+ """
161
+ input_image: H, W, 3, np.array
162
+ input_mask: H, W, np.array
163
+
164
+ image_wo_background: PIL.Image
165
+ """
166
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
167
+ # 0: background, 1: foreground
168
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
169
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
170
+ image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
171
+
172
+ return image_wo_background
173
+
174
+ if __name__ == '__main__':
175
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
176
+ input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
177
+
178
+ # example of mask painter
179
+ mask_color = 3
180
+ mask_alpha = 0.7
181
+ contour_color = 1
182
+ contour_width = 5
183
+
184
+ # save
185
+ painted_image = Image.fromarray(input_image)
186
+ painted_image.save('images/original.png')
187
+
188
+ painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
189
+ # save
190
+ painted_image = Image.fromarray(input_image)
191
+ painted_image.save('images/original1.png')
192
+
193
+ # example of point painter
194
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
195
+ input_points = np.array([[500, 375], [70, 600]]) # x, y
196
+ point_color = 5
197
+ point_alpha = 0.9
198
+ point_radius = 15
199
+ contour_color = 2
200
+ contour_width = 5
201
+ painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
202
+ # save
203
+ painted_image = Image.fromarray(painted_image_1)
204
+ painted_image.save('images/point_painter_1.png')
205
+
206
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
207
+ painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
208
+ # save
209
+ painted_image = Image.fromarray(painted_image_2)
210
+ painted_image.save('images/point_painter_2.png')
211
+
212
+ # example of background remover
213
+ input_image = np.array(Image.open('images/original.png').convert('RGB'))
214
+ image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
215
+ image_wo_background.save('images/image_wo_background.png')
matanyone/config/__init__.py ADDED
File without changes
matanyone/config/eval_matanyone_config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - model: base
4
+ - override hydra/job_logging: custom-no-rank.yaml
5
+
6
+ hydra:
7
+ run:
8
+ dir: ../output/${exp_id}/${dataset}
9
+ output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
10
+
11
+ amp: False
12
+ weights: pretrained_models/matanyone.pth # default (can be modified from outside)
13
+ output_dir: null # defaults to run_dir; specify this to override
14
+ flip_aug: False
15
+
16
+
17
+ # maximum shortest side of the input; -1 means no resizing
18
+ # With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader)
19
+ # this parameter is added for the sole purpose for the GUI in the current codebase
20
+ # InferenceCore will downsize the input and restore the output to the original size if needed
21
+ # if you are using this code for some other project, you can also utilize this parameter
22
+ max_internal_size: -1
23
+
24
+ # these parameters, when set, override the dataset's default; useful for debugging
25
+ save_all: True
26
+ use_all_masks: False
27
+ use_long_term: False
28
+ mem_every: 5
29
+
30
+ # only relevant when long_term is not enabled
31
+ max_mem_frames: 5
32
+
33
+ # only relevant when long_term is enabled
34
+ long_term:
35
+ count_usage: True
36
+ max_mem_frames: 10
37
+ min_mem_frames: 5
38
+ num_prototypes: 128
39
+ max_num_tokens: 10000
40
+ buffer_tokens: 2000
41
+
42
+ top_k: 30
43
+ stagger_updates: 5
44
+ chunk_size: -1 # number of objects to process in parallel; -1 means unlimited
45
+ save_scores: False
46
+ save_aux: False
47
+ visualize: False
matanyone/config/hydra/job_logging/custom-no-rank.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ handlers:
8
+ console:
9
+ class: logging.StreamHandler
10
+ formatter: simple
11
+ stream: ext://sys.stdout
12
+ file:
13
+ class: logging.FileHandler
14
+ formatter: simple
15
+ # absolute file path
16
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
17
+ mode: w
18
+ root:
19
+ level: INFO
20
+ handlers: [console, file]
21
+
22
+ disable_existing_loggers: false
matanyone/config/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ handlers:
8
+ console:
9
+ class: logging.StreamHandler
10
+ formatter: simple
11
+ stream: ext://sys.stdout
12
+ file:
13
+ class: logging.FileHandler
14
+ formatter: simple
15
+ # absolute file path
16
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
17
+ mode: w
18
+ root:
19
+ level: INFO
20
+ handlers: [console, file]
21
+
22
+ disable_existing_loggers: false
matanyone/config/model/base.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pixel_mean: [0.485, 0.456, 0.406]
2
+ pixel_std: [0.229, 0.224, 0.225]
3
+
4
+ pixel_dim: 256
5
+ key_dim: 64
6
+ value_dim: 256
7
+ sensory_dim: 256
8
+ embed_dim: 256
9
+
10
+ pixel_encoder:
11
+ type: resnet50
12
+ ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1
13
+
14
+ mask_encoder:
15
+ type: resnet18
16
+ final_dim: 256
17
+
18
+ pixel_pe_scale: 32
19
+ pixel_pe_temperature: 128
20
+
21
+ object_transformer:
22
+ embed_dim: ${model.embed_dim}
23
+ ff_dim: 2048
24
+ num_heads: 8
25
+ num_blocks: 3
26
+ num_queries: 16
27
+ read_from_pixel:
28
+ input_norm: False
29
+ input_add_pe: False
30
+ add_pe_to_qkv: [True, True, False]
31
+ read_from_past:
32
+ add_pe_to_qkv: [True, True, False]
33
+ read_from_memory:
34
+ add_pe_to_qkv: [True, True, False]
35
+ read_from_query:
36
+ add_pe_to_qkv: [True, True, False]
37
+ output_norm: False
38
+ query_self_attention:
39
+ add_pe_to_qkv: [True, True, False]
40
+ pixel_self_attention:
41
+ add_pe_to_qkv: [True, True, False]
42
+
43
+ object_summarizer:
44
+ embed_dim: ${model.object_transformer.embed_dim}
45
+ num_summaries: ${model.object_transformer.num_queries}
46
+ add_pe: True
47
+
48
+ aux_loss:
49
+ sensory:
50
+ enabled: True
51
+ weight: 0.01
52
+ query:
53
+ enabled: True
54
+ weight: 0.01
55
+
56
+ mask_decoder:
57
+ # first value must equal embed_dim
58
+ up_dims: [256, 128, 128, 64, 16]
matanyone/inference/__init__.py ADDED
File without changes
matanyone/inference/image_feature_store.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Iterable
3
+ import torch
4
+ from matanyone.model.matanyone import MatAnyone
5
+
6
+
7
+ class ImageFeatureStore:
8
+ """
9
+ A cache for image features.
10
+ These features might be reused at different parts of the inference pipeline.
11
+ This class provide an interface for reusing these features.
12
+ It is the user's responsibility to delete redundant features.
13
+
14
+ Feature of a frame should be associated with a unique index -- typically the frame id.
15
+ """
16
+ def __init__(self, network: MatAnyone, no_warning: bool = False):
17
+ self.network = network
18
+ self._store = {}
19
+ self.no_warning = no_warning
20
+
21
+ def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None:
22
+ ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats)
23
+ key, shrinkage, selection = self.network.transform_key(ms_features[0])
24
+ self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
25
+
26
+ def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
27
+ seq_length = images.shape[0]
28
+ ms_features, pix_feat = self.network.encode_image(images, seq_length)
29
+ key, shrinkage, selection = self.network.transform_key(ms_features[0])
30
+ for index in range(seq_length):
31
+ self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0))
32
+
33
+ def get_features(self, index: int,
34
+ image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor):
35
+ if index not in self._store:
36
+ self._encode_feature(index, image, last_feats)
37
+
38
+ return self._store[index][:2]
39
+
40
+ def get_key(self, index: int,
41
+ image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
42
+ if index not in self._store:
43
+ self._encode_feature(index, image, last_feats)
44
+
45
+ return self._store[index][2:]
46
+
47
+ def delete(self, index: int) -> None:
48
+ if index in self._store:
49
+ del self._store[index]
50
+
51
+ def __len__(self):
52
+ return len(self._store)
53
+
54
+ def __del__(self):
55
+ if len(self._store) > 0 and not self.no_warning:
56
+ warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
matanyone/inference/inference_core.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Iterable, Dict
2
+ import logging
3
+ from omegaconf import DictConfig
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from matanyone.inference.memory_manager import MemoryManager
10
+ from matanyone.inference.object_manager import ObjectManager
11
+ from matanyone.inference.image_feature_store import ImageFeatureStore
12
+ from matanyone.model.matanyone import MatAnyone
13
+ from matanyone.utils.tensor_utils import pad_divide_by, unpad, aggregate
14
+
15
+ log = logging.getLogger()
16
+
17
+
18
+ class InferenceCore:
19
+
20
+ def __init__(self,
21
+ network: MatAnyone,
22
+ cfg: DictConfig,
23
+ *,
24
+ image_feature_store: ImageFeatureStore = None):
25
+ self.network = network
26
+ self.cfg = cfg
27
+ self.mem_every = cfg.mem_every
28
+ stagger_updates = cfg.stagger_updates
29
+ self.chunk_size = cfg.chunk_size
30
+ self.save_aux = cfg.save_aux
31
+ self.max_internal_size = cfg.max_internal_size
32
+ self.flip_aug = cfg.flip_aug
33
+
34
+ self.curr_ti = -1
35
+ self.last_mem_ti = 0
36
+ # at which time indices should we update the sensory memory
37
+ if stagger_updates >= self.mem_every:
38
+ self.stagger_ti = set(range(1, self.mem_every + 1))
39
+ else:
40
+ self.stagger_ti = set(
41
+ np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int))
42
+ self.object_manager = ObjectManager()
43
+ self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager)
44
+
45
+ if image_feature_store is None:
46
+ self.image_feature_store = ImageFeatureStore(self.network)
47
+ else:
48
+ self.image_feature_store = image_feature_store
49
+
50
+ self.last_mask = None
51
+ self.last_pix_feat = None
52
+ self.last_msk_value = None
53
+
54
+ def clear_memory(self):
55
+ self.curr_ti = -1
56
+ self.last_mem_ti = 0
57
+ self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager)
58
+
59
+ def clear_non_permanent_memory(self):
60
+ self.curr_ti = -1
61
+ self.last_mem_ti = 0
62
+ self.memory.clear_non_permanent_memory()
63
+
64
+ def clear_sensory_memory(self):
65
+ self.curr_ti = -1
66
+ self.last_mem_ti = 0
67
+ self.memory.clear_sensory_memory()
68
+
69
+ def update_config(self, cfg):
70
+ self.mem_every = cfg['mem_every']
71
+ self.memory.update_config(cfg)
72
+
73
+ def clear_temp_mem(self):
74
+ self.memory.clear_work_mem()
75
+ # self.object_manager = ObjectManager()
76
+ self.memory.clear_obj_mem()
77
+ # self.memory.clear_sensory_memory()
78
+
79
+ def _add_memory(self,
80
+ image: torch.Tensor,
81
+ pix_feat: torch.Tensor,
82
+ prob: torch.Tensor,
83
+ key: torch.Tensor,
84
+ shrinkage: torch.Tensor,
85
+ selection: torch.Tensor,
86
+ *,
87
+ is_deep_update: bool = True,
88
+ force_permanent: bool = False) -> None:
89
+ """
90
+ Memorize the given segmentation in all memory stores.
91
+
92
+ The batch dimension is 1 if flip augmentation is not used.
93
+ image: RGB image, (1/2)*3*H*W
94
+ pix_feat: from the key encoder, (1/2)*_*H*W
95
+ prob: (1/2)*num_objects*H*W, in [0, 1]
96
+ key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W
97
+ selection can be None if not using long-term memory
98
+ is_deep_update: whether to use deep update (e.g. with the mask encoder)
99
+ force_permanent: whether to force the memory to be permanent
100
+ """
101
+ if prob.shape[1] == 0:
102
+ # nothing to add
103
+ log.warn('Trying to add an empty object mask to memory!')
104
+ return
105
+
106
+ if force_permanent:
107
+ as_permanent = 'all'
108
+ else:
109
+ as_permanent = 'first'
110
+
111
+ self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids)
112
+ msk_value, sensory, obj_value, _ = self.network.encode_mask(
113
+ image,
114
+ pix_feat,
115
+ self.memory.get_sensory(self.object_manager.all_obj_ids),
116
+ prob,
117
+ deep_update=is_deep_update,
118
+ chunk_size=self.chunk_size,
119
+ need_weights=self.save_aux)
120
+ self.memory.add_memory(key,
121
+ shrinkage,
122
+ msk_value,
123
+ obj_value,
124
+ self.object_manager.all_obj_ids,
125
+ selection=selection,
126
+ as_permanent=as_permanent)
127
+ self.last_mem_ti = self.curr_ti
128
+ if is_deep_update:
129
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
130
+ self.last_msk_value = msk_value
131
+
132
+ def _segment(self,
133
+ key: torch.Tensor,
134
+ selection: torch.Tensor,
135
+ pix_feat: torch.Tensor,
136
+ ms_features: Iterable[torch.Tensor],
137
+ update_sensory: bool = True) -> torch.Tensor:
138
+ """
139
+ Produce a segmentation using the given features and the memory
140
+
141
+ The batch dimension is 1 if flip augmentation is not used.
142
+ key/selection: for anisotropic l2: (1/2) * _ * H * W
143
+ pix_feat: from the key encoder, (1/2) * _ * H * W
144
+ ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W
145
+ with strides 16, 8, and 4 respectively
146
+ update_sensory: whether to update the sensory memory
147
+
148
+ Returns: (num_objects+1)*H*W normalized probability; the first channel is the background
149
+ """
150
+ bs = key.shape[0]
151
+ if self.flip_aug:
152
+ assert bs == 2
153
+ else:
154
+ assert bs == 1
155
+
156
+ if not self.memory.engaged:
157
+ log.warn('Trying to segment without any memory!')
158
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
159
+ device=key.device,
160
+ dtype=key.dtype)
161
+
162
+ uncert_output = None
163
+
164
+ if self.curr_ti == 0: # ONLY for the first frame for prediction
165
+ memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output)
166
+ else:
167
+ memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti,
168
+ last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask)
169
+ memory_readout = self.object_manager.realize_dict(memory_readout)
170
+
171
+ sensory, _, pred_prob_with_bg = self.network.segment(ms_features,
172
+ memory_readout,
173
+ self.memory.get_sensory(
174
+ self.object_manager.all_obj_ids),
175
+ chunk_size=self.chunk_size,
176
+ update_sensory=update_sensory)
177
+ # remove batch dim
178
+ if self.flip_aug:
179
+ # average predictions of the non-flipped and flipped version
180
+ pred_prob_with_bg = (pred_prob_with_bg[0] +
181
+ torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2
182
+ else:
183
+ pred_prob_with_bg = pred_prob_with_bg[0]
184
+ if update_sensory:
185
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
186
+ return pred_prob_with_bg
187
+
188
+ def pred_all_flow(self, images):
189
+ self.total_len = images.shape[0]
190
+ images, self.pad = pad_divide_by(images, 16)
191
+ images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w)
192
+
193
+ self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images)
194
+
195
+ def encode_all_images(self, images):
196
+ images, self.pad = pad_divide_by(images, 16)
197
+ self.image_feature_store.get_all_features(images) # t c h w
198
+ return images
199
+
200
+ def step(self,
201
+ image: torch.Tensor,
202
+ mask: Optional[torch.Tensor] = None,
203
+ objects: Optional[List[int]] = None,
204
+ *,
205
+ idx_mask: bool = False,
206
+ end: bool = False,
207
+ delete_buffer: bool = True,
208
+ force_permanent: bool = False,
209
+ matting: bool = True,
210
+ first_frame_pred: bool = False) -> torch.Tensor:
211
+ """
212
+ Take a step with a new incoming image.
213
+ If there is an incoming mask with new objects, we will memorize them.
214
+ If there is no incoming mask, we will segment the image using the memory.
215
+ In both cases, we will update the memory and return a segmentation.
216
+
217
+ image: 3*H*W
218
+ mask: H*W (if idx mask) or len(objects)*H*W or None
219
+ objects: list of object ids that are valid in the mask Tensor.
220
+ The ids themselves do not need to be consecutive/in order, but they need to be
221
+ in the same position in the list as the corresponding mask
222
+ in the tensor in non-idx-mask mode.
223
+ objects is ignored if the mask is None.
224
+ If idx_mask is False and objects is None, we sequentially infer the object ids.
225
+ idx_mask: if True, mask is expected to contain an object id at every pixel.
226
+ If False, mask should have multiple channels with each channel representing one object.
227
+ end: if we are at the end of the sequence, we do not need to update memory
228
+ if unsure just set it to False
229
+ delete_buffer: whether to delete the image feature buffer after this step
230
+ force_permanent: the memory recorded this frame will be added to the permanent memory
231
+ """
232
+ if objects is None and mask is not None:
233
+ assert not idx_mask
234
+ objects = list(range(1, mask.shape[0] + 1))
235
+
236
+ # resize input if needed -- currently only used for the GUI
237
+ resize_needed = False
238
+ if self.max_internal_size > 0:
239
+ h, w = image.shape[-2:]
240
+ min_side = min(h, w)
241
+ if min_side > self.max_internal_size:
242
+ resize_needed = True
243
+ new_h = int(h / min_side * self.max_internal_size)
244
+ new_w = int(w / min_side * self.max_internal_size)
245
+ image = F.interpolate(image.unsqueeze(0),
246
+ size=(new_h, new_w),
247
+ mode='bilinear',
248
+ align_corners=False)[0]
249
+ if mask is not None:
250
+ if idx_mask:
251
+ mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),
252
+ size=(new_h, new_w),
253
+ mode='nearest-exact',
254
+ align_corners=False)[0, 0].round().long()
255
+ else:
256
+ mask = F.interpolate(mask.unsqueeze(0),
257
+ size=(new_h, new_w),
258
+ mode='bilinear',
259
+ align_corners=False)[0]
260
+
261
+ self.curr_ti += 1
262
+
263
+ image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!!
264
+ image = image.unsqueeze(0) # add the batch dimension
265
+ if self.flip_aug:
266
+ image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0)
267
+
268
+ # whether to update the working memory
269
+ is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or
270
+ (mask is not None)) and (not end)
271
+ # segment when there is no input mask or when the input mask is incomplete
272
+ need_segment = (mask is None) or (self.object_manager.num_obj > 0
273
+ and not self.object_manager.has_all(objects))
274
+ update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end)
275
+
276
+ # reinit if it is the first frame for prediction
277
+ if first_frame_pred:
278
+ self.curr_ti = 0
279
+ self.last_mem_ti = 0
280
+ is_mem_frame = True
281
+ need_segment = True
282
+ update_sensory = True
283
+
284
+ # encoding the image
285
+ ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
286
+ key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image)
287
+
288
+ # segmentation from memory if needed
289
+ if need_segment:
290
+ pred_prob_with_bg = self._segment(key,
291
+ selection,
292
+ pix_feat,
293
+ ms_feat,
294
+ update_sensory=update_sensory)
295
+
296
+ # use the input mask if provided
297
+ if mask is not None:
298
+ # inform the manager of the new objects, and get a list of temporary id
299
+ # temporary ids -- indicates the position of objects in the tensor
300
+ # (starts with 1 due to the background channel)
301
+ corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects)
302
+
303
+ mask, _ = pad_divide_by(mask, 16)
304
+ if need_segment:
305
+ print("HERE!!!!!!!!!!!")
306
+ # merge predicted mask with the incomplete input mask
307
+ pred_prob_no_bg = pred_prob_with_bg[1:]
308
+ # use the mutual exclusivity of segmentation
309
+ if idx_mask:
310
+ pred_prob_no_bg[:, mask > 0] = 0
311
+ else:
312
+ pred_prob_no_bg[:, mask.max(0) > 0.5] = 0
313
+
314
+ new_masks = []
315
+ for mask_id, tmp_id in enumerate(corresponding_tmp_ids):
316
+ if idx_mask:
317
+ this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg)
318
+ else:
319
+ this_mask = mask[tmp_id]
320
+ if tmp_id > pred_prob_no_bg.shape[0]:
321
+ new_masks.append(this_mask.unsqueeze(0))
322
+ else:
323
+ # +1 for padding the background channel
324
+ pred_prob_no_bg[tmp_id - 1] = this_mask
325
+ # new_masks are always in the order of tmp_id
326
+ mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0)
327
+ elif idx_mask:
328
+ # simply convert cls to one-hot representation
329
+ if len(objects) == 0:
330
+ if delete_buffer:
331
+ self.image_feature_store.delete(self.curr_ti)
332
+ log.warn('Trying to insert an empty mask as memory!')
333
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
334
+ device=key.device,
335
+ dtype=key.dtype)
336
+ mask = torch.stack(
337
+ [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)],
338
+ dim=0)
339
+ if matting:
340
+ mask = mask.unsqueeze(0).float() / 255.
341
+ pred_prob_with_bg = torch.cat([1-mask, mask], 0)
342
+ else:
343
+ pred_prob_with_bg = aggregate(mask, dim=0)
344
+ pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0)
345
+
346
+ self.last_mask = pred_prob_with_bg[1:].unsqueeze(0)
347
+ if self.flip_aug:
348
+ self.last_mask = torch.cat(
349
+ [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0)
350
+ self.last_pix_feat = pix_feat
351
+
352
+ # save as memory if needed
353
+ if is_mem_frame or force_permanent:
354
+ # clear the memory for given mask and add the first predicted mask
355
+ if first_frame_pred:
356
+ self.clear_temp_mem()
357
+ self._add_memory(image,
358
+ pix_feat,
359
+ self.last_mask,
360
+ key,
361
+ shrinkage,
362
+ selection,
363
+ force_permanent=force_permanent,
364
+ is_deep_update=True)
365
+ else: # compute self.last_msk_value for non-memory frame
366
+ msk_value, _, _, _ = self.network.encode_mask(
367
+ image,
368
+ pix_feat,
369
+ self.memory.get_sensory(self.object_manager.all_obj_ids),
370
+ self.last_mask,
371
+ deep_update=False,
372
+ chunk_size=self.chunk_size,
373
+ need_weights=self.save_aux)
374
+ self.last_msk_value = msk_value
375
+
376
+ if delete_buffer:
377
+ self.image_feature_store.delete(self.curr_ti)
378
+
379
+ output_prob = unpad(pred_prob_with_bg, self.pad)
380
+ if resize_needed:
381
+ # restore output to the original size
382
+ output_prob = F.interpolate(output_prob.unsqueeze(0),
383
+ size=(h, w),
384
+ mode='bilinear',
385
+ align_corners=False)[0]
386
+
387
+ return output_prob
388
+
389
+ def delete_objects(self, objects: List[int]) -> None:
390
+ """
391
+ Delete the given objects from the memory.
392
+ """
393
+ self.object_manager.delete_objects(objects)
394
+ self.memory.purge_except(self.object_manager.all_obj_ids)
395
+
396
+ def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor:
397
+ if matting:
398
+ new_mask = output_prob[1:].squeeze(0)
399
+ else:
400
+ mask = torch.argmax(output_prob, dim=0)
401
+
402
+ # index in tensor != object id -- remap the ids here
403
+ new_mask = torch.zeros_like(mask)
404
+ for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
405
+ new_mask[mask == tmp_id] = obj.id
406
+
407
+ return new_mask
matanyone/inference/kv_memory_store.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Literal
2
+ from collections import defaultdict
3
+ import torch
4
+
5
+
6
+ def _add_last_dim(dictionary, key, new_value, prepend=False):
7
+ # append/prepend a new value to the last dimension of a tensor in a dictionary
8
+ # if the key does not exist, put the new value in
9
+ # append by default
10
+ if key in dictionary:
11
+ dictionary[key] = torch.cat([dictionary[key], new_value], -1)
12
+ else:
13
+ dictionary[key] = new_value
14
+
15
+
16
+ class KeyValueMemoryStore:
17
+ """
18
+ Works for key/value pairs type storage
19
+ e.g., working and long-term memory
20
+ """
21
+ def __init__(self, save_selection: bool = False, save_usage: bool = False):
22
+ """
23
+ We store keys and values of objects that first appear in the same frame in a bucket.
24
+ Each bucket contains a set of object ids.
25
+ Each bucket is associated with a single key tensor
26
+ and a dictionary of value tensors indexed by object id.
27
+
28
+ The keys and values are stored as the concatenation of a permanent part and a temporary part.
29
+ """
30
+ self.save_selection = save_selection
31
+ self.save_usage = save_usage
32
+
33
+ self.global_bucket_id = 0 # does not reduce even if buckets are removed
34
+ self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
35
+ self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
36
+ self.v: Dict[int, torch.Tensor] = {} # indexed by object id
37
+
38
+ # indexed by bucket id; the end point of permanent memory
39
+ self.perm_end_pt: Dict[int, int] = defaultdict(int)
40
+
41
+ # shrinkage and selection are just like the keys
42
+ self.s = {}
43
+ if self.save_selection:
44
+ self.e = {} # does not contain the permanent memory part
45
+
46
+ # usage
47
+ if self.save_usage:
48
+ self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
49
+ self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
50
+
51
+ def add(self,
52
+ key: torch.Tensor,
53
+ values: Dict[int, torch.Tensor],
54
+ shrinkage: torch.Tensor,
55
+ selection: torch.Tensor,
56
+ supposed_bucket_id: int = -1,
57
+ as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
58
+ """
59
+ key: (1/2)*C*N
60
+ values: dict of values ((1/2)*C*N), object ids are used as keys
61
+ shrinkage: (1/2)*1*N
62
+ selection: (1/2)*C*N
63
+
64
+ supposed_bucket_id: used to sync the bucket id between working and long-term memory
65
+ if provided, the input should all be in a single bucket indexed by this id
66
+ as_permanent: whether to store the input as permanent memory
67
+ 'no': don't
68
+ 'first': only store it as permanent memory if the bucket is empty
69
+ 'all': always store it as permanent memory
70
+ """
71
+ bs = key.shape[0]
72
+ ne = key.shape[-1]
73
+ assert len(key.shape) == 3
74
+ assert len(shrinkage.shape) == 3
75
+ assert not self.save_selection or len(selection.shape) == 3
76
+ assert as_permanent in ['no', 'first', 'all']
77
+
78
+ # add the value and create new buckets if necessary
79
+ if supposed_bucket_id >= 0:
80
+ enabled_buckets = [supposed_bucket_id]
81
+ bucket_exist = supposed_bucket_id in self.buckets
82
+ for obj, value in values.items():
83
+ if bucket_exist:
84
+ assert obj in self.v
85
+ assert obj in self.buckets[supposed_bucket_id]
86
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
87
+ else:
88
+ assert obj not in self.v
89
+ self.v[obj] = value
90
+ self.buckets[supposed_bucket_id] = list(values.keys())
91
+ else:
92
+ new_bucket_id = None
93
+ enabled_buckets = set()
94
+ for obj, value in values.items():
95
+ assert len(value.shape) == 3
96
+ if obj in self.v:
97
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
98
+ bucket_used = [
99
+ bucket_id for bucket_id, object_ids in self.buckets.items()
100
+ if obj in object_ids
101
+ ]
102
+ assert len(bucket_used) == 1 # each object should only be in one bucket
103
+ enabled_buckets.add(bucket_used[0])
104
+ else:
105
+ self.v[obj] = value
106
+ if new_bucket_id is None:
107
+ # create new bucket
108
+ new_bucket_id = self.global_bucket_id
109
+ self.global_bucket_id += 1
110
+ self.buckets[new_bucket_id] = []
111
+ # put the new object into the corresponding bucket
112
+ self.buckets[new_bucket_id].append(obj)
113
+ enabled_buckets.add(new_bucket_id)
114
+
115
+ # increment the permanent size if necessary
116
+ add_as_permanent = {} # indexed by bucket id
117
+ for bucket_id in enabled_buckets:
118
+ add_as_permanent[bucket_id] = False
119
+ if as_permanent == 'all':
120
+ self.perm_end_pt[bucket_id] += ne
121
+ add_as_permanent[bucket_id] = True
122
+ elif as_permanent == 'first':
123
+ if self.perm_end_pt[bucket_id] == 0:
124
+ self.perm_end_pt[bucket_id] = ne
125
+ add_as_permanent[bucket_id] = True
126
+
127
+ # create new counters for usage if necessary
128
+ if self.save_usage and as_permanent != 'all':
129
+ new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
130
+ new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
131
+
132
+ # add the key to every bucket
133
+ for bucket_id in self.buckets:
134
+ if bucket_id not in enabled_buckets:
135
+ # if we are not adding new values to a bucket, we should skip it
136
+ continue
137
+
138
+ _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
139
+ _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
140
+ if not add_as_permanent[bucket_id]:
141
+ if self.save_selection:
142
+ _add_last_dim(self.e, bucket_id, selection)
143
+ if self.save_usage:
144
+ _add_last_dim(self.use_cnt, bucket_id, new_count)
145
+ _add_last_dim(self.life_cnt, bucket_id, new_life)
146
+
147
+ def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
148
+ # increase all life count by 1
149
+ # increase use of indexed elements
150
+ if not self.save_usage:
151
+ return
152
+
153
+ usage = usage[:, self.perm_end_pt[bucket_id]:]
154
+ if usage.shape[-1] == 0:
155
+ # if there is no temporary memory, we don't need to update
156
+ return
157
+ self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
158
+ self.life_cnt[bucket_id] += 1
159
+
160
+ def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
161
+ # keep only the temporary elements *outside* of this range (with some boundary conditions)
162
+ # the permanent elements are ignored in this computation
163
+ # i.e., concat (a[:start], a[end:])
164
+ # bucket with size <= min_size are not modified
165
+
166
+ assert start >= 0
167
+ assert end <= 0
168
+
169
+ object_ids = self.buckets[bucket_id]
170
+ bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
171
+ if bucket_num_elements <= min_size:
172
+ return
173
+
174
+ if end == 0:
175
+ # negative 0 would not work as the end index!
176
+ # effectively make the second part an empty slice
177
+ end = self.k[bucket_id].shape[-1] + 1
178
+
179
+ p_size = self.perm_end_pt[bucket_id]
180
+ start = start + p_size
181
+
182
+ k = self.k[bucket_id]
183
+ s = self.s[bucket_id]
184
+ if self.save_selection:
185
+ e = self.e[bucket_id]
186
+ if self.save_usage:
187
+ use_cnt = self.use_cnt[bucket_id]
188
+ life_cnt = self.life_cnt[bucket_id]
189
+
190
+ self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
191
+ self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
192
+ if self.save_selection:
193
+ self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
194
+ if self.save_usage:
195
+ self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
196
+ self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
197
+ -1)
198
+ for obj_id in object_ids:
199
+ v = self.v[obj_id]
200
+ self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
201
+
202
+ def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
203
+ self.sieve_by_range(bucket_id, 0, -max_len, max_len)
204
+
205
+ def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
206
+ # for long-term memory only
207
+ object_ids = self.buckets[bucket_id]
208
+
209
+ assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
210
+
211
+ # normalize with life duration
212
+ usage = self.get_usage(bucket_id)
213
+ bs = usage.shape[0]
214
+
215
+ survivals = []
216
+
217
+ for bi in range(bs):
218
+ _, survived = torch.topk(usage[bi], k=max_size)
219
+ survivals.append(survived.flatten())
220
+ assert survived.shape[-1] == survivals[0].shape[-1]
221
+
222
+ self.k[bucket_id] = torch.stack(
223
+ [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
224
+ self.s[bucket_id] = torch.stack(
225
+ [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
226
+
227
+ if self.save_selection:
228
+ # Long-term memory does not store selection so this should not be needed
229
+ self.e[bucket_id] = torch.stack(
230
+ [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
231
+ for obj_id in object_ids:
232
+ self.v[obj_id] = torch.stack(
233
+ [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
234
+
235
+ self.use_cnt[bucket_id] = torch.stack(
236
+ [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
237
+ self.life_cnt[bucket_id] = torch.stack(
238
+ [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
239
+
240
+ def get_usage(self, bucket_id: int) -> torch.Tensor:
241
+ # return normalized usage
242
+ if not self.save_usage:
243
+ raise RuntimeError('I did not count usage!')
244
+ else:
245
+ usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
246
+ return usage
247
+
248
+ def get_all_sliced(
249
+ self, bucket_id: int, start: int, end: int
250
+ ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
251
+ # return k, sk, ek, value, normalized usage in order, sliced by start and end
252
+ # this only queries the temporary memory
253
+
254
+ assert start >= 0
255
+ assert end <= 0
256
+
257
+ p_size = self.perm_end_pt[bucket_id]
258
+ start = start + p_size
259
+
260
+ if end == 0:
261
+ # negative 0 would not work as the end index!
262
+ k = self.k[bucket_id][:, :, start:]
263
+ sk = self.s[bucket_id][:, :, start:]
264
+ ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
265
+ value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
266
+ usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
267
+ else:
268
+ k = self.k[bucket_id][:, :, start:end]
269
+ sk = self.s[bucket_id][:, :, start:end]
270
+ ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
271
+ value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
272
+ usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
273
+
274
+ return k, sk, ek, value, usage
275
+
276
+ def purge_except(self, obj_keep_idx: List[int]):
277
+ # purge certain objects from the memory except the one listed
278
+ obj_keep_idx = set(obj_keep_idx)
279
+
280
+ # remove objects that are not in the keep list from the buckets
281
+ buckets_to_remove = []
282
+ for bucket_id, object_ids in self.buckets.items():
283
+ self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
284
+ if len(self.buckets[bucket_id]) == 0:
285
+ buckets_to_remove.append(bucket_id)
286
+
287
+ # remove object values that are not in the keep list
288
+ self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
289
+
290
+ # remove buckets that are empty
291
+ for bucket_id in buckets_to_remove:
292
+ del self.buckets[bucket_id]
293
+ del self.k[bucket_id]
294
+ del self.s[bucket_id]
295
+ if self.save_selection:
296
+ del self.e[bucket_id]
297
+ if self.save_usage:
298
+ del self.use_cnt[bucket_id]
299
+ del self.life_cnt[bucket_id]
300
+
301
+ def clear_non_permanent_memory(self):
302
+ # clear all non-permanent memory
303
+ for bucket_id in self.buckets:
304
+ self.sieve_by_range(bucket_id, 0, 0, 0)
305
+
306
+ def get_v_size(self, obj_id: int) -> int:
307
+ return self.v[obj_id].shape[-1]
308
+
309
+ def size(self, bucket_id: int) -> int:
310
+ if bucket_id not in self.k:
311
+ return 0
312
+ else:
313
+ return self.k[bucket_id].shape[-1]
314
+
315
+ def perm_size(self, bucket_id: int) -> int:
316
+ return self.perm_end_pt[bucket_id]
317
+
318
+ def non_perm_size(self, bucket_id: int) -> int:
319
+ return self.size(bucket_id) - self.perm_size(bucket_id)
320
+
321
+ def engaged(self, bucket_id: Optional[int] = None) -> bool:
322
+ if bucket_id is None:
323
+ return len(self.buckets) > 0
324
+ else:
325
+ return bucket_id in self.buckets
326
+
327
+ @property
328
+ def num_objects(self) -> int:
329
+ return len(self.v)
330
+
331
+ @property
332
+ def key(self) -> Dict[int, torch.Tensor]:
333
+ return self.k
334
+
335
+ @property
336
+ def value(self) -> Dict[int, torch.Tensor]:
337
+ return self.v
338
+
339
+ @property
340
+ def shrinkage(self) -> Dict[int, torch.Tensor]:
341
+ return self.s
342
+
343
+ @property
344
+ def selection(self) -> Dict[int, torch.Tensor]:
345
+ return self.e
346
+
347
+ def __contains__(self, key):
348
+ return key in self.v
matanyone/inference/memory_manager.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from omegaconf import DictConfig
3
+ from typing import List, Dict
4
+ import torch
5
+ import cv2
6
+
7
+ from matanyone.inference.object_manager import ObjectManager
8
+ from matanyone.inference.kv_memory_store import KeyValueMemoryStore
9
+ from matanyone.model.matanyone import MatAnyone
10
+ from matanyone.model.utils.memory_utils import *
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class MemoryManager:
16
+ """
17
+ Manages all three memory stores and the transition between working/long-term memory
18
+ """
19
+ def __init__(self, cfg: DictConfig, object_manager: ObjectManager):
20
+ self.object_manager = object_manager
21
+ self.sensory_dim = cfg.model.sensory_dim
22
+ self.top_k = cfg.top_k
23
+ self.chunk_size = cfg.chunk_size
24
+
25
+ self.save_aux = cfg.save_aux
26
+
27
+ self.use_long_term = cfg.use_long_term
28
+ self.count_long_term_usage = cfg.long_term.count_usage
29
+ # subtract 1 because the first-frame is now counted as "permanent memory"
30
+ # and is not counted towards max_mem_frames
31
+ # but we want to keep the hyperparameters consistent as before for the same behavior
32
+ if self.use_long_term:
33
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
34
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
35
+ self.num_prototypes = cfg.long_term.num_prototypes
36
+ self.max_long_tokens = cfg.long_term.max_num_tokens
37
+ self.buffer_tokens = cfg.long_term.buffer_tokens
38
+ else:
39
+ self.max_mem_frames = cfg.max_mem_frames - 1
40
+
41
+ # dimensions will be inferred from input later
42
+ self.CK = self.CV = None
43
+ self.H = self.W = None
44
+
45
+ # The sensory memory is stored as a dictionary indexed by object ids
46
+ # each of shape bs * C^h * H * W
47
+ self.sensory = {}
48
+
49
+ # a dictionary indexed by object ids, each of shape bs * T * Q * C
50
+ self.obj_v = {}
51
+
52
+ self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
53
+ save_usage=self.use_long_term)
54
+ if self.use_long_term:
55
+ self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage)
56
+
57
+ self.config_stale = True
58
+ self.engaged = False
59
+
60
+ def update_config(self, cfg: DictConfig) -> None:
61
+ self.config_stale = True
62
+ self.top_k = cfg['top_k']
63
+
64
+ assert self.use_long_term == cfg.use_long_term, 'cannot update this'
65
+ assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this'
66
+
67
+ self.use_long_term = cfg.use_long_term
68
+ self.count_long_term_usage = cfg.long_term.count_usage
69
+ if self.use_long_term:
70
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
71
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
72
+ self.num_prototypes = cfg.long_term.num_prototypes
73
+ self.max_long_tokens = cfg.long_term.max_num_tokens
74
+ self.buffer_tokens = cfg.long_term.buffer_tokens
75
+ else:
76
+ self.max_mem_frames = cfg.max_mem_frames - 1
77
+
78
+ def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor:
79
+ # affinity: bs*N*HW
80
+ # v: bs*C*N or bs*num_objects*C*N
81
+ # returns bs*C*HW or bs*num_objects*C*HW
82
+ if len(v.shape) == 3:
83
+ # single object
84
+ if uncert_mask is not None:
85
+ return v @ affinity * uncert_mask
86
+ else:
87
+ return v @ affinity
88
+ else:
89
+ bs, num_objects, C, N = v.shape
90
+ v = v.view(bs, num_objects * C, N)
91
+ out = v @ affinity
92
+ if uncert_mask is not None:
93
+ uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1)
94
+ out = out * uncert_mask
95
+ return out.view(bs, num_objects, C, -1)
96
+
97
+ def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor:
98
+ # -1 because the mask does not contain the background channel
99
+ return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]]
100
+
101
+ def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
102
+ return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)
103
+
104
+ def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
105
+ return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)
106
+
107
+ def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
108
+ # All the values that the object ids refer to should have the same shape
109
+ value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1)
110
+ if self.use_long_term and obj_ids[0] in self.long_mem.value:
111
+ lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1)
112
+ value = torch.cat([lt_value, value], dim=-1)
113
+
114
+ return value
115
+
116
+ def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
117
+ last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]:
118
+ """
119
+ Read from all memory stores and returns a single memory readout tensor for each object
120
+
121
+ pix_feat: (1/2) x C x H x W
122
+ query_key: (1/2) x C^k x H x W
123
+ selection: (1/2) x C^k x H x W
124
+ last_mask: (1/2) x num_objects x H x W (at stride 16)
125
+ return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
126
+ """
127
+ h, w = pix_feat.shape[-2:]
128
+ bs = pix_feat.shape[0]
129
+ assert last_mask.shape[0] == bs
130
+
131
+ uncert_mask = uncert_output["mask"] if uncert_output is not None else None
132
+
133
+ """
134
+ Compute affinity and perform readout
135
+ """
136
+ all_readout_mem = {}
137
+ buckets = self.work_mem.buckets
138
+ for bucket_id, bucket in buckets.items():
139
+
140
+ if self.chunk_size < 1:
141
+ object_chunks = [bucket]
142
+ else:
143
+ object_chunks = [
144
+ bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
145
+ ]
146
+
147
+ for objects in object_chunks:
148
+ this_sensory = self._get_sensory_by_ids(objects)
149
+ this_last_mask = self._get_mask_by_ids(last_mask, objects)
150
+ this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
151
+ pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory,
152
+ this_last_mask)
153
+ this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
154
+ readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
155
+ for i, obj in enumerate(objects):
156
+ all_readout_mem[obj] = readout_memory[:, i]
157
+
158
+ if self.save_aux:
159
+ aux_output = {
160
+ # 'sensory': this_sensory,
161
+ # 'pixel_readout': pixel_readout,
162
+ 'q_logits': aux_features['logits'] if aux_features else None,
163
+ # 'q_weights': aux_features['q_weights'] if aux_features else None,
164
+ # 'p_weights': aux_features['p_weights'] if aux_features else None,
165
+ # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
166
+ }
167
+ self.aux = aux_output
168
+
169
+ return all_readout_mem
170
+
171
+ def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
172
+ last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None,
173
+ last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
174
+ """
175
+ Read from all memory stores and returns a single memory readout tensor for each object
176
+
177
+ pix_feat: (1/2) x C x H x W
178
+ query_key: (1/2) x C^k x H x W
179
+ selection: (1/2) x C^k x H x W
180
+ last_mask: (1/2) x num_objects x H x W (at stride 16)
181
+ return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
182
+ """
183
+ h, w = pix_feat.shape[-2:]
184
+ bs = pix_feat.shape[0]
185
+ assert query_key.shape[0] == bs
186
+ assert selection.shape[0] == bs
187
+ assert last_mask.shape[0] == bs
188
+
189
+ uncert_mask = uncert_output["mask"] if uncert_output is not None else None
190
+
191
+ query_key = query_key.flatten(start_dim=2) # bs*C^k*HW
192
+ selection = selection.flatten(start_dim=2) # bs*C^k*HW
193
+ """
194
+ Compute affinity and perform readout
195
+ """
196
+ all_readout_mem = {}
197
+ buckets = self.work_mem.buckets
198
+ for bucket_id, bucket in buckets.items():
199
+ if self.use_long_term and self.long_mem.engaged(bucket_id):
200
+ # Use long-term memory
201
+ long_mem_size = self.long_mem.size(bucket_id)
202
+ memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]],
203
+ -1)
204
+ shrinkage = torch.cat(
205
+ [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1)
206
+
207
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
208
+ affinity, usage = do_softmax(similarity,
209
+ top_k=self.top_k,
210
+ inplace=True,
211
+ return_usage=True)
212
+ """
213
+ Record memory usage for working and long-term memory
214
+ """
215
+ # ignore the index return for long-term memory
216
+ work_usage = usage[:, long_mem_size:]
217
+ self.work_mem.update_bucket_usage(bucket_id, work_usage)
218
+
219
+ if self.count_long_term_usage:
220
+ # ignore the index return for working memory
221
+ long_usage = usage[:, :long_mem_size]
222
+ self.long_mem.update_bucket_usage(bucket_id, long_usage)
223
+ else:
224
+ # no long-term memory
225
+ memory_key = self.work_mem.key[bucket_id]
226
+ shrinkage = self.work_mem.shrinkage[bucket_id]
227
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask)
228
+
229
+ if self.use_long_term:
230
+ affinity, usage = do_softmax(similarity,
231
+ top_k=self.top_k,
232
+ inplace=True,
233
+ return_usage=True)
234
+ self.work_mem.update_bucket_usage(bucket_id, usage)
235
+ else:
236
+ affinity = do_softmax(similarity, top_k=self.top_k, inplace=True)
237
+
238
+ if self.chunk_size < 1:
239
+ object_chunks = [bucket]
240
+ else:
241
+ object_chunks = [
242
+ bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
243
+ ]
244
+
245
+ for objects in object_chunks:
246
+ this_sensory = self._get_sensory_by_ids(objects)
247
+ this_last_mask = self._get_mask_by_ids(last_mask, objects)
248
+ this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
249
+ visual_readout = self._readout(affinity,
250
+ this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w)
251
+
252
+ uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0])
253
+
254
+ if uncert_output is not None:
255
+ uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
256
+ visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob)
257
+
258
+ pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
259
+ this_last_mask)
260
+ this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
261
+ readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
262
+ for i, obj in enumerate(objects):
263
+ all_readout_mem[obj] = readout_memory[:, i]
264
+
265
+ if self.save_aux:
266
+ aux_output = {
267
+ # 'sensory': this_sensory,
268
+ # 'pixel_readout': pixel_readout,
269
+ 'q_logits': aux_features['logits'] if aux_features else None,
270
+ # 'q_weights': aux_features['q_weights'] if aux_features else None,
271
+ # 'p_weights': aux_features['p_weights'] if aux_features else None,
272
+ # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
273
+ }
274
+ self.aux = aux_output
275
+
276
+ return all_readout_mem
277
+
278
+ def add_memory(self,
279
+ key: torch.Tensor,
280
+ shrinkage: torch.Tensor,
281
+ msk_value: torch.Tensor,
282
+ obj_value: torch.Tensor,
283
+ objects: List[int],
284
+ selection: torch.Tensor = None,
285
+ *,
286
+ as_permanent: bool = False) -> None:
287
+ # key: (1/2)*C*H*W
288
+ # msk_value: (1/2)*num_objects*C*H*W
289
+ # obj_value: (1/2)*num_objects*Q*C
290
+ # objects contains a list of object ids corresponding to the objects in msk_value/obj_value
291
+ bs = key.shape[0]
292
+ assert shrinkage.shape[0] == bs
293
+ assert msk_value.shape[0] == bs
294
+ assert obj_value.shape[0] == bs
295
+
296
+ self.engaged = True
297
+ if self.H is None or self.config_stale:
298
+ self.config_stale = False
299
+ self.H, self.W = msk_value.shape[-2:]
300
+ self.HW = self.H * self.W
301
+ # convert from num. frames to num. tokens
302
+ self.max_work_tokens = self.max_mem_frames * self.HW
303
+ if self.use_long_term:
304
+ self.min_work_tokens = self.min_mem_frames * self.HW
305
+
306
+ # key: bs*C*N
307
+ # value: bs*num_objects*C*N
308
+ key = key.flatten(start_dim=2)
309
+ shrinkage = shrinkage.flatten(start_dim=2)
310
+ self.CK = key.shape[1]
311
+
312
+ msk_value = msk_value.flatten(start_dim=3)
313
+ self.CV = msk_value.shape[2]
314
+
315
+ if selection is not None:
316
+ # not used in non-long-term mode
317
+ selection = selection.flatten(start_dim=2)
318
+
319
+ # insert object values into object memory
320
+ for obj_id, obj in enumerate(objects):
321
+ if obj in self.obj_v:
322
+ """streaming average
323
+ each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
324
+ first embed_dim keeps track of the sum of embeddings
325
+ the last dim keeps the total count
326
+ averaging in done inside the object transformer
327
+
328
+ incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
329
+ self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
330
+ """
331
+ last_acc = self.obj_v[obj][:, :, -1]
332
+ new_acc = last_acc + obj_value[:, obj_id, :, -1]
333
+
334
+ self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
335
+ obj_value[:, obj_id, :, :-1])
336
+ self.obj_v[obj][:, :, -1] = new_acc
337
+ else:
338
+ self.obj_v[obj] = obj_value[:, obj_id]
339
+
340
+ # convert mask value tensor into a dict for insertion
341
+ msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
342
+ self.work_mem.add(key,
343
+ msk_values,
344
+ shrinkage,
345
+ selection=selection,
346
+ as_permanent=as_permanent)
347
+
348
+ for bucket_id in self.work_mem.buckets.keys():
349
+ # long-term memory cleanup
350
+ if self.use_long_term:
351
+ # Do memory compressed if needed
352
+ if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
353
+ # Remove obsolete features if needed
354
+ if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
355
+ self.num_prototypes):
356
+ self.long_mem.remove_obsolete_features(
357
+ bucket_id,
358
+ self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
359
+
360
+ self.compress_features(bucket_id)
361
+ else:
362
+ # FIFO
363
+ self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens)
364
+
365
+ def purge_except(self, obj_keep_idx: List[int]) -> None:
366
+ # purge certain objects from the memory except the one listed
367
+ self.work_mem.purge_except(obj_keep_idx)
368
+ if self.use_long_term and self.long_mem.engaged():
369
+ self.long_mem.purge_except(obj_keep_idx)
370
+ self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx}
371
+
372
+ if not self.work_mem.engaged():
373
+ # everything is removed!
374
+ self.engaged = False
375
+
376
+ def compress_features(self, bucket_id: int) -> None:
377
+ HW = self.HW
378
+
379
+ # perform memory consolidation
380
+ prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
381
+ *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens))
382
+
383
+ # remove consolidated working memory
384
+ self.work_mem.sieve_by_range(bucket_id,
385
+ 0,
386
+ -self.min_work_tokens,
387
+ min_size=self.min_work_tokens)
388
+
389
+ # add to long-term memory
390
+ self.long_mem.add(prototype_key,
391
+ prototype_value,
392
+ prototype_shrinkage,
393
+ selection=None,
394
+ supposed_bucket_id=bucket_id)
395
+
396
+ def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor,
397
+ candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor],
398
+ usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
399
+ # find the indices with max usage
400
+ bs = candidate_key.shape[0]
401
+ assert bs in [1, 2]
402
+
403
+ prototype_key = []
404
+ prototype_selection = []
405
+ for bi in range(bs):
406
+ _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True)
407
+ prototype_indices = max_usage_indices.flatten()
408
+ prototype_key.append(candidate_key[bi, :, prototype_indices])
409
+ prototype_selection.append(candidate_selection[bi, :, prototype_indices])
410
+ prototype_key = torch.stack(prototype_key, dim=0)
411
+ prototype_selection = torch.stack(prototype_selection, dim=0)
412
+ """
413
+ Potentiation step
414
+ """
415
+ similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key,
416
+ prototype_selection)
417
+ affinity = do_softmax(similarity)
418
+
419
+ # readout the values
420
+ prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()}
421
+
422
+ # readout the shrinkage term
423
+ prototype_shrinkage = self._readout(affinity, candidate_shrinkage)
424
+
425
+ return prototype_key, prototype_value, prototype_shrinkage
426
+
427
+ def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]):
428
+ for obj in ids:
429
+ if obj not in self.sensory:
430
+ # also initializes the sensory memory
431
+ bs, _, h, w = sample_key.shape
432
+ self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w),
433
+ device=sample_key.device)
434
+
435
+ def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
436
+ # sensory: 1*num_objects*C*H*W
437
+ for obj_id, obj in enumerate(ids):
438
+ self.sensory[obj] = sensory[:, obj_id]
439
+
440
+ def get_sensory(self, ids: List[int]):
441
+ # returns (1/2)*num_objects*C*H*W
442
+ return self._get_sensory_by_ids(ids)
443
+
444
+ def clear_non_permanent_memory(self):
445
+ self.work_mem.clear_non_permanent_memory()
446
+ if self.use_long_term:
447
+ self.long_mem.clear_non_permanent_memory()
448
+
449
+ def clear_sensory_memory(self):
450
+ self.sensory = {}
451
+
452
+ def clear_work_mem(self):
453
+ self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
454
+ save_usage=self.use_long_term)
455
+
456
+ def clear_obj_mem(self):
457
+ self.obj_v = {}
matanyone/inference/object_info.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ObjectInfo:
2
+ """
3
+ Store meta information for an object
4
+ """
5
+ def __init__(self, id: int):
6
+ self.id = id
7
+ self.poke_count = 0 # count number of detections missed
8
+
9
+ def poke(self) -> None:
10
+ self.poke_count += 1
11
+
12
+ def unpoke(self) -> None:
13
+ self.poke_count = 0
14
+
15
+ def __hash__(self):
16
+ return hash(self.id)
17
+
18
+ def __eq__(self, other):
19
+ if type(other) == int:
20
+ return self.id == other
21
+ return self.id == other.id
22
+
23
+ def __repr__(self):
24
+ return f'(ID: {self.id})'
matanyone/inference/object_manager.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Dict
2
+
3
+ import torch
4
+ from matanyone.inference.object_info import ObjectInfo
5
+
6
+
7
+ class ObjectManager:
8
+ """
9
+ Object IDs are immutable. The same ID always represent the same object.
10
+ Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
11
+ Temporary IDs start from 1.
12
+ """
13
+
14
+ def __init__(self):
15
+ self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
16
+ self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
17
+ self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
18
+
19
+ self.all_historical_object_ids: List[int] = []
20
+
21
+ def _recompute_obj_id_to_obj_mapping(self) -> None:
22
+ self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
23
+
24
+ def add_new_objects(
25
+ self, objects: Union[List[ObjectInfo], ObjectInfo,
26
+ List[int]]) -> (List[int], List[int]):
27
+ if not isinstance(objects, list):
28
+ objects = [objects]
29
+
30
+ corresponding_tmp_ids = []
31
+ corresponding_obj_ids = []
32
+ for obj in objects:
33
+ if isinstance(obj, int):
34
+ obj = ObjectInfo(id=obj)
35
+
36
+ if obj in self.obj_to_tmp_id:
37
+ # old object
38
+ corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
39
+ corresponding_obj_ids.append(obj.id)
40
+ else:
41
+ # new object
42
+ new_obj = ObjectInfo(id=obj.id)
43
+
44
+ # new object
45
+ new_tmp_id = len(self.obj_to_tmp_id) + 1
46
+ self.obj_to_tmp_id[new_obj] = new_tmp_id
47
+ self.tmp_id_to_obj[new_tmp_id] = new_obj
48
+ self.all_historical_object_ids.append(new_obj.id)
49
+ corresponding_tmp_ids.append(new_tmp_id)
50
+ corresponding_obj_ids.append(new_obj.id)
51
+
52
+ self._recompute_obj_id_to_obj_mapping()
53
+ assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
54
+ return corresponding_tmp_ids, corresponding_obj_ids
55
+
56
+ def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
57
+ # delete an object or a list of objects
58
+ # re-sort the tmp ids
59
+ if isinstance(obj_ids_to_remove, int):
60
+ obj_ids_to_remove = [obj_ids_to_remove]
61
+
62
+ new_tmp_id = 1
63
+ total_num_id = len(self.obj_to_tmp_id)
64
+
65
+ local_obj_to_tmp_id = {}
66
+ local_tmp_to_obj_id = {}
67
+
68
+ for tmp_iter in range(1, total_num_id + 1):
69
+ obj = self.tmp_id_to_obj[tmp_iter]
70
+ if obj.id not in obj_ids_to_remove:
71
+ local_obj_to_tmp_id[obj] = new_tmp_id
72
+ local_tmp_to_obj_id[new_tmp_id] = obj
73
+ new_tmp_id += 1
74
+
75
+ self.obj_to_tmp_id = local_obj_to_tmp_id
76
+ self.tmp_id_to_obj = local_tmp_to_obj_id
77
+ self._recompute_obj_id_to_obj_mapping()
78
+
79
+ def purge_inactive_objects(self,
80
+ max_missed_detection_count: int) -> (bool, List[int], List[int]):
81
+ # remove tmp ids of objects that are removed
82
+ obj_id_to_be_deleted = []
83
+ tmp_id_to_be_deleted = []
84
+ tmp_id_to_keep = []
85
+ obj_id_to_keep = []
86
+
87
+ for obj in self.obj_to_tmp_id:
88
+ if obj.poke_count > max_missed_detection_count:
89
+ obj_id_to_be_deleted.append(obj.id)
90
+ tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
91
+ else:
92
+ tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
93
+ obj_id_to_keep.append(obj.id)
94
+
95
+ purge_activated = len(obj_id_to_be_deleted) > 0
96
+ if purge_activated:
97
+ self.delete_objects(obj_id_to_be_deleted)
98
+ return purge_activated, tmp_id_to_keep, obj_id_to_keep
99
+
100
+ def tmp_to_obj_cls(self, mask) -> torch.Tensor:
101
+ # remap tmp id cls representation to the true object id representation
102
+ new_mask = torch.zeros_like(mask)
103
+ for tmp_id, obj in self.tmp_id_to_obj.items():
104
+ new_mask[mask == tmp_id] = obj.id
105
+ return new_mask
106
+
107
+ def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
108
+ # returns the mapping in a dict format for saving it with pickle
109
+ return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
110
+
111
+ def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
112
+ # turns a dict indexed by obj id into a tensor, ordered by tmp IDs
113
+ output = []
114
+ for _, obj in self.tmp_id_to_obj.items():
115
+ if obj.id not in obj_dict:
116
+ raise NotImplementedError
117
+ output.append(obj_dict[obj.id])
118
+ output = torch.stack(output, dim=dim)
119
+ return output
120
+
121
+ def make_one_hot(self, cls_mask) -> torch.Tensor:
122
+ output = []
123
+ for _, obj in self.tmp_id_to_obj.items():
124
+ output.append(cls_mask == obj.id)
125
+ if len(output) == 0:
126
+ output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
127
+ else:
128
+ output = torch.stack(output, dim=0)
129
+ return output
130
+
131
+ @property
132
+ def all_obj_ids(self) -> List[int]:
133
+ return [k.id for k in self.obj_to_tmp_id]
134
+
135
+ @property
136
+ def num_obj(self) -> int:
137
+ return len(self.obj_to_tmp_id)
138
+
139
+ def has_all(self, objects: List[int]) -> bool:
140
+ for obj in objects:
141
+ if obj not in self.obj_to_tmp_id:
142
+ return False
143
+ return True
144
+
145
+ def find_object_by_id(self, obj_id) -> ObjectInfo:
146
+ return self.obj_id_to_obj[obj_id]
147
+
148
+ def find_tmp_by_id(self, obj_id) -> int:
149
+ return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
matanyone/inference/utils/__init__.py ADDED
File without changes
matanyone/inference/utils/args_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from omegaconf import DictConfig
3
+
4
+ log = logging.getLogger()
5
+
6
+
7
+ def get_dataset_cfg(cfg: DictConfig):
8
+ dataset_name = cfg.dataset
9
+ data_cfg = cfg.datasets[dataset_name]
10
+
11
+ potential_overrides = [
12
+ 'image_directory',
13
+ 'mask_directory',
14
+ 'json_directory',
15
+ 'size',
16
+ 'save_all',
17
+ 'use_all_masks',
18
+ 'use_long_term',
19
+ 'mem_every',
20
+ ]
21
+
22
+ for override in potential_overrides:
23
+ if cfg[override] is not None:
24
+ log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
25
+ data_cfg[override] = cfg[override]
26
+ # escalte all potential overrides to the top-level config
27
+ if override in data_cfg:
28
+ cfg[override] = data_cfg[override]
29
+
30
+ return data_cfg
matanyone/model/__init__.py ADDED
File without changes
matanyone/model/aux_modules.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ For computing auxiliary outputs for auxiliary losses
3
+ """
4
+ from typing import Dict
5
+ from omegaconf import DictConfig
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from matanyone.model.group_modules import GConv2d
10
+ from matanyone.utils.tensor_utils import aggregate
11
+
12
+
13
+ class LinearPredictor(nn.Module):
14
+ def __init__(self, x_dim: int, pix_dim: int):
15
+ super().__init__()
16
+ self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
17
+
18
+ def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
19
+ # pixel_feat: B*pix_dim*H*W
20
+ # x: B*num_objects*x_dim*H*W
21
+ num_objects = x.shape[1]
22
+ x = self.projection(x)
23
+
24
+ pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
25
+ logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
26
+ return logits
27
+
28
+
29
+ class DirectPredictor(nn.Module):
30
+ def __init__(self, x_dim: int):
31
+ super().__init__()
32
+ self.projection = GConv2d(x_dim, 1, kernel_size=1)
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ # x: B*num_objects*x_dim*H*W
36
+ logits = self.projection(x).squeeze(2)
37
+ return logits
38
+
39
+
40
+ class AuxComputer(nn.Module):
41
+ def __init__(self, cfg: DictConfig):
42
+ super().__init__()
43
+
44
+ use_sensory_aux = cfg.model.aux_loss.sensory.enabled
45
+ self.use_query_aux = cfg.model.aux_loss.query.enabled
46
+ self.use_sensory_aux = use_sensory_aux
47
+
48
+ sensory_dim = cfg.model.sensory_dim
49
+ embed_dim = cfg.model.embed_dim
50
+
51
+ if use_sensory_aux:
52
+ self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
53
+
54
+ def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
55
+ prob = torch.sigmoid(logits)
56
+ if selector is not None:
57
+ prob = prob * selector
58
+ logits = aggregate(prob, dim=1)
59
+ return logits
60
+
61
+ def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
62
+ selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
63
+ sensory = aux_input['sensory']
64
+ q_logits = aux_input['q_logits']
65
+
66
+ aux_output = {}
67
+ aux_output['attn_mask'] = aux_input['attn_mask']
68
+
69
+ if self.use_sensory_aux:
70
+ # B*num_objects*H*W
71
+ logits = self.sensory_aux(pix_feat, sensory)
72
+ aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
73
+ if self.use_query_aux:
74
+ # B*num_objects*num_levels*H*W
75
+ aux_output['q_logits'] = self._aggregate_with_selector(
76
+ torch.stack(q_logits, dim=2),
77
+ selector.unsqueeze(2) if selector is not None else None)
78
+
79
+ return aux_output
80
+
81
+ def compute_mask(self, aux_input: Dict[str, torch.Tensor],
82
+ selector: torch.Tensor) -> Dict[str, torch.Tensor]:
83
+ # sensory = aux_input['sensory']
84
+ q_logits = aux_input['q_logits']
85
+
86
+ aux_output = {}
87
+
88
+ # B*num_objects*num_levels*H*W
89
+ aux_output['q_logits'] = self._aggregate_with_selector(
90
+ torch.stack(q_logits, dim=2),
91
+ selector.unsqueeze(2) if selector is not None else None)
92
+
93
+ return aux_output
matanyone/model/big_modules.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ big_modules.py - This file stores higher-level network blocks.
3
+
4
+ x - usually denotes features that are shared between objects.
5
+ g - usually denotes features that are not shared between objects
6
+ with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
7
+
8
+ The trailing number of a variable usually denotes the stride
9
+ """
10
+
11
+ from omegaconf import DictConfig
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from matanyone.model.group_modules import *
17
+ from matanyone.model.utils import resnet
18
+ from matanyone.model.modules import *
19
+
20
+ class UncertPred(nn.Module):
21
+ def __init__(self, model_cfg: DictConfig):
22
+ super().__init__()
23
+ self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False)
24
+ self.bn1 = nn.BatchNorm2d(64)
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
27
+ self.bn2 = nn.BatchNorm2d(32)
28
+ self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
29
+
30
+ def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
31
+ last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area')
32
+ x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1)
33
+ x = self.conv1x1_v2(x)
34
+ x = self.bn1(x)
35
+ x = self.relu(x)
36
+ x = self.conv3x3(x)
37
+ x = self.bn2(x)
38
+ x = self.relu(x)
39
+ x = self.conv3x3_out(x)
40
+ return x
41
+
42
+ # override the default train() to freeze BN statistics
43
+ def train(self, mode=True):
44
+ self.training = False
45
+ for module in self.children():
46
+ module.train(False)
47
+ return self
48
+
49
+ class PixelEncoder(nn.Module):
50
+ def __init__(self, model_cfg: DictConfig):
51
+ super().__init__()
52
+
53
+ self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
54
+ if self.is_resnet:
55
+ if model_cfg.pixel_encoder.type == 'resnet18':
56
+ network = resnet.resnet18(pretrained=True)
57
+ elif model_cfg.pixel_encoder.type == 'resnet50':
58
+ network = resnet.resnet50(pretrained=True)
59
+ else:
60
+ raise NotImplementedError
61
+ self.conv1 = network.conv1
62
+ self.bn1 = network.bn1
63
+ self.relu = network.relu
64
+ self.maxpool = network.maxpool
65
+
66
+ self.res2 = network.layer1
67
+ self.layer2 = network.layer2
68
+ self.layer3 = network.layer3
69
+ else:
70
+ raise NotImplementedError
71
+
72
+ def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
73
+ f1 = x
74
+ x = self.conv1(x)
75
+ x = self.bn1(x)
76
+ x = self.relu(x)
77
+ f2 = x
78
+ x = self.maxpool(x)
79
+ f4 = self.res2(x)
80
+ f8 = self.layer2(f4)
81
+ f16 = self.layer3(f8)
82
+
83
+ return f16, f8, f4, f2, f1
84
+
85
+ # override the default train() to freeze BN statistics
86
+ def train(self, mode=True):
87
+ self.training = False
88
+ for module in self.children():
89
+ module.train(False)
90
+ return self
91
+
92
+
93
+ class KeyProjection(nn.Module):
94
+ def __init__(self, model_cfg: DictConfig):
95
+ super().__init__()
96
+ in_dim = model_cfg.pixel_encoder.ms_dims[0]
97
+ mid_dim = model_cfg.pixel_dim
98
+ key_dim = model_cfg.key_dim
99
+
100
+ self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
101
+ self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
102
+ # shrinkage
103
+ self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
104
+ # selection
105
+ self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
106
+
107
+ nn.init.orthogonal_(self.key_proj.weight.data)
108
+ nn.init.zeros_(self.key_proj.bias.data)
109
+
110
+ def forward(self, x: torch.Tensor, *, need_s: bool,
111
+ need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
112
+ x = self.pix_feat_proj(x)
113
+ shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
114
+ selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
115
+
116
+ return self.key_proj(x), shrinkage, selection
117
+
118
+
119
+ class MaskEncoder(nn.Module):
120
+ def __init__(self, model_cfg: DictConfig, single_object=False):
121
+ super().__init__()
122
+ pixel_dim = model_cfg.pixel_dim
123
+ value_dim = model_cfg.value_dim
124
+ sensory_dim = model_cfg.sensory_dim
125
+ final_dim = model_cfg.mask_encoder.final_dim
126
+
127
+ self.single_object = single_object
128
+ extra_dim = 1 if single_object else 2
129
+
130
+ if model_cfg.mask_encoder.type == 'resnet18':
131
+ network = resnet.resnet18(pretrained=True, extra_dim=extra_dim)
132
+ elif model_cfg.mask_encoder.type == 'resnet50':
133
+ network = resnet.resnet50(pretrained=True, extra_dim=extra_dim)
134
+ else:
135
+ raise NotImplementedError
136
+ self.conv1 = network.conv1
137
+ self.bn1 = network.bn1
138
+ self.relu = network.relu
139
+ self.maxpool = network.maxpool
140
+
141
+ self.layer1 = network.layer1
142
+ self.layer2 = network.layer2
143
+ self.layer3 = network.layer3
144
+
145
+ self.distributor = MainToGroupDistributor()
146
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
147
+
148
+ self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
149
+
150
+ def forward(self,
151
+ image: torch.Tensor,
152
+ pix_feat: torch.Tensor,
153
+ sensory: torch.Tensor,
154
+ masks: torch.Tensor,
155
+ others: torch.Tensor,
156
+ *,
157
+ deep_update: bool = True,
158
+ chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
159
+ # ms_features are from the key encoder
160
+ # we only use the first one (lowest resolution), following XMem
161
+ if self.single_object:
162
+ g = masks.unsqueeze(2)
163
+ else:
164
+ g = torch.stack([masks, others], dim=2)
165
+
166
+ g = self.distributor(image, g)
167
+
168
+ batch_size, num_objects = g.shape[:2]
169
+ if chunk_size < 1 or chunk_size >= num_objects:
170
+ chunk_size = num_objects
171
+ fast_path = True
172
+ new_sensory = sensory
173
+ else:
174
+ if deep_update:
175
+ new_sensory = torch.empty_like(sensory)
176
+ else:
177
+ new_sensory = sensory
178
+ fast_path = False
179
+
180
+ # chunk-by-chunk inference
181
+ all_g = []
182
+ for i in range(0, num_objects, chunk_size):
183
+ if fast_path:
184
+ g_chunk = g
185
+ else:
186
+ g_chunk = g[:, i:i + chunk_size]
187
+ actual_chunk_size = g_chunk.shape[1]
188
+ g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
189
+
190
+ g_chunk = self.conv1(g_chunk)
191
+ g_chunk = self.bn1(g_chunk) # 1/2, 64
192
+ g_chunk = self.maxpool(g_chunk) # 1/4, 64
193
+ g_chunk = self.relu(g_chunk)
194
+
195
+ g_chunk = self.layer1(g_chunk) # 1/4
196
+ g_chunk = self.layer2(g_chunk) # 1/8
197
+ g_chunk = self.layer3(g_chunk) # 1/16
198
+
199
+ g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
200
+ g_chunk = self.fuser(pix_feat, g_chunk)
201
+ all_g.append(g_chunk)
202
+ if deep_update:
203
+ if fast_path:
204
+ new_sensory = self.sensory_update(g_chunk, sensory)
205
+ else:
206
+ new_sensory[:, i:i + chunk_size] = self.sensory_update(
207
+ g_chunk, sensory[:, i:i + chunk_size])
208
+ g = torch.cat(all_g, dim=1)
209
+
210
+ return g, new_sensory
211
+
212
+ # override the default train() to freeze BN statistics
213
+ def train(self, mode=True):
214
+ self.training = False
215
+ for module in self.children():
216
+ module.train(False)
217
+ return self
218
+
219
+
220
+ class PixelFeatureFuser(nn.Module):
221
+ def __init__(self, model_cfg: DictConfig, single_object=False):
222
+ super().__init__()
223
+ value_dim = model_cfg.value_dim
224
+ sensory_dim = model_cfg.sensory_dim
225
+ pixel_dim = model_cfg.pixel_dim
226
+ embed_dim = model_cfg.embed_dim
227
+ self.single_object = single_object
228
+
229
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
230
+ if self.single_object:
231
+ self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
232
+ else:
233
+ self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
234
+
235
+ def forward(self,
236
+ pix_feat: torch.Tensor,
237
+ pixel_memory: torch.Tensor,
238
+ sensory_memory: torch.Tensor,
239
+ last_mask: torch.Tensor,
240
+ last_others: torch.Tensor,
241
+ *,
242
+ chunk_size: int = -1) -> torch.Tensor:
243
+ batch_size, num_objects = pixel_memory.shape[:2]
244
+
245
+ if self.single_object:
246
+ last_mask = last_mask.unsqueeze(2)
247
+ else:
248
+ last_mask = torch.stack([last_mask, last_others], dim=2)
249
+
250
+ if chunk_size < 1:
251
+ chunk_size = num_objects
252
+
253
+ # chunk-by-chunk inference
254
+ all_p16 = []
255
+ for i in range(0, num_objects, chunk_size):
256
+ sensory_readout = self.sensory_compress(
257
+ torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
258
+ p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
259
+ p16 = self.fuser(pix_feat, p16)
260
+ all_p16.append(p16)
261
+ p16 = torch.cat(all_p16, dim=1)
262
+
263
+ return p16
264
+
265
+
266
+ class MaskDecoder(nn.Module):
267
+ def __init__(self, model_cfg: DictConfig):
268
+ super().__init__()
269
+ embed_dim = model_cfg.embed_dim
270
+ sensory_dim = model_cfg.sensory_dim
271
+ ms_image_dims = model_cfg.pixel_encoder.ms_dims
272
+ up_dims = model_cfg.mask_decoder.up_dims
273
+
274
+ assert embed_dim == up_dims[0]
275
+
276
+ self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim,
277
+ sensory_dim)
278
+
279
+ self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
280
+ self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
281
+ self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
282
+ # newly add for alpha matte
283
+ self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3])
284
+ self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4])
285
+
286
+ self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
287
+ self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
288
+
289
+ def forward(self,
290
+ ms_image_feat: Iterable[torch.Tensor],
291
+ memory_readout: torch.Tensor,
292
+ sensory: torch.Tensor,
293
+ *,
294
+ chunk_size: int = -1,
295
+ update_sensory: bool = True,
296
+ seg_pass: bool = False,
297
+ last_mask=None,
298
+ sigmoid_residual=False) -> (torch.Tensor, torch.Tensor):
299
+
300
+ batch_size, num_objects = memory_readout.shape[:2]
301
+ f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:])
302
+ if chunk_size < 1 or chunk_size >= num_objects:
303
+ chunk_size = num_objects
304
+ fast_path = True
305
+ new_sensory = sensory
306
+ else:
307
+ if update_sensory:
308
+ new_sensory = torch.empty_like(sensory)
309
+ else:
310
+ new_sensory = sensory
311
+ fast_path = False
312
+
313
+ # chunk-by-chunk inference
314
+ all_logits = []
315
+ for i in range(0, num_objects, chunk_size):
316
+ if fast_path:
317
+ p16 = memory_readout
318
+ else:
319
+ p16 = memory_readout[:, i:i + chunk_size]
320
+ actual_chunk_size = p16.shape[1]
321
+
322
+ p8 = self.up_16_8(p16, f8)
323
+ p4 = self.up_8_4(p8, f4)
324
+ p2 = self.up_4_2(p4, f2)
325
+ p1 = self.up_2_1(p2, f1)
326
+ with torch.cuda.amp.autocast(enabled=False):
327
+ if seg_pass:
328
+ if last_mask is not None:
329
+ res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
330
+ if sigmoid_residual:
331
+ res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
332
+ logits = last_mask + res
333
+ else:
334
+ logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
335
+ else:
336
+ if last_mask is not None:
337
+ res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
338
+ if sigmoid_residual:
339
+ res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
340
+ logits = last_mask + res
341
+ else:
342
+ logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
343
+ ## SensoryUpdater_fullscale
344
+ if update_sensory:
345
+ p1 = torch.cat(
346
+ [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
347
+ if fast_path:
348
+ new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory)
349
+ else:
350
+ new_sensory[:,
351
+ i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1],
352
+ sensory[:,
353
+ i:i + chunk_size])
354
+ all_logits.append(logits)
355
+ logits = torch.cat(all_logits, dim=0)
356
+ logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
357
+
358
+ return new_sensory, logits
matanyone/model/channel_attn.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class CAResBlock(nn.Module):
8
+ def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
9
+ super().__init__()
10
+ self.residual = residual
11
+ self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
12
+ self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
13
+
14
+ t = int((abs(math.log2(out_dim)) + 1) // 2)
15
+ k = t if t % 2 else t + 1
16
+ self.pool = nn.AdaptiveAvgPool2d(1)
17
+ self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
18
+
19
+ if self.residual:
20
+ if in_dim == out_dim:
21
+ self.downsample = nn.Identity()
22
+ else:
23
+ self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ r = x
27
+ x = self.conv1(F.relu(x))
28
+ x = self.conv2(F.relu(x))
29
+
30
+ b, c = x.shape[:2]
31
+ w = self.pool(x).view(b, 1, c)
32
+ w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
33
+
34
+ if self.residual:
35
+ x = x * w + self.downsample(r)
36
+ else:
37
+ x = x * w
38
+
39
+ return x
matanyone/model/group_modules.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from matanyone.model.channel_attn import CAResBlock
6
+
7
+ def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
8
+ align_corners: bool) -> torch.Tensor:
9
+ batch_size, num_objects = g.shape[:2]
10
+ g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
11
+ scale_factor=ratio,
12
+ mode=mode,
13
+ align_corners=align_corners)
14
+ g = g.view(batch_size, num_objects, *g.shape[1:])
15
+ return g
16
+
17
+
18
+ def upsample_groups(g: torch.Tensor,
19
+ ratio: float = 2,
20
+ mode: str = 'bilinear',
21
+ align_corners: bool = False) -> torch.Tensor:
22
+ return interpolate_groups(g, ratio, mode, align_corners)
23
+
24
+
25
+ def downsample_groups(g: torch.Tensor,
26
+ ratio: float = 1 / 2,
27
+ mode: str = 'area',
28
+ align_corners: bool = None) -> torch.Tensor:
29
+ return interpolate_groups(g, ratio, mode, align_corners)
30
+
31
+
32
+ class GConv2d(nn.Conv2d):
33
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
34
+ batch_size, num_objects = g.shape[:2]
35
+ g = super().forward(g.flatten(start_dim=0, end_dim=1))
36
+ return g.view(batch_size, num_objects, *g.shape[1:])
37
+
38
+
39
+ class GroupResBlock(nn.Module):
40
+ def __init__(self, in_dim: int, out_dim: int):
41
+ super().__init__()
42
+
43
+ if in_dim == out_dim:
44
+ self.downsample = nn.Identity()
45
+ else:
46
+ self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
47
+
48
+ self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
49
+ self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
50
+
51
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
52
+ out_g = self.conv1(F.relu(g))
53
+ out_g = self.conv2(F.relu(out_g))
54
+
55
+ g = self.downsample(g)
56
+
57
+ return out_g + g
58
+
59
+
60
+ class MainToGroupDistributor(nn.Module):
61
+ def __init__(self,
62
+ x_transform: Optional[nn.Module] = None,
63
+ g_transform: Optional[nn.Module] = None,
64
+ method: str = 'cat',
65
+ reverse_order: bool = False):
66
+ super().__init__()
67
+
68
+ self.x_transform = x_transform
69
+ self.g_transform = g_transform
70
+ self.method = method
71
+ self.reverse_order = reverse_order
72
+
73
+ def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
74
+ num_objects = g.shape[1]
75
+
76
+ if self.x_transform is not None:
77
+ x = self.x_transform(x)
78
+
79
+ if self.g_transform is not None:
80
+ g = self.g_transform(g)
81
+
82
+ if not skip_expand:
83
+ x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
84
+ if self.method == 'cat':
85
+ if self.reverse_order:
86
+ g = torch.cat([g, x], 2)
87
+ else:
88
+ g = torch.cat([x, g], 2)
89
+ elif self.method == 'add':
90
+ g = x + g
91
+ elif self.method == 'mulcat':
92
+ g = torch.cat([x * g, g], dim=2)
93
+ elif self.method == 'muladd':
94
+ g = x * g + g
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return g
99
+
100
+
101
+ class GroupFeatureFusionBlock(nn.Module):
102
+ def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
103
+ super().__init__()
104
+
105
+ x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
106
+ g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
107
+
108
+ self.distributor = MainToGroupDistributor(x_transform=x_transform,
109
+ g_transform=g_transform,
110
+ method='add')
111
+ self.block1 = CAResBlock(out_dim, out_dim)
112
+ self.block2 = CAResBlock(out_dim, out_dim)
113
+
114
+ def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
115
+ batch_size, num_objects = g.shape[:2]
116
+
117
+ g = self.distributor(x, g)
118
+
119
+ g = g.flatten(start_dim=0, end_dim=1)
120
+
121
+ g = self.block1(g)
122
+ g = self.block2(g)
123
+
124
+ g = g.view(batch_size, num_objects, *g.shape[1:])
125
+
126
+ return g
matanyone/model/matanyone.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import logging
3
+ from omegaconf import DictConfig
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from matanyone.model.modules import *
8
+ from matanyone.model.big_modules import *
9
+ from matanyone.model.aux_modules import AuxComputer
10
+ from matanyone.model.utils.memory_utils import *
11
+ from matanyone.model.transformer.object_transformer import QueryTransformer
12
+ from matanyone.model.transformer.object_summarizer import ObjectSummarizer
13
+ from matanyone.utils.tensor_utils import aggregate
14
+
15
+ log = logging.getLogger()
16
+
17
+
18
+ class MatAnyone(nn.Module):
19
+
20
+ def __init__(self, cfg: DictConfig, *, single_object=False):
21
+ super().__init__()
22
+ self.cfg = cfg
23
+ model_cfg = cfg.model
24
+ self.ms_dims = model_cfg.pixel_encoder.ms_dims
25
+ self.key_dim = model_cfg.key_dim
26
+ self.value_dim = model_cfg.value_dim
27
+ self.sensory_dim = model_cfg.sensory_dim
28
+ self.pixel_dim = model_cfg.pixel_dim
29
+ self.embed_dim = model_cfg.embed_dim
30
+ self.single_object = single_object
31
+
32
+ log.info(f'Single object: {self.single_object}')
33
+
34
+ self.pixel_encoder = PixelEncoder(model_cfg)
35
+ self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
36
+ self.key_proj = KeyProjection(model_cfg)
37
+ self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
38
+ self.mask_decoder = MaskDecoder(model_cfg)
39
+ self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
40
+ self.object_transformer = QueryTransformer(model_cfg)
41
+ self.object_summarizer = ObjectSummarizer(model_cfg)
42
+ self.aux_computer = AuxComputer(cfg)
43
+ self.uncert_pred = UncertPred(model_cfg)
44
+
45
+ self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
46
+ self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
47
+
48
+ def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
49
+ # for each object, return the sum of masks of all other objects
50
+ if self.single_object:
51
+ return None
52
+
53
+ num_objects = masks.shape[1]
54
+ if num_objects >= 1:
55
+ others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
56
+ else:
57
+ others = torch.zeros_like(masks)
58
+ return others
59
+
60
+ def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
61
+ logits = self.uncert_pred(last_frame_feat=last_pix_feat,
62
+ cur_frame_feat=cur_pix_feat,
63
+ last_mask=last_mask,
64
+ mem_val_diff=mem_val_diff)
65
+
66
+ prob = torch.sigmoid(logits)
67
+ mask = (prob > 0) + 0
68
+
69
+ uncert_output = {"logits": logits,
70
+ "prob": prob,
71
+ "mask": mask}
72
+
73
+ return uncert_output
74
+
75
+ def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
76
+ image = (image - self.pixel_mean) / self.pixel_std
77
+ ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
78
+ return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
79
+
80
+ def encode_mask(
81
+ self,
82
+ image: torch.Tensor,
83
+ ms_features: List[torch.Tensor],
84
+ sensory: torch.Tensor,
85
+ masks: torch.Tensor,
86
+ *,
87
+ deep_update: bool = True,
88
+ chunk_size: int = -1,
89
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
90
+ image = (image - self.pixel_mean) / self.pixel_std
91
+ others = self._get_others(masks)
92
+ mask_value, new_sensory = self.mask_encoder(image,
93
+ ms_features,
94
+ sensory,
95
+ masks,
96
+ others,
97
+ deep_update=deep_update,
98
+ chunk_size=chunk_size)
99
+ object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
100
+ return mask_value, new_sensory, object_summaries, object_logits
101
+
102
+ def transform_key(self,
103
+ final_pix_feat: torch.Tensor,
104
+ *,
105
+ need_sk: bool = True,
106
+ need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor):
107
+ key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
108
+ return key, shrinkage, selection
109
+
110
+ # Used in training only.
111
+ # This step is replaced by MemoryManager in test time
112
+ def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
113
+ memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
114
+ msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
115
+ sensory: torch.Tensor, last_mask: torch.Tensor,
116
+ selector: torch.Tensor, uncert_output=None, seg_pass=False,
117
+ last_pix_feat=None, last_pred_mask=None) -> (torch.Tensor, Dict[str, torch.Tensor]):
118
+ """
119
+ query_key : B * CK * H * W
120
+ query_selection : B * CK * H * W
121
+ memory_key : B * CK * T * H * W
122
+ memory_shrinkage: B * 1 * T * H * W
123
+ msk_value : B * num_objects * CV * T * H * W
124
+ obj_memory : B * num_objects * T * num_summaries * C
125
+ pixel_feature : B * C * H * W
126
+ """
127
+ batch_size, num_objects = msk_value.shape[:2]
128
+
129
+ uncert_mask = uncert_output["mask"] if uncert_output is not None else None
130
+
131
+ # read using visual attention
132
+ with torch.cuda.amp.autocast(enabled=False):
133
+ affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
134
+ query_selection.float(), uncert_mask=uncert_mask)
135
+
136
+ msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
137
+
138
+ # B * (num_objects*CV) * H * W
139
+ pixel_readout = readout(affinity, msk_value, uncert_mask)
140
+ pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
141
+ *pixel_readout.shape[-2:])
142
+
143
+ uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1])
144
+ uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
145
+ pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob)
146
+
147
+ pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
148
+
149
+
150
+ # read from query transformer
151
+ mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
152
+
153
+ aux_output = {
154
+ 'sensory': sensory,
155
+ 'q_logits': aux_features['logits'] if aux_features else None,
156
+ 'attn_mask': aux_features['attn_mask'] if aux_features else None,
157
+ }
158
+
159
+ return mem_readout, aux_output, uncert_output
160
+
161
+ def read_first_frame_memory(self, pixel_readout,
162
+ obj_memory: torch.Tensor, pix_feat: torch.Tensor,
163
+ sensory: torch.Tensor, last_mask: torch.Tensor,
164
+ selector: torch.Tensor, seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
165
+ """
166
+ query_key : B * CK * H * W
167
+ query_selection : B * CK * H * W
168
+ memory_key : B * CK * T * H * W
169
+ memory_shrinkage: B * 1 * T * H * W
170
+ msk_value : B * num_objects * CV * T * H * W
171
+ obj_memory : B * num_objects * T * num_summaries * C
172
+ pixel_feature : B * C * H * W
173
+ """
174
+
175
+ pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
176
+
177
+ # read from query transformer
178
+ mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
179
+
180
+ aux_output = {
181
+ 'sensory': sensory,
182
+ 'q_logits': aux_features['logits'] if aux_features else None,
183
+ 'attn_mask': aux_features['attn_mask'] if aux_features else None,
184
+ }
185
+
186
+ return mem_readout, aux_output
187
+
188
+ def pixel_fusion(self,
189
+ pix_feat: torch.Tensor,
190
+ pixel: torch.Tensor,
191
+ sensory: torch.Tensor,
192
+ last_mask: torch.Tensor,
193
+ *,
194
+ chunk_size: int = -1) -> torch.Tensor:
195
+ last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
196
+ last_others = self._get_others(last_mask)
197
+ fused = self.pixel_fuser(pix_feat,
198
+ pixel,
199
+ sensory,
200
+ last_mask,
201
+ last_others,
202
+ chunk_size=chunk_size)
203
+ return fused
204
+
205
+ def readout_query(self,
206
+ pixel_readout,
207
+ obj_memory,
208
+ *,
209
+ selector=None,
210
+ need_weights=False,
211
+ seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
212
+ return self.object_transformer(pixel_readout,
213
+ obj_memory,
214
+ selector=selector,
215
+ need_weights=need_weights,
216
+ seg_pass=seg_pass)
217
+
218
+ def segment(self,
219
+ ms_image_feat: List[torch.Tensor],
220
+ memory_readout: torch.Tensor,
221
+ sensory: torch.Tensor,
222
+ *,
223
+ selector: bool = None,
224
+ chunk_size: int = -1,
225
+ update_sensory: bool = True,
226
+ seg_pass: bool = False,
227
+ clamp_mat: bool = True,
228
+ last_mask=None,
229
+ sigmoid_residual=False,
230
+ seg_mat=False) -> (torch.Tensor, torch.Tensor, torch.Tensor):
231
+ """
232
+ multi_scale_features is from the key encoder for skip-connection
233
+ memory_readout is from working/long-term memory
234
+ sensory is the sensory memory
235
+ last_mask is the mask from the last frame, supplementing sensory memory
236
+ selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
237
+ during training.
238
+ """
239
+ #### use mat head for seg data
240
+ if seg_mat:
241
+ assert seg_pass
242
+ seg_pass = False
243
+ ####
244
+ sensory, logits = self.mask_decoder(ms_image_feat,
245
+ memory_readout,
246
+ sensory,
247
+ chunk_size=chunk_size,
248
+ update_sensory=update_sensory,
249
+ seg_pass = seg_pass,
250
+ last_mask=last_mask,
251
+ sigmoid_residual=sigmoid_residual)
252
+ if seg_pass:
253
+ prob = torch.sigmoid(logits)
254
+ if selector is not None:
255
+ prob = prob * selector
256
+
257
+ # Softmax over all objects[]
258
+ logits = aggregate(prob, dim=1)
259
+ prob = F.softmax(logits, dim=1)
260
+ else:
261
+ if clamp_mat:
262
+ logits = logits.clamp(0.0, 1.0)
263
+ logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1)
264
+ prob = logits
265
+
266
+ return sensory, logits, prob
267
+
268
+ def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
269
+ selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
270
+ return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass)
271
+
272
+ def forward(self, *args, **kwargs):
273
+ raise NotImplementedError
274
+
275
+ def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
276
+ if not self.single_object:
277
+ # Map single-object weight to multi-object weight (4->5 out channels in conv1)
278
+ for k in list(src_dict.keys()):
279
+ if k == 'mask_encoder.conv1.weight':
280
+ if src_dict[k].shape[1] == 4:
281
+ log.info(f'Converting {k} from single object to multiple objects.')
282
+ pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
283
+ if not init_as_zero_if_needed:
284
+ nn.init.orthogonal_(pads)
285
+ log.info(f'Randomly initialized padding for {k}.')
286
+ else:
287
+ log.info(f'Zero-initialized padding for {k}.')
288
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
289
+ elif k == 'pixel_fuser.sensory_compress.weight':
290
+ if src_dict[k].shape[1] == self.sensory_dim + 1:
291
+ log.info(f'Converting {k} from single object to multiple objects.')
292
+ pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
293
+ if not init_as_zero_if_needed:
294
+ nn.init.orthogonal_(pads)
295
+ log.info(f'Randomly initialized padding for {k}.')
296
+ else:
297
+ log.info(f'Zero-initialized padding for {k}.')
298
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
299
+ elif self.single_object:
300
+ """
301
+ If the model is multiple-object and we are training in single-object,
302
+ we strip the last channel of conv1.
303
+ This is not supposed to happen in standard training except when users are trying to
304
+ finetune a trained model with single object datasets.
305
+ """
306
+ if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
307
+ log.warning(f'Converting mask_encoder.conv1.weight from multiple objects to single object.'
308
+ 'This is not supposed to happen in standard training.')
309
+ src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
310
+ src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
311
+
312
+ for k in src_dict:
313
+ if k not in self.state_dict():
314
+ log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
315
+ for k in self.state_dict():
316
+ if k not in src_dict:
317
+ log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
318
+
319
+ self.load_state_dict(src_dict, strict=False)
320
+
321
+ @property
322
+ def device(self) -> torch.device:
323
+ return self.pixel_mean.device
matanyone/model/modules.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterable
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from matanyone.model.group_modules import *
6
+
7
+
8
+ class UpsampleBlock(nn.Module):
9
+ def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
10
+ super().__init__()
11
+ self.out_conv = ResBlock(in_dim, out_dim)
12
+ self.scale_factor = scale_factor
13
+
14
+ def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
15
+ g = F.interpolate(in_g,
16
+ scale_factor=self.scale_factor,
17
+ mode='bilinear')
18
+ g = self.out_conv(g)
19
+ g = g + skip_f
20
+ return g
21
+
22
+ class MaskUpsampleBlock(nn.Module):
23
+ def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
24
+ super().__init__()
25
+ self.distributor = MainToGroupDistributor(method='add')
26
+ self.out_conv = GroupResBlock(in_dim, out_dim)
27
+ self.scale_factor = scale_factor
28
+
29
+ def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
30
+ g = upsample_groups(in_g, ratio=self.scale_factor)
31
+ g = self.distributor(skip_f, g)
32
+ g = self.out_conv(g)
33
+ return g
34
+
35
+
36
+ class DecoderFeatureProcessor(nn.Module):
37
+ def __init__(self, decoder_dims: List[int], out_dims: List[int]):
38
+ super().__init__()
39
+ self.transforms = nn.ModuleList([
40
+ nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
41
+ ])
42
+
43
+ def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
44
+ outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
45
+ return outputs
46
+
47
+
48
+ # @torch.jit.script
49
+ def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
50
+ # h: batch_size * num_objects * hidden_dim * h * w
51
+ # values: batch_size * num_objects * (hidden_dim*3) * h * w
52
+ dim = values.shape[2] // 3
53
+ forget_gate = torch.sigmoid(values[:, :, :dim])
54
+ update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
55
+ new_value = torch.tanh(values[:, :, dim * 2:])
56
+ new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
57
+ return new_h
58
+
59
+
60
+ class SensoryUpdater_fullscale(nn.Module):
61
+ # Used in the decoder, multi-scale feature + GRU
62
+ def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
63
+ super().__init__()
64
+ self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
65
+ self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
66
+ self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
67
+ self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
68
+ self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
69
+
70
+ self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
71
+
72
+ nn.init.xavier_normal_(self.transform.weight)
73
+
74
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
75
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
76
+ self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
77
+ self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
78
+ self.g1_conv(downsample_groups(g[4], ratio=1/16))
79
+
80
+ with torch.cuda.amp.autocast(enabled=False):
81
+ g = g.float()
82
+ h = h.float()
83
+ values = self.transform(torch.cat([g, h], dim=2))
84
+ new_h = _recurrent_update(h, values)
85
+
86
+ return new_h
87
+
88
+ class SensoryUpdater(nn.Module):
89
+ # Used in the decoder, multi-scale feature + GRU
90
+ def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
91
+ super().__init__()
92
+ self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
93
+ self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
94
+ self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
95
+
96
+ self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
97
+
98
+ nn.init.xavier_normal_(self.transform.weight)
99
+
100
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
101
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
102
+ self.g4_conv(downsample_groups(g[2], ratio=1/4))
103
+
104
+ with torch.cuda.amp.autocast(enabled=False):
105
+ g = g.float()
106
+ h = h.float()
107
+ values = self.transform(torch.cat([g, h], dim=2))
108
+ new_h = _recurrent_update(h, values)
109
+
110
+ return new_h
111
+
112
+
113
+ class SensoryDeepUpdater(nn.Module):
114
+ def __init__(self, f_dim: int, sensory_dim: int):
115
+ super().__init__()
116
+ self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
117
+
118
+ nn.init.xavier_normal_(self.transform.weight)
119
+
120
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
121
+ with torch.cuda.amp.autocast(enabled=False):
122
+ g = g.float()
123
+ h = h.float()
124
+ values = self.transform(torch.cat([g, h], dim=2))
125
+ new_h = _recurrent_update(h, values)
126
+
127
+ return new_h
128
+
129
+
130
+ class ResBlock(nn.Module):
131
+ def __init__(self, in_dim: int, out_dim: int):
132
+ super().__init__()
133
+
134
+ if in_dim == out_dim:
135
+ self.downsample = nn.Identity()
136
+ else:
137
+ self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
138
+
139
+ self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
140
+ self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
141
+
142
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
143
+ out_g = self.conv1(F.relu(g))
144
+ out_g = self.conv2(F.relu(out_g))
145
+
146
+ g = self.downsample(g)
147
+
148
+ return out_g + g
149
+
150
+ def __init__(self, in_dim, reduction_dim, bins):
151
+ super(PPM, self).__init__()
152
+ self.features = []
153
+ for bin in bins:
154
+ self.features.append(nn.Sequential(
155
+ nn.AdaptiveAvgPool2d(bin),
156
+ nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
157
+ nn.PReLU()
158
+ ))
159
+ self.features = nn.ModuleList(self.features)
160
+ self.fuse = nn.Sequential(
161
+ nn.Conv2d(in_dim+reduction_dim*4, in_dim, kernel_size=3, padding=1, bias=False),
162
+ nn.PReLU())
163
+
164
+ def forward(self, x):
165
+ x_size = x.size()
166
+ out = [x]
167
+ for f in self.features:
168
+ out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
169
+ out_feat = self.fuse(torch.cat(out, 1))
170
+ return out_feat
matanyone/model/transformer/__init__.py ADDED
File without changes
matanyone/model/transformer/object_summarizer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from omegaconf import DictConfig
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from matanyone.model.transformer.positional_encoding import PositionalEncoding
8
+
9
+
10
+ # @torch.jit.script
11
+ def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
12
+ logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
13
+ # value: B*num_objects*H*W*value_dim
14
+ # logits: B*num_objects*H*W*num_summaries
15
+ # masks: B*num_objects*H*W*num_summaries: 1 if allowed
16
+ weights = logits.sigmoid() * masks
17
+ # B*num_objects*num_summaries*value_dim
18
+ sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
19
+ # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
20
+ area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
21
+
22
+ # B*num_objects*num_summaries*value_dim
23
+ return sums, area
24
+
25
+
26
+ class ObjectSummarizer(nn.Module):
27
+ def __init__(self, model_cfg: DictConfig):
28
+ super().__init__()
29
+
30
+ this_cfg = model_cfg.object_summarizer
31
+ self.value_dim = model_cfg.value_dim
32
+ self.embed_dim = this_cfg.embed_dim
33
+ self.num_summaries = this_cfg.num_summaries
34
+ self.add_pe = this_cfg.add_pe
35
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
36
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
37
+
38
+ if self.add_pe:
39
+ self.pos_enc = PositionalEncoding(self.embed_dim,
40
+ scale=self.pixel_pe_scale,
41
+ temperature=self.pixel_pe_temperature)
42
+
43
+ self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
44
+ self.feature_pred = nn.Sequential(
45
+ nn.Linear(self.embed_dim, self.embed_dim),
46
+ nn.ReLU(inplace=True),
47
+ nn.Linear(self.embed_dim, self.embed_dim),
48
+ )
49
+ self.weights_pred = nn.Sequential(
50
+ nn.Linear(self.embed_dim, self.embed_dim),
51
+ nn.ReLU(inplace=True),
52
+ nn.Linear(self.embed_dim, self.num_summaries),
53
+ )
54
+
55
+ def forward(self,
56
+ masks: torch.Tensor,
57
+ value: torch.Tensor,
58
+ need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
59
+ # masks: B*num_objects*(H0)*(W0)
60
+ # value: B*num_objects*value_dim*H*W
61
+ # -> B*num_objects*H*W*value_dim
62
+ h, w = value.shape[-2:]
63
+ masks = F.interpolate(masks, size=(h, w), mode='area')
64
+ masks = masks.unsqueeze(-1)
65
+ inv_masks = 1 - masks
66
+ repeated_masks = torch.cat([
67
+ masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
68
+ inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
69
+ ],
70
+ dim=-1)
71
+
72
+ value = value.permute(0, 1, 3, 4, 2)
73
+ value = self.input_proj(value)
74
+ if self.add_pe:
75
+ pe = self.pos_enc(value)
76
+ value = value + pe
77
+
78
+ with torch.cuda.amp.autocast(enabled=False):
79
+ value = value.float()
80
+ feature = self.feature_pred(value)
81
+ logits = self.weights_pred(value)
82
+ sums, area = _weighted_pooling(repeated_masks, feature, logits)
83
+
84
+ summaries = torch.cat([sums, area], dim=-1)
85
+
86
+ if need_weights:
87
+ return summaries, logits
88
+ else:
89
+ return summaries, None
matanyone/model/transformer/object_transformer.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ from omegaconf import DictConfig
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from matanyone.model.group_modules import GConv2d
7
+ from matanyone.utils.tensor_utils import aggregate
8
+ from matanyone.model.transformer.positional_encoding import PositionalEncoding
9
+ from matanyone.model.transformer.transformer_layers import *
10
+
11
+
12
+ class QueryTransformerBlock(nn.Module):
13
+ def __init__(self, model_cfg: DictConfig):
14
+ super().__init__()
15
+
16
+ this_cfg = model_cfg.object_transformer
17
+ self.embed_dim = this_cfg.embed_dim
18
+ self.num_heads = this_cfg.num_heads
19
+ self.num_queries = this_cfg.num_queries
20
+ self.ff_dim = this_cfg.ff_dim
21
+
22
+ self.read_from_pixel = CrossAttention(self.embed_dim,
23
+ self.num_heads,
24
+ add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
25
+ self.self_attn = SelfAttention(self.embed_dim,
26
+ self.num_heads,
27
+ add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
28
+ self.ffn = FFN(self.embed_dim, self.ff_dim)
29
+ self.read_from_query = CrossAttention(self.embed_dim,
30
+ self.num_heads,
31
+ add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
32
+ norm=this_cfg.read_from_query.output_norm)
33
+ self.pixel_ffn = PixelFFN(self.embed_dim)
34
+
35
+ def forward(
36
+ self,
37
+ x: torch.Tensor,
38
+ pixel: torch.Tensor,
39
+ query_pe: torch.Tensor,
40
+ pixel_pe: torch.Tensor,
41
+ attn_mask: torch.Tensor,
42
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
43
+ # x: (bs*num_objects)*num_queries*embed_dim
44
+ # pixel: bs*num_objects*C*H*W
45
+ # query_pe: (bs*num_objects)*num_queries*embed_dim
46
+ # pixel_pe: (bs*num_objects)*(H*W)*C
47
+ # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
48
+
49
+ # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
50
+ pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
51
+ x, q_weights = self.read_from_pixel(x,
52
+ pixel_flat,
53
+ query_pe,
54
+ pixel_pe,
55
+ attn_mask=attn_mask,
56
+ need_weights=need_weights)
57
+ x = self.self_attn(x, query_pe)
58
+ x = self.ffn(x)
59
+
60
+ pixel_flat, p_weights = self.read_from_query(pixel_flat,
61
+ x,
62
+ pixel_pe,
63
+ query_pe,
64
+ need_weights=need_weights)
65
+ pixel = self.pixel_ffn(pixel, pixel_flat)
66
+
67
+ if need_weights:
68
+ bs, num_objects, _, h, w = pixel.shape
69
+ q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
70
+ p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
71
+ self.num_queries, h, w)
72
+
73
+ return x, pixel, q_weights, p_weights
74
+
75
+
76
+ class QueryTransformer(nn.Module):
77
+ def __init__(self, model_cfg: DictConfig):
78
+ super().__init__()
79
+
80
+ this_cfg = model_cfg.object_transformer
81
+ self.value_dim = model_cfg.value_dim
82
+ self.embed_dim = this_cfg.embed_dim
83
+ self.num_heads = this_cfg.num_heads
84
+ self.num_queries = this_cfg.num_queries
85
+
86
+ # query initialization and embedding
87
+ self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
88
+ self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
89
+
90
+ # projection from object summaries to query initialization and embedding
91
+ self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
92
+ self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
93
+
94
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
95
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
96
+ self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
97
+ self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
98
+ self.spatial_pe = PositionalEncoding(self.embed_dim,
99
+ scale=self.pixel_pe_scale,
100
+ temperature=self.pixel_pe_temperature,
101
+ channel_last=False,
102
+ transpose_output=True)
103
+
104
+ # transformer blocks
105
+ self.num_blocks = this_cfg.num_blocks
106
+ self.blocks = nn.ModuleList(
107
+ QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
108
+ self.mask_pred = nn.ModuleList(
109
+ nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
110
+ for _ in range(self.num_blocks + 1))
111
+
112
+ self.act = nn.ReLU(inplace=True)
113
+
114
+ def forward(self,
115
+ pixel: torch.Tensor,
116
+ obj_summaries: torch.Tensor,
117
+ selector: Optional[torch.Tensor] = None,
118
+ need_weights: bool = False,
119
+ seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
120
+ # pixel: B*num_objects*embed_dim*H*W
121
+ # obj_summaries: B*num_objects*T*num_queries*embed_dim
122
+ T = obj_summaries.shape[2]
123
+ bs, num_objects, _, H, W = pixel.shape
124
+
125
+ # normalize object values
126
+ # the last channel is the cumulative area of the object
127
+ obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
128
+ self.embed_dim + 1)
129
+ # sum over time
130
+ # during inference, T=1 as we already did streaming average in memory_manager
131
+ obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
132
+ obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
133
+ obj_values = obj_sums / (obj_area + 1e-4)
134
+ obj_init = self.summary_to_query_init(obj_values)
135
+ obj_emb = self.summary_to_query_emb(obj_values)
136
+
137
+ # positional embeddings for object queries
138
+ query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
139
+ query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
140
+
141
+ # positional embeddings for pixel features
142
+ pixel_init = self.pixel_init_proj(pixel)
143
+ pixel_emb = self.pixel_emb_proj(pixel)
144
+ pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
145
+ pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
146
+ pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
147
+
148
+ pixel = pixel_init
149
+
150
+ # run the transformer
151
+ aux_features = {'logits': []}
152
+
153
+ # first aux output
154
+ aux_logits = self.mask_pred[0](pixel).squeeze(2)
155
+ attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
156
+ aux_features['logits'].append(aux_logits)
157
+ for i in range(self.num_blocks):
158
+ query, pixel, q_weights, p_weights = self.blocks[i](query,
159
+ pixel,
160
+ query_emb,
161
+ pixel_pe,
162
+ attn_mask,
163
+ need_weights=need_weights)
164
+
165
+ if self.training or i <= self.num_blocks - 1 or need_weights:
166
+ aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
167
+ attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
168
+ aux_features['logits'].append(aux_logits)
169
+
170
+ aux_features['q_weights'] = q_weights # last layer only
171
+ aux_features['p_weights'] = p_weights # last layer only
172
+
173
+ if self.training:
174
+ # no need to save all heads
175
+ aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
176
+ self.num_queries, H, W)[:, :, 0]
177
+
178
+ return pixel, aux_features
179
+
180
+ def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor:
181
+ # logits: batch_size*num_objects*H*W
182
+ # selector: batch_size*num_objects*1*1
183
+ # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
184
+ # where True means the attention is blocked
185
+
186
+ if selector is None:
187
+ prob = logits.sigmoid()
188
+ else:
189
+ prob = logits.sigmoid() * selector
190
+ logits = aggregate(prob, dim=1)
191
+
192
+ is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
193
+ foreground_mask = is_foreground.bool().flatten(start_dim=2)
194
+ inv_foreground_mask = ~foreground_mask
195
+ inv_background_mask = foreground_mask
196
+
197
+ aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
198
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
199
+ aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
200
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
201
+
202
+ aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
203
+
204
+ aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
205
+
206
+ return aux_mask
matanyone/model/transformer/positional_encoding.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference:
2
+ # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
3
+ # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Gets a base embedding for one dimension with sin and cos intertwined
15
+ """
16
+ emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
17
+ return torch.flatten(emb, -2, -1)
18
+
19
+
20
+ class PositionalEncoding(nn.Module):
21
+ def __init__(self,
22
+ dim: int,
23
+ scale: float = math.pi * 2,
24
+ temperature: float = 10000,
25
+ normalize: bool = True,
26
+ channel_last: bool = True,
27
+ transpose_output: bool = False):
28
+ super().__init__()
29
+ dim = int(np.ceil(dim / 4) * 2)
30
+ self.dim = dim
31
+ inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
32
+ self.register_buffer("inv_freq", inv_freq)
33
+ self.normalize = normalize
34
+ self.scale = scale
35
+ self.eps = 1e-6
36
+ self.channel_last = channel_last
37
+ self.transpose_output = transpose_output
38
+
39
+ self.cached_penc = None # the cache is irrespective of the number of objects
40
+
41
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ :param tensor: A 4/5d tensor of size
44
+ channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
45
+ channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
46
+ :return: positional encoding tensor that has the same shape as the input if the input is 4d
47
+ if the input is 5d, the output is broadcastable along the k-dimension
48
+ """
49
+ if len(tensor.shape) != 4 and len(tensor.shape) != 5:
50
+ raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
51
+
52
+ if len(tensor.shape) == 5:
53
+ # take a sample from the k dimension
54
+ num_objects = tensor.shape[1]
55
+ tensor = tensor[:, 0]
56
+ else:
57
+ num_objects = None
58
+
59
+ if self.channel_last:
60
+ batch_size, h, w, c = tensor.shape
61
+ else:
62
+ batch_size, c, h, w = tensor.shape
63
+
64
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
65
+ if num_objects is None:
66
+ return self.cached_penc
67
+ else:
68
+ return self.cached_penc.unsqueeze(1)
69
+
70
+ self.cached_penc = None
71
+
72
+ pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
73
+ pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
74
+ if self.normalize:
75
+ pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
76
+ pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
77
+
78
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
79
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
80
+ emb_y = get_emb(sin_inp_y).unsqueeze(1)
81
+ emb_x = get_emb(sin_inp_x)
82
+
83
+ emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
84
+ emb[:, :, :self.dim] = emb_x
85
+ emb[:, :, self.dim:] = emb_y
86
+
87
+ if not self.channel_last and self.transpose_output:
88
+ # cancelled out
89
+ pass
90
+ elif (not self.channel_last) or (self.transpose_output):
91
+ emb = emb.permute(2, 0, 1)
92
+
93
+ self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
94
+ if num_objects is None:
95
+ return self.cached_penc
96
+ else:
97
+ return self.cached_penc.unsqueeze(1)
98
+
99
+
100
+ if __name__ == '__main__':
101
+ pe = PositionalEncoding(8).cuda()
102
+ input = torch.ones((1, 8, 8, 8)).cuda()
103
+ output = pe(input)
104
+ # print(output)
105
+ print(output[0, :, 0, 0])
106
+ print(output[0, :, 0, 5])
107
+ print(output[0, 0, :, 0])
108
+ print(output[0, 0, 0, :])
matanyone/model/transformer/transformer_layers.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from PyTorch nn.Transformer
2
+
3
+ from typing import List, Callable
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from matanyone.model.channel_attn import CAResBlock
10
+
11
+
12
+ class SelfAttention(nn.Module):
13
+ def __init__(self,
14
+ dim: int,
15
+ nhead: int,
16
+ dropout: float = 0.0,
17
+ batch_first: bool = True,
18
+ add_pe_to_qkv: List[bool] = [True, True, False]):
19
+ super().__init__()
20
+ self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
21
+ self.norm = nn.LayerNorm(dim)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.add_pe_to_qkv = add_pe_to_qkv
24
+
25
+ def forward(self,
26
+ x: torch.Tensor,
27
+ pe: torch.Tensor,
28
+ attn_mask: bool = None,
29
+ key_padding_mask: bool = None) -> torch.Tensor:
30
+ x = self.norm(x)
31
+ if any(self.add_pe_to_qkv):
32
+ x_with_pe = x + pe
33
+ q = x_with_pe if self.add_pe_to_qkv[0] else x
34
+ k = x_with_pe if self.add_pe_to_qkv[1] else x
35
+ v = x_with_pe if self.add_pe_to_qkv[2] else x
36
+ else:
37
+ q = k = v = x
38
+
39
+ r = x
40
+ x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
41
+ return r + self.dropout(x)
42
+
43
+
44
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
45
+ class CrossAttention(nn.Module):
46
+ def __init__(self,
47
+ dim: int,
48
+ nhead: int,
49
+ dropout: float = 0.0,
50
+ batch_first: bool = True,
51
+ add_pe_to_qkv: List[bool] = [True, True, False],
52
+ residual: bool = True,
53
+ norm: bool = True):
54
+ super().__init__()
55
+ self.cross_attn = nn.MultiheadAttention(dim,
56
+ nhead,
57
+ dropout=dropout,
58
+ batch_first=batch_first)
59
+ if norm:
60
+ self.norm = nn.LayerNorm(dim)
61
+ else:
62
+ self.norm = nn.Identity()
63
+ self.dropout = nn.Dropout(dropout)
64
+ self.add_pe_to_qkv = add_pe_to_qkv
65
+ self.residual = residual
66
+
67
+ def forward(self,
68
+ x: torch.Tensor,
69
+ mem: torch.Tensor,
70
+ x_pe: torch.Tensor,
71
+ mem_pe: torch.Tensor,
72
+ attn_mask: bool = None,
73
+ *,
74
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
75
+ x = self.norm(x)
76
+ if self.add_pe_to_qkv[0]:
77
+ q = x + x_pe
78
+ else:
79
+ q = x
80
+
81
+ if any(self.add_pe_to_qkv[1:]):
82
+ mem_with_pe = mem + mem_pe
83
+ k = mem_with_pe if self.add_pe_to_qkv[1] else mem
84
+ v = mem_with_pe if self.add_pe_to_qkv[2] else mem
85
+ else:
86
+ k = v = mem
87
+ r = x
88
+ x, weights = self.cross_attn(q,
89
+ k,
90
+ v,
91
+ attn_mask=attn_mask,
92
+ need_weights=need_weights,
93
+ average_attn_weights=False)
94
+
95
+ if self.residual:
96
+ return r + self.dropout(x), weights
97
+ else:
98
+ return self.dropout(x), weights
99
+
100
+
101
+ class FFN(nn.Module):
102
+ def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
103
+ super().__init__()
104
+ self.linear1 = nn.Linear(dim_in, dim_ff)
105
+ self.linear2 = nn.Linear(dim_ff, dim_in)
106
+ self.norm = nn.LayerNorm(dim_in)
107
+
108
+ if isinstance(activation, str):
109
+ self.activation = _get_activation_fn(activation)
110
+ else:
111
+ self.activation = activation
112
+
113
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
+ r = x
115
+ x = self.norm(x)
116
+ x = self.linear2(self.activation(self.linear1(x)))
117
+ x = r + x
118
+ return x
119
+
120
+
121
+ class PixelFFN(nn.Module):
122
+ def __init__(self, dim: int):
123
+ super().__init__()
124
+ self.dim = dim
125
+ self.conv = CAResBlock(dim, dim)
126
+
127
+ def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
128
+ # pixel: batch_size * num_objects * dim * H * W
129
+ # pixel_flat: (batch_size*num_objects) * (H*W) * dim
130
+ bs, num_objects, _, h, w = pixel.shape
131
+ pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
132
+ pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
133
+
134
+ x = self.conv(pixel_flat)
135
+ x = x.view(bs, num_objects, self.dim, h, w)
136
+ return x
137
+
138
+
139
+ class OutputFFN(nn.Module):
140
+ def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
141
+ super().__init__()
142
+ self.linear1 = nn.Linear(dim_in, dim_out)
143
+ self.linear2 = nn.Linear(dim_out, dim_out)
144
+
145
+ if isinstance(activation, str):
146
+ self.activation = _get_activation_fn(activation)
147
+ else:
148
+ self.activation = activation
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ x = self.linear2(self.activation(self.linear1(x)))
152
+ return x
153
+
154
+
155
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
156
+ if activation == "relu":
157
+ return F.relu
158
+ elif activation == "gelu":
159
+ return F.gelu
160
+
161
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
matanyone/model/utils/__init__.py ADDED
File without changes
matanyone/model/utils/memory_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from typing import Optional, Union, Tuple
4
+
5
+
6
+ # @torch.jit.script
7
+ def get_similarity(mk: torch.Tensor,
8
+ ms: torch.Tensor,
9
+ qk: torch.Tensor,
10
+ qe: torch.Tensor,
11
+ add_batch_dim: bool = False,
12
+ uncert_mask = None) -> torch.Tensor:
13
+ # used for training/inference and memory reading/memory potentiation
14
+ # mk: B x CK x [N] - Memory keys
15
+ # ms: B x 1 x [N] - Memory shrinkage
16
+ # qk: B x CK x [HW/P] - Query keys
17
+ # qe: B x CK x [HW/P] - Query selection
18
+ # Dimensions in [] are flattened
19
+ # Return: B*N*HW
20
+ if add_batch_dim:
21
+ mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
22
+ qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
23
+
24
+ CK = mk.shape[1]
25
+
26
+ mk = mk.flatten(start_dim=2)
27
+ ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
28
+ qk = qk.flatten(start_dim=2)
29
+ qe = qe.flatten(start_dim=2) if qe is not None else None
30
+
31
+ # query token selection based on temporal sparsity
32
+ if uncert_mask is not None:
33
+ uncert_mask = uncert_mask.flatten(start_dim=2)
34
+ uncert_mask = uncert_mask.expand(-1, 64, -1)
35
+ qk = qk * uncert_mask
36
+ qe = qe * uncert_mask
37
+
38
+ if qe is not None:
39
+ # See XMem's appendix for derivation
40
+ mk = mk.transpose(1, 2)
41
+ a_sq = (mk.pow(2) @ qe)
42
+ two_ab = 2 * (mk @ (qk * qe))
43
+ b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
44
+ similarity = (-a_sq + two_ab - b_sq)
45
+ else:
46
+ # similar to STCN if we don't have the selection term
47
+ a_sq = mk.pow(2).sum(1).unsqueeze(2)
48
+ two_ab = 2 * (mk.transpose(1, 2) @ qk)
49
+ similarity = (-a_sq + two_ab)
50
+
51
+ if ms is not None:
52
+ similarity = similarity * ms / math.sqrt(CK) # B*N*HW
53
+ else:
54
+ similarity = similarity / math.sqrt(CK) # B*N*HW
55
+
56
+ return similarity
57
+
58
+
59
+ def do_softmax(
60
+ similarity: torch.Tensor,
61
+ top_k: Optional[int] = None,
62
+ inplace: bool = False,
63
+ return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
64
+ # normalize similarity with top-k softmax
65
+ # similarity: B x N x [HW/P]
66
+ # use inplace with care
67
+ if top_k is not None:
68
+ values, indices = torch.topk(similarity, k=top_k, dim=1)
69
+
70
+ x_exp = values.exp_()
71
+ x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
72
+ if inplace:
73
+ similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
74
+ affinity = similarity
75
+ else:
76
+ affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
77
+ else:
78
+ maxes = torch.max(similarity, dim=1, keepdim=True)[0]
79
+ x_exp = torch.exp(similarity - maxes)
80
+ x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
81
+ affinity = x_exp / x_exp_sum
82
+ indices = None
83
+
84
+ if return_usage:
85
+ return affinity, affinity.sum(dim=2)
86
+
87
+ return affinity
88
+
89
+
90
+ def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
91
+ qe: torch.Tensor, uncert_mask = None) -> torch.Tensor:
92
+ # shorthand used in training with no top-k
93
+ similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask)
94
+ affinity = do_softmax(similarity)
95
+ return affinity
96
+
97
+ def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor:
98
+ B, CV, T, H, W = mv.shape
99
+
100
+ mo = mv.view(B, CV, T * H * W)
101
+ mem = torch.bmm(mo, affinity)
102
+ if uncert_mask is not None:
103
+ uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1)
104
+ mem = mem * uncert_mask
105
+ mem = mem.view(B, CV, H, W)
106
+
107
+ return mem
matanyone/model/utils/parameter_groups.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ log = logging.getLogger()
4
+
5
+
6
+ def get_parameter_groups(model, stage_cfg, print_log=False):
7
+ """
8
+ Assign different weight decays and learning rates to different parameters.
9
+ Returns a parameter group which can be passed to the optimizer.
10
+ """
11
+ weight_decay = stage_cfg.weight_decay
12
+ embed_weight_decay = stage_cfg.embed_weight_decay
13
+ backbone_lr_ratio = stage_cfg.backbone_lr_ratio
14
+ base_lr = stage_cfg.learning_rate
15
+
16
+ backbone_params = []
17
+ embed_params = []
18
+ other_params = []
19
+
20
+ embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
21
+ embedding_names = [e + '.weight' for e in embedding_names]
22
+
23
+ # inspired by detectron2
24
+ memo = set()
25
+ for name, param in model.named_parameters():
26
+ if not param.requires_grad:
27
+ continue
28
+ # Avoid duplicating parameters
29
+ if param in memo:
30
+ continue
31
+ memo.add(param)
32
+
33
+ if name.startswith('module'):
34
+ name = name[7:]
35
+
36
+ inserted = False
37
+ if name.startswith('pixel_encoder.'):
38
+ backbone_params.append(param)
39
+ inserted = True
40
+ if print_log:
41
+ log.info(f'{name} counted as a backbone parameter.')
42
+ else:
43
+ for e in embedding_names:
44
+ if name.endswith(e):
45
+ embed_params.append(param)
46
+ inserted = True
47
+ if print_log:
48
+ log.info(f'{name} counted as an embedding parameter.')
49
+ break
50
+
51
+ if not inserted:
52
+ other_params.append(param)
53
+
54
+ parameter_groups = [
55
+ {
56
+ 'params': backbone_params,
57
+ 'lr': base_lr * backbone_lr_ratio,
58
+ 'weight_decay': weight_decay
59
+ },
60
+ {
61
+ 'params': embed_params,
62
+ 'lr': base_lr,
63
+ 'weight_decay': embed_weight_decay
64
+ },
65
+ {
66
+ 'params': other_params,
67
+ 'lr': base_lr,
68
+ 'weight_decay': weight_decay
69
+ },
70
+ ]
71
+
72
+ return parameter_groups
matanyone/model/utils/resnet.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ resnet.py - A modified ResNet structure
3
+ We append extra channels to the first conv by some network surgery
4
+ """
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils import model_zoo
12
+
13
+
14
+ def load_weights_add_extra_dim(target, source_state, extra_dim=1):
15
+ new_dict = OrderedDict()
16
+
17
+ for k1, v1 in target.state_dict().items():
18
+ if not 'num_batches_tracked' in k1:
19
+ if k1 in source_state:
20
+ tar_v = source_state[k1]
21
+
22
+ if v1.shape != tar_v.shape:
23
+ # Init the new segmentation channel with zeros
24
+ # print(v1.shape, tar_v.shape)
25
+ c, _, w, h = v1.shape
26
+ pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
27
+ nn.init.orthogonal_(pads)
28
+ tar_v = torch.cat([tar_v, pads], 1)
29
+
30
+ new_dict[k1] = tar_v
31
+
32
+ target.load_state_dict(new_dict)
33
+
34
+
35
+ model_urls = {
36
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
37
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
38
+ }
39
+
40
+
41
+ def conv3x3(in_planes, out_planes, stride=1, dilation=1):
42
+ return nn.Conv2d(in_planes,
43
+ out_planes,
44
+ kernel_size=3,
45
+ stride=stride,
46
+ padding=dilation,
47
+ dilation=dilation,
48
+ bias=False)
49
+
50
+
51
+ class BasicBlock(nn.Module):
52
+ expansion = 1
53
+
54
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
55
+ super(BasicBlock, self).__init__()
56
+ self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
57
+ self.bn1 = nn.BatchNorm2d(planes)
58
+ self.relu = nn.ReLU(inplace=True)
59
+ self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
60
+ self.bn2 = nn.BatchNorm2d(planes)
61
+ self.downsample = downsample
62
+ self.stride = stride
63
+
64
+ def forward(self, x):
65
+ residual = x
66
+
67
+ out = self.conv1(x)
68
+ out = self.bn1(out)
69
+ out = self.relu(out)
70
+
71
+ out = self.conv2(out)
72
+ out = self.bn2(out)
73
+
74
+ if self.downsample is not None:
75
+ residual = self.downsample(x)
76
+
77
+ out += residual
78
+ out = self.relu(out)
79
+
80
+ return out
81
+
82
+
83
+ class Bottleneck(nn.Module):
84
+ expansion = 4
85
+
86
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
87
+ super(Bottleneck, self).__init__()
88
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
89
+ self.bn1 = nn.BatchNorm2d(planes)
90
+ self.conv2 = nn.Conv2d(planes,
91
+ planes,
92
+ kernel_size=3,
93
+ stride=stride,
94
+ dilation=dilation,
95
+ padding=dilation,
96
+ bias=False)
97
+ self.bn2 = nn.BatchNorm2d(planes)
98
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
99
+ self.bn3 = nn.BatchNorm2d(planes * 4)
100
+ self.relu = nn.ReLU(inplace=True)
101
+ self.downsample = downsample
102
+ self.stride = stride
103
+
104
+ def forward(self, x):
105
+ residual = x
106
+
107
+ out = self.conv1(x)
108
+ out = self.bn1(out)
109
+ out = self.relu(out)
110
+
111
+ out = self.conv2(out)
112
+ out = self.bn2(out)
113
+ out = self.relu(out)
114
+
115
+ out = self.conv3(out)
116
+ out = self.bn3(out)
117
+
118
+ if self.downsample is not None:
119
+ residual = self.downsample(x)
120
+
121
+ out += residual
122
+ out = self.relu(out)
123
+
124
+ return out
125
+
126
+
127
+ class ResNet(nn.Module):
128
+ def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
129
+ self.inplanes = 64
130
+ super(ResNet, self).__init__()
131
+ self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
132
+ self.bn1 = nn.BatchNorm2d(64)
133
+ self.relu = nn.ReLU(inplace=True)
134
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
135
+ self.layer1 = self._make_layer(block, 64, layers[0])
136
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
137
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
138
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
139
+
140
+ for m in self.modules():
141
+ if isinstance(m, nn.Conv2d):
142
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
143
+ m.weight.data.normal_(0, math.sqrt(2. / n))
144
+ elif isinstance(m, nn.BatchNorm2d):
145
+ m.weight.data.fill_(1)
146
+ m.bias.data.zero_()
147
+
148
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
149
+ downsample = None
150
+ if stride != 1 or self.inplanes != planes * block.expansion:
151
+ downsample = nn.Sequential(
152
+ nn.Conv2d(self.inplanes,
153
+ planes * block.expansion,
154
+ kernel_size=1,
155
+ stride=stride,
156
+ bias=False),
157
+ nn.BatchNorm2d(planes * block.expansion),
158
+ )
159
+
160
+ layers = [block(self.inplanes, planes, stride, downsample)]
161
+ self.inplanes = planes * block.expansion
162
+ for i in range(1, blocks):
163
+ layers.append(block(self.inplanes, planes, dilation=dilation))
164
+
165
+ return nn.Sequential(*layers)
166
+
167
+
168
+ def resnet18(pretrained=True, extra_dim=0):
169
+ model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
170
+ if pretrained:
171
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
172
+ return model
173
+
174
+
175
+ def resnet50(pretrained=True, extra_dim=0):
176
+ model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
177
+ if pretrained:
178
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
179
+ return model
matanyone/utils/__init__.py ADDED
File without changes
matanyone/utils/get_default_model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A helper function to get a default model for quick testing
3
+ """
4
+ from omegaconf import open_dict
5
+ from hydra import compose, initialize
6
+
7
+ import torch
8
+ from matanyone.model.matanyone import MatAnyone
9
+ from matanyone.inference.utils.args_utils import get_dataset_cfg
10
+
11
+ def get_matanyone_model(ckpt_path, device) -> MatAnyone:
12
+ initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
13
+ cfg = compose(config_name="eval_matanyone_config")
14
+
15
+ with open_dict(cfg):
16
+ cfg['weights'] = ckpt_path
17
+
18
+ # Load the network weights
19
+ matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
20
+ model_weights = torch.load(cfg.weights, map_location=device)
21
+ matanyone.load_weights(model_weights)
22
+
23
+ return matanyone
matanyone/utils/tensor_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterable
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # STM
7
+ def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
8
+ h, w = in_img.shape[-2:]
9
+
10
+ if h % d > 0:
11
+ new_h = h + d - h % d
12
+ else:
13
+ new_h = h
14
+ if w % d > 0:
15
+ new_w = w + d - w % d
16
+ else:
17
+ new_w = w
18
+ lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
19
+ lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
20
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
21
+ out = F.pad(in_img, pad_array)
22
+ return out, pad_array
23
+
24
+
25
+ def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
26
+ if len(img.shape) == 4:
27
+ if pad[2] + pad[3] > 0:
28
+ img = img[:, :, pad[2]:-pad[3], :]
29
+ if pad[0] + pad[1] > 0:
30
+ img = img[:, :, :, pad[0]:-pad[1]]
31
+ elif len(img.shape) == 3:
32
+ if pad[2] + pad[3] > 0:
33
+ img = img[:, pad[2]:-pad[3], :]
34
+ if pad[0] + pad[1] > 0:
35
+ img = img[:, :, pad[0]:-pad[1]]
36
+ elif len(img.shape) == 5:
37
+ if pad[2] + pad[3] > 0:
38
+ img = img[:, :, :, pad[2]:-pad[3], :]
39
+ if pad[0] + pad[1] > 0:
40
+ img = img[:, :, :, :, pad[0]:-pad[1]]
41
+ else:
42
+ raise NotImplementedError
43
+ return img
44
+
45
+
46
+ # @torch.jit.script
47
+ def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
48
+ with torch.cuda.amp.autocast(enabled=False):
49
+ prob = prob.float()
50
+ new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
51
+ dim).clamp(1e-7, 1 - 1e-7)
52
+ logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf)
53
+
54
+ return logits
55
+
56
+
57
+ # @torch.jit.script
58
+ def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
59
+ # cls_gt: B*1*H*W
60
+ B, _, H, W = cls_gt.shape
61
+ one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
62
+ return one_hot
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ progressbar2
2
+ gdown >= 4.7.1
3
+ gitpython >= 3.1
4
+ git+https://github.com/cheind/py-thin-plate-spline
5
+ hickle >= 5.0
6
+ tensorboard >= 2.11
7
+ numpy >= 1.21
8
+ git+https://github.com/facebookresearch/segment-anything.git
9
+ gradio==4.31.0
10
+ fastapi==0.111.0
11
+ pydantic==2.7.1
12
+ opencv-python >= 4.8
13
+ matplotlib
14
+ pyyaml
15
+ av >= 0.5.2
16
+ openmim
17
+ tqdm >= 4.66.1
18
+ psutil
19
+ ffmpeg-python
20
+ cython
21
+ Pillow >= 9.5
22
+ scipy >= 1.7
23
+ pycocotools >= 2.0.7
24
+ einops >= 0.6
25
+ hydra-core >= 1.3.2
26
+ PySide6 >= 6.2.0
27
+ charset-normalizer >= 3.1.0
28
+ netifaces >= 0.11.0
29
+ cchardet >= 2.1.7
30
+ easydict
31
+ requests
32
+ pyqtdarktheme
33
+ imageio == 2.25.0
34
+ imageio[ffmpeg]
35
+ ffmpeg-python