Spaces:
Runtime error
Runtime error
[feature] add point-per-side config
Browse files
app.py
CHANGED
@@ -22,8 +22,12 @@ hourglass_args = {
|
|
22 |
},
|
23 |
}
|
24 |
|
25 |
-
def predict(image, speed_mode):
|
26 |
-
mask_generator = SamAutomaticMaskGenerator(
|
|
|
|
|
|
|
|
|
27 |
masks = mask_generator.generate(image)
|
28 |
|
29 |
if len(masks) == 0:
|
@@ -54,6 +58,11 @@ def main():
|
|
54 |
with gr.Row():
|
55 |
with gr.Column():
|
56 |
input_image = gr.Image(label="Input Image")
|
|
|
|
|
|
|
|
|
|
|
57 |
speed_mode = gr.Dropdown(
|
58 |
choices=list(hourglass_args.keys()),
|
59 |
value="baseline",
|
@@ -77,7 +86,7 @@ def main():
|
|
77 |
|
78 |
run_btn.click(
|
79 |
fn=predict,
|
80 |
-
inputs=[input_image, speed_mode],
|
81 |
outputs=output_image
|
82 |
)
|
83 |
clear_btn.click(
|
|
|
22 |
},
|
23 |
}
|
24 |
|
25 |
+
def predict(image, speed_mode, point_per_side):
|
26 |
+
mask_generator = SamAutomaticMaskGenerator(
|
27 |
+
build_sam(checkpoint="sam_vit_h_4b8939.pth", hourglass_kwargs=hourglass_args[speed_mode]),
|
28 |
+
point_per_side=point_per_side,
|
29 |
+
points_per_batch=64 if point_per_side > 12 else point_per_side * point_per_side
|
30 |
+
)
|
31 |
masks = mask_generator.generate(image)
|
32 |
|
33 |
if len(masks) == 0:
|
|
|
58 |
with gr.Row():
|
59 |
with gr.Column():
|
60 |
input_image = gr.Image(label="Input Image")
|
61 |
+
points_per_side = gr.Dropdown(
|
62 |
+
choices=[4, 6, 8, 12, 16, 32],
|
63 |
+
value=12,
|
64 |
+
label="Points per Side",
|
65 |
+
)
|
66 |
speed_mode = gr.Dropdown(
|
67 |
choices=list(hourglass_args.keys()),
|
68 |
value="baseline",
|
|
|
86 |
|
87 |
run_btn.click(
|
88 |
fn=predict,
|
89 |
+
inputs=[input_image, speed_mode, points_per_side],
|
90 |
outputs=output_image
|
91 |
)
|
92 |
clear_btn.click(
|