kxqt commited on
Commit
6837ba9
1 Parent(s): 307c457

[feature] add point-per-side config

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -22,8 +22,12 @@ hourglass_args = {
22
  },
23
  }
24
 
25
- def predict(image, speed_mode):
26
- mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="sam_vit_h_4b8939.pth", hourglass_kwargs=hourglass_args[speed_mode]))
 
 
 
 
27
  masks = mask_generator.generate(image)
28
 
29
  if len(masks) == 0:
@@ -54,6 +58,11 @@ def main():
54
  with gr.Row():
55
  with gr.Column():
56
  input_image = gr.Image(label="Input Image")
 
 
 
 
 
57
  speed_mode = gr.Dropdown(
58
  choices=list(hourglass_args.keys()),
59
  value="baseline",
@@ -77,7 +86,7 @@ def main():
77
 
78
  run_btn.click(
79
  fn=predict,
80
- inputs=[input_image, speed_mode],
81
  outputs=output_image
82
  )
83
  clear_btn.click(
 
22
  },
23
  }
24
 
25
+ def predict(image, speed_mode, point_per_side):
26
+ mask_generator = SamAutomaticMaskGenerator(
27
+ build_sam(checkpoint="sam_vit_h_4b8939.pth", hourglass_kwargs=hourglass_args[speed_mode]),
28
+ point_per_side=point_per_side,
29
+ points_per_batch=64 if point_per_side > 12 else point_per_side * point_per_side
30
+ )
31
  masks = mask_generator.generate(image)
32
 
33
  if len(masks) == 0:
 
58
  with gr.Row():
59
  with gr.Column():
60
  input_image = gr.Image(label="Input Image")
61
+ points_per_side = gr.Dropdown(
62
+ choices=[4, 6, 8, 12, 16, 32],
63
+ value=12,
64
+ label="Points per Side",
65
+ )
66
  speed_mode = gr.Dropdown(
67
  choices=list(hourglass_args.keys()),
68
  value="baseline",
 
86
 
87
  run_btn.click(
88
  fn=predict,
89
+ inputs=[input_image, speed_mode, points_per_side],
90
  outputs=output_image
91
  )
92
  clear_btn.click(