File size: 2,676 Bytes
55debdb
 
 
 
 
d5680a3
55debdb
d99ae43
d5680a3
 
 
 
 
 
 
 
 
 
 
 
d99ae43
 
 
 
 
 
55debdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d99ae43
55debdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# %%
import numpy as np
import torch


def build_tree(all_dots, dist='euclidean'):
    num_sample = all_dots.shape[0]
    
    if dist == 'euclidean':
        A = all_dots[:, None] - all_dots[None, :]
        A = (A ** 2).sum(-1)
        A = np.sqrt(A)
        A = torch.tensor(A)
    elif dist == 'cosine':
        # assume all_dots is normalized
        A = all_dots @ all_dots.T
        A = torch.tensor(A)
        A = 1 - A
    else:
        raise ValueError('dist must be euclidean or cosine')

    d_sum = A.mean(dim=1)
    start_idx = torch.argmin(d_sum).item()
    indices = [start_idx]
    distances = [114514,]
    
    for i in range(num_sample - 1):
        _A = A[indices]
        min_dist = _A.min(dim=0).values
        next_idx = torch.argmax(min_dist).item()
        distance = min_dist[next_idx].item()
        indices.append(next_idx)
        distances.append(distance)
    indices = np.array(indices)
    distances = np.array(distances)

    levels = np.log2(distances[1] / distances)
    levels = np.floor(levels).astype(int) + 1
    levels[0] = 0

    n_levels = levels.max() + 1
    pi_indices = [indices[0],]
    for i_level in range(1, n_levels):
        current_level_indices = levels == i_level
        prev_level_indices = levels < i_level
        current_level_indices = indices[current_level_indices]
        prev_level_indices = indices[prev_level_indices]
        _A = A[prev_level_indices][:, current_level_indices]
        _pi = _A.min(dim=0).indices
        pi = prev_level_indices[_pi]
        if isinstance(pi, np.int64) or isinstance(pi, int):
            pi = [pi,]
        if isinstance(pi, np.ndarray):
            pi = pi.tolist()
        pi_indices.extend(pi)
    pi_indices = np.array(pi_indices)

    edges = np.stack([indices, pi_indices], axis=1)
    return edges, levels


def find_connected_component(edges, start_node):
    # Dictionary to store adjacency list
    adjacency_list = {}
    for edge in edges:
        # Unpack edge
        a, b = edge
        # Add the connection for both nodes
        if a in adjacency_list:
            adjacency_list[a].append(b)
        else:
            adjacency_list[a] = [b]
        if b in adjacency_list:
            adjacency_list[b].append(a)
        else:
            adjacency_list[b] = [a]
    
    # Use BFS to find all nodes in the connected component
    connected_component = set()
    queue = [start_node]
    
    while queue:
        node = queue.pop(0)
        if node not in connected_component:
            connected_component.add(node)
            queue.extend(adjacency_list.get(node, []))  # Add neighbors to the queue

    return np.array(list(connected_component))