SkalskiP commited on
Commit
b38c358
1 Parent(s): 16af4be

Florence-2 + SAM2 + FLUX.1

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/ filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,14 +1,18 @@
1
  from typing import Tuple
2
 
3
- import requests
4
  import random
5
  import numpy as np
6
  import gradio as gr
7
  import spaces
8
  import torch
9
- from PIL import Image
10
  from diffusers import FluxInpaintPipeline
11
 
 
 
 
 
12
  MARKDOWN = """
13
  # FLUX.1 Inpainting 🔥
14
 
@@ -19,52 +23,16 @@ for taking it to the next level by enabling inpainting with the FLUX.
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  IMAGE_SIZE = 1024
22
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
-
24
-
25
- def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
26
- image = image.convert("RGBA")
27
- data = image.getdata()
28
- new_data = []
29
- for item in data:
30
- avg = sum(item[:3]) / 3
31
- if avg < threshold:
32
- new_data.append((0, 0, 0, 0))
33
- else:
34
- new_data.append(item)
35
-
36
- image.putdata(new_data)
37
- return image
38
-
39
-
40
- EXAMPLES = [
41
- [
42
- {
43
- "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
44
- "layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2.png", stream=True).raw))],
45
- "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
46
- },
47
- "little lion",
48
- 42,
49
- False,
50
- 0.85,
51
- 30
52
- ],
53
- [
54
- {
55
- "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
56
- "layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-3.png", stream=True).raw))],
57
- "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-3.png", stream=True).raw),
58
- },
59
- "tattoos",
60
- 42,
61
- False,
62
- 0.85,
63
- 30
64
- ]
65
- ]
66
-
67
- pipe = FluxInpaintPipeline.from_pretrained(
68
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
69
 
70
 
@@ -74,11 +42,6 @@ def resize_image_dimensions(
74
  ) -> Tuple[int, int]:
75
  width, height = original_resolution_wh
76
 
77
- # if width <= maximum_dimension and height <= maximum_dimension:
78
- # width = width - (width % 32)
79
- # height = height - (height % 32)
80
- # return width, height
81
-
82
  if width > height:
83
  scaling_factor = maximum_dimension / width
84
  else:
@@ -93,17 +56,20 @@ def resize_image_dimensions(
93
  return new_width, new_height
94
 
95
 
96
- @spaces.GPU(duration=100)
 
 
97
  def process(
98
  input_image_editor: dict,
99
- input_text: str,
 
100
  seed_slicer: int,
101
  randomize_seed_checkbox: bool,
102
  strength_slider: float,
103
  num_inference_steps_slider: int,
104
  progress=gr.Progress(track_tqdm=True)
105
  ):
106
- if not input_text:
107
  gr.Info("Please enter a text prompt.")
108
  return None, None
109
 
@@ -114,21 +80,50 @@ def process(
114
  gr.Info("Please upload an image.")
115
  return None, None
116
 
117
- if not mask:
118
- gr.Info("Please draw a mask on the image.")
 
 
 
 
 
119
  return None, None
120
 
121
  width, height = resize_image_dimensions(original_resolution_wh=image.size)
122
- resized_image = image.resize((width, height), Image.LANCZOS)
123
- resized_mask = mask.resize((width, height), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  if randomize_seed_checkbox:
126
  seed_slicer = random.randint(0, MAX_SEED)
127
  generator = torch.Generator().manual_seed(seed_slicer)
128
- result = pipe(
129
- prompt=input_text,
130
- image=resized_image,
131
- mask_image=resized_mask,
132
  width=width,
133
  height=height,
134
  strength=strength_slider,
@@ -136,7 +131,7 @@ def process(
136
  num_inference_steps=num_inference_steps_slider
137
  ).images[0]
138
  print('INFERENCE DONE')
139
- return result, resized_mask
140
 
141
 
142
  with gr.Blocks() as demo:
@@ -152,17 +147,24 @@ with gr.Blocks() as demo:
152
  brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
153
 
154
  with gr.Row():
155
- input_text_component = gr.Text(
156
  label="Prompt",
157
  show_label=False,
158
  max_lines=1,
159
- placeholder="Enter your prompt",
160
  container=False,
161
  )
162
  submit_button_component = gr.Button(
163
  value='Submit', variant='primary', scale=0)
164
 
165
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
166
  seed_slicer_component = gr.Slider(
167
  label="Seed",
168
  minimum=0,
@@ -201,31 +203,13 @@ with gr.Blocks() as demo:
201
  with gr.Accordion("Debug", open=False):
202
  output_mask_component = gr.Image(
203
  type='pil', image_mode='RGB', label='Input mask', format="png")
204
- with gr.Row():
205
- gr.Examples(
206
- fn=process,
207
- examples=EXAMPLES,
208
- inputs=[
209
- input_image_editor_component,
210
- input_text_component,
211
- seed_slicer_component,
212
- randomize_seed_checkbox_component,
213
- strength_slider_component,
214
- num_inference_steps_slider_component
215
- ],
216
- outputs=[
217
- output_image_component,
218
- output_mask_component
219
- ],
220
- run_on_click=True,
221
- cache_examples=True
222
- )
223
 
224
  submit_button_component.click(
225
  fn=process,
226
  inputs=[
227
  input_image_editor_component,
228
- input_text_component,
 
229
  seed_slicer_component,
230
  randomize_seed_checkbox_component,
231
  strength_slider_component,
 
1
  from typing import Tuple
2
 
3
+ import supervision as sv
4
  import random
5
  import numpy as np
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ from PIL import Image, ImageFilter
10
  from diffusers import FluxInpaintPipeline
11
 
12
+ from utils.florence import load_florence_model, run_florence_inference, \
13
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
14
+ from utils.sam import load_sam_image_model, run_sam_inference
15
+
16
  MARKDOWN = """
17
  # FLUX.1 Inpainting 🔥
18
 
 
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
  IMAGE_SIZE = 1024
26
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
29
+ if torch.cuda.get_device_properties(0).major >= 8:
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ torch.backends.cudnn.allow_tf32 = True
32
+
33
+ FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
34
+ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
35
+ FLUX_INPAINTING_PIPELINE = FluxInpaintPipeline.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
37
 
38
 
 
42
  ) -> Tuple[int, int]:
43
  width, height = original_resolution_wh
44
 
 
 
 
 
 
45
  if width > height:
46
  scaling_factor = maximum_dimension / width
47
  else:
 
56
  return new_width, new_height
57
 
58
 
59
+ @spaces.GPU(duration=150)
60
+ @torch.inference_mode()
61
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
62
  def process(
63
  input_image_editor: dict,
64
+ inpainting_prompt_text: str,
65
+ segmentation_prompt_text: str,
66
  seed_slicer: int,
67
  randomize_seed_checkbox: bool,
68
  strength_slider: float,
69
  num_inference_steps_slider: int,
70
  progress=gr.Progress(track_tqdm=True)
71
  ):
72
+ if not inpainting_prompt_text:
73
  gr.Info("Please enter a text prompt.")
74
  return None, None
75
 
 
80
  gr.Info("Please upload an image.")
81
  return None, None
82
 
83
+ if not mask and not segmentation_prompt_text:
84
+ gr.Info("Please draw a mask or enter a segmentation prompt.")
85
+ return None, None
86
+
87
+ if mask and segmentation_prompt_text:
88
+ gr.Info("Both mask and segmentation prompt are provided. Please provide only "
89
+ "one.")
90
  return None, None
91
 
92
  width, height = resize_image_dimensions(original_resolution_wh=image.size)
93
+ image = image.resize((width, height), Image.LANCZOS)
94
+
95
+ if segmentation_prompt_text:
96
+ _, result = run_florence_inference(
97
+ model=FLORENCE_MODEL,
98
+ processor=FLORENCE_PROCESSOR,
99
+ device=DEVICE,
100
+ image=image,
101
+ task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
102
+ text=segmentation_prompt_text
103
+ )
104
+ detections = sv.Detections.from_lmm(
105
+ lmm=sv.LMM.FLORENCE_2,
106
+ result=result,
107
+ resolution_wh=image.size
108
+ )
109
+ detections = run_sam_inference(SAM_IMAGE_MODEL, image, detections)
110
+
111
+ if len(detections) == 0:
112
+ gr.Info(f"{segmentation_prompt_text} prompt did not return any detections.")
113
+ return None, None
114
+
115
+ mask = Image.fromarray((detections.mask[0].astype(np.uint8)) * 255)
116
+
117
+ mask = mask.resize((width, height), Image.LANCZOS)
118
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=10))
119
 
120
  if randomize_seed_checkbox:
121
  seed_slicer = random.randint(0, MAX_SEED)
122
  generator = torch.Generator().manual_seed(seed_slicer)
123
+ result = FLUX_INPAINTING_PIPELINE(
124
+ prompt=inpainting_prompt_text,
125
+ image=image,
126
+ mask_image=mask,
127
  width=width,
128
  height=height,
129
  strength=strength_slider,
 
131
  num_inference_steps=num_inference_steps_slider
132
  ).images[0]
133
  print('INFERENCE DONE')
134
+ return result, mask
135
 
136
 
137
  with gr.Blocks() as demo:
 
147
  brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
148
 
149
  with gr.Row():
150
+ inpainting_prompt_text_component = gr.Text(
151
  label="Prompt",
152
  show_label=False,
153
  max_lines=1,
154
+ placeholder="Enter inpainting prompt",
155
  container=False,
156
  )
157
  submit_button_component = gr.Button(
158
  value='Submit', variant='primary', scale=0)
159
 
160
  with gr.Accordion("Advanced Settings", open=False):
161
+ segmentation_prompt_text_component = gr.Text(
162
+ label="Prompt",
163
+ show_label=False,
164
+ max_lines=1,
165
+ placeholder="Enter segmentation prompt",
166
+ container=False,
167
+ )
168
  seed_slicer_component = gr.Slider(
169
  label="Seed",
170
  minimum=0,
 
203
  with gr.Accordion("Debug", open=False):
204
  output_mask_component = gr.Image(
205
  type='pil', image_mode='RGB', label='Input mask', format="png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  submit_button_component.click(
208
  fn=process,
209
  inputs=[
210
  input_image_editor_component,
211
+ inpainting_prompt_text_component,
212
+ segmentation_prompt_text_component,
213
  seed_slicer_component,
214
  randomize_seed_checkbox_component,
215
  strength_slider_component,
configs/__init__.py CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
configs/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
configs/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
configs/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
configs/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
requirements.txt CHANGED
@@ -1,6 +1,13 @@
 
 
 
 
 
 
1
  gradio
2
  spaces
3
  accelerate
4
  transformers==4.42.4
5
  sentencepiece
6
- git+https://github.com/Gothos/diffusers.git@flux-inpaint
 
 
1
+ tqdm
2
+ einops
3
+ timm
4
+ samv2
5
+ opencv-python
6
+ pytest
7
  gradio
8
  spaces
9
  accelerate
10
  transformers==4.42.4
11
  sentencepiece
12
+ supervision
13
+ git+https://github.com/Gothos/diffusers.git@flux-inpaint
utils/florence.py CHANGED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, Any, Tuple, Dict
3
+ from unittest.mock import patch
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
11
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
12
+
13
+
14
+ def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
15
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
16
+ if not str(filename).endswith("/modeling_florence2.py"):
17
+ return get_imports(filename)
18
+ imports = get_imports(filename)
19
+ imports.remove("flash_attn")
20
+ return imports
21
+
22
+
23
+ def load_florence_model(
24
+ device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
25
+ ) -> Tuple[Any, Any]:
26
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ checkpoint, trust_remote_code=True).to(device).eval()
29
+ processor = AutoProcessor.from_pretrained(
30
+ checkpoint, trust_remote_code=True)
31
+ return model, processor
32
+
33
+
34
+ def run_florence_inference(
35
+ model: Any,
36
+ processor: Any,
37
+ device: torch.device,
38
+ image: Image,
39
+ task: str,
40
+ text: str = ""
41
+ ) -> Tuple[str, Dict]:
42
+ prompt = task + text
43
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
44
+ generated_ids = model.generate(
45
+ input_ids=inputs["input_ids"],
46
+ pixel_values=inputs["pixel_values"],
47
+ max_new_tokens=1024,
48
+ num_beams=3
49
+ )
50
+ generated_text = processor.batch_decode(
51
+ generated_ids, skip_special_tokens=False)[0]
52
+ response = processor.post_process_generation(
53
+ generated_text, task=task, image_size=image.size)
54
+ return generated_text, response
utils/sam.py CHANGED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import supervision as sv
5
+ import torch
6
+ from PIL import Image
7
+ from sam2.build_sam import build_sam2, build_sam2_video_predictor
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+
10
+ SAM_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
11
+ SAM_CONFIG = "sam2_hiera_s.yaml"
12
+
13
+
14
+ def load_sam_image_model(
15
+ device: torch.device,
16
+ config: str = SAM_CONFIG,
17
+ checkpoint: str = SAM_CHECKPOINT
18
+ ) -> SAM2ImagePredictor:
19
+ model = build_sam2(config, checkpoint, device=device)
20
+ return SAM2ImagePredictor(sam_model=model)
21
+
22
+
23
+ def load_sam_video_model(
24
+ device: torch.device,
25
+ config: str = SAM_CONFIG,
26
+ checkpoint: str = SAM_CHECKPOINT
27
+ ) -> Any:
28
+ return build_sam2_video_predictor(config, checkpoint, device=device)
29
+
30
+
31
+ def run_sam_inference(
32
+ model: Any,
33
+ image: Image,
34
+ detections: sv.Detections
35
+ ) -> sv.Detections:
36
+ image = np.array(image.convert("RGB"))
37
+ model.set_image(image)
38
+ mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
39
+
40
+ # dirty fix; remove this later
41
+ if len(mask.shape) == 4:
42
+ mask = np.squeeze(mask)
43
+
44
+ detections.mask = mask.astype(bool)
45
+ return detections