Spaces:
Running
on
Zero
Running
on
Zero
update mask logic
Browse files
app.py
CHANGED
@@ -3134,10 +3134,10 @@ with demo:
|
|
3134 |
image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True)
|
3135 |
load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
|
3136 |
gr.Markdown("### Step 2b: Draw Points")
|
3137 |
-
gr.Markdown("##### 🖱️ Left Click: Foreground")
|
3138 |
-
gr.Markdown("##### 🖱️ Middle Click: Background")
|
3139 |
gr.Markdown("""
|
3140 |
<h5>
|
|
|
|
|
3141 |
Top Right Buttons: </br>
|
3142 |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none"
|
3143 |
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
|
@@ -3154,6 +3154,8 @@ with demo:
|
|
3154 |
<path stroke="currentColor" stroke-linecap="round" stroke-width="1.5" d="M9 21h12"></path></g>
|
3155 |
</svg> :
|
3156 |
Clear All Points
|
|
|
|
|
3157 |
</h5>
|
3158 |
""")
|
3159 |
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
@@ -3194,7 +3196,9 @@ with demo:
|
|
3194 |
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
3195 |
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
3196 |
add_download_button(mask_gallery, "mask")
|
3197 |
-
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
|
|
|
|
|
3198 |
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
3199 |
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
3200 |
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
@@ -3234,7 +3238,7 @@ with demo:
|
|
3234 |
return rgbs
|
3235 |
|
3236 |
def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
|
3237 |
-
crop_expand, distance_threshold, distance_power, area_threshold):
|
3238 |
ncut_images = [image[0] for image in ncut_images]
|
3239 |
if len(ncut_images) == 0:
|
3240 |
return []
|
@@ -3252,30 +3256,6 @@ with demo:
|
|
3252 |
h, w = ncut_images[0].shape[:2]
|
3253 |
ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
|
3254 |
# normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
|
3255 |
-
positive_distances = []
|
3256 |
-
negative_distances = []
|
3257 |
-
for rgb, is_positive in rgbs:
|
3258 |
-
rgb = torch.tensor(rgb).float() / 255
|
3259 |
-
# rgb = F.normalize(rgb, p=2, dim=-1)
|
3260 |
-
distance = (ncut_pixels - rgb[None]).norm(dim=-1)
|
3261 |
-
distance = distance.squeeze(-1)
|
3262 |
-
if is_positive:
|
3263 |
-
positive_distances.append(distance)
|
3264 |
-
else:
|
3265 |
-
negative_distances.append(distance)
|
3266 |
-
if len(positive_distances) == 0:
|
3267 |
-
raise gr.Error("No prompt points. Please draw some points on the image.")
|
3268 |
-
positive_distances = torch.stack(positive_distances)
|
3269 |
-
negative_flag = len(negative_distances) > 0
|
3270 |
-
if len(negative_distances) == 0:
|
3271 |
-
negative_distances = positive_distances * 0 # dummy
|
3272 |
-
else:
|
3273 |
-
negative_distances = torch.stack(negative_distances)
|
3274 |
-
|
3275 |
-
positive_distance = positive_distances.min(dim=0).values
|
3276 |
-
negative_distance = negative_distances.min(dim=0).values
|
3277 |
-
# positive_distance = positive_distances.mean(dim=0)
|
3278 |
-
# negative_distance = negative_distances.mean(dim=0)
|
3279 |
|
3280 |
def to_mask(heatmap, threshold):
|
3281 |
heatmap = 1 / (heatmap + 1e-6)
|
@@ -3289,13 +3269,26 @@ with demo:
|
|
3289 |
mask = heatmap > threshold
|
3290 |
return mask
|
3291 |
|
3292 |
-
|
3293 |
-
|
3294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3295 |
positive_mask = positive_mask & ~negative_mask
|
3296 |
|
3297 |
|
3298 |
-
#convert to PIL
|
3299 |
mask = positive_mask.cpu().numpy()
|
3300 |
mask = mask.astype(np.uint8) * 255
|
3301 |
mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
|
@@ -3343,7 +3336,12 @@ with demo:
|
|
3343 |
return bounding_boxes, cleaned_pil_mask
|
3344 |
|
3345 |
bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask])
|
|
|
|
|
|
|
|
|
3346 |
|
|
|
3347 |
# combine the masks, also draw the bounding boxes
|
3348 |
combined_masks = []
|
3349 |
for i_image in range(len(mask)):
|
@@ -3352,6 +3350,13 @@ with demo:
|
|
3352 |
clean_mask = np.array(filtered_masks[i_image].convert("RGB"))
|
3353 |
combined_mask = noisy_mask * 0.4 + clean_mask
|
3354 |
combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3355 |
for x, y, w, h in bbox:
|
3356 |
cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2)
|
3357 |
combined_mask = Image.fromarray(combined_mask)
|
@@ -3381,10 +3386,6 @@ with demo:
|
|
3381 |
crop = image.crop((_x, _y, _x + _w, _y + _h))
|
3382 |
return crop
|
3383 |
|
3384 |
-
original_images = [image[0] for image in original_images]
|
3385 |
-
if isinstance(original_images[0], str):
|
3386 |
-
original_images = [Image.open(image) for image in original_images]
|
3387 |
-
|
3388 |
mask_h, mask_w = filtered_masks[0].size
|
3389 |
cropped_images = []
|
3390 |
for _image, _bboxs in zip(original_images, bboxs):
|
@@ -3395,7 +3396,8 @@ with demo:
|
|
3395 |
|
3396 |
run_crop_button.click(run_crop,
|
3397 |
inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
|
3398 |
-
crop_expand_slider, distance_threshold_slider, distance_power_slider,
|
|
|
3399 |
outputs=[mask_gallery, crop_gallery])
|
3400 |
|
3401 |
|
|
|
3134 |
image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True)
|
3135 |
load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
|
3136 |
gr.Markdown("### Step 2b: Draw Points")
|
|
|
|
|
3137 |
gr.Markdown("""
|
3138 |
<h5>
|
3139 |
+
🖱️ Left Click: Foreground </br>
|
3140 |
+
🖱️ Middle Click: Background </br></br>
|
3141 |
Top Right Buttons: </br>
|
3142 |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none"
|
3143 |
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
|
|
|
3154 |
<path stroke="currentColor" stroke-linecap="round" stroke-width="1.5" d="M9 21h12"></path></g>
|
3155 |
</svg> :
|
3156 |
Clear All Points
|
3157 |
+
</br>
|
3158 |
+
(Known issue: please manually clear the points after loading new image)
|
3159 |
</h5>
|
3160 |
""")
|
3161 |
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False)
|
|
|
3196 |
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
3197 |
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
3198 |
add_download_button(mask_gallery, "mask")
|
3199 |
+
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (Foreground)", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
|
3200 |
+
negative_distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (Background)", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
|
3201 |
+
overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox")
|
3202 |
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
3203 |
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
3204 |
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
|
|
3238 |
return rgbs
|
3239 |
|
3240 |
def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
|
3241 |
+
crop_expand, distance_threshold, distance_power, area_threshold, overlay_image, negative_distance_threshold):
|
3242 |
ncut_images = [image[0] for image in ncut_images]
|
3243 |
if len(ncut_images) == 0:
|
3244 |
return []
|
|
|
3256 |
h, w = ncut_images[0].shape[:2]
|
3257 |
ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
|
3258 |
# normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3259 |
|
3260 |
def to_mask(heatmap, threshold):
|
3261 |
heatmap = 1 / (heatmap + 1e-6)
|
|
|
3269 |
mask = heatmap > threshold
|
3270 |
return mask
|
3271 |
|
3272 |
+
positive_masks, negative_masks = [], []
|
3273 |
+
for rgb, is_positive in rgbs:
|
3274 |
+
rgb = torch.tensor(rgb).float() / 255
|
3275 |
+
distance = (ncut_pixels - rgb[None]).norm(dim=-1)
|
3276 |
+
distance = distance.squeeze(-1)
|
3277 |
+
if is_positive:
|
3278 |
+
positive_masks.append(to_mask(distance, distance_threshold))
|
3279 |
+
else:
|
3280 |
+
negative_masks.append(to_mask(distance, negative_distance_threshold))
|
3281 |
+
if len(positive_masks) == 0:
|
3282 |
+
raise gr.Error("No prompt points. Please draw some points on the image.")
|
3283 |
+
positive_masks = torch.stack(positive_masks)
|
3284 |
+
positive_mask = positive_masks.any(dim=0)
|
3285 |
+
if len(negative_masks) > 0:
|
3286 |
+
negative_masks = torch.stack(negative_masks)
|
3287 |
+
negative_mask = negative_masks.any(dim=0)
|
3288 |
positive_mask = positive_mask & ~negative_mask
|
3289 |
|
3290 |
|
3291 |
+
# convert to PIL
|
3292 |
mask = positive_mask.cpu().numpy()
|
3293 |
mask = mask.astype(np.uint8) * 255
|
3294 |
mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
|
|
|
3336 |
return bounding_boxes, cleaned_pil_mask
|
3337 |
|
3338 |
bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask])
|
3339 |
+
|
3340 |
+
original_images = [image[0] for image in original_images]
|
3341 |
+
if isinstance(original_images[0], str):
|
3342 |
+
original_images = [Image.open(image) for image in original_images]
|
3343 |
|
3344 |
+
|
3345 |
# combine the masks, also draw the bounding boxes
|
3346 |
combined_masks = []
|
3347 |
for i_image in range(len(mask)):
|
|
|
3350 |
clean_mask = np.array(filtered_masks[i_image].convert("RGB"))
|
3351 |
combined_mask = noisy_mask * 0.4 + clean_mask
|
3352 |
combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
|
3353 |
+
if overlay_image:
|
3354 |
+
combined_mask[:, :, 0] = 0 # remove red channel
|
3355 |
+
combined_mask[:, :, 1] = 0 # remove green channel
|
3356 |
+
_image = original_images[i_image].convert("RGB").resize((combined_mask.shape[1], combined_mask.shape[0]))
|
3357 |
+
_image = np.array(_image)
|
3358 |
+
combined_mask = 0.5 * combined_mask + 0.5 * _image
|
3359 |
+
combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
|
3360 |
for x, y, w, h in bbox:
|
3361 |
cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2)
|
3362 |
combined_mask = Image.fromarray(combined_mask)
|
|
|
3386 |
crop = image.crop((_x, _y, _x + _w, _y + _h))
|
3387 |
return crop
|
3388 |
|
|
|
|
|
|
|
|
|
3389 |
mask_h, mask_w = filtered_masks[0].size
|
3390 |
cropped_images = []
|
3391 |
for _image, _bboxs in zip(original_images, bboxs):
|
|
|
3396 |
|
3397 |
run_crop_button.click(run_crop,
|
3398 |
inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
|
3399 |
+
crop_expand_slider, distance_threshold_slider, distance_power_slider,
|
3400 |
+
area_threshold_slider, overlay_image_checkbox, negative_distance_threshold_slider],
|
3401 |
outputs=[mask_gallery, crop_gallery])
|
3402 |
|
3403 |
|