huzey commited on
Commit
4d21d31
1 Parent(s): 57f7a8d

add application

Browse files
Files changed (1) hide show
  1. app.py +315 -13
app.py CHANGED
@@ -419,7 +419,7 @@ def segment_fg_bg(images):
419
  # transform the input images
420
  input_images = (input_images - means) / stds
421
  # output = model(input_images)[:, 5]
422
- output = model(input_images)['attn'][6]
423
  fg_act = output[:, 6, 6].mean(0)
424
  bg_act = output[:, 0, 0].mean(0)
425
  fg_acts.append(fg_act)
@@ -455,8 +455,8 @@ def segment_fg_bg(images):
455
  # output = model(input_images)[:, 5]
456
  output = model(input_images)['attn'][6]
457
  output = F.normalize(output, dim=-1)
458
- heatmap_fg = output @ fg_act[:, None]
459
- heatmap_bg = output @ bg_act[:, None]
460
  heatmap_fgs.append(heatmap_fg.cpu())
461
  heatmap_bgs.append(heatmap_bg.cpu())
462
  heatmap_fg = torch.cat(heatmap_fgs, dim=0)
@@ -498,8 +498,8 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
498
  left = F.normalize(left, dim=-1)
499
  right = F.normalize(right, dim=-1)
500
  heatmap = left @ right.T
501
- heatmap = F.normalize(heatmap, dim=-1)
502
- num_samples = clusters + 20
503
  if num_samples > fps_idx.shape[0]:
504
  num_samples = fps_idx.shape[0]
505
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
@@ -939,7 +939,7 @@ def ncut_run(
939
  return video_path, logging_str
940
 
941
  cluster_images = None
942
- if plot_clusters:
943
  start = time.time()
944
  progress_start = 0.6
945
  progress(progress_start, desc="Plotting Clusters")
@@ -955,7 +955,7 @@ def ncut_run(
955
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
956
 
957
  norm_images = None
958
- if alignedcut_eig_norm_plot:
959
  norm_images = []
960
  # eig_magnitude = torch.clamp(eig_magnitude, 0, 1)
961
  vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
@@ -977,7 +977,7 @@ def ncut_run(
977
 
978
 
979
  def _ncut_run(*args, **kwargs):
980
- n_ret = kwargs.pop("n_ret", 1)
981
  try:
982
  if torch.cuda.is_available():
983
  torch.cuda.empty_cache()
@@ -1653,8 +1653,9 @@ def load_and_append(existing_images, *args, **kwargs):
1653
  gr.Info(f"Total images: {len(existing_images)}")
1654
  return existing_images
1655
 
1656
- def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False):
1657
- gr.Markdown('### Input Images')
 
1658
  input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
1659
  format="webp")
1660
 
@@ -2020,10 +2021,12 @@ def add_download_button(gallery, filename_prefix="output"):
2020
  return create_file_button, download_button
2021
 
2022
 
2023
- def make_output_images_section():
2024
- gr.Markdown('### Output Images')
 
2025
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
2026
- add_rotate_flip_buttons(output_gallery)
 
2027
  return output_gallery
2028
 
2029
  def make_parameters_section(is_lisa=False, model_ratio=True):
@@ -2133,6 +2136,8 @@ demo = gr.Blocks(
2133
  css=custom_css,
2134
  )
2135
  with demo:
 
 
2136
  with gr.Tab('AlignedCut'):
2137
 
2138
  with gr.Row():
@@ -3081,7 +3086,304 @@ with demo:
3081
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
3082
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
3083
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3085
 
3086
  with gr.Tab('📄About'):
3087
  with gr.Column():
 
419
  # transform the input images
420
  input_images = (input_images - means) / stds
421
  # output = model(input_images)[:, 5]
422
+ output = model(input_images)['attn'][6] # [B, H=14, W=14, C]
423
  fg_act = output[:, 6, 6].mean(0)
424
  bg_act = output[:, 0, 0].mean(0)
425
  fg_acts.append(fg_act)
 
455
  # output = model(input_images)[:, 5]
456
  output = model(input_images)['attn'][6]
457
  output = F.normalize(output, dim=-1)
458
+ heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
459
+ heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
460
  heatmap_fgs.append(heatmap_fg.cpu())
461
  heatmap_bgs.append(heatmap_bg.cpu())
462
  heatmap_fg = torch.cat(heatmap_fgs, dim=0)
 
498
  left = F.normalize(left, dim=-1)
499
  right = F.normalize(right, dim=-1)
500
  heatmap = left @ right.T
501
+ heatmap = F.normalize(heatmap, dim=-1) # [300, N_pixel] PCA-> [300, 8]
502
+ num_samples = clusters + 20 # 100/120
503
  if num_samples > fps_idx.shape[0]:
504
  num_samples = fps_idx.shape[0]
505
  r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
 
939
  return video_path, logging_str
940
 
941
  cluster_images = None
942
+ if plot_clusters and kwargs.get("n_ret", 1) > 1:
943
  start = time.time()
944
  progress_start = 0.6
945
  progress(progress_start, desc="Plotting Clusters")
 
955
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
956
 
957
  norm_images = None
958
+ if alignedcut_eig_norm_plot and kwargs.get("n_ret", 1) > 1:
959
  norm_images = []
960
  # eig_magnitude = torch.clamp(eig_magnitude, 0, 1)
961
  vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
 
977
 
978
 
979
  def _ncut_run(*args, **kwargs):
980
+ n_ret = kwargs.get("n_ret", 1)
981
  try:
982
  if torch.cuda.is_available():
983
  torch.cuda.empty_cache()
 
1653
  gr.Info(f"Total images: {len(existing_images)}")
1654
  return existing_images
1655
 
1656
+ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True):
1657
+ if markdown:
1658
+ gr.Markdown('### Input Images')
1659
  input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
1660
  format="webp")
1661
 
 
2021
  return create_file_button, download_button
2022
 
2023
 
2024
+ def make_output_images_section(markdown=True, button=True):
2025
+ if markdown:
2026
+ gr.Markdown('### Output Images')
2027
  output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
2028
+ if button:
2029
+ add_rotate_flip_buttons(output_gallery)
2030
  return output_gallery
2031
 
2032
  def make_parameters_section(is_lisa=False, model_ratio=True):
 
2136
  css=custom_css,
2137
  )
2138
  with demo:
2139
+
2140
+
2141
  with gr.Tab('AlignedCut'):
2142
 
2143
  with gr.Row():
 
3086
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
3087
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
3088
 
3089
+ with gr.Tab('Application'):
3090
+ gr.Markdown("Draw some points on the image to find corrsponding segments in other images. E.g. click on one face to segment all the face. [Video Tutorial (coming...)]()")
3091
+ with gr.Row():
3092
+ with gr.Column(scale=5, min_width=200):
3093
+ gr.Markdown("### Step 0: Load Images")
3094
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(markdown=False)
3095
+ submit_button.visible = False
3096
+ num_images_slider.value = 30
3097
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
3098
+ with gr.Column(scale=5, min_width=200):
3099
+ gr.Markdown("### Step 1: NCUT Embedding")
3100
+ output_gallery = make_output_images_section(markdown=False, button=False)
3101
+ submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
3102
+ add_rotate_flip_buttons(output_gallery)
3103
+ [
3104
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
3105
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3106
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3107
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
3108
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
3109
+ ] = make_parameters_section()
3110
 
3111
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
3112
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
3113
+
3114
+ submit_button.click(
3115
+ partial(run_fn, n_ret=1),
3116
+ inputs=[
3117
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
3118
+ positive_prompt, negative_prompt,
3119
+ false_placeholder, no_prompt, no_prompt, no_prompt,
3120
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
3121
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3122
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
3123
+ ],
3124
+ outputs=[output_gallery, logging_text],
3125
+ )
3126
+
3127
+ with gr.Column(scale=5, min_width=200):
3128
+ gr.Markdown("### Step 2a: Pick an Image")
3129
+ from gradio_image_prompter import ImagePrompter
3130
+ image_type_radio = gr.Radio(["Original", "NCUT"], label="Image Display Type", value="Original", elem_id="image_type_radio")
3131
+ with gr.Row():
3132
+ image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
3133
+ image2_slider = gr.Slider(0, 100, step=1, label="Image#2 Index", value=1, elem_id="image2_slider", interactive=True)
3134
+ image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True)
3135
+ load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
3136
+ gr.Markdown("### Step 2b: Draw Points")
3137
+ gr.Markdown("##### 🖱️ Left Click: Foreground")
3138
+ gr.Markdown("##### 🖱️ Middle Click: Background")
3139
+ gr.Markdown("""
3140
+ <h5>
3141
+ Top Right
3142
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none"
3143
+ stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
3144
+ style="vertical-align: middle; height: 1em; width: 1em; display: inline;">
3145
+ <polyline points="1 4 1 10 7 10"></polyline>
3146
+ <path d="M3.51 15a9 9 0 1 0 2.13-9.36L1 10"></path>
3147
+ </svg> :
3148
+ Remove Last Point
3149
+ </h5>
3150
+ """)
3151
+ prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
3152
+ prompt_image2 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
3153
+ prompt_image3 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
3154
+ # def update_number_of_images(images):
3155
+ # if images is None:
3156
+ # return gr.update(max=0, value=0)
3157
+ # return gr.update(max=len(images)-1, value=1)
3158
+ # input_gallery.change(update_number_of_images, inputs=input_gallery, outputs=image1_slider)
3159
+
3160
+ def update_prompt_image(original_images, ncut_images, image_type, index):
3161
+ if image_type == "Original":
3162
+ images = original_images
3163
+ else:
3164
+ images = ncut_images
3165
+ if images is None:
3166
+ return
3167
+ total_len = len(images)
3168
+ if total_len == 0:
3169
+ return
3170
+ if index >= total_len:
3171
+ index = total_len - 1
3172
+ return gr.update(value={'image': images[index][0]})
3173
+ load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image1_slider], outputs=[prompt_image1])
3174
+ load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image2_slider], outputs=[prompt_image2])
3175
+ load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image3_slider], outputs=[prompt_image3])
3176
+
3177
+ image3_slider.visible = False
3178
+ prompt_image3.visible = False
3179
+
3180
+
3181
+
3182
+ with gr.Column(scale=5, min_width=200):
3183
+ gr.Markdown("### Step 3: Segment and Crop")
3184
+ mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3185
+ run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
3186
+ add_download_button(mask_gallery, "mask")
3187
+ distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
3188
+ # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
3189
+ distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
3190
+ crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
3191
+ add_download_button(crop_gallery, "cropped")
3192
+ crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
3193
+ area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
3194
+
3195
+ # logging_image = gr.Image(value=None, label="Logging Image", elem_id="logging_image", interactive=False)
3196
+
3197
+ # prompt_image.change(lambda x: gr.update(value=x.get('image', None)), inputs=prompt_image, outputs=[logging_image])
3198
+
3199
+ def relative_xy(prompts):
3200
+ image = prompts['image']
3201
+ points = np.asarray(prompts['points'])
3202
+ if points.shape[0] == 0:
3203
+ return [], []
3204
+ is_point = points[:, 5] == 4.0
3205
+ points = points[is_point]
3206
+ is_positive = points[:, 2] == 1.0
3207
+ is_negative = points[:, 2] == 0.0
3208
+ xy = points[:, :2].tolist()
3209
+ if isinstance(image, str):
3210
+ image = Image.open(image)
3211
+ image = np.array(image)
3212
+ h, w = image.shape[:2]
3213
+ new_xy = [(x/w, y/h) for x, y in xy]
3214
+ # print(new_xy)
3215
+ return new_xy, is_positive
3216
+
3217
+ def xy_rgb(prompts, image_idx, ncut_images):
3218
+ image = ncut_images[image_idx]
3219
+ xy, is_positive = relative_xy(prompts)
3220
+ rgbs = []
3221
+ for i, (x, y) in enumerate(xy):
3222
+ rgb = image.getpixel((int(x*image.width), int(y*image.height)))
3223
+ rgbs.append((rgb, is_positive[i]))
3224
+ return rgbs
3225
+
3226
+ def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
3227
+ crop_expand, distance_threshold, distance_power, area_threshold):
3228
+ ncut_images = [image[0] for image in ncut_images]
3229
+ if len(ncut_images) == 0:
3230
+ return []
3231
+ if isinstance(ncut_images[0], str):
3232
+ ncut_images = [Image.open(image) for image in ncut_images]
3233
+
3234
+ rgbs = xy_rgb(prompts1, image_idx1, ncut_images) + \
3235
+ xy_rgb(prompts2, image_idx2, ncut_images) + \
3236
+ xy_rgb(prompts3, image_idx3, ncut_images)
3237
+ # print(rgbs)
3238
+
3239
+
3240
+ ncut_images = [np.array(image).astype(np.float32) for image in ncut_images]
3241
+ ncut_pixels = [image.reshape(-1, 3) for image in ncut_images]
3242
+ h, w = ncut_images[0].shape[:2]
3243
+ ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
3244
+ # normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
3245
+ positive_distances = []
3246
+ negative_distances = []
3247
+ for rgb, is_positive in rgbs:
3248
+ rgb = torch.tensor(rgb).float() / 255
3249
+ # rgb = F.normalize(rgb, p=2, dim=-1)
3250
+ distance = (ncut_pixels - rgb[None]).norm(dim=-1)
3251
+ distance = distance.squeeze(-1)
3252
+ if is_positive:
3253
+ positive_distances.append(distance)
3254
+ else:
3255
+ negative_distances.append(distance)
3256
+ if len(positive_distances) == 0:
3257
+ raise gr.Error("No prompt points. Please draw some points on the image.")
3258
+ positive_distances = torch.stack(positive_distances)
3259
+ negative_flag = len(negative_distances) > 0
3260
+ if len(negative_distances) == 0:
3261
+ negative_distances = positive_distances * 0 # dummy
3262
+ else:
3263
+ negative_distances = torch.stack(negative_distances)
3264
+
3265
+ positive_distance = positive_distances.min(dim=0).values
3266
+ negative_distance = negative_distances.min(dim=0).values
3267
+ # positive_distance = positive_distances.mean(dim=0)
3268
+ # negative_distance = negative_distances.mean(dim=0)
3269
+
3270
+ def to_mask(heatmap, threshold):
3271
+ heatmap = 1 / (heatmap + 1e-6)
3272
+ heatmap = heatmap.reshape(len(ncut_images), h, w)
3273
+ vmin, vmax = heatmap.quantile(0.01), heatmap.quantile(0.99)
3274
+ heatmap = (heatmap - vmin) / (vmax - vmin)
3275
+ mask = heatmap > threshold
3276
+ return mask
3277
+
3278
+ positive_mask = to_mask(positive_distance, distance_threshold)
3279
+ if negative_flag:
3280
+ negative_mask = to_mask(negative_distance, distance_threshold)
3281
+ positive_mask = positive_mask & ~negative_mask
3282
+
3283
+
3284
+ #convert to PIL
3285
+ mask = positive_mask.cpu().numpy()
3286
+ mask = mask.astype(np.uint8) * 255
3287
+ mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
3288
+
3289
+ import cv2
3290
+ def get_bboxes_and_clean_mask(pil_mask, min_area=500):
3291
+ """
3292
+ Args:
3293
+ - pil_mask: A Pillow image of a binary mask with 255 for the object and 0 for the background.
3294
+ - min_area: Minimum area for a connected component to be considered valid (default 500).
3295
+
3296
+ Returns:
3297
+ - bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
3298
+ - cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
3299
+ """
3300
+ # Convert the Pillow image to a NumPy array
3301
+ mask = np.array(pil_mask)
3302
+
3303
+ # Ensure the mask is binary (0 or 255)
3304
+ mask = np.where(mask > 127, 255, 0).astype(np.uint8)
3305
+
3306
+ # Remove small noise using morphological operations (denoising)
3307
+ kernel = np.ones((5, 5), np.uint8)
3308
+ cleaned_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
3309
+
3310
+ # Find connected components in the cleaned mask
3311
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(cleaned_mask, connectivity=8)
3312
+
3313
+ # Initialize an empty mask to store the final cleaned mask
3314
+ final_cleaned_mask = np.zeros_like(cleaned_mask)
3315
+
3316
+ # Collect bounding boxes for components that are larger than the threshold and update the cleaned mask
3317
+ bounding_boxes = []
3318
+ for i in range(1, num_labels): # Skip label 0 (background)
3319
+ x, y, w, h, area = stats[i]
3320
+ if area >= min_area:
3321
+ # Add the bounding box of the valid component
3322
+ bounding_boxes.append((x, y, w, h))
3323
+ # Keep the valid components in the final cleaned mask
3324
+ final_cleaned_mask[labels == i] = 255
3325
+
3326
+ # Convert the final cleaned mask back to a Pillow image
3327
+ cleaned_pil_mask = Image.fromarray(final_cleaned_mask)
3328
+
3329
+ return bounding_boxes, cleaned_pil_mask
3330
+
3331
+ bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask])
3332
+
3333
+ # combine the masks, also draw the bounding boxes
3334
+ combined_masks = []
3335
+ for i_image in range(len(mask)):
3336
+ noisy_mask = np.array(mask[i_image].convert("RGB"))
3337
+ bbox = bboxs[i_image]
3338
+ clean_mask = np.array(filtered_masks[i_image].convert("RGB"))
3339
+ combined_mask = noisy_mask * 0.4 + clean_mask
3340
+ combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
3341
+ for x, y, w, h in bbox:
3342
+ cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2)
3343
+ combined_mask = Image.fromarray(combined_mask)
3344
+ combined_masks.append(combined_mask)
3345
+
3346
+ def extend_the_mask(xywh, factor=1.5):
3347
+ x, y, w, h = xywh
3348
+ x -= w * (factor - 1) / 2
3349
+ y -= h * (factor - 1) / 2
3350
+ w *= factor
3351
+ h *= factor
3352
+ return x, y, w, h
3353
+
3354
+ def resize_the_mask(xywh, original_size, target_size):
3355
+ x, y, w, h = xywh
3356
+ x *= target_size[0] / original_size[0]
3357
+ y *= target_size[1] / original_size[1]
3358
+ w *= target_size[0] / original_size[0]
3359
+ h *= target_size[1] / original_size[1]
3360
+ x, y, w, h = int(x), int(y), int(w), int(h)
3361
+ return x, y, w, h
3362
+
3363
+ def crop_image(image, xywh, mask_h, mask_w, factor=1.0):
3364
+ x, y, w, h = xywh
3365
+ x, y, w, h = resize_the_mask((x, y, w, h), (mask_h, mask_w), image.size)
3366
+ _x, _y, _w, _h = extend_the_mask((x, y, w, h), factor=factor)
3367
+ crop = image.crop((_x, _y, _x + _w, _y + _h))
3368
+ return crop
3369
+
3370
+ original_images = [image[0] for image in original_images]
3371
+ if isinstance(original_images[0], str):
3372
+ original_images = [Image.open(image) for image in original_images]
3373
+
3374
+ mask_h, mask_w = filtered_masks[0].size
3375
+ cropped_images = []
3376
+ for _image, _bboxs in zip(original_images, bboxs):
3377
+ for _bbox in _bboxs:
3378
+ cropped_images.append(crop_image(_image, _bbox, mask_h, mask_w, factor=crop_expand))
3379
+
3380
+ return combined_masks, cropped_images
3381
+
3382
+ run_crop_button.click(run_crop,
3383
+ inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
3384
+ crop_expand_slider, distance_threshold_slider, distance_power_slider, area_threshold_slider],
3385
+ outputs=[mask_gallery, crop_gallery])
3386
+
3387
 
3388
  with gr.Tab('📄About'):
3389
  with gr.Column():