Spaces:
Running
on
Zero
Running
on
Zero
fix msg, faster eig norm
Browse files- app.py +32 -9
- fps_cluster.py +2 -1
app.py
CHANGED
@@ -2232,6 +2232,12 @@ with demo:
|
|
2232 |
def __run_fn(*args, **kwargs):
|
2233 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
2234 |
rgb_gallery = to_pil_images(rgb)
|
|
|
|
|
|
|
|
|
|
|
|
|
2235 |
return eigvecs, rgb, rgb_gallery, logging_str
|
2236 |
|
2237 |
submit_button.click(
|
@@ -2408,8 +2414,8 @@ with demo:
|
|
2408 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
2409 |
with gr.Accordion("Outputs", open=True):
|
2410 |
gr.Markdown("""
|
2411 |
-
1. spectral-tSNE tree: ◆
|
2412 |
-
2. Cluster Heatmap: max cosine similarity to
|
2413 |
""")
|
2414 |
with gr.Column(scale=5, min_width=200):
|
2415 |
prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
|
@@ -2427,6 +2433,7 @@ with demo:
|
|
2427 |
tsne_plot.change(updaste_tsne_plot_change_granularity,
|
2428 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
|
2429 |
outputs=[tsne_prompt_image])
|
|
|
2430 |
run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
|
2431 |
inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
2432 |
# output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
|
@@ -2500,8 +2507,8 @@ with demo:
|
|
2500 |
x = int(x * w)
|
2501 |
y = int(y * h)
|
2502 |
eigvec = _eigvec[y, x]
|
2503 |
-
|
2504 |
-
closest_idx = np.
|
2505 |
return closest_idx
|
2506 |
|
2507 |
def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
|
@@ -2529,13 +2536,29 @@ with demo:
|
|
2529 |
output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
|
2530 |
|
2531 |
# draw heatmap for the connected components
|
|
|
2532 |
connected_eigvecs = fps_eigvecs[connected_idxs]
|
2533 |
-
left =
|
2534 |
-
right =
|
2535 |
-
left = F.normalize(left, p=2, dim=-1)
|
2536 |
-
right = F.normalize(right, p=2, dim=-1)
|
|
|
2537 |
similarity = left @ right.T
|
2538 |
-
similarity = similarity.max(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2539 |
hot_map = matplotlib.colormaps['hot']
|
2540 |
heatmap = hot_map(similarity)[..., :3] # B H W 3
|
2541 |
heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
|
|
2232 |
def __run_fn(*args, **kwargs):
|
2233 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
2234 |
rgb_gallery = to_pil_images(rgb)
|
2235 |
+
# normalize the eigvecs
|
2236 |
+
eigvecs = torch.tensor(eigvecs)
|
2237 |
+
if torch.cuda.is_available():
|
2238 |
+
eigvecs = eigvecs.cuda()
|
2239 |
+
eigvecs = F.normalize(eigvecs, p=2, dim=-1)
|
2240 |
+
eigvecs = eigvecs.cpu().numpy()
|
2241 |
return eigvecs, rgb, rgb_gallery, logging_str
|
2242 |
|
2243 |
submit_button.click(
|
|
|
2414 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
2415 |
with gr.Accordion("Outputs", open=True):
|
2416 |
gr.Markdown("""
|
2417 |
+
1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked dot.
|
2418 |
+
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
2419 |
""")
|
2420 |
with gr.Column(scale=5, min_width=200):
|
2421 |
prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
|
|
|
2433 |
tsne_plot.change(updaste_tsne_plot_change_granularity,
|
2434 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
|
2435 |
outputs=[tsne_prompt_image])
|
2436 |
+
prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
2437 |
run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
|
2438 |
inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
2439 |
# output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
|
|
|
2507 |
x = int(x * w)
|
2508 |
y = int(y * h)
|
2509 |
eigvec = _eigvec[y, x]
|
2510 |
+
sim = fps_eigvecs @ eigvec
|
2511 |
+
closest_idx = np.argmax(sim)
|
2512 |
return closest_idx
|
2513 |
|
2514 |
def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
|
|
|
2536 |
output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
|
2537 |
|
2538 |
# draw heatmap for the connected components
|
2539 |
+
## cosine distance
|
2540 |
connected_eigvecs = fps_eigvecs[connected_idxs]
|
2541 |
+
left = eigvecs.astype(np.float32) # B H W C
|
2542 |
+
right = connected_eigvecs.astype(np.float32) # N C
|
2543 |
+
# left = F.normalize(left, p=2, dim=-1)
|
2544 |
+
# right = F.normalize(right, p=2, dim=-1)
|
2545 |
+
# eigvec is already normalized when saved to gr.State
|
2546 |
similarity = left @ right.T
|
2547 |
+
similarity = similarity.max(axis=-1) # B H W N
|
2548 |
+
## euclidean distance
|
2549 |
+
# b, h, w = tsne3d_rgb.shape[:3]
|
2550 |
+
# tsne3d_rgb = tsne3d_rgb.reshape(b*h*w, 3)
|
2551 |
+
# connected_rgb = tsne3d_rgb[fps_indices][connected_idxs]
|
2552 |
+
# left = torch.tensor(tsne3d_rgb).float() # (B H W) 3
|
2553 |
+
# right = torch.tensor(connected_rgb).float() # N 3
|
2554 |
+
# # dist B H W N
|
2555 |
+
# dist = left[:, None] - right[None]
|
2556 |
+
# dist = torch.sqrt((dist ** 2).sum(dim=-1))
|
2557 |
+
# dist = dist.min(dim=-1).values # B H W
|
2558 |
+
# dist = dist.reshape(b, h, w)
|
2559 |
+
# gr.Info(f"dist: min={dist.min().item()}, max={dist.max().item()}, mean={dist.mean().item()}", 3)
|
2560 |
+
# similarity = 1 - dist
|
2561 |
+
|
2562 |
hot_map = matplotlib.colormaps['hot']
|
2563 |
heatmap = hot_map(similarity)[..., :3] # B H W 3
|
2564 |
heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
fps_cluster.py
CHANGED
@@ -5,7 +5,8 @@ import torch
|
|
5 |
|
6 |
def build_tree(all_dots):
|
7 |
num_sample = all_dots.shape[0]
|
8 |
-
center = all_dots.mean(axis=0)
|
|
|
9 |
distances_to_center = np.linalg.norm(all_dots - center, axis=1)
|
10 |
start_idx = np.argmin(distances_to_center)
|
11 |
indices = [start_idx]
|
|
|
5 |
|
6 |
def build_tree(all_dots):
|
7 |
num_sample = all_dots.shape[0]
|
8 |
+
# center = all_dots.mean(axis=0)
|
9 |
+
center = np.median(all_dots, axis=0)
|
10 |
distances_to_center = np.linalg.norm(all_dots - center, axis=1)
|
11 |
start_idx = np.argmin(distances_to_center)
|
12 |
indices = [start_idx]
|