huzey commited on
Commit
03a488b
1 Parent(s): 369bb85

update mask logic

Browse files
Files changed (1) hide show
  1. app.py +39 -37
app.py CHANGED
@@ -3134,10 +3134,10 @@ with demo:
3134
  image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True)
3135
  load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
3136
  gr.Markdown("### Step 2b: Draw Points")
3137
- gr.Markdown("##### 🖱️ Left Click: Foreground")
3138
- gr.Markdown("##### 🖱️ Middle Click: Background")
3139
  gr.Markdown("""
3140
  <h5>
 
 
3141
  Top Right Buttons: </br>
3142
  <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none"
3143
  stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
@@ -3154,6 +3154,8 @@ with demo:
3154
  <path stroke="currentColor" stroke-linecap="round" stroke-width="1.5" d="M9 21h12"></path></g>
3155
  </svg> :
3156
  Clear All Points
 
 
3157
  </h5>
3158
  """)
3159
  prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
@@ -3194,7 +3196,9 @@ with demo:
3194
  mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3195
  run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
3196
  add_download_button(mask_gallery, "mask")
3197
- distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
 
 
3198
  # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
3199
  distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
3200
  crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
@@ -3234,7 +3238,7 @@ with demo:
3234
  return rgbs
3235
 
3236
  def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
3237
- crop_expand, distance_threshold, distance_power, area_threshold):
3238
  ncut_images = [image[0] for image in ncut_images]
3239
  if len(ncut_images) == 0:
3240
  return []
@@ -3252,30 +3256,6 @@ with demo:
3252
  h, w = ncut_images[0].shape[:2]
3253
  ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
3254
  # normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
3255
- positive_distances = []
3256
- negative_distances = []
3257
- for rgb, is_positive in rgbs:
3258
- rgb = torch.tensor(rgb).float() / 255
3259
- # rgb = F.normalize(rgb, p=2, dim=-1)
3260
- distance = (ncut_pixels - rgb[None]).norm(dim=-1)
3261
- distance = distance.squeeze(-1)
3262
- if is_positive:
3263
- positive_distances.append(distance)
3264
- else:
3265
- negative_distances.append(distance)
3266
- if len(positive_distances) == 0:
3267
- raise gr.Error("No prompt points. Please draw some points on the image.")
3268
- positive_distances = torch.stack(positive_distances)
3269
- negative_flag = len(negative_distances) > 0
3270
- if len(negative_distances) == 0:
3271
- negative_distances = positive_distances * 0 # dummy
3272
- else:
3273
- negative_distances = torch.stack(negative_distances)
3274
-
3275
- positive_distance = positive_distances.min(dim=0).values
3276
- negative_distance = negative_distances.min(dim=0).values
3277
- # positive_distance = positive_distances.mean(dim=0)
3278
- # negative_distance = negative_distances.mean(dim=0)
3279
 
3280
  def to_mask(heatmap, threshold):
3281
  heatmap = 1 / (heatmap + 1e-6)
@@ -3289,13 +3269,26 @@ with demo:
3289
  mask = heatmap > threshold
3290
  return mask
3291
 
3292
- positive_mask = to_mask(positive_distance, distance_threshold)
3293
- if negative_flag:
3294
- negative_mask = to_mask(negative_distance, distance_threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
3295
  positive_mask = positive_mask & ~negative_mask
3296
 
3297
 
3298
- #convert to PIL
3299
  mask = positive_mask.cpu().numpy()
3300
  mask = mask.astype(np.uint8) * 255
3301
  mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
@@ -3343,7 +3336,12 @@ with demo:
3343
  return bounding_boxes, cleaned_pil_mask
3344
 
3345
  bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask])
 
 
 
 
3346
 
 
3347
  # combine the masks, also draw the bounding boxes
3348
  combined_masks = []
3349
  for i_image in range(len(mask)):
@@ -3352,6 +3350,13 @@ with demo:
3352
  clean_mask = np.array(filtered_masks[i_image].convert("RGB"))
3353
  combined_mask = noisy_mask * 0.4 + clean_mask
3354
  combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
3355
  for x, y, w, h in bbox:
3356
  cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2)
3357
  combined_mask = Image.fromarray(combined_mask)
@@ -3381,10 +3386,6 @@ with demo:
3381
  crop = image.crop((_x, _y, _x + _w, _y + _h))
3382
  return crop
3383
 
3384
- original_images = [image[0] for image in original_images]
3385
- if isinstance(original_images[0], str):
3386
- original_images = [Image.open(image) for image in original_images]
3387
-
3388
  mask_h, mask_w = filtered_masks[0].size
3389
  cropped_images = []
3390
  for _image, _bboxs in zip(original_images, bboxs):
@@ -3395,7 +3396,8 @@ with demo:
3395
 
3396
  run_crop_button.click(run_crop,
3397
  inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
3398
- crop_expand_slider, distance_threshold_slider, distance_power_slider, area_threshold_slider],
 
3399
  outputs=[mask_gallery, crop_gallery])
3400
 
3401
 
 
3134
  image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True)
3135
  load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
3136
  gr.Markdown("### Step 2b: Draw Points")
 
 
3137
  gr.Markdown("""
3138
  <h5>
3139
+ 🖱️ Left Click: Foreground </br>
3140
+ 🖱️ Middle Click: Background </br></br>
3141
  Top Right Buttons: </br>
3142
  <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none"
3143
  stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
 
3154
  <path stroke="currentColor" stroke-linecap="round" stroke-width="1.5" d="M9 21h12"></path></g>
3155
  </svg> :
3156
  Clear All Points
3157
+ </br>
3158
+ (Known issue: please manually clear the points after loading new image)
3159
  </h5>
3160
  """)
3161
  prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
 
3196
  mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3197
  run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
3198
  add_download_button(mask_gallery, "mask")
3199
+ distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (Foreground)", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
3200
+ negative_distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (Background)", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
3201
+ overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
3202
  # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
3203
  distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
3204
  crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
 
3238
  return rgbs
3239
 
3240
  def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
3241
+ crop_expand, distance_threshold, distance_power, area_threshold, overlay_image, negative_distance_threshold):
3242
  ncut_images = [image[0] for image in ncut_images]
3243
  if len(ncut_images) == 0:
3244
  return []
 
3256
  h, w = ncut_images[0].shape[:2]
3257
  ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
3258
  # normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3259
 
3260
  def to_mask(heatmap, threshold):
3261
  heatmap = 1 / (heatmap + 1e-6)
 
3269
  mask = heatmap > threshold
3270
  return mask
3271
 
3272
+ positive_masks, negative_masks = [], []
3273
+ for rgb, is_positive in rgbs:
3274
+ rgb = torch.tensor(rgb).float() / 255
3275
+ distance = (ncut_pixels - rgb[None]).norm(dim=-1)
3276
+ distance = distance.squeeze(-1)
3277
+ if is_positive:
3278
+ positive_masks.append(to_mask(distance, distance_threshold))
3279
+ else:
3280
+ negative_masks.append(to_mask(distance, negative_distance_threshold))
3281
+ if len(positive_masks) == 0:
3282
+ raise gr.Error("No prompt points. Please draw some points on the image.")
3283
+ positive_masks = torch.stack(positive_masks)
3284
+ positive_mask = positive_masks.any(dim=0)
3285
+ if len(negative_masks) > 0:
3286
+ negative_masks = torch.stack(negative_masks)
3287
+ negative_mask = negative_masks.any(dim=0)
3288
  positive_mask = positive_mask & ~negative_mask
3289
 
3290
 
3291
+ # convert to PIL
3292
  mask = positive_mask.cpu().numpy()
3293
  mask = mask.astype(np.uint8) * 255
3294
  mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
 
3336
  return bounding_boxes, cleaned_pil_mask
3337
 
3338
  bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask])
3339
+
3340
+ original_images = [image[0] for image in original_images]
3341
+ if isinstance(original_images[0], str):
3342
+ original_images = [Image.open(image) for image in original_images]
3343
 
3344
+
3345
  # combine the masks, also draw the bounding boxes
3346
  combined_masks = []
3347
  for i_image in range(len(mask)):
 
3350
  clean_mask = np.array(filtered_masks[i_image].convert("RGB"))
3351
  combined_mask = noisy_mask * 0.4 + clean_mask
3352
  combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
3353
+ if overlay_image:
3354
+ combined_mask[:, :, 0] = 0 # remove red channel
3355
+ combined_mask[:, :, 1] = 0 # remove green channel
3356
+ _image = original_images[i_image].convert("RGB").resize((combined_mask.shape[1], combined_mask.shape[0]))
3357
+ _image = np.array(_image)
3358
+ combined_mask = 0.5 * combined_mask + 0.5 * _image
3359
+ combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
3360
  for x, y, w, h in bbox:
3361
  cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2)
3362
  combined_mask = Image.fromarray(combined_mask)
 
3386
  crop = image.crop((_x, _y, _x + _w, _y + _h))
3387
  return crop
3388
 
 
 
 
 
3389
  mask_h, mask_w = filtered_masks[0].size
3390
  cropped_images = []
3391
  for _image, _bboxs in zip(original_images, bboxs):
 
3396
 
3397
  run_crop_button.click(run_crop,
3398
  inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
3399
+ crop_expand_slider, distance_threshold_slider, distance_power_slider,
3400
+ area_threshold_slider, overlay_image_checkbox, negative_distance_threshold_slider],
3401
  outputs=[mask_gallery, crop_gallery])
3402
 
3403