Spaces:
Running
on
Zero
Running
on
Zero
Florence-2 + SAM2 + FLUX.1
Browse files- .gitattributes +1 -0
- app.py +71 -87
- configs/__init__.py +5 -0
- configs/sam2_hiera_b+.yaml +113 -0
- configs/sam2_hiera_l.yaml +117 -0
- configs/sam2_hiera_s.yaml +116 -0
- configs/sam2_hiera_t.yaml +118 -0
- requirements.txt +8 -1
- utils/florence.py +54 -0
- utils/sam.py +45 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
checkpoints/ filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
from typing import Tuple
|
2 |
|
3 |
-
import
|
4 |
import random
|
5 |
import numpy as np
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
from PIL import Image
|
10 |
from diffusers import FluxInpaintPipeline
|
11 |
|
|
|
|
|
|
|
|
|
12 |
MARKDOWN = """
|
13 |
# FLUX.1 Inpainting 🔥
|
14 |
|
@@ -19,52 +23,16 @@ for taking it to the next level by enabling inpainting with the FLUX.
|
|
19 |
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
IMAGE_SIZE = 1024
|
22 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
new_data.append((0, 0, 0, 0))
|
33 |
-
else:
|
34 |
-
new_data.append(item)
|
35 |
-
|
36 |
-
image.putdata(new_data)
|
37 |
-
return image
|
38 |
-
|
39 |
-
|
40 |
-
EXAMPLES = [
|
41 |
-
[
|
42 |
-
{
|
43 |
-
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
|
44 |
-
"layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2.png", stream=True).raw))],
|
45 |
-
"composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
|
46 |
-
},
|
47 |
-
"little lion",
|
48 |
-
42,
|
49 |
-
False,
|
50 |
-
0.85,
|
51 |
-
30
|
52 |
-
],
|
53 |
-
[
|
54 |
-
{
|
55 |
-
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
|
56 |
-
"layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-3.png", stream=True).raw))],
|
57 |
-
"composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-3.png", stream=True).raw),
|
58 |
-
},
|
59 |
-
"tattoos",
|
60 |
-
42,
|
61 |
-
False,
|
62 |
-
0.85,
|
63 |
-
30
|
64 |
-
]
|
65 |
-
]
|
66 |
-
|
67 |
-
pipe = FluxInpaintPipeline.from_pretrained(
|
68 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
69 |
|
70 |
|
@@ -74,11 +42,6 @@ def resize_image_dimensions(
|
|
74 |
) -> Tuple[int, int]:
|
75 |
width, height = original_resolution_wh
|
76 |
|
77 |
-
# if width <= maximum_dimension and height <= maximum_dimension:
|
78 |
-
# width = width - (width % 32)
|
79 |
-
# height = height - (height % 32)
|
80 |
-
# return width, height
|
81 |
-
|
82 |
if width > height:
|
83 |
scaling_factor = maximum_dimension / width
|
84 |
else:
|
@@ -93,17 +56,20 @@ def resize_image_dimensions(
|
|
93 |
return new_width, new_height
|
94 |
|
95 |
|
96 |
-
@spaces.GPU(duration=
|
|
|
|
|
97 |
def process(
|
98 |
input_image_editor: dict,
|
99 |
-
|
|
|
100 |
seed_slicer: int,
|
101 |
randomize_seed_checkbox: bool,
|
102 |
strength_slider: float,
|
103 |
num_inference_steps_slider: int,
|
104 |
progress=gr.Progress(track_tqdm=True)
|
105 |
):
|
106 |
-
if not
|
107 |
gr.Info("Please enter a text prompt.")
|
108 |
return None, None
|
109 |
|
@@ -114,21 +80,50 @@ def process(
|
|
114 |
gr.Info("Please upload an image.")
|
115 |
return None, None
|
116 |
|
117 |
-
if not mask:
|
118 |
-
gr.Info("Please draw a mask
|
|
|
|
|
|
|
|
|
|
|
119 |
return None, None
|
120 |
|
121 |
width, height = resize_image_dimensions(original_resolution_wh=image.size)
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
if randomize_seed_checkbox:
|
126 |
seed_slicer = random.randint(0, MAX_SEED)
|
127 |
generator = torch.Generator().manual_seed(seed_slicer)
|
128 |
-
result =
|
129 |
-
prompt=
|
130 |
-
image=
|
131 |
-
mask_image=
|
132 |
width=width,
|
133 |
height=height,
|
134 |
strength=strength_slider,
|
@@ -136,7 +131,7 @@ def process(
|
|
136 |
num_inference_steps=num_inference_steps_slider
|
137 |
).images[0]
|
138 |
print('INFERENCE DONE')
|
139 |
-
return result,
|
140 |
|
141 |
|
142 |
with gr.Blocks() as demo:
|
@@ -152,17 +147,24 @@ with gr.Blocks() as demo:
|
|
152 |
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
|
153 |
|
154 |
with gr.Row():
|
155 |
-
|
156 |
label="Prompt",
|
157 |
show_label=False,
|
158 |
max_lines=1,
|
159 |
-
placeholder="Enter
|
160 |
container=False,
|
161 |
)
|
162 |
submit_button_component = gr.Button(
|
163 |
value='Submit', variant='primary', scale=0)
|
164 |
|
165 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
seed_slicer_component = gr.Slider(
|
167 |
label="Seed",
|
168 |
minimum=0,
|
@@ -201,31 +203,13 @@ with gr.Blocks() as demo:
|
|
201 |
with gr.Accordion("Debug", open=False):
|
202 |
output_mask_component = gr.Image(
|
203 |
type='pil', image_mode='RGB', label='Input mask', format="png")
|
204 |
-
with gr.Row():
|
205 |
-
gr.Examples(
|
206 |
-
fn=process,
|
207 |
-
examples=EXAMPLES,
|
208 |
-
inputs=[
|
209 |
-
input_image_editor_component,
|
210 |
-
input_text_component,
|
211 |
-
seed_slicer_component,
|
212 |
-
randomize_seed_checkbox_component,
|
213 |
-
strength_slider_component,
|
214 |
-
num_inference_steps_slider_component
|
215 |
-
],
|
216 |
-
outputs=[
|
217 |
-
output_image_component,
|
218 |
-
output_mask_component
|
219 |
-
],
|
220 |
-
run_on_click=True,
|
221 |
-
cache_examples=True
|
222 |
-
)
|
223 |
|
224 |
submit_button_component.click(
|
225 |
fn=process,
|
226 |
inputs=[
|
227 |
input_image_editor_component,
|
228 |
-
|
|
|
229 |
seed_slicer_component,
|
230 |
randomize_seed_checkbox_component,
|
231 |
strength_slider_component,
|
|
|
1 |
from typing import Tuple
|
2 |
|
3 |
+
import supervision as sv
|
4 |
import random
|
5 |
import numpy as np
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
+
from PIL import Image, ImageFilter
|
10 |
from diffusers import FluxInpaintPipeline
|
11 |
|
12 |
+
from utils.florence import load_florence_model, run_florence_inference, \
|
13 |
+
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
|
14 |
+
from utils.sam import load_sam_image_model, run_sam_inference
|
15 |
+
|
16 |
MARKDOWN = """
|
17 |
# FLUX.1 Inpainting 🔥
|
18 |
|
|
|
23 |
|
24 |
MAX_SEED = np.iinfo(np.int32).max
|
25 |
IMAGE_SIZE = 1024
|
26 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
|
28 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
29 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
30 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
31 |
+
torch.backends.cudnn.allow_tf32 = True
|
32 |
+
|
33 |
+
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
|
34 |
+
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
|
35 |
+
FLUX_INPAINTING_PIPELINE = FluxInpaintPipeline.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
37 |
|
38 |
|
|
|
42 |
) -> Tuple[int, int]:
|
43 |
width, height = original_resolution_wh
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
if width > height:
|
46 |
scaling_factor = maximum_dimension / width
|
47 |
else:
|
|
|
56 |
return new_width, new_height
|
57 |
|
58 |
|
59 |
+
@spaces.GPU(duration=150)
|
60 |
+
@torch.inference_mode()
|
61 |
+
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
62 |
def process(
|
63 |
input_image_editor: dict,
|
64 |
+
inpainting_prompt_text: str,
|
65 |
+
segmentation_prompt_text: str,
|
66 |
seed_slicer: int,
|
67 |
randomize_seed_checkbox: bool,
|
68 |
strength_slider: float,
|
69 |
num_inference_steps_slider: int,
|
70 |
progress=gr.Progress(track_tqdm=True)
|
71 |
):
|
72 |
+
if not inpainting_prompt_text:
|
73 |
gr.Info("Please enter a text prompt.")
|
74 |
return None, None
|
75 |
|
|
|
80 |
gr.Info("Please upload an image.")
|
81 |
return None, None
|
82 |
|
83 |
+
if not mask and not segmentation_prompt_text:
|
84 |
+
gr.Info("Please draw a mask or enter a segmentation prompt.")
|
85 |
+
return None, None
|
86 |
+
|
87 |
+
if mask and segmentation_prompt_text:
|
88 |
+
gr.Info("Both mask and segmentation prompt are provided. Please provide only "
|
89 |
+
"one.")
|
90 |
return None, None
|
91 |
|
92 |
width, height = resize_image_dimensions(original_resolution_wh=image.size)
|
93 |
+
image = image.resize((width, height), Image.LANCZOS)
|
94 |
+
|
95 |
+
if segmentation_prompt_text:
|
96 |
+
_, result = run_florence_inference(
|
97 |
+
model=FLORENCE_MODEL,
|
98 |
+
processor=FLORENCE_PROCESSOR,
|
99 |
+
device=DEVICE,
|
100 |
+
image=image,
|
101 |
+
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
|
102 |
+
text=segmentation_prompt_text
|
103 |
+
)
|
104 |
+
detections = sv.Detections.from_lmm(
|
105 |
+
lmm=sv.LMM.FLORENCE_2,
|
106 |
+
result=result,
|
107 |
+
resolution_wh=image.size
|
108 |
+
)
|
109 |
+
detections = run_sam_inference(SAM_IMAGE_MODEL, image, detections)
|
110 |
+
|
111 |
+
if len(detections) == 0:
|
112 |
+
gr.Info(f"{segmentation_prompt_text} prompt did not return any detections.")
|
113 |
+
return None, None
|
114 |
+
|
115 |
+
mask = Image.fromarray((detections.mask[0].astype(np.uint8)) * 255)
|
116 |
+
|
117 |
+
mask = mask.resize((width, height), Image.LANCZOS)
|
118 |
+
mask = mask.filter(ImageFilter.GaussianBlur(radius=10))
|
119 |
|
120 |
if randomize_seed_checkbox:
|
121 |
seed_slicer = random.randint(0, MAX_SEED)
|
122 |
generator = torch.Generator().manual_seed(seed_slicer)
|
123 |
+
result = FLUX_INPAINTING_PIPELINE(
|
124 |
+
prompt=inpainting_prompt_text,
|
125 |
+
image=image,
|
126 |
+
mask_image=mask,
|
127 |
width=width,
|
128 |
height=height,
|
129 |
strength=strength_slider,
|
|
|
131 |
num_inference_steps=num_inference_steps_slider
|
132 |
).images[0]
|
133 |
print('INFERENCE DONE')
|
134 |
+
return result, mask
|
135 |
|
136 |
|
137 |
with gr.Blocks() as demo:
|
|
|
147 |
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
|
148 |
|
149 |
with gr.Row():
|
150 |
+
inpainting_prompt_text_component = gr.Text(
|
151 |
label="Prompt",
|
152 |
show_label=False,
|
153 |
max_lines=1,
|
154 |
+
placeholder="Enter inpainting prompt",
|
155 |
container=False,
|
156 |
)
|
157 |
submit_button_component = gr.Button(
|
158 |
value='Submit', variant='primary', scale=0)
|
159 |
|
160 |
with gr.Accordion("Advanced Settings", open=False):
|
161 |
+
segmentation_prompt_text_component = gr.Text(
|
162 |
+
label="Prompt",
|
163 |
+
show_label=False,
|
164 |
+
max_lines=1,
|
165 |
+
placeholder="Enter segmentation prompt",
|
166 |
+
container=False,
|
167 |
+
)
|
168 |
seed_slicer_component = gr.Slider(
|
169 |
label="Seed",
|
170 |
minimum=0,
|
|
|
203 |
with gr.Accordion("Debug", open=False):
|
204 |
output_mask_component = gr.Image(
|
205 |
type='pil', image_mode='RGB', label='Input mask', format="png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
submit_button_component.click(
|
208 |
fn=process,
|
209 |
inputs=[
|
210 |
input_image_editor_component,
|
211 |
+
inpainting_prompt_text_component,
|
212 |
+
segmentation_prompt_text_component,
|
213 |
seed_slicer_component,
|
214 |
randomize_seed_checkbox_component,
|
215 |
strength_slider_component,
|
configs/__init__.py
CHANGED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
configs/sam2_hiera_b+.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [32, 32]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [32, 32]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
# use high-resolution feature map in the SAM mask decoder
|
93 |
+
use_high_res_features_in_sam: true
|
94 |
+
# output 3 masks on the first click on initial conditioning frames
|
95 |
+
multimask_output_in_sam: true
|
96 |
+
# SAM heads
|
97 |
+
iou_prediction_use_sigmoid: True
|
98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
99 |
+
use_obj_ptrs_in_encoder: true
|
100 |
+
add_tpos_enc_to_obj_ptrs: false
|
101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
102 |
+
# object occlusion prediction
|
103 |
+
pred_obj_scores: true
|
104 |
+
pred_obj_scores_mlp: true
|
105 |
+
fixed_no_obj_ptr: true
|
106 |
+
# multimask tracking settings
|
107 |
+
multimask_output_for_tracking: true
|
108 |
+
use_multimask_token_for_obj_ptr: true
|
109 |
+
multimask_min_pt_num: 0
|
110 |
+
multimask_max_pt_num: 1
|
111 |
+
use_mlp_for_obj_ptr_proj: true
|
112 |
+
# Compilation flag
|
113 |
+
compile_image_encoder: False
|
configs/sam2_hiera_l.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 144
|
12 |
+
num_heads: 2
|
13 |
+
stages: [2, 6, 36, 4]
|
14 |
+
global_att_blocks: [23, 33, 43]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
window_spec: [8, 4, 16, 8]
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [32, 32]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
+
rope_theta: 10000.0
|
54 |
+
feat_sizes: [32, 32]
|
55 |
+
rope_k_repeat: True
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 4
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
num_maskmem: 7
|
89 |
+
image_size: 1024
|
90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: false
|
105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
+
# object occlusion prediction
|
107 |
+
pred_obj_scores: true
|
108 |
+
pred_obj_scores_mlp: true
|
109 |
+
fixed_no_obj_ptr: true
|
110 |
+
# multimask tracking settings
|
111 |
+
multimask_output_for_tracking: true
|
112 |
+
use_multimask_token_for_obj_ptr: true
|
113 |
+
multimask_min_pt_num: 0
|
114 |
+
multimask_max_pt_num: 1
|
115 |
+
use_mlp_for_obj_ptr_proj: true
|
116 |
+
# Compilation flag
|
117 |
+
compile_image_encoder: False
|
configs/sam2_hiera_s.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 11, 2]
|
14 |
+
global_att_blocks: [7, 10, 13]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
92 |
+
use_mask_input_as_output_without_sam: true
|
93 |
+
# Memory
|
94 |
+
directly_add_no_mem_embed: true
|
95 |
+
# use high-resolution feature map in the SAM mask decoder
|
96 |
+
use_high_res_features_in_sam: true
|
97 |
+
# output 3 masks on the first click on initial conditioning frames
|
98 |
+
multimask_output_in_sam: true
|
99 |
+
# SAM heads
|
100 |
+
iou_prediction_use_sigmoid: True
|
101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
102 |
+
use_obj_ptrs_in_encoder: true
|
103 |
+
add_tpos_enc_to_obj_ptrs: false
|
104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
+
# object occlusion prediction
|
106 |
+
pred_obj_scores: true
|
107 |
+
pred_obj_scores_mlp: true
|
108 |
+
fixed_no_obj_ptr: true
|
109 |
+
# multimask tracking settings
|
110 |
+
multimask_output_for_tracking: true
|
111 |
+
use_multimask_token_for_obj_ptr: true
|
112 |
+
multimask_min_pt_num: 0
|
113 |
+
multimask_max_pt_num: 1
|
114 |
+
use_mlp_for_obj_ptr_proj: true
|
115 |
+
# Compilation flag
|
116 |
+
compile_image_encoder: False
|
configs/sam2_hiera_t.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 7, 2]
|
14 |
+
global_att_blocks: [5, 7, 9]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
# SAM decoder
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: false
|
105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
+
# object occlusion prediction
|
107 |
+
pred_obj_scores: true
|
108 |
+
pred_obj_scores_mlp: true
|
109 |
+
fixed_no_obj_ptr: true
|
110 |
+
# multimask tracking settings
|
111 |
+
multimask_output_for_tracking: true
|
112 |
+
use_multimask_token_for_obj_ptr: true
|
113 |
+
multimask_min_pt_num: 0
|
114 |
+
multimask_max_pt_num: 1
|
115 |
+
use_mlp_for_obj_ptr_proj: true
|
116 |
+
# Compilation flag
|
117 |
+
# HieraT does not currently support compilation, should always be set to False
|
118 |
+
compile_image_encoder: False
|
requirements.txt
CHANGED
@@ -1,6 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
gradio
|
2 |
spaces
|
3 |
accelerate
|
4 |
transformers==4.42.4
|
5 |
sentencepiece
|
6 |
-
|
|
|
|
1 |
+
tqdm
|
2 |
+
einops
|
3 |
+
timm
|
4 |
+
samv2
|
5 |
+
opencv-python
|
6 |
+
pytest
|
7 |
gradio
|
8 |
spaces
|
9 |
accelerate
|
10 |
transformers==4.42.4
|
11 |
sentencepiece
|
12 |
+
supervision
|
13 |
+
git+https://github.com/Gothos/diffusers.git@flux-inpaint
|
utils/florence.py
CHANGED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union, Any, Tuple, Dict
|
3 |
+
from unittest.mock import patch
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
8 |
+
from transformers.dynamic_module_utils import get_imports
|
9 |
+
|
10 |
+
FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
|
11 |
+
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
|
12 |
+
|
13 |
+
|
14 |
+
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
|
15 |
+
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
|
16 |
+
if not str(filename).endswith("/modeling_florence2.py"):
|
17 |
+
return get_imports(filename)
|
18 |
+
imports = get_imports(filename)
|
19 |
+
imports.remove("flash_attn")
|
20 |
+
return imports
|
21 |
+
|
22 |
+
|
23 |
+
def load_florence_model(
|
24 |
+
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
|
25 |
+
) -> Tuple[Any, Any]:
|
26 |
+
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
27 |
+
model = AutoModelForCausalLM.from_pretrained(
|
28 |
+
checkpoint, trust_remote_code=True).to(device).eval()
|
29 |
+
processor = AutoProcessor.from_pretrained(
|
30 |
+
checkpoint, trust_remote_code=True)
|
31 |
+
return model, processor
|
32 |
+
|
33 |
+
|
34 |
+
def run_florence_inference(
|
35 |
+
model: Any,
|
36 |
+
processor: Any,
|
37 |
+
device: torch.device,
|
38 |
+
image: Image,
|
39 |
+
task: str,
|
40 |
+
text: str = ""
|
41 |
+
) -> Tuple[str, Dict]:
|
42 |
+
prompt = task + text
|
43 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
|
44 |
+
generated_ids = model.generate(
|
45 |
+
input_ids=inputs["input_ids"],
|
46 |
+
pixel_values=inputs["pixel_values"],
|
47 |
+
max_new_tokens=1024,
|
48 |
+
num_beams=3
|
49 |
+
)
|
50 |
+
generated_text = processor.batch_decode(
|
51 |
+
generated_ids, skip_special_tokens=False)[0]
|
52 |
+
response = processor.post_process_generation(
|
53 |
+
generated_text, task=task, image_size=image.size)
|
54 |
+
return generated_text, response
|
utils/sam.py
CHANGED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import supervision as sv
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
8 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
9 |
+
|
10 |
+
SAM_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
|
11 |
+
SAM_CONFIG = "sam2_hiera_s.yaml"
|
12 |
+
|
13 |
+
|
14 |
+
def load_sam_image_model(
|
15 |
+
device: torch.device,
|
16 |
+
config: str = SAM_CONFIG,
|
17 |
+
checkpoint: str = SAM_CHECKPOINT
|
18 |
+
) -> SAM2ImagePredictor:
|
19 |
+
model = build_sam2(config, checkpoint, device=device)
|
20 |
+
return SAM2ImagePredictor(sam_model=model)
|
21 |
+
|
22 |
+
|
23 |
+
def load_sam_video_model(
|
24 |
+
device: torch.device,
|
25 |
+
config: str = SAM_CONFIG,
|
26 |
+
checkpoint: str = SAM_CHECKPOINT
|
27 |
+
) -> Any:
|
28 |
+
return build_sam2_video_predictor(config, checkpoint, device=device)
|
29 |
+
|
30 |
+
|
31 |
+
def run_sam_inference(
|
32 |
+
model: Any,
|
33 |
+
image: Image,
|
34 |
+
detections: sv.Detections
|
35 |
+
) -> sv.Detections:
|
36 |
+
image = np.array(image.convert("RGB"))
|
37 |
+
model.set_image(image)
|
38 |
+
mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
|
39 |
+
|
40 |
+
# dirty fix; remove this later
|
41 |
+
if len(mask.shape) == 4:
|
42 |
+
mask = np.squeeze(mask)
|
43 |
+
|
44 |
+
detections.mask = mask.astype(bool)
|
45 |
+
return detections
|