Spaces:
Running
on
L4
Running
on
L4
Commit
·
dcc8c59
1
Parent(s):
941f6aa
init
Browse files- .gitignore +8 -0
- LICENSE +42 -0
- README.md +5 -4
- hugging_face/app.py +967 -0
- hugging_face/matanyone_wrapper.py +73 -0
- hugging_face/tools/__init__.py +0 -0
- hugging_face/tools/base_segmenter.py +129 -0
- hugging_face/tools/download_util.py +109 -0
- hugging_face/tools/interact_tools.py +99 -0
- hugging_face/tools/mask_painter.py +288 -0
- hugging_face/tools/misc.py +131 -0
- hugging_face/tools/painter.py +215 -0
- matanyone/config/__init__.py +0 -0
- matanyone/config/eval_matanyone_config.yaml +47 -0
- matanyone/config/hydra/job_logging/custom-no-rank.yaml +22 -0
- matanyone/config/hydra/job_logging/custom.yaml +22 -0
- matanyone/config/model/base.yaml +58 -0
- matanyone/inference/__init__.py +0 -0
- matanyone/inference/image_feature_store.py +56 -0
- matanyone/inference/inference_core.py +407 -0
- matanyone/inference/kv_memory_store.py +348 -0
- matanyone/inference/memory_manager.py +457 -0
- matanyone/inference/object_info.py +24 -0
- matanyone/inference/object_manager.py +149 -0
- matanyone/inference/utils/__init__.py +0 -0
- matanyone/inference/utils/args_utils.py +30 -0
- matanyone/model/__init__.py +0 -0
- matanyone/model/aux_modules.py +93 -0
- matanyone/model/big_modules.py +358 -0
- matanyone/model/channel_attn.py +39 -0
- matanyone/model/group_modules.py +126 -0
- matanyone/model/matanyone.py +323 -0
- matanyone/model/modules.py +170 -0
- matanyone/model/transformer/__init__.py +0 -0
- matanyone/model/transformer/object_summarizer.py +89 -0
- matanyone/model/transformer/object_transformer.py +206 -0
- matanyone/model/transformer/positional_encoding.py +108 -0
- matanyone/model/transformer/transformer_layers.py +161 -0
- matanyone/model/utils/__init__.py +0 -0
- matanyone/model/utils/memory_utils.py +107 -0
- matanyone/model/utils/parameter_groups.py +72 -0
- matanyone/model/utils/resnet.py +179 -0
- matanyone/utils/__init__.py +0 -0
- matanyone/utils/get_default_model.py +23 -0
- matanyone/utils/tensor_utils.py +62 -0
- requirements.txt +35 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.vscode/
|
3 |
+
.DS_Store
|
4 |
+
assets/
|
5 |
+
inputs/
|
6 |
+
test_sample/
|
7 |
+
results/
|
8 |
+
pretrained_models/
|
LICENSE
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# S-Lab License 1.0
|
2 |
+
|
3 |
+
Copyright 2023 S-Lab
|
4 |
+
|
5 |
+
Redistribution and use for non-commercial purpose in source and
|
6 |
+
binary forms, with or without modification, are permitted provided
|
7 |
+
that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright
|
10 |
+
notice, this list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright
|
13 |
+
notice, this list of conditions and the following disclaimer in
|
14 |
+
the documentation and/or other materials provided with the
|
15 |
+
distribution.
|
16 |
+
|
17 |
+
3. Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived
|
19 |
+
from this software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
22 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
23 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
24 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
25 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
26 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
27 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
28 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
29 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
30 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
31 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
32 |
+
|
33 |
+
In the event that redistribution and/or use for commercial purpose in
|
34 |
+
source or binary forms, with or without modification is required,
|
35 |
+
please contact the contributor(s) of the work.
|
36 |
+
|
37 |
+
|
38 |
+
---
|
39 |
+
For inquiries permission for commercial use, please consult our team:
|
40 |
+
Peiqing Yang ([email protected]),
|
41 |
+
Dr. Shangchen Zhou ([email protected]),
|
42 |
+
Prof. Chen Change Loy ([email protected])
|
README.md
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
---
|
2 |
title: MatAnyone
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.16.0
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
license: other
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: MatAnyone
|
3 |
+
emoji: 🤡
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.16.0
|
8 |
+
app_file: hugging_face/app.py
|
9 |
pinned: false
|
10 |
license: other
|
11 |
+
short_description: Gradio demo for MatAnyone
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
hugging_face/app.py
ADDED
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("../")
|
3 |
+
sys.path.append("../../")
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import psutil
|
9 |
+
import ffmpeg
|
10 |
+
import imageio
|
11 |
+
import argparse
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import torch
|
16 |
+
import numpy as np
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
from tools.painter import mask_painter
|
20 |
+
from tools.interact_tools import SamControler
|
21 |
+
from tools.misc import get_device
|
22 |
+
from tools.download_util import load_file_from_url
|
23 |
+
|
24 |
+
from matanyone_wrapper import matanyone
|
25 |
+
from matanyone.utils.get_default_model import get_matanyone_model
|
26 |
+
from matanyone.inference.inference_core import InferenceCore
|
27 |
+
|
28 |
+
def parse_augment():
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument('--device', type=str, default=None)
|
31 |
+
parser.add_argument('--sam_model_type', type=str, default="vit_h")
|
32 |
+
parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
|
33 |
+
parser.add_argument('--mask_save', default=False)
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
if not args.device:
|
37 |
+
args.device = str(get_device())
|
38 |
+
|
39 |
+
return args
|
40 |
+
|
41 |
+
# SAM generator
|
42 |
+
class MaskGenerator():
|
43 |
+
def __init__(self, sam_checkpoint, args):
|
44 |
+
self.args = args
|
45 |
+
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
46 |
+
|
47 |
+
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
48 |
+
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
49 |
+
return mask, logit, painted_image
|
50 |
+
|
51 |
+
# convert points input to prompt state
|
52 |
+
def get_prompt(click_state, click_input):
|
53 |
+
inputs = json.loads(click_input)
|
54 |
+
points = click_state[0]
|
55 |
+
labels = click_state[1]
|
56 |
+
for input in inputs:
|
57 |
+
points.append(input[:2])
|
58 |
+
labels.append(input[2])
|
59 |
+
click_state[0] = points
|
60 |
+
click_state[1] = labels
|
61 |
+
prompt = {
|
62 |
+
"prompt_type":["click"],
|
63 |
+
"input_point":click_state[0],
|
64 |
+
"input_label":click_state[1],
|
65 |
+
"multimask_output":"True",
|
66 |
+
}
|
67 |
+
return prompt
|
68 |
+
|
69 |
+
def get_frames_from_image(image_input, image_state):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
video_path:str
|
73 |
+
timestamp:float64
|
74 |
+
Return
|
75 |
+
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
76 |
+
"""
|
77 |
+
|
78 |
+
user_name = time.time()
|
79 |
+
frames = [image_input] * 2 # hardcode: mimic a video with 2 frames
|
80 |
+
image_size = (frames[0].shape[0],frames[0].shape[1])
|
81 |
+
# initialize video_state
|
82 |
+
image_state = {
|
83 |
+
"user_name": user_name,
|
84 |
+
"image_name": "output.png",
|
85 |
+
"origin_images": frames,
|
86 |
+
"painted_images": frames.copy(),
|
87 |
+
"masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
|
88 |
+
"logits": [None]*len(frames),
|
89 |
+
"select_frame_number": 0,
|
90 |
+
"fps": None
|
91 |
+
}
|
92 |
+
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
|
93 |
+
model.samcontroler.sam_controler.reset_image()
|
94 |
+
model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
|
95 |
+
return image_state, image_info, image_state["origin_images"][0], \
|
96 |
+
gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
97 |
+
gr.update(visible=True), gr.update(visible=True), \
|
98 |
+
gr.update(visible=True), gr.update(visible=True),\
|
99 |
+
gr.update(visible=True), gr.update(visible=True), \
|
100 |
+
gr.update(visible=True), gr.update(visible=True), \
|
101 |
+
gr.update(visible=True), gr.update(visible=True), \
|
102 |
+
gr.update(visible=True), gr.update(visible=True, value=[]), \
|
103 |
+
gr.update(visible=True)
|
104 |
+
|
105 |
+
# extract frames from upload video
|
106 |
+
def get_frames_from_video(video_input, video_state):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
video_path:str
|
110 |
+
timestamp:float64
|
111 |
+
Return
|
112 |
+
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
113 |
+
"""
|
114 |
+
video_path = video_input
|
115 |
+
frames = []
|
116 |
+
user_name = time.time()
|
117 |
+
|
118 |
+
# extract Audio
|
119 |
+
audio_path = "audio.wav"
|
120 |
+
audio_path = video_input.replace(".mp4", "_audio.wav")
|
121 |
+
try:
|
122 |
+
ffmpeg.input(video_path).output(audio_path, format='wav', acodec='pcm_s16le', ac=2, ar='44100').run(overwrite_output=True, quiet=True)
|
123 |
+
except Exception as e:
|
124 |
+
print(f"Audio extraction error: {str(e)}")
|
125 |
+
audio_path = "" # Set to "" if extraction fails
|
126 |
+
# print(f'audio_path: {audio_path}')
|
127 |
+
|
128 |
+
# extract frames
|
129 |
+
try:
|
130 |
+
cap = cv2.VideoCapture(video_path)
|
131 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
132 |
+
while cap.isOpened():
|
133 |
+
ret, frame = cap.read()
|
134 |
+
if ret == True:
|
135 |
+
current_memory_usage = psutil.virtual_memory().percent
|
136 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
137 |
+
if current_memory_usage > 90:
|
138 |
+
break
|
139 |
+
else:
|
140 |
+
break
|
141 |
+
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
142 |
+
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
143 |
+
image_size = (frames[0].shape[0],frames[0].shape[1])
|
144 |
+
|
145 |
+
# initialize video_state
|
146 |
+
video_state = {
|
147 |
+
"user_name": user_name,
|
148 |
+
"video_name": os.path.split(video_path)[-1],
|
149 |
+
"origin_images": frames,
|
150 |
+
"painted_images": frames.copy(),
|
151 |
+
"masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
|
152 |
+
"logits": [None]*len(frames),
|
153 |
+
"select_frame_number": 0,
|
154 |
+
"fps": fps,
|
155 |
+
"audio": audio_path
|
156 |
+
}
|
157 |
+
video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
|
158 |
+
model.samcontroler.sam_controler.reset_image()
|
159 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
160 |
+
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
161 |
+
gr.update(visible=True), gr.update(visible=True), \
|
162 |
+
gr.update(visible=True), gr.update(visible=True),\
|
163 |
+
gr.update(visible=True), gr.update(visible=True), \
|
164 |
+
gr.update(visible=True), gr.update(visible=False), \
|
165 |
+
gr.update(visible=False), gr.update(visible=True), \
|
166 |
+
gr.update(visible=True)
|
167 |
+
|
168 |
+
# get the select frame from gradio slider
|
169 |
+
def select_video_template(image_selection_slider, video_state, interactive_state):
|
170 |
+
|
171 |
+
image_selection_slider -= 1
|
172 |
+
video_state["select_frame_number"] = image_selection_slider
|
173 |
+
|
174 |
+
# once select a new template frame, set the image in sam
|
175 |
+
model.samcontroler.sam_controler.reset_image()
|
176 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
177 |
+
|
178 |
+
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
|
179 |
+
|
180 |
+
def select_image_template(image_selection_slider, video_state, interactive_state):
|
181 |
+
|
182 |
+
image_selection_slider = 0 # fixed for image
|
183 |
+
video_state["select_frame_number"] = image_selection_slider
|
184 |
+
|
185 |
+
# once select a new template frame, set the image in sam
|
186 |
+
model.samcontroler.sam_controler.reset_image()
|
187 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
188 |
+
|
189 |
+
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
|
190 |
+
|
191 |
+
# set the tracking end frame
|
192 |
+
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
193 |
+
interactive_state["track_end_number"] = track_pause_number_slider
|
194 |
+
|
195 |
+
return video_state["painted_images"][track_pause_number_slider],interactive_state
|
196 |
+
|
197 |
+
# use sam to get the mask
|
198 |
+
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
|
199 |
+
"""
|
200 |
+
Args:
|
201 |
+
template_frame: PIL.Image
|
202 |
+
point_prompt: flag for positive or negative button click
|
203 |
+
click_state: [[points], [labels]]
|
204 |
+
"""
|
205 |
+
if point_prompt == "Positive":
|
206 |
+
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
207 |
+
interactive_state["positive_click_times"] += 1
|
208 |
+
else:
|
209 |
+
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
210 |
+
interactive_state["negative_click_times"] += 1
|
211 |
+
|
212 |
+
# prompt for sam model
|
213 |
+
model.samcontroler.sam_controler.reset_image()
|
214 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
|
215 |
+
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
216 |
+
|
217 |
+
mask, logit, painted_image = model.first_frame_click(
|
218 |
+
image=video_state["origin_images"][video_state["select_frame_number"]],
|
219 |
+
points=np.array(prompt["input_point"]),
|
220 |
+
labels=np.array(prompt["input_label"]),
|
221 |
+
multimask=prompt["multimask_output"],
|
222 |
+
)
|
223 |
+
video_state["masks"][video_state["select_frame_number"]] = mask
|
224 |
+
video_state["logits"][video_state["select_frame_number"]] = logit
|
225 |
+
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
226 |
+
|
227 |
+
return painted_image, video_state, interactive_state
|
228 |
+
|
229 |
+
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
230 |
+
mask = video_state["masks"][video_state["select_frame_number"]]
|
231 |
+
interactive_state["multi_mask"]["masks"].append(mask)
|
232 |
+
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
233 |
+
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
234 |
+
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
|
235 |
+
|
236 |
+
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
|
237 |
+
|
238 |
+
def clear_click(video_state, click_state):
|
239 |
+
click_state = [[],[]]
|
240 |
+
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
241 |
+
return template_frame, click_state
|
242 |
+
|
243 |
+
def remove_multi_mask(interactive_state, mask_dropdown):
|
244 |
+
interactive_state["multi_mask"]["mask_names"]= []
|
245 |
+
interactive_state["multi_mask"]["masks"] = []
|
246 |
+
|
247 |
+
return interactive_state, gr.update(choices=[],value=[])
|
248 |
+
|
249 |
+
def show_mask(video_state, interactive_state, mask_dropdown):
|
250 |
+
mask_dropdown.sort()
|
251 |
+
if video_state["origin_images"]:
|
252 |
+
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
253 |
+
for i in range(len(mask_dropdown)):
|
254 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
255 |
+
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
256 |
+
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
257 |
+
|
258 |
+
return select_frame
|
259 |
+
|
260 |
+
# image matting
|
261 |
+
def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
|
262 |
+
matanyone_processor.clear_memory()
|
263 |
+
if interactive_state["track_end_number"]:
|
264 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
265 |
+
else:
|
266 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
|
267 |
+
|
268 |
+
if interactive_state["multi_mask"]["masks"]:
|
269 |
+
if len(mask_dropdown) == 0:
|
270 |
+
mask_dropdown = ["mask_001"]
|
271 |
+
mask_dropdown.sort()
|
272 |
+
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
|
273 |
+
for i in range(1,len(mask_dropdown)):
|
274 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
275 |
+
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
|
276 |
+
video_state["masks"][video_state["select_frame_number"]]= template_mask
|
277 |
+
else:
|
278 |
+
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
279 |
+
|
280 |
+
# operation error
|
281 |
+
if len(np.unique(template_mask))==1:
|
282 |
+
template_mask[0][0]=1
|
283 |
+
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
|
284 |
+
foreground_output = Image.fromarray(foreground[-1])
|
285 |
+
alpha_output = Image.fromarray(alpha[-1][:,:,0])
|
286 |
+
return foreground_output, alpha_output
|
287 |
+
|
288 |
+
# video matting
|
289 |
+
def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
|
290 |
+
matanyone_processor.clear_memory()
|
291 |
+
if interactive_state["track_end_number"]:
|
292 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
293 |
+
else:
|
294 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
|
295 |
+
|
296 |
+
if interactive_state["multi_mask"]["masks"]:
|
297 |
+
if len(mask_dropdown) == 0:
|
298 |
+
mask_dropdown = ["mask_001"]
|
299 |
+
mask_dropdown.sort()
|
300 |
+
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
|
301 |
+
for i in range(1,len(mask_dropdown)):
|
302 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
303 |
+
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
|
304 |
+
video_state["masks"][video_state["select_frame_number"]]= template_mask
|
305 |
+
else:
|
306 |
+
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
307 |
+
fps = video_state["fps"]
|
308 |
+
|
309 |
+
audio_path = video_state["audio"]
|
310 |
+
|
311 |
+
# operation error
|
312 |
+
if len(np.unique(template_mask))==1:
|
313 |
+
template_mask[0][0]=1
|
314 |
+
foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
|
315 |
+
|
316 |
+
foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
|
317 |
+
alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
|
318 |
+
|
319 |
+
return foreground_output, alpha_output
|
320 |
+
|
321 |
+
|
322 |
+
def add_audio_to_video(video_path, audio_path, output_path):
|
323 |
+
try:
|
324 |
+
video_input = ffmpeg.input(video_path)
|
325 |
+
audio_input = ffmpeg.input(audio_path)
|
326 |
+
|
327 |
+
_ = (
|
328 |
+
ffmpeg
|
329 |
+
.output(video_input, audio_input, output_path, vcodec="copy", acodec="aac")
|
330 |
+
.run(overwrite_output=True, capture_stdout=True, capture_stderr=True)
|
331 |
+
)
|
332 |
+
return output_path
|
333 |
+
except ffmpeg.Error as e:
|
334 |
+
print(f"FFmpeg error:\n{e.stderr.decode()}")
|
335 |
+
return None
|
336 |
+
|
337 |
+
|
338 |
+
def generate_video_from_frames(frames, output_path, fps=30, gray2rgb=False, audio_path=""):
|
339 |
+
"""
|
340 |
+
Generates a video from a list of frames.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
frames (list of numpy arrays): The frames to include in the video.
|
344 |
+
output_path (str): The path to save the generated video.
|
345 |
+
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
346 |
+
"""
|
347 |
+
frames = torch.from_numpy(np.asarray(frames))
|
348 |
+
_, h, w, _ = frames.shape
|
349 |
+
if gray2rgb:
|
350 |
+
frames = np.repeat(frames, 3, axis=3)
|
351 |
+
|
352 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
353 |
+
os.makedirs(os.path.dirname(output_path))
|
354 |
+
video_temp_path = output_path.replace(".mp4", "_temp.mp4")
|
355 |
+
|
356 |
+
# resize back to ensure input resolution
|
357 |
+
imageio.mimwrite(video_temp_path, frames, fps=fps, quality=7,
|
358 |
+
codec='libx264', ffmpeg_params=["-vf", f"scale={w}:{h}"])
|
359 |
+
|
360 |
+
# add audio to video if audio path exists
|
361 |
+
if audio_path != "" and os.path.exists(audio_path):
|
362 |
+
output_path = add_audio_to_video(video_temp_path, audio_path, output_path)
|
363 |
+
os.remove(video_temp_path)
|
364 |
+
return output_path
|
365 |
+
else:
|
366 |
+
return video_temp_path
|
367 |
+
|
368 |
+
# reset all states for a new input
|
369 |
+
def restart():
|
370 |
+
return {
|
371 |
+
"user_name": "",
|
372 |
+
"video_name": "",
|
373 |
+
"origin_images": None,
|
374 |
+
"painted_images": None,
|
375 |
+
"masks": None,
|
376 |
+
"inpaint_masks": None,
|
377 |
+
"logits": None,
|
378 |
+
"select_frame_number": 0,
|
379 |
+
"fps": 30
|
380 |
+
}, {
|
381 |
+
"inference_times": 0,
|
382 |
+
"negative_click_times" : 0,
|
383 |
+
"positive_click_times": 0,
|
384 |
+
"mask_save": args.mask_save,
|
385 |
+
"multi_mask": {
|
386 |
+
"mask_names": [],
|
387 |
+
"masks": []
|
388 |
+
},
|
389 |
+
"track_end_number": None,
|
390 |
+
}, [[],[]], None, None, \
|
391 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
|
392 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
393 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
394 |
+
gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
|
395 |
+
|
396 |
+
# args, defined in track_anything.py
|
397 |
+
args = parse_augment()
|
398 |
+
sam_checkpoint_url_dict = {
|
399 |
+
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
400 |
+
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
401 |
+
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
402 |
+
}
|
403 |
+
checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models')
|
404 |
+
|
405 |
+
sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder)
|
406 |
+
# initialize sams
|
407 |
+
model = MaskGenerator(sam_checkpoint, args)
|
408 |
+
|
409 |
+
# initialize matanyone
|
410 |
+
pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0"
|
411 |
+
ckpt_path = load_file_from_url(os.path.join(pretrain_model_url, 'matanyone.pth'), checkpoint_folder)
|
412 |
+
matanyone_model = get_matanyone_model(ckpt_path, args.device)
|
413 |
+
matanyone_model = matanyone_model.to(args.device).eval()
|
414 |
+
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
415 |
+
|
416 |
+
# download test samples
|
417 |
+
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
|
418 |
+
test_sample_path = os.path.join('/home/user/app/hugging_face/', "test_sample/")
|
419 |
+
load_file_from_url(os.path.join(media_url, 'test-sample0.mp4'), test_sample_path)
|
420 |
+
load_file_from_url(os.path.join(media_url, 'test-sample1.mp4'), test_sample_path)
|
421 |
+
load_file_from_url(os.path.join(media_url, 'test-sample2.mp4'), test_sample_path)
|
422 |
+
load_file_from_url(os.path.join(media_url, 'test-sample3.mp4'), test_sample_path)
|
423 |
+
load_file_from_url(os.path.join(media_url, 'test-sample4.mp4'), test_sample_path)
|
424 |
+
load_file_from_url(os.path.join(media_url, 'test-sample5.mp4'), test_sample_path)
|
425 |
+
load_file_from_url(os.path.join(media_url, 'test-sample6.mp4'), test_sample_path)
|
426 |
+
load_file_from_url(os.path.join(media_url, 'test-sample0.png'), test_sample_path)
|
427 |
+
load_file_from_url(os.path.join(media_url, 'test-sample1.png'), test_sample_path)
|
428 |
+
|
429 |
+
# download assets
|
430 |
+
assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
|
431 |
+
load_file_from_url(os.path.join(media_url, 'tutorial_single_target.mp4'), assets_path)
|
432 |
+
load_file_from_url(os.path.join(media_url, 'tutorial_multi_targets.mp4'), assets_path)
|
433 |
+
|
434 |
+
# documents
|
435 |
+
title = r"""<div class="multi-layer" align="center"><span>MatAnyone</span></div>
|
436 |
+
"""
|
437 |
+
description = r"""
|
438 |
+
<b>Official Gradio demo</b> for <a href='https://github.com/pq-yang/MatAnyone' target='_blank'><b>MatAnyone: Stable Video Matting with Consistent Memory Propagation</b></a>.<br>
|
439 |
+
🔥 MatAnyone is a practical human video matting framework supporting target assignment 🎯.<br>
|
440 |
+
🎪 Try to drop your video/image, assign the target masks with a few clicks, and get the the matting results 🤡!<br>
|
441 |
+
"""
|
442 |
+
article = r"""
|
443 |
+
<b>If MatAnyone is helpful, please help to 🌟 the <a href='https://github.com/pq-yang/MatAnyone' target='_blank'>Github Repo</a>. Thanks!</b>
|
444 |
+
|
445 |
+
---
|
446 |
+
|
447 |
+
📑 **Citation**
|
448 |
+
<br>
|
449 |
+
If our work is useful for your research, please consider citing:
|
450 |
+
```bibtex
|
451 |
+
@InProceedings{yang2025matanyone,
|
452 |
+
title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation},
|
453 |
+
author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change},
|
454 |
+
booktitle = {arXiv preprint arXiv:2501.14677},
|
455 |
+
year = {2025}
|
456 |
+
}
|
457 |
+
```
|
458 |
+
📝 **License**
|
459 |
+
<br>
|
460 |
+
This project is licensed under <a rel="license" href="https://github.com/pq-yang/MatAnyone/blob/main/LICENSE">S-Lab License 1.0</a>.
|
461 |
+
Redistribution and use for non-commercial purposes should follow this license.
|
462 |
+
<br>
|
463 |
+
📧 **Contact**
|
464 |
+
<br>
|
465 |
+
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
|
466 |
+
<br>
|
467 |
+
👏 **Acknowledgement**
|
468 |
+
<br>
|
469 |
+
The project is developed upon [Cutie](https://github.com/hkchengrex/Cutie), and harnesses the capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works!
|
470 |
+
"""
|
471 |
+
|
472 |
+
my_custom_css = """
|
473 |
+
.gradio-container {width: 85% !important; margin: 0 auto;}
|
474 |
+
.gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important}
|
475 |
+
button {border-radius: 8px !important;}
|
476 |
+
.new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;}
|
477 |
+
.green_button {background-color: #4CAF50 !important; color: #ffffff !important; border: none !important;}
|
478 |
+
.new_button:hover {background-color: #4b4b4b !important;}
|
479 |
+
.green_button:hover {background-color: #77bd79 !important;}
|
480 |
+
|
481 |
+
.mask_button_group {gap: 10px !important;}
|
482 |
+
.video .wrap.svelte-lcpz3o {
|
483 |
+
display: flex !important;
|
484 |
+
align-items: center !important;
|
485 |
+
justify-content: center !important;
|
486 |
+
height: auto !important;
|
487 |
+
max-height: 300px !important;
|
488 |
+
}
|
489 |
+
.video .wrap.svelte-lcpz3o > :first-child {
|
490 |
+
height: auto !important;
|
491 |
+
width: 100% !important;
|
492 |
+
object-fit: contain !important;
|
493 |
+
}
|
494 |
+
.video .container.svelte-sxyn79 {
|
495 |
+
display: none !important;
|
496 |
+
}
|
497 |
+
.margin_center {width: 50% !important; margin: auto !important;}
|
498 |
+
.jc_center {justify-content: center !important;}
|
499 |
+
.video-title {
|
500 |
+
margin-bottom: 5px !important;
|
501 |
+
}
|
502 |
+
.custom-bg {
|
503 |
+
background-color: #f0f0f0;
|
504 |
+
padding: 10px;
|
505 |
+
border-radius: 10px;
|
506 |
+
}
|
507 |
+
|
508 |
+
<style>
|
509 |
+
@import url('https://fonts.googleapis.com/css2?family=Sarpanch:wght@400;500;600;700;800;900&family=Sen:[email protected]&family=Sixtyfour+Convergence&family=Stardos+Stencil:wght@400;700&display=swap');
|
510 |
+
body {
|
511 |
+
display: flex;
|
512 |
+
justify-content: center;
|
513 |
+
align-items: center;
|
514 |
+
height: 100vh;
|
515 |
+
margin: 0;
|
516 |
+
background-color: #0d1117;
|
517 |
+
font-family: Arial, sans-serif;
|
518 |
+
font-size: 18px;
|
519 |
+
}
|
520 |
+
.title-container {
|
521 |
+
text-align: center;
|
522 |
+
padding: 0;
|
523 |
+
margin: 0;
|
524 |
+
background: white;
|
525 |
+
height: 5vh;
|
526 |
+
width: 80vw;
|
527 |
+
font-family: "Sarpanch", sans-serif;
|
528 |
+
font-weight: 60;
|
529 |
+
}
|
530 |
+
#custom-markdown {
|
531 |
+
font-family: "Roboto", sans-serif;
|
532 |
+
font-size: 18px;
|
533 |
+
color: #333333;
|
534 |
+
font-weight: bold;
|
535 |
+
}
|
536 |
+
small {
|
537 |
+
font-size: 60%;
|
538 |
+
}
|
539 |
+
</style>
|
540 |
+
"""
|
541 |
+
|
542 |
+
with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
543 |
+
gr.HTML('''
|
544 |
+
<div class="title-container">
|
545 |
+
<h1 class="title is-2 publication-title"
|
546 |
+
style="font-size:50px; font-family: 'Sarpanch', serif;
|
547 |
+
background: linear-gradient(to right, #d231d8, #2dc464);
|
548 |
+
display: inline-block; -webkit-background-clip: text;
|
549 |
+
-webkit-text-fill-color: transparent;">
|
550 |
+
MatAnyone
|
551 |
+
</h1>
|
552 |
+
</div>
|
553 |
+
''')
|
554 |
+
|
555 |
+
gr.Markdown(description)
|
556 |
+
|
557 |
+
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
|
558 |
+
with gr.Row():
|
559 |
+
with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
|
560 |
+
with gr.Row():
|
561 |
+
with gr.Column():
|
562 |
+
gr.Markdown("### Case 1: Single Target")
|
563 |
+
gr.Video(value="/home/user/app/hugging_face/assets/tutorial_single_target.mp4", elem_classes="video")
|
564 |
+
|
565 |
+
with gr.Column():
|
566 |
+
gr.Markdown("### Case 2: Multiple Targets")
|
567 |
+
gr.Video(value="/home/user/app/hugging_face/assets/tutorial_multi_targets.mp4", elem_classes="video")
|
568 |
+
|
569 |
+
with gr.Tabs():
|
570 |
+
with gr.TabItem("Video"):
|
571 |
+
click_state = gr.State([[],[]])
|
572 |
+
|
573 |
+
interactive_state = gr.State({
|
574 |
+
"inference_times": 0,
|
575 |
+
"negative_click_times" : 0,
|
576 |
+
"positive_click_times": 0,
|
577 |
+
"mask_save": args.mask_save,
|
578 |
+
"multi_mask": {
|
579 |
+
"mask_names": [],
|
580 |
+
"masks": []
|
581 |
+
},
|
582 |
+
"track_end_number": None,
|
583 |
+
}
|
584 |
+
)
|
585 |
+
|
586 |
+
video_state = gr.State(
|
587 |
+
{
|
588 |
+
"user_name": "",
|
589 |
+
"video_name": "",
|
590 |
+
"origin_images": None,
|
591 |
+
"painted_images": None,
|
592 |
+
"masks": None,
|
593 |
+
"inpaint_masks": None,
|
594 |
+
"logits": None,
|
595 |
+
"select_frame_number": 0,
|
596 |
+
"fps": 30,
|
597 |
+
"audio": "",
|
598 |
+
}
|
599 |
+
)
|
600 |
+
|
601 |
+
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
|
602 |
+
with gr.Row():
|
603 |
+
with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
|
604 |
+
with gr.Row():
|
605 |
+
erode_kernel_size = gr.Slider(label='Erode Kernel Size',
|
606 |
+
minimum=0,
|
607 |
+
maximum=30,
|
608 |
+
step=1,
|
609 |
+
value=10,
|
610 |
+
info="Erosion on the added mask",
|
611 |
+
interactive=True)
|
612 |
+
dilate_kernel_size = gr.Slider(label='Dilate Kernel Size',
|
613 |
+
minimum=0,
|
614 |
+
maximum=30,
|
615 |
+
step=1,
|
616 |
+
value=10,
|
617 |
+
info="Dilation on the added mask",
|
618 |
+
interactive=True)
|
619 |
+
|
620 |
+
with gr.Row():
|
621 |
+
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Start Frame", info="Choose the start frame for target assignment and video matting", visible=False)
|
622 |
+
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
|
623 |
+
with gr.Row():
|
624 |
+
point_prompt = gr.Radio(
|
625 |
+
choices=["Positive", "Negative"],
|
626 |
+
value="Positive",
|
627 |
+
label="Point Prompt",
|
628 |
+
info="Click to add positive or negative point for target mask",
|
629 |
+
interactive=True,
|
630 |
+
visible=False,
|
631 |
+
min_width=100,
|
632 |
+
scale=1)
|
633 |
+
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False)
|
634 |
+
|
635 |
+
gr.Markdown("---")
|
636 |
+
|
637 |
+
with gr.Column():
|
638 |
+
# input video
|
639 |
+
with gr.Row(equal_height=True):
|
640 |
+
with gr.Column(scale=2):
|
641 |
+
gr.Markdown("## Step1: Upload video")
|
642 |
+
with gr.Column(scale=2):
|
643 |
+
step2_title = gr.Markdown("## Step2: Add masks <small>(Several clicks then **`Add Mask`** <u>one by one</u>)</small>", visible=False)
|
644 |
+
with gr.Row(equal_height=True):
|
645 |
+
with gr.Column(scale=2):
|
646 |
+
video_input = gr.Video(label="Input Video", elem_classes="video")
|
647 |
+
extract_frames_button = gr.Button(value="Load Video", interactive=True, elem_classes="new_button")
|
648 |
+
with gr.Column(scale=2):
|
649 |
+
video_info = gr.Textbox(label="Video Info", visible=False)
|
650 |
+
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
|
651 |
+
with gr.Row(equal_height=True, elem_classes="mask_button_group"):
|
652 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, elem_classes="new_button", min_width=100)
|
653 |
+
add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100)
|
654 |
+
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100) # no use
|
655 |
+
matting_button = gr.Button(value="Video Matting", interactive=True, visible=False, elem_classes="green_button", min_width=100)
|
656 |
+
|
657 |
+
gr.HTML('<hr style="border: none; height: 1.5px; background: linear-gradient(to right, #a566b4, #74a781);margin: 5px 0;">')
|
658 |
+
|
659 |
+
# output video
|
660 |
+
with gr.Row(equal_height=True):
|
661 |
+
with gr.Column(scale=2):
|
662 |
+
foreground_video_output = gr.Video(label="Foreground Output", visible=False, elem_classes="video")
|
663 |
+
foreground_output_button = gr.Button(value="Foreground Output", visible=False, elem_classes="new_button")
|
664 |
+
with gr.Column(scale=2):
|
665 |
+
alpha_video_output = gr.Video(label="Alpha Output", visible=False, elem_classes="video")
|
666 |
+
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
|
667 |
+
|
668 |
+
|
669 |
+
# first step: get the video information
|
670 |
+
extract_frames_button.click(
|
671 |
+
fn=get_frames_from_video,
|
672 |
+
inputs=[
|
673 |
+
video_input, video_state
|
674 |
+
],
|
675 |
+
outputs=[video_state, video_info, template_frame,
|
676 |
+
image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
|
677 |
+
foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
|
678 |
+
)
|
679 |
+
|
680 |
+
# second step: select images from slider
|
681 |
+
image_selection_slider.release(fn=select_video_template,
|
682 |
+
inputs=[image_selection_slider, video_state, interactive_state],
|
683 |
+
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
|
684 |
+
track_pause_number_slider.release(fn=get_end_number,
|
685 |
+
inputs=[track_pause_number_slider, video_state, interactive_state],
|
686 |
+
outputs=[template_frame, interactive_state], api_name="end_image")
|
687 |
+
|
688 |
+
# click select image to get mask using sam
|
689 |
+
template_frame.select(
|
690 |
+
fn=sam_refine,
|
691 |
+
inputs=[video_state, point_prompt, click_state, interactive_state],
|
692 |
+
outputs=[template_frame, video_state, interactive_state]
|
693 |
+
)
|
694 |
+
|
695 |
+
# add different mask
|
696 |
+
add_mask_button.click(
|
697 |
+
fn=add_multi_mask,
|
698 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
699 |
+
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
|
700 |
+
)
|
701 |
+
|
702 |
+
remove_mask_button.click(
|
703 |
+
fn=remove_multi_mask,
|
704 |
+
inputs=[interactive_state, mask_dropdown],
|
705 |
+
outputs=[interactive_state, mask_dropdown]
|
706 |
+
)
|
707 |
+
|
708 |
+
# video matting
|
709 |
+
matting_button.click(
|
710 |
+
fn=video_matting,
|
711 |
+
inputs=[video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
|
712 |
+
outputs=[foreground_video_output, alpha_video_output]
|
713 |
+
)
|
714 |
+
|
715 |
+
# click to get mask
|
716 |
+
mask_dropdown.change(
|
717 |
+
fn=show_mask,
|
718 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
719 |
+
outputs=[template_frame]
|
720 |
+
)
|
721 |
+
|
722 |
+
# clear input
|
723 |
+
video_input.change(
|
724 |
+
fn=restart,
|
725 |
+
inputs=[],
|
726 |
+
outputs=[
|
727 |
+
video_state,
|
728 |
+
interactive_state,
|
729 |
+
click_state,
|
730 |
+
foreground_video_output, alpha_video_output,
|
731 |
+
template_frame,
|
732 |
+
image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
|
733 |
+
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
|
734 |
+
],
|
735 |
+
queue=False,
|
736 |
+
show_progress=False)
|
737 |
+
|
738 |
+
video_input.clear(
|
739 |
+
fn=restart,
|
740 |
+
inputs=[],
|
741 |
+
outputs=[
|
742 |
+
video_state,
|
743 |
+
interactive_state,
|
744 |
+
click_state,
|
745 |
+
foreground_video_output, alpha_video_output,
|
746 |
+
template_frame,
|
747 |
+
image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
|
748 |
+
add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
|
749 |
+
],
|
750 |
+
queue=False,
|
751 |
+
show_progress=False)
|
752 |
+
|
753 |
+
# points clear
|
754 |
+
clear_button_click.click(
|
755 |
+
fn = clear_click,
|
756 |
+
inputs = [video_state, click_state,],
|
757 |
+
outputs = [template_frame,click_state],
|
758 |
+
)
|
759 |
+
|
760 |
+
# set example
|
761 |
+
gr.Markdown("---")
|
762 |
+
gr.Markdown("## Examples")
|
763 |
+
gr.Examples(
|
764 |
+
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4", "test-sample5.mp4", "test-sample6.mp4"]],
|
765 |
+
inputs=[video_input],
|
766 |
+
)
|
767 |
+
|
768 |
+
with gr.TabItem("Image"):
|
769 |
+
click_state = gr.State([[],[]])
|
770 |
+
|
771 |
+
interactive_state = gr.State({
|
772 |
+
"inference_times": 0,
|
773 |
+
"negative_click_times" : 0,
|
774 |
+
"positive_click_times": 0,
|
775 |
+
"mask_save": args.mask_save,
|
776 |
+
"multi_mask": {
|
777 |
+
"mask_names": [],
|
778 |
+
"masks": []
|
779 |
+
},
|
780 |
+
"track_end_number": None,
|
781 |
+
}
|
782 |
+
)
|
783 |
+
|
784 |
+
image_state = gr.State(
|
785 |
+
{
|
786 |
+
"user_name": "",
|
787 |
+
"image_name": "",
|
788 |
+
"origin_images": None,
|
789 |
+
"painted_images": None,
|
790 |
+
"masks": None,
|
791 |
+
"inpaint_masks": None,
|
792 |
+
"logits": None,
|
793 |
+
"select_frame_number": 0,
|
794 |
+
"fps": 30
|
795 |
+
}
|
796 |
+
)
|
797 |
+
|
798 |
+
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
|
799 |
+
with gr.Row():
|
800 |
+
with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
|
801 |
+
with gr.Row():
|
802 |
+
erode_kernel_size = gr.Slider(label='Erode Kernel Size',
|
803 |
+
minimum=0,
|
804 |
+
maximum=30,
|
805 |
+
step=1,
|
806 |
+
value=10,
|
807 |
+
info="Erosion on the added mask",
|
808 |
+
interactive=True)
|
809 |
+
dilate_kernel_size = gr.Slider(label='Dilate Kernel Size',
|
810 |
+
minimum=0,
|
811 |
+
maximum=30,
|
812 |
+
step=1,
|
813 |
+
value=10,
|
814 |
+
info="Dilation on the added mask",
|
815 |
+
interactive=True)
|
816 |
+
|
817 |
+
with gr.Row():
|
818 |
+
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Num of Refinement Iterations", info="More iterations → More details & More time", visible=False)
|
819 |
+
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
|
820 |
+
with gr.Row():
|
821 |
+
point_prompt = gr.Radio(
|
822 |
+
choices=["Positive", "Negative"],
|
823 |
+
value="Positive",
|
824 |
+
label="Point Prompt",
|
825 |
+
info="Click to add positive or negative point for target mask",
|
826 |
+
interactive=True,
|
827 |
+
visible=False,
|
828 |
+
min_width=100,
|
829 |
+
scale=1)
|
830 |
+
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False)
|
831 |
+
|
832 |
+
gr.Markdown("---")
|
833 |
+
|
834 |
+
with gr.Column():
|
835 |
+
# input image
|
836 |
+
with gr.Row(equal_height=True):
|
837 |
+
with gr.Column(scale=2):
|
838 |
+
gr.Markdown("## Step1: Upload image")
|
839 |
+
with gr.Column(scale=2):
|
840 |
+
step2_title = gr.Markdown("## Step2: Add masks <small>(Several clicks then **`Add Mask`** <u>one by one</u>)</small>", visible=False)
|
841 |
+
with gr.Row(equal_height=True):
|
842 |
+
with gr.Column(scale=2):
|
843 |
+
image_input = gr.Image(label="Input Image", elem_classes="image")
|
844 |
+
extract_frames_button = gr.Button(value="Load Image", interactive=True, elem_classes="new_button")
|
845 |
+
with gr.Column(scale=2):
|
846 |
+
image_info = gr.Textbox(label="Image Info", visible=False)
|
847 |
+
template_frame = gr.Image(type="pil", label="Start Frame", interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
|
848 |
+
with gr.Row(equal_height=True, elem_classes="mask_button_group"):
|
849 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, elem_classes="new_button", min_width=100)
|
850 |
+
add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100)
|
851 |
+
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, elem_classes="new_button", min_width=100)
|
852 |
+
matting_button = gr.Button(value="Image Matting", interactive=True, visible=False, elem_classes="green_button", min_width=100)
|
853 |
+
|
854 |
+
gr.HTML('<hr style="border: none; height: 1.5px; background: linear-gradient(to right, #a566b4, #74a781);margin: 5px 0;">')
|
855 |
+
|
856 |
+
# output image
|
857 |
+
with gr.Row(equal_height=True):
|
858 |
+
with gr.Column(scale=2):
|
859 |
+
foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image")
|
860 |
+
foreground_output_button = gr.Button(value="Foreground Output", visible=False, elem_classes="new_button")
|
861 |
+
with gr.Column(scale=2):
|
862 |
+
alpha_image_output = gr.Image(type="pil", label="Alpha Output", visible=False, elem_classes="image")
|
863 |
+
alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
|
864 |
+
|
865 |
+
# first step: get the image information
|
866 |
+
extract_frames_button.click(
|
867 |
+
fn=get_frames_from_image,
|
868 |
+
inputs=[
|
869 |
+
image_input, image_state
|
870 |
+
],
|
871 |
+
outputs=[image_state, image_info, template_frame,
|
872 |
+
image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
|
873 |
+
foreground_image_output, alpha_image_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
|
874 |
+
)
|
875 |
+
|
876 |
+
# second step: select images from slider
|
877 |
+
image_selection_slider.release(fn=select_image_template,
|
878 |
+
inputs=[image_selection_slider, image_state, interactive_state],
|
879 |
+
outputs=[template_frame, image_state, interactive_state], api_name="select_image")
|
880 |
+
track_pause_number_slider.release(fn=get_end_number,
|
881 |
+
inputs=[track_pause_number_slider, image_state, interactive_state],
|
882 |
+
outputs=[template_frame, interactive_state], api_name="end_image")
|
883 |
+
|
884 |
+
# click select image to get mask using sam
|
885 |
+
template_frame.select(
|
886 |
+
fn=sam_refine,
|
887 |
+
inputs=[image_state, point_prompt, click_state, interactive_state],
|
888 |
+
outputs=[template_frame, image_state, interactive_state]
|
889 |
+
)
|
890 |
+
|
891 |
+
# add different mask
|
892 |
+
add_mask_button.click(
|
893 |
+
fn=add_multi_mask,
|
894 |
+
inputs=[image_state, interactive_state, mask_dropdown],
|
895 |
+
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
|
896 |
+
)
|
897 |
+
|
898 |
+
remove_mask_button.click(
|
899 |
+
fn=remove_multi_mask,
|
900 |
+
inputs=[interactive_state, mask_dropdown],
|
901 |
+
outputs=[interactive_state, mask_dropdown]
|
902 |
+
)
|
903 |
+
|
904 |
+
# image matting
|
905 |
+
matting_button.click(
|
906 |
+
fn=image_matting,
|
907 |
+
inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider],
|
908 |
+
outputs=[foreground_image_output, alpha_image_output]
|
909 |
+
)
|
910 |
+
|
911 |
+
# click to get mask
|
912 |
+
mask_dropdown.change(
|
913 |
+
fn=show_mask,
|
914 |
+
inputs=[image_state, interactive_state, mask_dropdown],
|
915 |
+
outputs=[template_frame]
|
916 |
+
)
|
917 |
+
|
918 |
+
# clear input
|
919 |
+
image_input.change(
|
920 |
+
fn=restart,
|
921 |
+
inputs=[],
|
922 |
+
outputs=[
|
923 |
+
image_state,
|
924 |
+
interactive_state,
|
925 |
+
click_state,
|
926 |
+
foreground_image_output, alpha_image_output,
|
927 |
+
template_frame,
|
928 |
+
image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
|
929 |
+
add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, image_info, step2_title
|
930 |
+
],
|
931 |
+
queue=False,
|
932 |
+
show_progress=False)
|
933 |
+
|
934 |
+
image_input.clear(
|
935 |
+
fn=restart,
|
936 |
+
inputs=[],
|
937 |
+
outputs=[
|
938 |
+
image_state,
|
939 |
+
interactive_state,
|
940 |
+
click_state,
|
941 |
+
foreground_image_output, alpha_image_output,
|
942 |
+
template_frame,
|
943 |
+
image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
|
944 |
+
add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, image_info, step2_title
|
945 |
+
],
|
946 |
+
queue=False,
|
947 |
+
show_progress=False)
|
948 |
+
|
949 |
+
# points clear
|
950 |
+
clear_button_click.click(
|
951 |
+
fn = clear_click,
|
952 |
+
inputs = [image_state, click_state,],
|
953 |
+
outputs = [template_frame,click_state],
|
954 |
+
)
|
955 |
+
|
956 |
+
# set example
|
957 |
+
gr.Markdown("---")
|
958 |
+
gr.Markdown("## Examples")
|
959 |
+
gr.Examples(
|
960 |
+
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.png", "test-sample1.png"]],
|
961 |
+
inputs=[image_input],
|
962 |
+
)
|
963 |
+
|
964 |
+
gr.Markdown(article)
|
965 |
+
|
966 |
+
demo.queue()
|
967 |
+
demo.launch(debug=True)
|
hugging_face/matanyone_wrapper.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
import torch
|
3 |
+
from torchvision.transforms.functional import to_tensor
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
def gen_dilate(alpha, min_kernel_size, max_kernel_size):
|
9 |
+
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
10 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
|
11 |
+
fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
|
12 |
+
dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
|
13 |
+
return dilate.astype(np.float32)
|
14 |
+
|
15 |
+
def gen_erosion(alpha, min_kernel_size, max_kernel_size):
|
16 |
+
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
17 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
|
18 |
+
fg = np.array(np.equal(alpha, 255).astype(np.float32))
|
19 |
+
erode = cv2.erode(fg, kernel, iterations=1)*255
|
20 |
+
return erode.astype(np.float32)
|
21 |
+
|
22 |
+
@torch.inference_mode()
|
23 |
+
@torch.cuda.amp.autocast()
|
24 |
+
def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
|
25 |
+
"""
|
26 |
+
Args:
|
27 |
+
frames_np: [(H,W,C)]*n, uint8
|
28 |
+
mask: (H,W), uint8
|
29 |
+
Outputs:
|
30 |
+
com: [(H,W,C)]*n, uint8
|
31 |
+
pha: [(H,W,C)]*n, uint8
|
32 |
+
"""
|
33 |
+
|
34 |
+
# print(f'===== [r_erode] {r_erode}; [r_dilate] {r_dilate} =====')
|
35 |
+
bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3))
|
36 |
+
objects = [1]
|
37 |
+
|
38 |
+
# [optional] erode & dilate on given seg mask
|
39 |
+
if r_dilate > 0:
|
40 |
+
mask = gen_dilate(mask, r_dilate, r_dilate)
|
41 |
+
if r_erode > 0:
|
42 |
+
mask = gen_erosion(mask, r_erode, r_erode)
|
43 |
+
|
44 |
+
mask = torch.from_numpy(mask).cuda()
|
45 |
+
|
46 |
+
frames_np = [frames_np[0]]* n_warmup + frames_np
|
47 |
+
|
48 |
+
frames = []
|
49 |
+
phas = []
|
50 |
+
for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
|
51 |
+
image = to_tensor(frame_single).cuda().float()
|
52 |
+
|
53 |
+
if ti == 0:
|
54 |
+
output_prob = processor.step(image, mask, objects=objects) # encode given mask
|
55 |
+
output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
|
56 |
+
else:
|
57 |
+
if ti <= n_warmup:
|
58 |
+
output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
|
59 |
+
else:
|
60 |
+
output_prob = processor.step(image)
|
61 |
+
|
62 |
+
# convert output probabilities to an object mask
|
63 |
+
mask = processor.output_prob_to_mask(output_prob)
|
64 |
+
|
65 |
+
pha = mask.unsqueeze(2).cpu().numpy()
|
66 |
+
com_np = frame_single / 255. * pha + bgr * (1 - pha)
|
67 |
+
|
68 |
+
# DONOT save the warmup frames
|
69 |
+
if ti > (n_warmup-1):
|
70 |
+
frames.append((com_np*255).astype(np.uint8))
|
71 |
+
phas.append((pha*255).astype(np.uint8))
|
72 |
+
|
73 |
+
return frames, phas
|
hugging_face/tools/__init__.py
ADDED
File without changes
|
hugging_face/tools/base_segmenter.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from PIL import Image, ImageDraw, ImageOps
|
5 |
+
import numpy as np
|
6 |
+
from typing import Union
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import PIL
|
10 |
+
from .mask_painter import mask_painter
|
11 |
+
|
12 |
+
|
13 |
+
class BaseSegmenter:
|
14 |
+
def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
|
15 |
+
"""
|
16 |
+
device: model device
|
17 |
+
SAM_checkpoint: path of SAM checkpoint
|
18 |
+
model_type: vit_b, vit_l, vit_h
|
19 |
+
"""
|
20 |
+
print(f"Initializing BaseSegmenter to {device}")
|
21 |
+
assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
|
22 |
+
|
23 |
+
self.device = device
|
24 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
25 |
+
self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
|
26 |
+
self.model.to(device=self.device)
|
27 |
+
self.predictor = SamPredictor(self.model)
|
28 |
+
self.embedded = False
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def set_image(self, image: np.ndarray):
|
32 |
+
# PIL.open(image_path) 3channel: RGB
|
33 |
+
# image embedding: avoid encode the same image multiple times
|
34 |
+
self.orignal_image = image
|
35 |
+
if self.embedded:
|
36 |
+
print('repeat embedding, please reset_image.')
|
37 |
+
return
|
38 |
+
self.predictor.set_image(image)
|
39 |
+
self.embedded = True
|
40 |
+
return
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def reset_image(self):
|
44 |
+
# reset image embeding
|
45 |
+
self.predictor.reset_image()
|
46 |
+
self.embedded = False
|
47 |
+
|
48 |
+
def predict(self, prompts, mode, multimask=True):
|
49 |
+
"""
|
50 |
+
image: numpy array, h, w, 3
|
51 |
+
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
|
52 |
+
prompts['point_coords']: numpy array [N,2]
|
53 |
+
prompts['point_labels']: numpy array [1,N]
|
54 |
+
prompts['mask_input']: numpy array [1,256,256]
|
55 |
+
mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
|
56 |
+
mask_outputs: True (return 3 masks), False (return 1 mask only)
|
57 |
+
whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
|
58 |
+
"""
|
59 |
+
assert self.embedded, 'prediction is called before set_image (feature embedding).'
|
60 |
+
assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
|
61 |
+
|
62 |
+
if mode == 'point':
|
63 |
+
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
|
64 |
+
point_labels=prompts['point_labels'],
|
65 |
+
multimask_output=multimask)
|
66 |
+
elif mode == 'mask':
|
67 |
+
masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
|
68 |
+
multimask_output=multimask)
|
69 |
+
elif mode == 'both': # both
|
70 |
+
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
|
71 |
+
point_labels=prompts['point_labels'],
|
72 |
+
mask_input=prompts['mask_input'],
|
73 |
+
multimask_output=multimask)
|
74 |
+
else:
|
75 |
+
raise("Not implement now!")
|
76 |
+
# masks (n, h, w), scores (n,), logits (n, 256, 256)
|
77 |
+
return masks, scores, logits
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
# load and show an image
|
82 |
+
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
83 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
|
84 |
+
|
85 |
+
# initialise BaseSegmenter
|
86 |
+
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
87 |
+
model_type = 'vit_h'
|
88 |
+
device = "cuda:4"
|
89 |
+
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
|
90 |
+
|
91 |
+
# image embedding (once embedded, multiple prompts can be applied)
|
92 |
+
base_segmenter.set_image(image)
|
93 |
+
|
94 |
+
# examples
|
95 |
+
# point only ------------------------
|
96 |
+
mode = 'point'
|
97 |
+
prompts = {
|
98 |
+
'point_coords': np.array([[500, 375], [1125, 625]]),
|
99 |
+
'point_labels': np.array([1, 1]),
|
100 |
+
}
|
101 |
+
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
102 |
+
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
103 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
104 |
+
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
105 |
+
|
106 |
+
# both ------------------------
|
107 |
+
mode = 'both'
|
108 |
+
mask_input = logits[np.argmax(scores), :, :]
|
109 |
+
prompts = {'mask_input': mask_input [None, :, :]}
|
110 |
+
prompts = {
|
111 |
+
'point_coords': np.array([[500, 375], [1125, 625]]),
|
112 |
+
'point_labels': np.array([1, 0]),
|
113 |
+
'mask_input': mask_input[None, :, :]
|
114 |
+
}
|
115 |
+
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
116 |
+
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
117 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
118 |
+
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
|
119 |
+
|
120 |
+
# mask only ------------------------
|
121 |
+
mode = 'mask'
|
122 |
+
mask_input = logits[np.argmax(scores), :, :]
|
123 |
+
|
124 |
+
prompts = {'mask_input': mask_input[None, :, :]}
|
125 |
+
|
126 |
+
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
127 |
+
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
128 |
+
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
129 |
+
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
|
hugging_face/tools/download_util.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
from torch.hub import download_url_to_file, get_dir
|
5 |
+
from tqdm import tqdm
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
|
8 |
+
def sizeof_fmt(size, suffix='B'):
|
9 |
+
"""Get human readable file size.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
size (int): File size.
|
13 |
+
suffix (str): Suffix. Default: 'B'.
|
14 |
+
|
15 |
+
Return:
|
16 |
+
str: Formated file siz.
|
17 |
+
"""
|
18 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
19 |
+
if abs(size) < 1024.0:
|
20 |
+
return f'{size:3.1f} {unit}{suffix}'
|
21 |
+
size /= 1024.0
|
22 |
+
return f'{size:3.1f} Y{suffix}'
|
23 |
+
|
24 |
+
|
25 |
+
def download_file_from_google_drive(file_id, save_path):
|
26 |
+
"""Download files from google drive.
|
27 |
+
Ref:
|
28 |
+
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
29 |
+
Args:
|
30 |
+
file_id (str): File id.
|
31 |
+
save_path (str): Save path.
|
32 |
+
"""
|
33 |
+
|
34 |
+
session = requests.Session()
|
35 |
+
URL = 'https://docs.google.com/uc?export=download'
|
36 |
+
params = {'id': file_id}
|
37 |
+
|
38 |
+
response = session.get(URL, params=params, stream=True)
|
39 |
+
token = get_confirm_token(response)
|
40 |
+
if token:
|
41 |
+
params['confirm'] = token
|
42 |
+
response = session.get(URL, params=params, stream=True)
|
43 |
+
|
44 |
+
# get file size
|
45 |
+
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
46 |
+
print(response_file_size)
|
47 |
+
if 'Content-Range' in response_file_size.headers:
|
48 |
+
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
49 |
+
else:
|
50 |
+
file_size = None
|
51 |
+
|
52 |
+
save_response_content(response, save_path, file_size)
|
53 |
+
|
54 |
+
|
55 |
+
def get_confirm_token(response):
|
56 |
+
for key, value in response.cookies.items():
|
57 |
+
if key.startswith('download_warning'):
|
58 |
+
return value
|
59 |
+
return None
|
60 |
+
|
61 |
+
|
62 |
+
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
63 |
+
if file_size is not None:
|
64 |
+
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
65 |
+
|
66 |
+
readable_file_size = sizeof_fmt(file_size)
|
67 |
+
else:
|
68 |
+
pbar = None
|
69 |
+
|
70 |
+
with open(destination, 'wb') as f:
|
71 |
+
downloaded_size = 0
|
72 |
+
for chunk in response.iter_content(chunk_size):
|
73 |
+
downloaded_size += chunk_size
|
74 |
+
if pbar is not None:
|
75 |
+
pbar.update(1)
|
76 |
+
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
77 |
+
if chunk: # filter out keep-alive new chunks
|
78 |
+
f.write(chunk)
|
79 |
+
if pbar is not None:
|
80 |
+
pbar.close()
|
81 |
+
|
82 |
+
|
83 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
84 |
+
"""Load file form http url, will download models if necessary.
|
85 |
+
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
86 |
+
Args:
|
87 |
+
url (str): URL to be downloaded.
|
88 |
+
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
89 |
+
Default: None.
|
90 |
+
progress (bool): Whether to show the download progress. Default: True.
|
91 |
+
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
92 |
+
Returns:
|
93 |
+
str: The path to the downloaded file.
|
94 |
+
"""
|
95 |
+
if model_dir is None: # use the pytorch hub_dir
|
96 |
+
hub_dir = get_dir()
|
97 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
98 |
+
|
99 |
+
os.makedirs(model_dir, exist_ok=True)
|
100 |
+
|
101 |
+
parts = urlparse(url)
|
102 |
+
filename = os.path.basename(parts.path)
|
103 |
+
if file_name is not None:
|
104 |
+
filename = file_name
|
105 |
+
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
106 |
+
if not os.path.exists(cached_file):
|
107 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
108 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
109 |
+
return cached_file
|
hugging_face/tools/interact_tools.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from PIL import Image, ImageDraw, ImageOps
|
5 |
+
import numpy as np
|
6 |
+
from typing import Union
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import PIL
|
10 |
+
from .mask_painter import mask_painter as mask_painter2
|
11 |
+
from .base_segmenter import BaseSegmenter
|
12 |
+
from .painter import mask_painter, point_painter
|
13 |
+
import os
|
14 |
+
import requests
|
15 |
+
import sys
|
16 |
+
|
17 |
+
|
18 |
+
mask_color = 3
|
19 |
+
mask_alpha = 0.7
|
20 |
+
contour_color = 1
|
21 |
+
contour_width = 5
|
22 |
+
point_color_ne = 8
|
23 |
+
point_color_ps = 50
|
24 |
+
point_alpha = 0.9
|
25 |
+
point_radius = 15
|
26 |
+
contour_color = 2
|
27 |
+
contour_width = 5
|
28 |
+
|
29 |
+
|
30 |
+
class SamControler():
|
31 |
+
def __init__(self, SAM_checkpoint, model_type, device):
|
32 |
+
'''
|
33 |
+
initialize sam controler
|
34 |
+
'''
|
35 |
+
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
36 |
+
|
37 |
+
|
38 |
+
# def seg_again(self, image: np.ndarray):
|
39 |
+
# '''
|
40 |
+
# it is used when interact in video
|
41 |
+
# '''
|
42 |
+
# self.sam_controler.reset_image()
|
43 |
+
# self.sam_controler.set_image(image)
|
44 |
+
# return
|
45 |
+
|
46 |
+
|
47 |
+
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
|
48 |
+
'''
|
49 |
+
it is used in first frame in video
|
50 |
+
return: mask, logit, painted image(mask+point)
|
51 |
+
'''
|
52 |
+
# self.sam_controler.set_image(image)
|
53 |
+
origal_image = self.sam_controler.orignal_image
|
54 |
+
neg_flag = labels[-1]
|
55 |
+
if neg_flag==1:
|
56 |
+
#find neg
|
57 |
+
prompts = {
|
58 |
+
'point_coords': points,
|
59 |
+
'point_labels': labels,
|
60 |
+
}
|
61 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
62 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
63 |
+
prompts = {
|
64 |
+
'point_coords': points,
|
65 |
+
'point_labels': labels,
|
66 |
+
'mask_input': logit[None, :, :]
|
67 |
+
}
|
68 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
69 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
70 |
+
else:
|
71 |
+
#find positive
|
72 |
+
prompts = {
|
73 |
+
'point_coords': points,
|
74 |
+
'point_labels': labels,
|
75 |
+
}
|
76 |
+
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
77 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
78 |
+
|
79 |
+
|
80 |
+
assert len(points)==len(labels)
|
81 |
+
|
82 |
+
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
83 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
84 |
+
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
85 |
+
painted_image = Image.fromarray(painted_image)
|
86 |
+
|
87 |
+
return mask, logit, painted_image
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
hugging_face/tools/mask_painter.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import copy
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
def colormap(rgb=True):
|
10 |
+
color_list = np.array(
|
11 |
+
[
|
12 |
+
0.000, 0.000, 0.000,
|
13 |
+
1.000, 1.000, 1.000,
|
14 |
+
1.000, 0.498, 0.313,
|
15 |
+
0.392, 0.581, 0.929,
|
16 |
+
0.000, 0.447, 0.741,
|
17 |
+
0.850, 0.325, 0.098,
|
18 |
+
0.929, 0.694, 0.125,
|
19 |
+
0.494, 0.184, 0.556,
|
20 |
+
0.466, 0.674, 0.188,
|
21 |
+
0.301, 0.745, 0.933,
|
22 |
+
0.635, 0.078, 0.184,
|
23 |
+
0.300, 0.300, 0.300,
|
24 |
+
0.600, 0.600, 0.600,
|
25 |
+
1.000, 0.000, 0.000,
|
26 |
+
1.000, 0.500, 0.000,
|
27 |
+
0.749, 0.749, 0.000,
|
28 |
+
0.000, 1.000, 0.000,
|
29 |
+
0.000, 0.000, 1.000,
|
30 |
+
0.667, 0.000, 1.000,
|
31 |
+
0.333, 0.333, 0.000,
|
32 |
+
0.333, 0.667, 0.000,
|
33 |
+
0.333, 1.000, 0.000,
|
34 |
+
0.667, 0.333, 0.000,
|
35 |
+
0.667, 0.667, 0.000,
|
36 |
+
0.667, 1.000, 0.000,
|
37 |
+
1.000, 0.333, 0.000,
|
38 |
+
1.000, 0.667, 0.000,
|
39 |
+
1.000, 1.000, 0.000,
|
40 |
+
0.000, 0.333, 0.500,
|
41 |
+
0.000, 0.667, 0.500,
|
42 |
+
0.000, 1.000, 0.500,
|
43 |
+
0.333, 0.000, 0.500,
|
44 |
+
0.333, 0.333, 0.500,
|
45 |
+
0.333, 0.667, 0.500,
|
46 |
+
0.333, 1.000, 0.500,
|
47 |
+
0.667, 0.000, 0.500,
|
48 |
+
0.667, 0.333, 0.500,
|
49 |
+
0.667, 0.667, 0.500,
|
50 |
+
0.667, 1.000, 0.500,
|
51 |
+
1.000, 0.000, 0.500,
|
52 |
+
1.000, 0.333, 0.500,
|
53 |
+
1.000, 0.667, 0.500,
|
54 |
+
1.000, 1.000, 0.500,
|
55 |
+
0.000, 0.333, 1.000,
|
56 |
+
0.000, 0.667, 1.000,
|
57 |
+
0.000, 1.000, 1.000,
|
58 |
+
0.333, 0.000, 1.000,
|
59 |
+
0.333, 0.333, 1.000,
|
60 |
+
0.333, 0.667, 1.000,
|
61 |
+
0.333, 1.000, 1.000,
|
62 |
+
0.667, 0.000, 1.000,
|
63 |
+
0.667, 0.333, 1.000,
|
64 |
+
0.667, 0.667, 1.000,
|
65 |
+
0.667, 1.000, 1.000,
|
66 |
+
1.000, 0.000, 1.000,
|
67 |
+
1.000, 0.333, 1.000,
|
68 |
+
1.000, 0.667, 1.000,
|
69 |
+
0.167, 0.000, 0.000,
|
70 |
+
0.333, 0.000, 0.000,
|
71 |
+
0.500, 0.000, 0.000,
|
72 |
+
0.667, 0.000, 0.000,
|
73 |
+
0.833, 0.000, 0.000,
|
74 |
+
1.000, 0.000, 0.000,
|
75 |
+
0.000, 0.167, 0.000,
|
76 |
+
0.000, 0.333, 0.000,
|
77 |
+
0.000, 0.500, 0.000,
|
78 |
+
0.000, 0.667, 0.000,
|
79 |
+
0.000, 0.833, 0.000,
|
80 |
+
0.000, 1.000, 0.000,
|
81 |
+
0.000, 0.000, 0.167,
|
82 |
+
0.000, 0.000, 0.333,
|
83 |
+
0.000, 0.000, 0.500,
|
84 |
+
0.000, 0.000, 0.667,
|
85 |
+
0.000, 0.000, 0.833,
|
86 |
+
0.000, 0.000, 1.000,
|
87 |
+
0.143, 0.143, 0.143,
|
88 |
+
0.286, 0.286, 0.286,
|
89 |
+
0.429, 0.429, 0.429,
|
90 |
+
0.571, 0.571, 0.571,
|
91 |
+
0.714, 0.714, 0.714,
|
92 |
+
0.857, 0.857, 0.857
|
93 |
+
]
|
94 |
+
).astype(np.float32)
|
95 |
+
color_list = color_list.reshape((-1, 3)) * 255
|
96 |
+
if not rgb:
|
97 |
+
color_list = color_list[:, ::-1]
|
98 |
+
return color_list
|
99 |
+
|
100 |
+
|
101 |
+
color_list = colormap()
|
102 |
+
color_list = color_list.astype('uint8').tolist()
|
103 |
+
|
104 |
+
|
105 |
+
def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
|
106 |
+
background_color = np.array(background_color)
|
107 |
+
contour_color = np.array(contour_color)
|
108 |
+
|
109 |
+
# background_mask = 1 - background_mask
|
110 |
+
# contour_mask = 1 - contour_mask
|
111 |
+
|
112 |
+
for i in range(3):
|
113 |
+
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
|
114 |
+
+ background_color[i] * (background_alpha-background_mask*background_alpha)
|
115 |
+
|
116 |
+
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
|
117 |
+
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
|
118 |
+
|
119 |
+
return image.astype('uint8')
|
120 |
+
|
121 |
+
|
122 |
+
def mask_generator_00(mask, background_radius, contour_radius):
|
123 |
+
# no background width when '00'
|
124 |
+
# distance map
|
125 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
126 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
127 |
+
dist_map = dist_transform_fore - dist_transform_back
|
128 |
+
# ...:::!!!:::...
|
129 |
+
contour_radius += 2
|
130 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
131 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
132 |
+
contour_mask[contour_mask>0.5] = 1.
|
133 |
+
|
134 |
+
return mask, contour_mask
|
135 |
+
|
136 |
+
|
137 |
+
def mask_generator_01(mask, background_radius, contour_radius):
|
138 |
+
# no background width when '00'
|
139 |
+
# distance map
|
140 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
141 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
142 |
+
dist_map = dist_transform_fore - dist_transform_back
|
143 |
+
# ...:::!!!:::...
|
144 |
+
contour_radius += 2
|
145 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
146 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
147 |
+
return mask, contour_mask
|
148 |
+
|
149 |
+
|
150 |
+
def mask_generator_10(mask, background_radius, contour_radius):
|
151 |
+
# distance map
|
152 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
153 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
154 |
+
dist_map = dist_transform_fore - dist_transform_back
|
155 |
+
# .....:::::!!!!!
|
156 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
157 |
+
background_mask = (background_mask - np.min(background_mask))
|
158 |
+
background_mask = background_mask / np.max(background_mask)
|
159 |
+
# ...:::!!!:::...
|
160 |
+
contour_radius += 2
|
161 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
162 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
163 |
+
contour_mask[contour_mask>0.5] = 1.
|
164 |
+
return background_mask, contour_mask
|
165 |
+
|
166 |
+
|
167 |
+
def mask_generator_11(mask, background_radius, contour_radius):
|
168 |
+
# distance map
|
169 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
170 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
171 |
+
dist_map = dist_transform_fore - dist_transform_back
|
172 |
+
# .....:::::!!!!!
|
173 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
174 |
+
background_mask = (background_mask - np.min(background_mask))
|
175 |
+
background_mask = background_mask / np.max(background_mask)
|
176 |
+
# ...:::!!!:::...
|
177 |
+
contour_radius += 2
|
178 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
179 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
180 |
+
return background_mask, contour_mask
|
181 |
+
|
182 |
+
|
183 |
+
def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
|
184 |
+
"""
|
185 |
+
Input:
|
186 |
+
input_image: numpy array
|
187 |
+
input_mask: numpy array
|
188 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
189 |
+
background_blur_radius: radius of background blur, must be odd number
|
190 |
+
contour_width: width of mask contour, must be odd number
|
191 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
192 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
193 |
+
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
|
194 |
+
|
195 |
+
Output:
|
196 |
+
painted_image: numpy array
|
197 |
+
"""
|
198 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
199 |
+
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
200 |
+
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
|
201 |
+
|
202 |
+
# downsample input image and mask
|
203 |
+
width, height = input_image.shape[0], input_image.shape[1]
|
204 |
+
res = 1024
|
205 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
206 |
+
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
207 |
+
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
208 |
+
|
209 |
+
# 0: background, 1: foreground
|
210 |
+
msk = np.clip(input_mask, 0, 1)
|
211 |
+
|
212 |
+
# generate masks for background and contour pixels
|
213 |
+
background_radius = (background_blur_radius - 1) // 2
|
214 |
+
contour_radius = (contour_width - 1) // 2
|
215 |
+
generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
|
216 |
+
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
217 |
+
|
218 |
+
# paint
|
219 |
+
painted_image = vis_add_mask\
|
220 |
+
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
|
221 |
+
|
222 |
+
return painted_image
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == '__main__':
|
226 |
+
|
227 |
+
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
228 |
+
background_blur_radius = 31 # radius of background blur, must be odd number
|
229 |
+
contour_width = 11 # contour width, must be odd number
|
230 |
+
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
|
231 |
+
contour_alpha = 1 # transparency of background, 0: no contour highlighted
|
232 |
+
|
233 |
+
# load input image and mask
|
234 |
+
input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
|
235 |
+
input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
|
236 |
+
|
237 |
+
# paint
|
238 |
+
overall_time_1 = 0
|
239 |
+
overall_time_2 = 0
|
240 |
+
overall_time_3 = 0
|
241 |
+
overall_time_4 = 0
|
242 |
+
overall_time_5 = 0
|
243 |
+
|
244 |
+
for i in range(50):
|
245 |
+
t2 = time.time()
|
246 |
+
painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
|
247 |
+
e2 = time.time()
|
248 |
+
|
249 |
+
t3 = time.time()
|
250 |
+
painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
|
251 |
+
e3 = time.time()
|
252 |
+
|
253 |
+
t1 = time.time()
|
254 |
+
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
|
255 |
+
e1 = time.time()
|
256 |
+
|
257 |
+
t4 = time.time()
|
258 |
+
painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
|
259 |
+
e4 = time.time()
|
260 |
+
|
261 |
+
t5 = time.time()
|
262 |
+
painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
|
263 |
+
e5 = time.time()
|
264 |
+
|
265 |
+
overall_time_1 += (e1 - t1)
|
266 |
+
overall_time_2 += (e2 - t2)
|
267 |
+
overall_time_3 += (e3 - t3)
|
268 |
+
overall_time_4 += (e4 - t4)
|
269 |
+
overall_time_5 += (e5 - t5)
|
270 |
+
|
271 |
+
print(f'average time w gaussian: {overall_time_1/50}')
|
272 |
+
print(f'average time w/o gaussian00: {overall_time_2/50}')
|
273 |
+
print(f'average time w/o gaussian10: {overall_time_3/50}')
|
274 |
+
print(f'average time w/o gaussian01: {overall_time_4/50}')
|
275 |
+
print(f'average time w/o gaussian11: {overall_time_5/50}')
|
276 |
+
|
277 |
+
# save
|
278 |
+
painted_image_00 = Image.fromarray(painted_image_00)
|
279 |
+
painted_image_00.save('./test_img/painter_output_image_00.png')
|
280 |
+
|
281 |
+
painted_image_10 = Image.fromarray(painted_image_10)
|
282 |
+
painted_image_10.save('./test_img/painter_output_image_10.png')
|
283 |
+
|
284 |
+
painted_image_01 = Image.fromarray(painted_image_01)
|
285 |
+
painted_image_01.save('./test_img/painter_output_image_01.png')
|
286 |
+
|
287 |
+
painted_image_11 = Image.fromarray(painted_image_11)
|
288 |
+
painted_image_11.save('./test_img/painter_output_image_11.png')
|
hugging_face/tools/misc.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
from os import path as osp
|
10 |
+
|
11 |
+
def constant_init(module, val, bias=0):
|
12 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
13 |
+
nn.init.constant_(module.weight, val)
|
14 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
15 |
+
nn.init.constant_(module.bias, bias)
|
16 |
+
|
17 |
+
initialized_logger = {}
|
18 |
+
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
19 |
+
"""Get the root logger.
|
20 |
+
The logger will be initialized if it has not been initialized. By default a
|
21 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
22 |
+
also be added.
|
23 |
+
Args:
|
24 |
+
logger_name (str): root logger name. Default: 'basicsr'.
|
25 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
26 |
+
will be added to the root logger.
|
27 |
+
log_level (int): The root logger level. Note that only the process of
|
28 |
+
rank 0 is affected, while other processes will set the level to
|
29 |
+
"Error" and be silent most of the time.
|
30 |
+
Returns:
|
31 |
+
logging.Logger: The root logger.
|
32 |
+
"""
|
33 |
+
logger = logging.getLogger(logger_name)
|
34 |
+
# if the logger has been initialized, just return it
|
35 |
+
if logger_name in initialized_logger:
|
36 |
+
return logger
|
37 |
+
|
38 |
+
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
39 |
+
stream_handler = logging.StreamHandler()
|
40 |
+
stream_handler.setFormatter(logging.Formatter(format_str))
|
41 |
+
logger.addHandler(stream_handler)
|
42 |
+
logger.propagate = False
|
43 |
+
|
44 |
+
if log_file is not None:
|
45 |
+
logger.setLevel(log_level)
|
46 |
+
# add file handler
|
47 |
+
# file_handler = logging.FileHandler(log_file, 'w')
|
48 |
+
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
|
49 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
50 |
+
file_handler.setLevel(log_level)
|
51 |
+
logger.addHandler(file_handler)
|
52 |
+
initialized_logger[logger_name] = True
|
53 |
+
return logger
|
54 |
+
|
55 |
+
|
56 |
+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
57 |
+
torch.__version__)[0][:3])] >= [1, 12, 0]
|
58 |
+
|
59 |
+
def gpu_is_available():
|
60 |
+
if IS_HIGH_VERSION:
|
61 |
+
if torch.backends.mps.is_available():
|
62 |
+
return True
|
63 |
+
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
|
64 |
+
|
65 |
+
def get_device(gpu_id=None):
|
66 |
+
if gpu_id is None:
|
67 |
+
gpu_str = ''
|
68 |
+
elif isinstance(gpu_id, int):
|
69 |
+
gpu_str = f':{gpu_id}'
|
70 |
+
else:
|
71 |
+
raise TypeError('Input should be int value.')
|
72 |
+
|
73 |
+
if IS_HIGH_VERSION:
|
74 |
+
if torch.backends.mps.is_available():
|
75 |
+
return torch.device('mps'+gpu_str)
|
76 |
+
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
77 |
+
|
78 |
+
|
79 |
+
def set_random_seed(seed):
|
80 |
+
"""Set random seeds."""
|
81 |
+
random.seed(seed)
|
82 |
+
np.random.seed(seed)
|
83 |
+
torch.manual_seed(seed)
|
84 |
+
torch.cuda.manual_seed(seed)
|
85 |
+
torch.cuda.manual_seed_all(seed)
|
86 |
+
|
87 |
+
|
88 |
+
def get_time_str():
|
89 |
+
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
90 |
+
|
91 |
+
|
92 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
93 |
+
"""Scan a directory to find the interested files.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
dir_path (str): Path of the directory.
|
97 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
98 |
+
interested in. Default: None.
|
99 |
+
recursive (bool, optional): If set to True, recursively scan the
|
100 |
+
directory. Default: False.
|
101 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
102 |
+
Default: False.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
A generator for all the interested files with relative pathes.
|
106 |
+
"""
|
107 |
+
|
108 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
109 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
110 |
+
|
111 |
+
root = dir_path
|
112 |
+
|
113 |
+
def _scandir(dir_path, suffix, recursive):
|
114 |
+
for entry in os.scandir(dir_path):
|
115 |
+
if not entry.name.startswith('.') and entry.is_file():
|
116 |
+
if full_path:
|
117 |
+
return_path = entry.path
|
118 |
+
else:
|
119 |
+
return_path = osp.relpath(entry.path, root)
|
120 |
+
|
121 |
+
if suffix is None:
|
122 |
+
yield return_path
|
123 |
+
elif return_path.endswith(suffix):
|
124 |
+
yield return_path
|
125 |
+
else:
|
126 |
+
if recursive:
|
127 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
128 |
+
else:
|
129 |
+
continue
|
130 |
+
|
131 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
hugging_face/tools/painter.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# paint masks, contours, or points on images, with specified colors
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import copy
|
7 |
+
import time
|
8 |
+
|
9 |
+
|
10 |
+
def colormap(rgb=True):
|
11 |
+
color_list = np.array(
|
12 |
+
[
|
13 |
+
0.000, 0.000, 0.000,
|
14 |
+
1.000, 1.000, 1.000,
|
15 |
+
1.000, 0.498, 0.313,
|
16 |
+
0.392, 0.581, 0.929,
|
17 |
+
0.000, 0.447, 0.741,
|
18 |
+
0.850, 0.325, 0.098,
|
19 |
+
0.929, 0.694, 0.125,
|
20 |
+
0.494, 0.184, 0.556,
|
21 |
+
0.466, 0.674, 0.188,
|
22 |
+
0.301, 0.745, 0.933,
|
23 |
+
0.635, 0.078, 0.184,
|
24 |
+
0.300, 0.300, 0.300,
|
25 |
+
0.600, 0.600, 0.600,
|
26 |
+
1.000, 0.000, 0.000,
|
27 |
+
1.000, 0.500, 0.000,
|
28 |
+
0.749, 0.749, 0.000,
|
29 |
+
0.000, 1.000, 0.000,
|
30 |
+
0.000, 0.000, 1.000,
|
31 |
+
0.667, 0.000, 1.000,
|
32 |
+
0.333, 0.333, 0.000,
|
33 |
+
0.333, 0.667, 0.000,
|
34 |
+
0.333, 1.000, 0.000,
|
35 |
+
0.667, 0.333, 0.000,
|
36 |
+
0.667, 0.667, 0.000,
|
37 |
+
0.667, 1.000, 0.000,
|
38 |
+
1.000, 0.333, 0.000,
|
39 |
+
1.000, 0.667, 0.000,
|
40 |
+
1.000, 1.000, 0.000,
|
41 |
+
0.000, 0.333, 0.500,
|
42 |
+
0.000, 0.667, 0.500,
|
43 |
+
0.000, 1.000, 0.500,
|
44 |
+
0.333, 0.000, 0.500,
|
45 |
+
0.333, 0.333, 0.500,
|
46 |
+
0.333, 0.667, 0.500,
|
47 |
+
0.333, 1.000, 0.500,
|
48 |
+
0.667, 0.000, 0.500,
|
49 |
+
0.667, 0.333, 0.500,
|
50 |
+
0.667, 0.667, 0.500,
|
51 |
+
0.667, 1.000, 0.500,
|
52 |
+
1.000, 0.000, 0.500,
|
53 |
+
1.000, 0.333, 0.500,
|
54 |
+
1.000, 0.667, 0.500,
|
55 |
+
1.000, 1.000, 0.500,
|
56 |
+
0.000, 0.333, 1.000,
|
57 |
+
0.000, 0.667, 1.000,
|
58 |
+
0.000, 1.000, 1.000,
|
59 |
+
0.333, 0.000, 1.000,
|
60 |
+
0.333, 0.333, 1.000,
|
61 |
+
0.333, 0.667, 1.000,
|
62 |
+
0.333, 1.000, 1.000,
|
63 |
+
0.667, 0.000, 1.000,
|
64 |
+
0.667, 0.333, 1.000,
|
65 |
+
0.667, 0.667, 1.000,
|
66 |
+
0.667, 1.000, 1.000,
|
67 |
+
1.000, 0.000, 1.000,
|
68 |
+
1.000, 0.333, 1.000,
|
69 |
+
1.000, 0.667, 1.000,
|
70 |
+
0.167, 0.000, 0.000,
|
71 |
+
0.333, 0.000, 0.000,
|
72 |
+
0.500, 0.000, 0.000,
|
73 |
+
0.667, 0.000, 0.000,
|
74 |
+
0.833, 0.000, 0.000,
|
75 |
+
1.000, 0.000, 0.000,
|
76 |
+
0.000, 0.167, 0.000,
|
77 |
+
0.000, 0.333, 0.000,
|
78 |
+
0.000, 0.500, 0.000,
|
79 |
+
0.000, 0.667, 0.000,
|
80 |
+
0.000, 0.833, 0.000,
|
81 |
+
0.000, 1.000, 0.000,
|
82 |
+
0.000, 0.000, 0.167,
|
83 |
+
0.000, 0.000, 0.333,
|
84 |
+
0.000, 0.000, 0.500,
|
85 |
+
0.000, 0.000, 0.667,
|
86 |
+
0.000, 0.000, 0.833,
|
87 |
+
0.000, 0.000, 1.000,
|
88 |
+
0.143, 0.143, 0.143,
|
89 |
+
0.286, 0.286, 0.286,
|
90 |
+
0.429, 0.429, 0.429,
|
91 |
+
0.571, 0.571, 0.571,
|
92 |
+
0.714, 0.714, 0.714,
|
93 |
+
0.857, 0.857, 0.857
|
94 |
+
]
|
95 |
+
).astype(np.float32)
|
96 |
+
color_list = color_list.reshape((-1, 3)) * 255
|
97 |
+
if not rgb:
|
98 |
+
color_list = color_list[:, ::-1]
|
99 |
+
return color_list
|
100 |
+
|
101 |
+
|
102 |
+
color_list = colormap()
|
103 |
+
color_list = color_list.astype('uint8').tolist()
|
104 |
+
|
105 |
+
|
106 |
+
def vis_add_mask(image, mask, color, alpha):
|
107 |
+
color = np.array(color_list[color])
|
108 |
+
mask = mask > 0.5
|
109 |
+
image[mask] = image[mask] * (1-alpha) + color * alpha
|
110 |
+
return image.astype('uint8')
|
111 |
+
|
112 |
+
def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
|
113 |
+
h, w = input_image.shape[:2]
|
114 |
+
point_mask = np.zeros((h, w)).astype('uint8')
|
115 |
+
for point in input_points:
|
116 |
+
point_mask[point[1], point[0]] = 1
|
117 |
+
|
118 |
+
kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
|
119 |
+
point_mask = cv2.dilate(point_mask, kernel)
|
120 |
+
|
121 |
+
contour_radius = (contour_width - 1) // 2
|
122 |
+
dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
|
123 |
+
dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
|
124 |
+
dist_map = dist_transform_fore - dist_transform_back
|
125 |
+
# ...:::!!!:::...
|
126 |
+
contour_radius += 2
|
127 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
128 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
129 |
+
contour_mask[contour_mask>0.5] = 1.
|
130 |
+
|
131 |
+
# paint mask
|
132 |
+
painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
|
133 |
+
# paint contour
|
134 |
+
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
|
135 |
+
return painted_image
|
136 |
+
|
137 |
+
def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
|
138 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
|
139 |
+
# 0: background, 1: foreground
|
140 |
+
mask = np.clip(input_mask, 0, 1)
|
141 |
+
contour_radius = (contour_width - 1) // 2
|
142 |
+
|
143 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
144 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
145 |
+
dist_map = dist_transform_fore - dist_transform_back
|
146 |
+
# ...:::!!!:::...
|
147 |
+
contour_radius += 2
|
148 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
149 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
150 |
+
contour_mask[contour_mask>0.5] = 1.
|
151 |
+
|
152 |
+
# paint mask
|
153 |
+
painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
|
154 |
+
# paint contour
|
155 |
+
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
|
156 |
+
|
157 |
+
return painted_image
|
158 |
+
|
159 |
+
def background_remover(input_image, input_mask):
|
160 |
+
"""
|
161 |
+
input_image: H, W, 3, np.array
|
162 |
+
input_mask: H, W, np.array
|
163 |
+
|
164 |
+
image_wo_background: PIL.Image
|
165 |
+
"""
|
166 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
|
167 |
+
# 0: background, 1: foreground
|
168 |
+
mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
|
169 |
+
image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
|
170 |
+
image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
|
171 |
+
|
172 |
+
return image_wo_background
|
173 |
+
|
174 |
+
if __name__ == '__main__':
|
175 |
+
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
176 |
+
input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
|
177 |
+
|
178 |
+
# example of mask painter
|
179 |
+
mask_color = 3
|
180 |
+
mask_alpha = 0.7
|
181 |
+
contour_color = 1
|
182 |
+
contour_width = 5
|
183 |
+
|
184 |
+
# save
|
185 |
+
painted_image = Image.fromarray(input_image)
|
186 |
+
painted_image.save('images/original.png')
|
187 |
+
|
188 |
+
painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
|
189 |
+
# save
|
190 |
+
painted_image = Image.fromarray(input_image)
|
191 |
+
painted_image.save('images/original1.png')
|
192 |
+
|
193 |
+
# example of point painter
|
194 |
+
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
195 |
+
input_points = np.array([[500, 375], [70, 600]]) # x, y
|
196 |
+
point_color = 5
|
197 |
+
point_alpha = 0.9
|
198 |
+
point_radius = 15
|
199 |
+
contour_color = 2
|
200 |
+
contour_width = 5
|
201 |
+
painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
|
202 |
+
# save
|
203 |
+
painted_image = Image.fromarray(painted_image_1)
|
204 |
+
painted_image.save('images/point_painter_1.png')
|
205 |
+
|
206 |
+
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
207 |
+
painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
|
208 |
+
# save
|
209 |
+
painted_image = Image.fromarray(painted_image_2)
|
210 |
+
painted_image.save('images/point_painter_2.png')
|
211 |
+
|
212 |
+
# example of background remover
|
213 |
+
input_image = np.array(Image.open('images/original.png').convert('RGB'))
|
214 |
+
image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
|
215 |
+
image_wo_background.save('images/image_wo_background.png')
|
matanyone/config/__init__.py
ADDED
File without changes
|
matanyone/config/eval_matanyone_config.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
- model: base
|
4 |
+
- override hydra/job_logging: custom-no-rank.yaml
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ../output/${exp_id}/${dataset}
|
9 |
+
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
10 |
+
|
11 |
+
amp: False
|
12 |
+
weights: pretrained_models/matanyone.pth # default (can be modified from outside)
|
13 |
+
output_dir: null # defaults to run_dir; specify this to override
|
14 |
+
flip_aug: False
|
15 |
+
|
16 |
+
|
17 |
+
# maximum shortest side of the input; -1 means no resizing
|
18 |
+
# With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader)
|
19 |
+
# this parameter is added for the sole purpose for the GUI in the current codebase
|
20 |
+
# InferenceCore will downsize the input and restore the output to the original size if needed
|
21 |
+
# if you are using this code for some other project, you can also utilize this parameter
|
22 |
+
max_internal_size: -1
|
23 |
+
|
24 |
+
# these parameters, when set, override the dataset's default; useful for debugging
|
25 |
+
save_all: True
|
26 |
+
use_all_masks: False
|
27 |
+
use_long_term: False
|
28 |
+
mem_every: 5
|
29 |
+
|
30 |
+
# only relevant when long_term is not enabled
|
31 |
+
max_mem_frames: 5
|
32 |
+
|
33 |
+
# only relevant when long_term is enabled
|
34 |
+
long_term:
|
35 |
+
count_usage: True
|
36 |
+
max_mem_frames: 10
|
37 |
+
min_mem_frames: 5
|
38 |
+
num_prototypes: 128
|
39 |
+
max_num_tokens: 10000
|
40 |
+
buffer_tokens: 2000
|
41 |
+
|
42 |
+
top_k: 30
|
43 |
+
stagger_updates: 5
|
44 |
+
chunk_size: -1 # number of objects to process in parallel; -1 means unlimited
|
45 |
+
save_scores: False
|
46 |
+
save_aux: False
|
47 |
+
visualize: False
|
matanyone/config/hydra/job_logging/custom-no-rank.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
handlers:
|
8 |
+
console:
|
9 |
+
class: logging.StreamHandler
|
10 |
+
formatter: simple
|
11 |
+
stream: ext://sys.stdout
|
12 |
+
file:
|
13 |
+
class: logging.FileHandler
|
14 |
+
formatter: simple
|
15 |
+
# absolute file path
|
16 |
+
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
|
17 |
+
mode: w
|
18 |
+
root:
|
19 |
+
level: INFO
|
20 |
+
handlers: [console, file]
|
21 |
+
|
22 |
+
disable_existing_loggers: false
|
matanyone/config/hydra/job_logging/custom.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
handlers:
|
8 |
+
console:
|
9 |
+
class: logging.StreamHandler
|
10 |
+
formatter: simple
|
11 |
+
stream: ext://sys.stdout
|
12 |
+
file:
|
13 |
+
class: logging.FileHandler
|
14 |
+
formatter: simple
|
15 |
+
# absolute file path
|
16 |
+
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
17 |
+
mode: w
|
18 |
+
root:
|
19 |
+
level: INFO
|
20 |
+
handlers: [console, file]
|
21 |
+
|
22 |
+
disable_existing_loggers: false
|
matanyone/config/model/base.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pixel_mean: [0.485, 0.456, 0.406]
|
2 |
+
pixel_std: [0.229, 0.224, 0.225]
|
3 |
+
|
4 |
+
pixel_dim: 256
|
5 |
+
key_dim: 64
|
6 |
+
value_dim: 256
|
7 |
+
sensory_dim: 256
|
8 |
+
embed_dim: 256
|
9 |
+
|
10 |
+
pixel_encoder:
|
11 |
+
type: resnet50
|
12 |
+
ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1
|
13 |
+
|
14 |
+
mask_encoder:
|
15 |
+
type: resnet18
|
16 |
+
final_dim: 256
|
17 |
+
|
18 |
+
pixel_pe_scale: 32
|
19 |
+
pixel_pe_temperature: 128
|
20 |
+
|
21 |
+
object_transformer:
|
22 |
+
embed_dim: ${model.embed_dim}
|
23 |
+
ff_dim: 2048
|
24 |
+
num_heads: 8
|
25 |
+
num_blocks: 3
|
26 |
+
num_queries: 16
|
27 |
+
read_from_pixel:
|
28 |
+
input_norm: False
|
29 |
+
input_add_pe: False
|
30 |
+
add_pe_to_qkv: [True, True, False]
|
31 |
+
read_from_past:
|
32 |
+
add_pe_to_qkv: [True, True, False]
|
33 |
+
read_from_memory:
|
34 |
+
add_pe_to_qkv: [True, True, False]
|
35 |
+
read_from_query:
|
36 |
+
add_pe_to_qkv: [True, True, False]
|
37 |
+
output_norm: False
|
38 |
+
query_self_attention:
|
39 |
+
add_pe_to_qkv: [True, True, False]
|
40 |
+
pixel_self_attention:
|
41 |
+
add_pe_to_qkv: [True, True, False]
|
42 |
+
|
43 |
+
object_summarizer:
|
44 |
+
embed_dim: ${model.object_transformer.embed_dim}
|
45 |
+
num_summaries: ${model.object_transformer.num_queries}
|
46 |
+
add_pe: True
|
47 |
+
|
48 |
+
aux_loss:
|
49 |
+
sensory:
|
50 |
+
enabled: True
|
51 |
+
weight: 0.01
|
52 |
+
query:
|
53 |
+
enabled: True
|
54 |
+
weight: 0.01
|
55 |
+
|
56 |
+
mask_decoder:
|
57 |
+
# first value must equal embed_dim
|
58 |
+
up_dims: [256, 128, 128, 64, 16]
|
matanyone/inference/__init__.py
ADDED
File without changes
|
matanyone/inference/image_feature_store.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Iterable
|
3 |
+
import torch
|
4 |
+
from matanyone.model.matanyone import MatAnyone
|
5 |
+
|
6 |
+
|
7 |
+
class ImageFeatureStore:
|
8 |
+
"""
|
9 |
+
A cache for image features.
|
10 |
+
These features might be reused at different parts of the inference pipeline.
|
11 |
+
This class provide an interface for reusing these features.
|
12 |
+
It is the user's responsibility to delete redundant features.
|
13 |
+
|
14 |
+
Feature of a frame should be associated with a unique index -- typically the frame id.
|
15 |
+
"""
|
16 |
+
def __init__(self, network: MatAnyone, no_warning: bool = False):
|
17 |
+
self.network = network
|
18 |
+
self._store = {}
|
19 |
+
self.no_warning = no_warning
|
20 |
+
|
21 |
+
def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None:
|
22 |
+
ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats)
|
23 |
+
key, shrinkage, selection = self.network.transform_key(ms_features[0])
|
24 |
+
self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
|
25 |
+
|
26 |
+
def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
|
27 |
+
seq_length = images.shape[0]
|
28 |
+
ms_features, pix_feat = self.network.encode_image(images, seq_length)
|
29 |
+
key, shrinkage, selection = self.network.transform_key(ms_features[0])
|
30 |
+
for index in range(seq_length):
|
31 |
+
self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0))
|
32 |
+
|
33 |
+
def get_features(self, index: int,
|
34 |
+
image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor):
|
35 |
+
if index not in self._store:
|
36 |
+
self._encode_feature(index, image, last_feats)
|
37 |
+
|
38 |
+
return self._store[index][:2]
|
39 |
+
|
40 |
+
def get_key(self, index: int,
|
41 |
+
image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
42 |
+
if index not in self._store:
|
43 |
+
self._encode_feature(index, image, last_feats)
|
44 |
+
|
45 |
+
return self._store[index][2:]
|
46 |
+
|
47 |
+
def delete(self, index: int) -> None:
|
48 |
+
if index in self._store:
|
49 |
+
del self._store[index]
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self._store)
|
53 |
+
|
54 |
+
def __del__(self):
|
55 |
+
if len(self._store) > 0 and not self.no_warning:
|
56 |
+
warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
|
matanyone/inference/inference_core.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Iterable, Dict
|
2 |
+
import logging
|
3 |
+
from omegaconf import DictConfig
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from matanyone.inference.memory_manager import MemoryManager
|
10 |
+
from matanyone.inference.object_manager import ObjectManager
|
11 |
+
from matanyone.inference.image_feature_store import ImageFeatureStore
|
12 |
+
from matanyone.model.matanyone import MatAnyone
|
13 |
+
from matanyone.utils.tensor_utils import pad_divide_by, unpad, aggregate
|
14 |
+
|
15 |
+
log = logging.getLogger()
|
16 |
+
|
17 |
+
|
18 |
+
class InferenceCore:
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
network: MatAnyone,
|
22 |
+
cfg: DictConfig,
|
23 |
+
*,
|
24 |
+
image_feature_store: ImageFeatureStore = None):
|
25 |
+
self.network = network
|
26 |
+
self.cfg = cfg
|
27 |
+
self.mem_every = cfg.mem_every
|
28 |
+
stagger_updates = cfg.stagger_updates
|
29 |
+
self.chunk_size = cfg.chunk_size
|
30 |
+
self.save_aux = cfg.save_aux
|
31 |
+
self.max_internal_size = cfg.max_internal_size
|
32 |
+
self.flip_aug = cfg.flip_aug
|
33 |
+
|
34 |
+
self.curr_ti = -1
|
35 |
+
self.last_mem_ti = 0
|
36 |
+
# at which time indices should we update the sensory memory
|
37 |
+
if stagger_updates >= self.mem_every:
|
38 |
+
self.stagger_ti = set(range(1, self.mem_every + 1))
|
39 |
+
else:
|
40 |
+
self.stagger_ti = set(
|
41 |
+
np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int))
|
42 |
+
self.object_manager = ObjectManager()
|
43 |
+
self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager)
|
44 |
+
|
45 |
+
if image_feature_store is None:
|
46 |
+
self.image_feature_store = ImageFeatureStore(self.network)
|
47 |
+
else:
|
48 |
+
self.image_feature_store = image_feature_store
|
49 |
+
|
50 |
+
self.last_mask = None
|
51 |
+
self.last_pix_feat = None
|
52 |
+
self.last_msk_value = None
|
53 |
+
|
54 |
+
def clear_memory(self):
|
55 |
+
self.curr_ti = -1
|
56 |
+
self.last_mem_ti = 0
|
57 |
+
self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager)
|
58 |
+
|
59 |
+
def clear_non_permanent_memory(self):
|
60 |
+
self.curr_ti = -1
|
61 |
+
self.last_mem_ti = 0
|
62 |
+
self.memory.clear_non_permanent_memory()
|
63 |
+
|
64 |
+
def clear_sensory_memory(self):
|
65 |
+
self.curr_ti = -1
|
66 |
+
self.last_mem_ti = 0
|
67 |
+
self.memory.clear_sensory_memory()
|
68 |
+
|
69 |
+
def update_config(self, cfg):
|
70 |
+
self.mem_every = cfg['mem_every']
|
71 |
+
self.memory.update_config(cfg)
|
72 |
+
|
73 |
+
def clear_temp_mem(self):
|
74 |
+
self.memory.clear_work_mem()
|
75 |
+
# self.object_manager = ObjectManager()
|
76 |
+
self.memory.clear_obj_mem()
|
77 |
+
# self.memory.clear_sensory_memory()
|
78 |
+
|
79 |
+
def _add_memory(self,
|
80 |
+
image: torch.Tensor,
|
81 |
+
pix_feat: torch.Tensor,
|
82 |
+
prob: torch.Tensor,
|
83 |
+
key: torch.Tensor,
|
84 |
+
shrinkage: torch.Tensor,
|
85 |
+
selection: torch.Tensor,
|
86 |
+
*,
|
87 |
+
is_deep_update: bool = True,
|
88 |
+
force_permanent: bool = False) -> None:
|
89 |
+
"""
|
90 |
+
Memorize the given segmentation in all memory stores.
|
91 |
+
|
92 |
+
The batch dimension is 1 if flip augmentation is not used.
|
93 |
+
image: RGB image, (1/2)*3*H*W
|
94 |
+
pix_feat: from the key encoder, (1/2)*_*H*W
|
95 |
+
prob: (1/2)*num_objects*H*W, in [0, 1]
|
96 |
+
key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W
|
97 |
+
selection can be None if not using long-term memory
|
98 |
+
is_deep_update: whether to use deep update (e.g. with the mask encoder)
|
99 |
+
force_permanent: whether to force the memory to be permanent
|
100 |
+
"""
|
101 |
+
if prob.shape[1] == 0:
|
102 |
+
# nothing to add
|
103 |
+
log.warn('Trying to add an empty object mask to memory!')
|
104 |
+
return
|
105 |
+
|
106 |
+
if force_permanent:
|
107 |
+
as_permanent = 'all'
|
108 |
+
else:
|
109 |
+
as_permanent = 'first'
|
110 |
+
|
111 |
+
self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids)
|
112 |
+
msk_value, sensory, obj_value, _ = self.network.encode_mask(
|
113 |
+
image,
|
114 |
+
pix_feat,
|
115 |
+
self.memory.get_sensory(self.object_manager.all_obj_ids),
|
116 |
+
prob,
|
117 |
+
deep_update=is_deep_update,
|
118 |
+
chunk_size=self.chunk_size,
|
119 |
+
need_weights=self.save_aux)
|
120 |
+
self.memory.add_memory(key,
|
121 |
+
shrinkage,
|
122 |
+
msk_value,
|
123 |
+
obj_value,
|
124 |
+
self.object_manager.all_obj_ids,
|
125 |
+
selection=selection,
|
126 |
+
as_permanent=as_permanent)
|
127 |
+
self.last_mem_ti = self.curr_ti
|
128 |
+
if is_deep_update:
|
129 |
+
self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
|
130 |
+
self.last_msk_value = msk_value
|
131 |
+
|
132 |
+
def _segment(self,
|
133 |
+
key: torch.Tensor,
|
134 |
+
selection: torch.Tensor,
|
135 |
+
pix_feat: torch.Tensor,
|
136 |
+
ms_features: Iterable[torch.Tensor],
|
137 |
+
update_sensory: bool = True) -> torch.Tensor:
|
138 |
+
"""
|
139 |
+
Produce a segmentation using the given features and the memory
|
140 |
+
|
141 |
+
The batch dimension is 1 if flip augmentation is not used.
|
142 |
+
key/selection: for anisotropic l2: (1/2) * _ * H * W
|
143 |
+
pix_feat: from the key encoder, (1/2) * _ * H * W
|
144 |
+
ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W
|
145 |
+
with strides 16, 8, and 4 respectively
|
146 |
+
update_sensory: whether to update the sensory memory
|
147 |
+
|
148 |
+
Returns: (num_objects+1)*H*W normalized probability; the first channel is the background
|
149 |
+
"""
|
150 |
+
bs = key.shape[0]
|
151 |
+
if self.flip_aug:
|
152 |
+
assert bs == 2
|
153 |
+
else:
|
154 |
+
assert bs == 1
|
155 |
+
|
156 |
+
if not self.memory.engaged:
|
157 |
+
log.warn('Trying to segment without any memory!')
|
158 |
+
return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
|
159 |
+
device=key.device,
|
160 |
+
dtype=key.dtype)
|
161 |
+
|
162 |
+
uncert_output = None
|
163 |
+
|
164 |
+
if self.curr_ti == 0: # ONLY for the first frame for prediction
|
165 |
+
memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output)
|
166 |
+
else:
|
167 |
+
memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti,
|
168 |
+
last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask)
|
169 |
+
memory_readout = self.object_manager.realize_dict(memory_readout)
|
170 |
+
|
171 |
+
sensory, _, pred_prob_with_bg = self.network.segment(ms_features,
|
172 |
+
memory_readout,
|
173 |
+
self.memory.get_sensory(
|
174 |
+
self.object_manager.all_obj_ids),
|
175 |
+
chunk_size=self.chunk_size,
|
176 |
+
update_sensory=update_sensory)
|
177 |
+
# remove batch dim
|
178 |
+
if self.flip_aug:
|
179 |
+
# average predictions of the non-flipped and flipped version
|
180 |
+
pred_prob_with_bg = (pred_prob_with_bg[0] +
|
181 |
+
torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2
|
182 |
+
else:
|
183 |
+
pred_prob_with_bg = pred_prob_with_bg[0]
|
184 |
+
if update_sensory:
|
185 |
+
self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
|
186 |
+
return pred_prob_with_bg
|
187 |
+
|
188 |
+
def pred_all_flow(self, images):
|
189 |
+
self.total_len = images.shape[0]
|
190 |
+
images, self.pad = pad_divide_by(images, 16)
|
191 |
+
images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w)
|
192 |
+
|
193 |
+
self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images)
|
194 |
+
|
195 |
+
def encode_all_images(self, images):
|
196 |
+
images, self.pad = pad_divide_by(images, 16)
|
197 |
+
self.image_feature_store.get_all_features(images) # t c h w
|
198 |
+
return images
|
199 |
+
|
200 |
+
def step(self,
|
201 |
+
image: torch.Tensor,
|
202 |
+
mask: Optional[torch.Tensor] = None,
|
203 |
+
objects: Optional[List[int]] = None,
|
204 |
+
*,
|
205 |
+
idx_mask: bool = False,
|
206 |
+
end: bool = False,
|
207 |
+
delete_buffer: bool = True,
|
208 |
+
force_permanent: bool = False,
|
209 |
+
matting: bool = True,
|
210 |
+
first_frame_pred: bool = False) -> torch.Tensor:
|
211 |
+
"""
|
212 |
+
Take a step with a new incoming image.
|
213 |
+
If there is an incoming mask with new objects, we will memorize them.
|
214 |
+
If there is no incoming mask, we will segment the image using the memory.
|
215 |
+
In both cases, we will update the memory and return a segmentation.
|
216 |
+
|
217 |
+
image: 3*H*W
|
218 |
+
mask: H*W (if idx mask) or len(objects)*H*W or None
|
219 |
+
objects: list of object ids that are valid in the mask Tensor.
|
220 |
+
The ids themselves do not need to be consecutive/in order, but they need to be
|
221 |
+
in the same position in the list as the corresponding mask
|
222 |
+
in the tensor in non-idx-mask mode.
|
223 |
+
objects is ignored if the mask is None.
|
224 |
+
If idx_mask is False and objects is None, we sequentially infer the object ids.
|
225 |
+
idx_mask: if True, mask is expected to contain an object id at every pixel.
|
226 |
+
If False, mask should have multiple channels with each channel representing one object.
|
227 |
+
end: if we are at the end of the sequence, we do not need to update memory
|
228 |
+
if unsure just set it to False
|
229 |
+
delete_buffer: whether to delete the image feature buffer after this step
|
230 |
+
force_permanent: the memory recorded this frame will be added to the permanent memory
|
231 |
+
"""
|
232 |
+
if objects is None and mask is not None:
|
233 |
+
assert not idx_mask
|
234 |
+
objects = list(range(1, mask.shape[0] + 1))
|
235 |
+
|
236 |
+
# resize input if needed -- currently only used for the GUI
|
237 |
+
resize_needed = False
|
238 |
+
if self.max_internal_size > 0:
|
239 |
+
h, w = image.shape[-2:]
|
240 |
+
min_side = min(h, w)
|
241 |
+
if min_side > self.max_internal_size:
|
242 |
+
resize_needed = True
|
243 |
+
new_h = int(h / min_side * self.max_internal_size)
|
244 |
+
new_w = int(w / min_side * self.max_internal_size)
|
245 |
+
image = F.interpolate(image.unsqueeze(0),
|
246 |
+
size=(new_h, new_w),
|
247 |
+
mode='bilinear',
|
248 |
+
align_corners=False)[0]
|
249 |
+
if mask is not None:
|
250 |
+
if idx_mask:
|
251 |
+
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),
|
252 |
+
size=(new_h, new_w),
|
253 |
+
mode='nearest-exact',
|
254 |
+
align_corners=False)[0, 0].round().long()
|
255 |
+
else:
|
256 |
+
mask = F.interpolate(mask.unsqueeze(0),
|
257 |
+
size=(new_h, new_w),
|
258 |
+
mode='bilinear',
|
259 |
+
align_corners=False)[0]
|
260 |
+
|
261 |
+
self.curr_ti += 1
|
262 |
+
|
263 |
+
image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!!
|
264 |
+
image = image.unsqueeze(0) # add the batch dimension
|
265 |
+
if self.flip_aug:
|
266 |
+
image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0)
|
267 |
+
|
268 |
+
# whether to update the working memory
|
269 |
+
is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or
|
270 |
+
(mask is not None)) and (not end)
|
271 |
+
# segment when there is no input mask or when the input mask is incomplete
|
272 |
+
need_segment = (mask is None) or (self.object_manager.num_obj > 0
|
273 |
+
and not self.object_manager.has_all(objects))
|
274 |
+
update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end)
|
275 |
+
|
276 |
+
# reinit if it is the first frame for prediction
|
277 |
+
if first_frame_pred:
|
278 |
+
self.curr_ti = 0
|
279 |
+
self.last_mem_ti = 0
|
280 |
+
is_mem_frame = True
|
281 |
+
need_segment = True
|
282 |
+
update_sensory = True
|
283 |
+
|
284 |
+
# encoding the image
|
285 |
+
ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
|
286 |
+
key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image)
|
287 |
+
|
288 |
+
# segmentation from memory if needed
|
289 |
+
if need_segment:
|
290 |
+
pred_prob_with_bg = self._segment(key,
|
291 |
+
selection,
|
292 |
+
pix_feat,
|
293 |
+
ms_feat,
|
294 |
+
update_sensory=update_sensory)
|
295 |
+
|
296 |
+
# use the input mask if provided
|
297 |
+
if mask is not None:
|
298 |
+
# inform the manager of the new objects, and get a list of temporary id
|
299 |
+
# temporary ids -- indicates the position of objects in the tensor
|
300 |
+
# (starts with 1 due to the background channel)
|
301 |
+
corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects)
|
302 |
+
|
303 |
+
mask, _ = pad_divide_by(mask, 16)
|
304 |
+
if need_segment:
|
305 |
+
print("HERE!!!!!!!!!!!")
|
306 |
+
# merge predicted mask with the incomplete input mask
|
307 |
+
pred_prob_no_bg = pred_prob_with_bg[1:]
|
308 |
+
# use the mutual exclusivity of segmentation
|
309 |
+
if idx_mask:
|
310 |
+
pred_prob_no_bg[:, mask > 0] = 0
|
311 |
+
else:
|
312 |
+
pred_prob_no_bg[:, mask.max(0) > 0.5] = 0
|
313 |
+
|
314 |
+
new_masks = []
|
315 |
+
for mask_id, tmp_id in enumerate(corresponding_tmp_ids):
|
316 |
+
if idx_mask:
|
317 |
+
this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg)
|
318 |
+
else:
|
319 |
+
this_mask = mask[tmp_id]
|
320 |
+
if tmp_id > pred_prob_no_bg.shape[0]:
|
321 |
+
new_masks.append(this_mask.unsqueeze(0))
|
322 |
+
else:
|
323 |
+
# +1 for padding the background channel
|
324 |
+
pred_prob_no_bg[tmp_id - 1] = this_mask
|
325 |
+
# new_masks are always in the order of tmp_id
|
326 |
+
mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0)
|
327 |
+
elif idx_mask:
|
328 |
+
# simply convert cls to one-hot representation
|
329 |
+
if len(objects) == 0:
|
330 |
+
if delete_buffer:
|
331 |
+
self.image_feature_store.delete(self.curr_ti)
|
332 |
+
log.warn('Trying to insert an empty mask as memory!')
|
333 |
+
return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
|
334 |
+
device=key.device,
|
335 |
+
dtype=key.dtype)
|
336 |
+
mask = torch.stack(
|
337 |
+
[mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)],
|
338 |
+
dim=0)
|
339 |
+
if matting:
|
340 |
+
mask = mask.unsqueeze(0).float() / 255.
|
341 |
+
pred_prob_with_bg = torch.cat([1-mask, mask], 0)
|
342 |
+
else:
|
343 |
+
pred_prob_with_bg = aggregate(mask, dim=0)
|
344 |
+
pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0)
|
345 |
+
|
346 |
+
self.last_mask = pred_prob_with_bg[1:].unsqueeze(0)
|
347 |
+
if self.flip_aug:
|
348 |
+
self.last_mask = torch.cat(
|
349 |
+
[self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0)
|
350 |
+
self.last_pix_feat = pix_feat
|
351 |
+
|
352 |
+
# save as memory if needed
|
353 |
+
if is_mem_frame or force_permanent:
|
354 |
+
# clear the memory for given mask and add the first predicted mask
|
355 |
+
if first_frame_pred:
|
356 |
+
self.clear_temp_mem()
|
357 |
+
self._add_memory(image,
|
358 |
+
pix_feat,
|
359 |
+
self.last_mask,
|
360 |
+
key,
|
361 |
+
shrinkage,
|
362 |
+
selection,
|
363 |
+
force_permanent=force_permanent,
|
364 |
+
is_deep_update=True)
|
365 |
+
else: # compute self.last_msk_value for non-memory frame
|
366 |
+
msk_value, _, _, _ = self.network.encode_mask(
|
367 |
+
image,
|
368 |
+
pix_feat,
|
369 |
+
self.memory.get_sensory(self.object_manager.all_obj_ids),
|
370 |
+
self.last_mask,
|
371 |
+
deep_update=False,
|
372 |
+
chunk_size=self.chunk_size,
|
373 |
+
need_weights=self.save_aux)
|
374 |
+
self.last_msk_value = msk_value
|
375 |
+
|
376 |
+
if delete_buffer:
|
377 |
+
self.image_feature_store.delete(self.curr_ti)
|
378 |
+
|
379 |
+
output_prob = unpad(pred_prob_with_bg, self.pad)
|
380 |
+
if resize_needed:
|
381 |
+
# restore output to the original size
|
382 |
+
output_prob = F.interpolate(output_prob.unsqueeze(0),
|
383 |
+
size=(h, w),
|
384 |
+
mode='bilinear',
|
385 |
+
align_corners=False)[0]
|
386 |
+
|
387 |
+
return output_prob
|
388 |
+
|
389 |
+
def delete_objects(self, objects: List[int]) -> None:
|
390 |
+
"""
|
391 |
+
Delete the given objects from the memory.
|
392 |
+
"""
|
393 |
+
self.object_manager.delete_objects(objects)
|
394 |
+
self.memory.purge_except(self.object_manager.all_obj_ids)
|
395 |
+
|
396 |
+
def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor:
|
397 |
+
if matting:
|
398 |
+
new_mask = output_prob[1:].squeeze(0)
|
399 |
+
else:
|
400 |
+
mask = torch.argmax(output_prob, dim=0)
|
401 |
+
|
402 |
+
# index in tensor != object id -- remap the ids here
|
403 |
+
new_mask = torch.zeros_like(mask)
|
404 |
+
for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
|
405 |
+
new_mask[mask == tmp_id] = obj.id
|
406 |
+
|
407 |
+
return new_mask
|
matanyone/inference/kv_memory_store.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Literal
|
2 |
+
from collections import defaultdict
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def _add_last_dim(dictionary, key, new_value, prepend=False):
|
7 |
+
# append/prepend a new value to the last dimension of a tensor in a dictionary
|
8 |
+
# if the key does not exist, put the new value in
|
9 |
+
# append by default
|
10 |
+
if key in dictionary:
|
11 |
+
dictionary[key] = torch.cat([dictionary[key], new_value], -1)
|
12 |
+
else:
|
13 |
+
dictionary[key] = new_value
|
14 |
+
|
15 |
+
|
16 |
+
class KeyValueMemoryStore:
|
17 |
+
"""
|
18 |
+
Works for key/value pairs type storage
|
19 |
+
e.g., working and long-term memory
|
20 |
+
"""
|
21 |
+
def __init__(self, save_selection: bool = False, save_usage: bool = False):
|
22 |
+
"""
|
23 |
+
We store keys and values of objects that first appear in the same frame in a bucket.
|
24 |
+
Each bucket contains a set of object ids.
|
25 |
+
Each bucket is associated with a single key tensor
|
26 |
+
and a dictionary of value tensors indexed by object id.
|
27 |
+
|
28 |
+
The keys and values are stored as the concatenation of a permanent part and a temporary part.
|
29 |
+
"""
|
30 |
+
self.save_selection = save_selection
|
31 |
+
self.save_usage = save_usage
|
32 |
+
|
33 |
+
self.global_bucket_id = 0 # does not reduce even if buckets are removed
|
34 |
+
self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
|
35 |
+
self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
|
36 |
+
self.v: Dict[int, torch.Tensor] = {} # indexed by object id
|
37 |
+
|
38 |
+
# indexed by bucket id; the end point of permanent memory
|
39 |
+
self.perm_end_pt: Dict[int, int] = defaultdict(int)
|
40 |
+
|
41 |
+
# shrinkage and selection are just like the keys
|
42 |
+
self.s = {}
|
43 |
+
if self.save_selection:
|
44 |
+
self.e = {} # does not contain the permanent memory part
|
45 |
+
|
46 |
+
# usage
|
47 |
+
if self.save_usage:
|
48 |
+
self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
|
49 |
+
self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
|
50 |
+
|
51 |
+
def add(self,
|
52 |
+
key: torch.Tensor,
|
53 |
+
values: Dict[int, torch.Tensor],
|
54 |
+
shrinkage: torch.Tensor,
|
55 |
+
selection: torch.Tensor,
|
56 |
+
supposed_bucket_id: int = -1,
|
57 |
+
as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
|
58 |
+
"""
|
59 |
+
key: (1/2)*C*N
|
60 |
+
values: dict of values ((1/2)*C*N), object ids are used as keys
|
61 |
+
shrinkage: (1/2)*1*N
|
62 |
+
selection: (1/2)*C*N
|
63 |
+
|
64 |
+
supposed_bucket_id: used to sync the bucket id between working and long-term memory
|
65 |
+
if provided, the input should all be in a single bucket indexed by this id
|
66 |
+
as_permanent: whether to store the input as permanent memory
|
67 |
+
'no': don't
|
68 |
+
'first': only store it as permanent memory if the bucket is empty
|
69 |
+
'all': always store it as permanent memory
|
70 |
+
"""
|
71 |
+
bs = key.shape[0]
|
72 |
+
ne = key.shape[-1]
|
73 |
+
assert len(key.shape) == 3
|
74 |
+
assert len(shrinkage.shape) == 3
|
75 |
+
assert not self.save_selection or len(selection.shape) == 3
|
76 |
+
assert as_permanent in ['no', 'first', 'all']
|
77 |
+
|
78 |
+
# add the value and create new buckets if necessary
|
79 |
+
if supposed_bucket_id >= 0:
|
80 |
+
enabled_buckets = [supposed_bucket_id]
|
81 |
+
bucket_exist = supposed_bucket_id in self.buckets
|
82 |
+
for obj, value in values.items():
|
83 |
+
if bucket_exist:
|
84 |
+
assert obj in self.v
|
85 |
+
assert obj in self.buckets[supposed_bucket_id]
|
86 |
+
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
|
87 |
+
else:
|
88 |
+
assert obj not in self.v
|
89 |
+
self.v[obj] = value
|
90 |
+
self.buckets[supposed_bucket_id] = list(values.keys())
|
91 |
+
else:
|
92 |
+
new_bucket_id = None
|
93 |
+
enabled_buckets = set()
|
94 |
+
for obj, value in values.items():
|
95 |
+
assert len(value.shape) == 3
|
96 |
+
if obj in self.v:
|
97 |
+
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
|
98 |
+
bucket_used = [
|
99 |
+
bucket_id for bucket_id, object_ids in self.buckets.items()
|
100 |
+
if obj in object_ids
|
101 |
+
]
|
102 |
+
assert len(bucket_used) == 1 # each object should only be in one bucket
|
103 |
+
enabled_buckets.add(bucket_used[0])
|
104 |
+
else:
|
105 |
+
self.v[obj] = value
|
106 |
+
if new_bucket_id is None:
|
107 |
+
# create new bucket
|
108 |
+
new_bucket_id = self.global_bucket_id
|
109 |
+
self.global_bucket_id += 1
|
110 |
+
self.buckets[new_bucket_id] = []
|
111 |
+
# put the new object into the corresponding bucket
|
112 |
+
self.buckets[new_bucket_id].append(obj)
|
113 |
+
enabled_buckets.add(new_bucket_id)
|
114 |
+
|
115 |
+
# increment the permanent size if necessary
|
116 |
+
add_as_permanent = {} # indexed by bucket id
|
117 |
+
for bucket_id in enabled_buckets:
|
118 |
+
add_as_permanent[bucket_id] = False
|
119 |
+
if as_permanent == 'all':
|
120 |
+
self.perm_end_pt[bucket_id] += ne
|
121 |
+
add_as_permanent[bucket_id] = True
|
122 |
+
elif as_permanent == 'first':
|
123 |
+
if self.perm_end_pt[bucket_id] == 0:
|
124 |
+
self.perm_end_pt[bucket_id] = ne
|
125 |
+
add_as_permanent[bucket_id] = True
|
126 |
+
|
127 |
+
# create new counters for usage if necessary
|
128 |
+
if self.save_usage and as_permanent != 'all':
|
129 |
+
new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
|
130 |
+
new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
|
131 |
+
|
132 |
+
# add the key to every bucket
|
133 |
+
for bucket_id in self.buckets:
|
134 |
+
if bucket_id not in enabled_buckets:
|
135 |
+
# if we are not adding new values to a bucket, we should skip it
|
136 |
+
continue
|
137 |
+
|
138 |
+
_add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
|
139 |
+
_add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
|
140 |
+
if not add_as_permanent[bucket_id]:
|
141 |
+
if self.save_selection:
|
142 |
+
_add_last_dim(self.e, bucket_id, selection)
|
143 |
+
if self.save_usage:
|
144 |
+
_add_last_dim(self.use_cnt, bucket_id, new_count)
|
145 |
+
_add_last_dim(self.life_cnt, bucket_id, new_life)
|
146 |
+
|
147 |
+
def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
|
148 |
+
# increase all life count by 1
|
149 |
+
# increase use of indexed elements
|
150 |
+
if not self.save_usage:
|
151 |
+
return
|
152 |
+
|
153 |
+
usage = usage[:, self.perm_end_pt[bucket_id]:]
|
154 |
+
if usage.shape[-1] == 0:
|
155 |
+
# if there is no temporary memory, we don't need to update
|
156 |
+
return
|
157 |
+
self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
|
158 |
+
self.life_cnt[bucket_id] += 1
|
159 |
+
|
160 |
+
def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
|
161 |
+
# keep only the temporary elements *outside* of this range (with some boundary conditions)
|
162 |
+
# the permanent elements are ignored in this computation
|
163 |
+
# i.e., concat (a[:start], a[end:])
|
164 |
+
# bucket with size <= min_size are not modified
|
165 |
+
|
166 |
+
assert start >= 0
|
167 |
+
assert end <= 0
|
168 |
+
|
169 |
+
object_ids = self.buckets[bucket_id]
|
170 |
+
bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
|
171 |
+
if bucket_num_elements <= min_size:
|
172 |
+
return
|
173 |
+
|
174 |
+
if end == 0:
|
175 |
+
# negative 0 would not work as the end index!
|
176 |
+
# effectively make the second part an empty slice
|
177 |
+
end = self.k[bucket_id].shape[-1] + 1
|
178 |
+
|
179 |
+
p_size = self.perm_end_pt[bucket_id]
|
180 |
+
start = start + p_size
|
181 |
+
|
182 |
+
k = self.k[bucket_id]
|
183 |
+
s = self.s[bucket_id]
|
184 |
+
if self.save_selection:
|
185 |
+
e = self.e[bucket_id]
|
186 |
+
if self.save_usage:
|
187 |
+
use_cnt = self.use_cnt[bucket_id]
|
188 |
+
life_cnt = self.life_cnt[bucket_id]
|
189 |
+
|
190 |
+
self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
|
191 |
+
self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
|
192 |
+
if self.save_selection:
|
193 |
+
self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
|
194 |
+
if self.save_usage:
|
195 |
+
self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
|
196 |
+
self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
|
197 |
+
-1)
|
198 |
+
for obj_id in object_ids:
|
199 |
+
v = self.v[obj_id]
|
200 |
+
self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
|
201 |
+
|
202 |
+
def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
|
203 |
+
self.sieve_by_range(bucket_id, 0, -max_len, max_len)
|
204 |
+
|
205 |
+
def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
|
206 |
+
# for long-term memory only
|
207 |
+
object_ids = self.buckets[bucket_id]
|
208 |
+
|
209 |
+
assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
|
210 |
+
|
211 |
+
# normalize with life duration
|
212 |
+
usage = self.get_usage(bucket_id)
|
213 |
+
bs = usage.shape[0]
|
214 |
+
|
215 |
+
survivals = []
|
216 |
+
|
217 |
+
for bi in range(bs):
|
218 |
+
_, survived = torch.topk(usage[bi], k=max_size)
|
219 |
+
survivals.append(survived.flatten())
|
220 |
+
assert survived.shape[-1] == survivals[0].shape[-1]
|
221 |
+
|
222 |
+
self.k[bucket_id] = torch.stack(
|
223 |
+
[self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
224 |
+
self.s[bucket_id] = torch.stack(
|
225 |
+
[self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
226 |
+
|
227 |
+
if self.save_selection:
|
228 |
+
# Long-term memory does not store selection so this should not be needed
|
229 |
+
self.e[bucket_id] = torch.stack(
|
230 |
+
[self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
231 |
+
for obj_id in object_ids:
|
232 |
+
self.v[obj_id] = torch.stack(
|
233 |
+
[self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
234 |
+
|
235 |
+
self.use_cnt[bucket_id] = torch.stack(
|
236 |
+
[self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
|
237 |
+
self.life_cnt[bucket_id] = torch.stack(
|
238 |
+
[self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
|
239 |
+
|
240 |
+
def get_usage(self, bucket_id: int) -> torch.Tensor:
|
241 |
+
# return normalized usage
|
242 |
+
if not self.save_usage:
|
243 |
+
raise RuntimeError('I did not count usage!')
|
244 |
+
else:
|
245 |
+
usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
|
246 |
+
return usage
|
247 |
+
|
248 |
+
def get_all_sliced(
|
249 |
+
self, bucket_id: int, start: int, end: int
|
250 |
+
) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
|
251 |
+
# return k, sk, ek, value, normalized usage in order, sliced by start and end
|
252 |
+
# this only queries the temporary memory
|
253 |
+
|
254 |
+
assert start >= 0
|
255 |
+
assert end <= 0
|
256 |
+
|
257 |
+
p_size = self.perm_end_pt[bucket_id]
|
258 |
+
start = start + p_size
|
259 |
+
|
260 |
+
if end == 0:
|
261 |
+
# negative 0 would not work as the end index!
|
262 |
+
k = self.k[bucket_id][:, :, start:]
|
263 |
+
sk = self.s[bucket_id][:, :, start:]
|
264 |
+
ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
|
265 |
+
value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
|
266 |
+
usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
|
267 |
+
else:
|
268 |
+
k = self.k[bucket_id][:, :, start:end]
|
269 |
+
sk = self.s[bucket_id][:, :, start:end]
|
270 |
+
ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
|
271 |
+
value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
|
272 |
+
usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
|
273 |
+
|
274 |
+
return k, sk, ek, value, usage
|
275 |
+
|
276 |
+
def purge_except(self, obj_keep_idx: List[int]):
|
277 |
+
# purge certain objects from the memory except the one listed
|
278 |
+
obj_keep_idx = set(obj_keep_idx)
|
279 |
+
|
280 |
+
# remove objects that are not in the keep list from the buckets
|
281 |
+
buckets_to_remove = []
|
282 |
+
for bucket_id, object_ids in self.buckets.items():
|
283 |
+
self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
|
284 |
+
if len(self.buckets[bucket_id]) == 0:
|
285 |
+
buckets_to_remove.append(bucket_id)
|
286 |
+
|
287 |
+
# remove object values that are not in the keep list
|
288 |
+
self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
|
289 |
+
|
290 |
+
# remove buckets that are empty
|
291 |
+
for bucket_id in buckets_to_remove:
|
292 |
+
del self.buckets[bucket_id]
|
293 |
+
del self.k[bucket_id]
|
294 |
+
del self.s[bucket_id]
|
295 |
+
if self.save_selection:
|
296 |
+
del self.e[bucket_id]
|
297 |
+
if self.save_usage:
|
298 |
+
del self.use_cnt[bucket_id]
|
299 |
+
del self.life_cnt[bucket_id]
|
300 |
+
|
301 |
+
def clear_non_permanent_memory(self):
|
302 |
+
# clear all non-permanent memory
|
303 |
+
for bucket_id in self.buckets:
|
304 |
+
self.sieve_by_range(bucket_id, 0, 0, 0)
|
305 |
+
|
306 |
+
def get_v_size(self, obj_id: int) -> int:
|
307 |
+
return self.v[obj_id].shape[-1]
|
308 |
+
|
309 |
+
def size(self, bucket_id: int) -> int:
|
310 |
+
if bucket_id not in self.k:
|
311 |
+
return 0
|
312 |
+
else:
|
313 |
+
return self.k[bucket_id].shape[-1]
|
314 |
+
|
315 |
+
def perm_size(self, bucket_id: int) -> int:
|
316 |
+
return self.perm_end_pt[bucket_id]
|
317 |
+
|
318 |
+
def non_perm_size(self, bucket_id: int) -> int:
|
319 |
+
return self.size(bucket_id) - self.perm_size(bucket_id)
|
320 |
+
|
321 |
+
def engaged(self, bucket_id: Optional[int] = None) -> bool:
|
322 |
+
if bucket_id is None:
|
323 |
+
return len(self.buckets) > 0
|
324 |
+
else:
|
325 |
+
return bucket_id in self.buckets
|
326 |
+
|
327 |
+
@property
|
328 |
+
def num_objects(self) -> int:
|
329 |
+
return len(self.v)
|
330 |
+
|
331 |
+
@property
|
332 |
+
def key(self) -> Dict[int, torch.Tensor]:
|
333 |
+
return self.k
|
334 |
+
|
335 |
+
@property
|
336 |
+
def value(self) -> Dict[int, torch.Tensor]:
|
337 |
+
return self.v
|
338 |
+
|
339 |
+
@property
|
340 |
+
def shrinkage(self) -> Dict[int, torch.Tensor]:
|
341 |
+
return self.s
|
342 |
+
|
343 |
+
@property
|
344 |
+
def selection(self) -> Dict[int, torch.Tensor]:
|
345 |
+
return self.e
|
346 |
+
|
347 |
+
def __contains__(self, key):
|
348 |
+
return key in self.v
|
matanyone/inference/memory_manager.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from omegaconf import DictConfig
|
3 |
+
from typing import List, Dict
|
4 |
+
import torch
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
from matanyone.inference.object_manager import ObjectManager
|
8 |
+
from matanyone.inference.kv_memory_store import KeyValueMemoryStore
|
9 |
+
from matanyone.model.matanyone import MatAnyone
|
10 |
+
from matanyone.model.utils.memory_utils import *
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class MemoryManager:
|
16 |
+
"""
|
17 |
+
Manages all three memory stores and the transition between working/long-term memory
|
18 |
+
"""
|
19 |
+
def __init__(self, cfg: DictConfig, object_manager: ObjectManager):
|
20 |
+
self.object_manager = object_manager
|
21 |
+
self.sensory_dim = cfg.model.sensory_dim
|
22 |
+
self.top_k = cfg.top_k
|
23 |
+
self.chunk_size = cfg.chunk_size
|
24 |
+
|
25 |
+
self.save_aux = cfg.save_aux
|
26 |
+
|
27 |
+
self.use_long_term = cfg.use_long_term
|
28 |
+
self.count_long_term_usage = cfg.long_term.count_usage
|
29 |
+
# subtract 1 because the first-frame is now counted as "permanent memory"
|
30 |
+
# and is not counted towards max_mem_frames
|
31 |
+
# but we want to keep the hyperparameters consistent as before for the same behavior
|
32 |
+
if self.use_long_term:
|
33 |
+
self.max_mem_frames = cfg.long_term.max_mem_frames - 1
|
34 |
+
self.min_mem_frames = cfg.long_term.min_mem_frames - 1
|
35 |
+
self.num_prototypes = cfg.long_term.num_prototypes
|
36 |
+
self.max_long_tokens = cfg.long_term.max_num_tokens
|
37 |
+
self.buffer_tokens = cfg.long_term.buffer_tokens
|
38 |
+
else:
|
39 |
+
self.max_mem_frames = cfg.max_mem_frames - 1
|
40 |
+
|
41 |
+
# dimensions will be inferred from input later
|
42 |
+
self.CK = self.CV = None
|
43 |
+
self.H = self.W = None
|
44 |
+
|
45 |
+
# The sensory memory is stored as a dictionary indexed by object ids
|
46 |
+
# each of shape bs * C^h * H * W
|
47 |
+
self.sensory = {}
|
48 |
+
|
49 |
+
# a dictionary indexed by object ids, each of shape bs * T * Q * C
|
50 |
+
self.obj_v = {}
|
51 |
+
|
52 |
+
self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
|
53 |
+
save_usage=self.use_long_term)
|
54 |
+
if self.use_long_term:
|
55 |
+
self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage)
|
56 |
+
|
57 |
+
self.config_stale = True
|
58 |
+
self.engaged = False
|
59 |
+
|
60 |
+
def update_config(self, cfg: DictConfig) -> None:
|
61 |
+
self.config_stale = True
|
62 |
+
self.top_k = cfg['top_k']
|
63 |
+
|
64 |
+
assert self.use_long_term == cfg.use_long_term, 'cannot update this'
|
65 |
+
assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this'
|
66 |
+
|
67 |
+
self.use_long_term = cfg.use_long_term
|
68 |
+
self.count_long_term_usage = cfg.long_term.count_usage
|
69 |
+
if self.use_long_term:
|
70 |
+
self.max_mem_frames = cfg.long_term.max_mem_frames - 1
|
71 |
+
self.min_mem_frames = cfg.long_term.min_mem_frames - 1
|
72 |
+
self.num_prototypes = cfg.long_term.num_prototypes
|
73 |
+
self.max_long_tokens = cfg.long_term.max_num_tokens
|
74 |
+
self.buffer_tokens = cfg.long_term.buffer_tokens
|
75 |
+
else:
|
76 |
+
self.max_mem_frames = cfg.max_mem_frames - 1
|
77 |
+
|
78 |
+
def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor:
|
79 |
+
# affinity: bs*N*HW
|
80 |
+
# v: bs*C*N or bs*num_objects*C*N
|
81 |
+
# returns bs*C*HW or bs*num_objects*C*HW
|
82 |
+
if len(v.shape) == 3:
|
83 |
+
# single object
|
84 |
+
if uncert_mask is not None:
|
85 |
+
return v @ affinity * uncert_mask
|
86 |
+
else:
|
87 |
+
return v @ affinity
|
88 |
+
else:
|
89 |
+
bs, num_objects, C, N = v.shape
|
90 |
+
v = v.view(bs, num_objects * C, N)
|
91 |
+
out = v @ affinity
|
92 |
+
if uncert_mask is not None:
|
93 |
+
uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1)
|
94 |
+
out = out * uncert_mask
|
95 |
+
return out.view(bs, num_objects, C, -1)
|
96 |
+
|
97 |
+
def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor:
|
98 |
+
# -1 because the mask does not contain the background channel
|
99 |
+
return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]]
|
100 |
+
|
101 |
+
def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
|
102 |
+
return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)
|
103 |
+
|
104 |
+
def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
|
105 |
+
return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)
|
106 |
+
|
107 |
+
def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
|
108 |
+
# All the values that the object ids refer to should have the same shape
|
109 |
+
value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1)
|
110 |
+
if self.use_long_term and obj_ids[0] in self.long_mem.value:
|
111 |
+
lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1)
|
112 |
+
value = torch.cat([lt_value, value], dim=-1)
|
113 |
+
|
114 |
+
return value
|
115 |
+
|
116 |
+
def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
|
117 |
+
last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]:
|
118 |
+
"""
|
119 |
+
Read from all memory stores and returns a single memory readout tensor for each object
|
120 |
+
|
121 |
+
pix_feat: (1/2) x C x H x W
|
122 |
+
query_key: (1/2) x C^k x H x W
|
123 |
+
selection: (1/2) x C^k x H x W
|
124 |
+
last_mask: (1/2) x num_objects x H x W (at stride 16)
|
125 |
+
return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
|
126 |
+
"""
|
127 |
+
h, w = pix_feat.shape[-2:]
|
128 |
+
bs = pix_feat.shape[0]
|
129 |
+
assert last_mask.shape[0] == bs
|
130 |
+
|
131 |
+
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
132 |
+
|
133 |
+
"""
|
134 |
+
Compute affinity and perform readout
|
135 |
+
"""
|
136 |
+
all_readout_mem = {}
|
137 |
+
buckets = self.work_mem.buckets
|
138 |
+
for bucket_id, bucket in buckets.items():
|
139 |
+
|
140 |
+
if self.chunk_size < 1:
|
141 |
+
object_chunks = [bucket]
|
142 |
+
else:
|
143 |
+
object_chunks = [
|
144 |
+
bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
|
145 |
+
]
|
146 |
+
|
147 |
+
for objects in object_chunks:
|
148 |
+
this_sensory = self._get_sensory_by_ids(objects)
|
149 |
+
this_last_mask = self._get_mask_by_ids(last_mask, objects)
|
150 |
+
this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
|
151 |
+
pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory,
|
152 |
+
this_last_mask)
|
153 |
+
this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
|
154 |
+
readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
|
155 |
+
for i, obj in enumerate(objects):
|
156 |
+
all_readout_mem[obj] = readout_memory[:, i]
|
157 |
+
|
158 |
+
if self.save_aux:
|
159 |
+
aux_output = {
|
160 |
+
# 'sensory': this_sensory,
|
161 |
+
# 'pixel_readout': pixel_readout,
|
162 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
163 |
+
# 'q_weights': aux_features['q_weights'] if aux_features else None,
|
164 |
+
# 'p_weights': aux_features['p_weights'] if aux_features else None,
|
165 |
+
# 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
|
166 |
+
}
|
167 |
+
self.aux = aux_output
|
168 |
+
|
169 |
+
return all_readout_mem
|
170 |
+
|
171 |
+
def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
|
172 |
+
last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None,
|
173 |
+
last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
|
174 |
+
"""
|
175 |
+
Read from all memory stores and returns a single memory readout tensor for each object
|
176 |
+
|
177 |
+
pix_feat: (1/2) x C x H x W
|
178 |
+
query_key: (1/2) x C^k x H x W
|
179 |
+
selection: (1/2) x C^k x H x W
|
180 |
+
last_mask: (1/2) x num_objects x H x W (at stride 16)
|
181 |
+
return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
|
182 |
+
"""
|
183 |
+
h, w = pix_feat.shape[-2:]
|
184 |
+
bs = pix_feat.shape[0]
|
185 |
+
assert query_key.shape[0] == bs
|
186 |
+
assert selection.shape[0] == bs
|
187 |
+
assert last_mask.shape[0] == bs
|
188 |
+
|
189 |
+
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
190 |
+
|
191 |
+
query_key = query_key.flatten(start_dim=2) # bs*C^k*HW
|
192 |
+
selection = selection.flatten(start_dim=2) # bs*C^k*HW
|
193 |
+
"""
|
194 |
+
Compute affinity and perform readout
|
195 |
+
"""
|
196 |
+
all_readout_mem = {}
|
197 |
+
buckets = self.work_mem.buckets
|
198 |
+
for bucket_id, bucket in buckets.items():
|
199 |
+
if self.use_long_term and self.long_mem.engaged(bucket_id):
|
200 |
+
# Use long-term memory
|
201 |
+
long_mem_size = self.long_mem.size(bucket_id)
|
202 |
+
memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]],
|
203 |
+
-1)
|
204 |
+
shrinkage = torch.cat(
|
205 |
+
[self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1)
|
206 |
+
|
207 |
+
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
|
208 |
+
affinity, usage = do_softmax(similarity,
|
209 |
+
top_k=self.top_k,
|
210 |
+
inplace=True,
|
211 |
+
return_usage=True)
|
212 |
+
"""
|
213 |
+
Record memory usage for working and long-term memory
|
214 |
+
"""
|
215 |
+
# ignore the index return for long-term memory
|
216 |
+
work_usage = usage[:, long_mem_size:]
|
217 |
+
self.work_mem.update_bucket_usage(bucket_id, work_usage)
|
218 |
+
|
219 |
+
if self.count_long_term_usage:
|
220 |
+
# ignore the index return for working memory
|
221 |
+
long_usage = usage[:, :long_mem_size]
|
222 |
+
self.long_mem.update_bucket_usage(bucket_id, long_usage)
|
223 |
+
else:
|
224 |
+
# no long-term memory
|
225 |
+
memory_key = self.work_mem.key[bucket_id]
|
226 |
+
shrinkage = self.work_mem.shrinkage[bucket_id]
|
227 |
+
similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask)
|
228 |
+
|
229 |
+
if self.use_long_term:
|
230 |
+
affinity, usage = do_softmax(similarity,
|
231 |
+
top_k=self.top_k,
|
232 |
+
inplace=True,
|
233 |
+
return_usage=True)
|
234 |
+
self.work_mem.update_bucket_usage(bucket_id, usage)
|
235 |
+
else:
|
236 |
+
affinity = do_softmax(similarity, top_k=self.top_k, inplace=True)
|
237 |
+
|
238 |
+
if self.chunk_size < 1:
|
239 |
+
object_chunks = [bucket]
|
240 |
+
else:
|
241 |
+
object_chunks = [
|
242 |
+
bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
|
243 |
+
]
|
244 |
+
|
245 |
+
for objects in object_chunks:
|
246 |
+
this_sensory = self._get_sensory_by_ids(objects)
|
247 |
+
this_last_mask = self._get_mask_by_ids(last_mask, objects)
|
248 |
+
this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
|
249 |
+
visual_readout = self._readout(affinity,
|
250 |
+
this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w)
|
251 |
+
|
252 |
+
uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0])
|
253 |
+
|
254 |
+
if uncert_output is not None:
|
255 |
+
uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
|
256 |
+
visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob)
|
257 |
+
|
258 |
+
pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
|
259 |
+
this_last_mask)
|
260 |
+
this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
|
261 |
+
readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
|
262 |
+
for i, obj in enumerate(objects):
|
263 |
+
all_readout_mem[obj] = readout_memory[:, i]
|
264 |
+
|
265 |
+
if self.save_aux:
|
266 |
+
aux_output = {
|
267 |
+
# 'sensory': this_sensory,
|
268 |
+
# 'pixel_readout': pixel_readout,
|
269 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
270 |
+
# 'q_weights': aux_features['q_weights'] if aux_features else None,
|
271 |
+
# 'p_weights': aux_features['p_weights'] if aux_features else None,
|
272 |
+
# 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
|
273 |
+
}
|
274 |
+
self.aux = aux_output
|
275 |
+
|
276 |
+
return all_readout_mem
|
277 |
+
|
278 |
+
def add_memory(self,
|
279 |
+
key: torch.Tensor,
|
280 |
+
shrinkage: torch.Tensor,
|
281 |
+
msk_value: torch.Tensor,
|
282 |
+
obj_value: torch.Tensor,
|
283 |
+
objects: List[int],
|
284 |
+
selection: torch.Tensor = None,
|
285 |
+
*,
|
286 |
+
as_permanent: bool = False) -> None:
|
287 |
+
# key: (1/2)*C*H*W
|
288 |
+
# msk_value: (1/2)*num_objects*C*H*W
|
289 |
+
# obj_value: (1/2)*num_objects*Q*C
|
290 |
+
# objects contains a list of object ids corresponding to the objects in msk_value/obj_value
|
291 |
+
bs = key.shape[0]
|
292 |
+
assert shrinkage.shape[0] == bs
|
293 |
+
assert msk_value.shape[0] == bs
|
294 |
+
assert obj_value.shape[0] == bs
|
295 |
+
|
296 |
+
self.engaged = True
|
297 |
+
if self.H is None or self.config_stale:
|
298 |
+
self.config_stale = False
|
299 |
+
self.H, self.W = msk_value.shape[-2:]
|
300 |
+
self.HW = self.H * self.W
|
301 |
+
# convert from num. frames to num. tokens
|
302 |
+
self.max_work_tokens = self.max_mem_frames * self.HW
|
303 |
+
if self.use_long_term:
|
304 |
+
self.min_work_tokens = self.min_mem_frames * self.HW
|
305 |
+
|
306 |
+
# key: bs*C*N
|
307 |
+
# value: bs*num_objects*C*N
|
308 |
+
key = key.flatten(start_dim=2)
|
309 |
+
shrinkage = shrinkage.flatten(start_dim=2)
|
310 |
+
self.CK = key.shape[1]
|
311 |
+
|
312 |
+
msk_value = msk_value.flatten(start_dim=3)
|
313 |
+
self.CV = msk_value.shape[2]
|
314 |
+
|
315 |
+
if selection is not None:
|
316 |
+
# not used in non-long-term mode
|
317 |
+
selection = selection.flatten(start_dim=2)
|
318 |
+
|
319 |
+
# insert object values into object memory
|
320 |
+
for obj_id, obj in enumerate(objects):
|
321 |
+
if obj in self.obj_v:
|
322 |
+
"""streaming average
|
323 |
+
each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
|
324 |
+
first embed_dim keeps track of the sum of embeddings
|
325 |
+
the last dim keeps the total count
|
326 |
+
averaging in done inside the object transformer
|
327 |
+
|
328 |
+
incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
|
329 |
+
self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
|
330 |
+
"""
|
331 |
+
last_acc = self.obj_v[obj][:, :, -1]
|
332 |
+
new_acc = last_acc + obj_value[:, obj_id, :, -1]
|
333 |
+
|
334 |
+
self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
|
335 |
+
obj_value[:, obj_id, :, :-1])
|
336 |
+
self.obj_v[obj][:, :, -1] = new_acc
|
337 |
+
else:
|
338 |
+
self.obj_v[obj] = obj_value[:, obj_id]
|
339 |
+
|
340 |
+
# convert mask value tensor into a dict for insertion
|
341 |
+
msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
|
342 |
+
self.work_mem.add(key,
|
343 |
+
msk_values,
|
344 |
+
shrinkage,
|
345 |
+
selection=selection,
|
346 |
+
as_permanent=as_permanent)
|
347 |
+
|
348 |
+
for bucket_id in self.work_mem.buckets.keys():
|
349 |
+
# long-term memory cleanup
|
350 |
+
if self.use_long_term:
|
351 |
+
# Do memory compressed if needed
|
352 |
+
if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
|
353 |
+
# Remove obsolete features if needed
|
354 |
+
if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
|
355 |
+
self.num_prototypes):
|
356 |
+
self.long_mem.remove_obsolete_features(
|
357 |
+
bucket_id,
|
358 |
+
self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
|
359 |
+
|
360 |
+
self.compress_features(bucket_id)
|
361 |
+
else:
|
362 |
+
# FIFO
|
363 |
+
self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens)
|
364 |
+
|
365 |
+
def purge_except(self, obj_keep_idx: List[int]) -> None:
|
366 |
+
# purge certain objects from the memory except the one listed
|
367 |
+
self.work_mem.purge_except(obj_keep_idx)
|
368 |
+
if self.use_long_term and self.long_mem.engaged():
|
369 |
+
self.long_mem.purge_except(obj_keep_idx)
|
370 |
+
self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx}
|
371 |
+
|
372 |
+
if not self.work_mem.engaged():
|
373 |
+
# everything is removed!
|
374 |
+
self.engaged = False
|
375 |
+
|
376 |
+
def compress_features(self, bucket_id: int) -> None:
|
377 |
+
HW = self.HW
|
378 |
+
|
379 |
+
# perform memory consolidation
|
380 |
+
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
381 |
+
*self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens))
|
382 |
+
|
383 |
+
# remove consolidated working memory
|
384 |
+
self.work_mem.sieve_by_range(bucket_id,
|
385 |
+
0,
|
386 |
+
-self.min_work_tokens,
|
387 |
+
min_size=self.min_work_tokens)
|
388 |
+
|
389 |
+
# add to long-term memory
|
390 |
+
self.long_mem.add(prototype_key,
|
391 |
+
prototype_value,
|
392 |
+
prototype_shrinkage,
|
393 |
+
selection=None,
|
394 |
+
supposed_bucket_id=bucket_id)
|
395 |
+
|
396 |
+
def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor,
|
397 |
+
candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor],
|
398 |
+
usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
|
399 |
+
# find the indices with max usage
|
400 |
+
bs = candidate_key.shape[0]
|
401 |
+
assert bs in [1, 2]
|
402 |
+
|
403 |
+
prototype_key = []
|
404 |
+
prototype_selection = []
|
405 |
+
for bi in range(bs):
|
406 |
+
_, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True)
|
407 |
+
prototype_indices = max_usage_indices.flatten()
|
408 |
+
prototype_key.append(candidate_key[bi, :, prototype_indices])
|
409 |
+
prototype_selection.append(candidate_selection[bi, :, prototype_indices])
|
410 |
+
prototype_key = torch.stack(prototype_key, dim=0)
|
411 |
+
prototype_selection = torch.stack(prototype_selection, dim=0)
|
412 |
+
"""
|
413 |
+
Potentiation step
|
414 |
+
"""
|
415 |
+
similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key,
|
416 |
+
prototype_selection)
|
417 |
+
affinity = do_softmax(similarity)
|
418 |
+
|
419 |
+
# readout the values
|
420 |
+
prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()}
|
421 |
+
|
422 |
+
# readout the shrinkage term
|
423 |
+
prototype_shrinkage = self._readout(affinity, candidate_shrinkage)
|
424 |
+
|
425 |
+
return prototype_key, prototype_value, prototype_shrinkage
|
426 |
+
|
427 |
+
def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]):
|
428 |
+
for obj in ids:
|
429 |
+
if obj not in self.sensory:
|
430 |
+
# also initializes the sensory memory
|
431 |
+
bs, _, h, w = sample_key.shape
|
432 |
+
self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w),
|
433 |
+
device=sample_key.device)
|
434 |
+
|
435 |
+
def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
|
436 |
+
# sensory: 1*num_objects*C*H*W
|
437 |
+
for obj_id, obj in enumerate(ids):
|
438 |
+
self.sensory[obj] = sensory[:, obj_id]
|
439 |
+
|
440 |
+
def get_sensory(self, ids: List[int]):
|
441 |
+
# returns (1/2)*num_objects*C*H*W
|
442 |
+
return self._get_sensory_by_ids(ids)
|
443 |
+
|
444 |
+
def clear_non_permanent_memory(self):
|
445 |
+
self.work_mem.clear_non_permanent_memory()
|
446 |
+
if self.use_long_term:
|
447 |
+
self.long_mem.clear_non_permanent_memory()
|
448 |
+
|
449 |
+
def clear_sensory_memory(self):
|
450 |
+
self.sensory = {}
|
451 |
+
|
452 |
+
def clear_work_mem(self):
|
453 |
+
self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
|
454 |
+
save_usage=self.use_long_term)
|
455 |
+
|
456 |
+
def clear_obj_mem(self):
|
457 |
+
self.obj_v = {}
|
matanyone/inference/object_info.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ObjectInfo:
|
2 |
+
"""
|
3 |
+
Store meta information for an object
|
4 |
+
"""
|
5 |
+
def __init__(self, id: int):
|
6 |
+
self.id = id
|
7 |
+
self.poke_count = 0 # count number of detections missed
|
8 |
+
|
9 |
+
def poke(self) -> None:
|
10 |
+
self.poke_count += 1
|
11 |
+
|
12 |
+
def unpoke(self) -> None:
|
13 |
+
self.poke_count = 0
|
14 |
+
|
15 |
+
def __hash__(self):
|
16 |
+
return hash(self.id)
|
17 |
+
|
18 |
+
def __eq__(self, other):
|
19 |
+
if type(other) == int:
|
20 |
+
return self.id == other
|
21 |
+
return self.id == other.id
|
22 |
+
|
23 |
+
def __repr__(self):
|
24 |
+
return f'(ID: {self.id})'
|
matanyone/inference/object_manager.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, List, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from matanyone.inference.object_info import ObjectInfo
|
5 |
+
|
6 |
+
|
7 |
+
class ObjectManager:
|
8 |
+
"""
|
9 |
+
Object IDs are immutable. The same ID always represent the same object.
|
10 |
+
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
|
11 |
+
Temporary IDs start from 1.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
|
16 |
+
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
|
17 |
+
self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
|
18 |
+
|
19 |
+
self.all_historical_object_ids: List[int] = []
|
20 |
+
|
21 |
+
def _recompute_obj_id_to_obj_mapping(self) -> None:
|
22 |
+
self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
|
23 |
+
|
24 |
+
def add_new_objects(
|
25 |
+
self, objects: Union[List[ObjectInfo], ObjectInfo,
|
26 |
+
List[int]]) -> (List[int], List[int]):
|
27 |
+
if not isinstance(objects, list):
|
28 |
+
objects = [objects]
|
29 |
+
|
30 |
+
corresponding_tmp_ids = []
|
31 |
+
corresponding_obj_ids = []
|
32 |
+
for obj in objects:
|
33 |
+
if isinstance(obj, int):
|
34 |
+
obj = ObjectInfo(id=obj)
|
35 |
+
|
36 |
+
if obj in self.obj_to_tmp_id:
|
37 |
+
# old object
|
38 |
+
corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
|
39 |
+
corresponding_obj_ids.append(obj.id)
|
40 |
+
else:
|
41 |
+
# new object
|
42 |
+
new_obj = ObjectInfo(id=obj.id)
|
43 |
+
|
44 |
+
# new object
|
45 |
+
new_tmp_id = len(self.obj_to_tmp_id) + 1
|
46 |
+
self.obj_to_tmp_id[new_obj] = new_tmp_id
|
47 |
+
self.tmp_id_to_obj[new_tmp_id] = new_obj
|
48 |
+
self.all_historical_object_ids.append(new_obj.id)
|
49 |
+
corresponding_tmp_ids.append(new_tmp_id)
|
50 |
+
corresponding_obj_ids.append(new_obj.id)
|
51 |
+
|
52 |
+
self._recompute_obj_id_to_obj_mapping()
|
53 |
+
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
|
54 |
+
return corresponding_tmp_ids, corresponding_obj_ids
|
55 |
+
|
56 |
+
def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
|
57 |
+
# delete an object or a list of objects
|
58 |
+
# re-sort the tmp ids
|
59 |
+
if isinstance(obj_ids_to_remove, int):
|
60 |
+
obj_ids_to_remove = [obj_ids_to_remove]
|
61 |
+
|
62 |
+
new_tmp_id = 1
|
63 |
+
total_num_id = len(self.obj_to_tmp_id)
|
64 |
+
|
65 |
+
local_obj_to_tmp_id = {}
|
66 |
+
local_tmp_to_obj_id = {}
|
67 |
+
|
68 |
+
for tmp_iter in range(1, total_num_id + 1):
|
69 |
+
obj = self.tmp_id_to_obj[tmp_iter]
|
70 |
+
if obj.id not in obj_ids_to_remove:
|
71 |
+
local_obj_to_tmp_id[obj] = new_tmp_id
|
72 |
+
local_tmp_to_obj_id[new_tmp_id] = obj
|
73 |
+
new_tmp_id += 1
|
74 |
+
|
75 |
+
self.obj_to_tmp_id = local_obj_to_tmp_id
|
76 |
+
self.tmp_id_to_obj = local_tmp_to_obj_id
|
77 |
+
self._recompute_obj_id_to_obj_mapping()
|
78 |
+
|
79 |
+
def purge_inactive_objects(self,
|
80 |
+
max_missed_detection_count: int) -> (bool, List[int], List[int]):
|
81 |
+
# remove tmp ids of objects that are removed
|
82 |
+
obj_id_to_be_deleted = []
|
83 |
+
tmp_id_to_be_deleted = []
|
84 |
+
tmp_id_to_keep = []
|
85 |
+
obj_id_to_keep = []
|
86 |
+
|
87 |
+
for obj in self.obj_to_tmp_id:
|
88 |
+
if obj.poke_count > max_missed_detection_count:
|
89 |
+
obj_id_to_be_deleted.append(obj.id)
|
90 |
+
tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
|
91 |
+
else:
|
92 |
+
tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
|
93 |
+
obj_id_to_keep.append(obj.id)
|
94 |
+
|
95 |
+
purge_activated = len(obj_id_to_be_deleted) > 0
|
96 |
+
if purge_activated:
|
97 |
+
self.delete_objects(obj_id_to_be_deleted)
|
98 |
+
return purge_activated, tmp_id_to_keep, obj_id_to_keep
|
99 |
+
|
100 |
+
def tmp_to_obj_cls(self, mask) -> torch.Tensor:
|
101 |
+
# remap tmp id cls representation to the true object id representation
|
102 |
+
new_mask = torch.zeros_like(mask)
|
103 |
+
for tmp_id, obj in self.tmp_id_to_obj.items():
|
104 |
+
new_mask[mask == tmp_id] = obj.id
|
105 |
+
return new_mask
|
106 |
+
|
107 |
+
def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
|
108 |
+
# returns the mapping in a dict format for saving it with pickle
|
109 |
+
return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
|
110 |
+
|
111 |
+
def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
|
112 |
+
# turns a dict indexed by obj id into a tensor, ordered by tmp IDs
|
113 |
+
output = []
|
114 |
+
for _, obj in self.tmp_id_to_obj.items():
|
115 |
+
if obj.id not in obj_dict:
|
116 |
+
raise NotImplementedError
|
117 |
+
output.append(obj_dict[obj.id])
|
118 |
+
output = torch.stack(output, dim=dim)
|
119 |
+
return output
|
120 |
+
|
121 |
+
def make_one_hot(self, cls_mask) -> torch.Tensor:
|
122 |
+
output = []
|
123 |
+
for _, obj in self.tmp_id_to_obj.items():
|
124 |
+
output.append(cls_mask == obj.id)
|
125 |
+
if len(output) == 0:
|
126 |
+
output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
|
127 |
+
else:
|
128 |
+
output = torch.stack(output, dim=0)
|
129 |
+
return output
|
130 |
+
|
131 |
+
@property
|
132 |
+
def all_obj_ids(self) -> List[int]:
|
133 |
+
return [k.id for k in self.obj_to_tmp_id]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def num_obj(self) -> int:
|
137 |
+
return len(self.obj_to_tmp_id)
|
138 |
+
|
139 |
+
def has_all(self, objects: List[int]) -> bool:
|
140 |
+
for obj in objects:
|
141 |
+
if obj not in self.obj_to_tmp_id:
|
142 |
+
return False
|
143 |
+
return True
|
144 |
+
|
145 |
+
def find_object_by_id(self, obj_id) -> ObjectInfo:
|
146 |
+
return self.obj_id_to_obj[obj_id]
|
147 |
+
|
148 |
+
def find_tmp_by_id(self, obj_id) -> int:
|
149 |
+
return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
|
matanyone/inference/utils/__init__.py
ADDED
File without changes
|
matanyone/inference/utils/args_utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from omegaconf import DictConfig
|
3 |
+
|
4 |
+
log = logging.getLogger()
|
5 |
+
|
6 |
+
|
7 |
+
def get_dataset_cfg(cfg: DictConfig):
|
8 |
+
dataset_name = cfg.dataset
|
9 |
+
data_cfg = cfg.datasets[dataset_name]
|
10 |
+
|
11 |
+
potential_overrides = [
|
12 |
+
'image_directory',
|
13 |
+
'mask_directory',
|
14 |
+
'json_directory',
|
15 |
+
'size',
|
16 |
+
'save_all',
|
17 |
+
'use_all_masks',
|
18 |
+
'use_long_term',
|
19 |
+
'mem_every',
|
20 |
+
]
|
21 |
+
|
22 |
+
for override in potential_overrides:
|
23 |
+
if cfg[override] is not None:
|
24 |
+
log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
|
25 |
+
data_cfg[override] = cfg[override]
|
26 |
+
# escalte all potential overrides to the top-level config
|
27 |
+
if override in data_cfg:
|
28 |
+
cfg[override] = data_cfg[override]
|
29 |
+
|
30 |
+
return data_cfg
|
matanyone/model/__init__.py
ADDED
File without changes
|
matanyone/model/aux_modules.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
For computing auxiliary outputs for auxiliary losses
|
3 |
+
"""
|
4 |
+
from typing import Dict
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from matanyone.model.group_modules import GConv2d
|
10 |
+
from matanyone.utils.tensor_utils import aggregate
|
11 |
+
|
12 |
+
|
13 |
+
class LinearPredictor(nn.Module):
|
14 |
+
def __init__(self, x_dim: int, pix_dim: int):
|
15 |
+
super().__init__()
|
16 |
+
self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
|
17 |
+
|
18 |
+
def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
19 |
+
# pixel_feat: B*pix_dim*H*W
|
20 |
+
# x: B*num_objects*x_dim*H*W
|
21 |
+
num_objects = x.shape[1]
|
22 |
+
x = self.projection(x)
|
23 |
+
|
24 |
+
pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
|
25 |
+
logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
|
26 |
+
return logits
|
27 |
+
|
28 |
+
|
29 |
+
class DirectPredictor(nn.Module):
|
30 |
+
def __init__(self, x_dim: int):
|
31 |
+
super().__init__()
|
32 |
+
self.projection = GConv2d(x_dim, 1, kernel_size=1)
|
33 |
+
|
34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
35 |
+
# x: B*num_objects*x_dim*H*W
|
36 |
+
logits = self.projection(x).squeeze(2)
|
37 |
+
return logits
|
38 |
+
|
39 |
+
|
40 |
+
class AuxComputer(nn.Module):
|
41 |
+
def __init__(self, cfg: DictConfig):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
use_sensory_aux = cfg.model.aux_loss.sensory.enabled
|
45 |
+
self.use_query_aux = cfg.model.aux_loss.query.enabled
|
46 |
+
self.use_sensory_aux = use_sensory_aux
|
47 |
+
|
48 |
+
sensory_dim = cfg.model.sensory_dim
|
49 |
+
embed_dim = cfg.model.embed_dim
|
50 |
+
|
51 |
+
if use_sensory_aux:
|
52 |
+
self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
|
53 |
+
|
54 |
+
def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
|
55 |
+
prob = torch.sigmoid(logits)
|
56 |
+
if selector is not None:
|
57 |
+
prob = prob * selector
|
58 |
+
logits = aggregate(prob, dim=1)
|
59 |
+
return logits
|
60 |
+
|
61 |
+
def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
|
62 |
+
selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
|
63 |
+
sensory = aux_input['sensory']
|
64 |
+
q_logits = aux_input['q_logits']
|
65 |
+
|
66 |
+
aux_output = {}
|
67 |
+
aux_output['attn_mask'] = aux_input['attn_mask']
|
68 |
+
|
69 |
+
if self.use_sensory_aux:
|
70 |
+
# B*num_objects*H*W
|
71 |
+
logits = self.sensory_aux(pix_feat, sensory)
|
72 |
+
aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
|
73 |
+
if self.use_query_aux:
|
74 |
+
# B*num_objects*num_levels*H*W
|
75 |
+
aux_output['q_logits'] = self._aggregate_with_selector(
|
76 |
+
torch.stack(q_logits, dim=2),
|
77 |
+
selector.unsqueeze(2) if selector is not None else None)
|
78 |
+
|
79 |
+
return aux_output
|
80 |
+
|
81 |
+
def compute_mask(self, aux_input: Dict[str, torch.Tensor],
|
82 |
+
selector: torch.Tensor) -> Dict[str, torch.Tensor]:
|
83 |
+
# sensory = aux_input['sensory']
|
84 |
+
q_logits = aux_input['q_logits']
|
85 |
+
|
86 |
+
aux_output = {}
|
87 |
+
|
88 |
+
# B*num_objects*num_levels*H*W
|
89 |
+
aux_output['q_logits'] = self._aggregate_with_selector(
|
90 |
+
torch.stack(q_logits, dim=2),
|
91 |
+
selector.unsqueeze(2) if selector is not None else None)
|
92 |
+
|
93 |
+
return aux_output
|
matanyone/model/big_modules.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
big_modules.py - This file stores higher-level network blocks.
|
3 |
+
|
4 |
+
x - usually denotes features that are shared between objects.
|
5 |
+
g - usually denotes features that are not shared between objects
|
6 |
+
with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
|
7 |
+
|
8 |
+
The trailing number of a variable usually denotes the stride
|
9 |
+
"""
|
10 |
+
|
11 |
+
from omegaconf import DictConfig
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from matanyone.model.group_modules import *
|
17 |
+
from matanyone.model.utils import resnet
|
18 |
+
from matanyone.model.modules import *
|
19 |
+
|
20 |
+
class UncertPred(nn.Module):
|
21 |
+
def __init__(self, model_cfg: DictConfig):
|
22 |
+
super().__init__()
|
23 |
+
self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False)
|
24 |
+
self.bn1 = nn.BatchNorm2d(64)
|
25 |
+
self.relu = nn.ReLU(inplace=True)
|
26 |
+
self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
|
27 |
+
self.bn2 = nn.BatchNorm2d(32)
|
28 |
+
self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
|
29 |
+
|
30 |
+
def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
|
31 |
+
last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area')
|
32 |
+
x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1)
|
33 |
+
x = self.conv1x1_v2(x)
|
34 |
+
x = self.bn1(x)
|
35 |
+
x = self.relu(x)
|
36 |
+
x = self.conv3x3(x)
|
37 |
+
x = self.bn2(x)
|
38 |
+
x = self.relu(x)
|
39 |
+
x = self.conv3x3_out(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
# override the default train() to freeze BN statistics
|
43 |
+
def train(self, mode=True):
|
44 |
+
self.training = False
|
45 |
+
for module in self.children():
|
46 |
+
module.train(False)
|
47 |
+
return self
|
48 |
+
|
49 |
+
class PixelEncoder(nn.Module):
|
50 |
+
def __init__(self, model_cfg: DictConfig):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
|
54 |
+
if self.is_resnet:
|
55 |
+
if model_cfg.pixel_encoder.type == 'resnet18':
|
56 |
+
network = resnet.resnet18(pretrained=True)
|
57 |
+
elif model_cfg.pixel_encoder.type == 'resnet50':
|
58 |
+
network = resnet.resnet50(pretrained=True)
|
59 |
+
else:
|
60 |
+
raise NotImplementedError
|
61 |
+
self.conv1 = network.conv1
|
62 |
+
self.bn1 = network.bn1
|
63 |
+
self.relu = network.relu
|
64 |
+
self.maxpool = network.maxpool
|
65 |
+
|
66 |
+
self.res2 = network.layer1
|
67 |
+
self.layer2 = network.layer2
|
68 |
+
self.layer3 = network.layer3
|
69 |
+
else:
|
70 |
+
raise NotImplementedError
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
73 |
+
f1 = x
|
74 |
+
x = self.conv1(x)
|
75 |
+
x = self.bn1(x)
|
76 |
+
x = self.relu(x)
|
77 |
+
f2 = x
|
78 |
+
x = self.maxpool(x)
|
79 |
+
f4 = self.res2(x)
|
80 |
+
f8 = self.layer2(f4)
|
81 |
+
f16 = self.layer3(f8)
|
82 |
+
|
83 |
+
return f16, f8, f4, f2, f1
|
84 |
+
|
85 |
+
# override the default train() to freeze BN statistics
|
86 |
+
def train(self, mode=True):
|
87 |
+
self.training = False
|
88 |
+
for module in self.children():
|
89 |
+
module.train(False)
|
90 |
+
return self
|
91 |
+
|
92 |
+
|
93 |
+
class KeyProjection(nn.Module):
|
94 |
+
def __init__(self, model_cfg: DictConfig):
|
95 |
+
super().__init__()
|
96 |
+
in_dim = model_cfg.pixel_encoder.ms_dims[0]
|
97 |
+
mid_dim = model_cfg.pixel_dim
|
98 |
+
key_dim = model_cfg.key_dim
|
99 |
+
|
100 |
+
self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
|
101 |
+
self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
|
102 |
+
# shrinkage
|
103 |
+
self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
|
104 |
+
# selection
|
105 |
+
self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
|
106 |
+
|
107 |
+
nn.init.orthogonal_(self.key_proj.weight.data)
|
108 |
+
nn.init.zeros_(self.key_proj.bias.data)
|
109 |
+
|
110 |
+
def forward(self, x: torch.Tensor, *, need_s: bool,
|
111 |
+
need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
112 |
+
x = self.pix_feat_proj(x)
|
113 |
+
shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
|
114 |
+
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
|
115 |
+
|
116 |
+
return self.key_proj(x), shrinkage, selection
|
117 |
+
|
118 |
+
|
119 |
+
class MaskEncoder(nn.Module):
|
120 |
+
def __init__(self, model_cfg: DictConfig, single_object=False):
|
121 |
+
super().__init__()
|
122 |
+
pixel_dim = model_cfg.pixel_dim
|
123 |
+
value_dim = model_cfg.value_dim
|
124 |
+
sensory_dim = model_cfg.sensory_dim
|
125 |
+
final_dim = model_cfg.mask_encoder.final_dim
|
126 |
+
|
127 |
+
self.single_object = single_object
|
128 |
+
extra_dim = 1 if single_object else 2
|
129 |
+
|
130 |
+
if model_cfg.mask_encoder.type == 'resnet18':
|
131 |
+
network = resnet.resnet18(pretrained=True, extra_dim=extra_dim)
|
132 |
+
elif model_cfg.mask_encoder.type == 'resnet50':
|
133 |
+
network = resnet.resnet50(pretrained=True, extra_dim=extra_dim)
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
self.conv1 = network.conv1
|
137 |
+
self.bn1 = network.bn1
|
138 |
+
self.relu = network.relu
|
139 |
+
self.maxpool = network.maxpool
|
140 |
+
|
141 |
+
self.layer1 = network.layer1
|
142 |
+
self.layer2 = network.layer2
|
143 |
+
self.layer3 = network.layer3
|
144 |
+
|
145 |
+
self.distributor = MainToGroupDistributor()
|
146 |
+
self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
|
147 |
+
|
148 |
+
self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
|
149 |
+
|
150 |
+
def forward(self,
|
151 |
+
image: torch.Tensor,
|
152 |
+
pix_feat: torch.Tensor,
|
153 |
+
sensory: torch.Tensor,
|
154 |
+
masks: torch.Tensor,
|
155 |
+
others: torch.Tensor,
|
156 |
+
*,
|
157 |
+
deep_update: bool = True,
|
158 |
+
chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
|
159 |
+
# ms_features are from the key encoder
|
160 |
+
# we only use the first one (lowest resolution), following XMem
|
161 |
+
if self.single_object:
|
162 |
+
g = masks.unsqueeze(2)
|
163 |
+
else:
|
164 |
+
g = torch.stack([masks, others], dim=2)
|
165 |
+
|
166 |
+
g = self.distributor(image, g)
|
167 |
+
|
168 |
+
batch_size, num_objects = g.shape[:2]
|
169 |
+
if chunk_size < 1 or chunk_size >= num_objects:
|
170 |
+
chunk_size = num_objects
|
171 |
+
fast_path = True
|
172 |
+
new_sensory = sensory
|
173 |
+
else:
|
174 |
+
if deep_update:
|
175 |
+
new_sensory = torch.empty_like(sensory)
|
176 |
+
else:
|
177 |
+
new_sensory = sensory
|
178 |
+
fast_path = False
|
179 |
+
|
180 |
+
# chunk-by-chunk inference
|
181 |
+
all_g = []
|
182 |
+
for i in range(0, num_objects, chunk_size):
|
183 |
+
if fast_path:
|
184 |
+
g_chunk = g
|
185 |
+
else:
|
186 |
+
g_chunk = g[:, i:i + chunk_size]
|
187 |
+
actual_chunk_size = g_chunk.shape[1]
|
188 |
+
g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
|
189 |
+
|
190 |
+
g_chunk = self.conv1(g_chunk)
|
191 |
+
g_chunk = self.bn1(g_chunk) # 1/2, 64
|
192 |
+
g_chunk = self.maxpool(g_chunk) # 1/4, 64
|
193 |
+
g_chunk = self.relu(g_chunk)
|
194 |
+
|
195 |
+
g_chunk = self.layer1(g_chunk) # 1/4
|
196 |
+
g_chunk = self.layer2(g_chunk) # 1/8
|
197 |
+
g_chunk = self.layer3(g_chunk) # 1/16
|
198 |
+
|
199 |
+
g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
|
200 |
+
g_chunk = self.fuser(pix_feat, g_chunk)
|
201 |
+
all_g.append(g_chunk)
|
202 |
+
if deep_update:
|
203 |
+
if fast_path:
|
204 |
+
new_sensory = self.sensory_update(g_chunk, sensory)
|
205 |
+
else:
|
206 |
+
new_sensory[:, i:i + chunk_size] = self.sensory_update(
|
207 |
+
g_chunk, sensory[:, i:i + chunk_size])
|
208 |
+
g = torch.cat(all_g, dim=1)
|
209 |
+
|
210 |
+
return g, new_sensory
|
211 |
+
|
212 |
+
# override the default train() to freeze BN statistics
|
213 |
+
def train(self, mode=True):
|
214 |
+
self.training = False
|
215 |
+
for module in self.children():
|
216 |
+
module.train(False)
|
217 |
+
return self
|
218 |
+
|
219 |
+
|
220 |
+
class PixelFeatureFuser(nn.Module):
|
221 |
+
def __init__(self, model_cfg: DictConfig, single_object=False):
|
222 |
+
super().__init__()
|
223 |
+
value_dim = model_cfg.value_dim
|
224 |
+
sensory_dim = model_cfg.sensory_dim
|
225 |
+
pixel_dim = model_cfg.pixel_dim
|
226 |
+
embed_dim = model_cfg.embed_dim
|
227 |
+
self.single_object = single_object
|
228 |
+
|
229 |
+
self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
|
230 |
+
if self.single_object:
|
231 |
+
self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
|
232 |
+
else:
|
233 |
+
self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
|
234 |
+
|
235 |
+
def forward(self,
|
236 |
+
pix_feat: torch.Tensor,
|
237 |
+
pixel_memory: torch.Tensor,
|
238 |
+
sensory_memory: torch.Tensor,
|
239 |
+
last_mask: torch.Tensor,
|
240 |
+
last_others: torch.Tensor,
|
241 |
+
*,
|
242 |
+
chunk_size: int = -1) -> torch.Tensor:
|
243 |
+
batch_size, num_objects = pixel_memory.shape[:2]
|
244 |
+
|
245 |
+
if self.single_object:
|
246 |
+
last_mask = last_mask.unsqueeze(2)
|
247 |
+
else:
|
248 |
+
last_mask = torch.stack([last_mask, last_others], dim=2)
|
249 |
+
|
250 |
+
if chunk_size < 1:
|
251 |
+
chunk_size = num_objects
|
252 |
+
|
253 |
+
# chunk-by-chunk inference
|
254 |
+
all_p16 = []
|
255 |
+
for i in range(0, num_objects, chunk_size):
|
256 |
+
sensory_readout = self.sensory_compress(
|
257 |
+
torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
|
258 |
+
p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
|
259 |
+
p16 = self.fuser(pix_feat, p16)
|
260 |
+
all_p16.append(p16)
|
261 |
+
p16 = torch.cat(all_p16, dim=1)
|
262 |
+
|
263 |
+
return p16
|
264 |
+
|
265 |
+
|
266 |
+
class MaskDecoder(nn.Module):
|
267 |
+
def __init__(self, model_cfg: DictConfig):
|
268 |
+
super().__init__()
|
269 |
+
embed_dim = model_cfg.embed_dim
|
270 |
+
sensory_dim = model_cfg.sensory_dim
|
271 |
+
ms_image_dims = model_cfg.pixel_encoder.ms_dims
|
272 |
+
up_dims = model_cfg.mask_decoder.up_dims
|
273 |
+
|
274 |
+
assert embed_dim == up_dims[0]
|
275 |
+
|
276 |
+
self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim,
|
277 |
+
sensory_dim)
|
278 |
+
|
279 |
+
self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
|
280 |
+
self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
|
281 |
+
self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
|
282 |
+
# newly add for alpha matte
|
283 |
+
self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3])
|
284 |
+
self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4])
|
285 |
+
|
286 |
+
self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
|
287 |
+
self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
|
288 |
+
|
289 |
+
def forward(self,
|
290 |
+
ms_image_feat: Iterable[torch.Tensor],
|
291 |
+
memory_readout: torch.Tensor,
|
292 |
+
sensory: torch.Tensor,
|
293 |
+
*,
|
294 |
+
chunk_size: int = -1,
|
295 |
+
update_sensory: bool = True,
|
296 |
+
seg_pass: bool = False,
|
297 |
+
last_mask=None,
|
298 |
+
sigmoid_residual=False) -> (torch.Tensor, torch.Tensor):
|
299 |
+
|
300 |
+
batch_size, num_objects = memory_readout.shape[:2]
|
301 |
+
f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:])
|
302 |
+
if chunk_size < 1 or chunk_size >= num_objects:
|
303 |
+
chunk_size = num_objects
|
304 |
+
fast_path = True
|
305 |
+
new_sensory = sensory
|
306 |
+
else:
|
307 |
+
if update_sensory:
|
308 |
+
new_sensory = torch.empty_like(sensory)
|
309 |
+
else:
|
310 |
+
new_sensory = sensory
|
311 |
+
fast_path = False
|
312 |
+
|
313 |
+
# chunk-by-chunk inference
|
314 |
+
all_logits = []
|
315 |
+
for i in range(0, num_objects, chunk_size):
|
316 |
+
if fast_path:
|
317 |
+
p16 = memory_readout
|
318 |
+
else:
|
319 |
+
p16 = memory_readout[:, i:i + chunk_size]
|
320 |
+
actual_chunk_size = p16.shape[1]
|
321 |
+
|
322 |
+
p8 = self.up_16_8(p16, f8)
|
323 |
+
p4 = self.up_8_4(p8, f4)
|
324 |
+
p2 = self.up_4_2(p4, f2)
|
325 |
+
p1 = self.up_2_1(p2, f1)
|
326 |
+
with torch.cuda.amp.autocast(enabled=False):
|
327 |
+
if seg_pass:
|
328 |
+
if last_mask is not None:
|
329 |
+
res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
330 |
+
if sigmoid_residual:
|
331 |
+
res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
|
332 |
+
logits = last_mask + res
|
333 |
+
else:
|
334 |
+
logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
335 |
+
else:
|
336 |
+
if last_mask is not None:
|
337 |
+
res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
338 |
+
if sigmoid_residual:
|
339 |
+
res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
|
340 |
+
logits = last_mask + res
|
341 |
+
else:
|
342 |
+
logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
343 |
+
## SensoryUpdater_fullscale
|
344 |
+
if update_sensory:
|
345 |
+
p1 = torch.cat(
|
346 |
+
[p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
|
347 |
+
if fast_path:
|
348 |
+
new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory)
|
349 |
+
else:
|
350 |
+
new_sensory[:,
|
351 |
+
i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1],
|
352 |
+
sensory[:,
|
353 |
+
i:i + chunk_size])
|
354 |
+
all_logits.append(logits)
|
355 |
+
logits = torch.cat(all_logits, dim=0)
|
356 |
+
logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
|
357 |
+
|
358 |
+
return new_sensory, logits
|
matanyone/model/channel_attn.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class CAResBlock(nn.Module):
|
8 |
+
def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
|
9 |
+
super().__init__()
|
10 |
+
self.residual = residual
|
11 |
+
self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
|
12 |
+
self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
|
13 |
+
|
14 |
+
t = int((abs(math.log2(out_dim)) + 1) // 2)
|
15 |
+
k = t if t % 2 else t + 1
|
16 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
17 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
|
18 |
+
|
19 |
+
if self.residual:
|
20 |
+
if in_dim == out_dim:
|
21 |
+
self.downsample = nn.Identity()
|
22 |
+
else:
|
23 |
+
self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
26 |
+
r = x
|
27 |
+
x = self.conv1(F.relu(x))
|
28 |
+
x = self.conv2(F.relu(x))
|
29 |
+
|
30 |
+
b, c = x.shape[:2]
|
31 |
+
w = self.pool(x).view(b, 1, c)
|
32 |
+
w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
|
33 |
+
|
34 |
+
if self.residual:
|
35 |
+
x = x * w + self.downsample(r)
|
36 |
+
else:
|
37 |
+
x = x * w
|
38 |
+
|
39 |
+
return x
|
matanyone/model/group_modules.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from matanyone.model.channel_attn import CAResBlock
|
6 |
+
|
7 |
+
def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
|
8 |
+
align_corners: bool) -> torch.Tensor:
|
9 |
+
batch_size, num_objects = g.shape[:2]
|
10 |
+
g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
|
11 |
+
scale_factor=ratio,
|
12 |
+
mode=mode,
|
13 |
+
align_corners=align_corners)
|
14 |
+
g = g.view(batch_size, num_objects, *g.shape[1:])
|
15 |
+
return g
|
16 |
+
|
17 |
+
|
18 |
+
def upsample_groups(g: torch.Tensor,
|
19 |
+
ratio: float = 2,
|
20 |
+
mode: str = 'bilinear',
|
21 |
+
align_corners: bool = False) -> torch.Tensor:
|
22 |
+
return interpolate_groups(g, ratio, mode, align_corners)
|
23 |
+
|
24 |
+
|
25 |
+
def downsample_groups(g: torch.Tensor,
|
26 |
+
ratio: float = 1 / 2,
|
27 |
+
mode: str = 'area',
|
28 |
+
align_corners: bool = None) -> torch.Tensor:
|
29 |
+
return interpolate_groups(g, ratio, mode, align_corners)
|
30 |
+
|
31 |
+
|
32 |
+
class GConv2d(nn.Conv2d):
|
33 |
+
def forward(self, g: torch.Tensor) -> torch.Tensor:
|
34 |
+
batch_size, num_objects = g.shape[:2]
|
35 |
+
g = super().forward(g.flatten(start_dim=0, end_dim=1))
|
36 |
+
return g.view(batch_size, num_objects, *g.shape[1:])
|
37 |
+
|
38 |
+
|
39 |
+
class GroupResBlock(nn.Module):
|
40 |
+
def __init__(self, in_dim: int, out_dim: int):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
if in_dim == out_dim:
|
44 |
+
self.downsample = nn.Identity()
|
45 |
+
else:
|
46 |
+
self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
|
47 |
+
|
48 |
+
self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
|
49 |
+
self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
|
50 |
+
|
51 |
+
def forward(self, g: torch.Tensor) -> torch.Tensor:
|
52 |
+
out_g = self.conv1(F.relu(g))
|
53 |
+
out_g = self.conv2(F.relu(out_g))
|
54 |
+
|
55 |
+
g = self.downsample(g)
|
56 |
+
|
57 |
+
return out_g + g
|
58 |
+
|
59 |
+
|
60 |
+
class MainToGroupDistributor(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
x_transform: Optional[nn.Module] = None,
|
63 |
+
g_transform: Optional[nn.Module] = None,
|
64 |
+
method: str = 'cat',
|
65 |
+
reverse_order: bool = False):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.x_transform = x_transform
|
69 |
+
self.g_transform = g_transform
|
70 |
+
self.method = method
|
71 |
+
self.reverse_order = reverse_order
|
72 |
+
|
73 |
+
def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
|
74 |
+
num_objects = g.shape[1]
|
75 |
+
|
76 |
+
if self.x_transform is not None:
|
77 |
+
x = self.x_transform(x)
|
78 |
+
|
79 |
+
if self.g_transform is not None:
|
80 |
+
g = self.g_transform(g)
|
81 |
+
|
82 |
+
if not skip_expand:
|
83 |
+
x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
|
84 |
+
if self.method == 'cat':
|
85 |
+
if self.reverse_order:
|
86 |
+
g = torch.cat([g, x], 2)
|
87 |
+
else:
|
88 |
+
g = torch.cat([x, g], 2)
|
89 |
+
elif self.method == 'add':
|
90 |
+
g = x + g
|
91 |
+
elif self.method == 'mulcat':
|
92 |
+
g = torch.cat([x * g, g], dim=2)
|
93 |
+
elif self.method == 'muladd':
|
94 |
+
g = x * g + g
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
return g
|
99 |
+
|
100 |
+
|
101 |
+
class GroupFeatureFusionBlock(nn.Module):
|
102 |
+
def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
|
106 |
+
g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
|
107 |
+
|
108 |
+
self.distributor = MainToGroupDistributor(x_transform=x_transform,
|
109 |
+
g_transform=g_transform,
|
110 |
+
method='add')
|
111 |
+
self.block1 = CAResBlock(out_dim, out_dim)
|
112 |
+
self.block2 = CAResBlock(out_dim, out_dim)
|
113 |
+
|
114 |
+
def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
|
115 |
+
batch_size, num_objects = g.shape[:2]
|
116 |
+
|
117 |
+
g = self.distributor(x, g)
|
118 |
+
|
119 |
+
g = g.flatten(start_dim=0, end_dim=1)
|
120 |
+
|
121 |
+
g = self.block1(g)
|
122 |
+
g = self.block2(g)
|
123 |
+
|
124 |
+
g = g.view(batch_size, num_objects, *g.shape[1:])
|
125 |
+
|
126 |
+
return g
|
matanyone/model/matanyone.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
import logging
|
3 |
+
from omegaconf import DictConfig
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from matanyone.model.modules import *
|
8 |
+
from matanyone.model.big_modules import *
|
9 |
+
from matanyone.model.aux_modules import AuxComputer
|
10 |
+
from matanyone.model.utils.memory_utils import *
|
11 |
+
from matanyone.model.transformer.object_transformer import QueryTransformer
|
12 |
+
from matanyone.model.transformer.object_summarizer import ObjectSummarizer
|
13 |
+
from matanyone.utils.tensor_utils import aggregate
|
14 |
+
|
15 |
+
log = logging.getLogger()
|
16 |
+
|
17 |
+
|
18 |
+
class MatAnyone(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, cfg: DictConfig, *, single_object=False):
|
21 |
+
super().__init__()
|
22 |
+
self.cfg = cfg
|
23 |
+
model_cfg = cfg.model
|
24 |
+
self.ms_dims = model_cfg.pixel_encoder.ms_dims
|
25 |
+
self.key_dim = model_cfg.key_dim
|
26 |
+
self.value_dim = model_cfg.value_dim
|
27 |
+
self.sensory_dim = model_cfg.sensory_dim
|
28 |
+
self.pixel_dim = model_cfg.pixel_dim
|
29 |
+
self.embed_dim = model_cfg.embed_dim
|
30 |
+
self.single_object = single_object
|
31 |
+
|
32 |
+
log.info(f'Single object: {self.single_object}')
|
33 |
+
|
34 |
+
self.pixel_encoder = PixelEncoder(model_cfg)
|
35 |
+
self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
|
36 |
+
self.key_proj = KeyProjection(model_cfg)
|
37 |
+
self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
|
38 |
+
self.mask_decoder = MaskDecoder(model_cfg)
|
39 |
+
self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
|
40 |
+
self.object_transformer = QueryTransformer(model_cfg)
|
41 |
+
self.object_summarizer = ObjectSummarizer(model_cfg)
|
42 |
+
self.aux_computer = AuxComputer(cfg)
|
43 |
+
self.uncert_pred = UncertPred(model_cfg)
|
44 |
+
|
45 |
+
self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
|
46 |
+
self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
|
47 |
+
|
48 |
+
def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
|
49 |
+
# for each object, return the sum of masks of all other objects
|
50 |
+
if self.single_object:
|
51 |
+
return None
|
52 |
+
|
53 |
+
num_objects = masks.shape[1]
|
54 |
+
if num_objects >= 1:
|
55 |
+
others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
|
56 |
+
else:
|
57 |
+
others = torch.zeros_like(masks)
|
58 |
+
return others
|
59 |
+
|
60 |
+
def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
|
61 |
+
logits = self.uncert_pred(last_frame_feat=last_pix_feat,
|
62 |
+
cur_frame_feat=cur_pix_feat,
|
63 |
+
last_mask=last_mask,
|
64 |
+
mem_val_diff=mem_val_diff)
|
65 |
+
|
66 |
+
prob = torch.sigmoid(logits)
|
67 |
+
mask = (prob > 0) + 0
|
68 |
+
|
69 |
+
uncert_output = {"logits": logits,
|
70 |
+
"prob": prob,
|
71 |
+
"mask": mask}
|
72 |
+
|
73 |
+
return uncert_output
|
74 |
+
|
75 |
+
def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
|
76 |
+
image = (image - self.pixel_mean) / self.pixel_std
|
77 |
+
ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
|
78 |
+
return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
|
79 |
+
|
80 |
+
def encode_mask(
|
81 |
+
self,
|
82 |
+
image: torch.Tensor,
|
83 |
+
ms_features: List[torch.Tensor],
|
84 |
+
sensory: torch.Tensor,
|
85 |
+
masks: torch.Tensor,
|
86 |
+
*,
|
87 |
+
deep_update: bool = True,
|
88 |
+
chunk_size: int = -1,
|
89 |
+
need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
|
90 |
+
image = (image - self.pixel_mean) / self.pixel_std
|
91 |
+
others = self._get_others(masks)
|
92 |
+
mask_value, new_sensory = self.mask_encoder(image,
|
93 |
+
ms_features,
|
94 |
+
sensory,
|
95 |
+
masks,
|
96 |
+
others,
|
97 |
+
deep_update=deep_update,
|
98 |
+
chunk_size=chunk_size)
|
99 |
+
object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
|
100 |
+
return mask_value, new_sensory, object_summaries, object_logits
|
101 |
+
|
102 |
+
def transform_key(self,
|
103 |
+
final_pix_feat: torch.Tensor,
|
104 |
+
*,
|
105 |
+
need_sk: bool = True,
|
106 |
+
need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
107 |
+
key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
|
108 |
+
return key, shrinkage, selection
|
109 |
+
|
110 |
+
# Used in training only.
|
111 |
+
# This step is replaced by MemoryManager in test time
|
112 |
+
def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
|
113 |
+
memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
|
114 |
+
msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
115 |
+
sensory: torch.Tensor, last_mask: torch.Tensor,
|
116 |
+
selector: torch.Tensor, uncert_output=None, seg_pass=False,
|
117 |
+
last_pix_feat=None, last_pred_mask=None) -> (torch.Tensor, Dict[str, torch.Tensor]):
|
118 |
+
"""
|
119 |
+
query_key : B * CK * H * W
|
120 |
+
query_selection : B * CK * H * W
|
121 |
+
memory_key : B * CK * T * H * W
|
122 |
+
memory_shrinkage: B * 1 * T * H * W
|
123 |
+
msk_value : B * num_objects * CV * T * H * W
|
124 |
+
obj_memory : B * num_objects * T * num_summaries * C
|
125 |
+
pixel_feature : B * C * H * W
|
126 |
+
"""
|
127 |
+
batch_size, num_objects = msk_value.shape[:2]
|
128 |
+
|
129 |
+
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
130 |
+
|
131 |
+
# read using visual attention
|
132 |
+
with torch.cuda.amp.autocast(enabled=False):
|
133 |
+
affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
|
134 |
+
query_selection.float(), uncert_mask=uncert_mask)
|
135 |
+
|
136 |
+
msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
|
137 |
+
|
138 |
+
# B * (num_objects*CV) * H * W
|
139 |
+
pixel_readout = readout(affinity, msk_value, uncert_mask)
|
140 |
+
pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
|
141 |
+
*pixel_readout.shape[-2:])
|
142 |
+
|
143 |
+
uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1])
|
144 |
+
uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
|
145 |
+
pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob)
|
146 |
+
|
147 |
+
pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
|
148 |
+
|
149 |
+
|
150 |
+
# read from query transformer
|
151 |
+
mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
|
152 |
+
|
153 |
+
aux_output = {
|
154 |
+
'sensory': sensory,
|
155 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
156 |
+
'attn_mask': aux_features['attn_mask'] if aux_features else None,
|
157 |
+
}
|
158 |
+
|
159 |
+
return mem_readout, aux_output, uncert_output
|
160 |
+
|
161 |
+
def read_first_frame_memory(self, pixel_readout,
|
162 |
+
obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
163 |
+
sensory: torch.Tensor, last_mask: torch.Tensor,
|
164 |
+
selector: torch.Tensor, seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
|
165 |
+
"""
|
166 |
+
query_key : B * CK * H * W
|
167 |
+
query_selection : B * CK * H * W
|
168 |
+
memory_key : B * CK * T * H * W
|
169 |
+
memory_shrinkage: B * 1 * T * H * W
|
170 |
+
msk_value : B * num_objects * CV * T * H * W
|
171 |
+
obj_memory : B * num_objects * T * num_summaries * C
|
172 |
+
pixel_feature : B * C * H * W
|
173 |
+
"""
|
174 |
+
|
175 |
+
pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
|
176 |
+
|
177 |
+
# read from query transformer
|
178 |
+
mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
|
179 |
+
|
180 |
+
aux_output = {
|
181 |
+
'sensory': sensory,
|
182 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
183 |
+
'attn_mask': aux_features['attn_mask'] if aux_features else None,
|
184 |
+
}
|
185 |
+
|
186 |
+
return mem_readout, aux_output
|
187 |
+
|
188 |
+
def pixel_fusion(self,
|
189 |
+
pix_feat: torch.Tensor,
|
190 |
+
pixel: torch.Tensor,
|
191 |
+
sensory: torch.Tensor,
|
192 |
+
last_mask: torch.Tensor,
|
193 |
+
*,
|
194 |
+
chunk_size: int = -1) -> torch.Tensor:
|
195 |
+
last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
|
196 |
+
last_others = self._get_others(last_mask)
|
197 |
+
fused = self.pixel_fuser(pix_feat,
|
198 |
+
pixel,
|
199 |
+
sensory,
|
200 |
+
last_mask,
|
201 |
+
last_others,
|
202 |
+
chunk_size=chunk_size)
|
203 |
+
return fused
|
204 |
+
|
205 |
+
def readout_query(self,
|
206 |
+
pixel_readout,
|
207 |
+
obj_memory,
|
208 |
+
*,
|
209 |
+
selector=None,
|
210 |
+
need_weights=False,
|
211 |
+
seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
|
212 |
+
return self.object_transformer(pixel_readout,
|
213 |
+
obj_memory,
|
214 |
+
selector=selector,
|
215 |
+
need_weights=need_weights,
|
216 |
+
seg_pass=seg_pass)
|
217 |
+
|
218 |
+
def segment(self,
|
219 |
+
ms_image_feat: List[torch.Tensor],
|
220 |
+
memory_readout: torch.Tensor,
|
221 |
+
sensory: torch.Tensor,
|
222 |
+
*,
|
223 |
+
selector: bool = None,
|
224 |
+
chunk_size: int = -1,
|
225 |
+
update_sensory: bool = True,
|
226 |
+
seg_pass: bool = False,
|
227 |
+
clamp_mat: bool = True,
|
228 |
+
last_mask=None,
|
229 |
+
sigmoid_residual=False,
|
230 |
+
seg_mat=False) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
231 |
+
"""
|
232 |
+
multi_scale_features is from the key encoder for skip-connection
|
233 |
+
memory_readout is from working/long-term memory
|
234 |
+
sensory is the sensory memory
|
235 |
+
last_mask is the mask from the last frame, supplementing sensory memory
|
236 |
+
selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
|
237 |
+
during training.
|
238 |
+
"""
|
239 |
+
#### use mat head for seg data
|
240 |
+
if seg_mat:
|
241 |
+
assert seg_pass
|
242 |
+
seg_pass = False
|
243 |
+
####
|
244 |
+
sensory, logits = self.mask_decoder(ms_image_feat,
|
245 |
+
memory_readout,
|
246 |
+
sensory,
|
247 |
+
chunk_size=chunk_size,
|
248 |
+
update_sensory=update_sensory,
|
249 |
+
seg_pass = seg_pass,
|
250 |
+
last_mask=last_mask,
|
251 |
+
sigmoid_residual=sigmoid_residual)
|
252 |
+
if seg_pass:
|
253 |
+
prob = torch.sigmoid(logits)
|
254 |
+
if selector is not None:
|
255 |
+
prob = prob * selector
|
256 |
+
|
257 |
+
# Softmax over all objects[]
|
258 |
+
logits = aggregate(prob, dim=1)
|
259 |
+
prob = F.softmax(logits, dim=1)
|
260 |
+
else:
|
261 |
+
if clamp_mat:
|
262 |
+
logits = logits.clamp(0.0, 1.0)
|
263 |
+
logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1)
|
264 |
+
prob = logits
|
265 |
+
|
266 |
+
return sensory, logits, prob
|
267 |
+
|
268 |
+
def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
|
269 |
+
selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
|
270 |
+
return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass)
|
271 |
+
|
272 |
+
def forward(self, *args, **kwargs):
|
273 |
+
raise NotImplementedError
|
274 |
+
|
275 |
+
def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
|
276 |
+
if not self.single_object:
|
277 |
+
# Map single-object weight to multi-object weight (4->5 out channels in conv1)
|
278 |
+
for k in list(src_dict.keys()):
|
279 |
+
if k == 'mask_encoder.conv1.weight':
|
280 |
+
if src_dict[k].shape[1] == 4:
|
281 |
+
log.info(f'Converting {k} from single object to multiple objects.')
|
282 |
+
pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
|
283 |
+
if not init_as_zero_if_needed:
|
284 |
+
nn.init.orthogonal_(pads)
|
285 |
+
log.info(f'Randomly initialized padding for {k}.')
|
286 |
+
else:
|
287 |
+
log.info(f'Zero-initialized padding for {k}.')
|
288 |
+
src_dict[k] = torch.cat([src_dict[k], pads], 1)
|
289 |
+
elif k == 'pixel_fuser.sensory_compress.weight':
|
290 |
+
if src_dict[k].shape[1] == self.sensory_dim + 1:
|
291 |
+
log.info(f'Converting {k} from single object to multiple objects.')
|
292 |
+
pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
|
293 |
+
if not init_as_zero_if_needed:
|
294 |
+
nn.init.orthogonal_(pads)
|
295 |
+
log.info(f'Randomly initialized padding for {k}.')
|
296 |
+
else:
|
297 |
+
log.info(f'Zero-initialized padding for {k}.')
|
298 |
+
src_dict[k] = torch.cat([src_dict[k], pads], 1)
|
299 |
+
elif self.single_object:
|
300 |
+
"""
|
301 |
+
If the model is multiple-object and we are training in single-object,
|
302 |
+
we strip the last channel of conv1.
|
303 |
+
This is not supposed to happen in standard training except when users are trying to
|
304 |
+
finetune a trained model with single object datasets.
|
305 |
+
"""
|
306 |
+
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
|
307 |
+
log.warning(f'Converting mask_encoder.conv1.weight from multiple objects to single object.'
|
308 |
+
'This is not supposed to happen in standard training.')
|
309 |
+
src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
|
310 |
+
src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
|
311 |
+
|
312 |
+
for k in src_dict:
|
313 |
+
if k not in self.state_dict():
|
314 |
+
log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
|
315 |
+
for k in self.state_dict():
|
316 |
+
if k not in src_dict:
|
317 |
+
log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
|
318 |
+
|
319 |
+
self.load_state_dict(src_dict, strict=False)
|
320 |
+
|
321 |
+
@property
|
322 |
+
def device(self) -> torch.device:
|
323 |
+
return self.pixel_mean.device
|
matanyone/model/modules.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Iterable
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from matanyone.model.group_modules import *
|
6 |
+
|
7 |
+
|
8 |
+
class UpsampleBlock(nn.Module):
|
9 |
+
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
|
10 |
+
super().__init__()
|
11 |
+
self.out_conv = ResBlock(in_dim, out_dim)
|
12 |
+
self.scale_factor = scale_factor
|
13 |
+
|
14 |
+
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
|
15 |
+
g = F.interpolate(in_g,
|
16 |
+
scale_factor=self.scale_factor,
|
17 |
+
mode='bilinear')
|
18 |
+
g = self.out_conv(g)
|
19 |
+
g = g + skip_f
|
20 |
+
return g
|
21 |
+
|
22 |
+
class MaskUpsampleBlock(nn.Module):
|
23 |
+
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
|
24 |
+
super().__init__()
|
25 |
+
self.distributor = MainToGroupDistributor(method='add')
|
26 |
+
self.out_conv = GroupResBlock(in_dim, out_dim)
|
27 |
+
self.scale_factor = scale_factor
|
28 |
+
|
29 |
+
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
|
30 |
+
g = upsample_groups(in_g, ratio=self.scale_factor)
|
31 |
+
g = self.distributor(skip_f, g)
|
32 |
+
g = self.out_conv(g)
|
33 |
+
return g
|
34 |
+
|
35 |
+
|
36 |
+
class DecoderFeatureProcessor(nn.Module):
|
37 |
+
def __init__(self, decoder_dims: List[int], out_dims: List[int]):
|
38 |
+
super().__init__()
|
39 |
+
self.transforms = nn.ModuleList([
|
40 |
+
nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
|
41 |
+
])
|
42 |
+
|
43 |
+
def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
|
44 |
+
outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
|
45 |
+
return outputs
|
46 |
+
|
47 |
+
|
48 |
+
# @torch.jit.script
|
49 |
+
def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
50 |
+
# h: batch_size * num_objects * hidden_dim * h * w
|
51 |
+
# values: batch_size * num_objects * (hidden_dim*3) * h * w
|
52 |
+
dim = values.shape[2] // 3
|
53 |
+
forget_gate = torch.sigmoid(values[:, :, :dim])
|
54 |
+
update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
|
55 |
+
new_value = torch.tanh(values[:, :, dim * 2:])
|
56 |
+
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
|
57 |
+
return new_h
|
58 |
+
|
59 |
+
|
60 |
+
class SensoryUpdater_fullscale(nn.Module):
|
61 |
+
# Used in the decoder, multi-scale feature + GRU
|
62 |
+
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
|
63 |
+
super().__init__()
|
64 |
+
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
|
65 |
+
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
|
66 |
+
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
|
67 |
+
self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
|
68 |
+
self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
|
69 |
+
|
70 |
+
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
|
71 |
+
|
72 |
+
nn.init.xavier_normal_(self.transform.weight)
|
73 |
+
|
74 |
+
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
75 |
+
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
|
76 |
+
self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
|
77 |
+
self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
|
78 |
+
self.g1_conv(downsample_groups(g[4], ratio=1/16))
|
79 |
+
|
80 |
+
with torch.cuda.amp.autocast(enabled=False):
|
81 |
+
g = g.float()
|
82 |
+
h = h.float()
|
83 |
+
values = self.transform(torch.cat([g, h], dim=2))
|
84 |
+
new_h = _recurrent_update(h, values)
|
85 |
+
|
86 |
+
return new_h
|
87 |
+
|
88 |
+
class SensoryUpdater(nn.Module):
|
89 |
+
# Used in the decoder, multi-scale feature + GRU
|
90 |
+
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
|
91 |
+
super().__init__()
|
92 |
+
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
|
93 |
+
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
|
94 |
+
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
|
95 |
+
|
96 |
+
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
|
97 |
+
|
98 |
+
nn.init.xavier_normal_(self.transform.weight)
|
99 |
+
|
100 |
+
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
101 |
+
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
|
102 |
+
self.g4_conv(downsample_groups(g[2], ratio=1/4))
|
103 |
+
|
104 |
+
with torch.cuda.amp.autocast(enabled=False):
|
105 |
+
g = g.float()
|
106 |
+
h = h.float()
|
107 |
+
values = self.transform(torch.cat([g, h], dim=2))
|
108 |
+
new_h = _recurrent_update(h, values)
|
109 |
+
|
110 |
+
return new_h
|
111 |
+
|
112 |
+
|
113 |
+
class SensoryDeepUpdater(nn.Module):
|
114 |
+
def __init__(self, f_dim: int, sensory_dim: int):
|
115 |
+
super().__init__()
|
116 |
+
self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
|
117 |
+
|
118 |
+
nn.init.xavier_normal_(self.transform.weight)
|
119 |
+
|
120 |
+
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
121 |
+
with torch.cuda.amp.autocast(enabled=False):
|
122 |
+
g = g.float()
|
123 |
+
h = h.float()
|
124 |
+
values = self.transform(torch.cat([g, h], dim=2))
|
125 |
+
new_h = _recurrent_update(h, values)
|
126 |
+
|
127 |
+
return new_h
|
128 |
+
|
129 |
+
|
130 |
+
class ResBlock(nn.Module):
|
131 |
+
def __init__(self, in_dim: int, out_dim: int):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
if in_dim == out_dim:
|
135 |
+
self.downsample = nn.Identity()
|
136 |
+
else:
|
137 |
+
self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
138 |
+
|
139 |
+
self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
|
140 |
+
self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
|
141 |
+
|
142 |
+
def forward(self, g: torch.Tensor) -> torch.Tensor:
|
143 |
+
out_g = self.conv1(F.relu(g))
|
144 |
+
out_g = self.conv2(F.relu(out_g))
|
145 |
+
|
146 |
+
g = self.downsample(g)
|
147 |
+
|
148 |
+
return out_g + g
|
149 |
+
|
150 |
+
def __init__(self, in_dim, reduction_dim, bins):
|
151 |
+
super(PPM, self).__init__()
|
152 |
+
self.features = []
|
153 |
+
for bin in bins:
|
154 |
+
self.features.append(nn.Sequential(
|
155 |
+
nn.AdaptiveAvgPool2d(bin),
|
156 |
+
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
|
157 |
+
nn.PReLU()
|
158 |
+
))
|
159 |
+
self.features = nn.ModuleList(self.features)
|
160 |
+
self.fuse = nn.Sequential(
|
161 |
+
nn.Conv2d(in_dim+reduction_dim*4, in_dim, kernel_size=3, padding=1, bias=False),
|
162 |
+
nn.PReLU())
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
x_size = x.size()
|
166 |
+
out = [x]
|
167 |
+
for f in self.features:
|
168 |
+
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
|
169 |
+
out_feat = self.fuse(torch.cat(out, 1))
|
170 |
+
return out_feat
|
matanyone/model/transformer/__init__.py
ADDED
File without changes
|
matanyone/model/transformer/object_summarizer.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Optional
|
2 |
+
from omegaconf import DictConfig
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
8 |
+
|
9 |
+
|
10 |
+
# @torch.jit.script
|
11 |
+
def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
|
12 |
+
logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
13 |
+
# value: B*num_objects*H*W*value_dim
|
14 |
+
# logits: B*num_objects*H*W*num_summaries
|
15 |
+
# masks: B*num_objects*H*W*num_summaries: 1 if allowed
|
16 |
+
weights = logits.sigmoid() * masks
|
17 |
+
# B*num_objects*num_summaries*value_dim
|
18 |
+
sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
|
19 |
+
# B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
|
20 |
+
area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
|
21 |
+
|
22 |
+
# B*num_objects*num_summaries*value_dim
|
23 |
+
return sums, area
|
24 |
+
|
25 |
+
|
26 |
+
class ObjectSummarizer(nn.Module):
|
27 |
+
def __init__(self, model_cfg: DictConfig):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
this_cfg = model_cfg.object_summarizer
|
31 |
+
self.value_dim = model_cfg.value_dim
|
32 |
+
self.embed_dim = this_cfg.embed_dim
|
33 |
+
self.num_summaries = this_cfg.num_summaries
|
34 |
+
self.add_pe = this_cfg.add_pe
|
35 |
+
self.pixel_pe_scale = model_cfg.pixel_pe_scale
|
36 |
+
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
|
37 |
+
|
38 |
+
if self.add_pe:
|
39 |
+
self.pos_enc = PositionalEncoding(self.embed_dim,
|
40 |
+
scale=self.pixel_pe_scale,
|
41 |
+
temperature=self.pixel_pe_temperature)
|
42 |
+
|
43 |
+
self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
|
44 |
+
self.feature_pred = nn.Sequential(
|
45 |
+
nn.Linear(self.embed_dim, self.embed_dim),
|
46 |
+
nn.ReLU(inplace=True),
|
47 |
+
nn.Linear(self.embed_dim, self.embed_dim),
|
48 |
+
)
|
49 |
+
self.weights_pred = nn.Sequential(
|
50 |
+
nn.Linear(self.embed_dim, self.embed_dim),
|
51 |
+
nn.ReLU(inplace=True),
|
52 |
+
nn.Linear(self.embed_dim, self.num_summaries),
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self,
|
56 |
+
masks: torch.Tensor,
|
57 |
+
value: torch.Tensor,
|
58 |
+
need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
|
59 |
+
# masks: B*num_objects*(H0)*(W0)
|
60 |
+
# value: B*num_objects*value_dim*H*W
|
61 |
+
# -> B*num_objects*H*W*value_dim
|
62 |
+
h, w = value.shape[-2:]
|
63 |
+
masks = F.interpolate(masks, size=(h, w), mode='area')
|
64 |
+
masks = masks.unsqueeze(-1)
|
65 |
+
inv_masks = 1 - masks
|
66 |
+
repeated_masks = torch.cat([
|
67 |
+
masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
|
68 |
+
inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
|
69 |
+
],
|
70 |
+
dim=-1)
|
71 |
+
|
72 |
+
value = value.permute(0, 1, 3, 4, 2)
|
73 |
+
value = self.input_proj(value)
|
74 |
+
if self.add_pe:
|
75 |
+
pe = self.pos_enc(value)
|
76 |
+
value = value + pe
|
77 |
+
|
78 |
+
with torch.cuda.amp.autocast(enabled=False):
|
79 |
+
value = value.float()
|
80 |
+
feature = self.feature_pred(value)
|
81 |
+
logits = self.weights_pred(value)
|
82 |
+
sums, area = _weighted_pooling(repeated_masks, feature, logits)
|
83 |
+
|
84 |
+
summaries = torch.cat([sums, area], dim=-1)
|
85 |
+
|
86 |
+
if need_weights:
|
87 |
+
return summaries, logits
|
88 |
+
else:
|
89 |
+
return summaries, None
|
matanyone/model/transformer/object_transformer.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
from omegaconf import DictConfig
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from matanyone.model.group_modules import GConv2d
|
7 |
+
from matanyone.utils.tensor_utils import aggregate
|
8 |
+
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
9 |
+
from matanyone.model.transformer.transformer_layers import *
|
10 |
+
|
11 |
+
|
12 |
+
class QueryTransformerBlock(nn.Module):
|
13 |
+
def __init__(self, model_cfg: DictConfig):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
this_cfg = model_cfg.object_transformer
|
17 |
+
self.embed_dim = this_cfg.embed_dim
|
18 |
+
self.num_heads = this_cfg.num_heads
|
19 |
+
self.num_queries = this_cfg.num_queries
|
20 |
+
self.ff_dim = this_cfg.ff_dim
|
21 |
+
|
22 |
+
self.read_from_pixel = CrossAttention(self.embed_dim,
|
23 |
+
self.num_heads,
|
24 |
+
add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
|
25 |
+
self.self_attn = SelfAttention(self.embed_dim,
|
26 |
+
self.num_heads,
|
27 |
+
add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
|
28 |
+
self.ffn = FFN(self.embed_dim, self.ff_dim)
|
29 |
+
self.read_from_query = CrossAttention(self.embed_dim,
|
30 |
+
self.num_heads,
|
31 |
+
add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
|
32 |
+
norm=this_cfg.read_from_query.output_norm)
|
33 |
+
self.pixel_ffn = PixelFFN(self.embed_dim)
|
34 |
+
|
35 |
+
def forward(
|
36 |
+
self,
|
37 |
+
x: torch.Tensor,
|
38 |
+
pixel: torch.Tensor,
|
39 |
+
query_pe: torch.Tensor,
|
40 |
+
pixel_pe: torch.Tensor,
|
41 |
+
attn_mask: torch.Tensor,
|
42 |
+
need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
|
43 |
+
# x: (bs*num_objects)*num_queries*embed_dim
|
44 |
+
# pixel: bs*num_objects*C*H*W
|
45 |
+
# query_pe: (bs*num_objects)*num_queries*embed_dim
|
46 |
+
# pixel_pe: (bs*num_objects)*(H*W)*C
|
47 |
+
# attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
|
48 |
+
|
49 |
+
# bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
|
50 |
+
pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
|
51 |
+
x, q_weights = self.read_from_pixel(x,
|
52 |
+
pixel_flat,
|
53 |
+
query_pe,
|
54 |
+
pixel_pe,
|
55 |
+
attn_mask=attn_mask,
|
56 |
+
need_weights=need_weights)
|
57 |
+
x = self.self_attn(x, query_pe)
|
58 |
+
x = self.ffn(x)
|
59 |
+
|
60 |
+
pixel_flat, p_weights = self.read_from_query(pixel_flat,
|
61 |
+
x,
|
62 |
+
pixel_pe,
|
63 |
+
query_pe,
|
64 |
+
need_weights=need_weights)
|
65 |
+
pixel = self.pixel_ffn(pixel, pixel_flat)
|
66 |
+
|
67 |
+
if need_weights:
|
68 |
+
bs, num_objects, _, h, w = pixel.shape
|
69 |
+
q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
|
70 |
+
p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
|
71 |
+
self.num_queries, h, w)
|
72 |
+
|
73 |
+
return x, pixel, q_weights, p_weights
|
74 |
+
|
75 |
+
|
76 |
+
class QueryTransformer(nn.Module):
|
77 |
+
def __init__(self, model_cfg: DictConfig):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
this_cfg = model_cfg.object_transformer
|
81 |
+
self.value_dim = model_cfg.value_dim
|
82 |
+
self.embed_dim = this_cfg.embed_dim
|
83 |
+
self.num_heads = this_cfg.num_heads
|
84 |
+
self.num_queries = this_cfg.num_queries
|
85 |
+
|
86 |
+
# query initialization and embedding
|
87 |
+
self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
|
88 |
+
self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
|
89 |
+
|
90 |
+
# projection from object summaries to query initialization and embedding
|
91 |
+
self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
|
92 |
+
self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
|
93 |
+
|
94 |
+
self.pixel_pe_scale = model_cfg.pixel_pe_scale
|
95 |
+
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
|
96 |
+
self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
|
97 |
+
self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
|
98 |
+
self.spatial_pe = PositionalEncoding(self.embed_dim,
|
99 |
+
scale=self.pixel_pe_scale,
|
100 |
+
temperature=self.pixel_pe_temperature,
|
101 |
+
channel_last=False,
|
102 |
+
transpose_output=True)
|
103 |
+
|
104 |
+
# transformer blocks
|
105 |
+
self.num_blocks = this_cfg.num_blocks
|
106 |
+
self.blocks = nn.ModuleList(
|
107 |
+
QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
|
108 |
+
self.mask_pred = nn.ModuleList(
|
109 |
+
nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
|
110 |
+
for _ in range(self.num_blocks + 1))
|
111 |
+
|
112 |
+
self.act = nn.ReLU(inplace=True)
|
113 |
+
|
114 |
+
def forward(self,
|
115 |
+
pixel: torch.Tensor,
|
116 |
+
obj_summaries: torch.Tensor,
|
117 |
+
selector: Optional[torch.Tensor] = None,
|
118 |
+
need_weights: bool = False,
|
119 |
+
seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
|
120 |
+
# pixel: B*num_objects*embed_dim*H*W
|
121 |
+
# obj_summaries: B*num_objects*T*num_queries*embed_dim
|
122 |
+
T = obj_summaries.shape[2]
|
123 |
+
bs, num_objects, _, H, W = pixel.shape
|
124 |
+
|
125 |
+
# normalize object values
|
126 |
+
# the last channel is the cumulative area of the object
|
127 |
+
obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
|
128 |
+
self.embed_dim + 1)
|
129 |
+
# sum over time
|
130 |
+
# during inference, T=1 as we already did streaming average in memory_manager
|
131 |
+
obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
|
132 |
+
obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
|
133 |
+
obj_values = obj_sums / (obj_area + 1e-4)
|
134 |
+
obj_init = self.summary_to_query_init(obj_values)
|
135 |
+
obj_emb = self.summary_to_query_emb(obj_values)
|
136 |
+
|
137 |
+
# positional embeddings for object queries
|
138 |
+
query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
|
139 |
+
query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
|
140 |
+
|
141 |
+
# positional embeddings for pixel features
|
142 |
+
pixel_init = self.pixel_init_proj(pixel)
|
143 |
+
pixel_emb = self.pixel_emb_proj(pixel)
|
144 |
+
pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
|
145 |
+
pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
|
146 |
+
pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
|
147 |
+
|
148 |
+
pixel = pixel_init
|
149 |
+
|
150 |
+
# run the transformer
|
151 |
+
aux_features = {'logits': []}
|
152 |
+
|
153 |
+
# first aux output
|
154 |
+
aux_logits = self.mask_pred[0](pixel).squeeze(2)
|
155 |
+
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
|
156 |
+
aux_features['logits'].append(aux_logits)
|
157 |
+
for i in range(self.num_blocks):
|
158 |
+
query, pixel, q_weights, p_weights = self.blocks[i](query,
|
159 |
+
pixel,
|
160 |
+
query_emb,
|
161 |
+
pixel_pe,
|
162 |
+
attn_mask,
|
163 |
+
need_weights=need_weights)
|
164 |
+
|
165 |
+
if self.training or i <= self.num_blocks - 1 or need_weights:
|
166 |
+
aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
|
167 |
+
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
|
168 |
+
aux_features['logits'].append(aux_logits)
|
169 |
+
|
170 |
+
aux_features['q_weights'] = q_weights # last layer only
|
171 |
+
aux_features['p_weights'] = p_weights # last layer only
|
172 |
+
|
173 |
+
if self.training:
|
174 |
+
# no need to save all heads
|
175 |
+
aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
|
176 |
+
self.num_queries, H, W)[:, :, 0]
|
177 |
+
|
178 |
+
return pixel, aux_features
|
179 |
+
|
180 |
+
def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor:
|
181 |
+
# logits: batch_size*num_objects*H*W
|
182 |
+
# selector: batch_size*num_objects*1*1
|
183 |
+
# returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
|
184 |
+
# where True means the attention is blocked
|
185 |
+
|
186 |
+
if selector is None:
|
187 |
+
prob = logits.sigmoid()
|
188 |
+
else:
|
189 |
+
prob = logits.sigmoid() * selector
|
190 |
+
logits = aggregate(prob, dim=1)
|
191 |
+
|
192 |
+
is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
|
193 |
+
foreground_mask = is_foreground.bool().flatten(start_dim=2)
|
194 |
+
inv_foreground_mask = ~foreground_mask
|
195 |
+
inv_background_mask = foreground_mask
|
196 |
+
|
197 |
+
aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
|
198 |
+
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
|
199 |
+
aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
|
200 |
+
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
|
201 |
+
|
202 |
+
aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
|
203 |
+
|
204 |
+
aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
|
205 |
+
|
206 |
+
return aux_mask
|
matanyone/model/transformer/positional_encoding.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference:
|
2 |
+
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
|
3 |
+
# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
Gets a base embedding for one dimension with sin and cos intertwined
|
15 |
+
"""
|
16 |
+
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
|
17 |
+
return torch.flatten(emb, -2, -1)
|
18 |
+
|
19 |
+
|
20 |
+
class PositionalEncoding(nn.Module):
|
21 |
+
def __init__(self,
|
22 |
+
dim: int,
|
23 |
+
scale: float = math.pi * 2,
|
24 |
+
temperature: float = 10000,
|
25 |
+
normalize: bool = True,
|
26 |
+
channel_last: bool = True,
|
27 |
+
transpose_output: bool = False):
|
28 |
+
super().__init__()
|
29 |
+
dim = int(np.ceil(dim / 4) * 2)
|
30 |
+
self.dim = dim
|
31 |
+
inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
|
32 |
+
self.register_buffer("inv_freq", inv_freq)
|
33 |
+
self.normalize = normalize
|
34 |
+
self.scale = scale
|
35 |
+
self.eps = 1e-6
|
36 |
+
self.channel_last = channel_last
|
37 |
+
self.transpose_output = transpose_output
|
38 |
+
|
39 |
+
self.cached_penc = None # the cache is irrespective of the number of objects
|
40 |
+
|
41 |
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
42 |
+
"""
|
43 |
+
:param tensor: A 4/5d tensor of size
|
44 |
+
channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
|
45 |
+
channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
|
46 |
+
:return: positional encoding tensor that has the same shape as the input if the input is 4d
|
47 |
+
if the input is 5d, the output is broadcastable along the k-dimension
|
48 |
+
"""
|
49 |
+
if len(tensor.shape) != 4 and len(tensor.shape) != 5:
|
50 |
+
raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
|
51 |
+
|
52 |
+
if len(tensor.shape) == 5:
|
53 |
+
# take a sample from the k dimension
|
54 |
+
num_objects = tensor.shape[1]
|
55 |
+
tensor = tensor[:, 0]
|
56 |
+
else:
|
57 |
+
num_objects = None
|
58 |
+
|
59 |
+
if self.channel_last:
|
60 |
+
batch_size, h, w, c = tensor.shape
|
61 |
+
else:
|
62 |
+
batch_size, c, h, w = tensor.shape
|
63 |
+
|
64 |
+
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
|
65 |
+
if num_objects is None:
|
66 |
+
return self.cached_penc
|
67 |
+
else:
|
68 |
+
return self.cached_penc.unsqueeze(1)
|
69 |
+
|
70 |
+
self.cached_penc = None
|
71 |
+
|
72 |
+
pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
|
73 |
+
pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
|
74 |
+
if self.normalize:
|
75 |
+
pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
|
76 |
+
pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
|
77 |
+
|
78 |
+
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
|
79 |
+
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
80 |
+
emb_y = get_emb(sin_inp_y).unsqueeze(1)
|
81 |
+
emb_x = get_emb(sin_inp_x)
|
82 |
+
|
83 |
+
emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
|
84 |
+
emb[:, :, :self.dim] = emb_x
|
85 |
+
emb[:, :, self.dim:] = emb_y
|
86 |
+
|
87 |
+
if not self.channel_last and self.transpose_output:
|
88 |
+
# cancelled out
|
89 |
+
pass
|
90 |
+
elif (not self.channel_last) or (self.transpose_output):
|
91 |
+
emb = emb.permute(2, 0, 1)
|
92 |
+
|
93 |
+
self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
94 |
+
if num_objects is None:
|
95 |
+
return self.cached_penc
|
96 |
+
else:
|
97 |
+
return self.cached_penc.unsqueeze(1)
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
pe = PositionalEncoding(8).cuda()
|
102 |
+
input = torch.ones((1, 8, 8, 8)).cuda()
|
103 |
+
output = pe(input)
|
104 |
+
# print(output)
|
105 |
+
print(output[0, :, 0, 0])
|
106 |
+
print(output[0, :, 0, 5])
|
107 |
+
print(output[0, 0, :, 0])
|
108 |
+
print(output[0, 0, 0, :])
|
matanyone/model/transformer/transformer_layers.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from PyTorch nn.Transformer
|
2 |
+
|
3 |
+
from typing import List, Callable
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from matanyone.model.channel_attn import CAResBlock
|
10 |
+
|
11 |
+
|
12 |
+
class SelfAttention(nn.Module):
|
13 |
+
def __init__(self,
|
14 |
+
dim: int,
|
15 |
+
nhead: int,
|
16 |
+
dropout: float = 0.0,
|
17 |
+
batch_first: bool = True,
|
18 |
+
add_pe_to_qkv: List[bool] = [True, True, False]):
|
19 |
+
super().__init__()
|
20 |
+
self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
|
21 |
+
self.norm = nn.LayerNorm(dim)
|
22 |
+
self.dropout = nn.Dropout(dropout)
|
23 |
+
self.add_pe_to_qkv = add_pe_to_qkv
|
24 |
+
|
25 |
+
def forward(self,
|
26 |
+
x: torch.Tensor,
|
27 |
+
pe: torch.Tensor,
|
28 |
+
attn_mask: bool = None,
|
29 |
+
key_padding_mask: bool = None) -> torch.Tensor:
|
30 |
+
x = self.norm(x)
|
31 |
+
if any(self.add_pe_to_qkv):
|
32 |
+
x_with_pe = x + pe
|
33 |
+
q = x_with_pe if self.add_pe_to_qkv[0] else x
|
34 |
+
k = x_with_pe if self.add_pe_to_qkv[1] else x
|
35 |
+
v = x_with_pe if self.add_pe_to_qkv[2] else x
|
36 |
+
else:
|
37 |
+
q = k = v = x
|
38 |
+
|
39 |
+
r = x
|
40 |
+
x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
|
41 |
+
return r + self.dropout(x)
|
42 |
+
|
43 |
+
|
44 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
|
45 |
+
class CrossAttention(nn.Module):
|
46 |
+
def __init__(self,
|
47 |
+
dim: int,
|
48 |
+
nhead: int,
|
49 |
+
dropout: float = 0.0,
|
50 |
+
batch_first: bool = True,
|
51 |
+
add_pe_to_qkv: List[bool] = [True, True, False],
|
52 |
+
residual: bool = True,
|
53 |
+
norm: bool = True):
|
54 |
+
super().__init__()
|
55 |
+
self.cross_attn = nn.MultiheadAttention(dim,
|
56 |
+
nhead,
|
57 |
+
dropout=dropout,
|
58 |
+
batch_first=batch_first)
|
59 |
+
if norm:
|
60 |
+
self.norm = nn.LayerNorm(dim)
|
61 |
+
else:
|
62 |
+
self.norm = nn.Identity()
|
63 |
+
self.dropout = nn.Dropout(dropout)
|
64 |
+
self.add_pe_to_qkv = add_pe_to_qkv
|
65 |
+
self.residual = residual
|
66 |
+
|
67 |
+
def forward(self,
|
68 |
+
x: torch.Tensor,
|
69 |
+
mem: torch.Tensor,
|
70 |
+
x_pe: torch.Tensor,
|
71 |
+
mem_pe: torch.Tensor,
|
72 |
+
attn_mask: bool = None,
|
73 |
+
*,
|
74 |
+
need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
|
75 |
+
x = self.norm(x)
|
76 |
+
if self.add_pe_to_qkv[0]:
|
77 |
+
q = x + x_pe
|
78 |
+
else:
|
79 |
+
q = x
|
80 |
+
|
81 |
+
if any(self.add_pe_to_qkv[1:]):
|
82 |
+
mem_with_pe = mem + mem_pe
|
83 |
+
k = mem_with_pe if self.add_pe_to_qkv[1] else mem
|
84 |
+
v = mem_with_pe if self.add_pe_to_qkv[2] else mem
|
85 |
+
else:
|
86 |
+
k = v = mem
|
87 |
+
r = x
|
88 |
+
x, weights = self.cross_attn(q,
|
89 |
+
k,
|
90 |
+
v,
|
91 |
+
attn_mask=attn_mask,
|
92 |
+
need_weights=need_weights,
|
93 |
+
average_attn_weights=False)
|
94 |
+
|
95 |
+
if self.residual:
|
96 |
+
return r + self.dropout(x), weights
|
97 |
+
else:
|
98 |
+
return self.dropout(x), weights
|
99 |
+
|
100 |
+
|
101 |
+
class FFN(nn.Module):
|
102 |
+
def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
|
103 |
+
super().__init__()
|
104 |
+
self.linear1 = nn.Linear(dim_in, dim_ff)
|
105 |
+
self.linear2 = nn.Linear(dim_ff, dim_in)
|
106 |
+
self.norm = nn.LayerNorm(dim_in)
|
107 |
+
|
108 |
+
if isinstance(activation, str):
|
109 |
+
self.activation = _get_activation_fn(activation)
|
110 |
+
else:
|
111 |
+
self.activation = activation
|
112 |
+
|
113 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
114 |
+
r = x
|
115 |
+
x = self.norm(x)
|
116 |
+
x = self.linear2(self.activation(self.linear1(x)))
|
117 |
+
x = r + x
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
class PixelFFN(nn.Module):
|
122 |
+
def __init__(self, dim: int):
|
123 |
+
super().__init__()
|
124 |
+
self.dim = dim
|
125 |
+
self.conv = CAResBlock(dim, dim)
|
126 |
+
|
127 |
+
def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
|
128 |
+
# pixel: batch_size * num_objects * dim * H * W
|
129 |
+
# pixel_flat: (batch_size*num_objects) * (H*W) * dim
|
130 |
+
bs, num_objects, _, h, w = pixel.shape
|
131 |
+
pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
|
132 |
+
pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
|
133 |
+
|
134 |
+
x = self.conv(pixel_flat)
|
135 |
+
x = x.view(bs, num_objects, self.dim, h, w)
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class OutputFFN(nn.Module):
|
140 |
+
def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
|
141 |
+
super().__init__()
|
142 |
+
self.linear1 = nn.Linear(dim_in, dim_out)
|
143 |
+
self.linear2 = nn.Linear(dim_out, dim_out)
|
144 |
+
|
145 |
+
if isinstance(activation, str):
|
146 |
+
self.activation = _get_activation_fn(activation)
|
147 |
+
else:
|
148 |
+
self.activation = activation
|
149 |
+
|
150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
151 |
+
x = self.linear2(self.activation(self.linear1(x)))
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
156 |
+
if activation == "relu":
|
157 |
+
return F.relu
|
158 |
+
elif activation == "gelu":
|
159 |
+
return F.gelu
|
160 |
+
|
161 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
matanyone/model/utils/__init__.py
ADDED
File without changes
|
matanyone/model/utils/memory_utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from typing import Optional, Union, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
# @torch.jit.script
|
7 |
+
def get_similarity(mk: torch.Tensor,
|
8 |
+
ms: torch.Tensor,
|
9 |
+
qk: torch.Tensor,
|
10 |
+
qe: torch.Tensor,
|
11 |
+
add_batch_dim: bool = False,
|
12 |
+
uncert_mask = None) -> torch.Tensor:
|
13 |
+
# used for training/inference and memory reading/memory potentiation
|
14 |
+
# mk: B x CK x [N] - Memory keys
|
15 |
+
# ms: B x 1 x [N] - Memory shrinkage
|
16 |
+
# qk: B x CK x [HW/P] - Query keys
|
17 |
+
# qe: B x CK x [HW/P] - Query selection
|
18 |
+
# Dimensions in [] are flattened
|
19 |
+
# Return: B*N*HW
|
20 |
+
if add_batch_dim:
|
21 |
+
mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
|
22 |
+
qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
|
23 |
+
|
24 |
+
CK = mk.shape[1]
|
25 |
+
|
26 |
+
mk = mk.flatten(start_dim=2)
|
27 |
+
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
|
28 |
+
qk = qk.flatten(start_dim=2)
|
29 |
+
qe = qe.flatten(start_dim=2) if qe is not None else None
|
30 |
+
|
31 |
+
# query token selection based on temporal sparsity
|
32 |
+
if uncert_mask is not None:
|
33 |
+
uncert_mask = uncert_mask.flatten(start_dim=2)
|
34 |
+
uncert_mask = uncert_mask.expand(-1, 64, -1)
|
35 |
+
qk = qk * uncert_mask
|
36 |
+
qe = qe * uncert_mask
|
37 |
+
|
38 |
+
if qe is not None:
|
39 |
+
# See XMem's appendix for derivation
|
40 |
+
mk = mk.transpose(1, 2)
|
41 |
+
a_sq = (mk.pow(2) @ qe)
|
42 |
+
two_ab = 2 * (mk @ (qk * qe))
|
43 |
+
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
|
44 |
+
similarity = (-a_sq + two_ab - b_sq)
|
45 |
+
else:
|
46 |
+
# similar to STCN if we don't have the selection term
|
47 |
+
a_sq = mk.pow(2).sum(1).unsqueeze(2)
|
48 |
+
two_ab = 2 * (mk.transpose(1, 2) @ qk)
|
49 |
+
similarity = (-a_sq + two_ab)
|
50 |
+
|
51 |
+
if ms is not None:
|
52 |
+
similarity = similarity * ms / math.sqrt(CK) # B*N*HW
|
53 |
+
else:
|
54 |
+
similarity = similarity / math.sqrt(CK) # B*N*HW
|
55 |
+
|
56 |
+
return similarity
|
57 |
+
|
58 |
+
|
59 |
+
def do_softmax(
|
60 |
+
similarity: torch.Tensor,
|
61 |
+
top_k: Optional[int] = None,
|
62 |
+
inplace: bool = False,
|
63 |
+
return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
64 |
+
# normalize similarity with top-k softmax
|
65 |
+
# similarity: B x N x [HW/P]
|
66 |
+
# use inplace with care
|
67 |
+
if top_k is not None:
|
68 |
+
values, indices = torch.topk(similarity, k=top_k, dim=1)
|
69 |
+
|
70 |
+
x_exp = values.exp_()
|
71 |
+
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
|
72 |
+
if inplace:
|
73 |
+
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
|
74 |
+
affinity = similarity
|
75 |
+
else:
|
76 |
+
affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
|
77 |
+
else:
|
78 |
+
maxes = torch.max(similarity, dim=1, keepdim=True)[0]
|
79 |
+
x_exp = torch.exp(similarity - maxes)
|
80 |
+
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
|
81 |
+
affinity = x_exp / x_exp_sum
|
82 |
+
indices = None
|
83 |
+
|
84 |
+
if return_usage:
|
85 |
+
return affinity, affinity.sum(dim=2)
|
86 |
+
|
87 |
+
return affinity
|
88 |
+
|
89 |
+
|
90 |
+
def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
|
91 |
+
qe: torch.Tensor, uncert_mask = None) -> torch.Tensor:
|
92 |
+
# shorthand used in training with no top-k
|
93 |
+
similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask)
|
94 |
+
affinity = do_softmax(similarity)
|
95 |
+
return affinity
|
96 |
+
|
97 |
+
def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor:
|
98 |
+
B, CV, T, H, W = mv.shape
|
99 |
+
|
100 |
+
mo = mv.view(B, CV, T * H * W)
|
101 |
+
mem = torch.bmm(mo, affinity)
|
102 |
+
if uncert_mask is not None:
|
103 |
+
uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1)
|
104 |
+
mem = mem * uncert_mask
|
105 |
+
mem = mem.view(B, CV, H, W)
|
106 |
+
|
107 |
+
return mem
|
matanyone/model/utils/parameter_groups.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
log = logging.getLogger()
|
4 |
+
|
5 |
+
|
6 |
+
def get_parameter_groups(model, stage_cfg, print_log=False):
|
7 |
+
"""
|
8 |
+
Assign different weight decays and learning rates to different parameters.
|
9 |
+
Returns a parameter group which can be passed to the optimizer.
|
10 |
+
"""
|
11 |
+
weight_decay = stage_cfg.weight_decay
|
12 |
+
embed_weight_decay = stage_cfg.embed_weight_decay
|
13 |
+
backbone_lr_ratio = stage_cfg.backbone_lr_ratio
|
14 |
+
base_lr = stage_cfg.learning_rate
|
15 |
+
|
16 |
+
backbone_params = []
|
17 |
+
embed_params = []
|
18 |
+
other_params = []
|
19 |
+
|
20 |
+
embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
|
21 |
+
embedding_names = [e + '.weight' for e in embedding_names]
|
22 |
+
|
23 |
+
# inspired by detectron2
|
24 |
+
memo = set()
|
25 |
+
for name, param in model.named_parameters():
|
26 |
+
if not param.requires_grad:
|
27 |
+
continue
|
28 |
+
# Avoid duplicating parameters
|
29 |
+
if param in memo:
|
30 |
+
continue
|
31 |
+
memo.add(param)
|
32 |
+
|
33 |
+
if name.startswith('module'):
|
34 |
+
name = name[7:]
|
35 |
+
|
36 |
+
inserted = False
|
37 |
+
if name.startswith('pixel_encoder.'):
|
38 |
+
backbone_params.append(param)
|
39 |
+
inserted = True
|
40 |
+
if print_log:
|
41 |
+
log.info(f'{name} counted as a backbone parameter.')
|
42 |
+
else:
|
43 |
+
for e in embedding_names:
|
44 |
+
if name.endswith(e):
|
45 |
+
embed_params.append(param)
|
46 |
+
inserted = True
|
47 |
+
if print_log:
|
48 |
+
log.info(f'{name} counted as an embedding parameter.')
|
49 |
+
break
|
50 |
+
|
51 |
+
if not inserted:
|
52 |
+
other_params.append(param)
|
53 |
+
|
54 |
+
parameter_groups = [
|
55 |
+
{
|
56 |
+
'params': backbone_params,
|
57 |
+
'lr': base_lr * backbone_lr_ratio,
|
58 |
+
'weight_decay': weight_decay
|
59 |
+
},
|
60 |
+
{
|
61 |
+
'params': embed_params,
|
62 |
+
'lr': base_lr,
|
63 |
+
'weight_decay': embed_weight_decay
|
64 |
+
},
|
65 |
+
{
|
66 |
+
'params': other_params,
|
67 |
+
'lr': base_lr,
|
68 |
+
'weight_decay': weight_decay
|
69 |
+
},
|
70 |
+
]
|
71 |
+
|
72 |
+
return parameter_groups
|
matanyone/model/utils/resnet.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
resnet.py - A modified ResNet structure
|
3 |
+
We append extra channels to the first conv by some network surgery
|
4 |
+
"""
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.utils import model_zoo
|
12 |
+
|
13 |
+
|
14 |
+
def load_weights_add_extra_dim(target, source_state, extra_dim=1):
|
15 |
+
new_dict = OrderedDict()
|
16 |
+
|
17 |
+
for k1, v1 in target.state_dict().items():
|
18 |
+
if not 'num_batches_tracked' in k1:
|
19 |
+
if k1 in source_state:
|
20 |
+
tar_v = source_state[k1]
|
21 |
+
|
22 |
+
if v1.shape != tar_v.shape:
|
23 |
+
# Init the new segmentation channel with zeros
|
24 |
+
# print(v1.shape, tar_v.shape)
|
25 |
+
c, _, w, h = v1.shape
|
26 |
+
pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
|
27 |
+
nn.init.orthogonal_(pads)
|
28 |
+
tar_v = torch.cat([tar_v, pads], 1)
|
29 |
+
|
30 |
+
new_dict[k1] = tar_v
|
31 |
+
|
32 |
+
target.load_state_dict(new_dict)
|
33 |
+
|
34 |
+
|
35 |
+
model_urls = {
|
36 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
37 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
|
42 |
+
return nn.Conv2d(in_planes,
|
43 |
+
out_planes,
|
44 |
+
kernel_size=3,
|
45 |
+
stride=stride,
|
46 |
+
padding=dilation,
|
47 |
+
dilation=dilation,
|
48 |
+
bias=False)
|
49 |
+
|
50 |
+
|
51 |
+
class BasicBlock(nn.Module):
|
52 |
+
expansion = 1
|
53 |
+
|
54 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
|
55 |
+
super(BasicBlock, self).__init__()
|
56 |
+
self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
|
57 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
58 |
+
self.relu = nn.ReLU(inplace=True)
|
59 |
+
self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
|
60 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
61 |
+
self.downsample = downsample
|
62 |
+
self.stride = stride
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
residual = x
|
66 |
+
|
67 |
+
out = self.conv1(x)
|
68 |
+
out = self.bn1(out)
|
69 |
+
out = self.relu(out)
|
70 |
+
|
71 |
+
out = self.conv2(out)
|
72 |
+
out = self.bn2(out)
|
73 |
+
|
74 |
+
if self.downsample is not None:
|
75 |
+
residual = self.downsample(x)
|
76 |
+
|
77 |
+
out += residual
|
78 |
+
out = self.relu(out)
|
79 |
+
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class Bottleneck(nn.Module):
|
84 |
+
expansion = 4
|
85 |
+
|
86 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
|
87 |
+
super(Bottleneck, self).__init__()
|
88 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
89 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
90 |
+
self.conv2 = nn.Conv2d(planes,
|
91 |
+
planes,
|
92 |
+
kernel_size=3,
|
93 |
+
stride=stride,
|
94 |
+
dilation=dilation,
|
95 |
+
padding=dilation,
|
96 |
+
bias=False)
|
97 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
98 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
99 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
100 |
+
self.relu = nn.ReLU(inplace=True)
|
101 |
+
self.downsample = downsample
|
102 |
+
self.stride = stride
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
residual = x
|
106 |
+
|
107 |
+
out = self.conv1(x)
|
108 |
+
out = self.bn1(out)
|
109 |
+
out = self.relu(out)
|
110 |
+
|
111 |
+
out = self.conv2(out)
|
112 |
+
out = self.bn2(out)
|
113 |
+
out = self.relu(out)
|
114 |
+
|
115 |
+
out = self.conv3(out)
|
116 |
+
out = self.bn3(out)
|
117 |
+
|
118 |
+
if self.downsample is not None:
|
119 |
+
residual = self.downsample(x)
|
120 |
+
|
121 |
+
out += residual
|
122 |
+
out = self.relu(out)
|
123 |
+
|
124 |
+
return out
|
125 |
+
|
126 |
+
|
127 |
+
class ResNet(nn.Module):
|
128 |
+
def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
|
129 |
+
self.inplanes = 64
|
130 |
+
super(ResNet, self).__init__()
|
131 |
+
self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
132 |
+
self.bn1 = nn.BatchNorm2d(64)
|
133 |
+
self.relu = nn.ReLU(inplace=True)
|
134 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
135 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
136 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
137 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
138 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
139 |
+
|
140 |
+
for m in self.modules():
|
141 |
+
if isinstance(m, nn.Conv2d):
|
142 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
143 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
144 |
+
elif isinstance(m, nn.BatchNorm2d):
|
145 |
+
m.weight.data.fill_(1)
|
146 |
+
m.bias.data.zero_()
|
147 |
+
|
148 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
|
149 |
+
downsample = None
|
150 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
151 |
+
downsample = nn.Sequential(
|
152 |
+
nn.Conv2d(self.inplanes,
|
153 |
+
planes * block.expansion,
|
154 |
+
kernel_size=1,
|
155 |
+
stride=stride,
|
156 |
+
bias=False),
|
157 |
+
nn.BatchNorm2d(planes * block.expansion),
|
158 |
+
)
|
159 |
+
|
160 |
+
layers = [block(self.inplanes, planes, stride, downsample)]
|
161 |
+
self.inplanes = planes * block.expansion
|
162 |
+
for i in range(1, blocks):
|
163 |
+
layers.append(block(self.inplanes, planes, dilation=dilation))
|
164 |
+
|
165 |
+
return nn.Sequential(*layers)
|
166 |
+
|
167 |
+
|
168 |
+
def resnet18(pretrained=True, extra_dim=0):
|
169 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
|
170 |
+
if pretrained:
|
171 |
+
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
|
172 |
+
return model
|
173 |
+
|
174 |
+
|
175 |
+
def resnet50(pretrained=True, extra_dim=0):
|
176 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
|
177 |
+
if pretrained:
|
178 |
+
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
|
179 |
+
return model
|
matanyone/utils/__init__.py
ADDED
File without changes
|
matanyone/utils/get_default_model.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A helper function to get a default model for quick testing
|
3 |
+
"""
|
4 |
+
from omegaconf import open_dict
|
5 |
+
from hydra import compose, initialize
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from matanyone.model.matanyone import MatAnyone
|
9 |
+
from matanyone.inference.utils.args_utils import get_dataset_cfg
|
10 |
+
|
11 |
+
def get_matanyone_model(ckpt_path, device) -> MatAnyone:
|
12 |
+
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
|
13 |
+
cfg = compose(config_name="eval_matanyone_config")
|
14 |
+
|
15 |
+
with open_dict(cfg):
|
16 |
+
cfg['weights'] = ckpt_path
|
17 |
+
|
18 |
+
# Load the network weights
|
19 |
+
matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
|
20 |
+
model_weights = torch.load(cfg.weights, map_location=device)
|
21 |
+
matanyone.load_weights(model_weights)
|
22 |
+
|
23 |
+
return matanyone
|
matanyone/utils/tensor_utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Iterable
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
# STM
|
7 |
+
def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
|
8 |
+
h, w = in_img.shape[-2:]
|
9 |
+
|
10 |
+
if h % d > 0:
|
11 |
+
new_h = h + d - h % d
|
12 |
+
else:
|
13 |
+
new_h = h
|
14 |
+
if w % d > 0:
|
15 |
+
new_w = w + d - w % d
|
16 |
+
else:
|
17 |
+
new_w = w
|
18 |
+
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
|
19 |
+
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
|
20 |
+
pad_array = (int(lw), int(uw), int(lh), int(uh))
|
21 |
+
out = F.pad(in_img, pad_array)
|
22 |
+
return out, pad_array
|
23 |
+
|
24 |
+
|
25 |
+
def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
|
26 |
+
if len(img.shape) == 4:
|
27 |
+
if pad[2] + pad[3] > 0:
|
28 |
+
img = img[:, :, pad[2]:-pad[3], :]
|
29 |
+
if pad[0] + pad[1] > 0:
|
30 |
+
img = img[:, :, :, pad[0]:-pad[1]]
|
31 |
+
elif len(img.shape) == 3:
|
32 |
+
if pad[2] + pad[3] > 0:
|
33 |
+
img = img[:, pad[2]:-pad[3], :]
|
34 |
+
if pad[0] + pad[1] > 0:
|
35 |
+
img = img[:, :, pad[0]:-pad[1]]
|
36 |
+
elif len(img.shape) == 5:
|
37 |
+
if pad[2] + pad[3] > 0:
|
38 |
+
img = img[:, :, :, pad[2]:-pad[3], :]
|
39 |
+
if pad[0] + pad[1] > 0:
|
40 |
+
img = img[:, :, :, :, pad[0]:-pad[1]]
|
41 |
+
else:
|
42 |
+
raise NotImplementedError
|
43 |
+
return img
|
44 |
+
|
45 |
+
|
46 |
+
# @torch.jit.script
|
47 |
+
def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
|
48 |
+
with torch.cuda.amp.autocast(enabled=False):
|
49 |
+
prob = prob.float()
|
50 |
+
new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
|
51 |
+
dim).clamp(1e-7, 1 - 1e-7)
|
52 |
+
logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf)
|
53 |
+
|
54 |
+
return logits
|
55 |
+
|
56 |
+
|
57 |
+
# @torch.jit.script
|
58 |
+
def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
|
59 |
+
# cls_gt: B*1*H*W
|
60 |
+
B, _, H, W = cls_gt.shape
|
61 |
+
one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
|
62 |
+
return one_hot
|
requirements.txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
progressbar2
|
2 |
+
gdown >= 4.7.1
|
3 |
+
gitpython >= 3.1
|
4 |
+
git+https://github.com/cheind/py-thin-plate-spline
|
5 |
+
hickle >= 5.0
|
6 |
+
tensorboard >= 2.11
|
7 |
+
numpy >= 1.21
|
8 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
9 |
+
gradio==4.31.0
|
10 |
+
fastapi==0.111.0
|
11 |
+
pydantic==2.7.1
|
12 |
+
opencv-python >= 4.8
|
13 |
+
matplotlib
|
14 |
+
pyyaml
|
15 |
+
av >= 0.5.2
|
16 |
+
openmim
|
17 |
+
tqdm >= 4.66.1
|
18 |
+
psutil
|
19 |
+
ffmpeg-python
|
20 |
+
cython
|
21 |
+
Pillow >= 9.5
|
22 |
+
scipy >= 1.7
|
23 |
+
pycocotools >= 2.0.7
|
24 |
+
einops >= 0.6
|
25 |
+
hydra-core >= 1.3.2
|
26 |
+
PySide6 >= 6.2.0
|
27 |
+
charset-normalizer >= 3.1.0
|
28 |
+
netifaces >= 0.11.0
|
29 |
+
cchardet >= 2.1.7
|
30 |
+
easydict
|
31 |
+
requests
|
32 |
+
pyqtdarktheme
|
33 |
+
imageio == 2.25.0
|
34 |
+
imageio[ffmpeg]
|
35 |
+
ffmpeg-python
|