Spaces:
Running
on
Zero
Running
on
Zero
add tree+image
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ import multiprocessing as mp
|
|
16 |
from einops import rearrange
|
17 |
from matplotlib import pyplot as plt
|
18 |
import matplotlib
|
|
|
19 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
20 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
21 |
|
@@ -2342,13 +2343,99 @@ with demo:
|
|
2342 |
|
2343 |
pil_image = Image.fromarray(image)
|
2344 |
return pil_image
|
2345 |
-
|
2346 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2347 |
if len(eigvecs) == 0:
|
2348 |
gr.Warning("Please run NCUT first.")
|
2349 |
return
|
|
|
|
|
|
|
|
|
2350 |
eigvecs = torch.tensor(eigvecs)
|
2351 |
-
|
2352 |
gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
|
2353 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
2354 |
from sklearn.manifold import TSNE
|
@@ -2357,8 +2444,8 @@ with demo:
|
|
2357 |
torch.manual_seed(seed)
|
2358 |
np.random.seed(seed)
|
2359 |
|
2360 |
-
fps_idx = farthest_point_sampling(
|
2361 |
-
fps_eigvecs =
|
2362 |
fps_eigvecs = fps_eigvecs.numpy()
|
2363 |
|
2364 |
tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
|
@@ -2371,19 +2458,29 @@ with demo:
|
|
2371 |
metric='cosine',
|
2372 |
random_state=seed,
|
2373 |
).fit_transform(fps_eigvecs)
|
|
|
|
|
|
|
2374 |
|
2375 |
edges = build_tree(tsne_embed)
|
|
|
|
|
2376 |
|
2377 |
# Plot the t-SNE points
|
2378 |
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
|
2379 |
|
2380 |
-
|
|
|
|
|
|
|
2381 |
|
2382 |
-
|
2383 |
-
|
2384 |
-
|
2385 |
-
|
2386 |
-
|
|
|
|
|
2387 |
gr.Markdown('---')
|
2388 |
gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
|
2389 |
gr.Markdown('---')
|
|
|
16 |
from einops import rearrange
|
17 |
from matplotlib import pyplot as plt
|
18 |
import matplotlib
|
19 |
+
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
|
20 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
21 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
22 |
|
|
|
2343 |
|
2344 |
pil_image = Image.fromarray(image)
|
2345 |
return pil_image
|
2346 |
+
|
2347 |
+
def get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed):
|
2348 |
+
n_dots = fps_eigvecs.shape[0]
|
2349 |
+
if n_dots > max_display_dots:
|
2350 |
+
dots_idx = np.random.choice(n_dots, max_display_dots, replace=False)
|
2351 |
+
import fpsample
|
2352 |
+
dots_idx = fpsample.bucket_fps_kdline_sampling(tsne_embed, max_display_dots, 5).astype(np.int64)
|
2353 |
+
else:
|
2354 |
+
dots_idx = np.arange(n_dots)
|
2355 |
+
fps_eigvecs = fps_eigvecs[dots_idx]
|
2356 |
+
fps_tsne_rgb = fps_tsne_rgb[dots_idx]
|
2357 |
+
|
2358 |
+
heatmaps = eigvecs @ fps_eigvecs.T # [B, H, W, C] @ [N, C] -> [B, H, W, N]
|
2359 |
+
value = heatmaps.mean(1).mean(1) # [B, N]
|
2360 |
+
top1_image_idxs = value.argmax(axis=0) # [N]
|
2361 |
+
|
2362 |
+
def pad_image_with_border(image, border_color, border_width):
|
2363 |
+
new_image = np.ones((image.shape[0] + 2 * border_width, image.shape[1] + 2 * border_width, image.shape[2]), dtype=image.dtype)
|
2364 |
+
new_image[:, :] = border_color
|
2365 |
+
new_image[border_width:-border_width, border_width:-border_width] = image
|
2366 |
+
return new_image
|
2367 |
+
|
2368 |
+
top1_image_blended = []
|
2369 |
+
cm = matplotlib.colormaps['hot']
|
2370 |
+
for i_fps in range(len(top1_image_idxs)):
|
2371 |
+
image_idx = top1_image_idxs[i_fps]
|
2372 |
+
image = images[image_idx]
|
2373 |
+
heatmap = heatmaps[image_idx, :, :, i_fps]
|
2374 |
+
heatmap = cm(heatmap)
|
2375 |
+
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
|
2376 |
+
image = image.convert("RGB").resize((256, 256))
|
2377 |
+
heatmap = Image.fromarray(heatmap).resize((256, 256)).convert("RGB")
|
2378 |
+
blended = 0.5 * np.array(image) + 0.5 * np.array(heatmap)
|
2379 |
+
blended = np.clip(blended, 0, 255).astype(np.uint8)
|
2380 |
+
border_color = fps_tsne_rgb[i_fps, :3] * 255
|
2381 |
+
border_width = 20
|
2382 |
+
padded_image = pad_image_with_border(blended, border_color, border_width)
|
2383 |
+
top1_image_blended.append(padded_image)
|
2384 |
+
|
2385 |
+
return top1_image_blended, dots_idx
|
2386 |
+
|
2387 |
+
|
2388 |
+
|
2389 |
+
def plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne_rgb, max_display_dots=100):
|
2390 |
+
top1_image_blended, dots_idx = get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed)
|
2391 |
+
|
2392 |
+
# Plot the t-SNE points
|
2393 |
+
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
|
2394 |
+
ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb)
|
2395 |
+
ax.set_xticks([])
|
2396 |
+
ax.set_yticks([])
|
2397 |
+
ax.axis('off')
|
2398 |
+
ax.set_xlim(tsne_embed[:, 0].min()*1.1, tsne_embed[:, 0].max()*1.1)
|
2399 |
+
ax.set_ylim(tsne_embed[:, 1].min()*1.1, tsne_embed[:, 1].max()*1.1)
|
2400 |
+
|
2401 |
+
# Add the top1_image_blended to the scatter plot
|
2402 |
+
for i, (x, y) in enumerate(tsne_embed[dots_idx]):
|
2403 |
+
img = top1_image_blended[i]
|
2404 |
+
img = np.array(img)
|
2405 |
+
imgbox = OffsetImage(img, zoom=0.15)
|
2406 |
+
ab = AnnotationBbox(imgbox, (x, y), frameon=False)
|
2407 |
+
ax.add_artist(ab)
|
2408 |
+
ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb)
|
2409 |
+
|
2410 |
+
# Remove the white space around the plot
|
2411 |
+
fig.tight_layout(pad=0)
|
2412 |
+
|
2413 |
+
# Save the plot to an in-memory buffer
|
2414 |
+
buf = io.BytesIO()
|
2415 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
2416 |
+
buf.seek(0)
|
2417 |
+
|
2418 |
+
# Load the image into a NumPy array
|
2419 |
+
image = np.array(Image.open(buf))
|
2420 |
+
|
2421 |
+
# Close the buffer and plot
|
2422 |
+
buf.close()
|
2423 |
+
plt.close(fig)
|
2424 |
+
|
2425 |
+
pil_image = Image.fromarray(image)
|
2426 |
+
return pil_image
|
2427 |
+
|
2428 |
+
|
2429 |
+
def run_fps_tsne_hierarchical(image_gallery, eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0, max_display_dots=300):
|
2430 |
if len(eigvecs) == 0:
|
2431 |
gr.Warning("Please run NCUT first.")
|
2432 |
return
|
2433 |
+
images = [image[0] for image in image_gallery]
|
2434 |
+
if isinstance(images[0], str):
|
2435 |
+
images = [Image.open(image) for image in images]
|
2436 |
+
|
2437 |
eigvecs = torch.tensor(eigvecs)
|
2438 |
+
_eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
|
2439 |
gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
|
2440 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
2441 |
from sklearn.manifold import TSNE
|
|
|
2444 |
torch.manual_seed(seed)
|
2445 |
np.random.seed(seed)
|
2446 |
|
2447 |
+
fps_idx = farthest_point_sampling(_eigvecs, num_sample_fps)
|
2448 |
+
fps_eigvecs = _eigvecs[fps_idx]
|
2449 |
fps_eigvecs = fps_eigvecs.numpy()
|
2450 |
|
2451 |
tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
|
|
|
2458 |
metric='cosine',
|
2459 |
random_state=seed,
|
2460 |
).fit_transform(fps_eigvecs)
|
2461 |
+
# normalize = [-1, 1]
|
2462 |
+
tsne_embed[:, 0] = (tsne_embed[:, 0] - tsne_embed[:, 0].min()) / (tsne_embed[:, 0].max() - tsne_embed[:, 0].min()) * 2 - 1
|
2463 |
+
tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1
|
2464 |
|
2465 |
edges = build_tree(tsne_embed)
|
2466 |
+
# edges = build_tree(fps_eigvecs, dist='cosine')
|
2467 |
+
# edges = build_tree(fps_tsne3d_rgb)
|
2468 |
|
2469 |
# Plot the t-SNE points
|
2470 |
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
|
2471 |
|
2472 |
+
# Plot the t-SNE points with image heatmaps
|
2473 |
+
big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
|
2474 |
+
|
2475 |
+
return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, big_pil_image
|
2476 |
|
2477 |
+
big_tsne_plot = gr.Image(label="spectral-tSNE tree [+ Cluster Heatmap]", elem_id="big_tsne_plot", interactive=False, format='png')
|
2478 |
+
|
2479 |
+
run_hierarchical_button.click(
|
2480 |
+
run_fps_tsne_hierarchical,
|
2481 |
+
inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
|
2482 |
+
outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, big_tsne_plot],
|
2483 |
+
)
|
2484 |
gr.Markdown('---')
|
2485 |
gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
|
2486 |
gr.Markdown('---')
|