ncut-pytorch / fps_cluster.py
huzey's picture
add tree option
d5680a3
raw
history blame
2.78 kB
# %%
import numpy as np
import torch
def build_tree(all_dots, dist='euclidean'):
num_sample = all_dots.shape[0]
# center = all_dots.mean(axis=0)
center = np.median(all_dots, axis=0)
distances_to_center = np.linalg.norm(all_dots - center, axis=1)
start_idx = np.argmin(distances_to_center)
indices = [start_idx]
distances = [114514,]
if dist == 'euclidean':
A = all_dots[:, None] - all_dots[None, :]
A = (A ** 2).sum(-1)
A = np.sqrt(A)
A = torch.tensor(A)
elif dist == 'cosine':
# assume all_dots is normalized
A = all_dots @ all_dots.T
A = torch.tensor(A)
A = 1 - A
else:
raise ValueError('dist must be euclidean or cosine')
for i in range(num_sample - 1):
_A = A[indices]
min_dist = _A.min(dim=0).values
next_idx = torch.argmax(min_dist).item()
distance = min_dist[next_idx].item()
indices.append(next_idx)
distances.append(distance)
indices = np.array(indices)
distances = np.array(distances)
levels = np.log2(distances[1] / distances)
levels = np.floor(levels).astype(int) + 1
levels[0] = 0
n_levels = levels.max() + 1
pi_indices = [indices[0],]
for i_level in range(1, n_levels):
current_level_indices = levels == i_level
prev_level_indices = levels < i_level
current_level_indices = indices[current_level_indices]
prev_level_indices = indices[prev_level_indices]
_A = A[prev_level_indices][:, current_level_indices]
_pi = _A.min(dim=0).indices
pi = prev_level_indices[_pi]
if isinstance(pi, np.int64) or isinstance(pi, int):
pi = [pi,]
if isinstance(pi, np.ndarray):
pi = pi.tolist()
pi_indices.extend(pi)
pi_indices = np.array(pi_indices)
edges = np.stack([indices, pi_indices], axis=1)
return edges
def find_connected_component(edges, start_node):
# Dictionary to store adjacency list
adjacency_list = {}
for edge in edges:
# Unpack edge
a, b = edge
# Add the connection for both nodes
if a in adjacency_list:
adjacency_list[a].append(b)
else:
adjacency_list[a] = [b]
if b in adjacency_list:
adjacency_list[b].append(a)
else:
adjacency_list[b] = [a]
# Use BFS to find all nodes in the connected component
connected_component = set()
queue = [start_node]
while queue:
node = queue.pop(0)
if node not in connected_component:
connected_component.add(node)
queue.extend(adjacency_list.get(node, [])) # Add neighbors to the queue
return np.array(list(connected_component))