xqt commited on
Commit
71139a9
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cars.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .tmp/
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Segment Anything 2 Assist
3
+ emoji: 👁
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.42.0
8
+ app_file: SegmentAnything2AssistApp.py
9
+ pinned: true
10
+ license: bsd-3-clause
11
+ short_description: A tool to use SAM2 on images.
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SegmentAnything2AssistApp.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import gradio_image_annotation
3
+ import gradio_imageslider
4
+ import spaces
5
+ import torch
6
+
7
+ import src.SegmentAnything2Assist as SegmentAnything2Assist
8
+
9
+ example_image_annotation = {
10
+ "image": "assets/cars.jpg",
11
+ "boxes": [{'label': '+', 'color': (0, 255, 0), 'xmin': 886, 'ymin': 551, 'xmax': 886, 'ymax': 551}, {'label': '-', 'color': (255, 0, 0), 'xmin': 1239, 'ymin': 576, 'xmax': 1239, 'ymax': 576}, {'label': '-', 'color': (255, 0, 0), 'xmin': 610, 'ymin': 574, 'xmax': 610, 'ymax': 574}, {'label': '', 'color': (0, 0, 255), 'xmin': 254, 'ymin': 466, 'xmax': 1347, 'ymax': 1047}]
12
+ }
13
+
14
+ VERBOSE = True
15
+
16
+ segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = "sam2_hiera_tiny", device = torch.device("cuda"))
17
+ __image_point_coords = None
18
+ __image_point_labels = None
19
+ __image_box = None
20
+ __current_mask = None
21
+ __current_segment = None
22
+
23
+ def __change_base_model(model_name, device):
24
+ global segment_anything2assist
25
+ try:
26
+ segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = model_name, device = torch.device(device))
27
+ gradio.Info(f"Model changed to {model_name} on {device}", duration = 5)
28
+ except:
29
+ gradio.Error(f"Model could not be changed", duration = 5)
30
+
31
+ def __post_process_annotator_inputs(value):
32
+ global __image_point_coords, __image_point_labels, __image_box
33
+ global __current_mask, __current_segment
34
+ if VERBOSE:
35
+ print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.")
36
+ __current_mask, __current_segment = None, None
37
+ new_boxes = []
38
+ __image_point_coords = []
39
+ __image_point_labels = []
40
+ __image_box = []
41
+
42
+ b_has_box = False
43
+ for box in value["boxes"]:
44
+ if box['label'] == '':
45
+ if not b_has_box:
46
+ new_box = box.copy()
47
+ new_box['color'] = (0, 0, 255)
48
+ new_boxes.append(new_box)
49
+ b_has_box = True
50
+ __image_box = [
51
+ box['xmin'],
52
+ box['ymin'],
53
+ box['xmax'],
54
+ box['ymax']
55
+ ]
56
+
57
+
58
+ elif box['label'] == '+' or box['label'] == '-':
59
+ new_box = box.copy()
60
+ new_box['color'] = (0, 255, 0) if box['label'] == '+' else (255, 0, 0)
61
+ new_box['xmin'] = int((box['xmin'] + box['xmax']) / 2)
62
+ new_box['ymin'] = int((box['ymin'] + box['ymax']) / 2)
63
+ new_box['xmax'] = new_box['xmin']
64
+ new_box['ymax'] = new_box['ymin']
65
+ new_boxes.append(new_box)
66
+
67
+ __image_point_coords.append([new_box['xmin'], new_box['ymin']])
68
+ __image_point_labels.append(1 if box['label'] == '+' else 0)
69
+
70
+ if len(__image_box) == 0:
71
+ __image_box = None
72
+
73
+ if len(__image_point_coords) == 0:
74
+ __image_point_coords = None
75
+
76
+ if len(__image_point_labels) == 0:
77
+ __image_point_labels = None
78
+
79
+ if VERBOSE:
80
+ print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.")
81
+
82
+
83
+
84
+ @spaces.GPU(duration = 60)
85
+ def __generate_mask(value, mask_threshold, max_hole_area, max_sprinkle_area, image_output_mode):
86
+ global __current_mask, __current_segment
87
+ global __image_point_coords, __image_point_labels, __image_box
88
+ global segment_anything2assist
89
+ if VERBOSE:
90
+ print("SegmentAnything2AssistApp::__generate_mask::Called.")
91
+ mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image(
92
+ value["image"],
93
+ __image_point_coords,
94
+ __image_point_labels,
95
+ __image_box,
96
+ mask_threshold,
97
+ max_hole_area,
98
+ max_sprinkle_area
99
+ )
100
+
101
+ if VERBOSE:
102
+ print("SegmentAnything2AssistApp::__generate_mask::Masks generated.")
103
+
104
+ __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(value["image"], mask_chw[0])
105
+
106
+ if VERBOSE:
107
+ print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.")
108
+
109
+ if image_output_mode == "Mask":
110
+ return [value["image"], __current_mask]
111
+ elif image_output_mode == "Segment":
112
+ return [value["image"], __current_segment]
113
+ else:
114
+ gradio.Warning("This is an issue, please report the problem!", duration=5)
115
+ return gradio_imageslider.ImageSlider(render = True)
116
+
117
+ def __change_output_mode(image_input, radio):
118
+ global __current_mask, __current_segment
119
+ global __image_point_coords, __image_point_labels, __image_box
120
+ if VERBOSE:
121
+ print("SegmentAnything2AssistApp::__generate_mask::Called.")
122
+ if __current_mask is None or __current_segment is None:
123
+ gradio.Warning("Configuration was changed, generate the mask again", duration=5)
124
+ return gradio_imageslider.ImageSlider(render = True)
125
+ if radio == "Mask":
126
+ return [image_input["image"], __current_mask]
127
+ elif radio == "Segment":
128
+ return [image_input["image"], __current_segment]
129
+ else:
130
+ gradio.Warning("This is an issue, please report the problem!", duration=5)
131
+ return gradio_imageslider.ImageSlider(render = True)
132
+
133
+ def __generate_multi_mask_output(image, auto_list, auto_mode, auto_bbox_mode):
134
+ global segment_anything2assist
135
+ image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image(image, [int(i) - 1 for i in auto_list])
136
+
137
+ output_1 = image_with_bbox if auto_bbox_mode else image
138
+ output_2 = mask if auto_mode == "Mask" else segment
139
+ return [output_1, output_2]
140
+
141
+ @spaces.GPU(duration = 60)
142
+ def __generate_auto_mask(
143
+ image,
144
+ points_per_side,
145
+ points_per_batch,
146
+ pred_iou_thresh,
147
+ stability_score_thresh,
148
+ stability_score_offset,
149
+ mask_threshold,
150
+ box_nms_thresh,
151
+ crop_n_layers,
152
+ crop_nms_thresh,
153
+ crop_overlay_ratio,
154
+ crop_n_points_downscale_factor,
155
+ min_mask_region_area,
156
+ use_m2m,
157
+ multimask_output,
158
+ output_mode
159
+ ):
160
+ global segment_anything2assist
161
+ if VERBOSE:
162
+ print("SegmentAnything2AssistApp::__generate_auto_mask::Called.")
163
+
164
+ __auto_masks = segment_anything2assist.generate_automatic_masks(
165
+ image,
166
+ points_per_side,
167
+ points_per_batch,
168
+ pred_iou_thresh,
169
+ stability_score_thresh,
170
+ stability_score_offset,
171
+ mask_threshold,
172
+ box_nms_thresh,
173
+ crop_n_layers,
174
+ crop_nms_thresh,
175
+ crop_overlay_ratio,
176
+ crop_n_points_downscale_factor,
177
+ min_mask_region_area,
178
+ use_m2m,
179
+ multimask_output
180
+ )
181
+
182
+ if len(__auto_masks) == 0:
183
+ gradio.Warning("No masks generated, please tweak the advanced parameters.", duration = 5)
184
+ return gradio_imageslider.ImageSlider(), \
185
+ gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False), \
186
+ gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False)
187
+ else:
188
+ choices = [str(i) for i in range(len(__auto_masks))]
189
+ returning_image = __generate_multi_mask_output(image, ["0"], output_mode, False)
190
+ return returning_image, \
191
+ gradio.CheckboxGroup(choices, value = ["0"], label = "Mask List", interactive = True), \
192
+ gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = True)
193
+
194
+ with gradio.Blocks() as base_app:
195
+ gradio.Markdown("# SegmentAnything2Assist")
196
+ with gradio.Row():
197
+ with gradio.Column():
198
+ base_model_choice = gradio.Dropdown(
199
+ ['sam2_hiera_large', 'sam2_hiera_small', 'sam2_hiera_base_plus','sam2_hiera_tiny'],
200
+ value = 'sam2_hiera_tiny',
201
+ label = "Model Choice"
202
+ )
203
+ with gradio.Column():
204
+ base_gpu_choice = gradio.Dropdown(
205
+ ['cpu', 'cuda'],
206
+ value = 'cuda',
207
+ label = "Device Choice"
208
+ )
209
+ base_model_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice])
210
+ base_gpu_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice])
211
+ with gradio.Tab(label = "Image Segmentation", id = "image_tab") as image_tab:
212
+ gradio.Markdown("Image Segmentation", render = True)
213
+ with gradio.Column():
214
+ with gradio.Accordion("Image Annotation Documentation", open = False):
215
+ gradio.Markdown("""
216
+ Image annotation allows you to mark specific regions of an image with labels.
217
+ In this app, you can annotate an image by drawing boxes and assigning labels to them.
218
+ The labels can be either '+' or '-'.
219
+ To annotate an image, simply click and drag to draw a box around the desired region.
220
+ You can add multiple boxes with different labels.
221
+ Once you have annotated the image, click the 'Generate Mask' button to generate a mask based on the annotations.
222
+ The mask can be either a binary mask or a segmented mask, depending on the selected output mode.
223
+ You can switch between the output modes using the radio buttons.
224
+ If you make any changes to the annotations or the output mode, you need to regenerate the mask by clicking the button again.
225
+ Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area.
226
+ These options control the sensitivity and accuracy of the segmentation process.
227
+ Experiment with different settings to achieve the desired results.
228
+ """)
229
+ image_input = gradio_image_annotation.image_annotator(example_image_annotation)
230
+ with gradio.Accordion("Advanced Options", open = False):
231
+ image_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "SAM Mask Threshold")
232
+ image_generate_SAM_max_hole_area = gradio.Slider(0, 1000, 0, label = "SAM Max Hole Area")
233
+ image_generate_SAM_max_sprinkle_area = gradio.Slider(0, 1000, 0, label = "SAM Max Sprinkle Area")
234
+ image_generate_mask_button = gradio.Button("Generate Mask")
235
+ image_output = gradio_imageslider.ImageSlider()
236
+ image_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode")
237
+
238
+ image_input.change(__post_process_annotator_inputs, inputs = [image_input])
239
+ image_generate_mask_button.click(__generate_mask, inputs = [
240
+ image_input,
241
+ image_generate_SAM_mask_threshold,
242
+ image_generate_SAM_max_hole_area,
243
+ image_generate_SAM_max_sprinkle_area,
244
+ image_output_mode
245
+ ],
246
+ outputs = [image_output])
247
+ image_output_mode.change(__change_output_mode, inputs = [image_input, image_output_mode], outputs = [image_output])
248
+ with gradio.Tab(label = "Auto Segmentation", id = "auto_tab"):
249
+ gradio.Markdown("Auto Segmentation", render = True)
250
+ with gradio.Column():
251
+ with gradio.Accordion("Auto Annotation Documentation", open = False):
252
+ gradio.Markdown("""
253
+ """)
254
+ auto_input = gradio.Image("assets/cars.jpg")
255
+ with gradio.Accordion("Advanced Options", open = False):
256
+ auto_generate_SAM_points_per_side = gradio.Slider(1, 64, 32, 1, label = "Points Per Side", interactive = True)
257
+ auto_generate_SAM_points_per_batch = gradio.Slider(1, 64, 32, 1, label = "Points Per Batch", interactive = True)
258
+ auto_generate_SAM_pred_iou_thresh = gradio.Slider(0.0, 1.0, 0.8, 1, label = "Pred IOU Threshold", interactive = True)
259
+ auto_generate_SAM_stability_score_thresh = gradio.Slider(0.0, 1.0, 0.95, label = "Stability Score Threshold", interactive = True)
260
+ auto_generate_SAM_stability_score_offset = gradio.Slider(0.0, 1.0, 1.0, label = "Stability Score Offset", interactive = True)
261
+ auto_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "Mask Threshold", interactive = True)
262
+ auto_generate_SAM_box_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Box NMS Threshold", interactive = True)
263
+ auto_generate_SAM_crop_n_layers = gradio.Slider(0, 10, 0, 1, label = "Crop N Layers", interactive = True)
264
+ auto_generate_SAM_crop_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Crop NMS Threshold", interactive = True)
265
+ auto_generate_SAM_crop_overlay_ratio = gradio.Slider(0.0, 1.0, 512 / 1500, label = "Crop Overlay Ratio", interactive = True)
266
+ auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider(1, 10, 1, label = "Crop N Points Downscale Factor", interactive = True)
267
+ auto_generate_SAM_min_mask_region_area = gradio.Slider(0, 1000, 0, label = "Min Mask Region Area", interactive = True)
268
+ auto_generate_SAM_use_m2m = gradio.Checkbox(label = "Use M2M", interactive = True)
269
+ auto_generate_SAM_multimask_output = gradio.Checkbox(value = True, label = "Multi Mask Output", interactive = True)
270
+ auto_generate_button = gradio.Button("Generate Auto Mask")
271
+ with gradio.Row():
272
+ with gradio.Column():
273
+ auto_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode", interactive = True)
274
+ auto_output_list = gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False)
275
+ auto_output_bbox = gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False)
276
+ with gradio.Column(scale = 3):
277
+ auto_output = gradio_imageslider.ImageSlider()
278
+
279
+ auto_generate_button.click(
280
+ __generate_auto_mask,
281
+ inputs = [
282
+ auto_input,
283
+ auto_generate_SAM_points_per_side,
284
+ auto_generate_SAM_points_per_batch,
285
+ auto_generate_SAM_pred_iou_thresh,
286
+ auto_generate_SAM_stability_score_thresh,
287
+ auto_generate_SAM_stability_score_offset,
288
+ auto_generate_SAM_mask_threshold,
289
+ auto_generate_SAM_box_nms_thresh,
290
+ auto_generate_SAM_crop_n_layers,
291
+ auto_generate_SAM_crop_nms_thresh,
292
+ auto_generate_SAM_crop_overlay_ratio,
293
+ auto_generate_SAM_crop_n_points_downscale_factor,
294
+ auto_generate_SAM_min_mask_region_area,
295
+ auto_generate_SAM_use_m2m,
296
+ auto_generate_SAM_multimask_output,
297
+ auto_output_mode
298
+ ],
299
+ outputs = [
300
+ auto_output,
301
+ auto_output_list,
302
+ auto_output_bbox
303
+ ]
304
+ )
305
+ auto_output_list.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output])
306
+ auto_output_bbox.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output])
307
+ auto_output_mode.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output])
308
+
309
+
310
+ if __name__ == "__main__":
311
+ base_app.launch()
312
+
assets/cars.jpg ADDED

Git LFS Details

  • SHA256: 76e496e8975c7f21955cbe73aaa027e541fccf5169d50744e14df780717ee52a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
requirements.txt ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==4.4.0
5
+ certifi==2024.7.4
6
+ charset-normalizer==3.3.2
7
+ click==8.1.7
8
+ contourpy==1.2.1
9
+ cycler==0.12.1
10
+ fastapi==0.112.2
11
+ ffmpy==0.4.0
12
+ filelock==3.15.4
13
+ fonttools==4.53.1
14
+ fsspec==2024.6.1
15
+ gradio==4.42.0
16
+ gradio_client==1.3.0
17
+ gradio_image_annotation==0.2.3
18
+ gradio_imageslider==0.0.20
19
+ h11==0.14.0
20
+ httpcore==1.0.5
21
+ httpx==0.27.0
22
+ huggingface-hub==0.24.6
23
+ hydra-core==1.3.2
24
+ idna==3.7
25
+ importlib_resources==6.4.4
26
+ iopath==0.1.10
27
+ Jinja2==3.1.4
28
+ kiwisolver==1.4.5
29
+ markdown-it-py==3.0.0
30
+ MarkupSafe==2.1.5
31
+ matplotlib==3.9.2
32
+ mdurl==0.1.2
33
+ mpmath==1.3.0
34
+ networkx==3.2.1
35
+ numpy==2.1.0
36
+ omegaconf==2.3.0
37
+ opencv-python==4.10.0.84
38
+ orjson==3.10.7
39
+ packaging==24.1
40
+ pandas==2.2.2
41
+ pillow==10.4.0
42
+ portalocker==2.10.1
43
+ psutil==5.9.8
44
+ pydantic==2.8.2
45
+ pydantic_core==2.20.1
46
+ pydub==0.25.1
47
+ Pygments==2.18.0
48
+ pyparsing==3.1.4
49
+ python-dateutil==2.9.0.post0
50
+ python-multipart==0.0.9
51
+ pytz==2024.1
52
+ PyYAML==6.0.2
53
+ requests==2.32.3
54
+ rich==13.7.1
55
+ ruff==0.6.2
56
+ SAM-2 @ git+https://github.com/facebookresearch/segment-anything-2.git@7e1596c0b6462eb1d1ba7e1492430fed95023598
57
+ semantic-version==2.10.0
58
+ setuptools==73.0.1
59
+ shellingham==1.5.4
60
+ six==1.16.0
61
+ sniffio==1.3.1
62
+ spaces==0.29.3
63
+ starlette==0.38.2
64
+ sympy==1.13.2
65
+ tomlkit==0.12.0
66
+ tqdm==4.66.5
67
+ typer==0.12.5
68
+ typing_extensions==4.12.2
69
+ tzdata==2024.1
70
+ urllib3==2.2.2
71
+ uvicorn==0.30.6
72
+ websockets==12.0
src/SegmentAnything2Assist.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import os
3
+ import sam2.sam2_image_predictor
4
+ import tqdm
5
+ import requests
6
+ import torch
7
+ import numpy
8
+ import pickle
9
+
10
+ import sam2.build_sam
11
+ import sam2.automatic_mask_generator
12
+
13
+ import cv2
14
+
15
+ SAM2_MODELS = {
16
+ "sam2_hiera_tiny": {
17
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
18
+ "model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt",
19
+ "config_file": "sam2_hiera_t.yaml"
20
+ },
21
+ "sam2_hiera_small": {
22
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
23
+ "model_path": ".tmp/checkpoints/sam2_hiera_small.pt",
24
+ "config_file": "sam2_hiera_s.yaml"
25
+ },
26
+ "sam2_hiera_base_plus": {
27
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
28
+ "model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt",
29
+ "config_file": "sam2_hiera_b+.yaml"
30
+ },
31
+ "sam2_hiera_large": {
32
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
33
+ "model_path": ".tmp/checkpoints/sam2_hiera_large.pt",
34
+ "config_file": "sam2_hiera_l.yaml"
35
+ },
36
+ }
37
+
38
+ class SegmentAnything2Assist:
39
+ def __init__(
40
+ self,
41
+ model_name: str | typing.Literal["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_large"] = "sam2_hiera_small",
42
+ configuration: str |typing.Literal["Automatic Mask Generator", "Image"] = "Automatic Mask Generator",
43
+ download_url: str | None = None,
44
+ model_path: str | None = None,
45
+ download: bool = True,
46
+ device: str | torch.device = torch.device("cpu"),
47
+ verbose: bool = True
48
+ ) -> None:
49
+ assert model_name in SAM2_MODELS.keys(), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}"
50
+ assert configuration in ["Automatic Mask Generator", "Image"]
51
+
52
+ self.model_name = model_name
53
+ self.configuration = configuration
54
+ self.config_file = SAM2_MODELS[model_name]["config_file"]
55
+ self.device = device
56
+
57
+ self.download_url = download_url if download_url is not None else SAM2_MODELS[model_name]["download_url"]
58
+ self.model_path = model_path if model_path is not None else SAM2_MODELS[model_name]["model_path"]
59
+ os.makedirs(os.path.dirname(self.model_path), exist_ok = True)
60
+ self.verbose = verbose
61
+
62
+ if self.verbose:
63
+ print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}")
64
+ print(f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}")
65
+ print(f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}")
66
+ print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}")
67
+ print(f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}")
68
+
69
+ if download:
70
+ self.download_model()
71
+
72
+ if self.is_model_available():
73
+ self.sam2 = sam2.build_sam.build_sam2(config_file = self.config_file, ckpt_path = self.model_path, device = self.device)
74
+ if self.verbose:
75
+ print("SegmentAnything2Assist::__init__::SAM2 is loaded.")
76
+ else:
77
+ self.sam2 = None
78
+ if self.verbose:
79
+ print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
80
+
81
+
82
+ def is_model_available(self) -> bool:
83
+ ret = os.path.exists(self.model_path)
84
+ if self.verbose:
85
+ print(f"SegmentAnything2Assist::is_model_available::{ret}")
86
+ return ret
87
+
88
+ def load_model(self) -> None:
89
+ if self.is_model_available():
90
+ self.sam2 = sam2.build_sam(checkpoint = self.model_path)
91
+
92
+ def download_model(
93
+ self,
94
+ force: bool = False
95
+ ) -> None:
96
+ if not force and self.is_model_available():
97
+ print(f"{self.model_path} already exists. Skipping download.")
98
+ return
99
+
100
+ response = requests.get(self.download_url, stream=True)
101
+ total_size = int(response.headers.get('content-length', 0))
102
+
103
+ with open(self.model_path, 'wb') as file, tqdm.tqdm(total = total_size, unit = 'B', unit_scale = True) as progress_bar:
104
+ for data in response.iter_content(chunk_size = 1024):
105
+ file.write(data)
106
+ progress_bar.update(len(data))
107
+
108
+ def generate_automatic_masks(
109
+ self,
110
+ image,
111
+ points_per_side = 32,
112
+ points_per_batch = 32,
113
+ pred_iou_thresh = 0.8,
114
+ stability_score_thresh = 0.95,
115
+ stability_score_offset = 1.0,
116
+ mask_threshold = 0.0,
117
+ box_nms_thresh = 0.7,
118
+ crop_n_layers = 0,
119
+ crop_nms_thresh = 0.7,
120
+ crop_overlay_ratio = 512 / 1500,
121
+ crop_n_points_downscale_factor = 1,
122
+ min_mask_region_area = 0,
123
+ use_m2m = False,
124
+ multimask_output = True
125
+ ):
126
+ if self.sam2 is None:
127
+ print("SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded.")
128
+ return None
129
+
130
+ generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator(
131
+ model = self.sam2,
132
+ points_per_side = points_per_side,
133
+ points_per_batch = points_per_batch,
134
+ pred_iou_thresh = pred_iou_thresh,
135
+ stability_score_thresh = stability_score_thresh,
136
+ stability_score_offset = stability_score_offset,
137
+ mask_threshold = mask_threshold,
138
+ box_nms_thresh = box_nms_thresh,
139
+ crop_n_layers = crop_n_layers,
140
+ crop_nms_thresh = crop_nms_thresh,
141
+ crop_overlay_ratio = crop_overlay_ratio,
142
+ crop_n_points_downscale_factor = crop_n_points_downscale_factor,
143
+ min_mask_region_area = min_mask_region_area,
144
+ use_m2m = use_m2m,
145
+ multimask_output = multimask_output
146
+ )
147
+ masks = generator.generate(image)
148
+
149
+ pickle.dump(masks, open(".tmp/auto_masks.pkl", "wb"))
150
+
151
+ return masks
152
+
153
+ def generate_masks_from_image(
154
+ self,
155
+ image,
156
+ point_coords,
157
+ point_labels,
158
+ box,
159
+ mask_threshold = 0.0,
160
+ max_hole_area = 0.0,
161
+ max_sprinkle_area = 0.0
162
+ ):
163
+ generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
164
+ self.sam2,
165
+ mask_threshold = mask_threshold,
166
+ max_hole_area = max_hole_area,
167
+ max_sprinkle_area = max_sprinkle_area
168
+ )
169
+ generator.set_image(image)
170
+
171
+ masks_chw, mask_iou, mask_low_logits = generator.predict(
172
+ point_coords = numpy.array(point_coords) if point_coords is not None else None,
173
+ point_labels = numpy.array(point_labels) if point_labels is not None else None,
174
+ box = numpy.array(box) if box is not None else None,
175
+ multimask_output = False
176
+ )
177
+
178
+ return masks_chw, mask_iou
179
+
180
+ def apply_mask_to_image(
181
+ self,
182
+ image,
183
+ mask
184
+ ):
185
+ mask = numpy.array(mask)
186
+ mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8)
187
+ segment = cv2.bitwise_and(image, image, mask = mask)
188
+ return mask, segment
189
+
190
+ def apply_auto_mask_to_image(
191
+ self,
192
+ image,
193
+ auto_list
194
+ ):
195
+ if not os.path.exists("auto_masks.pkl"):
196
+ return
197
+
198
+ masks = pickle.load(open(".tmp/auto_masks.pkl", "rb"))
199
+
200
+ image_with_bounding_boxes = image.copy()
201
+ all_masks = None
202
+ for _ in auto_list:
203
+ mask = numpy.array(masks[_]['segmentation'])
204
+ mask = numpy.where(mask == True, 255, 0).astype(numpy.uint8)
205
+ bbox = masks[_]["bbox"]
206
+ if all_masks is None:
207
+ all_masks = mask
208
+ else:
209
+ all_masks = cv2.bitwise_or(all_masks, mask)
210
+
211
+ random_color = numpy.random.randint(0, 255, size = 3)
212
+ image_with_bounding_boxes = cv2.rectangle(image_with_bounding_boxes, (int(bbox[0]), int(bbox[1])), (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), random_color.tolist(), 2)
213
+ image_with_bounding_boxes = cv2.putText(image_with_bounding_boxes, f"{_ + 1}", (int(bbox[0]), int(bbox[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, random_color.tolist(), 2)
214
+
215
+ all_masks = numpy.where(all_masks > 0, 255, 0).astype(numpy.uint8)
216
+ image_with_segments = cv2.bitwise_and(image, image, mask = all_masks)
217
+ return image_with_bounding_boxes, all_masks, image_with_segments