Spaces:
Running
on
Zero
Running
on
Zero
add advanced tab for recursive
Browse files
app.py
CHANGED
@@ -325,9 +325,9 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
325 |
mask = mask[mask_sort_idx[:3]]
|
326 |
sort_values.append(mask.mean().item())
|
327 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
328 |
-
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:
|
329 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
330 |
-
top10_image_idx[idx.item()] = mask_sort_idx[:
|
331 |
# do the sorting
|
332 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
333 |
fps_idx = fps_idx[_sort_idx]
|
@@ -427,6 +427,7 @@ def ncut_run(
|
|
427 |
alignedcut_eig_norm_plot=False,
|
428 |
**kwargs,
|
429 |
):
|
|
|
430 |
progress = gr.Progress()
|
431 |
progress(0.2, desc="Feature Extraction")
|
432 |
|
@@ -483,6 +484,7 @@ def ncut_run(
|
|
483 |
|
484 |
if recursion:
|
485 |
rgbs = []
|
|
|
486 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
487 |
inp = features
|
488 |
progress_start = 0.4
|
@@ -509,6 +511,7 @@ def ncut_run(
|
|
509 |
progess_start=progress_start,
|
510 |
)
|
511 |
logging_str += _logging_str
|
|
|
512 |
|
513 |
|
514 |
if "AlignedThreeModelAttnNodes" == model_name:
|
@@ -528,8 +531,42 @@ def ncut_run(
|
|
528 |
inp = eigvecs.reshape(*features.shape[:-1], -1)
|
529 |
if recursion_metric == "cosine":
|
530 |
inp = F.normalize(inp, dim=-1)
|
531 |
-
|
532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
if old_school_ncut: # individual images
|
534 |
logging_str += "Running NCut for each image independently\n"
|
535 |
rgb = []
|
@@ -643,7 +680,7 @@ def ncut_run(
|
|
643 |
norm_images = to_pil_images(norm_images)
|
644 |
logging_str += "Eigenvector Magnitude\n"
|
645 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
646 |
-
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=
|
647 |
|
648 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
649 |
|
@@ -651,26 +688,26 @@ def ncut_run(
|
|
651 |
|
652 |
def _ncut_run(*args, **kwargs):
|
653 |
n_ret = kwargs.pop("n_ret", 1)
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
|
658 |
-
|
659 |
|
660 |
-
|
661 |
-
|
662 |
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
ret = ncut_run(*args, **kwargs)
|
672 |
-
ret = list(ret)[:n_ret] + [ret[-1]]
|
673 |
-
return ret
|
674 |
|
675 |
if USE_HUGGINGFACE_ZEROGPU:
|
676 |
@spaces.GPU(duration=30)
|
@@ -1407,7 +1444,7 @@ with demo:
|
|
1407 |
scroll_to_output=True,
|
1408 |
)
|
1409 |
|
1410 |
-
with gr.Tab('AlignedCut (
|
1411 |
|
1412 |
with gr.Row():
|
1413 |
with gr.Column(scale=5, min_width=200):
|
@@ -1582,6 +1619,78 @@ with demo:
|
|
1582 |
outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
|
1583 |
api_name="API_RecursiveCut"
|
1584 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1585 |
|
1586 |
|
1587 |
with gr.Tab('Video'):
|
@@ -1741,7 +1850,7 @@ with demo:
|
|
1741 |
outputs=[output_gallery, logging_text],
|
1742 |
)
|
1743 |
|
1744 |
-
with gr.Tab('Model Aligned (
|
1745 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
1746 |
gr.Markdown('---')
|
1747 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
@@ -1965,8 +2074,9 @@ with demo:
|
|
1965 |
return gr.update()
|
1966 |
|
1967 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
1968 |
-
hidden_button.change(unlock_tabs_with_info, n_smiles,
|
1969 |
-
hidden_button.change(unlock_tabs, n_smiles,
|
|
|
1970 |
|
1971 |
with gr.Row():
|
1972 |
with gr.Column():
|
|
|
325 |
mask = mask[mask_sort_idx[:3]]
|
326 |
sort_values.append(mask.mean().item())
|
327 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
328 |
+
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
|
329 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
330 |
+
top10_image_idx[idx.item()] = mask_sort_idx[:6]
|
331 |
# do the sorting
|
332 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
333 |
fps_idx = fps_idx[_sort_idx]
|
|
|
427 |
alignedcut_eig_norm_plot=False,
|
428 |
**kwargs,
|
429 |
):
|
430 |
+
advanced = kwargs.get("advanced", False)
|
431 |
progress = gr.Progress()
|
432 |
progress(0.2, desc="Feature Extraction")
|
433 |
|
|
|
484 |
|
485 |
if recursion:
|
486 |
rgbs = []
|
487 |
+
all_eigvecs = []
|
488 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
489 |
inp = features
|
490 |
progress_start = 0.4
|
|
|
511 |
progess_start=progress_start,
|
512 |
)
|
513 |
logging_str += _logging_str
|
514 |
+
all_eigvecs.append(eigvecs.cpu().clone())
|
515 |
|
516 |
|
517 |
if "AlignedThreeModelAttnNodes" == model_name:
|
|
|
531 |
inp = eigvecs.reshape(*features.shape[:-1], -1)
|
532 |
if recursion_metric == "cosine":
|
533 |
inp = F.normalize(inp, dim=-1)
|
534 |
+
|
535 |
+
if not advanced:
|
536 |
+
return rgbs[0], rgbs[1], rgbs[2], logging_str
|
537 |
+
if advanced:
|
538 |
+
cluster_plots, norm_plots = [], []
|
539 |
+
for i in range(3):
|
540 |
+
eigvecs = all_eigvecs[i]
|
541 |
+
# add norm plot, cluster plot
|
542 |
+
start = time.time()
|
543 |
+
progress_start = 0.6
|
544 |
+
progress(progress_start, desc=f"Plotting Clusters Recursion #{i+1}")
|
545 |
+
h, w = features.shape[1], features.shape[2]
|
546 |
+
if torch.cuda.is_available():
|
547 |
+
images = images.cuda()
|
548 |
+
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
549 |
+
cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
|
550 |
+
logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
|
551 |
+
|
552 |
+
norm_images = []
|
553 |
+
vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
|
554 |
+
eig_magnitude = (eig_magnitude - vmin) / (vmax - vmin)
|
555 |
+
eig_magnitude = eig_magnitude.cpu().numpy()
|
556 |
+
colormap = matplotlib.colormaps['Reds']
|
557 |
+
for i_image in range(eig_magnitude.shape[0]):
|
558 |
+
norm_image = colormap(eig_magnitude[i_image])
|
559 |
+
norm_images.append(torch.tensor(norm_image[..., :3]))
|
560 |
+
norm_images = to_pil_images(norm_images)
|
561 |
+
logging_str += f"Recursion #{i+1} Eigenvector Magnitude: [{vmin:.2f}, {vmax:.2f}]\n"
|
562 |
+
gr.Info(f"Recursion #{i+1} Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
563 |
+
|
564 |
+
cluster_plots.append(cluster_images)
|
565 |
+
norm_plots.append(norm_images)
|
566 |
+
|
567 |
+
return *rgbs, *norm_plots, *cluster_plots, logging_str
|
568 |
+
|
569 |
+
|
570 |
if old_school_ncut: # individual images
|
571 |
logging_str += "Running NCut for each image independently\n"
|
572 |
rgb = []
|
|
|
680 |
norm_images = to_pil_images(norm_images)
|
681 |
logging_str += "Eigenvector Magnitude\n"
|
682 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
683 |
+
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
684 |
|
685 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
686 |
|
|
|
688 |
|
689 |
def _ncut_run(*args, **kwargs):
|
690 |
n_ret = kwargs.pop("n_ret", 1)
|
691 |
+
try:
|
692 |
+
if torch.cuda.is_available():
|
693 |
+
torch.cuda.empty_cache()
|
694 |
|
695 |
+
ret = ncut_run(*args, **kwargs)
|
696 |
|
697 |
+
if torch.cuda.is_available():
|
698 |
+
torch.cuda.empty_cache()
|
699 |
|
700 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
701 |
+
return ret
|
702 |
+
except Exception as e:
|
703 |
+
gr.Error(str(e))
|
704 |
+
if torch.cuda.is_available():
|
705 |
+
torch.cuda.empty_cache()
|
706 |
+
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
707 |
+
|
708 |
+
# ret = ncut_run(*args, **kwargs)
|
709 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
710 |
+
# return ret
|
711 |
|
712 |
if USE_HUGGINGFACE_ZEROGPU:
|
713 |
@spaces.GPU(duration=30)
|
|
|
1444 |
scroll_to_output=True,
|
1445 |
)
|
1446 |
|
1447 |
+
with gr.Tab('AlignedCut (Advanced)', visible=False) as tab_alignedcut_advanced:
|
1448 |
|
1449 |
with gr.Row():
|
1450 |
with gr.Column(scale=5, min_width=200):
|
|
|
1619 |
outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
|
1620 |
api_name="API_RecursiveCut"
|
1621 |
)
|
1622 |
+
|
1623 |
+
with gr.Tab('Recursive Cut (Advanced)', visible=False) as tab_recursivecut_advanced:
|
1624 |
+
|
1625 |
+
with gr.Row():
|
1626 |
+
with gr.Column(scale=5, min_width=200):
|
1627 |
+
gr.Markdown('### Output (Recursion #1)')
|
1628 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1629 |
+
add_output_images_buttons(l1_gallery)
|
1630 |
+
l1_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1631 |
+
l1_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
1632 |
+
with gr.Column(scale=5, min_width=200):
|
1633 |
+
gr.Markdown('### Output (Recursion #2)')
|
1634 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1635 |
+
add_output_images_buttons(l2_gallery)
|
1636 |
+
l2_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1637 |
+
l2_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
1638 |
+
with gr.Column(scale=5, min_width=200):
|
1639 |
+
gr.Markdown('### Output (Recursion #3)')
|
1640 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
1641 |
+
add_output_images_buttons(l3_gallery)
|
1642 |
+
l3_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
1643 |
+
l3_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
1644 |
+
with gr.Row():
|
1645 |
+
with gr.Column(scale=5, min_width=200):
|
1646 |
+
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
1647 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
1648 |
+
num_images_slider.value = 100
|
1649 |
+
clear_images_button.visible = False
|
1650 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
1651 |
+
|
1652 |
+
with gr.Column(scale=5, min_width=200):
|
1653 |
+
with gr.Accordion("➡️ Recursion config", open=True):
|
1654 |
+
l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig")
|
1655 |
+
l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig")
|
1656 |
+
l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig")
|
1657 |
+
metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric")
|
1658 |
+
l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.5, elem_id="recursion_l1_gamma")
|
1659 |
+
l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.5, elem_id="recursion_l2_gamma")
|
1660 |
+
l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma")
|
1661 |
+
[
|
1662 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1663 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
1664 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1665 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
1666 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
1667 |
+
] = make_parameters_section()
|
1668 |
+
num_eig_slider.visible = False
|
1669 |
+
affinity_focal_gamma_slider.visible = False
|
1670 |
+
true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
|
1671 |
+
true_placeholder.visible = False
|
1672 |
+
false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
|
1673 |
+
false_placeholder.visible = False
|
1674 |
+
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
|
1675 |
+
number_placeholder.visible = False
|
1676 |
+
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
|
1677 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
1678 |
+
|
1679 |
+
submit_button.click(
|
1680 |
+
partial(run_fn, n_ret=9, advanced=True),
|
1681 |
+
inputs=[
|
1682 |
+
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
1683 |
+
positive_prompt, negative_prompt,
|
1684 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
1685 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
1686 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1687 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
1688 |
+
false_placeholder, number_placeholder, true_placeholder,
|
1689 |
+
l2_num_eig_slider, l3_num_eig_slider, metric_dropdown,
|
1690 |
+
l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider
|
1691 |
+
],
|
1692 |
+
outputs=[l1_gallery, l2_gallery, l3_gallery, l1_norm_gallery, l2_norm_gallery, l3_norm_gallery, l1_cluster_gallery, l2_cluster_gallery, l3_cluster_gallery, logging_text],
|
1693 |
+
)
|
1694 |
|
1695 |
|
1696 |
with gr.Tab('Video'):
|
|
|
1850 |
outputs=[output_gallery, logging_text],
|
1851 |
)
|
1852 |
|
1853 |
+
with gr.Tab('Model Aligned (Advanced)', visible=False) as tab_model_aligned_advanced:
|
1854 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
1855 |
gr.Markdown('---')
|
1856 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
|
2074 |
return gr.update()
|
2075 |
|
2076 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
2077 |
+
hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_advanced)
|
2078 |
+
hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_advanced)
|
2079 |
+
hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
|
2080 |
|
2081 |
with gr.Row():
|
2082 |
with gr.Column():
|