RashiAgarwal commited on
Commit
8616324
·
1 Parent(s): f50a8fb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +319 -0
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import gradio as gr
3
+ import torch
4
+ from utils.tools_gradio import fast_process
5
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
6
+ from PIL import ImageDraw
7
+ import numpy as np
8
+
9
+ # Load the pre-trained model
10
+ model = YOLO('./weights/FastSAM.pt')
11
+
12
+ device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
+
20
+ # Description
21
+ title = " # Fast Segment Anything"
22
+
23
+ description_p = """ # Acknowledgement
24
+ This demo has reference to the Github project [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM)
25
+
26
+ """
27
+
28
+ examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
29
+ ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
30
+
31
+ default_example = examples[0]
32
+
33
+
34
+ def segment_everything(
35
+ input,
36
+ input_size=1024,
37
+ iou_threshold=0.7,
38
+ conf_threshold=0.25,
39
+ better_quality=False,
40
+ withContours=True,
41
+ use_retina=True,
42
+ text="",
43
+ wider=False,
44
+ mask_random_color=True,
45
+ ):
46
+ input_size = int(input_size)
47
+ w, h = input.size
48
+ scale = input_size / max(w, h)
49
+ new_w = int(w * scale)
50
+ new_h = int(h * scale)
51
+ input = input.resize((new_w, new_h))
52
+
53
+ results = model(input,
54
+ device=device,
55
+ retina_masks=True,
56
+ iou=iou_threshold,
57
+ conf=conf_threshold,
58
+ imgsz=input_size,)
59
+
60
+ if len(text) > 0:
61
+ results = format_results(results[0], 0)
62
+ annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
63
+ annotations = np.array([annotations])
64
+ else:
65
+ annotations = results[0].masks.data
66
+
67
+ fig = fast_process(annotations=annotations,
68
+ image=input,
69
+ device=device,
70
+ scale=(1024 // input_size),
71
+ better_quality=better_quality,
72
+ mask_random_color=mask_random_color,
73
+ bbox=None,
74
+ use_retina=use_retina,
75
+ withContours=withContours,)
76
+ return fig
77
+
78
+
79
+ def segment_with_points(
80
+ input,
81
+ input_size=1024,
82
+ iou_threshold=0.7,
83
+ conf_threshold=0.25,
84
+ better_quality=False,
85
+ withContours=True,
86
+ use_retina=True,
87
+ mask_random_color=True,
88
+ ):
89
+ global global_points
90
+ global global_point_label
91
+
92
+ input_size = int(input_size)
93
+ w, h = input.size
94
+ scale = input_size / max(w, h)
95
+ new_w = int(w * scale)
96
+ new_h = int(h * scale)
97
+ input = input.resize((new_w, new_h))
98
+
99
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
100
+
101
+ results = model(input,
102
+ device=device,
103
+ retina_masks=True,
104
+ iou=iou_threshold,
105
+ conf=conf_threshold,
106
+ imgsz=input_size,)
107
+
108
+ results = format_results(results[0], 0)
109
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
110
+ annotations = np.array([annotations])
111
+
112
+ fig = fast_process(annotations=annotations,
113
+ image=input,
114
+ device=device,
115
+ scale=(1024 // input_size),
116
+ better_quality=better_quality,
117
+ mask_random_color=mask_random_color,
118
+ bbox=None,
119
+ use_retina=use_retina,
120
+ withContours=withContours,)
121
+
122
+ global_points = []
123
+ global_point_label = []
124
+ return fig, None
125
+
126
+
127
+ def get_points_with_draw(image, label, evt: gr.SelectData):
128
+ global global_points
129
+ global global_point_label
130
+
131
+ x, y = evt.index[0], evt.index[1]
132
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
133
+ global_points.append([x, y])
134
+ global_point_label.append(1 if label == 'Add Mask' else 0)
135
+
136
+ print(x, y, label == 'Add Mask')
137
+
138
+
139
+ draw = ImageDraw.Draw(image)
140
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
141
+ return image
142
+
143
+
144
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
145
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
146
+ cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
147
+
148
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
149
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
150
+ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
151
+
152
+ global_points = []
153
+ global_point_label = []
154
+
155
+ input_size_slider = gr.components.Slider(minimum=512,
156
+ maximum=1024,
157
+ value=1024,
158
+ step=64,
159
+ label='Input_size',
160
+ info='Our model was trained on a size of 1024')
161
+
162
+ with gr.Blocks(title='Fast Segment Anything') as demo:
163
+ with gr.Row():
164
+ with gr.Column(scale=2):
165
+ # Title
166
+ gr.Markdown(title)
167
+
168
+ with gr.Tab("Everything mode"):
169
+ # Images
170
+ with gr.Row(variant="panel"):
171
+ with gr.Column(scale=1):
172
+ cond_img_e.render()
173
+
174
+ with gr.Column(scale=1):
175
+ segm_img_e.render()
176
+
177
+ # Submit & Clear
178
+ with gr.Row():
179
+ with gr.Column():
180
+ input_size_slider.render()
181
+
182
+ with gr.Row():
183
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
184
+
185
+ with gr.Column():
186
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
187
+ clear_btn_e = gr.Button("Clear", variant="secondary")
188
+
189
+ gr.Markdown("Try some of the examples below ⬇️")
190
+ gr.Examples(examples=examples,
191
+ inputs=[cond_img_e],
192
+ outputs=segm_img_e,
193
+ fn=segment_everything,
194
+ cache_examples=True,
195
+ examples_per_page=4)
196
+
197
+ with gr.Column():
198
+ with gr.Accordion("Advanced options", open=False):
199
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
200
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
201
+ with gr.Row():
202
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
203
+ with gr.Column():
204
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
205
+
206
+
207
+
208
+ segment_btn_e.click(segment_everything,
209
+ inputs=[
210
+ cond_img_e,
211
+ input_size_slider,
212
+ iou_threshold,
213
+ conf_threshold,
214
+ mor_check,
215
+ contour_check,
216
+ retina_check,
217
+ ],
218
+ outputs=segm_img_e)
219
+
220
+ with gr.Tab("Points mode"):
221
+ # Images
222
+ with gr.Row(variant="panel"):
223
+ with gr.Column(scale=1):
224
+ cond_img_p.render()
225
+
226
+ with gr.Column(scale=1):
227
+ segm_img_p.render()
228
+
229
+ # Submit & Clear
230
+ with gr.Row():
231
+ with gr.Column():
232
+ with gr.Row():
233
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
234
+
235
+ with gr.Column():
236
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
237
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
238
+
239
+ gr.Markdown("Try some of the examples below ⬇️")
240
+ gr.Examples(examples=examples,
241
+ inputs=[cond_img_p],
242
+ # outputs=segm_img_p,
243
+ # fn=segment_with_points,
244
+ # cache_examples=True,
245
+ examples_per_page=4)
246
+
247
+ with gr.Column():
248
+ # Description
249
+ gr.Markdown(description_p)
250
+
251
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
252
+
253
+ segment_btn_p.click(segment_with_points,
254
+ inputs=[cond_img_p],
255
+ outputs=[segm_img_p, cond_img_p])
256
+
257
+ with gr.Tab("Text mode"):
258
+ # Images
259
+ with gr.Row(variant="panel"):
260
+ with gr.Column(scale=1):
261
+ cond_img_t.render()
262
+
263
+ with gr.Column(scale=1):
264
+ segm_img_t.render()
265
+
266
+ # Submit & Clear
267
+ with gr.Row():
268
+ with gr.Column():
269
+ input_size_slider_t = gr.components.Slider(minimum=512,
270
+ maximum=1024,
271
+ value=1024,
272
+ step=64,
273
+ label='Input_size',
274
+ info='Our model was trained on a size of 1024')
275
+ with gr.Row():
276
+ with gr.Column():
277
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
278
+ text_box = gr.Textbox(label="text prompt", value="a black dog")
279
+
280
+ with gr.Column():
281
+ segment_btn_t = gr.Button("Segment with text", variant='primary')
282
+ clear_btn_t = gr.Button("Clear", variant="secondary")
283
+
284
+ gr.Markdown("Try some of the examples below ⬇️")
285
+ gr.Examples(examples=[["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"]],
286
+ inputs=[cond_img_t],
287
+ # outputs=segm_img_e,
288
+ # fn=segment_everything,
289
+ # cache_examples=True,
290
+ examples_per_page=4)
291
+
292
+
293
+
294
+ segment_btn_t.click(segment_everything,
295
+ inputs=[
296
+ cond_img_t,
297
+ input_size_slider_t,
298
+ iou_threshold,
299
+ conf_threshold,
300
+ mor_check,
301
+ contour_check,
302
+ retina_check,
303
+ text_box,
304
+ wider_check,
305
+ ],
306
+ outputs=segm_img_t)
307
+
308
+ def clear():
309
+ return None, None
310
+
311
+ def clear_text():
312
+ return None, None, None
313
+
314
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
315
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
316
+ clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
317
+
318
+ demo.queue()
319
+ demo.launch()