File size: 11,437 Bytes
5dac5bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# %%
import torch
import torch.nn.functional as F

def affinity_from_features(
    features,
    features_B=None,
    affinity_focal_gamma=1.0,
    distance="cosine",
    normalize_features=False,
    fill_diagonal=False,
    n_features=1,
):
    """Compute affinity matrix from input features.

    Args:
        features (torch.Tensor): input features, shape (n_samples, n_features)
        feature_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
        affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
            on weak connections, default 1.0
        distance (str): distance metric, 'cosine' (default) or 'euclidean'.
        apply_normalize (bool): normalize input features before computing affinity matrix,
            default True

    Returns:
        (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
    """
    # compute affinity matrix from input features
    features = features.clone()
    if features_B is not None:
        features_B = features_B.clone()

    # if feature_B is not provided, compute affinity matrix on features x features
    # if feature_B is provided, compute affinity matrix on features x feature_B
    if features_B is not None:
        assert not fill_diagonal, "fill_diagonal should be False when feature_B is None"
    features_B = features if features_B is None else features_B

    if normalize_features:
        features = F.normalize(features, dim=-1)
        features_B = F.normalize(features_B, dim=-1)

    if distance == "cosine":
        # if not check_if_normalized(features):
        
        # TODO: make sure features are normalized within each head
        
        features = F.normalize(features, dim=-1)
        # if not check_if_normalized(features_B):
        features_B = F.normalize(features_B, dim=-1)
        A = 1 - (features @ features_B.T) / n_features
    elif distance == "euclidean":
        A = torch.cdist(features, features_B, p=2) / n_features
    else:
        raise ValueError("distance should be 'cosine' or 'euclidean'")

    if fill_diagonal:
        A[torch.arange(A.shape[0]), torch.arange(A.shape[0])] = 0

    # torch.exp make affinity matrix positive definite,
    # lower affinity_focal_gamma reduce the weak edge weights
    A = torch.exp(-((A / affinity_focal_gamma)))
    return A

from ncut_pytorch.ncut_pytorch import run_subgraph_sampling, propagate_knn, gram_schmidt
import logging

import torch

def ncut(
    A,
    num_eig=20,
    eig_solver="svd_lowrank",
    make_symmetric=True,
):
    """PyTorch implementation of Normalized cut without Nystrom-like approximation.

    Args:
        A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
        num_eig (int): number of eigenvectors to return
        eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']

    Returns:
        (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
        (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
    """
    if make_symmetric:
        # make sure A is symmetric
        A = (A + A.T) / 2

    # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
    D_r = A.sum(dim=0).detach().clone()
    D_c = A.sum(dim=1).detach().clone()
    A /= torch.sqrt(D_r)[:, None]
    A /= torch.sqrt(D_c)[None, :]

    # compute eigenvectors
    if eig_solver == "svd_lowrank":  # default
        # only top q eigenvectors, fastest
        eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
    elif eig_solver == "lobpcg":
        # only top k eigenvectors, fast
        eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
    elif eig_solver == "svd":
        # all eigenvectors, slow
        eigen_vector, eigen_value, _ = torch.svd(A)
    elif eig_solver == "eigh":
        # all eigenvectors, slow
        eigen_value, eigen_vector = torch.linalg.eigh(A)
    else:
        raise ValueError(
            "eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
        )

    # sort eigenvectors by eigenvalues, take top (descending order)
    eigen_value = eigen_value.real
    eigen_vector = eigen_vector.real
    
    sort_order = torch.argsort(eigen_value, descending=True)[:num_eig]
    eigen_value = eigen_value[sort_order]
    eigen_vector = eigen_vector[:, sort_order]

    if eigen_value.min() < 0:
        logging.warning(
            "negative eigenvalues detected, please make sure the affinity matrix is positive definite"
        )

    return eigen_vector, eigen_value

def nystrom_ncut(
    features,
    features_B=None,
    num_eig=100,
    num_sample=10000,
    knn=10,
    sample_method="farthest",
    distance="cosine",
    affinity_focal_gamma=1.0,
    indirect_connection=False,
    indirect_pca_dim=100,
    device=None,
    eig_solver="svd_lowrank",
    normalize_features=False,
    matmul_chunk_size=8096,
    make_orthogonal=False,
    verbose=False,
    no_propagation=False,
    make_symmetric=False,
    n_features=1,
):
    """PyTorch implementation of Faster Nystrom Normalized cut.

    Args:
        features (torch.Tensor): feature matrix, shape (n_samples, n_features)
        features_2 (torch.Tensor): feature matrix 2, for asymmetric affinity matrix, shape (n_samples2, n_features)
        num_eig (int): default 20, number of top eigenvectors to return
        num_sample (int): default 30000, number of samples for Nystrom-like approximation
        knn (int): default 3, number of KNN for propagating eigenvectors from subgraph to full graph,
            smaller knn will result in more sharp eigenvectors,
        sample_method (str): sample method, 'farthest' (default) or 'random'
            'farthest' is recommended for better approximation
        distance (str): distance metric, 'cosine' (default) or 'euclidean'
        affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
            resulting in more sharp eigenvectors, default 1.0
        indirect_connection (bool): include indirect connection in the subgraph, default True
        indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
            the not sampled nodes, not applied to the sampled nodes
        device (str): device to use for computation, if None, will not change device
            a good practice is to pass features by CPU since it's usually large,
            and move subgraph affinity to GPU to speed up eigenvector computation
        eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
            'svd_lowrank' is recommended for large scale graph, it's the fastest
            they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
        normalize_features (bool): normalize input features before computing affinity matrix,
            default True
        matmul_chunk_size (int): chunk size for matrix multiplication
            large matrix multiplication is chunked to reduce memory usage,
            smaller chunk size will reduce memory usage but slower computation, default 8096
        make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
        verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
        no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors

    Returns:
        (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
        (torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
        (torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
    """

    # check if features dimension greater than num_eig
    if eig_solver in ["svd_lowrank", "lobpcg"]:
        assert features.shape[0] > (
            num_eig * 2
        ), "number of nodes should be greater than 2*num_eig"
    if eig_solver in ["svd", "eigh"]:
        assert (
            features.shape[0] > num_eig
        ), "number of nodes should be greater than num_eig"

    features = features.clone()
    if normalize_features:
        # features need to be normalized for affinity matrix computation (cosine distance)
        features = torch.nn.functional.normalize(features, dim=-1)

    sampled_indices = run_subgraph_sampling(
        features,
        num_sample=num_sample,
        sample_method=sample_method,
    )
    
    sampled_indices_B = run_subgraph_sampling(
        features_B,
        num_sample=num_sample,
        sample_method=sample_method,
    )

    sampled_features = features[sampled_indices]
    sampled_features_B = features_B[sampled_indices_B]
    # move subgraph gpu to speed up
    original_device = sampled_features.device
    device = original_device if device is None else device
    sampled_features = sampled_features.to(device)
    sampled_features_B = sampled_features_B.to(device)

    # compute affinity matrix on subgraph
    A = affinity_from_features(
        sampled_features, features_B=sampled_features_B,
        affinity_focal_gamma=affinity_focal_gamma, distance=distance,
        n_features=n_features,
    )

    not_sampled = torch.tensor(
        list(set(range(features.shape[0])) - set(sampled_indices))
    )

    if len(not_sampled) == 0:
        # if sampled all nodes, no need for nyström approximation
        eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
        return eigen_vector, eigen_value, sampled_indices

    # 1) PCA to reduce the node dimension for the not sampled nodes
    # 2) compute indirect connection on the PC nodes
    if len(not_sampled) > 0 and indirect_connection:
        raise NotImplementedError("indirect_connection is not implemented yet")
        indirect_pca_dim = min(indirect_pca_dim, min(*features.shape))
        U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
        feature_B = (features[not_sampled].T @ V).T  # project to PCA space
        feature_B = feature_B.to(device)
        B = affinity_from_features(
            sampled_features,
            feature_B,
            affinity_focal_gamma=affinity_focal_gamma,
            distance=distance,
            fill_diagonal=False,
        )
        # P is 1-hop random walk matrix
        B_row = B / B.sum(axis=1, keepdim=True)
        B_col = B / B.sum(axis=0, keepdim=True)
        P = B_row @ B_col.T
        P = (P + P.T) / 2
        # fill diagonal with 0
        P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
        A = A + P

    # compute normalized cut on the subgraph
    eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver, make_symmetric=make_symmetric)
    eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
    eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)

    if no_propagation:
        return eigen_vector, eigen_value, sampled_indices

    # propagate eigenvectors from subgraph to full graph
    eigen_vector = propagate_knn(
        eigen_vector,
        features,
        sampled_features,
        knn,
        chunk_size=matmul_chunk_size,
        device=device,
        use_tqdm=verbose,
    )

    # post-hoc orthogonalization
    if make_orthogonal:
        eigen_vector = gram_schmidt(eigen_vector)

    return eigen_vector, eigen_value, sampled_indices