huzey commited on
Commit
6375d85
1 Parent(s): 1838429

update app

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -3196,8 +3196,10 @@ with demo:
3196
  mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3197
  run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
3198
  add_download_button(mask_gallery, "mask")
3199
- distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (Foreground)", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
3200
- negative_distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (Background)", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
 
 
3201
  overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
3202
  # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
3203
  distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
@@ -3238,7 +3240,8 @@ with demo:
3238
  return rgbs
3239
 
3240
  def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
3241
- crop_expand, distance_threshold, distance_power, area_threshold, overlay_image, negative_distance_threshold):
 
3242
  ncut_images = [image[0] for image in ncut_images]
3243
  if len(ncut_images) == 0:
3244
  return []
@@ -3257,9 +3260,18 @@ with demo:
3257
  ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
3258
  # normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
3259
 
3260
- def to_mask(heatmap, threshold):
3261
- heatmap = 1 / (heatmap + 1e-6)
 
 
 
 
 
 
 
 
3262
  if heatmap.shape[0] > 10000:
 
3263
  random_idx = np.random.choice(heatmap.shape[0], 10000, replace=False)
3264
  vmin, vmax = heatmap[random_idx].quantile(0.01), heatmap[random_idx].quantile(0.99)
3265
  else:
@@ -3275,9 +3287,9 @@ with demo:
3275
  distance = (ncut_pixels - rgb[None]).norm(dim=-1)
3276
  distance = distance.squeeze(-1)
3277
  if is_positive:
3278
- positive_masks.append(to_mask(distance, distance_threshold))
3279
  else:
3280
- negative_masks.append(to_mask(distance, negative_distance_threshold))
3281
  if len(positive_masks) == 0:
3282
  raise gr.Error("No prompt points. Please draw some points on the image.")
3283
  positive_masks = torch.stack(positive_masks)
@@ -3397,7 +3409,8 @@ with demo:
3397
  run_crop_button.click(run_crop,
3398
  inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
3399
  crop_expand_slider, distance_threshold_slider, distance_power_slider,
3400
- area_threshold_slider, overlay_image_checkbox, negative_distance_threshold_slider],
 
3401
  outputs=[mask_gallery, crop_gallery])
3402
 
3403
 
 
3196
  mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3197
  run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
3198
  add_download_button(mask_gallery, "mask")
3199
+ distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask")
3200
+ fg_contrast_slider = gr.Slider(0, 2, step=0.01, label="Mask Scaling (FG)", value=1, elem_id="distance_focal", info="increase for smaller FG mask", visible=True)
3201
+ negative_distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (BG)", value=0.9, elem_id="distance_threshold", info="increase for less BG removal")
3202
+ bg_contrast_slider = gr.Slider(0, 2, step=0.01, label="Mask Scaling (BG)", value=1, elem_id="distance_focal", info="increase for less BG removal", visible=True)
3203
  overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
3204
  # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
3205
  distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
 
3240
  return rgbs
3241
 
3242
  def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
3243
+ crop_expand, distance_threshold, distance_power, area_threshold, overlay_image, negative_distance_threshold,
3244
+ fg_contrast, bg_contrast):
3245
  ncut_images = [image[0] for image in ncut_images]
3246
  if len(ncut_images) == 0:
3247
  return []
 
3260
  ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
3261
  # normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
3262
 
3263
+
3264
+
3265
+ def to_mask(heatmap, threshold, gamma):
3266
+ heatmap = (heatmap - heatmap.mean()) / heatmap.std()
3267
+ heatmap = heatmap.double()
3268
+ heatmap = torch.exp(heatmap)
3269
+ # heatmap = 1 / (heatmap + 1e-6)
3270
+ heatmap = 1 / heatmap ** gamma
3271
+ # import math
3272
+ # heatmap = 1 / heatmap ** math.log(6.1 - gamma)
3273
  if heatmap.shape[0] > 10000:
3274
+ np.random.seed(0)
3275
  random_idx = np.random.choice(heatmap.shape[0], 10000, replace=False)
3276
  vmin, vmax = heatmap[random_idx].quantile(0.01), heatmap[random_idx].quantile(0.99)
3277
  else:
 
3287
  distance = (ncut_pixels - rgb[None]).norm(dim=-1)
3288
  distance = distance.squeeze(-1)
3289
  if is_positive:
3290
+ positive_masks.append(to_mask(distance, distance_threshold, fg_contrast))
3291
  else:
3292
+ negative_masks.append(to_mask(distance, negative_distance_threshold, bg_contrast))
3293
  if len(positive_masks) == 0:
3294
  raise gr.Error("No prompt points. Please draw some points on the image.")
3295
  positive_masks = torch.stack(positive_masks)
 
3409
  run_crop_button.click(run_crop,
3410
  inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
3411
  crop_expand_slider, distance_threshold_slider, distance_power_slider,
3412
+ area_threshold_slider, overlay_image_checkbox, negative_distance_threshold_slider,
3413
+ fg_contrast_slider, bg_contrast_slider],
3414
  outputs=[mask_gallery, crop_gallery])
3415
 
3416