aakashv100 commited on
Commit
eb5db2c
·
1 Parent(s): 244f0a7
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ import torch
4
+ import numpy as np
5
+ from utils.tools_gradio import fast_process
6
+ from utils.tools import format_results
7
+
8
+
9
+ # Load the FastSAM model
10
+ model = YOLO("./weights/FastSAM.pt")
11
+
12
+ device = torch.device("cpu")
13
+
14
+ model.to(device)
15
+
16
+
17
+ def get_input_scale(input, input_size=1024):
18
+
19
+ input_size = int(input_size)
20
+ w, h = input.size
21
+ scale = input_size / max(w, h)
22
+ new_w = int(w * scale)
23
+ new_h = int(h * scale)
24
+ input = input.resize((new_w, new_h))
25
+
26
+ return input, input_size
27
+
28
+
29
+ def segment_everything(
30
+ input,
31
+ iou_threshold=0.9,
32
+ confidence_threshold=0.4
33
+ ):
34
+
35
+ input, input_size = get_input_scale(input)
36
+
37
+ results = model(
38
+ input,
39
+ device=device,
40
+ retina_masks=True,
41
+ iou=iou_threshold,
42
+ conf=confidence_threshold,
43
+ imgsz=input_size,
44
+ )
45
+
46
+ annotations = results[0].masks.data
47
+
48
+ fig = fast_process(
49
+ annotations=annotations,
50
+ image=input,
51
+ device=device,
52
+ scale=(1024 // input_size),
53
+ better_quality=False,
54
+ mask_random_color=True,
55
+ bbox=None,
56
+ use_retina=True,
57
+ withContours=True,
58
+ )
59
+
60
+ return fig
61
+
62
+
63
+ title = "FastSAM: Fast Segment Anything"
64
+
65
+ description_e = "Demo project of FastSAM. Adapted from Ultralytics. CPU only."
66
+
67
+ examples = [
68
+ ["examples/sa_8776.jpg"],
69
+ ["examples/sa_414.jpg"],
70
+ ["examples/sa_1309.jpg"],
71
+ ["examples/sa_11025.jpg"],
72
+ ["examples/sa_561.jpg"],
73
+ ["examples/sa_192.jpg"],
74
+ ["examples/sa_10039.jpg"],
75
+ ["examples/sa_862.jpg"],
76
+ ]
77
+
78
+ default_example = examples[0]
79
+
80
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type="pil")
81
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type="pil")
82
+
83
+
84
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
85
+
86
+
87
+ with gr.Blocks(css=css, title="Fast Segment Anything") as demo:
88
+ with gr.Row():
89
+ with gr.Column(scale=1):
90
+ # Title
91
+ gr.Markdown(title)
92
+
93
+ with gr.Column(scale=1):
94
+ # News
95
+ gr.Markdown(description_e)
96
+
97
+ with gr.Tab("Everything mode"):
98
+ # Images
99
+ with gr.Row(variant="panel"):
100
+ with gr.Column(scale=1):
101
+ cond_img_e.render()
102
+
103
+ with gr.Column(scale=1):
104
+ segm_img_e.render()
105
+
106
+ # Submit & Clear
107
+ with gr.Row():
108
+
109
+ with gr.Column():
110
+ segment_btn_e = gr.Button(
111
+ "Segment Everything", variant="primary"
112
+ )
113
+ clear_btn_e = gr.Button("Clear", variant="secondary")
114
+
115
+ gr.Markdown("Try some of the examples below ⬇️")
116
+ gr.Examples(
117
+ examples=examples,
118
+ inputs=[cond_img_e],
119
+ outputs=segm_img_e,
120
+ fn=segment_everything,
121
+ cache_examples=True,
122
+ examples_per_page=4,
123
+ )
124
+
125
+ with gr.Column():
126
+ with gr.Accordion("Advanced options", open=False):
127
+ iou_threshold = gr.Slider(
128
+ 0.1,
129
+ 0.9,
130
+ 0.7,
131
+ step=0.1,
132
+ label="iou",
133
+ info="iou threshold for filtering the annotations",
134
+ )
135
+ conf_threshold = gr.Slider(
136
+ 0.1,
137
+ 0.9,
138
+ 0.25,
139
+ step=0.05,
140
+ label="conf",
141
+ info="object confidence threshold",
142
+ )
143
+
144
+ # Description
145
+ gr.Markdown(description_e)
146
+
147
+ segment_btn_e.click(
148
+ segment_everything,
149
+ inputs=[cond_img_e, iou_threshold, conf_threshold],
150
+ outputs=segm_img_e,
151
+ )
152
+
153
+ def clear():
154
+ return None, None
155
+
156
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
157
+
158
+
159
+ demo.queue()
160
+ demo.launch(debug=True)
examples/dogs.jpg ADDED

Git LFS Details

  • SHA256: 49b29517d3a6457bf8bd0b83a80cbeb24c2466bf3e5804bd503ebe60e430d784
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB
examples/sa_10039.jpg ADDED

Git LFS Details

  • SHA256: 4a9735583a997fa08e5eb36b3ba8bf17a31771bb2aea71e6d51ab9824c1d141e
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
examples/sa_11025.jpg ADDED

Git LFS Details

  • SHA256: b7edd63aa5121414bc29a760770606d09387561ff990c89f9b82c35803bd20aa
  • Pointer size: 131 Bytes
  • Size of remote file: 988 kB
examples/sa_1309.jpg ADDED

Git LFS Details

  • SHA256: b1012cbfd3ffe4ee0da940dc45961fbd1ce7546bea566f650514ec56d72b0460
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
examples/sa_192.jpg ADDED

Git LFS Details

  • SHA256: dcec4fce91382cbfeb2711fff3caeae183c23cb6d8a6c9e2ca0cd2e8eac39512
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
examples/sa_414.jpg ADDED

Git LFS Details

  • SHA256: 69dbead40b43e54d3bb80fb372c2e241b0f3ff2159d32525433a75153e067c65
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
examples/sa_561.jpg ADDED

Git LFS Details

  • SHA256: 837d725885e427534623dcc7d82ea846fffea046877c94e2e9c5b027d593796b
  • Pointer size: 131 Bytes
  • Size of remote file: 822 kB
examples/sa_862.jpg ADDED

Git LFS Details

  • SHA256: 06efc970f0d95faa6e8c69ee73f2032627569dde1c28bc783faebdaefa5eb2a8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
examples/sa_8776.jpg ADDED

Git LFS Details

  • SHA256: 7d71aea32d9f14122378a0707a4243de968d87b292a20a905351b5eacd924212
  • Pointer size: 131 Bytes
  • Size of remote file: 471 kB
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base-----------------------------------
2
+ matplotlib>=3.2.2
3
+ opencv-python>=4.6.0
4
+ Pillow>=7.1.2
5
+ PyYAML>=5.3.1
6
+ requests>=2.23.0
7
+ scipy>=1.4.1
8
+ torch>=1.7.0
9
+ torchvision>=0.8.1
10
+ tqdm>=4.64.0
11
+
12
+ pandas>=1.1.4
13
+ seaborn>=0.11.0
14
+
15
+ gradio==3.35.2
16
+
17
+ # Ultralytics-----------------------------------
18
+ ultralytics >= 8.0.120
19
+
20
+ git+https://github.com/openai/CLIP.git
utils/__pycache__/tools.cpython-311.pyc ADDED
Binary file (27.5 kB). View file
 
utils/__pycache__/tools_gradio.cpython-311.pyc ADDED
Binary file (9.32 kB). View file
 
utils/tools.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+ import os
7
+ import sys
8
+
9
+
10
+ def convert_box_xywh_to_xyxy(box):
11
+ if len(box) == 4:
12
+ return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
13
+ else:
14
+ result = []
15
+ for b in box:
16
+ b = convert_box_xywh_to_xyxy(b)
17
+ result.append(b)
18
+ return result
19
+
20
+
21
+ def segment_image(image, bbox):
22
+ image_array = np.array(image)
23
+ segmented_image_array = np.zeros_like(image_array)
24
+ x1, y1, x2, y2 = bbox
25
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
26
+ segmented_image = Image.fromarray(segmented_image_array)
27
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
28
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
29
+ transparency_mask = np.zeros(
30
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
31
+ )
32
+ transparency_mask[y1:y2, x1:x2] = 255
33
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
34
+ black_image.paste(segmented_image, mask=transparency_mask_image)
35
+ return black_image
36
+
37
+
38
+ def format_results(result, filter=0):
39
+ annotations = []
40
+ n = len(result.masks.data)
41
+ for i in range(n):
42
+ annotation = {}
43
+ mask = result.masks.data[i] == 1.0
44
+
45
+ if torch.sum(mask) < filter:
46
+ continue
47
+ annotation["id"] = i
48
+ annotation["segmentation"] = mask.cpu().numpy()
49
+ annotation["bbox"] = result.boxes.data[i]
50
+ annotation["score"] = result.boxes.conf[i]
51
+ annotation["area"] = annotation["segmentation"].sum()
52
+ annotations.append(annotation)
53
+ return annotations
54
+
55
+
56
+ def filter_masks(annotations): # filter the overlap mask
57
+ annotations.sort(key=lambda x: x["area"], reverse=True)
58
+ to_remove = set()
59
+ for i in range(0, len(annotations)):
60
+ a = annotations[i]
61
+ for j in range(i + 1, len(annotations)):
62
+ b = annotations[j]
63
+ if i != j and j not in to_remove:
64
+ # check if
65
+ if b["area"] < a["area"]:
66
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
67
+ "segmentation"
68
+ ].sum() > 0.8:
69
+ to_remove.add(j)
70
+
71
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
72
+
73
+
74
+ def get_bbox_from_mask(mask):
75
+ mask = mask.astype(np.uint8)
76
+ contours, hierarchy = cv2.findContours(
77
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
78
+ )
79
+ x1, y1, w, h = cv2.boundingRect(contours[0])
80
+ x2, y2 = x1 + w, y1 + h
81
+ if len(contours) > 1:
82
+ for b in contours:
83
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
84
+ # 将多个bbox合并成一个
85
+ x1 = min(x1, x_t)
86
+ y1 = min(y1, y_t)
87
+ x2 = max(x2, x_t + w_t)
88
+ y2 = max(y2, y_t + h_t)
89
+ h = y2 - y1
90
+ w = x2 - x1
91
+ return [x1, y1, x2, y2]
92
+
93
+
94
+ def fast_process(
95
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
96
+ ):
97
+ if isinstance(annotations[0], dict):
98
+ annotations = [annotation["segmentation"] for annotation in annotations]
99
+ result_name = os.path.basename(args.img_path)
100
+ image = cv2.imread(args.img_path)
101
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
102
+ original_h = image.shape[0]
103
+ original_w = image.shape[1]
104
+ if sys.platform == "darwin":
105
+ plt.switch_backend("TkAgg")
106
+ plt.figure(figsize=(original_w / 100, original_h / 100))
107
+ # Add subplot with no margin.
108
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
109
+ plt.margins(0, 0)
110
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
111
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
112
+ plt.imshow(image)
113
+ if args.better_quality == True:
114
+ if isinstance(annotations[0], torch.Tensor):
115
+ annotations = np.array(annotations.cpu())
116
+ for i, mask in enumerate(annotations):
117
+ mask = cv2.morphologyEx(
118
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
119
+ )
120
+ annotations[i] = cv2.morphologyEx(
121
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
122
+ )
123
+ if args.device == "cpu":
124
+ annotations = np.array(annotations)
125
+ fast_show_mask(
126
+ annotations,
127
+ plt.gca(),
128
+ random_color=mask_random_color,
129
+ bbox=bbox,
130
+ points=points,
131
+ point_label=args.point_label,
132
+ retinamask=args.retina,
133
+ target_height=original_h,
134
+ target_width=original_w,
135
+ )
136
+ else:
137
+ if isinstance(annotations[0], np.ndarray):
138
+ annotations = torch.from_numpy(annotations)
139
+ fast_show_mask_gpu(
140
+ annotations,
141
+ plt.gca(),
142
+ random_color=args.randomcolor,
143
+ bbox=bbox,
144
+ points=points,
145
+ point_label=args.point_label,
146
+ retinamask=args.retina,
147
+ target_height=original_h,
148
+ target_width=original_w,
149
+ )
150
+ if isinstance(annotations, torch.Tensor):
151
+ annotations = annotations.cpu().numpy()
152
+ if args.withContours == True:
153
+ contour_all = []
154
+ temp = np.zeros((original_h, original_w, 1))
155
+ for i, mask in enumerate(annotations):
156
+ if type(mask) == dict:
157
+ mask = mask["segmentation"]
158
+ annotation = mask.astype(np.uint8)
159
+ if args.retina == False:
160
+ annotation = cv2.resize(
161
+ annotation,
162
+ (original_w, original_h),
163
+ interpolation=cv2.INTER_NEAREST,
164
+ )
165
+ contours, hierarchy = cv2.findContours(
166
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
167
+ )
168
+ for contour in contours:
169
+ contour_all.append(contour)
170
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
171
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
172
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
173
+ plt.imshow(contour_mask)
174
+
175
+ save_path = args.output
176
+ if not os.path.exists(save_path):
177
+ os.makedirs(save_path)
178
+ plt.axis("off")
179
+ fig = plt.gcf()
180
+ plt.draw()
181
+
182
+ try:
183
+ buf = fig.canvas.tostring_rgb()
184
+ except AttributeError:
185
+ fig.canvas.draw()
186
+ buf = fig.canvas.tostring_rgb()
187
+
188
+ cols, rows = fig.canvas.get_width_height()
189
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
190
+ cv2.imwrite(
191
+ os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
192
+ )
193
+
194
+
195
+ # CPU post process
196
+ def fast_show_mask(
197
+ annotation,
198
+ ax,
199
+ random_color=False,
200
+ bbox=None,
201
+ points=None,
202
+ point_label=None,
203
+ retinamask=True,
204
+ target_height=960,
205
+ target_width=960,
206
+ ):
207
+ msak_sum = annotation.shape[0]
208
+ height = annotation.shape[1]
209
+ weight = annotation.shape[2]
210
+ # 将annotation 按照面积 排序
211
+ areas = np.sum(annotation, axis=(1, 2))
212
+ sorted_indices = np.argsort(areas)
213
+ annotation = annotation[sorted_indices]
214
+
215
+ index = (annotation != 0).argmax(axis=0)
216
+ if random_color == True:
217
+ color = np.random.random((msak_sum, 1, 1, 3))
218
+ else:
219
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
220
+ [30 / 255, 144 / 255, 255 / 255]
221
+ )
222
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
223
+ visual = np.concatenate([color, transparency], axis=-1)
224
+ mask_image = np.expand_dims(annotation, -1) * visual
225
+
226
+ show = np.zeros((height, weight, 4))
227
+ h_indices, w_indices = np.meshgrid(
228
+ np.arange(height), np.arange(weight), indexing="ij"
229
+ )
230
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
231
+ # 使用向量化索引更新show的值
232
+ show[h_indices, w_indices, :] = mask_image[indices]
233
+ if bbox is not None:
234
+ x1, y1, x2, y2 = bbox
235
+ ax.add_patch(
236
+ plt.Rectangle(
237
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
238
+ )
239
+ )
240
+ # draw point
241
+ if points is not None:
242
+ plt.scatter(
243
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
244
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
245
+ s=20,
246
+ c="y",
247
+ )
248
+ plt.scatter(
249
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
250
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
251
+ s=20,
252
+ c="m",
253
+ )
254
+
255
+ if retinamask == False:
256
+ show = cv2.resize(
257
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
258
+ )
259
+ ax.imshow(show)
260
+
261
+
262
+ def fast_show_mask_gpu(
263
+ annotation,
264
+ ax,
265
+ random_color=False,
266
+ bbox=None,
267
+ points=None,
268
+ point_label=None,
269
+ retinamask=True,
270
+ target_height=960,
271
+ target_width=960,
272
+ ):
273
+ msak_sum = annotation.shape[0]
274
+ height = annotation.shape[1]
275
+ weight = annotation.shape[2]
276
+ areas = torch.sum(annotation, dim=(1, 2))
277
+ sorted_indices = torch.argsort(areas, descending=False)
278
+ annotation = annotation[sorted_indices]
279
+ # 找每个位置第一个非零值下标
280
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
281
+ if random_color == True:
282
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
283
+ else:
284
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
285
+ [30 / 255, 144 / 255, 255 / 255]
286
+ ).to(annotation.device)
287
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
288
+ visual = torch.cat([color, transparency], dim=-1)
289
+ mask_image = torch.unsqueeze(annotation, -1) * visual
290
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
291
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
292
+ h_indices, w_indices = torch.meshgrid(
293
+ torch.arange(height), torch.arange(weight), indexing="ij"
294
+ )
295
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
296
+ # 使用向量化索引更新show的值
297
+ show[h_indices, w_indices, :] = mask_image[indices]
298
+ show_cpu = show.cpu().numpy()
299
+ if bbox is not None:
300
+ x1, y1, x2, y2 = bbox
301
+ ax.add_patch(
302
+ plt.Rectangle(
303
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
304
+ )
305
+ )
306
+ # draw point
307
+ if points is not None:
308
+ plt.scatter(
309
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
310
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
311
+ s=20,
312
+ c="y",
313
+ )
314
+ plt.scatter(
315
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
316
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
317
+ s=20,
318
+ c="m",
319
+ )
320
+ if retinamask == False:
321
+ show_cpu = cv2.resize(
322
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
323
+ )
324
+ ax.imshow(show_cpu)
325
+
326
+
327
+ # clip
328
+ @torch.no_grad()
329
+ def retriev(model, preprocess, elements: [Image.Image], search_text: str, device):
330
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
331
+ import clip
332
+
333
+ tokenized_text = clip.tokenize([search_text]).to(device)
334
+ stacked_images = torch.stack(preprocessed_images)
335
+ image_features = model.encode_image(stacked_images)
336
+ text_features = model.encode_text(tokenized_text)
337
+ image_features /= image_features.norm(dim=-1, keepdim=True)
338
+ text_features /= text_features.norm(dim=-1, keepdim=True)
339
+ probs = 100.0 * image_features @ text_features.T
340
+ return probs[:, 0].softmax(dim=0)
341
+
342
+
343
+ def crop_image(annotations, image_like):
344
+ if isinstance(image_like, str):
345
+ image = Image.open(image_like)
346
+ else:
347
+ image = image_like
348
+ ori_w, ori_h = image.size
349
+ mask_h, mask_w = annotations[0]["segmentation"].shape
350
+ if ori_w != mask_w or ori_h != mask_h:
351
+ image = image.resize((mask_w, mask_h))
352
+ cropped_boxes = []
353
+ cropped_images = []
354
+ not_crop = []
355
+ origin_id = []
356
+ for _, mask in enumerate(annotations):
357
+ if np.sum(mask["segmentation"]) <= 100:
358
+ continue
359
+ origin_id.append(_)
360
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
361
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
362
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
363
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
364
+ return cropped_boxes, cropped_images, not_crop, origin_id, annotations
365
+
366
+
367
+ def box_prompt(masks, bbox, target_height, target_width):
368
+ h = masks.shape[1]
369
+ w = masks.shape[2]
370
+ if h != target_height or w != target_width:
371
+ bbox = [
372
+ int(bbox[0] * w / target_width),
373
+ int(bbox[1] * h / target_height),
374
+ int(bbox[2] * w / target_width),
375
+ int(bbox[3] * h / target_height),
376
+ ]
377
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
378
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
379
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
380
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
381
+
382
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
383
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
384
+
385
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
386
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
387
+
388
+ union = bbox_area + orig_masks_area - masks_area
389
+ IoUs = masks_area / union
390
+ max_iou_index = torch.argmax(IoUs)
391
+
392
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
393
+
394
+
395
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
396
+ h = masks[0]["segmentation"].shape[0]
397
+ w = masks[0]["segmentation"].shape[1]
398
+ if h != target_height or w != target_width:
399
+ points = [
400
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
401
+ for point in points
402
+ ]
403
+ onemask = np.zeros((h, w))
404
+ masks = sorted(masks, key=lambda x: x["area"], reverse=True)
405
+ for i, annotation in enumerate(masks):
406
+ if type(annotation) == dict:
407
+ mask = annotation["segmentation"]
408
+ else:
409
+ mask = annotation
410
+ for i, point in enumerate(points):
411
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
412
+ onemask[mask] = 1
413
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
414
+ onemask[mask] = 0
415
+ onemask = onemask >= 1
416
+ return onemask, 0
417
+
418
+
419
+ def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
420
+ cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
421
+ annotations, img_path
422
+ )
423
+
424
+ import clip
425
+
426
+ clip_model, preprocess = clip.load("ViT-B/32", device=device)
427
+ scores = retriev(clip_model, preprocess, cropped_boxes, text, device=device)
428
+ max_idx = scores.argsort()
429
+ max_idx = max_idx[-1]
430
+ max_idx = origin_id[int(max_idx)]
431
+
432
+ # find the biggest mask which contains the mask with max score
433
+ if wider:
434
+ mask0 = annotations_[max_idx]["segmentation"]
435
+ area0 = np.sum(mask0)
436
+ areas = [
437
+ (i, np.sum(mask["segmentation"]))
438
+ for i, mask in enumerate(annotations_)
439
+ if i in origin_id
440
+ ]
441
+ areas = sorted(areas, key=lambda area: area[1], reverse=True)
442
+ indices = [area[0] for area in areas]
443
+ for index in indices:
444
+ if (
445
+ index == max_idx
446
+ or np.sum(annotations_[index]["segmentation"] & mask0) / area0
447
+ > threshold
448
+ ):
449
+ max_idx = index
450
+ break
451
+
452
+ return annotations_[max_idx]["segmentation"], max_idx
utils/tools_gradio.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ use_retina=True,
17
+ withContours=True,
18
+ ):
19
+ if isinstance(annotations[0], dict):
20
+ annotations = [annotation['segmentation'] for annotation in annotations]
21
+
22
+ original_h = image.height
23
+ original_w = image.width
24
+ if better_quality:
25
+ if isinstance(annotations[0], torch.Tensor):
26
+ annotations = np.array(annotations.cpu())
27
+ for i, mask in enumerate(annotations):
28
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
29
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
30
+ if device == 'cpu':
31
+ annotations = np.array(annotations)
32
+ inner_mask = fast_show_mask(
33
+ annotations,
34
+ plt.gca(),
35
+ random_color=mask_random_color,
36
+ bbox=bbox,
37
+ retinamask=use_retina,
38
+ target_height=original_h,
39
+ target_width=original_w,
40
+ )
41
+ else:
42
+ if isinstance(annotations[0], np.ndarray):
43
+ annotations = torch.from_numpy(annotations)
44
+ inner_mask = fast_show_mask_gpu(
45
+ annotations,
46
+ plt.gca(),
47
+ random_color=mask_random_color,
48
+ bbox=bbox,
49
+ retinamask=use_retina,
50
+ target_height=original_h,
51
+ target_width=original_w,
52
+ )
53
+ if isinstance(annotations, torch.Tensor):
54
+ annotations = annotations.cpu().numpy()
55
+
56
+ if withContours:
57
+ contour_all = []
58
+ temp = np.zeros((original_h, original_w, 1))
59
+ for i, mask in enumerate(annotations):
60
+ if type(mask) == dict:
61
+ mask = mask['segmentation']
62
+ annotation = mask.astype(np.uint8)
63
+ if use_retina == False:
64
+ annotation = cv2.resize(
65
+ annotation,
66
+ (original_w, original_h),
67
+ interpolation=cv2.INTER_NEAREST,
68
+ )
69
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
70
+ for contour in contours:
71
+ contour_all.append(contour)
72
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
73
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
74
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
75
+
76
+ image = image.convert('RGBA')
77
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
78
+ image.paste(overlay_inner, (0, 0), overlay_inner)
79
+
80
+ if withContours:
81
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
82
+ image.paste(overlay_contour, (0, 0), overlay_contour)
83
+
84
+ return image
85
+
86
+
87
+ # CPU post process
88
+ def fast_show_mask(
89
+ annotation,
90
+ ax,
91
+ random_color=False,
92
+ bbox=None,
93
+ retinamask=True,
94
+ target_height=960,
95
+ target_width=960,
96
+ ):
97
+ mask_sum = annotation.shape[0]
98
+ height = annotation.shape[1]
99
+ weight = annotation.shape[2]
100
+ # 将annotation 按照面积 排序
101
+ areas = np.sum(annotation, axis=(1, 2))
102
+ sorted_indices = np.argsort(areas)[::1]
103
+ annotation = annotation[sorted_indices]
104
+
105
+ index = (annotation != 0).argmax(axis=0)
106
+ if random_color:
107
+ color = np.random.random((mask_sum, 1, 1, 3))
108
+ else:
109
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
110
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
111
+ visual = np.concatenate([color, transparency], axis=-1)
112
+ mask_image = np.expand_dims(annotation, -1) * visual
113
+
114
+ mask = np.zeros((height, weight, 4))
115
+
116
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
117
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
118
+
119
+ mask[h_indices, w_indices, :] = mask_image[indices]
120
+ if bbox is not None:
121
+ x1, y1, x2, y2 = bbox
122
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123
+
124
+ if not retinamask:
125
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126
+
127
+ return mask
128
+
129
+
130
+ def fast_show_mask_gpu(
131
+ annotation,
132
+ ax,
133
+ random_color=False,
134
+ bbox=None,
135
+ retinamask=True,
136
+ target_height=960,
137
+ target_width=960,
138
+ ):
139
+ device = annotation.device
140
+ mask_sum = annotation.shape[0]
141
+ height = annotation.shape[1]
142
+ weight = annotation.shape[2]
143
+ areas = torch.sum(annotation, dim=(1, 2))
144
+ sorted_indices = torch.argsort(areas, descending=False)
145
+ annotation = annotation[sorted_indices]
146
+ # 找每个位置第一个非零值下标
147
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
148
+ if random_color:
149
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150
+ else:
151
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
152
+ [30 / 255, 144 / 255, 255 / 255]
153
+ ).to(device)
154
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
155
+ visual = torch.cat([color, transparency], dim=-1)
156
+ mask_image = torch.unsqueeze(annotation, -1) * visual
157
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
158
+ mask = torch.zeros((height, weight, 4)).to(device)
159
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
160
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
161
+ # 使用向量化索引更新show的值
162
+ mask[h_indices, w_indices, :] = mask_image[indices]
163
+ mask_cpu = mask.cpu().numpy()
164
+ if bbox is not None:
165
+ x1, y1, x2, y2 = bbox
166
+ ax.add_patch(
167
+ plt.Rectangle(
168
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169
+ )
170
+ )
171
+ if not retinamask:
172
+ mask_cpu = cv2.resize(
173
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174
+ )
175
+ return mask_cpu
weights/FastSAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063