Spaces:
Running
on
Zero
Running
on
Zero
add mlp align
Browse files
app.py
CHANGED
@@ -7,9 +7,9 @@ from functools import partial
|
|
7 |
from io import BytesIO
|
8 |
import json
|
9 |
import os
|
10 |
-
from pprint import pprint
|
11 |
import uuid
|
12 |
import zipfile
|
|
|
13 |
|
14 |
from einops import rearrange
|
15 |
from matplotlib import pyplot as plt
|
@@ -42,7 +42,6 @@ from ncut_pytorch.backbone import MODEL_DICT, LAYER_DICT, RES_DICT
|
|
42 |
from ncut_pytorch import NCUT
|
43 |
from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
|
44 |
|
45 |
-
RUN_COUNT = 0
|
46 |
|
47 |
DATASETS = {
|
48 |
'Common': [
|
@@ -314,8 +313,7 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
314 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
315 |
return blended.astype(np.uint8)
|
316 |
|
317 |
-
|
318 |
-
load_model("CLIP(ViT-B-16/openai)")
|
319 |
def segment_fg_bg(images):
|
320 |
|
321 |
images = F.interpolate(images, (224, 224), mode="bilinear")
|
@@ -388,6 +386,8 @@ def segment_fg_bg(images):
|
|
388 |
|
389 |
|
390 |
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'):
|
|
|
|
|
391 |
progress = gr.Progress()
|
392 |
progress(progess_start, desc="Finding Clusters by FPS")
|
393 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -462,7 +462,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
462 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
463 |
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
|
464 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
465 |
-
top10_image_idx[idx.item()] = mask_sort_idx[:
|
466 |
# do the sorting
|
467 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
468 |
fps_idx = fps_idx[_sort_idx]
|
@@ -480,51 +480,46 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
480 |
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
481 |
|
482 |
|
483 |
-
|
484 |
-
|
485 |
-
num_plots = clusters // 5
|
486 |
-
plot_step_float = (1.0 - progess_start) / num_plots
|
487 |
-
for i_fig in range(num_plots):
|
488 |
-
progress(progess_start + i_fig * plot_step_float, desc=f"Plotting {title}")
|
489 |
-
if not advanced:
|
490 |
-
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
491 |
-
if advanced:
|
492 |
-
fig, axs = plt.subplots(5, 5, figsize=(15, 15))
|
493 |
for ax in axs.flatten():
|
494 |
ax.axis("off")
|
495 |
-
for j, idx in enumerate(
|
496 |
heatmap = fps_heatmaps[idx.item()]
|
497 |
-
# mask = (heatmap > 0.1).float()
|
498 |
-
# sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
|
499 |
size = (images.shape[1], images.shape[2])
|
500 |
heatmap = apply_reds_colormap(heatmap, size)
|
501 |
-
# for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
502 |
image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()]
|
503 |
for i, image_idx in enumerate(image_idxs):
|
504 |
-
# _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
|
505 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
506 |
axs[i, j].imshow(_heatmap)
|
507 |
if i == 0:
|
508 |
-
axs[i, j].set_title(f"{title} {
|
509 |
-
i_cluster += 1
|
510 |
plt.tight_layout(h_pad=0.5, w_pad=0.3)
|
511 |
-
|
512 |
-
filename = uuid.uuid4()
|
513 |
tmp_path = f"/tmp/{filename}.png"
|
514 |
plt.savefig(tmp_path, bbox_inches='tight', dpi=72)
|
515 |
-
|
516 |
-
img = Image.open(tmp_path)
|
517 |
-
img = img.convert("RGB")
|
518 |
-
img = copy.deepcopy(img)
|
519 |
-
|
520 |
os.remove(tmp_path)
|
521 |
-
|
522 |
-
fig_images.append(img)
|
523 |
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
|
525 |
return fig_images, ret_magnitude
|
526 |
|
527 |
-
def make_cluster_plot_advanced(eigvecs, images, h=64, w=64
|
528 |
heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
|
529 |
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
530 |
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
@@ -545,15 +540,9 @@ def make_cluster_plot_advanced(eigvecs, images, h=64, w=64, num_fg=100, num_bg=1
|
|
545 |
bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
|
546 |
other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
|
547 |
|
548 |
-
fg_images =
|
549 |
-
|
550 |
-
|
551 |
-
bg_images = []
|
552 |
-
if num_bg > 0:
|
553 |
-
bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_bg, eig_idx=bg_idx, title="bg" if small_title else "cluster")
|
554 |
-
other_images = []
|
555 |
-
if num_other > 0:
|
556 |
-
other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_other, eig_idx=other_idx, title="other" if small_title else "cluster")
|
557 |
|
558 |
cluster_images = fg_images + bg_images + other_images
|
559 |
|
@@ -842,7 +831,7 @@ def ncut_run(
|
|
842 |
if advanced:
|
843 |
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
844 |
else:
|
845 |
-
cluster_images, eig_magnitude =
|
846 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
847 |
|
848 |
norm_images = None
|
@@ -903,10 +892,6 @@ if USE_HUGGINGFACE_ZEROGPU:
|
|
903 |
def longer_run(*args, **kwargs):
|
904 |
return _ncut_run(*args, **kwargs)
|
905 |
|
906 |
-
@spaces.GPU(duration=90)
|
907 |
-
def quite_long_run(*args, **kwargs):
|
908 |
-
return _ncut_run(*args, **kwargs)
|
909 |
-
|
910 |
@spaces.GPU(duration=120)
|
911 |
def super_duper_long_run(*args, **kwargs):
|
912 |
return _ncut_run(*args, **kwargs)
|
@@ -924,9 +909,6 @@ if not USE_HUGGINGFACE_ZEROGPU:
|
|
924 |
def longer_run(*args, **kwargs):
|
925 |
return _ncut_run(*args, **kwargs)
|
926 |
|
927 |
-
def quite_long_run(*args, **kwargs):
|
928 |
-
return _ncut_run(*args, **kwargs)
|
929 |
-
|
930 |
def super_duper_long_run(*args, **kwargs):
|
931 |
return _ncut_run(*args, **kwargs)
|
932 |
|
@@ -1241,10 +1223,7 @@ def run_fn(
|
|
1241 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
1242 |
"advanced": advanced,
|
1243 |
}
|
1244 |
-
|
1245 |
-
RUN_COUNT += 1
|
1246 |
-
print(f"Run Count: {RUN_COUNT}")
|
1247 |
-
pprint(kwargs)
|
1248 |
|
1249 |
try:
|
1250 |
# try to aquiare GPU, can fail if the user is out of GPU quota
|
@@ -1256,15 +1235,13 @@ def run_fn(
|
|
1256 |
return super_duper_long_run(model, images, **kwargs)
|
1257 |
|
1258 |
num_images = len(images)
|
1259 |
-
if num_images
|
1260 |
return super_duper_long_run(model, images, **kwargs)
|
1261 |
if 'diffusion' in model_name.lower():
|
1262 |
return super_duper_long_run(model, images, **kwargs)
|
1263 |
if recursion:
|
1264 |
return longer_run(model, images, **kwargs)
|
1265 |
-
if num_images
|
1266 |
-
return quite_long_run(model, images, **kwargs)
|
1267 |
-
if num_images > 30:
|
1268 |
return longer_run(model, images, **kwargs)
|
1269 |
if old_school_ncut:
|
1270 |
return longer_run(model, images, **kwargs)
|
@@ -1284,7 +1261,7 @@ def run_fn(
|
|
1284 |
except gr.Error as e:
|
1285 |
# I assume this is a GPU quota error
|
1286 |
|
1287 |
-
info1 = 'Running out of HuggingFace GPU Quota?</br> Please try <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/demo/">Demo hosted at UPenn</a
|
1288 |
info2 = 'Or try use the Python package that powers this app: <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/">ncut-pytorch</a>'
|
1289 |
info = info1 + info2
|
1290 |
|
@@ -1292,6 +1269,165 @@ def run_fn(
|
|
1292 |
raise gr.Error(message, duration=0)
|
1293 |
|
1294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1295 |
def make_input_video_section():
|
1296 |
# gr.Markdown('### Input Video')
|
1297 |
input_gallery = gr.Video(value=None, label="Select video", elem_id="video-input", height="auto", show_share_button=False, interactive=True)
|
@@ -1426,7 +1562,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
|
|
1426 |
def make_example(name, images, dataset_name):
|
1427 |
with gr.Row():
|
1428 |
button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
|
1429 |
-
gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140)
|
1430 |
button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, 100, is_random=True, seed=42)), outputs=[input_gallery])
|
1431 |
return gallery, button
|
1432 |
example_items = [
|
@@ -1641,7 +1777,7 @@ def flip_rgb_gallery(images, axis=0):
|
|
1641 |
images = to_pil_images(images, resize=False)
|
1642 |
return images
|
1643 |
|
1644 |
-
def
|
1645 |
with gr.Row():
|
1646 |
rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
|
1647 |
rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
|
@@ -1760,7 +1896,7 @@ def add_download_button(gallery, filename_prefix="output"):
|
|
1760 |
def make_output_images_section():
|
1761 |
gr.Markdown('### Output Images')
|
1762 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
1763 |
-
|
1764 |
return output_gallery
|
1765 |
|
1766 |
def make_parameters_section(is_lisa=False, model_ratio=True):
|
@@ -1880,7 +2016,7 @@ with demo:
|
|
1880 |
|
1881 |
with gr.Column(scale=5, min_width=200):
|
1882 |
output_gallery = make_output_images_section()
|
1883 |
-
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=
|
1884 |
[
|
1885 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1886 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
@@ -2024,15 +2160,15 @@ with demo:
|
|
2024 |
with gr.Column(scale=5, min_width=200):
|
2025 |
gr.Markdown('### Output (Recursion #1)')
|
2026 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2027 |
-
|
2028 |
with gr.Column(scale=5, min_width=200):
|
2029 |
gr.Markdown('### Output (Recursion #2)')
|
2030 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2031 |
-
|
2032 |
with gr.Column(scale=5, min_width=200):
|
2033 |
gr.Markdown('### Output (Recursion #3)')
|
2034 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2035 |
-
|
2036 |
with gr.Row():
|
2037 |
|
2038 |
with gr.Column(scale=5, min_width=200):
|
@@ -2089,7 +2225,7 @@ with demo:
|
|
2089 |
with gr.Column(scale=5, min_width=200):
|
2090 |
gr.Markdown('### Output (Recursion #1)')
|
2091 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2092 |
-
|
2093 |
add_download_button(l1_gallery, "ncut_embed_recur1")
|
2094 |
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2095 |
add_download_button(l1_norm_gallery, "eig_norm_recur1")
|
@@ -2098,7 +2234,7 @@ with demo:
|
|
2098 |
with gr.Column(scale=5, min_width=200):
|
2099 |
gr.Markdown('### Output (Recursion #2)')
|
2100 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2101 |
-
|
2102 |
add_download_button(l2_gallery, "ncut_embed_recur2")
|
2103 |
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2104 |
add_download_button(l2_norm_gallery, "eig_norm_recur2")
|
@@ -2107,7 +2243,7 @@ with demo:
|
|
2107 |
with gr.Column(scale=5, min_width=200):
|
2108 |
gr.Markdown('### Output (Recursion #3)')
|
2109 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2110 |
-
|
2111 |
add_download_button(l3_gallery, "ncut_embed_recur3")
|
2112 |
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2113 |
add_download_button(l3_norm_gallery, "eig_norm_recur3")
|
@@ -2335,15 +2471,15 @@ with demo:
|
|
2335 |
# add_output_images_buttons(l3_gallery)
|
2336 |
gr.Markdown('### Output (Recursion #1)')
|
2337 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
2338 |
-
|
2339 |
add_download_button(l1_gallery, "modelaligned_recur1")
|
2340 |
gr.Markdown('### Output (Recursion #2)')
|
2341 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
2342 |
-
|
2343 |
add_download_button(l2_gallery, "modelaligned_recur2")
|
2344 |
gr.Markdown('### Output (Recursion #3)')
|
2345 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
2346 |
-
|
2347 |
add_download_button(l3_gallery, "modelaligned_recur3")
|
2348 |
|
2349 |
with gr.Row():
|
@@ -2413,7 +2549,7 @@ with demo:
|
|
2413 |
gr.Markdown(f'### Output Images')
|
2414 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2415 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
2416 |
-
|
2417 |
[
|
2418 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2419 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
@@ -2479,13 +2615,52 @@ with demo:
|
|
2479 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
2480 |
|
2481 |
with gr.Tab('Compare Models (Advanced)', visible=False) as tab_compare_models_advanced:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2482 |
def add_one_model(i_model=1):
|
2483 |
with gr.Column(scale=5, min_width=200) as col:
|
2484 |
gr.Markdown(f'### Output Images')
|
2485 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2486 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
2487 |
-
|
2488 |
add_download_button(output_gallery, f"ncut_embed")
|
|
|
|
|
|
|
2489 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2490 |
add_download_button(norm_gallery, f"eig_norm")
|
2491 |
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
@@ -2515,8 +2690,12 @@ with demo:
|
|
2515 |
outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
|
2516 |
)
|
2517 |
|
2518 |
-
|
|
|
|
|
2519 |
|
|
|
|
|
2520 |
with gr.Row():
|
2521 |
with gr.Column(scale=5, min_width=200):
|
2522 |
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
|
@@ -2524,7 +2703,8 @@ with demo:
|
|
2524 |
|
2525 |
|
2526 |
for i in range(3):
|
2527 |
-
add_one_model()
|
|
|
2528 |
|
2529 |
# Create rows and buttons in a loop
|
2530 |
rows = []
|
@@ -2537,7 +2717,8 @@ with demo:
|
|
2537 |
with row:
|
2538 |
for j in range(4):
|
2539 |
with gr.Column(scale=5, min_width=200):
|
2540 |
-
add_one_model()
|
|
|
2541 |
|
2542 |
button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
|
2543 |
buttons.append(button)
|
@@ -2553,7 +2734,11 @@ with demo:
|
|
2553 |
# Last button only reveals the last row and hides itself
|
2554 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
2555 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
2556 |
-
|
|
|
|
|
|
|
|
|
2557 |
|
2558 |
with gr.Tab('📄About'):
|
2559 |
with gr.Column():
|
|
|
7 |
from io import BytesIO
|
8 |
import json
|
9 |
import os
|
|
|
10 |
import uuid
|
11 |
import zipfile
|
12 |
+
import multiprocessing as mp
|
13 |
|
14 |
from einops import rearrange
|
15 |
from matplotlib import pyplot as plt
|
|
|
42 |
from ncut_pytorch import NCUT
|
43 |
from ncut_pytorch import eigenvector_to_rgb, rotate_rgb_cube
|
44 |
|
|
|
45 |
|
46 |
DATASETS = {
|
47 |
'Common': [
|
|
|
313 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
314 |
return blended.astype(np.uint8)
|
315 |
|
316 |
+
|
|
|
317 |
def segment_fg_bg(images):
|
318 |
|
319 |
images = F.interpolate(images, (224, 224), mode="bilinear")
|
|
|
386 |
|
387 |
|
388 |
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'):
|
389 |
+
if clusters == 0:
|
390 |
+
return [], []
|
391 |
progress = gr.Progress()
|
392 |
progress(progess_start, desc="Finding Clusters by FPS")
|
393 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
462 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
463 |
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
|
464 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
465 |
+
top10_image_idx[idx.item()] = mask_sort_idx[:6]
|
466 |
# do the sorting
|
467 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
468 |
fps_idx = fps_idx[_sort_idx]
|
|
|
480 |
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
481 |
|
482 |
|
483 |
+
def plot_cluster_images(fps_idx_chunk, chunk_idx):
|
484 |
+
fig, axs = plt.subplots(3, 5, figsize=(15, 9)) if not advanced else plt.subplots(6, 5, figsize=(15, 18))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
for ax in axs.flatten():
|
486 |
ax.axis("off")
|
487 |
+
for j, idx in enumerate(fps_idx_chunk):
|
488 |
heatmap = fps_heatmaps[idx.item()]
|
|
|
|
|
489 |
size = (images.shape[1], images.shape[2])
|
490 |
heatmap = apply_reds_colormap(heatmap, size)
|
|
|
491 |
image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()]
|
492 |
for i, image_idx in enumerate(image_idxs):
|
|
|
493 |
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
494 |
axs[i, j].imshow(_heatmap)
|
495 |
if i == 0:
|
496 |
+
axs[i, j].set_title(f"{title} {chunk_idx * 5 + j + 1}", fontsize=24)
|
|
|
497 |
plt.tight_layout(h_pad=0.5, w_pad=0.3)
|
498 |
+
filename = f"{datetime.now():%Y%m%d%H%M%S%f}_{uuid.uuid4().hex}"
|
|
|
499 |
tmp_path = f"/tmp/{filename}.png"
|
500 |
plt.savefig(tmp_path, bbox_inches='tight', dpi=72)
|
501 |
+
img = Image.open(tmp_path).convert("RGB")
|
|
|
|
|
|
|
|
|
502 |
os.remove(tmp_path)
|
|
|
|
|
503 |
plt.close()
|
504 |
+
return img
|
505 |
+
|
506 |
+
fig_images = []
|
507 |
+
num_plots = clusters // 5
|
508 |
+
plot_step_float = (1.0 - progess_start) / num_plots
|
509 |
+
fps_idx_chunks = [fps_idx[i*5:(i+1)*5] for i in range(num_plots)]
|
510 |
+
|
511 |
+
# with mp.Pool(processes=mp.cpu_count()) as pool:
|
512 |
+
# results = [pool.apply_async(plot_cluster_images, args=(chunk, i)) for i, chunk in enumerate(fps_idx_chunks)]
|
513 |
+
# for i, result in enumerate(results):
|
514 |
+
# progress(progess_start + i * plot_step_float, desc=f"Plotted {title}")
|
515 |
+
# fig_images.append(result.get())
|
516 |
+
for i, chunk in enumerate(fps_idx_chunks):
|
517 |
+
progress(progess_start + i * plot_step_float, desc=f"Plotted {title}")
|
518 |
+
fig_images.append(plot_cluster_images(chunk, i))
|
519 |
|
520 |
return fig_images, ret_magnitude
|
521 |
|
522 |
+
def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
|
523 |
heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
|
524 |
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
525 |
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
|
|
540 |
bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
|
541 |
other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
|
542 |
|
543 |
+
fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=fg_idx, title="fg")
|
544 |
+
bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=20, eig_idx=bg_idx, title="bg")
|
545 |
+
other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=0, eig_idx=other_idx, title="other")
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
|
547 |
cluster_images = fg_images + bg_images + other_images
|
548 |
|
|
|
831 |
if advanced:
|
832 |
cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
|
833 |
else:
|
834 |
+
cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=False)
|
835 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
836 |
|
837 |
norm_images = None
|
|
|
892 |
def longer_run(*args, **kwargs):
|
893 |
return _ncut_run(*args, **kwargs)
|
894 |
|
|
|
|
|
|
|
|
|
895 |
@spaces.GPU(duration=120)
|
896 |
def super_duper_long_run(*args, **kwargs):
|
897 |
return _ncut_run(*args, **kwargs)
|
|
|
909 |
def longer_run(*args, **kwargs):
|
910 |
return _ncut_run(*args, **kwargs)
|
911 |
|
|
|
|
|
|
|
912 |
def super_duper_long_run(*args, **kwargs):
|
913 |
return _ncut_run(*args, **kwargs)
|
914 |
|
|
|
1223 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
1224 |
"advanced": advanced,
|
1225 |
}
|
1226 |
+
# print(kwargs)
|
|
|
|
|
|
|
1227 |
|
1228 |
try:
|
1229 |
# try to aquiare GPU, can fail if the user is out of GPU quota
|
|
|
1235 |
return super_duper_long_run(model, images, **kwargs)
|
1236 |
|
1237 |
num_images = len(images)
|
1238 |
+
if num_images >= 100:
|
1239 |
return super_duper_long_run(model, images, **kwargs)
|
1240 |
if 'diffusion' in model_name.lower():
|
1241 |
return super_duper_long_run(model, images, **kwargs)
|
1242 |
if recursion:
|
1243 |
return longer_run(model, images, **kwargs)
|
1244 |
+
if num_images >= 50:
|
|
|
|
|
1245 |
return longer_run(model, images, **kwargs)
|
1246 |
if old_school_ncut:
|
1247 |
return longer_run(model, images, **kwargs)
|
|
|
1261 |
except gr.Error as e:
|
1262 |
# I assume this is a GPU quota error
|
1263 |
|
1264 |
+
info1 = 'Running out of HuggingFace GPU Quota?</br> Please try <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/demo/">Demo hosted at UPenn</a></br>'
|
1265 |
info2 = 'Or try use the Python package that powers this app: <a style="white-space: nowrap;text-underline-offset: 2px;color: var(--body-text-color)" href="https://ncut-pytorch.readthedocs.io/en/latest/">ncut-pytorch</a>'
|
1266 |
info = info1 + info2
|
1267 |
|
|
|
1269 |
raise gr.Error(message, duration=0)
|
1270 |
|
1271 |
|
1272 |
+
import torch
|
1273 |
+
from torch import nn
|
1274 |
+
from torch.utils.data import Dataset, DataLoader
|
1275 |
+
import pytorch_lightning as pl
|
1276 |
+
|
1277 |
+
# Custom Dataset
|
1278 |
+
class TwoTensorDataset(Dataset):
|
1279 |
+
def __init__(self, A, B):
|
1280 |
+
self.A = A
|
1281 |
+
self.B = B
|
1282 |
+
|
1283 |
+
def __len__(self):
|
1284 |
+
return len(self.A)
|
1285 |
+
|
1286 |
+
def __getitem__(self, idx):
|
1287 |
+
return self.A[idx], self.B[idx]
|
1288 |
+
|
1289 |
+
# MLP model
|
1290 |
+
class MLP(pl.LightningModule):
|
1291 |
+
def __init__(self, num_layer=3, width=512, lr=3e-4, fitting_steps=10000, seg_loss_lambda=1.0):
|
1292 |
+
super().__init__()
|
1293 |
+
layers = [nn.Linear(3, width), nn.GELU()]
|
1294 |
+
for _ in range(num_layer - 1):
|
1295 |
+
layers.append(nn.Linear(width, width))
|
1296 |
+
layers.append(nn.GELU())
|
1297 |
+
layers.append(nn.Linear(width, 3))
|
1298 |
+
self.layers = nn.Sequential(*layers)
|
1299 |
+
self.mse_loss = nn.MSELoss()
|
1300 |
+
self.lr = lr
|
1301 |
+
self.fitting_steps = fitting_steps
|
1302 |
+
self.seg_loss_lambda = seg_loss_lambda
|
1303 |
+
self.progress = gr.Progress()
|
1304 |
+
|
1305 |
+
def forward(self, x):
|
1306 |
+
return self.layers(x)
|
1307 |
+
|
1308 |
+
def training_step(self, batch, batch_idx):
|
1309 |
+
x, y = batch
|
1310 |
+
y_hat = self.forward(x)
|
1311 |
+
loss = self.mse_loss(y_hat, y)
|
1312 |
+
# loss = torch.nn.functional.mse_loss(torch.log(y_hat), torch.log(y))
|
1313 |
+
self.log("train_loss", loss)
|
1314 |
+
|
1315 |
+
# add segmentation constraint
|
1316 |
+
bsz = x.shape[0]
|
1317 |
+
sample_size = 1000
|
1318 |
+
if bsz > sample_size:
|
1319 |
+
idx = torch.randperm(bsz)[:sample_size]
|
1320 |
+
x = x[idx]
|
1321 |
+
y_hat = y_hat[idx]
|
1322 |
+
|
1323 |
+
old_dist = torch.pdist(x, p=2)
|
1324 |
+
new_dist = torch.pdist(y_hat, p=2)
|
1325 |
+
# seg_loss = torch.log((old_dist - new_dist)).pow(2).mean()
|
1326 |
+
seg_loss = self.mse_loss(old_dist, new_dist)
|
1327 |
+
self.log("seg_loss", seg_loss)
|
1328 |
+
loss += seg_loss * self.seg_loss_lambda
|
1329 |
+
|
1330 |
+
step = self.global_step
|
1331 |
+
if step % 100 == 0:
|
1332 |
+
self.progress(step / self.fitting_steps, desc="Fitting MLP")
|
1333 |
+
|
1334 |
+
return loss
|
1335 |
+
|
1336 |
+
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
1337 |
+
x = batch[0]
|
1338 |
+
y_hat = self.forward(x)
|
1339 |
+
return y_hat
|
1340 |
+
|
1341 |
+
def configure_optimizers(self):
|
1342 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
1343 |
+
return optimizer
|
1344 |
+
|
1345 |
+
|
1346 |
+
def fit_trans(rgb1, rgb2, num_layer=3, width=512, batch_size=256, lr=3e-4, fitting_steps=10000, fps_sample=4096, seg_loss_lambda=1.0):
|
1347 |
+
A = rgb1.clone()
|
1348 |
+
B = rgb2.clone()
|
1349 |
+
|
1350 |
+
# FPS sample on the data
|
1351 |
+
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
1352 |
+
A_idx = farthest_point_sampling(A, fps_sample)
|
1353 |
+
B_idx = farthest_point_sampling(B, fps_sample)
|
1354 |
+
A_B_idx = np.concatenate([A_idx, B_idx])
|
1355 |
+
A = A[A_B_idx]
|
1356 |
+
B = B[A_B_idx]
|
1357 |
+
|
1358 |
+
from torch.utils.data import DataLoader, TensorDataset
|
1359 |
+
# Dataset and DataLoader
|
1360 |
+
dataset = TwoTensorDataset(A, B)
|
1361 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
1362 |
+
|
1363 |
+
# Initialize model and trainer
|
1364 |
+
mlp = MLP(num_layer=num_layer, width=width, lr=lr, fitting_steps=fitting_steps, seg_loss_lambda=seg_loss_lambda)
|
1365 |
+
trainer = pl.Trainer(
|
1366 |
+
auto_scale_batch_size='power',
|
1367 |
+
max_epochs=100000,
|
1368 |
+
gpus=1,
|
1369 |
+
max_steps=fitting_steps,
|
1370 |
+
enable_checkpointing=False,
|
1371 |
+
enable_progress_bar=False,
|
1372 |
+
gradient_clip_val=1.0
|
1373 |
+
)
|
1374 |
+
|
1375 |
+
# Create a DataLoader for tensor A
|
1376 |
+
batch_size = 256 # Define your batch size
|
1377 |
+
data_loader = DataLoader(TensorDataset(rgb1), batch_size=batch_size, shuffle=False)
|
1378 |
+
|
1379 |
+
|
1380 |
+
# Train the model
|
1381 |
+
trainer.fit(mlp, dataloader)
|
1382 |
+
|
1383 |
+
|
1384 |
+
results = trainer.predict(mlp, data_loader)
|
1385 |
+
A_transformed = torch.cat(results, dim=0)
|
1386 |
+
|
1387 |
+
return A_transformed
|
1388 |
+
|
1389 |
+
if USE_HUGGINGFACE_ZEROGPU:
|
1390 |
+
@spaces.GPU(duration=60)
|
1391 |
+
def _run_mlp_fit(*args, **kwargs):
|
1392 |
+
return fit_trans(*args, **kwargs)
|
1393 |
+
else:
|
1394 |
+
def _run_mlp_fit(*args, **kwargs):
|
1395 |
+
return fit_trans(*args, **kwargs)
|
1396 |
+
|
1397 |
+
|
1398 |
+
def run_mlp_fit(input_gallery, target_gallery, num_layer=3, width=512, batch_size=256, lr=3e-4, fitting_steps=10000, fps_sample=4096, seg_loss_lambda=1.0):
|
1399 |
+
# print("Fitting MLP")
|
1400 |
+
# print("Target Gallery Length:", len(target_gallery))
|
1401 |
+
# print("Input Gallery Length:", len(input_gallery))
|
1402 |
+
if target_gallery is None or len(target_gallery) == 0:
|
1403 |
+
raise gr.Error("No target images selected. Please use the Mark button to select the target images.")
|
1404 |
+
if input_gallery is None or len(input_gallery) == 0:
|
1405 |
+
raise gr.Error("No input images selected.")
|
1406 |
+
def gallery_to_rgb(gallery):
|
1407 |
+
images = [tup[0] for tup in gallery]
|
1408 |
+
rgb = []
|
1409 |
+
for image in images:
|
1410 |
+
if isinstance(image, str):
|
1411 |
+
image = Image.open(image)
|
1412 |
+
image = image.convert('RGB')
|
1413 |
+
image = torch.tensor(np.array(image)).float()
|
1414 |
+
image = image / 255
|
1415 |
+
rgb.append(image)
|
1416 |
+
rgb = torch.stack(rgb)
|
1417 |
+
shape = rgb.shape
|
1418 |
+
rgb = rgb.reshape(-1, 3)
|
1419 |
+
return rgb, shape
|
1420 |
+
|
1421 |
+
target_rgb, target_shape = gallery_to_rgb(target_gallery)
|
1422 |
+
input_rgb, input_shape = gallery_to_rgb(input_gallery)
|
1423 |
+
|
1424 |
+
input_transformed = _run_mlp_fit(input_rgb, target_rgb, num_layer=num_layer, width=width, batch_size=batch_size, lr=lr,
|
1425 |
+
fitting_steps=fitting_steps, fps_sample=fps_sample, seg_loss_lambda=seg_loss_lambda)
|
1426 |
+
input_transformed = input_transformed.reshape(*input_shape)
|
1427 |
+
pil_images = to_pil_images(input_transformed, resize=False)
|
1428 |
+
return pil_images
|
1429 |
+
|
1430 |
+
|
1431 |
def make_input_video_section():
|
1432 |
# gr.Markdown('### Input Video')
|
1433 |
input_gallery = gr.Video(value=None, label="Select video", elem_id="video-input", height="auto", show_share_button=False, interactive=True)
|
|
|
1562 |
def make_example(name, images, dataset_name):
|
1563 |
with gr.Row():
|
1564 |
button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
|
1565 |
+
gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140, allow_preview=False)
|
1566 |
button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, 100, is_random=True, seed=42)), outputs=[input_gallery])
|
1567 |
return gallery, button
|
1568 |
example_items = [
|
|
|
1777 |
images = to_pil_images(images, resize=False)
|
1778 |
return images
|
1779 |
|
1780 |
+
def add_rotate_flip_buttons(output_gallery):
|
1781 |
with gr.Row():
|
1782 |
rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
|
1783 |
rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
|
|
|
1896 |
def make_output_images_section():
|
1897 |
gr.Markdown('### Output Images')
|
1898 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
1899 |
+
add_rotate_flip_buttons(output_gallery)
|
1900 |
return output_gallery
|
1901 |
|
1902 |
def make_parameters_section(is_lisa=False, model_ratio=True):
|
|
|
2016 |
|
2017 |
with gr.Column(scale=5, min_width=200):
|
2018 |
output_gallery = make_output_images_section()
|
2019 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2020 |
[
|
2021 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2022 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
2160 |
with gr.Column(scale=5, min_width=200):
|
2161 |
gr.Markdown('### Output (Recursion #1)')
|
2162 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2163 |
+
add_rotate_flip_buttons(l1_gallery)
|
2164 |
with gr.Column(scale=5, min_width=200):
|
2165 |
gr.Markdown('### Output (Recursion #2)')
|
2166 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2167 |
+
add_rotate_flip_buttons(l2_gallery)
|
2168 |
with gr.Column(scale=5, min_width=200):
|
2169 |
gr.Markdown('### Output (Recursion #3)')
|
2170 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2171 |
+
add_rotate_flip_buttons(l3_gallery)
|
2172 |
with gr.Row():
|
2173 |
|
2174 |
with gr.Column(scale=5, min_width=200):
|
|
|
2225 |
with gr.Column(scale=5, min_width=200):
|
2226 |
gr.Markdown('### Output (Recursion #1)')
|
2227 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2228 |
+
add_rotate_flip_buttons(l1_gallery)
|
2229 |
add_download_button(l1_gallery, "ncut_embed_recur1")
|
2230 |
l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2231 |
add_download_button(l1_norm_gallery, "eig_norm_recur1")
|
|
|
2234 |
with gr.Column(scale=5, min_width=200):
|
2235 |
gr.Markdown('### Output (Recursion #2)')
|
2236 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2237 |
+
add_rotate_flip_buttons(l2_gallery)
|
2238 |
add_download_button(l2_gallery, "ncut_embed_recur2")
|
2239 |
l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2240 |
add_download_button(l2_norm_gallery, "eig_norm_recur2")
|
|
|
2243 |
with gr.Column(scale=5, min_width=200):
|
2244 |
gr.Markdown('### Output (Recursion #3)')
|
2245 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2246 |
+
add_rotate_flip_buttons(l3_gallery)
|
2247 |
add_download_button(l3_gallery, "ncut_embed_recur3")
|
2248 |
l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2249 |
add_download_button(l3_norm_gallery, "eig_norm_recur3")
|
|
|
2471 |
# add_output_images_buttons(l3_gallery)
|
2472 |
gr.Markdown('### Output (Recursion #1)')
|
2473 |
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
2474 |
+
add_rotate_flip_buttons(l1_gallery)
|
2475 |
add_download_button(l1_gallery, "modelaligned_recur1")
|
2476 |
gr.Markdown('### Output (Recursion #2)')
|
2477 |
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
2478 |
+
add_rotate_flip_buttons(l2_gallery)
|
2479 |
add_download_button(l2_gallery, "modelaligned_recur2")
|
2480 |
gr.Markdown('### Output (Recursion #3)')
|
2481 |
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
2482 |
+
add_rotate_flip_buttons(l3_gallery)
|
2483 |
add_download_button(l3_gallery, "modelaligned_recur3")
|
2484 |
|
2485 |
with gr.Row():
|
|
|
2549 |
gr.Markdown(f'### Output Images')
|
2550 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2551 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
2552 |
+
add_rotate_flip_buttons(output_gallery)
|
2553 |
[
|
2554 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2555 |
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
|
|
2615 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
2616 |
|
2617 |
with gr.Tab('Compare Models (Advanced)', visible=False) as tab_compare_models_advanced:
|
2618 |
+
|
2619 |
+
target_images = gr.State([])
|
2620 |
+
input_images = gr.State([])
|
2621 |
+
def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images):
|
2622 |
+
with gr.Row():
|
2623 |
+
# mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
|
2624 |
+
# mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary')
|
2625 |
+
mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
|
2626 |
+
fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary')
|
2627 |
+
def mark_fn(images, text="target"):
|
2628 |
+
if images is None:
|
2629 |
+
raise gr.Error("No images selected")
|
2630 |
+
if len(images) == 0:
|
2631 |
+
raise gr.Error("No images selected")
|
2632 |
+
num_images = len(images)
|
2633 |
+
gr.Info(f"Marked {num_images} images as {text}")
|
2634 |
+
images = [(Image.open(tup[0]), []) for tup in images]
|
2635 |
+
return images
|
2636 |
+
mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images])
|
2637 |
+
# mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images])
|
2638 |
+
|
2639 |
+
with gr.Accordion("➡️ MLP Parameters", open=False):
|
2640 |
+
num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}")
|
2641 |
+
width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}")
|
2642 |
+
batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}")
|
2643 |
+
lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}")
|
2644 |
+
fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}")
|
2645 |
+
fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}")
|
2646 |
+
segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}")
|
2647 |
+
|
2648 |
+
fit_to_target_button.click(
|
2649 |
+
run_mlp_fit,
|
2650 |
+
inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider],
|
2651 |
+
outputs=[mlp_gallery],
|
2652 |
+
)
|
2653 |
+
|
2654 |
def add_one_model(i_model=1):
|
2655 |
with gr.Column(scale=5, min_width=200) as col:
|
2656 |
gr.Markdown(f'### Output Images')
|
2657 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2658 |
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
2659 |
+
add_rotate_flip_buttons(output_gallery)
|
2660 |
add_download_button(output_gallery, f"ncut_embed")
|
2661 |
+
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
2662 |
+
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
2663 |
+
add_download_button(mlp_gallery, f"mlp_color_align")
|
2664 |
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
2665 |
add_download_button(norm_gallery, f"eig_norm")
|
2666 |
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
|
|
2690 |
outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
|
2691 |
)
|
2692 |
|
2693 |
+
output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery])
|
2694 |
+
|
2695 |
+
return output_gallery
|
2696 |
|
2697 |
+
galleries = []
|
2698 |
+
|
2699 |
with gr.Row():
|
2700 |
with gr.Column(scale=5, min_width=200):
|
2701 |
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
|
|
|
2703 |
|
2704 |
|
2705 |
for i in range(3):
|
2706 |
+
g = add_one_model()
|
2707 |
+
galleries.append(g)
|
2708 |
|
2709 |
# Create rows and buttons in a loop
|
2710 |
rows = []
|
|
|
2717 |
with row:
|
2718 |
for j in range(4):
|
2719 |
with gr.Column(scale=5, min_width=200):
|
2720 |
+
g = add_one_model()
|
2721 |
+
galleries.append(g)
|
2722 |
|
2723 |
button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
|
2724 |
buttons.append(button)
|
|
|
2734 |
# Last button only reveals the last row and hides itself
|
2735 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
2736 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
2737 |
+
|
2738 |
+
|
2739 |
+
# add MLP fitting buttons
|
2740 |
+
|
2741 |
+
|
2742 |
|
2743 |
with gr.Tab('📄About'):
|
2744 |
with gr.Column():
|