Spaces:
Running
on
Zero
Running
on
Zero
update aligned, fix z-score
Browse files
app.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
# Author: Huzheng Yang
|
2 |
# %%
|
3 |
import copy
|
|
|
4 |
import os
|
|
|
|
|
5 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
6 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
7 |
|
@@ -241,7 +244,7 @@ def ncut_run(
|
|
241 |
logging_str = ""
|
242 |
if "AlignedThreeModelAttnNodes" == model_name:
|
243 |
# dirty patch for the alignedcut paper
|
244 |
-
resolution = (
|
245 |
else:
|
246 |
resolution = RES_DICT[model_name]
|
247 |
logging_str += f"Resolution: {resolution}\n"
|
@@ -357,11 +360,18 @@ def ncut_run(
|
|
357 |
|
358 |
if "AlignedThreeModelAttnNodes" == model_name:
|
359 |
# dirty patch for the alignedcut paper
|
360 |
-
galleries = []
|
361 |
-
for i_node in range(rgb.shape[1]):
|
362 |
-
|
363 |
-
|
364 |
-
return *galleries, logging_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
if is_lisa == True:
|
367 |
# dirty patch for the LISA model
|
@@ -457,9 +467,78 @@ def transform_image(image, resolution=(1024, 1024)):
|
|
457 |
image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
|
458 |
image = image / 255
|
459 |
# Normalize
|
460 |
-
|
|
|
|
|
461 |
return image
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
def load_alignedthreemodel():
|
464 |
|
465 |
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
@@ -687,10 +766,10 @@ def make_input_video_section():
|
|
687 |
clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
688 |
return input_gallery, submit_button, clear_images_button, max_frames_number
|
689 |
|
690 |
-
def make_dataset_images_section(advanced=False):
|
691 |
|
692 |
gr.Markdown('### Load Datasets')
|
693 |
-
load_images_button = gr.Button("Load", elem_id="load-images-button", variant='
|
694 |
advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
|
695 |
with gr.Column() as basic_block:
|
696 |
example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
|
@@ -700,10 +779,17 @@ def make_dataset_images_section(advanced=False):
|
|
700 |
with gr.Row():
|
701 |
dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
|
702 |
num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
|
708 |
if advanced:
|
709 |
advanced_block.visible = True
|
@@ -1168,12 +1254,18 @@ with demo:
|
|
1168 |
with gr.Column(scale=5, min_width=200):
|
1169 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
1170 |
|
1171 |
-
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
1172 |
num_images_slider.value = 100
|
1173 |
|
|
|
1174 |
with gr.Column(scale=5, min_width=200):
|
|
|
|
|
|
|
1175 |
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
1176 |
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
|
|
|
|
1177 |
[
|
1178 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1179 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
@@ -1185,20 +1277,23 @@ with demo:
|
|
1185 |
model_dropdown.visible = False
|
1186 |
layer_slider.visible = False
|
1187 |
node_type_dropdown.visible = False
|
|
|
|
|
1188 |
# logging text box
|
1189 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
1190 |
|
1191 |
-
galleries = []
|
1192 |
-
for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
|
1200 |
|
1201 |
-
clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
|
|
1202 |
|
1203 |
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
1204 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
@@ -1213,7 +1308,8 @@ with demo:
|
|
1213 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1214 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
1215 |
],
|
1216 |
-
outputs=galleries + [logging_text],
|
|
|
1217 |
)
|
1218 |
|
1219 |
with gr.Tab('Compare Models'):
|
@@ -1320,4 +1416,4 @@ if DOWNLOAD_ALL_MODELS_DATASETS:
|
|
1320 |
demo.launch(share=True)
|
1321 |
|
1322 |
|
1323 |
-
# %%
|
|
|
1 |
# Author: Huzheng Yang
|
2 |
# %%
|
3 |
import copy
|
4 |
+
from io import BytesIO
|
5 |
import os
|
6 |
+
|
7 |
+
from matplotlib import pyplot as plt
|
8 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
9 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
10 |
|
|
|
244 |
logging_str = ""
|
245 |
if "AlignedThreeModelAttnNodes" == model_name:
|
246 |
# dirty patch for the alignedcut paper
|
247 |
+
resolution = (224, 224)
|
248 |
else:
|
249 |
resolution = RES_DICT[model_name]
|
250 |
logging_str += f"Resolution: {resolution}\n"
|
|
|
360 |
|
361 |
if "AlignedThreeModelAttnNodes" == model_name:
|
362 |
# dirty patch for the alignedcut paper
|
363 |
+
# galleries = []
|
364 |
+
# for i_node in range(rgb.shape[1]):
|
365 |
+
# _rgb = rgb[:, i_node]
|
366 |
+
# galleries.append(to_pil_images(_rgb, target_size=56))
|
367 |
+
# return *galleries, logging_str
|
368 |
+
pil_images = []
|
369 |
+
for i_image in range(rgb.shape[0]):
|
370 |
+
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
371 |
+
pil_images.append(_im)
|
372 |
+
return pil_images, logging_str
|
373 |
+
|
374 |
+
|
375 |
|
376 |
if is_lisa == True:
|
377 |
# dirty patch for the LISA model
|
|
|
467 |
image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
|
468 |
image = image / 255
|
469 |
# Normalize
|
470 |
+
mean = [0.485, 0.456, 0.406]
|
471 |
+
std = [0.229, 0.224, 0.225]
|
472 |
+
image = (image - torch.tensor(mean).view(3, 1, 1)) / torch.tensor(std).view(3, 1, 1)
|
473 |
return image
|
474 |
|
475 |
+
def plot_one_image_36_grid(original_image, tsne_rgb_images):
|
476 |
+
mean = [0.485, 0.456, 0.406]
|
477 |
+
std = [0.229, 0.224, 0.225]
|
478 |
+
original_image = original_image * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
|
479 |
+
original_image = torch.clamp(original_image, 0, 1)
|
480 |
+
|
481 |
+
fig = plt.figure(figsize=(20, 4))
|
482 |
+
grid = plt.GridSpec(3, 14, hspace=0.1, wspace=0.1)
|
483 |
+
|
484 |
+
ax1 = fig.add_subplot(grid[0:2, 0:2])
|
485 |
+
img = original_image.cpu().float().numpy().transpose(1, 2, 0)
|
486 |
+
|
487 |
+
def convert_and_pad_image(np_array, pad_size=20):
|
488 |
+
"""
|
489 |
+
Converts a NumPy array of shape (height, width, 3) to a PNG image
|
490 |
+
and pads the right and bottom sides with a transparent background.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
np_array (numpy.ndarray): Input NumPy array of shape (height, width, 3)
|
494 |
+
pad_size (int, optional): Number of pixels to pad on the right and bottom sides. Default is 20.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
PIL.Image: Padded PNG image with transparent background
|
498 |
+
"""
|
499 |
+
# Convert NumPy array to PIL Image
|
500 |
+
img = Image.fromarray(np_array)
|
501 |
+
|
502 |
+
# Get the original size
|
503 |
+
width, height = img.size
|
504 |
+
|
505 |
+
# Create a new image with padding and transparent background
|
506 |
+
new_width = width + pad_size
|
507 |
+
new_height = height + pad_size
|
508 |
+
padded_img = Image.new('RGBA', (new_width, new_height), color=(255, 255, 255, 0))
|
509 |
+
|
510 |
+
# Paste the original image onto the padded image
|
511 |
+
padded_img.paste(img, (0, 0))
|
512 |
+
|
513 |
+
return padded_img
|
514 |
+
|
515 |
+
img = convert_and_pad_image((img*255).astype(np.uint8))
|
516 |
+
ax1.imshow(img)
|
517 |
+
ax1.axis('off')
|
518 |
+
|
519 |
+
model_names = ['CLIP', 'DINO', 'MAE']
|
520 |
+
|
521 |
+
for i_model, model_name in enumerate(model_names):
|
522 |
+
for i_layer in range(12):
|
523 |
+
ax = fig.add_subplot(grid[i_model, i_layer+2])
|
524 |
+
ax.imshow(tsne_rgb_images[i_layer+12*i_model].cpu().float().numpy())
|
525 |
+
ax.axis('off')
|
526 |
+
if i_model == 0:
|
527 |
+
ax.set_title(f'Layer{i_layer}', fontsize=16)
|
528 |
+
if i_layer == 0:
|
529 |
+
ax.text(-0.1, 0.5, model_name, va="center", ha="center", fontsize=16, transform=ax.transAxes, rotation=90,)
|
530 |
+
plt.tight_layout()
|
531 |
+
buf = BytesIO()
|
532 |
+
plt.savefig(buf, bbox_inches='tight', pad_inches=0, dpi=100)
|
533 |
+
|
534 |
+
buf.seek(0) # Move to the start of the BytesIO buffer
|
535 |
+
img = Image.open(buf)
|
536 |
+
img = img.convert("RGB")
|
537 |
+
img = copy.deepcopy(img)
|
538 |
+
buf.close()
|
539 |
+
plt.close()
|
540 |
+
return img
|
541 |
+
|
542 |
def load_alignedthreemodel():
|
543 |
|
544 |
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
|
|
766 |
clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
767 |
return input_gallery, submit_button, clear_images_button, max_frames_number
|
768 |
|
769 |
+
def make_dataset_images_section(advanced=False, is_random=False):
|
770 |
|
771 |
gr.Markdown('### Load Datasets')
|
772 |
+
load_images_button = gr.Button("🟢 Load Images", elem_id="load-images-button", variant='primary')
|
773 |
advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
|
774 |
with gr.Column() as basic_block:
|
775 |
example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
|
|
|
779 |
with gr.Row():
|
780 |
dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
|
781 |
num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
782 |
+
if not is_random:
|
783 |
+
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
|
784 |
+
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
|
785 |
+
is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
|
786 |
+
random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
|
787 |
+
if is_random:
|
788 |
+
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
|
789 |
+
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
|
790 |
+
is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
|
791 |
+
random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
|
792 |
+
|
793 |
|
794 |
if advanced:
|
795 |
advanced_block.visible = True
|
|
|
1254 |
with gr.Column(scale=5, min_width=200):
|
1255 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
1256 |
|
1257 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
|
1258 |
num_images_slider.value = 100
|
1259 |
|
1260 |
+
|
1261 |
with gr.Column(scale=5, min_width=200):
|
1262 |
+
output_gallery = make_output_images_section()
|
1263 |
+
gr.Markdown('### TIP1: use the `full-screen` button, and use `arrow keys` to navigate')
|
1264 |
+
gr.Markdown('---')
|
1265 |
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
1266 |
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
1267 |
+
gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT')
|
1268 |
+
gr.Markdown('---')
|
1269 |
[
|
1270 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
1271 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
|
|
1277 |
model_dropdown.visible = False
|
1278 |
layer_slider.visible = False
|
1279 |
node_type_dropdown.visible = False
|
1280 |
+
num_sample_ncut_slider.value = 10000
|
1281 |
+
num_sample_tsne_slider.value = 1000
|
1282 |
# logging text box
|
1283 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
1284 |
|
1285 |
+
# galleries = []
|
1286 |
+
# for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
|
1287 |
+
# with gr.Row():
|
1288 |
+
# for i_layer in range(1, 13):
|
1289 |
+
# with gr.Column(scale=5, min_width=200):
|
1290 |
+
# gr.Markdown(f'### {model_name} Layer {i_layer}')
|
1291 |
+
# output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
|
1292 |
+
# galleries.append(output_gallery)
|
1293 |
|
1294 |
|
1295 |
+
# clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
1296 |
+
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
1297 |
|
1298 |
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
1299 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
|
|
1308 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
1309 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
1310 |
],
|
1311 |
+
# outputs=galleries + [logging_text],
|
1312 |
+
outputs=[output_gallery, logging_text],
|
1313 |
)
|
1314 |
|
1315 |
with gr.Tab('Compare Models'):
|
|
|
1416 |
demo.launch(share=True)
|
1417 |
|
1418 |
|
1419 |
+
# %%
|