ncut-pytorch / app.py
huzey's picture
improve UI
c78a50f
raw
history blame
43.4 kB
# Author: Huzheng Yang
# %%
HUGGINGFACE_SPACE = True
BATCH_SIZE = 4
if HUGGINGFACE_SPACE: # huggingface ZeroGPU, dynamic GPU allocation
try:
import spaces
except ImportError:
HUGGINGFACE_SPACE = False # run on local machine
BATCH_SIZE = 1
import os
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import time
import gradio as gr
from backbone import extract_features, download_all_models, get_model
from backbone import MODEL_DICT, LAYER_DICT, RES_DICT
from ncut_pytorch import NCUT, eigenvector_to_rgb
DATASET_TUPS = [
# (name, num_classes)
('UCSC-VLAA/Recap-COCO-30K', None),
('nateraw/pascal-voc-2012', None),
('johnowhitaker/imagenette2-320', 10),
('jainr3/diffusiondb-pixelart', None),
('nielsr/CelebA-faces', None),
('JapanDegitalMaterial/Places_in_Japan', None),
('Borismile/Anime-dataset', None),
('Multimodal-Fatima/CUB_train', 200),
('mrm8488/ImageNet1K-val', 1000),
]
DATASET_NAMES = [tup[0] for tup in DATASET_TUPS]
DATASET_CLASSES = [tup[1] for tup in DATASET_TUPS]
from datasets import load_dataset
def download_all_datasets():
for name in DATASET_NAMES:
print(f"Downloading {name}")
try:
load_dataset(name, trust_remote_code=True)
except Exception as e:
print(f"Error downloading {name}: {e}")
def compute_ncut(
features,
num_eig=100,
num_sample_ncut=10000,
affinity_focal_gamma=0.3,
knn_ncut=10,
knn_tsne=10,
embedding_method="UMAP",
num_sample_tsne=300,
perplexity=150,
n_neighbors=150,
min_dist=0.1,
sampling_method="fps",
metric="cosine",
):
logging_str = ""
num_nodes = np.prod(features.shape[:3])
if num_nodes / 2 < num_eig:
# raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
num_eig = num_nodes // 2 - 1
logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
start = time.time()
eigvecs, eigvals = NCUT(
num_eig=num_eig,
num_sample=num_sample_ncut,
device="cuda" if torch.cuda.is_available() else "cpu",
affinity_focal_gamma=affinity_focal_gamma,
knn=knn_ncut,
sample_method=sampling_method,
distance=metric,
normalize_features=False,
).fit_transform(features.reshape(-1, features.shape[-1]))
# print(f"NCUT time: {time.time() - start:.2f}s")
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
start = time.time()
_, rgb = eigenvector_to_rgb(
eigvecs,
method=embedding_method,
num_sample=num_sample_tsne,
perplexity=perplexity,
n_neighbors=n_neighbors,
min_distance=min_dist,
knn=knn_tsne,
device="cuda" if torch.cuda.is_available() else "cpu",
)
logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n"
rgb = rgb.reshape(features.shape[:3] + (3,))
return rgb, logging_str, eigvecs
def dont_use_too_much_green(image_rgb):
# make sure the foval 40% of the image is red leading
x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
y1, y2 = int(image_rgb.shape[2] * 0.3), int(image_rgb.shape[2] * 0.7)
sum_values = image_rgb[:, x1:x2, y1:y2].mean((0, 1, 2))
sorted_indices = sum_values.argsort(descending=True)
image_rgb = image_rgb[:, :, :, sorted_indices]
return image_rgb
def to_pil_images(images):
return [
Image.fromarray((image * 255).cpu().numpy().astype(np.uint8)).resize((256, 256), Image.Resampling.NEAREST)
for image in images
]
def pil_images_to_video(images, output_path, fps=5):
# from pil images to numpy
images = [np.array(image) for image in images]
# print("Saving video to", output_path)
import cv2
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
height, width, _ = images[0].shape
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for image in images:
out.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
out.release()
return output_path
# save up to 100 videos in disk
class VideoCache:
def __init__(self, max_videos=100):
self.max_videos = max_videos
self.videos = {}
def add_video(self, video_path):
if len(self.videos) >= self.max_videos:
pop_path = self.videos.popitem()[0]
try:
os.remove(pop_path)
except:
pass
self.videos[video_path] = video_path
def get_video(self, video_path):
return self.videos.get(video_path, None)
video_cache = VideoCache()
def get_random_path(length=10):
import random
import string
name = ''.join(random.choices(string.ascii_lowercase + string.digits, k=length))
path = f'/tmp/{name}.mp4'
return path
default_images = ['./images/image_0.jpg', './images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg', './images/image_5.jpg']
default_outputs = ['./images/image-1.webp', './images/image-2.webp', './images/image-3.webp', './images/image-4.webp', './images/image-5.webp']
default_outputs_independent = ['./images/image-6.webp', './images/image-7.webp', './images/image-8.webp', './images/image-9.webp', './images/image-10.webp']
downscaled_images = ['./images/image_0_small.jpg', './images/image_1_small.jpg', './images/image_2_small.jpg', './images/image_3_small.jpg', './images/image_5_small.jpg']
downscaled_outputs = default_outputs
example_items = downscaled_images[:3] + downscaled_outputs[:3]
def ncut_run(
model,
images,
model_name="SAM(sam_vit_b)",
layer=-1,
num_eig=100,
node_type="block",
affinity_focal_gamma=0.3,
num_sample_ncut=10000,
knn_ncut=10,
embedding_method="UMAP",
num_sample_tsne=1000,
knn_tsne=10,
perplexity=500,
n_neighbors=500,
min_dist=0.1,
sampling_method="fps",
old_school_ncut=False,
recursion=False,
recursion_l2_n_eigs=50,
recursion_l3_n_eigs=20,
recursion_metric="euclidean",
video_output=False,
):
logging_str = ""
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
gr.Warning("Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.")
logging_str += f"Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.\n"
perplexity = num_sample_tsne - 1
n_neighbors = num_sample_tsne - 1
if torch.cuda.is_available():
torch.cuda.empty_cache()
node_type = node_type.split(":")[0].strip()
start = time.time()
features = extract_features(
images, model, model_name=model_name, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
)
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
if recursion:
rgbs = []
inp = features
for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
logging_str += f"Recursion #{i+1}\n"
rgb, _logging_str, eigvecs = compute_ncut(
inp,
num_eig=n_eigs,
num_sample_ncut=num_sample_ncut,
affinity_focal_gamma=affinity_focal_gamma,
knn_ncut=knn_ncut,
knn_tsne=knn_tsne,
num_sample_tsne=num_sample_tsne,
embedding_method=embedding_method,
perplexity=perplexity,
n_neighbors=n_neighbors,
min_dist=min_dist,
sampling_method=sampling_method,
metric="cosine" if i == 0 else recursion_metric,
)
logging_str += _logging_str
rgb = dont_use_too_much_green(rgb)
rgbs.append(to_pil_images(rgb))
inp = eigvecs.reshape(*features.shape[:3], -1)
if recursion_metric == "cosine":
inp = F.normalize(inp, dim=-1)
return rgbs[0], rgbs[1], rgbs[2], logging_str
if old_school_ncut: # individual images
logging_str += "Running NCut for each image independently\n"
rgb = []
for i_image in range(features.shape[0]):
feature = features[i_image]
_rgb, _logging_str, _ = compute_ncut(
feature[None],
num_eig=num_eig,
num_sample_ncut=num_sample_ncut,
affinity_focal_gamma=affinity_focal_gamma,
knn_ncut=knn_ncut,
knn_tsne=knn_tsne,
num_sample_tsne=num_sample_tsne,
embedding_method=embedding_method,
perplexity=perplexity,
n_neighbors=n_neighbors,
min_dist=min_dist,
sampling_method=sampling_method,
)
logging_str += _logging_str
rgb.append(_rgb[0])
if not old_school_ncut: # joint across all images
rgb, _logging_str, _ = compute_ncut(
features,
num_eig=num_eig,
num_sample_ncut=num_sample_ncut,
affinity_focal_gamma=affinity_focal_gamma,
knn_ncut=knn_ncut,
knn_tsne=knn_tsne,
num_sample_tsne=num_sample_tsne,
embedding_method=embedding_method,
perplexity=perplexity,
n_neighbors=n_neighbors,
min_dist=min_dist,
sampling_method=sampling_method,
)
logging_str += _logging_str
rgb = dont_use_too_much_green(rgb)
if video_output:
video_path = get_random_path()
video_cache.add_video(video_path)
pil_images_to_video(to_pil_images(rgb), video_path)
return video_path, logging_str
else:
return to_pil_images(rgb), logging_str
def _ncut_run(*args, **kwargs):
try:
ret = ncut_run(*args, **kwargs)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return ret
except Exception as e:
gr.Error(str(e))
if torch.cuda.is_available():
torch.cuda.empty_cache()
return [], "Error: " + str(e)
if HUGGINGFACE_SPACE:
@spaces.GPU(duration=20)
def quick_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
@spaces.GPU(duration=30)
def long_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
@spaces.GPU(duration=60)
def longer_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
@spaces.GPU(duration=120)
def super_duper_long_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
if not HUGGINGFACE_SPACE:
def quick_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
def long_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
def longer_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
def super_duper_long_run(*args, **kwargs):
return _ncut_run(*args, **kwargs)
def extract_video_frames(video_path, max_frames=100):
from decord import VideoReader
vr = VideoReader(video_path)
num_frames = len(vr)
if num_frames > max_frames:
gr.Warning(f"Video has {num_frames} frames. Only using {max_frames} frames. Evenly spaced.")
frame_idx = np.linspace(0, num_frames - 1, max_frames, dtype=int).tolist()
else:
frame_idx = list(range(num_frames))
frames = vr.get_batch(frame_idx).asnumpy()
# return as list of PIL images
return [(Image.fromarray(frames[i]), "") for i in range(frames.shape[0])]
def transform_image(image, resolution=(1024, 1024)):
image = image.convert('RGB').resize(resolution, Image.LANCZOS)
# Convert to torch tensor
image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
image = image / 255
# Normalize
image = (image - 0.5) / 0.5
return image
def run_fn(
images,
model_name="SAM(sam_vit_b)",
layer=-1,
num_eig=100,
node_type="block",
affinity_focal_gamma=0.3,
num_sample_ncut=10000,
knn_ncut=10,
embedding_method="UMAP",
num_sample_tsne=1000,
knn_tsne=10,
perplexity=500,
n_neighbors=500,
min_dist=0.1,
sampling_method="fps",
old_school_ncut=False,
max_frames=100,
recursion=False,
recursion_l2_n_eigs=50,
recursion_l3_n_eigs=20,
recursion_metric="euclidean",
):
if images is None:
gr.Warning("No images selected.")
return [], "No images selected."
video_output = False
if isinstance(images, str):
images = extract_video_frames(images, max_frames=max_frames)
video_output = True
if sampling_method == "fps":
sampling_method = "farthest"
# resize the images before acquiring GPU
resolution = RES_DICT[model_name]
images = [tup[0] for tup in images]
images = [transform_image(image, resolution=resolution) for image in images]
images = torch.stack(images)
model = get_model(model_name)
kwargs = {
"model_name": model_name,
"layer": layer,
"num_eig": num_eig,
"node_type": node_type,
"affinity_focal_gamma": affinity_focal_gamma,
"num_sample_ncut": num_sample_ncut,
"knn_ncut": knn_ncut,
"embedding_method": embedding_method,
"num_sample_tsne": num_sample_tsne,
"knn_tsne": knn_tsne,
"perplexity": perplexity,
"n_neighbors": n_neighbors,
"min_dist": min_dist,
"sampling_method": sampling_method,
"old_school_ncut": old_school_ncut,
"recursion": recursion,
"recursion_l2_n_eigs": recursion_l2_n_eigs,
"recursion_l3_n_eigs": recursion_l3_n_eigs,
"recursion_metric": recursion_metric,
"video_output": video_output,
}
# print(kwargs)
num_images = len(images)
if num_images > 100:
return super_duper_long_run(model, images, **kwargs)
if recursion:
return longer_run(model, images, **kwargs)
if num_images > 50:
return longer_run(model, images, **kwargs)
if old_school_ncut:
return longer_run(model, images, **kwargs)
if num_images > 10:
return long_run(model, images, **kwargs)
if embedding_method == "UMAP":
if perplexity >= 250 or num_sample_tsne >= 500:
return longer_run(model, images, **kwargs)
return long_run(model, images, **kwargs)
if embedding_method == "t-SNE":
if perplexity >= 250 or num_sample_tsne >= 500:
return long_run(model, images, **kwargs)
return quick_run(model, images, **kwargs)
return quick_run(model, images, **kwargs)
def make_input_images_section():
gr.Markdown('### Input Images')
input_gallery = gr.Gallery(value=None, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil", show_share_button=False)
submit_button = gr.Button("🔴 RUN", elem_id="submit_button")
clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button')
return input_gallery, submit_button, clear_images_button
def make_input_video_section():
gr.Markdown('### Input Video')
input_gallery = gr.Video(value=None, label="Select video", elem_id="video-input", height="auto", show_share_button=False)
gr.Markdown('_image backbone model is used to extract features from each frame, NCUT is computed on all frames_')
# max_frames_number = gr.Number(100, label="Max frames", elem_id="max_frames")
max_frames_number = gr.Slider(1, 200, step=1, label="Max frames", value=100, elem_id="max_frames")
submit_button = gr.Button("🔴 RUN", elem_id="submit_button")
clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button')
return input_gallery, submit_button, clear_images_button, max_frames_number
def make_example_images_section():
gr.Markdown('### Load Images 👇')
load_images_button = gr.Button("Load Example", elem_id="load-images-button")
example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
hide_button = gr.Button("Hide Example", elem_id="hide-button")
hide_button.click(
fn=lambda: gr.update(visible=False),
outputs=example_gallery
)
hide_button.click(
fn=lambda: gr.update(visible=False),
outputs=hide_button
)
return load_images_button, example_gallery, hide_button
def make_example_video_section():
gr.Markdown('### Load Video 👇')
load_video_button = gr.Button("Load Example", elem_id="load-video-button")
return load_video_button
def make_dataset_images_section(open=False):
with gr.Accordion("➡️ Click to expand: Load from dataset", open=open):
dataset_names = DATASET_NAMES
dataset_classes = DATASET_CLASSES
dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset")
num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
load_dataset_button = gr.Button("Load Dataset", elem_id="load-dataset-button")
def change_filter_options(dataset_name):
idx = dataset_names.index(dataset_name)
num_classes = dataset_classes[idx]
if num_classes is None:
return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
return (gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox", visible=True),
gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=True))
dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
def change_filter_by_class(is_filter, dataset_name):
idx = dataset_names.index(dataset_name)
num_classes = dataset_classes[idx]
return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
def change_random_seed(is_random):
return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random)
is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider)
def load_dataset_images(dataset_name, num_images=10,
is_filter=True, filter_by_class_text="0,1,2",
is_random=False, seed=1):
try:
dataset = load_dataset(dataset_name, trust_remote_code=True)
key = list(dataset.keys())[0]
dataset = dataset[key]
except Exception as e:
gr.Error(f"Error loading dataset {dataset_name}: {e}")
return None
if num_images > len(dataset):
num_images = len(dataset)
if 'label' not in dataset and is_filter:
gr.Error(f"Dataset {dataset_name} has no class label.")
return None
if is_filter:
classes = list(map(int, filter_by_class_text.split(",")))
labels = np.array(dataset['label'])
unique_labels = np.unique(labels)
valid_classes = [i for i in classes if i in unique_labels]
invalid_classes = [i for i in classes if i not in unique_labels]
if len(invalid_classes) > 0:
gr.Warning(f"Classes {invalid_classes} not found in the dataset.")
if len(valid_classes) == 0:
gr.Error(f"Classes {classes} not found in the dataset.")
return None
# shuffle each class
chunk_size = num_images // len(valid_classes)
image_idx = []
for i in valid_classes:
idx = np.where(labels == i)[0]
if is_random:
idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
else:
idx = idx[:chunk_size]
image_idx.extend(idx.tolist())
if not is_filter:
if is_random:
image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
else:
image_idx = list(range(num_images))
images = [dataset[i]['image'] for i in image_idx]
return images
load_dataset_button.click(load_dataset_images,
inputs=[dataset_dropdown, num_images_slider,
filter_by_class_checkbox, filter_by_class_text,
is_random_checkbox, random_seed_slider],
outputs=[input_gallery])
return dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button
def make_output_images_section():
gr.Markdown('### Output Images')
output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
return output_gallery
def make_parameters_section():
gr.Markdown('### Parameters')
from backbone import get_all_model_names
model_names = get_all_model_names()
model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8)", elem_id="model_name")
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
def change_layer_slider(model_name):
layer_dict = LAYER_DICT
if model_name in layer_dict:
value = layer_dict[model_name]
return (gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True),
gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?"))
else:
value = 12
return (gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?"),
gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True))
model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown])
with gr.Accordion("➡️ Click to expand: more parameters", open=False):
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
sampling_method_dropdown = gr.Dropdown(["fps", "random"], label="NCUT: Sampling method", value="fps", elem_id="sampling_method", info="Nyström approximation")
knn_ncut_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_shpere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
num_sample_tsne_slider = gr.Slider(100, 1000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation")
perplexity_slider = gr.Slider(10, 500, step=10, label="t-SNE: Perplexity", value=150, elem_id="perplexity")
n_neighbors_slider = gr.Slider(10, 500, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors")
min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist")
return [model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider,
sampling_method_dropdown]
demo = gr.Blocks(
theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'),
fill_width=False,
title="ncut-pytorch",
)
with demo:
with gr.Tab('AlignedCut'):
with gr.Row():
with gr.Column(scale=5, min_width=200):
input_gallery, submit_button, clear_images_button = make_input_images_section()
load_images_button, example_gallery, hide_button = make_example_images_section()
dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section()
with gr.Column(scale=5, min_width=200):
output_gallery = make_output_images_section()
[
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider,
sampling_method_dropdown
] = make_parameters_section()
# logging text box
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
load_images_button.click(lambda x: default_images, outputs=input_gallery)
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
submit_button.click(
run_fn,
inputs=[
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
],
outputs=[output_gallery, logging_text]
)
with gr.Tab('NCut'):
gr.Markdown('#### NCut (Legacy), not aligned, no Nyström approximation')
gr.Markdown('Each image is solved independently, <em>color is <b>not</b> aligned across images</em>')
gr.Markdown('---')
gr.Markdown('<p style="text-align: center;"><b>NCut vs. AlignedCut</b></p>')
with gr.Row():
with gr.Column(scale=5, min_width=200):
gr.Markdown('#### Pros')
gr.Markdown('- Easy Solution. Use less eigenvectors.')
gr.Markdown('- Exact solution. No Nyström approximation.')
with gr.Column(scale=5, min_width=200):
gr.Markdown('#### Cons')
gr.Markdown('- Not aligned. Distance is not preserved across images. No pseudo-labeling or correspondence.')
gr.Markdown('- Poor complexity scaling. Unable to handle large number of pixels.')
gr.Markdown('---')
with gr.Row():
with gr.Column(scale=5, min_width=200):
gr.Markdown(' ')
with gr.Column(scale=5, min_width=200):
gr.Markdown('<em>color is <b>not</b> aligned across images</em> 👇')
with gr.Row():
with gr.Column(scale=5, min_width=200):
input_gallery, submit_button, clear_images_button = make_input_images_section()
load_images_button, example_gallery, hide_button = make_example_images_section()
dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section()
example_gallery.visible = False
hide_button.visible = False
with gr.Column(scale=5, min_width=200):
output_gallery = make_output_images_section()
[
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider,
sampling_method_dropdown
] = make_parameters_section()
old_school_ncut_checkbox = gr.Checkbox(label="Old school NCut", value=True, elem_id="old_school_ncut")
invisible_list = [old_school_ncut_checkbox, num_sample_ncut_slider, knn_ncut_slider,
num_sample_tsne_slider, knn_tsne_slider, sampling_method_dropdown]
for item in invisible_list:
item.visible = False
# logging text box
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
load_images_button.click(lambda x: (default_images, default_outputs_independent), outputs=[input_gallery, output_gallery])
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
submit_button.click(
run_fn,
inputs=[
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
old_school_ncut_checkbox
],
outputs=[output_gallery, logging_text]
)
with gr.Tab('Recursive Cut'):
gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ')
gr.Markdown('__Recursive NCUT__ amplifies small object parts, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)')
gr.Markdown('---')
with gr.Row():
with gr.Column(scale=5, min_width=200):
input_gallery, submit_button, clear_images_button = make_input_images_section()
load_images_button, example_gallery, hide_button = make_example_images_section()
load_images_button.click(lambda x: default_images, outputs=[input_gallery])
example_gallery.visible = False
hide_button.visible = False
dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section()
num_images_slider.value = 100
with gr.Column(scale=5, min_width=200):
with gr.Accordion("➡️ Recursion config", open=True):
l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig")
l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig")
l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=25, elem_id="l3_num_eig")
metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric")
[
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider,
sampling_method_dropdown
] = make_parameters_section()
num_eig_slider.visible = False
# logging text box
with gr.Row():
with gr.Column(scale=5, min_width=200):
gr.Markdown('### Output (Recursion #1)')
l1_gallery = gr.Gallery(value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto")
with gr.Column(scale=5, min_width=200):
gr.Markdown('### Output (Recursion #2)')
l2_gallery = gr.Gallery(value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto")
with gr.Column(scale=5, min_width=200):
gr.Markdown('### Output (Recursion #3)')
l3_gallery = gr.Gallery(value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto")
with gr.Row():
with gr.Column(scale=5, min_width=200):
gr.Markdown(' ')
with gr.Column(scale=5, min_width=200):
gr.Markdown(' ')
with gr.Column(scale=5, min_width=200):
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
true_placeholder.visible = False
false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
false_placeholder.visible = False
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
number_placeholder.visible = False
clear_images_button.click(lambda x: ([], [], [], []), outputs=[input_gallery, l1_gallery, l2_gallery, l3_gallery])
submit_button.click(
run_fn,
inputs=[
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
false_placeholder, number_placeholder, true_placeholder,
l2_num_eig_slider, l3_num_eig_slider, metric_dropdown,
],
outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text]
)
with gr.Tab('Video'):
with gr.Row():
with gr.Column(scale=5, min_width=200):
video_input_gallery, submit_button, clear_images_button, max_frame_number = make_input_video_section()
# load_video_button = make_example_video_section()
with gr.Column(scale=5, min_width=200):
video_output_gallery = gr.Video(value=None, label="NCUT Embedding", elem_id="ncut", height="auto", show_share_button=False)
[
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider,
sampling_method_dropdown
] = make_parameters_section()
num_sample_tsne_slider.value = 1000
perplexity_slider.value = 500
n_neighbors_slider.value = 500
knn_tsne_slider.value = 20
# logging text box
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
load_images_button.click(lambda x: (default_images, default_outputs), outputs=[input_gallery, output_gallery])
clear_images_button.click(lambda x: (None, []), outputs=[video_input_gallery, video_output_gallery])
place_holder_false = gr.Checkbox(label="Place holder", value=False, elem_id="place_holder_false")
place_holder_false.visible = False
submit_button.click(
run_fn,
inputs=[
video_input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
place_holder_false, max_frame_number
],
outputs=[video_output_gallery, logging_text]
)
with gr.Tab('Text'):
gr.Markdown('=== under construction ===')
gr.Markdown('Please see the [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/gallery_llama3/) for example of NCUT on text input.')
gr.Markdown('---')
gr.Markdown('![ncut](https://ncut-pytorch.readthedocs.io/en/latest/images/gallery/llama3/llama3_layer_31.jpg)')
with gr.Tab('Compare'):
def add_one_model(i_model=1):
with gr.Column(scale=5, min_width=200) as col:
gr.Markdown(f'### Output Images')
output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto")
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}")
[
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider,
sampling_method_dropdown
] = make_parameters_section()
# logging text box
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
submit_button.click(
run_fn,
inputs=[
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
],
outputs=[output_gallery, logging_text]
)
return col
with gr.Row():
with gr.Column(scale=5, min_width=200):
input_gallery, submit_button, clear_images_button = make_input_images_section()
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
submit_button.visible = False
load_images_button, example_gallery, hide_button = make_example_images_section()
example_gallery.visible = False
hide_button.visible = False
dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section(open=True)
load_images_button.click(lambda x: default_images, outputs=input_gallery)
for i in range(2):
add_one_model()
# Create rows and buttons in a loop
rows = []
buttons = []
for i in range(4):
row = gr.Row(visible=False)
rows.append(row)
with row:
for j in range(3):
with gr.Column(scale=5, min_width=200):
add_one_model()
button = gr.Button("Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True)
buttons.append(button)
if i > 0:
# Reveal the current row and next button
buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row)
buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button)
# Hide the current button
buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1])
# Last button only reveals the last row and hides itself
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
if HUGGINGFACE_SPACE:
download_all_models()
download_all_datasets()
demo.launch()
else:
demo.launch(share=True)
# %%