Spaces:
Running
on
Zero
Running
on
Zero
add application
Browse files
app.py
CHANGED
@@ -419,7 +419,7 @@ def segment_fg_bg(images):
|
|
419 |
# transform the input images
|
420 |
input_images = (input_images - means) / stds
|
421 |
# output = model(input_images)[:, 5]
|
422 |
-
output = model(input_images)['attn'][6]
|
423 |
fg_act = output[:, 6, 6].mean(0)
|
424 |
bg_act = output[:, 0, 0].mean(0)
|
425 |
fg_acts.append(fg_act)
|
@@ -455,8 +455,8 @@ def segment_fg_bg(images):
|
|
455 |
# output = model(input_images)[:, 5]
|
456 |
output = model(input_images)['attn'][6]
|
457 |
output = F.normalize(output, dim=-1)
|
458 |
-
heatmap_fg = output @ fg_act[:, None]
|
459 |
-
heatmap_bg = output @ bg_act[:, None]
|
460 |
heatmap_fgs.append(heatmap_fg.cpu())
|
461 |
heatmap_bgs.append(heatmap_bg.cpu())
|
462 |
heatmap_fg = torch.cat(heatmap_fgs, dim=0)
|
@@ -498,8 +498,8 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
498 |
left = F.normalize(left, dim=-1)
|
499 |
right = F.normalize(right, dim=-1)
|
500 |
heatmap = left @ right.T
|
501 |
-
heatmap = F.normalize(heatmap, dim=-1)
|
502 |
-
num_samples = clusters + 20
|
503 |
if num_samples > fps_idx.shape[0]:
|
504 |
num_samples = fps_idx.shape[0]
|
505 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
@@ -939,7 +939,7 @@ def ncut_run(
|
|
939 |
return video_path, logging_str
|
940 |
|
941 |
cluster_images = None
|
942 |
-
if plot_clusters:
|
943 |
start = time.time()
|
944 |
progress_start = 0.6
|
945 |
progress(progress_start, desc="Plotting Clusters")
|
@@ -955,7 +955,7 @@ def ncut_run(
|
|
955 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
956 |
|
957 |
norm_images = None
|
958 |
-
if alignedcut_eig_norm_plot:
|
959 |
norm_images = []
|
960 |
# eig_magnitude = torch.clamp(eig_magnitude, 0, 1)
|
961 |
vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
|
@@ -977,7 +977,7 @@ def ncut_run(
|
|
977 |
|
978 |
|
979 |
def _ncut_run(*args, **kwargs):
|
980 |
-
n_ret = kwargs.
|
981 |
try:
|
982 |
if torch.cuda.is_available():
|
983 |
torch.cuda.empty_cache()
|
@@ -1653,8 +1653,9 @@ def load_and_append(existing_images, *args, **kwargs):
|
|
1653 |
gr.Info(f"Total images: {len(existing_images)}")
|
1654 |
return existing_images
|
1655 |
|
1656 |
-
def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False):
|
1657 |
-
|
|
|
1658 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
1659 |
format="webp")
|
1660 |
|
@@ -2020,10 +2021,12 @@ def add_download_button(gallery, filename_prefix="output"):
|
|
2020 |
return create_file_button, download_button
|
2021 |
|
2022 |
|
2023 |
-
def make_output_images_section():
|
2024 |
-
|
|
|
2025 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
2026 |
-
|
|
|
2027 |
return output_gallery
|
2028 |
|
2029 |
def make_parameters_section(is_lisa=False, model_ratio=True):
|
@@ -2133,6 +2136,8 @@ demo = gr.Blocks(
|
|
2133 |
css=custom_css,
|
2134 |
)
|
2135 |
with demo:
|
|
|
|
|
2136 |
with gr.Tab('AlignedCut'):
|
2137 |
|
2138 |
with gr.Row():
|
@@ -3081,7 +3086,304 @@ with demo:
|
|
3081 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
3082 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
3083 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3084 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3085 |
|
3086 |
with gr.Tab('📄About'):
|
3087 |
with gr.Column():
|
|
|
419 |
# transform the input images
|
420 |
input_images = (input_images - means) / stds
|
421 |
# output = model(input_images)[:, 5]
|
422 |
+
output = model(input_images)['attn'][6] # [B, H=14, W=14, C]
|
423 |
fg_act = output[:, 6, 6].mean(0)
|
424 |
bg_act = output[:, 0, 0].mean(0)
|
425 |
fg_acts.append(fg_act)
|
|
|
455 |
# output = model(input_images)[:, 5]
|
456 |
output = model(input_images)['attn'][6]
|
457 |
output = F.normalize(output, dim=-1)
|
458 |
+
heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
|
459 |
+
heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
|
460 |
heatmap_fgs.append(heatmap_fg.cpu())
|
461 |
heatmap_bgs.append(heatmap_bg.cpu())
|
462 |
heatmap_fg = torch.cat(heatmap_fgs, dim=0)
|
|
|
498 |
left = F.normalize(left, dim=-1)
|
499 |
right = F.normalize(right, dim=-1)
|
500 |
heatmap = left @ right.T
|
501 |
+
heatmap = F.normalize(heatmap, dim=-1) # [300, N_pixel] PCA-> [300, 8]
|
502 |
+
num_samples = clusters + 20 # 100/120
|
503 |
if num_samples > fps_idx.shape[0]:
|
504 |
num_samples = fps_idx.shape[0]
|
505 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
939 |
return video_path, logging_str
|
940 |
|
941 |
cluster_images = None
|
942 |
+
if plot_clusters and kwargs.get("n_ret", 1) > 1:
|
943 |
start = time.time()
|
944 |
progress_start = 0.6
|
945 |
progress(progress_start, desc="Plotting Clusters")
|
|
|
955 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
956 |
|
957 |
norm_images = None
|
958 |
+
if alignedcut_eig_norm_plot and kwargs.get("n_ret", 1) > 1:
|
959 |
norm_images = []
|
960 |
# eig_magnitude = torch.clamp(eig_magnitude, 0, 1)
|
961 |
vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
|
|
|
977 |
|
978 |
|
979 |
def _ncut_run(*args, **kwargs):
|
980 |
+
n_ret = kwargs.get("n_ret", 1)
|
981 |
try:
|
982 |
if torch.cuda.is_available():
|
983 |
torch.cuda.empty_cache()
|
|
|
1653 |
gr.Info(f"Total images: {len(existing_images)}")
|
1654 |
return existing_images
|
1655 |
|
1656 |
+
def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True):
|
1657 |
+
if markdown:
|
1658 |
+
gr.Markdown('### Input Images')
|
1659 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
1660 |
format="webp")
|
1661 |
|
|
|
2021 |
return create_file_button, download_button
|
2022 |
|
2023 |
|
2024 |
+
def make_output_images_section(markdown=True, button=True):
|
2025 |
+
if markdown:
|
2026 |
+
gr.Markdown('### Output Images')
|
2027 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
2028 |
+
if button:
|
2029 |
+
add_rotate_flip_buttons(output_gallery)
|
2030 |
return output_gallery
|
2031 |
|
2032 |
def make_parameters_section(is_lisa=False, model_ratio=True):
|
|
|
2136 |
css=custom_css,
|
2137 |
)
|
2138 |
with demo:
|
2139 |
+
|
2140 |
+
|
2141 |
with gr.Tab('AlignedCut'):
|
2142 |
|
2143 |
with gr.Row():
|
|
|
3086 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
3087 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
3088 |
|
3089 |
+
with gr.Tab('Application'):
|
3090 |
+
gr.Markdown("Draw some points on the image to find corrsponding segments in other images. E.g. click on one face to segment all the face. [Video Tutorial (coming...)]()")
|
3091 |
+
with gr.Row():
|
3092 |
+
with gr.Column(scale=5, min_width=200):
|
3093 |
+
gr.Markdown("### Step 0: Load Images")
|
3094 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(markdown=False)
|
3095 |
+
submit_button.visible = False
|
3096 |
+
num_images_slider.value = 30
|
3097 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
3098 |
+
with gr.Column(scale=5, min_width=200):
|
3099 |
+
gr.Markdown("### Step 1: NCUT Embedding")
|
3100 |
+
output_gallery = make_output_images_section(markdown=False, button=False)
|
3101 |
+
submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
|
3102 |
+
add_rotate_flip_buttons(output_gallery)
|
3103 |
+
[
|
3104 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
3105 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
3106 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3107 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
3108 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
3109 |
+
] = make_parameters_section()
|
3110 |
|
3111 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
3112 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
3113 |
+
|
3114 |
+
submit_button.click(
|
3115 |
+
partial(run_fn, n_ret=1),
|
3116 |
+
inputs=[
|
3117 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
3118 |
+
positive_prompt, negative_prompt,
|
3119 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
3120 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
3121 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3122 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
3123 |
+
],
|
3124 |
+
outputs=[output_gallery, logging_text],
|
3125 |
+
)
|
3126 |
+
|
3127 |
+
with gr.Column(scale=5, min_width=200):
|
3128 |
+
gr.Markdown("### Step 2a: Pick an Image")
|
3129 |
+
from gradio_image_prompter import ImagePrompter
|
3130 |
+
image_type_radio = gr.Radio(["Original", "NCUT"], label="Image Display Type", value="Original", elem_id="image_type_radio")
|
3131 |
+
with gr.Row():
|
3132 |
+
image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
3133 |
+
image2_slider = gr.Slider(0, 100, step=1, label="Image#2 Index", value=1, elem_id="image2_slider", interactive=True)
|
3134 |
+
image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True)
|
3135 |
+
load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
|
3136 |
+
gr.Markdown("### Step 2b: Draw Points")
|
3137 |
+
gr.Markdown("##### 🖱️ Left Click: Foreground")
|
3138 |
+
gr.Markdown("##### 🖱️ Middle Click: Background")
|
3139 |
+
gr.Markdown("""
|
3140 |
+
<h5>
|
3141 |
+
Top Right
|
3142 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none"
|
3143 |
+
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
|
3144 |
+
style="vertical-align: middle; height: 1em; width: 1em; display: inline;">
|
3145 |
+
<polyline points="1 4 1 10 7 10"></polyline>
|
3146 |
+
<path d="M3.51 15a9 9 0 1 0 2.13-9.36L1 10"></path>
|
3147 |
+
</svg> :
|
3148 |
+
Remove Last Point
|
3149 |
+
</h5>
|
3150 |
+
""")
|
3151 |
+
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
|
3152 |
+
prompt_image2 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
|
3153 |
+
prompt_image3 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
|
3154 |
+
# def update_number_of_images(images):
|
3155 |
+
# if images is None:
|
3156 |
+
# return gr.update(max=0, value=0)
|
3157 |
+
# return gr.update(max=len(images)-1, value=1)
|
3158 |
+
# input_gallery.change(update_number_of_images, inputs=input_gallery, outputs=image1_slider)
|
3159 |
+
|
3160 |
+
def update_prompt_image(original_images, ncut_images, image_type, index):
|
3161 |
+
if image_type == "Original":
|
3162 |
+
images = original_images
|
3163 |
+
else:
|
3164 |
+
images = ncut_images
|
3165 |
+
if images is None:
|
3166 |
+
return
|
3167 |
+
total_len = len(images)
|
3168 |
+
if total_len == 0:
|
3169 |
+
return
|
3170 |
+
if index >= total_len:
|
3171 |
+
index = total_len - 1
|
3172 |
+
return gr.update(value={'image': images[index][0]})
|
3173 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image1_slider], outputs=[prompt_image1])
|
3174 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image2_slider], outputs=[prompt_image2])
|
3175 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image3_slider], outputs=[prompt_image3])
|
3176 |
+
|
3177 |
+
image3_slider.visible = False
|
3178 |
+
prompt_image3.visible = False
|
3179 |
+
|
3180 |
+
|
3181 |
+
|
3182 |
+
with gr.Column(scale=5, min_width=200):
|
3183 |
+
gr.Markdown("### Step 3: Segment and Crop")
|
3184 |
+
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
3185 |
+
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
3186 |
+
add_download_button(mask_gallery, "mask")
|
3187 |
+
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
|
3188 |
+
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
3189 |
+
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
3190 |
+
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
3191 |
+
add_download_button(crop_gallery, "cropped")
|
3192 |
+
crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
|
3193 |
+
area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
|
3194 |
+
|
3195 |
+
# logging_image = gr.Image(value=None, label="Logging Image", elem_id="logging_image", interactive=False)
|
3196 |
+
|
3197 |
+
# prompt_image.change(lambda x: gr.update(value=x.get('image', None)), inputs=prompt_image, outputs=[logging_image])
|
3198 |
+
|
3199 |
+
def relative_xy(prompts):
|
3200 |
+
image = prompts['image']
|
3201 |
+
points = np.asarray(prompts['points'])
|
3202 |
+
if points.shape[0] == 0:
|
3203 |
+
return [], []
|
3204 |
+
is_point = points[:, 5] == 4.0
|
3205 |
+
points = points[is_point]
|
3206 |
+
is_positive = points[:, 2] == 1.0
|
3207 |
+
is_negative = points[:, 2] == 0.0
|
3208 |
+
xy = points[:, :2].tolist()
|
3209 |
+
if isinstance(image, str):
|
3210 |
+
image = Image.open(image)
|
3211 |
+
image = np.array(image)
|
3212 |
+
h, w = image.shape[:2]
|
3213 |
+
new_xy = [(x/w, y/h) for x, y in xy]
|
3214 |
+
# print(new_xy)
|
3215 |
+
return new_xy, is_positive
|
3216 |
+
|
3217 |
+
def xy_rgb(prompts, image_idx, ncut_images):
|
3218 |
+
image = ncut_images[image_idx]
|
3219 |
+
xy, is_positive = relative_xy(prompts)
|
3220 |
+
rgbs = []
|
3221 |
+
for i, (x, y) in enumerate(xy):
|
3222 |
+
rgb = image.getpixel((int(x*image.width), int(y*image.height)))
|
3223 |
+
rgbs.append((rgb, is_positive[i]))
|
3224 |
+
return rgbs
|
3225 |
+
|
3226 |
+
def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
|
3227 |
+
crop_expand, distance_threshold, distance_power, area_threshold):
|
3228 |
+
ncut_images = [image[0] for image in ncut_images]
|
3229 |
+
if len(ncut_images) == 0:
|
3230 |
+
return []
|
3231 |
+
if isinstance(ncut_images[0], str):
|
3232 |
+
ncut_images = [Image.open(image) for image in ncut_images]
|
3233 |
+
|
3234 |
+
rgbs = xy_rgb(prompts1, image_idx1, ncut_images) + \
|
3235 |
+
xy_rgb(prompts2, image_idx2, ncut_images) + \
|
3236 |
+
xy_rgb(prompts3, image_idx3, ncut_images)
|
3237 |
+
# print(rgbs)
|
3238 |
+
|
3239 |
+
|
3240 |
+
ncut_images = [np.array(image).astype(np.float32) for image in ncut_images]
|
3241 |
+
ncut_pixels = [image.reshape(-1, 3) for image in ncut_images]
|
3242 |
+
h, w = ncut_images[0].shape[:2]
|
3243 |
+
ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
|
3244 |
+
# normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
|
3245 |
+
positive_distances = []
|
3246 |
+
negative_distances = []
|
3247 |
+
for rgb, is_positive in rgbs:
|
3248 |
+
rgb = torch.tensor(rgb).float() / 255
|
3249 |
+
# rgb = F.normalize(rgb, p=2, dim=-1)
|
3250 |
+
distance = (ncut_pixels - rgb[None]).norm(dim=-1)
|
3251 |
+
distance = distance.squeeze(-1)
|
3252 |
+
if is_positive:
|
3253 |
+
positive_distances.append(distance)
|
3254 |
+
else:
|
3255 |
+
negative_distances.append(distance)
|
3256 |
+
if len(positive_distances) == 0:
|
3257 |
+
raise gr.Error("No prompt points. Please draw some points on the image.")
|
3258 |
+
positive_distances = torch.stack(positive_distances)
|
3259 |
+
negative_flag = len(negative_distances) > 0
|
3260 |
+
if len(negative_distances) == 0:
|
3261 |
+
negative_distances = positive_distances * 0 # dummy
|
3262 |
+
else:
|
3263 |
+
negative_distances = torch.stack(negative_distances)
|
3264 |
+
|
3265 |
+
positive_distance = positive_distances.min(dim=0).values
|
3266 |
+
negative_distance = negative_distances.min(dim=0).values
|
3267 |
+
# positive_distance = positive_distances.mean(dim=0)
|
3268 |
+
# negative_distance = negative_distances.mean(dim=0)
|
3269 |
+
|
3270 |
+
def to_mask(heatmap, threshold):
|
3271 |
+
heatmap = 1 / (heatmap + 1e-6)
|
3272 |
+
heatmap = heatmap.reshape(len(ncut_images), h, w)
|
3273 |
+
vmin, vmax = heatmap.quantile(0.01), heatmap.quantile(0.99)
|
3274 |
+
heatmap = (heatmap - vmin) / (vmax - vmin)
|
3275 |
+
mask = heatmap > threshold
|
3276 |
+
return mask
|
3277 |
+
|
3278 |
+
positive_mask = to_mask(positive_distance, distance_threshold)
|
3279 |
+
if negative_flag:
|
3280 |
+
negative_mask = to_mask(negative_distance, distance_threshold)
|
3281 |
+
positive_mask = positive_mask & ~negative_mask
|
3282 |
+
|
3283 |
+
|
3284 |
+
#convert to PIL
|
3285 |
+
mask = positive_mask.cpu().numpy()
|
3286 |
+
mask = mask.astype(np.uint8) * 255
|
3287 |
+
mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
|
3288 |
+
|
3289 |
+
import cv2
|
3290 |
+
def get_bboxes_and_clean_mask(pil_mask, min_area=500):
|
3291 |
+
"""
|
3292 |
+
Args:
|
3293 |
+
- pil_mask: A Pillow image of a binary mask with 255 for the object and 0 for the background.
|
3294 |
+
- min_area: Minimum area for a connected component to be considered valid (default 500).
|
3295 |
+
|
3296 |
+
Returns:
|
3297 |
+
- bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
|
3298 |
+
- cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
|
3299 |
+
"""
|
3300 |
+
# Convert the Pillow image to a NumPy array
|
3301 |
+
mask = np.array(pil_mask)
|
3302 |
+
|
3303 |
+
# Ensure the mask is binary (0 or 255)
|
3304 |
+
mask = np.where(mask > 127, 255, 0).astype(np.uint8)
|
3305 |
+
|
3306 |
+
# Remove small noise using morphological operations (denoising)
|
3307 |
+
kernel = np.ones((5, 5), np.uint8)
|
3308 |
+
cleaned_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
3309 |
+
|
3310 |
+
# Find connected components in the cleaned mask
|
3311 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(cleaned_mask, connectivity=8)
|
3312 |
+
|
3313 |
+
# Initialize an empty mask to store the final cleaned mask
|
3314 |
+
final_cleaned_mask = np.zeros_like(cleaned_mask)
|
3315 |
+
|
3316 |
+
# Collect bounding boxes for components that are larger than the threshold and update the cleaned mask
|
3317 |
+
bounding_boxes = []
|
3318 |
+
for i in range(1, num_labels): # Skip label 0 (background)
|
3319 |
+
x, y, w, h, area = stats[i]
|
3320 |
+
if area >= min_area:
|
3321 |
+
# Add the bounding box of the valid component
|
3322 |
+
bounding_boxes.append((x, y, w, h))
|
3323 |
+
# Keep the valid components in the final cleaned mask
|
3324 |
+
final_cleaned_mask[labels == i] = 255
|
3325 |
+
|
3326 |
+
# Convert the final cleaned mask back to a Pillow image
|
3327 |
+
cleaned_pil_mask = Image.fromarray(final_cleaned_mask)
|
3328 |
+
|
3329 |
+
return bounding_boxes, cleaned_pil_mask
|
3330 |
+
|
3331 |
+
bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask])
|
3332 |
+
|
3333 |
+
# combine the masks, also draw the bounding boxes
|
3334 |
+
combined_masks = []
|
3335 |
+
for i_image in range(len(mask)):
|
3336 |
+
noisy_mask = np.array(mask[i_image].convert("RGB"))
|
3337 |
+
bbox = bboxs[i_image]
|
3338 |
+
clean_mask = np.array(filtered_masks[i_image].convert("RGB"))
|
3339 |
+
combined_mask = noisy_mask * 0.4 + clean_mask
|
3340 |
+
combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
|
3341 |
+
for x, y, w, h in bbox:
|
3342 |
+
cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2)
|
3343 |
+
combined_mask = Image.fromarray(combined_mask)
|
3344 |
+
combined_masks.append(combined_mask)
|
3345 |
+
|
3346 |
+
def extend_the_mask(xywh, factor=1.5):
|
3347 |
+
x, y, w, h = xywh
|
3348 |
+
x -= w * (factor - 1) / 2
|
3349 |
+
y -= h * (factor - 1) / 2
|
3350 |
+
w *= factor
|
3351 |
+
h *= factor
|
3352 |
+
return x, y, w, h
|
3353 |
+
|
3354 |
+
def resize_the_mask(xywh, original_size, target_size):
|
3355 |
+
x, y, w, h = xywh
|
3356 |
+
x *= target_size[0] / original_size[0]
|
3357 |
+
y *= target_size[1] / original_size[1]
|
3358 |
+
w *= target_size[0] / original_size[0]
|
3359 |
+
h *= target_size[1] / original_size[1]
|
3360 |
+
x, y, w, h = int(x), int(y), int(w), int(h)
|
3361 |
+
return x, y, w, h
|
3362 |
+
|
3363 |
+
def crop_image(image, xywh, mask_h, mask_w, factor=1.0):
|
3364 |
+
x, y, w, h = xywh
|
3365 |
+
x, y, w, h = resize_the_mask((x, y, w, h), (mask_h, mask_w), image.size)
|
3366 |
+
_x, _y, _w, _h = extend_the_mask((x, y, w, h), factor=factor)
|
3367 |
+
crop = image.crop((_x, _y, _x + _w, _y + _h))
|
3368 |
+
return crop
|
3369 |
+
|
3370 |
+
original_images = [image[0] for image in original_images]
|
3371 |
+
if isinstance(original_images[0], str):
|
3372 |
+
original_images = [Image.open(image) for image in original_images]
|
3373 |
+
|
3374 |
+
mask_h, mask_w = filtered_masks[0].size
|
3375 |
+
cropped_images = []
|
3376 |
+
for _image, _bboxs in zip(original_images, bboxs):
|
3377 |
+
for _bbox in _bboxs:
|
3378 |
+
cropped_images.append(crop_image(_image, _bbox, mask_h, mask_w, factor=crop_expand))
|
3379 |
+
|
3380 |
+
return combined_masks, cropped_images
|
3381 |
+
|
3382 |
+
run_crop_button.click(run_crop,
|
3383 |
+
inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
|
3384 |
+
crop_expand_slider, distance_threshold_slider, distance_power_slider, area_threshold_slider],
|
3385 |
+
outputs=[mask_gallery, crop_gallery])
|
3386 |
+
|
3387 |
|
3388 |
with gr.Tab('📄About'):
|
3389 |
with gr.Column():
|