huzey commited on
Commit
d99ae43
·
1 Parent(s): 5d4d23f

fix fps_cluster

Browse files
Files changed (1) hide show
  1. fps_cluster.py +8 -7
fps_cluster.py CHANGED
@@ -5,12 +5,7 @@ import torch
5
 
6
  def build_tree(all_dots, dist='euclidean'):
7
  num_sample = all_dots.shape[0]
8
- # center = all_dots.mean(axis=0)
9
- center = np.median(all_dots, axis=0)
10
- distances_to_center = np.linalg.norm(all_dots - center, axis=1)
11
- start_idx = np.argmin(distances_to_center)
12
- indices = [start_idx]
13
- distances = [114514,]
14
  if dist == 'euclidean':
15
  A = all_dots[:, None] - all_dots[None, :]
16
  A = (A ** 2).sum(-1)
@@ -23,6 +18,12 @@ def build_tree(all_dots, dist='euclidean'):
23
  A = 1 - A
24
  else:
25
  raise ValueError('dist must be euclidean or cosine')
 
 
 
 
 
 
26
  for i in range(num_sample - 1):
27
  _A = A[indices]
28
  min_dist = _A.min(dim=0).values
@@ -55,7 +56,7 @@ def build_tree(all_dots, dist='euclidean'):
55
  pi_indices = np.array(pi_indices)
56
 
57
  edges = np.stack([indices, pi_indices], axis=1)
58
- return edges
59
 
60
 
61
  def find_connected_component(edges, start_node):
 
5
 
6
  def build_tree(all_dots, dist='euclidean'):
7
  num_sample = all_dots.shape[0]
8
+
 
 
 
 
 
9
  if dist == 'euclidean':
10
  A = all_dots[:, None] - all_dots[None, :]
11
  A = (A ** 2).sum(-1)
 
18
  A = 1 - A
19
  else:
20
  raise ValueError('dist must be euclidean or cosine')
21
+
22
+ d_sum = A.mean(dim=1)
23
+ start_idx = torch.argmin(d_sum).item()
24
+ indices = [start_idx]
25
+ distances = [114514,]
26
+
27
  for i in range(num_sample - 1):
28
  _A = A[indices]
29
  min_dist = _A.min(dim=0).values
 
56
  pi_indices = np.array(pi_indices)
57
 
58
  edges = np.stack([indices, pi_indices], axis=1)
59
+ return edges, levels
60
 
61
 
62
  def find_connected_component(edges, start_node):