Spaces:
Running
on
Zero
Running
on
Zero
update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ from PIL import Image
|
|
5 |
import torchvision.transforms as transforms
|
6 |
from torch import nn
|
7 |
import numpy as np
|
|
|
|
|
8 |
|
9 |
import gradio as gr
|
10 |
|
@@ -315,6 +317,7 @@ def compute_ncut(
|
|
315 |
):
|
316 |
from ncut_pytorch import NCUT, rgb_from_tsne_3d
|
317 |
|
|
|
318 |
eigvecs, eigvals = NCUT(
|
319 |
num_eig=num_eig,
|
320 |
num_sample=num_sample_ncut,
|
@@ -322,12 +325,17 @@ def compute_ncut(
|
|
322 |
affinity_focal_gamma=affinity_focal_gamma,
|
323 |
knn=knn_ncut,
|
324 |
).fit_transform(features.reshape(-1, features.shape[-1]))
|
|
|
|
|
|
|
325 |
X_3d, rgb = rgb_from_tsne_3d(
|
326 |
eigvecs,
|
327 |
num_sample=num_sample_tsne,
|
328 |
perplexity=perplexity,
|
329 |
knn=knn_tsne,
|
330 |
)
|
|
|
|
|
331 |
rgb = rgb.reshape(features.shape[:3] + (3,))
|
332 |
return rgb
|
333 |
|
@@ -368,9 +376,13 @@ def main_fn(
|
|
368 |
perplexity = num_sample_tsne - 1
|
369 |
|
370 |
images = [image[0] for image in images]
|
|
|
|
|
371 |
features = extract_features(
|
372 |
images, model_name=model_name, node_type=node_type, layer=layer
|
373 |
)
|
|
|
|
|
374 |
rgb = compute_ncut(
|
375 |
features,
|
376 |
num_eig=num_eig,
|
@@ -391,7 +403,7 @@ demo = gr.Interface(
|
|
391 |
main_fn,
|
392 |
[
|
393 |
gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
|
394 |
-
gr.Dropdown(["SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
|
395 |
gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
|
396 |
gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
|
397 |
gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
|
|
|
5 |
import torchvision.transforms as transforms
|
6 |
from torch import nn
|
7 |
import numpy as np
|
8 |
+
import os
|
9 |
+
import time
|
10 |
|
11 |
import gradio as gr
|
12 |
|
|
|
317 |
):
|
318 |
from ncut_pytorch import NCUT, rgb_from_tsne_3d
|
319 |
|
320 |
+
start = time.time()
|
321 |
eigvecs, eigvals = NCUT(
|
322 |
num_eig=num_eig,
|
323 |
num_sample=num_sample_ncut,
|
|
|
325 |
affinity_focal_gamma=affinity_focal_gamma,
|
326 |
knn=knn_ncut,
|
327 |
).fit_transform(features.reshape(-1, features.shape[-1]))
|
328 |
+
print(f"NCUT time: {time.time() - start:.2f}s")
|
329 |
+
|
330 |
+
start = time.time()
|
331 |
X_3d, rgb = rgb_from_tsne_3d(
|
332 |
eigvecs,
|
333 |
num_sample=num_sample_tsne,
|
334 |
perplexity=perplexity,
|
335 |
knn=knn_tsne,
|
336 |
)
|
337 |
+
print(f"t-SNE time: {time.time() - start:.2f}s")
|
338 |
+
|
339 |
rgb = rgb.reshape(features.shape[:3] + (3,))
|
340 |
return rgb
|
341 |
|
|
|
376 |
perplexity = num_sample_tsne - 1
|
377 |
|
378 |
images = [image[0] for image in images]
|
379 |
+
|
380 |
+
start = time.time()
|
381 |
features = extract_features(
|
382 |
images, model_name=model_name, node_type=node_type, layer=layer
|
383 |
)
|
384 |
+
print(f"Feature extraction time: {time.time() - start:.2f}s")
|
385 |
+
|
386 |
rgb = compute_ncut(
|
387 |
features,
|
388 |
num_eig=num_eig,
|
|
|
403 |
main_fn,
|
404 |
[
|
405 |
gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
|
406 |
+
gr.Dropdown(["SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
|
407 |
gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
|
408 |
gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
|
409 |
gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
|