Spaces:
Running
on
Zero
Running
on
Zero
# %% | |
import numpy as np | |
import torch | |
def build_tree(all_dots, dist='euclidean'): | |
num_sample = all_dots.shape[0] | |
if dist == 'euclidean': | |
A = all_dots[:, None] - all_dots[None, :] | |
A = (A ** 2).sum(-1) | |
A = np.sqrt(A) | |
A = torch.tensor(A) | |
elif dist == 'cosine': | |
# assume all_dots is normalized | |
A = all_dots @ all_dots.T | |
A = torch.tensor(A) | |
A = 1 - A | |
else: | |
raise ValueError('dist must be euclidean or cosine') | |
d_sum = A.mean(dim=1) | |
start_idx = torch.argmin(d_sum).item() | |
indices = [start_idx] | |
distances = [114514,] | |
for i in range(num_sample - 1): | |
_A = A[indices] | |
min_dist = _A.min(dim=0).values | |
next_idx = torch.argmax(min_dist).item() | |
distance = min_dist[next_idx].item() | |
indices.append(next_idx) | |
distances.append(distance) | |
indices = np.array(indices) | |
distances = np.array(distances) | |
levels = np.log2(distances[1] / distances) | |
levels = np.floor(levels).astype(int) + 1 | |
levels[0] = 0 | |
n_levels = levels.max() + 1 | |
pi_indices = [indices[0],] | |
for i_level in range(1, n_levels): | |
current_level_indices = levels == i_level | |
prev_level_indices = levels < i_level | |
current_level_indices = indices[current_level_indices] | |
prev_level_indices = indices[prev_level_indices] | |
_A = A[prev_level_indices][:, current_level_indices] | |
_pi = _A.min(dim=0).indices | |
pi = prev_level_indices[_pi] | |
if isinstance(pi, np.int64) or isinstance(pi, int): | |
pi = [pi,] | |
if isinstance(pi, np.ndarray): | |
pi = pi.tolist() | |
pi_indices.extend(pi) | |
pi_indices = np.array(pi_indices) | |
edges = np.stack([indices, pi_indices], axis=1) | |
return edges, levels | |
def find_connected_component(edges, start_node): | |
# Dictionary to store adjacency list | |
adjacency_list = {} | |
for edge in edges: | |
# Unpack edge | |
a, b = edge | |
# Add the connection for both nodes | |
if a in adjacency_list: | |
adjacency_list[a].append(b) | |
else: | |
adjacency_list[a] = [b] | |
if b in adjacency_list: | |
adjacency_list[b].append(a) | |
else: | |
adjacency_list[b] = [a] | |
# Use BFS to find all nodes in the connected component | |
connected_component = set() | |
queue = [start_node] | |
while queue: | |
node = queue.pop(0) | |
if node not in connected_component: | |
connected_component.add(node) | |
queue.extend(adjacency_list.get(node, [])) # Add neighbors to the queue | |
return np.array(list(connected_component)) |