huzey commited on
Commit
d5680a3
1 Parent(s): 38c0aa5

add tree option

Browse files
Files changed (1) hide show
  1. fps_cluster.py +13 -5
fps_cluster.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import torch
4
 
5
 
6
- def build_tree(all_dots):
7
  num_sample = all_dots.shape[0]
8
  # center = all_dots.mean(axis=0)
9
  center = np.median(all_dots, axis=0)
@@ -11,10 +11,18 @@ def build_tree(all_dots):
11
  start_idx = np.argmin(distances_to_center)
12
  indices = [start_idx]
13
  distances = [114514,]
14
- A = all_dots[:, None] - all_dots[None, :]
15
- A = (A ** 2).sum(-1)
16
- A = np.sqrt(A)
17
- A = torch.tensor(A)
 
 
 
 
 
 
 
 
18
  for i in range(num_sample - 1):
19
  _A = A[indices]
20
  min_dist = _A.min(dim=0).values
 
3
  import torch
4
 
5
 
6
+ def build_tree(all_dots, dist='euclidean'):
7
  num_sample = all_dots.shape[0]
8
  # center = all_dots.mean(axis=0)
9
  center = np.median(all_dots, axis=0)
 
11
  start_idx = np.argmin(distances_to_center)
12
  indices = [start_idx]
13
  distances = [114514,]
14
+ if dist == 'euclidean':
15
+ A = all_dots[:, None] - all_dots[None, :]
16
+ A = (A ** 2).sum(-1)
17
+ A = np.sqrt(A)
18
+ A = torch.tensor(A)
19
+ elif dist == 'cosine':
20
+ # assume all_dots is normalized
21
+ A = all_dots @ all_dots.T
22
+ A = torch.tensor(A)
23
+ A = 1 - A
24
+ else:
25
+ raise ValueError('dist must be euclidean or cosine')
26
  for i in range(num_sample - 1):
27
  _A = A[indices]
28
  min_dist = _A.min(dim=0).values