# %% import numpy as np import torch def build_tree(all_dots, dist='euclidean'): num_sample = all_dots.shape[0] if dist == 'euclidean': A = all_dots[:, None] - all_dots[None, :] A = (A ** 2).sum(-1) A = np.sqrt(A) A = torch.tensor(A) elif dist == 'cosine': # assume all_dots is normalized A = all_dots @ all_dots.T A = torch.tensor(A) A = 1 - A else: raise ValueError('dist must be euclidean or cosine') d_sum = A.mean(dim=1) start_idx = torch.argmin(d_sum).item() indices = [start_idx] distances = [114514,] for i in range(num_sample - 1): _A = A[indices] min_dist = _A.min(dim=0).values next_idx = torch.argmax(min_dist).item() distance = min_dist[next_idx].item() indices.append(next_idx) distances.append(distance) indices = np.array(indices) distances = np.array(distances) levels = np.log2(distances[1] / distances) levels = np.floor(levels).astype(int) + 1 levels[0] = 0 n_levels = levels.max() + 1 pi_indices = [indices[0],] for i_level in range(1, n_levels): current_level_indices = levels == i_level prev_level_indices = levels < i_level current_level_indices = indices[current_level_indices] prev_level_indices = indices[prev_level_indices] _A = A[prev_level_indices][:, current_level_indices] _pi = _A.min(dim=0).indices pi = prev_level_indices[_pi] if isinstance(pi, np.int64) or isinstance(pi, int): pi = [pi,] if isinstance(pi, np.ndarray): pi = pi.tolist() pi_indices.extend(pi) pi_indices = np.array(pi_indices) edges = np.stack([indices, pi_indices], axis=1) return edges, levels def find_connected_component(edges, start_node): # Dictionary to store adjacency list adjacency_list = {} for edge in edges: # Unpack edge a, b = edge # Add the connection for both nodes if a in adjacency_list: adjacency_list[a].append(b) else: adjacency_list[a] = [b] if b in adjacency_list: adjacency_list[b].append(a) else: adjacency_list[b] = [a] # Use BFS to find all nodes in the connected component connected_component = set() queue = [start_node] while queue: node = queue.pop(0) if node not in connected_component: connected_component.add(node) queue.extend(adjacency_list.get(node, [])) # Add neighbors to the queue return np.array(list(connected_component))